kD-Tree kNN in python

We're taking this tree to the k-th dimension. They need paper there.
Ok, first I will try and explain away the problems of the names kD-Tree and kNN.


 kD-Tree is a k-Dimensional tree.   The tree data structure itself that has k dimensions but the space that the tree is modeling. For example, if you were interested in how tall you are over  time you would have a two dimensional space; height and age.  If you were interested in how tall and how much you weighed you were over time you would have a three dimensional space; height-weight-age.   There are many ways to represent these kinds of multidimensional data in software.  A kD-Tree often used when you want to group like points to boxes for whatever reason.  Now of course, a box by definition is a two dimensional shape and  we are working with k dimensional space so it just won't do.  In order to group k dimensional space we use a hypercube. To be very brief, a hypercube is a box that works in k dimensional space. 


Said another way, the basic goal is to take a large number of k-dimensional points and group them together in hypercubes.   Why do we want to do this?  kNN is one reason.  kNN stands for k-nearest neighbors.  In short it is a list of the k closet points to a given point. There are many reasons that we want to know similar items.  For example imagine we had the height-weight-age data were were talking about above.  It might stand to reason that a people of a similar height and weight at the same age as yourself might  be correlated to chances of heat attack.  So, if you were able to find a group of people of like height-weight-time we might be able to predict your likelihood of a heart attack.  

Example Problem

So in this example case we have the data being a subset of people and we know about them  height-weight-age and if they had a heart attack or not.  We also know your  height-weight-age and want to get a percentage of like people that have had a heart attack. 

visual example of a kD-Tree from wikipedia

ads by google:

The Complete Source:  Also available for download.
import re
import sys
import math
import operator
from bisect import insort
from bisect import bisect
class Point:pass
class Node:
	#node.pointCount = len(points)
	#node.points = None
	#node.leftChild = None
	#node.rightChild = None
	def printOut(self,depth = 0):
		for i in range(0,depth):
			print "....",
		print "point count =", self.pointCount, "rect =", self.hyperRect.toString(), "points =",self.points
		if (self.leftChild is not None):
			for i in range(0,depth):
				print "....",
			print "left: "
		if (self.rightChild is not None):
			for i in range(0,depth):
				print "....",
			print "right: "
	def buildBoundingHyperRect(self,points):
		self.hyperRect = HyperRect()
def getFastDistance(a,b):
	dim = len(b);
	total = 0;
	for i in range(0,dim):
		delta = a[i] - b[i];
		total = total + (delta *delta)
	return total
class Neighbors:
	def add(self,node,query):
		for i in range(0,node.pointCount):
			dist = getFastDistance(node.points[i].data,query)
			if (dist < self.minDistanceSquared):
				item = [dist,node.points[i]]
				if (len(self.points) > self.k):
					self.points = self.points[0:self.k]
		if (len(self.points) == self.k):
			self.minDistanceSquared = self.points[self.k-1][0]
class HyperRect:
	def buildBoundingHyperRect(self,points):
		self.k = len(points[0].data)
		self.dims  = range(0,self.k) 
		high = points[0].data[:]
		low = points[0].data[:]
		for i in range(0,len(points)):
			for j in self.dims:
				point = points[i].data[j]
				if (high[j]  < point):
					high[j] = point
				if (low[j] > point):
					low[j] = point	
		self.high = high
		self.low = low
	def getWidestDimension(self):
		widest =0
		widestDim =-1
		for i in self.dims:
			width = self.high[i] -  self.low[i]
			if (width > widest):
				widestDim =i
				widest = width
		self.widest = widest;
		self.widestDim = widestDim;	
		return self.widestDim
	def getWidestDimensionWidth(self):
		return self.widest
	def toString(self):
		return "high =",self.high,"low =",self.low,
	def getMinDistance(self,query):
		total = 0
		for i in self.dims:
			delta = 0.0
			if (self.high[i] < query[i]):
				delta = query[i] - self.high[i]
			elif (self.low[i] > query[i]):
				delta = self.low[i] - query[i]
			total = total + (delta*delta)
		return total;
def buildKdHyperRectTree(points,rootMin=3):
	global nodes;
	if (points is None or len(points) ==0):
		return None
	n = Node() # make a new node
	n.buildBoundingHyperRect(points) 	# build the hyper rect for these points
						# this will fight the top left and botom
						# right of all given points.
	leaf = len(points) <= rootMin; #check if this
	splitDim  = -1
	if (not leaf):
		splitDim = n.hyperRect.getWidestDimension()    # get the widest dimension to split n
						            	# to maximize splitting affect
		if (n.hyperRect.getWidestDimensionWidth() == 0.0): # do we have a bunch of children at the same point?
			left = True 
	#init the node
	n.pointCount = len(points)
	n.points = None
	n.leftChild = None
	n.rightChild = None
	if (leaf or len(points)==0):
		n.points = points # we are a leaf so just store all points in the rect
		points.sort(key=lambda points: points.data[splitDim]) # sort by the best split att
		median = len(points)/2 	# get the median
								# and split left for small, right for larger
		n.leftChild = buildKdHyperRectTree(points[0:(median+1)],rootMin)
		if (median +1 < len(points)):
			n.rightChild = buildKdHyperRectTree(points[median+1:], rootMin)
	return n;
def getKNN(query,node, neighbours,distanceSquared):	
	if (neighbours.minDistanceSquared > distanceSquared):
		if (node.leftChild is None):
			distLeft = node.leftChild.hyperRect.getMinDistance(query)
			distRight = node.rightChild.hyperRect.getMinDistance(query)
			if (distLeft < distRight):
def runknn(filename):
	f = open(filename,"r")
	patten = re.compile("[ ]+")
	dataset = []
	index = 0
	for line in f:
		cleanLine = patten.sub(line," ")
		items = cleanLine.split()
		p = Point()
		p.data = [float(items[1]),float(items[2])]
		p.baseIndex = index
		index = index +1
	kd = buildKdHyperRectTree(dataset[:],10)
	for point in dataset:
		neighbours = Neighbors()
		neighbours.k = 4
		neighbours.points = []
		neighbours.minDistanceSquared = float("infinity")
		name = str(point.baseIndex+1)
		answer = name+" "
		for i in range(1,4):	
			name = str(neighbours.points[i][1].baseIndex+1)
			answer = answer+name
			if (i!=3):
				answer= answer+","
		print answer
if (len(sys.argv)==2):
if (len(sys.argv)==3):
	import profile
Michael Knight,
Apr 25, 2009, 5:14 AM