我正在计算int8
s向量中最频繁的数字。当我设置int
s的计数器数组时,Numba抱怨:
@jit(nopython=True)
def freq_int8(y):
"""Find most frequent number in array"""
count = np.zeros(256, dtype=int)
for val in y:
count[val] += 1
return ((np.argmax(count)+128) % 256) - 128
调用它,我得到以下错误:
TypingError: Invalid usage of Function(<built-in function zeros>) with parameters (int64, Function(<class 'int'>))
如果我删除dtype=int
它可以正常工作,我会得到不错的加速。但是,我对为什么声明一个int
s数组不起作用感到困惑。有没有已知的解决方法,在这里有没有值得提高的效率?
背景:我正在尝试从一些numpy繁重的代码中节省几微秒的时间。我尤其受到numpy.median
Numba的伤害,一直在研究Numba,但是我一直在努力提高median
。找到频率最高的数字是可以接受的替代选择median
,在这里,我已经获得了一些性能。上面的numba代码也比numpy.bincount
。
更新:输入接受的答案后,这是median
forint8
向量的实现。它比numpy.median
:大约快一个数量级:
@jit(nopython=True)
def median_int8(y):
N2 = len(y)//2
count = np.zeros(256, dtype=np.int32)
for val in y:
count[val] += 1
cs = 0
for i in range(-128, 128):
cs += count[i]
if cs > N2:
return float(i)
elif cs == N2:
j = i+1
while count[j] == 0:
j += 1
return (i + j)/2
出人意料的是,短向量的性能差异甚至更大,这显然是由于numpy
向量的开销所致:
>>> a = np.random.randint(-128, 128, 10)
>>> %timeit np.median(a)
The slowest run took 7.03 times longer than the fastest. This could mean that an intermediate result is being cached.
10000 loops, best of 3: 20.8 µs per loop
>>> %timeit median_int8(a)
The slowest run took 11.67 times longer than the fastest. This could mean that an intermediate result is being cached.
1000000 loops, best of 3: 593 ns per loop
这个开销是如此之大,我想知道是否有什么问题。
请注意,找到最频繁的数字通常称为mode,它与中位数相似,也就是平均值……在这种情况下,np.mean
速度会快得多。除非您的数据中有某些约束或特殊性,否则无法保证该模式近似于中位数。
如果您仍然想计算整数列表的模式np.bincount
,则正如您所提到的,就足够了(如果numba更快,那么就不会太多了):
count = np.bincount(y, minlength=256)
result = ((np.argmax(count)+128) % 256) - 128
请注意,我已经将minlength
参数添加到np.bincount
,以便它返回与代码中相同的256个长度列表。但实际上在实践中完全没有必要,因为您只想要argmax
,np.bincount
(不带minlength
)将返回一个列表,该列表的长度是中的最大数量y
。
至于numba错误,替换dtype=int
为dtype=np.int32
应该可以解决该问题。int
是python函数,您要nopython
在numba标头中指定。如果删除nopython
,则dtype=int
或dtype='i'
也将起作用(具有相同的效果)。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句