高维Python Numpy矩阵乘法

啊哈

我正在尝试寻找一种矩阵运算numpy,以加快以下计算的速度。

我有两个3D矩阵AB第一维表示示例,并且它们两个都有n_examples示例。我要实现的是对A和B中的每个示例进行点乘并求和:

import numpy as np

n_examples = 10
A = np.random.randn(n_examples, 20,30)
B = np.random.randn(n_examples, 30,5)
sum = np.zeros([20,5])
for i in range(len(A)):
  sum += np.dot(A[i],B[i])
索洛GP卡斯特罗

这是典型的应用程序np.tensordot()

sum = np.tensordot(A, B, [[0,2],[0,1]])

定时

使用以下代码:

import numpy as np

n_examples = 100
A = np.random.randn(n_examples, 20,30)
B = np.random.randn(n_examples, 30,5)

def sol1():
    sum = np.zeros([20,5])
    for i in range(len(A)):
      sum += np.dot(A[i],B[i])
    return sum

def sol2():
    return np.array(map(np.dot, A,B)).sum(0)

def sol3():
    return np.einsum('nmk,nkj->mj',A,B)

def sol4():
    return np.tensordot(A, B, [[2,0],[1,0]])

def sol5():
    return np.tensordot(A, B, [[0,2],[0,1]])

结果:

timeit sol1()
1000 loops, best of 3: 1.46 ms per loop

timeit sol2()
100 loops, best of 3: 4.22 ms per loop

timeit sol3()
1000 loops, best of 3: 1.87 ms per loop

timeit sol4()
10000 loops, best of 3: 205 µs per loop

timeit sol5()
10000 loops, best of 3: 172 µs per loop

在我的计算机上,这tensordot()是最快的解决方案,改变轴的评估顺序不会改变结果,也不会改变性能。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

高维Python Numpy矩阵乘法

来自分类Dev

Python NUMPY HUGE矩阵乘法

来自分类Dev

Python numpy:矩阵乘法给出错误的结果

来自分类Dev

比较Python,Numpy,Numba和C ++进行矩阵乘法

来自分类Dev

无法复制比较Python,Numpy和Numba矩阵乘法的结果

来自分类Dev

Python numpy:矩阵乘法给出错误的结果

来自分类Dev

使用numpy中的标量乘法对二维矩阵进行归一化

来自分类Dev

奇数维矩阵的矩阵乘法速度

来自分类Dev

Python张量矩阵乘法

来自分类Dev

numpy中的矩阵乘法

来自分类Dev

Numpy的多重矩阵乘法

来自分类Dev

numpy矩阵乘法行为

来自分类Dev

当到达numpy中两个不同维的数组时,点积和乘法矩阵是否相同

来自分类Dev

Python平行矩阵向量乘法

来自分类Dev

Python中的多维矩阵乘法

来自分类Dev

python的N×M矩阵乘法

来自分类Dev

Python语言中的矩阵乘法

来自分类Dev

使用Python进行矩阵乘法

来自分类Dev

Python矩阵乘法和缓存

来自分类Dev

Python:分而治之的递归矩阵乘法

来自分类Dev

python中用于矩形矩阵的稀疏矩阵矩阵乘法

来自分类Dev

Numpy矩阵的麻烦乘法列

来自分类Dev

Numpy中的乘法块矩阵

来自分类Dev

没有numpy的矩阵乘法

来自分类Dev

numpy中的矩阵加法/乘法

来自分类Dev

Numpy 数组和矩阵乘法

来自分类Dev

numpy中一维数组的乘法

来自分类Dev

二维数组的numpy列表乘法

来自分类Dev

一维数组在numpy中的乘法