KDTree 示例¶
日期 | 2011-03-24(最后修改),2008-09-23(创建) |
'注意:scipy 中有一个 kd-tree 的实现:https://docs.scipy.org.cn/scipy/docs/scipy.spatial.kdtree.KDTree/ 建议使用它而不是下面的代码。'
这是一个关于如何在 Python 中使用 NumPy 构建和搜索 kd-tree 的示例。kd-tree 用于在多维空间中搜索相邻数据点。在 kd-tree 中搜索所有 n 个点的最近邻,其时间复杂度为 O(n log n),与样本大小有关。
构建 kd-tree¶
In [ ]
#!python numbers=disable
# Copyleft 2008 Sturla Molden
# University of Oslo
#import psyco
import numpy
def kdtree( data, leafsize=10 ):
build a kd-tree for O(n log n) nearest neighbour search
data: 2D ndarray, shape =(ndim,ndata), preferentially C order
leafsize: max. number of data points to leave in a leaf
kd-tree: list of tuples
ndim = data.shape[0]
ndata = data.shape[1]
# find bounding hyper-rectangle
hrect = numpy.zeros((2,data.shape[0]))
hrect[0,:] = data.min(axis=1)
hrect[1,:] = data.max(axis=1)
# create root of kd-tree
idx = numpy.argsort(data[0,:], kind='mergesort')
data[:,:] = data[:,idx]
splitval = data[0,ndata/2]
left_hrect = hrect.copy()
right_hrect = hrect.copy()
left_hrect[1, 0] = splitval
right_hrect[0, 0] = splitval
tree = [(None, None, left_hrect, right_hrect, None, None)]
stack = [(data[:,:ndata/2], idx[:ndata/2], 1, 0, True),
(data[:,ndata/2:], idx[ndata/2:], 1, 0, False)]
# recursively split data in halves using hyper-rectangles:
while stack:
# pop data off stack
data, didx, depth, parent, leftbranch = stack.pop()
ndata = data.shape[1]
nodeptr = len(tree)
# update parent node
_didx, _data, _left_hrect, _right_hrect, left, right = tree[parent]
tree[parent] = (_didx, _data, _left_hrect, _right_hrect, nodeptr, right) if leftbranch \
else (_didx, _data, _left_hrect, _right_hrect, left, nodeptr)
# insert node in kd-tree
# leaf node?
if ndata <= leafsize:
_didx = didx.copy()
_data = data.copy()
leaf = (_didx, _data, None, None, 0, 0)
# not a leaf, split the data in two
splitdim = depth % ndim
idx = numpy.argsort(data[splitdim,:], kind='mergesort')
data[:,:] = data[:,idx]
didx = didx[idx]
nodeptr = len(tree)
stack.append((data[:,:ndata/2], didx[:ndata/2], depth+1, nodeptr, True))
stack.append((data[:,ndata/2:], didx[ndata/2:], depth+1, nodeptr, False))
splitval = data[splitdim,ndata/2]
if leftbranch:
left_hrect = _left_hrect.copy()
right_hrect = _left_hrect.copy()
left_hrect = _right_hrect.copy()
right_hrect = _right_hrect.copy()
left_hrect[1, splitdim] = splitval
right_hrect[0, splitdim] = splitval
# append node to tree
tree.append((None, None, left_hrect, right_hrect, None, None))
return tree
搜索 kd-tree¶
In [ ]
#!python numbers=disable
def intersect(hrect, r2, centroid):
checks if the hyperrectangle hrect intersects with the
hypersphere defined by centroid and r2
maxval = hrect[1,:]
minval = hrect[0,:]
p = centroid.copy()
idx = p < minval
p[idx] = minval[idx]
idx = p > maxval
p[idx] = maxval[idx]
return ((p-centroid)**2).sum() < r2
def quadratic_knn_search(data, lidx, ldata, K):
""" find K nearest neighbours of data among ldata """
ndata = ldata.shape[1]
param = ldata.shape[0]
K = K if K < ndata else ndata
retval = []
sqd = ((ldata - data[:,:ndata])**2).sum(axis=0) # data.reshape((param,1)).repeat(ndata, axis=1);
idx = numpy.argsort(sqd, kind='mergesort')
idx = idx[:K]
return zip(sqd[idx], lidx[idx])
def search_kdtree(tree, datapoint, K):
""" find the k nearest neighbours of datapoint in a kdtree """
stack = [tree[0]]
knn = [(numpy.inf, None)]*K
_datapt = datapoint[:,0]
while stack:
leaf_idx, leaf_data, left_hrect, \
right_hrect, left, right = stack.pop()
# leaf
if leaf_idx is not None:
_knn = quadratic_knn_search(datapoint, leaf_idx, leaf_data, K)
if _knn[0][0] < knn[-1][0]:
knn = sorted(knn + _knn)[:K]
# not a leaf
# check left branch
if intersect(left_hrect, knn[-1][0], _datapt):
# chech right branch
if intersect(right_hrect, knn[-1][0], _datapt):
return knn
def knn_search( data, K, leafsize=2048 ):
""" find the K nearest neighbours for data points in data,
using an O(n log n) kd-tree """
ndata = data.shape[1]
param = data.shape[0]
# build kdtree
tree = kdtree(data.copy(), leafsize=leafsize)
# search kdtree
knn = []
for i in numpy.arange(ndata):
_data = data[:,i].reshape((param,1)).repeat(leafsize, axis=1);
_knn = search_kdtree(tree, _data, K+1)
return knn
def radius_search(tree, datapoint, radius):
""" find all points within radius of datapoint """
stack = [tree[0]]
inside = []
while stack:
leaf_idx, leaf_data, left_hrect, \
right_hrect, left, right = stack.pop()
# leaf
if leaf_idx is not None:
distance = numpy.sqrt(((leaf_data - datapoint.reshape((param,1)))**2).sum(axis=0))
near = numpy.where(distance<=radius)
if len(near[0]):
idx = leaf_idx[near]
distance = distance[near]
inside += (zip(distance, idx))
if intersect(left_hrect, radius, datapoint):
if intersect(right_hrect, radius, datapoint):
return inside
与 kd-tree 相比,直接的穷举搜索的时间复杂度为二次方,与样本大小有关。当样本大小非常小时,它可能比使用 kd-tree 更快。在我的电脑上,大约 500 个样本或更少。
In [ ]
#!python numbers=disable
def knn_search( data, K ):
""" find the K nearest neighbours for data points in data,
using O(n**2) search """
ndata = data.shape[1]
knn = []
idx = numpy.arange(ndata)
for i in numpy.arange(ndata):
_knn = quadratic_knn_search(data[:,i], idx, data, K+1) # see above
knn.append( _knn[1:] )
return knn
虽然创建 KD 树非常快,但搜索它可能很耗时。由于 Python 令人讨厌的“全局解释器锁”(GIL),线程无法用于并行执行多个搜索。也就是说,Python 线程可以用于异步,但不能用于并发。但是,我们可以使用多个进程(多个解释器)。pyprocessing 包使这变得很容易。它具有与 Python 的 threading 和 Queue 标准模块类似的 API,但使用进程而不是线程。从 Python 2.6 开始,pyprocessing 已包含在 Python 的标准库中,称为“multiprocessing”模块。使用多个进程会带来一些小的开销,包括进程创建、进程启动、IPC 和进程终止。但是,由于进程在单独的地址空间中运行,因此不会产生内存争用。在以下示例中,使用多个进程的开销与计算量相比非常小,因此速度提升接近计算机上的 CPU 数量。
In [ ]
#!python numbers=disable
import multiprocessing as processing
import processing
import ctypes, os
def __num_processors():
if os.name == 'nt': # Windows
return int(os.getenv('NUMBER_OF_PROCESSORS'))
else: # glibc (Linux, *BSD, Apple)
get_nprocs = ctypes.cdll.libc.get_nprocs
get_nprocs.restype = ctypes.c_int
get_nprocs.argtypes = []
return get_nprocs()
def __search_kdtree(tree, data, K, leafsize):
knn = []
param = data.shape[0]
ndata = data.shape[1]
for i in numpy.arange(ndata):
_data = data[:,i].reshape((param,1)).repeat(leafsize, axis=1);
_knn = search_kdtree(tree, _data, K+1)
return knn
def __remote_process(rank, qin, qout, tree, K, leafsize):
while 1:
# read input queue (block until data arrives)
nc, data = qin.get()
# process data
knn = __search_kdtree(tree, data, K, leafsize)
# write to output queue
def knn_search_parallel(data, K, leafsize=2048):
""" find the K nearest neighbours for data points in data,
using an O(n log n) kd-tree, exploiting all logical
processors on the computer """
ndata = data.shape[1]
param = data.shape[0]
nproc = __num_processors()
# build kdtree
tree = kdtree(data.copy(), leafsize=leafsize)
# compute chunk size
chunk_size = data.shape[1] / (4*nproc)
chunk_size = 100 if chunk_size < 100 else chunk_size
# set up a pool of processes
qin = processing.Queue(maxsize=ndata/chunk_size)
qout = processing.Queue(maxsize=ndata/chunk_size)
pool = [processing.Process(target=__remote_process,
args=(rank, qin, qout, tree, K, leafsize))
for rank in range(nproc)]
for p in pool: p.start()
# put data chunks in input queue
cur, nc = 0, 0
while 1:
_data = data[:,cur:cur+chunk_size]
if _data.shape[1] == 0: break
cur += chunk_size
nc += 1
# read output queue
knn = []
while len(knn) < nc:
knn += [qout.get()]
# avoid race condition
_knn = [n for i,n in sorted(knn)]
knn = []
for tmp in _knn:
knn += tmp
# terminate workers
for p in pool: p.terminate()
return knn
In [ ]
#!python numbers=disable
from time import clock
def test():
K = 11
ndata = 10000
ndim = 12
data = 10 * numpy.random.rand(ndata*ndim).reshape((ndim,ndata) )
knn_search(data, K)
if __name__ == '__main__':
t0 = clock()
t1 = clock()
print "Elapsed time %.2f seconds" % t1-t0
#import profile # using Python's profiler is not useful if you are
#profile.run('test()') # running the parallel search.
In [ ]