为什么jax.numpy.dot()在CPU上的运行速度比numpy.dot()更慢?

fc

我想使用JAX在CPU上加速numpy代码,之后在GPU上加速。这是在本地计算机(仅CPU)上运行的示例代码:

import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time

size = 3000

key = random.PRNGKey(0)
x =  random.normal(key, (size,size), dtype=jnp.float64)

start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))

x2=np.random.normal((size,size))

start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))

我收到警告,时间成本如下:

/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130: 
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s

我在这里做错了吗?JAX是否还应该在CPU上加速numpy代码?

雅各布

Jax和Numpy之间可能存在性能差异,但是在原始帖子中,时间差异主要归因于数组创建中的一个错误。Jax使用的数组的形状为3000x3000,而Numpy使用的数组是长度为2的一维数组。to的第一个参数numpy.random.normalloc(即,从中采样的高斯平均值)。关键字参数size=应用于指示数组的形状。

numpy.random.normal(loc=0.0, scale=1.0, size=None)

进行此更改后,Jax和Numpy之间的性能就会有所不同。

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)

start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))

x2 = np.random.normal(size=(size, size)).astype(np.float64)

start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))

一轮的输出为

Time of jnp: 2.3315 s
Time of np: 2.8811 s

在测量定时性能时,应该收集多个运行,因为函数的性能是时间的分散而不是单个值。这可以通过Python标准库timeit.timeit函数或%timeitIPython和Jupyter Notebook中魔术来完成

import time
import jax
import jax.numpy as jnp
import numpy as np

size = 3000

key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)

%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

看起来Numpy中有一个针对32位浮点数的优化点运算。

本文收集自互联网,转载请注明来源。

如有侵权,请联系[email protected] 删除。

编辑于
0

我来说两句

0条评论
登录后参与评论

相关文章

来自分类Dev

为什么numpy.dot()抛出ValueError:形状未对齐?

来自分类Dev

numpy.dot-> MemoryError,my_dot->非常慢,但是可以。为什么?

来自分类Dev

为什么在NumPy中填充FFT会使运行速度慢得多?

来自分类Dev

numpy.dot的逆

来自分类Dev

多维数组上的Numpy np.dot()

来自分类Dev

为什么X.dot(XT)在numpy中需要这么多的内存?

来自分类Dev

为什么numpy.dot与矩阵乘法的这些GPU实现一样快?

来自分类Dev

Numpy等价于dot(A,B,3)

来自分类Dev

numpy.dot的意外结果

来自分类Dev

为什么Google Cloud Computing Engine比numpy.dot操作规格较低的专用服务器慢100倍

来自分类Dev

Python矩阵提供numpy.dot()

来自分类Dev

Numpy.dot()尺寸未对齐

来自分类Dev

加速列表理解内的numpy.dot

来自分类Dev

在循环中简化 numpy.dot

来自分类Dev

Python:用于多维数组的numpy.dot / numpy.tensordot

来自分类Dev

PyTorch:什么是PyTorch中的numpy.linalg.multi_dot()等效项

来自分类Dev

为什么numpy的root失败?

来自分类Dev

为什么numpy视图向后?

来自分类Dev

为什么 numpy 导入失败?

来自分类Dev

numpy.dot速度很慢,但已经安装了blas和lapack,该如何解决?

来自分类Dev

numpy.dot速度很慢,但已经安装了blas和lapack,该如何解决?

来自分类Dev

为什么numpy向量化不能提高我的代码速度

来自分类Dev

为什么嵌套字典会减慢 numpy 保存速度?

来自分类Dev

相同的规格,但 Ubuntu 比 Windows 更慢。为什么?

来自分类Dev

为什么我的Python脚本运行速度比HeapSort实现上的速度慢?

来自分类Dev

为什么在这段代码中 CPU 运行速度比 GPU 快?

来自分类Dev

如何获得比numpy.dot更快的代码进行矩阵乘法?

来自分类Dev

我如何找出A * B是Numpy中的Hadamard或Dot产品?

来自分类Dev

使用numpy.dot乘以两个巨大的矩阵

Related 相关文章

  1. 1

    为什么numpy.dot()抛出ValueError:形状未对齐?

  2. 2

    numpy.dot-> MemoryError,my_dot->非常慢,但是可以。为什么?

  3. 3

    为什么在NumPy中填充FFT会使运行速度慢得多?

  4. 4

    numpy.dot的逆

  5. 5

    多维数组上的Numpy np.dot()

  6. 6

    为什么X.dot(XT)在numpy中需要这么多的内存?

  7. 7

    为什么numpy.dot与矩阵乘法的这些GPU实现一样快?

  8. 8

    Numpy等价于dot(A,B,3)

  9. 9

    numpy.dot的意外结果

  10. 10

    为什么Google Cloud Computing Engine比numpy.dot操作规格较低的专用服务器慢100倍

  11. 11

    Python矩阵提供numpy.dot()

  12. 12

    Numpy.dot()尺寸未对齐

  13. 13

    加速列表理解内的numpy.dot

  14. 14

    在循环中简化 numpy.dot

  15. 15

    Python:用于多维数组的numpy.dot / numpy.tensordot

  16. 16

    PyTorch:什么是PyTorch中的numpy.linalg.multi_dot()等效项

  17. 17

    为什么numpy的root失败?

  18. 18

    为什么numpy视图向后?

  19. 19

    为什么 numpy 导入失败?

  20. 20

    numpy.dot速度很慢,但已经安装了blas和lapack,该如何解决?

  21. 21

    numpy.dot速度很慢,但已经安装了blas和lapack,该如何解决?

  22. 22

    为什么numpy向量化不能提高我的代码速度

  23. 23

    为什么嵌套字典会减慢 numpy 保存速度?

  24. 24

    相同的规格,但 Ubuntu 比 Windows 更慢。为什么?

  25. 25

    为什么我的Python脚本运行速度比HeapSort实现上的速度慢?

  26. 26

    为什么在这段代码中 CPU 运行速度比 GPU 快?

  27. 27

    如何获得比numpy.dot更快的代码进行矩阵乘法?

  28. 28

    我如何找出A * B是Numpy中的Hadamard或Dot产品?

  29. 29

    使用numpy.dot乘以两个巨大的矩阵

热门标签

归档