Google JAX

Google JAX,是Google开发的用于变换数值函数的Python机器学习框架[3][4][5]。它结合了修改版本的Autograd(自动通过函数的微分获得其梯度函数)[6],和TensorFlow的XLA(加速线性代数[7]。它被设计为尽可能的遵从NumPy的结构和工作流程,并协同工作于各种现存的框架如TensorFlowPyTorch[8][9]

JAX
開發者Google
首次发布2019年10月31日2019-10-31[1]
当前版本
  • 0.4.24 (2024年2月6日;穩定版本)[2]
源代码库github.com/google/jax
编程语言Python, C++
操作系统Linux, macOS, Windows
平台Python, NumPy
类型机器学习
许可协议Apache 2.0
网站jax.readthedocs.io/en/latest/ 编辑维基数据

主要功能

JAX的主要功能是[3]

  • grad:自动微分,
  • jit:即时编译,
  • vmap:自动向量化,
  • pmap:SPMD编程。

grad

下面的代码演示grad函数的自动微分。

# 导入库
from jax import grad
import jax.numpy as jnp

# 定义logistic函数
def logistic(x):  
    return jnp.exp(x) / (jnp.exp(x) + 1)

# 获得logistic函数的梯度函数
grad_logistic = grad(logistic)

# 求值logistic函数在x = 1处的梯度 
grad_log_out = grad_logistic(1.0)   
print(grad_log_out)

最终的输出为:

0.19661194

jit

下面的代码演示jit函数的优化。

# 导入库
from jax import jit
import jax.numpy as jnp

# 定义cube函数
def cube(x):
    return x * x * x

# 生成数据
x = jnp.ones((10000, 10000))

# 创建cube函数的jit版本
jit_cube = jit(cube)

# 应用cube函数和jit_cube函数于相同数据来比较其速度
cube(x)
jit_cube(x)

可见jit_cube的运行时间显著的短于cube

vmap

下面的代码展示vmap函数的通过SIMD的向量化。

# 导入库
from functools import partial
from jax import vmap
import jax.numpy as jnp

# 定义函数
def grads(self, inputs):
    in_grad_partial = partial(self._net_grads, self._net_params)
    grad_vmap = vmap(in_grad_partial)
    rich_grads = grad_vmap(inputs)
    flat_grads = np.asarray(self._flatten_batch(rich_grads))
    assert flat_grads.ndim == 2 and flat_grads.shape[0] == inputs.shape[0]
    return flat_grads

pmap

下面的代码展示pmap函数的对矩阵乘法的并行化。

# 从JAX导入pmap和random;导入JAX NumPy
from jax import pmap, random
import jax.numpy as jnp

# 生成2个维度为5000 x 6000的随机数矩阵,每设备一个
random_keys = random.split(random.PRNGKey(0), 2)
matrices = pmap(lambda key: random.normal(key, (5000, 6000)))(random_keys)

# 没有数据传输,并行的在每个CPU/GPU上进行局部矩阵乘法 
outputs = pmap(lambda x: jnp.dot(x, x.T))(matrices)

# 没有数据传输,并行的在每个CPU/GPU上分别求取这两个矩阵的均值
means = pmap(jnp.mean)(outputs)
print(means)

最终的输出为:

[1.1566595 1.1805978]

使用JAX的库

一些Python库使用JAX作为后端,这包括:

参见

引用

  1. .
  2. https://github.com/google/jax/releases/tag/jax-v0.4.24.
  3. Bradbury, James; Frostig, Roy; Hawkins, Peter; Johnson, Matthew James; Leary, Chris; MacLaurin, Dougal; Necula, George; Paszke, Adam; Vanderplas, Jake; Wanderman-Milne, Skye; Zhang, Qiao, , Astrophysics Source Code Library (Google), 2022-06-18 [2022-06-18], Bibcode:2021ascl.soft11002B, (原始内容存档于2022-06-18)
  4. Frostig, Roy; Johnson, Matthew James; Leary, Chris. (PDF). MLsys. 2018-02-02: 1–3. (原始内容存档 (PDF)于2022-06-21).
  5. . www.deepmind.com. [2022-06-18]. (原始内容存档于2022-06-18) (英语).
  6. . [2023-09-23]. (原始内容存档于2022-07-18).
  7. . [2023-09-23]. (原始内容存档于2022-09-01).
  8. Lynley, Matthew. . Business Insider. [2022-06-21]. (原始内容存档于2022-06-21) (美国英语).
  9. . Analytics India Magazine. 2022-04-25 [2022-06-18]. (原始内容存档于2022-06-18) (美国英语).
  10. , Google, 2022-07-29 [2022-07-29], (原始内容存档于2022-09-03)
  11. Kidger, Patrick, , 2022-07-29 [2022-07-29], (原始内容存档于2023-09-19)
  12. Kidger, Patrick, , 2023-08-05 [2023-08-08], (原始内容存档于2023-08-10)
  13. , DeepMind, 2022-07-28 [2022-07-29], (原始内容存档于2023-06-07)
  14. , Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10)
  15. , DeepMind, 2022-07-29 [2022-07-29], (原始内容存档于2023-04-26)
  16. , DeepMind, 2023-08-08 [2023-08-08], (原始内容存档于2022-11-23)
  17. , Google, 2023-08-08 [2023-08-08], (原始内容存档于2023-08-10)
  18. . [2022-08-31]. (原始内容存档于2022-08-31).
  19. . [2022-08-31]. (原始内容存档于2022-08-31).

外部链接


This article is issued from Wikipedia. The text is licensed under Creative Commons - Attribution - Sharealike. Additional terms may apply for the media files.