fn
기본적 cosine distance
으로 두 개의 큰 숫자 배열 (10000, 100) 및 (5000, 100) 행 단위로 계산 하는 함수를 적용하고 싶습니다 . 즉,이 배열의 각 행 조합에 대한 값을 계산합니다.
내 구현 :
import math
def fn(v1,v2):
sumxx, sumxy, sumyy = 0, 0, 0
for i in range(len(v1)):
x = v1[i]; y = v2[i]
sumxx += x*x
sumyy += y*y
sumxy += x*y
return sumxy/math.sqrt(sumxx*sumyy)
val = []
for i in range(array1.shape[0]):
for j in range(array2.shape[0]):
val.append(fn(array1[i, :], array2[j, :]))
이 기능은 매우 빠르고 몇 ms 밖에 걸리지 않습니다.
CPU times: user 4 ms, sys: 0 ns, total: 4 ms
Wall time: 1.24 ms
이를 수행하는 효율적인 방법이 있습니까?
접근 방법 # 1 : 우리는 간단하게 사용할 수있는 Scipy's cdist
그와 cosine
거리 기능 -
from scipy.spatial.distance import cdist
val_out = 1 - cdist(array1, array2, 'cosine')
접근 방법 # 2 : 사용하는 또 다른 방법 matrix-multiplication
-
def cosine_vectorized(array1, array2):
sumyy = (array2**2).sum(1)
sumxx = (array1**2).sum(1, keepdims=1)
sumxy = array1.dot(array2.T)
return (sumxy/np.sqrt(sumxx))/np.sqrt(sumyy)
접근법 # 3 :np.einsum
다른 것에 대한 자기 제곱 합계를 계산하는 데 사용 -
def cosine_vectorized_v2(array1, array2):
sumyy = np.einsum('ij,ij->i',array2,array2)
sumxx = np.einsum('ij,ij->i',array1,array1)[:,None]
sumxy = array1.dot(array2.T)
return (sumxy/np.sqrt(sumxx))/np.sqrt(sumyy)
접근 방법 # 4 : 다른 방법에 대한 계산 부담을 줄이기 위해 numexpr
모듈 가져 오기 square-root
-
import numexpr as ne
def cosine_vectorized_v3(array1, array2):
sumyy = np.einsum('ij,ij->i',array2,array2)
sumxx = np.einsum('ij,ij->i',array1,array1)[:,None]
sumxy = array1.dot(array2.T)
sqrt_sumxx = ne.evaluate('sqrt(sumxx)')
sqrt_sumyy = ne.evaluate('sqrt(sumyy)')
return ne.evaluate('(sumxy/sqrt_sumxx)/sqrt_sumyy')
런타임 테스트
# Using same sizes as stated in the question
In [185]: array1 = np.random.rand(10000,100)
...: array2 = np.random.rand(5000,100)
...:
In [194]: %timeit 1 - cdist(array1, array2, 'cosine')
1 loops, best of 3: 366 ms per loop
In [195]: %timeit cosine_vectorized(array1, array2)
1 loops, best of 3: 287 ms per loop
In [196]: %timeit cosine_vectorized_v2(array1, array2)
1 loops, best of 3: 283 ms per loop
In [197]: %timeit cosine_vectorized_v3(array1, array2)
1 loops, best of 3: 217 ms per loop
이 기사는 인터넷에서 수집됩니다. 재 인쇄 할 때 출처를 알려주십시오.
침해가 발생한 경우 연락 주시기 바랍니다[email protected] 삭제
몇 마디 만하겠습니다