1. 什么是 JAX?
JAX 是由 Google Research 开发的 高性能数值计算库,主要用于 机器学习、深度学习 和 科学计算。它基于 NumPy 的 API,但提供了 自动微分(Autograd)、XLA 编译加速 和 高效的 GPU/TPU 计算,使其成为 TensorFlow 和 PyTorch 的强劲竞争者。
2. JAX 的核心特点
1️⃣ 自动微分(Autograd)
- JAX 提供 前向(Forward-mode) 和 反向(Reverse-mode) 自动微分,适用于各种梯度计算任务,如 深度学习、强化学习、物理模拟 等。
2️⃣ JIT 编译(加速计算)
- JAX 使用 XLA(Accelerated Linear Algebra) 进行 Just-In-Time(JIT)编译,大幅提升计算速度,类似 TensorFlow 的 Graph Execution。
3️⃣ 并行计算(Vectorization & GPU/TPU 加速)
vmap
(自动向量化):自动将标量操作向量化,提高效率。pmap
(多 GPU/TPU 并行化):轻松实现数据并行计算。
4️⃣ 兼容 NumPy
- JAX 的 API 设计 类似 NumPy,但所有计算都是 不可变(Immutable) 的,适合 函数式编程。
3. 安装 JAX
JAX 可以通过 pip
安装:
# 仅支持 CPU
pip install jax# 支持 GPU(CUDA 版本)
pip install jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
注意:如果使用 GPU,需要安装 正确版本的 CUDA 和 cuDNN。
4. JAX 基本用法
1️⃣ JAX 作为 NumPy 替代
JAX 提供 jax.numpy
,它的 API 近似于 NumPy,但支持 GPU/TPU 加速:
import jax.numpy as jnp# 创建 JAX 数组(不可变)
x = jnp.array([1.0, 2.0, 3.0])# 矩阵运算(在 GPU 上计算)
y = jnp.dot(x, x)
print(y) # 14.0
2️⃣ 自动微分(grad)
JAX 提供 jax.grad()
计算标量函数的梯度:
import jax# 定义函数:y = x^2
def f(x):return x ** 2# 计算导数 dy/dx
grad_f = jax.grad(f)# 在 x=3 处求导
print(grad_f(3.0)) # 6.0
⚠️ 重要说明
grad(f)
只能用于标量输出(如损失函数)。- 如果是 向量输出,可以使用
jax.jacfwd()
或jax.jacrev()
计算 雅可比矩阵。
3️⃣ JIT 编译(加速计算)
使用 jax.jit()
可以 即时编译 代码,提高计算速度:
import jax.numpy as jnp
from jax import jit# 定义函数
def slow_func(x):return jnp.sin(x) + jnp.cos(x)# JIT 编译加速
fast_func = jit(slow_func)# 测试计算
import timex = jnp.linspace(0, 10, 1000)# 普通计算
start = time.time()
slow_func(x).block_until_ready() # 确保计算完成
print("Normal:", time.time() - start)# JIT 编译后计算
start = time.time()
fast_func(x).block_until_ready()
print("JIT Compiled:", time.time() - start)
运行结果
Normal: 0.04390454292297363
JIT Compiled: 0.01696181297302246
4️⃣ 向量化计算(vmap)
使用 jax.vmap()
自动向量化 计算,提高批量处理效率:
from jax import vmap
import jax.numpy as jnp# 定义标量函数
def f(x):return x ** 2# 直接计算(需要手写 for 循环)
x = jnp.array([1.0, 2.0, 3.0])
print(f(x)) # 错误!f 只能处理标量# 使用 vmap 自动向量化
vectorized_f = vmap(f)
print(vectorized_f(x)) # [1.0, 4.0, 9.0]
运行结果
[1. 4. 9.]
[1. 4. 9.]
5️⃣ 并行计算(pmap)
使用 jax.pmap()
进行 多 GPU/TPU 并行计算:
from jax import pmap
import jax.numpy as jnp# 定义一个简单的计算函数
def f(x):return x ** 2 + 2 * x + 1# 在多个设备上并行计算
parallel_f = pmap(f)# 输入数据(假设有 2 张 GPU)
x = jnp.array([1.0, 2.0])
print(parallel_f(x)) # [4.0, 9.0]
5. 使用 JAX 训练深度学习模型
使用 JAX 训练一个简单的 逻辑回归模型:
import jax.numpy as jnp
import jax
from jax import grad, jit
from jax.scipy.special import expit as sigmoid# 生成随机数据
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 2)) # 100 个样本,每个 2 维
y = (X[:, 0] + X[:, 1] > 0).astype(jnp.float32) # 线性分类任务# 初始化权重
w = jax.random.normal(key, (2,))
b = 0.0# 定义损失函数(交叉熵)
def loss_fn(w, b, X, y):logits = jnp.dot(X, w) + breturn -jnp.mean(y * jnp.log(sigmoid(logits)) + (1 - y) * jnp.log(1 - sigmoid(logits)))# 计算梯度
grad_fn = grad(loss_fn)# 训练循环
lr = 0.1
for i in range(100):grads = grad_fn(w, b, X, y)w -= lr * grads[0]b -= lr * grads[1]print("训练完成:", w, b)
运行结果
训练完成: [1.4514779 3.0926886] -0.20143564
6. JAX vs. PyTorch vs. TensorFlow
特性 | JAX | PyTorch | TensorFlow |
---|---|---|---|
计算方式 | 函数式 | 命令式 | 符号式+命令式 |
GPU/TPU | ✅ 强大 | ✅ 强大 | ✅ 强大 |
自动微分 | ✅ 强大(Autograd) | ✅(torch.autograd) | ✅(tf.GradientTape) |
JIT 编译 | ✅(XLA) | ❌ | ✅(XLA) |
并行计算 | ✅ pmap | ❌ 需要 DDP | ✅ tf.distribute |
适用场景 | 数学、优化、强化学习、深度学习 | 深度学习、CV、NLP | 大规模 AI 训练 |
7. 总结
- JAX 适合需要高性能计算的 AI 研究,尤其是 强化学习、物理模拟、自动微分优化 等任务。
- JIT 编译、自动微分和 GPU/TPU 并行化 让 JAX 比 NumPy、PyTorch 更高效。
- JAX 代码风格简洁,与 NumPy 兼容,但 学习曲线比 PyTorch/TensorFlow 稍陡峭。
JAX 是一个未来 AI 计算的重要工具,适用于高效数值计算和深度学习,尤其适合 Google Cloud、TPU 和科学计算 领域!