我想使用JAX在CPU上加速numpy代码,之后在GPU上加速。这是在本地计算机(仅CPU)上运行的示例代码:
import jax.numpy as jnp
from jax import random, jix
import numpy as np
import time
size = 3000
key = random.PRNGKey(0)
x = random.normal(key, (size,size), dtype=jnp.float64)
start=time.time()
test = jnp.dot(x, x.T).block_until_ready()
print('Time of jnp: {}s'.format(time.time() - start))
x2=np.random.normal((size,size))
start=time.time()
test2 = np.dot(x2, x2.T)
print('Time of np: {}s'.format(time.time() - start))
我收到警告,时间成本如下:
/.../lib/python3.7/site-packages/jax/lib/xla_bridge.py:130:
UserWarning: No GPU/TPU found, falling back to CPU.
warnings.warn('No GPU/TPU found, falling back to CPU.')
Time: 0.45157814025878906s
Time: 0.005244255065917969s
我在这里做错了吗?JAX是否还应该在CPU上加速numpy代码?
Jax和Numpy之间可能存在性能差异,但是在原始帖子中,时间差异主要归因于数组创建中的一个错误。Jax使用的数组的形状为3000x3000,而Numpy使用的数组是长度为2的一维数组。to的第一个参数numpy.random.normal
为loc
(即,从中采样的高斯平均值)。关键字参数size=
应用于指示数组的形状。
numpy.random.normal(loc=0.0, scale=1.0, size=None)
进行此更改后,Jax和Numpy之间的性能就会有所不同。
import time
import jax
import jax.numpy as jnp
import numpy as np
size = 3000
key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (size, size), dtype=jnp.float64)
start = time.time()
test = jnp.dot(x, x.T).block_until_ready()
print("Time of jnp: {:0.4f} s".format(time.time() - start))
x2 = np.random.normal(size=(size, size)).astype(np.float64)
start = time.time()
test2 = np.dot(x2, x2.T)
print("Time of np: {:0.4f} s".format(time.time() - start))
一轮的输出为
Time of jnp: 2.3315 s
Time of np: 2.8811 s
在测量定时性能时,应该收集多个运行,因为函数的性能是时间的分散而不是单个值。这可以通过Python标准库timeit.timeit
函数或%timeit
IPython和Jupyter Notebook中的魔术来完成。
import time
import jax
import jax.numpy as jnp
import numpy as np
size = 3000
key = jax.random.PRNGKey(0)
xjnp = jax.random.normal(key, shape=(size, size), dtype=jnp.float64)
xnp = np.random.normal(size=(size, size)).astype(np.float64)
%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.03 s ± 39.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.dot(xnp, xnp.T)
# 3.41 s ± 501 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
xjnp = xjnp.astype(jnp.float32)
xnp = xnp.astype(np.float32)
%timeit jnp.dot(xjnp, xjnp.T).block_until_ready()
# 2.05 s ± 74.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit np.dot(xnp, xnp.T)
# 1.73 s ± 383 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
看起来Numpy中有一个针对32位浮点数的优化点运算。
本文收集自互联网,转载请注明来源。
如有侵权,请联系[email protected] 删除。
我来说两句