我有一个二维的numpy数组,我希望将每个元素四舍五入到序列中最接近的数字。阵列具有形状(28000, 24)
。
例如,该序列将为[0, 0.05, 0.2, 0.33, 0.5]
。
例如,原件0.27
将四舍五入为0.33
,并将0.42
四舍五入为0.5
到目前为止,这是我使用的方式,但是使用双循环当然会非常慢。
MWE:
arr = np.array([[0.14, 0.18], [0.20, 0.27]])
new = []
sequence = np.array([0, 0.05, 0.2, 0.33, 0.5])
for i in range(len(arr)):
row = []
for j in range(len(arr[0])):
temp = (arr[i][j] - sequence)**2
row.append(list(sequence[np.where(temp == min(temp))])[0])
new.append(row)
结果:
[[0.2000001, 0.2000001], [0.2000001, 0.33000001]]
动机:
在机器学习中,我正在做出预测。由于结果反映出专家的信心,因此2/3可能为1(因此为0.66)。因此,在该数据中,将发生相对较多的0、0.1、0.2、0.33、0.66、0.75等。但是我的预测约为0.1724。通过将这种情况下的数值四舍五入到0.2,可以消除很多预测误差。
如何优化舍入所有元素?
更新:现在我已经预分配了内存,因此不必进行常量追加。
# new = [[0]*len(arr[0])] * len(arr), then unloading into new[i][j],
# instead of appending
时间:
Original problem: 36.62 seconds
Pre-allocated array: 15.52 seconds
shx2 SOLUTION 1 (extra dimension): 0.47 seconds
shx2 SOLUTION 2 (better for big arrays): 4.39 seconds
Jaime's np.digitize: 0.02 seconds
可以构建另一个中间存储空间不大于要处理的数组的真正矢量化解决方案np.digitize
。
>>> def round_to_sequence(arr, seq):
... rnd_thresholds = np.add(seq[:-1], seq[1:]) / 2
... arr = np.asarray(arr)
... idx = np.digitize(arr.ravel(), rnd_thresholds).reshape(arr.shape)
... return np.take(seq, idx)
...
>>> round_to_sequence([[0.14, 0.18], [0.20, 0.27]],
... [0, 0.05, 0.2, 0.33, 0.5])
array([[ 0.2 , 0.2 ],
[ 0.2 , 0.33]])
更新所以发生了什么...该函数的第一行指出了序列中各项之间的中点是什么。此值是四舍五入的阈值:在其以下,您必须向下舍入,在其上方,您必须舍入。我使用np.add
,而不是更清晰,seq[:-1] + seq[1:]
因此它可以接受列表或元组,而无需将其显式转换为numpy数组。
>>> seq = [0, 0.05, 0.2, 0.33, 0.5]
>>> rnd_threshold = np.add(seq[:-1], seq[1:]) / 2
>>> rnd_threshold
array([ 0.025, 0.125, 0.265, 0.415])
接下来,我们使用np.digitize
数组中的每个项目找出由这些阈值界定的bin。np.digitize
只需要一维数组,因此我们必须做一些.ravel
额外的.reshape
事情以保持数组的原始形状。照原样,它使用标准惯例,即对限制项进行四舍五入,您可以使用right
关键字参数来逆转此行为。
>>> arr = np.array([[0.14, 0.18], [0.20, 0.27]])
>>> idx = np.digitize(arr.ravel(), seq).reshape(arr.shape)
>>> idx
array([[2, 2],
[3, 3]], dtype=int64)
现在,我们要做的就是创建一个形状为的数组idx
,使用其条目索引要舍入的值序列。可以用实现此目的seq[idx]
,但是使用(总是?)通常更快(请参见此处)np.take
。
>>> np.take(seq, idx)
array([[ 0.2 , 0.2 ],
[ 0.33, 0.33]])
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句