给定一个data
整数的Nx2 numpy numpy数组(我们可以假设data
没有重复的行),我只需要保留其元素满足关系的行
(data[i,0] == data[j,1]) & (data[i,1] == data[j,0])
例如与
import numpy as np
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
我应该回来
array([[1, 2], # because 2,1 is present
[2, 1], # because 1,2 is present
[6, 6]]) # because 6,6 is present
一种详细的方法是
def filter_symmetric_pairs(data):
result = np.empty((0,2))
for i in range(len(data)):
for j in range(len(data)):
if (data[i,0] == data[j,1]) & (data[i,1] == data[j,0]):
result = np.vstack([result, data[i,:]])
return result
我想出了一个更简洁的方法:
def filter_symmetric_pairs(data):
return data[[row.tolist() in data[:,::-1].tolist() for row in data]]
有人可以建议一个更好的numpy成语吗?
这是您可以使用的几种不同方法。第一个是“显而易见的”二次方解决方案,它很简单,但是如果输入数组很大,可能会给您带来麻烦。只要输入中没有很大范围的数字,第二个就应该可以工作,并且它的优点是可以处理线性数量的内存。
import numpy as np
# Input data
data = np.array([[1, 2],
[2, 1],
[7, 3],
[6, 6],
[5, 6]])
# Method 1 (quadratic memory)
d0, d1 = data[:, 0, np.newaxis], data[:, 1]
# Compare all values in first column to all values in second column
c = d0 == d1
# Find where comparison matches both ways
c &= c.T
# Get matching elements
res = data[c.any(0)]
print(res)
# [[1 2]
# [2 1]
# [6 6]]
# Method 2 (linear memory)
# Convert pairs into single values
# (assumes positive values, otherwise shift first)
n = data.max() + 1
v = data[:, 0] + (n * data[:, 1])
# Symmetric values
v2 = (n * data[:, 0]) + data[:, 1]
# Find where symmetric is present
m = np.isin(v2, v)
res = data[m]
print(res)
# [[1 2]
# [2 1]
# [6 6]]
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句