我有以下数据结构:
[[[ 512 520 1 130523]]
[[ 520 614 573 7448]]
[[ 614 616 615 210]]
[[ 616 622 619 269]]
[[ 622 624 623 162]]
[[ 625 770 706 8822]]
[[ 770 776 773 241]]]
我正在尝试返回相同形状的对象,但仅返回具有3个最大的第4列的行(如果这有意义)(因此在这种情况下,将是第1、2和6行)
最优雅的方法是什么?
您可以对数组进行排序,但是从NumPy 1.8开始,有一种更快的方法来查找N个最大值(尤其是当data
它很大时):
import numpy as np
data = np.array([[[ 512, 520, 1, 130523]],
[[ 520, 614, 573, 7448]],
[[ 614, 616, 615, 210]],
[[ 616, 622, 619, 269]],
[[ 622, 624, 623, 162]],
[[ 625, 770, 706, 8822]],
[[ 770, 776, 773, 241]]])
idx = np.argpartition(-data[...,-1].flatten(), 3)
print(data[idx[:3]])
产量
[[[ 520 614 573 7448]]
[[ 512 520 1 130523]]
[[ 625 770 706 8822]]]
np.argpartition
执行部分排序。它以部分排序的顺序返回数组的索引,以使每个kth
项目都处于其最终排序位置。实际上,每组k
项目都相对于其他组进行了排序,但每个组本身都不进行排序(因此节省了一些时间)。
请注意,前3行的返回顺序与出现的顺序不同data
。
为了进行比较,以下是通过使用np.argsort
(执行完整排序)可以找到前3行的方法:
idx = np.argsort(data[..., -1].flatten())
print(data[idx[-3:]])
产量
[[[ 520 614 573 7448]]
[[ 625 770 706 8822]]
[[ 512 520 1 130523]]]
注意:np.argsort
对于小型阵列,速度更快:
In [63]: %timeit idx = np.argsort(data[..., -1].flatten())
100000 loops, best of 3: 2.6 µs per loop
In [64]: %timeit idx = np.argpartition(-data[...,-1].flatten(), 3)
100000 loops, best of 3: 5.61 µs per loop
但是np.argpartition
对于大型阵列更快:
In [92]: data2 = np.tile(data, (10**3,1,1))
In [93]: data2.shape
Out[93]: (7000, 1, 4)
In [94]: %timeit idx = np.argsort(data2[..., -1].flatten())
10000 loops, best of 3: 164 µs per loop
In [95]: %timeit idx = np.argpartition(-data2[...,-1].flatten(), 3)
10000 loops, best of 3: 49.5 µs per loop
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句