欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 教育 > 幼教 > 【深度学习框架】JAX:高效的数值计算与深度学习框架

【深度学习框架】JAX:高效的数值计算与深度学习框架

2025/2/25 0:14:49 来源:https://blog.csdn.net/IT_ORACLE/article/details/145507350  浏览:    关键词:【深度学习框架】JAX:高效的数值计算与深度学习框架

1. 什么是 JAX?

JAX 是由 Google Research 开发的 高性能数值计算库,主要用于 机器学习、深度学习科学计算。它基于 NumPy 的 API,但提供了 自动微分(Autograd)XLA 编译加速高效的 GPU/TPU 计算,使其成为 TensorFlow 和 PyTorch 的强劲竞争者


2. JAX 的核心特点

1️⃣ 自动微分(Autograd)

  • JAX 提供 前向(Forward-mode)反向(Reverse-mode) 自动微分,适用于各种梯度计算任务,如 深度学习、强化学习、物理模拟 等。

2️⃣ JIT 编译(加速计算)

  • JAX 使用 XLA(Accelerated Linear Algebra) 进行 Just-In-Time(JIT)编译,大幅提升计算速度,类似 TensorFlow 的 Graph Execution。

3️⃣ 并行计算(Vectorization & GPU/TPU 加速)

  • vmap(自动向量化):自动将标量操作向量化,提高效率。
  • pmap(多 GPU/TPU 并行化):轻松实现数据并行计算。

4️⃣ 兼容 NumPy

  • JAX 的 API 设计 类似 NumPy,但所有计算都是 不可变(Immutable) 的,适合 函数式编程

3. 安装 JAX

JAX 可以通过 pip 安装:

# 仅支持 CPU
pip install jax# 支持 GPU(CUDA 版本)
pip install jax[cuda12_pip] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

注意:如果使用 GPU,需要安装 正确版本的 CUDA 和 cuDNN


4. JAX 基本用法

1️⃣ JAX 作为 NumPy 替代

JAX 提供 jax.numpy,它的 API 近似于 NumPy,但支持 GPU/TPU 加速:

import jax.numpy as jnp# 创建 JAX 数组(不可变)
x = jnp.array([1.0, 2.0, 3.0])# 矩阵运算(在 GPU 上计算)
y = jnp.dot(x, x)
print(y)  # 14.0


2️⃣ 自动微分(grad)

JAX 提供 jax.grad() 计算标量函数的梯度:

import jax# 定义函数:y = x^2
def f(x):return x ** 2# 计算导数 dy/dx
grad_f = jax.grad(f)# 在 x=3 处求导
print(grad_f(3.0))  # 6.0

⚠️ 重要说明
  • grad(f) 只能用于标量输出(如损失函数)。
  • 如果是 向量输出,可以使用 jax.jacfwd()jax.jacrev() 计算 雅可比矩阵

3️⃣ JIT 编译(加速计算)

使用 jax.jit() 可以 即时编译 代码,提高计算速度:

import jax.numpy as jnp
from jax import jit# 定义函数
def slow_func(x):return jnp.sin(x) + jnp.cos(x)# JIT 编译加速
fast_func = jit(slow_func)# 测试计算
import timex = jnp.linspace(0, 10, 1000)# 普通计算
start = time.time()
slow_func(x).block_until_ready()  # 确保计算完成
print("Normal:", time.time() - start)# JIT 编译后计算
start = time.time()
fast_func(x).block_until_ready()
print("JIT Compiled:", time.time() - start)

 运行结果

Normal: 0.04390454292297363
JIT Compiled: 0.01696181297302246


4️⃣ 向量化计算(vmap)

使用 jax.vmap() 自动向量化 计算,提高批量处理效率:

from jax import vmap
import jax.numpy as jnp# 定义标量函数
def f(x):return x ** 2# 直接计算(需要手写 for 循环)
x = jnp.array([1.0, 2.0, 3.0])
print(f(x))  # 错误!f 只能处理标量# 使用 vmap 自动向量化
vectorized_f = vmap(f)
print(vectorized_f(x))  # [1.0, 4.0, 9.0]

运行结果 

[1. 4. 9.]
[1. 4. 9.]


5️⃣ 并行计算(pmap)

使用 jax.pmap() 进行 多 GPU/TPU 并行计算

from jax import pmap
import jax.numpy as jnp# 定义一个简单的计算函数
def f(x):return x ** 2 + 2 * x + 1# 在多个设备上并行计算
parallel_f = pmap(f)# 输入数据(假设有 2 张 GPU)
x = jnp.array([1.0, 2.0])
print(parallel_f(x))  # [4.0, 9.0]


5. 使用 JAX 训练深度学习模型

使用 JAX 训练一个简单的 逻辑回归模型

import jax.numpy as jnp
import jax
from jax import grad, jit
from jax.scipy.special import expit as sigmoid# 生成随机数据
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (100, 2))  # 100 个样本,每个 2 维
y = (X[:, 0] + X[:, 1] > 0).astype(jnp.float32)  # 线性分类任务# 初始化权重
w = jax.random.normal(key, (2,))
b = 0.0# 定义损失函数(交叉熵)
def loss_fn(w, b, X, y):logits = jnp.dot(X, w) + breturn -jnp.mean(y * jnp.log(sigmoid(logits)) + (1 - y) * jnp.log(1 - sigmoid(logits)))# 计算梯度
grad_fn = grad(loss_fn)# 训练循环
lr = 0.1
for i in range(100):grads = grad_fn(w, b, X, y)w -= lr * grads[0]b -= lr * grads[1]print("训练完成:", w, b)

 运行结果

训练完成: [1.4514779 3.0926886] -0.20143564


6. JAX vs. PyTorch vs. TensorFlow

特性JAXPyTorchTensorFlow
计算方式函数式命令式符号式+命令式
GPU/TPU✅ 强大✅ 强大✅ 强大
自动微分✅ 强大(Autograd)✅(torch.autograd)✅(tf.GradientTape)
JIT 编译✅(XLA)✅(XLA)
并行计算pmap❌ 需要 DDPtf.distribute
适用场景数学、优化、强化学习、深度学习深度学习、CV、NLP大规模 AI 训练

7. 总结

  • JAX 适合需要高性能计算的 AI 研究,尤其是 强化学习、物理模拟、自动微分优化 等任务。
  • JIT 编译、自动微分和 GPU/TPU 并行化 让 JAX 比 NumPy、PyTorch 更高效。
  • JAX 代码风格简洁,与 NumPy 兼容,但 学习曲线比 PyTorch/TensorFlow 稍陡峭

JAX 是一个未来 AI 计算的重要工具,适用于高效数值计算和深度学习,尤其适合 Google Cloud、TPU 和科学计算 领域!

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词