我正在将代码从MATLAB转换为python,以加快简单的操作。我写了一个包含嵌套循环和条件语句的函数。循环的目的是返回与数组y比较时数组x中最接近元素的索引列表。我正在按1e5项的顺序进行比较,这需要大约30秒才能运行。任何有助于加快此过程的帮助将不胜感激!我使用numba-pro自动即时编译器获得了部分成功:
@autojit()
def find_nearest(x,y,idx):
idx_old = 0
rng1 = range(y.shape[0])
rng2 = range(x.shape[0])
for i in rng1:
prev = abs(x[idx_old]-y[i])
for j in rng2:
if abs(x[j]-y[i]) < prev:
prev = abs(x[j]-y[i])
idx_old = j
idx[i] = idx_old
return idx
很抱歉成为这样的菜鸟,我是python的新手!
您的Numba代码没什么错,除了算法效率不高。更好的方法是对x
数组进行排序并执行二进制搜索,这与该答案以及以下答案非常相似:
def find_nearest(x, y):
indices = np.argsort(x)
loc = np.searchsorted(x[indices], y)
right = indices.take(loc, mode='clip')
left = indices.take(loc-1, mode='clip')
return np.where(abs(y-x[left]) < abs(y-x[right]), left, right)
在我的电脑,这是约80倍,甚至比对KDTree方法快x
且y
有10个6和10 5分别元素。大约有三分之二的时间都花在argsort
了阵列上,所以我认为在这里使用Numba不会带来太多收益。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句