### kD-Tree kNN in python

We're taking this tree to the k-th dimension. They need paper there.
Ok, first I will try and explain away the problems of the names kD-Tree and kNN.

 kD-Tree A kD-Tree is a k-Dimensional tree.   The tree data structure itself that has k dimensions but the space that the tree is modeling. For example, if you were interested in how tall you are over  time you would have a two dimensional space; height and age.  If you were interested in how tall and how much you weighed you were over time you would have a three dimensional space; height-weight-age.   There are many ways to represent these kinds of multidimensional data in software.  A kD-Tree often used when you want to group like points to boxes for whatever reason.  Now of course, a box by definition is a two dimensional shape and  we are working with k dimensional space so it just won't do.  In order to group k dimensional space we use a hypercube. To be very brief, a hypercube is a box that works in k dimensional space. kNNSaid another way, the basic goal is to take a large number of k-dimensional points and group them together in hypercubes.   Why do we want to do this?  kNN is one reason.  kNN stands for k-nearest neighbors.  In short it is a list of the k closet points to a given point. There are many reasons that we want to know similar items.  For example imagine we had the height-weight-age data were were talking about above.  It might stand to reason that a people of a similar height and weight at the same age as yourself might  be correlated to chances of heat attack.  So, if you were able to find a group of people of like height-weight-time we might be able to predict your likelihood of a heart attack.   Example ProblemSo in this example case we have the data being a subset of people and we know about them  height-weight-age and if they had a heart attack or not.  We also know your  height-weight-age and want to get a percentage of like people that have had a heart attack. visual example of a kD-Tree from wikipediaads by google:

The Complete Source:  Also available for download.
```#!/usr/bin/python
import re
import sys
import math
import operator
from bisect import insort
from bisect import bisect

class Point:pass
class Node:
#node.pointCount = len(points)
#node.points = None
#node.leftChild = None
#node.rightChild = None
def printOut(self,depth = 0):
for i in range(0,depth):
print "....",
print "point count =", self.pointCount, "rect =", self.hyperRect.toString(), "points =",self.points
if (self.leftChild is not None):
for i in range(0,depth):
print "....",
print "left: "
self.leftChild.printOut(depth+1)
if (self.rightChild is not None):
for i in range(0,depth):
print "....",
print "right: "
self.rightChild.printOut(depth+1)
def buildBoundingHyperRect(self,points):
self.hyperRect = HyperRect()
self.hyperRect.buildBoundingHyperRect(points)

def getFastDistance(a,b):
dim = len(b);
total = 0;
for i in range(0,dim):
delta = a[i] - b[i];
total = total + (delta *delta)
return total

class Neighbors:
def add(self,node,query):
for i in range(0,node.pointCount):
dist = getFastDistance(node.points[i].data,query)
if (dist < self.minDistanceSquared):
item = [dist,node.points[i]]
insort(self.points,item)
if (len(self.points) > self.k):
self.points = self.points[0:self.k]

if (len(self.points) == self.k):
self.minDistanceSquared = self.points[self.k-1][0]
return;

class HyperRect:
def buildBoundingHyperRect(self,points):
self.k = len(points[0].data)
self.dims  = range(0,self.k)
high = points[0].data[:]
low = points[0].data[:]
for i in range(0,len(points)):
for j in self.dims:
point = points[i].data[j]
if (high[j]  < point):
high[j] = point
if (low[j] > point):
low[j] = point
self.high = high
self.low = low
return

def getWidestDimension(self):
widest =0
widestDim =-1
for i in self.dims:
width = self.high[i] -  self.low[i]
if (width > widest):
widestDim =i
widest = width
self.widest = widest;
self.widestDim = widestDim;
return self.widestDim
def getWidestDimensionWidth(self):
return self.widest
def toString(self):
return "high =",self.high,"low =",self.low,
def getMinDistance(self,query):
total = 0
for i in self.dims:
delta = 0.0
if (self.high[i] < query[i]):
delta = query[i] - self.high[i]
elif (self.low[i] > query[i]):
delta = self.low[i] - query[i]
total = total + (delta*delta)
return total;
def buildKdHyperRectTree(points,rootMin=3):
global nodes;
if (points is None or len(points) ==0):
return None
n = Node() # make a new node
n.buildBoundingHyperRect(points) 	# build the hyper rect for these points
# this will fight the top left and botom
# right of all given points.

leaf = len(points) <= rootMin; #check if this
splitDim  = -1

if (not leaf):
splitDim = n.hyperRect.getWidestDimension()    # get the widest dimension to split n
# to maximize splitting affect
if (n.hyperRect.getWidestDimensionWidth() == 0.0): # do we have a bunch of children at the same point?
left = True
#init the node
n.pointCount = len(points)
n.points = None
n.leftChild = None
n.rightChild = None

if (leaf or len(points)==0):
n.points = points # we are a leaf so just store all points in the rect
else:
points.sort(key=lambda points: points.data[splitDim]) # sort by the best split att
median = len(points)/2 	# get the median
# and split left for small, right for larger
n.leftChild = buildKdHyperRectTree(points[0:(median+1)],rootMin)
if (median +1 < len(points)):
n.rightChild = buildKdHyperRectTree(points[median+1:], rootMin)
return n;

def getKNN(query,node, neighbours,distanceSquared):
if (neighbours.minDistanceSquared > distanceSquared):
if (node.leftChild is None):
neighbours.add(node,query)
else:
distLeft = node.leftChild.hyperRect.getMinDistance(query)
distRight = node.rightChild.hyperRect.getMinDistance(query)

if (distLeft < distRight):
getKNN(query,node.leftChild,neighbours,distLeft)
getKNN(query,node.rightChild,neighbours,distRight)
else:
getKNN(query,node.rightChild,neighbours,distRight)
getKNN(query,node.leftChild,neighbours,distLeft)

def runknn(filename):
f = open(filename,"r")
patten = re.compile("[ ]+")
dataset = []
index = 0
for line in f:
cleanLine = patten.sub(line," ")
items = cleanLine.split()
p = Point()
p.data = [float(items[1]),float(items[2])]
p.baseIndex = index
index = index +1
dataset.append(p)

kd = buildKdHyperRectTree(dataset[:],10)

for point in dataset:
neighbours = Neighbors()
neighbours.k = 4
neighbours.points = []
neighbours.minDistanceSquared = float("infinity")
getKNN(point.data,kd,neighbours,getFastDistance(kd.hyperRect.high,kd.hyperRect.low))
name = str(point.baseIndex+1)
answer = name+" "
for i in range(1,4):
name = str(neighbours.points[i][1].baseIndex+1)
answer = answer+name
if (i!=3):
answer= answer+","
print answer
if (len(sys.argv)==2):
runknn(sys.argv[1])
if (len(sys.argv)==3):
import profile
runknn.run("runknn(sys.argv[1])")```
ċ
kdtree_knn.py
(5k)
Michael Knight,
Apr 25, 2009, 5:14 AM
Comments