欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 新闻 > 资讯 > Pytorch的自动求导模块

Pytorch的自动求导模块

2025/1/5 5:11:17 来源:https://blog.csdn.net/qq_18055167/article/details/144835883  浏览:    关键词:Pytorch的自动求导模块

文章目录

  • torch.autograd.backward()
    • 基本用法
    • 非标量张量的反向传播
    • 保留计算图
    • 指定输入张量
    • 高阶梯度计算
  • 与 y.backward() 的区别
  • torch.autograd.grad()
    • 基本用法
    • 非标量张量的梯度
    • 高阶梯度计算
    • 多输入、多输出的梯度计算
    • 未使用的输入张量
    • 保留计算图
  • 与 backward() 的区别

torch.autograd.backward()

该函数实现自动求导梯度,函数如下:

torch.autograd.backward(tensors, grad_tensors=None, retain_graph=False, create_graph=False, inputs=None)

参数介绍:

  • tensors: 需要对其进行反向传播的目标张量(或张量列表),例如:loss。
    这些张量通常是计算图的最终输出。
  • grad_tensors:与 tensors 对应的梯度权重(或权重列表)。
    如果 tensors 是标量张量(单个值),可以省略此参数。
    如果 tensors 是非标量张量(如向量或矩阵),则必须提供 grad_tensors,表示每个张量的梯度权重。例如:当有多个loss需要计算梯度时,需要设置每个loss的权值。
  • retain_graph:是否保留计算图。
    默认值为 False,即反向传播后会释放计算图。如果需要多次反向传播,需设置为 True。
  • create_graph: 是否创建一个新的计算图,用于高阶梯度计算
    默认值为 False,如果需要计算二阶或更高阶梯度,需设置为 True。
  • inputs: 指定需要计算梯度的输入张量(或张量列表)。
    如果指定了此参数,只有这些张量的 .grad 属性会被更新,而不是整个计算图中的所有张量。

基本用法

import torch  # 定义张量并启用梯度计算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.backward() 触发反向传播  
torch.autograd.backward(y)  # 查看梯度  
print(x.grad)  # 输出:4.0 (dy/dx = 2x, 当 x=2 时,dy/dx=4)

非标量张量的反向传播

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_tensors 权重  
grad_tensors = torch.tensor([1.0, 1.0, 1.0])  # 权重  
torch.autograd.backward(y, grad_tensors=grad_tensors)  # 查看梯度  
print(x.grad)  # 输出:[2.0, 4.0, 6.0] (dy/dx = 2x)

保留计算图

如果需要多次调用反向传播,可以设置 retain_graph=True。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向传播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 输出:12.0 (dy/dx = 3x^2, 当 x=2 时,dy/dx=12)  # 第二次反向传播  
torch.autograd.backward(y, retain_graph=True)  
print(x.grad)  # 输出:24.0 (梯度累积,12.0 + 12.0)

指定输入张量

通过 inputs 参数,可以只计算指定张量的梯度,而忽略其他张量。

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2 + z ** 3  # y = x^2 + z^3  # 只计算 x 的梯度  
torch.autograd.backward(y, inputs=[x])  
print(x.grad)  # 输出:4.0 (dy/dx = 2x)  
print(z.grad)  # 输出:None (未计算 z 的梯度)

高阶梯度计算

通过设置 create_graph=True,可以构建新的计算图,用于计算高阶梯度。

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次反向传播,创建新的计算图  
torch.autograd.backward(y, create_graph=True)  
print(x.grad)  # 输出:12.0 (dy/dx = 3x^2)  # 计算二阶梯度  
x_grad = x.grad  
x_grad.backward()  
print(x.grad)  # 输出:18.0 (d^2y/dx^2 = 6x)

与 y.backward() 的区别

  • 灵活性:

    • torch.autograd.backward() 更灵活,可以对多个张量同时进行反向传播,并指定梯度权重。
    • y.backward() 是对单个张量的简单封装,适合常见场景。对多个loss求导时,需要指定gradient和grad_outputs相同作用。
  • 梯度权重:

    • torch.autograd.backward() 需要显式提供 grad_tensors 参数(如果目标张量是非标量)。
    • y.backward() 会自动处理标量张量,非标量张量需要手动传入权重。
  • 输入控制:

    • torch.autograd.backward() 可以通过 inputs 参数指定只计算某些张量的梯度。
    • y.backward() 无法直接控制,只会更新计算图中所有相关张量的 .grad。

torch.autograd.grad()

torch.autograd.grad() 是 PyTorch 中用于计算张量梯度的函数,与 backward() 不同的是,它不会更新张量的 .grad 属性,而是直接返回计算的梯度值。它适用于需要手动获取梯度值而不修改计算图中张量的 .grad 属性的场景。

torch.autograd.grad(  outputs,   inputs,   grad_outputs=None,   retain_graph=False,   create_graph=False,   only_inputs=True,   allow_unused=False  
)

参数介绍:

  • outputs:
    目标张量(或张量列表),即需要对其进行求导的输出张量。
  • inputs:
    需要计算梯度的输入张量(或张量列表)。
    这些张量必须启用了 requires_grad=True。
  • grad_outputs:
    与 outputs 对应的梯度权重(或权重列表)。
    如果 outputs 是标量张量,可以省略此参数;如果是非标量张量,则需要提供权重,表示每个输出张量的梯度权重。
  • retain_graph:
    是否保留计算图。
    默认值为 False,即反向传播后会释放计算图。如果需要多次计算梯度,需设置为 True。
  • create_graph:
    是否创建一个新的计算图,用于高阶梯度计算。
    默认值为 False,如果需要计算二阶或更高阶梯度,需设置为 True。
  • only_inputs:
    是否只对 inputs 中的张量计算梯度。
    默认值为 True,表示只计算 inputs 的梯度。
  • allow_unused:
    是否允许 inputs 中的某些张量未被 outputs 使用。
    默认值为 False,如果某些 inputs 未被 outputs 使用,会抛出错误。如果设置为 True,未使用的张量的梯度会返回 None。

返回值:

  • 返回一个元组,包含 inputs 中每个张量的梯度值。
  • 如果某个输入张量未被 outputs 使用,且 allow_unused=True,则对应的梯度为 None。

基本用法

import torch  # 定义张量并启用梯度计算  
x = torch.tensor(2.0, requires_grad=True)  
y = x ** 2  # y = x^2  # 使用 torch.autograd.grad() 计算梯度  
grad = torch.autograd.grad(y, x)  
print(grad)  # 输出:(4.0,) (dy/dx = 2x, 当 x=2 时,dy/dx=4)

非标量张量的梯度

当目标张量是非标量时,需要提供 grad_outputs 参数:

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)  
y = x ** 2  # y = [x1^2, x2^2, x3^2]  # 指定 grad_outputs 权重  
grad_outputs = torch.tensor([1.0, 1.0, 1.0])  # 权重  
grad = torch.autograd.grad(y, x, grad_outputs=grad_outputs)  
print(grad)  # 输出:(tensor([2.0, 4.0, 6.0]),) (dy/dx = 2x)

高阶梯度计算

通过设置 create_graph=True,可以计算高阶梯度:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次计算梯度  
grad = torch.autograd.grad(y, x, create_graph=True)  
print(grad)  # 输出:(12.0,) (dy/dx = 3x^2)  # 计算二阶梯度  
grad2 = torch.autograd.grad(grad[0], x)  
print(grad2)  # 输出:(6.0,) (d^2y/dx^2 = 6x)

多输入、多输出的梯度计算

可以对多个输入和输出同时计算梯度:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y1 = x ** 2 + z ** 3  # y1 = x^2 + z^3  
y2 = x * z  # y2 = x * z  # 对多个输入计算梯度  
grads = torch.autograd.grad([y1, y2], [x, z], grad_outputs=[torch.tensor(1.0), torch.tensor(1.0)])  
print(grads)  # 输出:(7.0, 11.0) (dy1/dx + dy2/dx, dy1/dz + dy2/dz)

未使用的输入张量

如果某些输入张量未被目标张量使用,需设置 allow_unused=True:

x = torch.tensor(2.0, requires_grad=True)  
z = torch.tensor(3.0, requires_grad=True)  
y = x ** 2  # y = x^2  # z 未被 y 使用  
grad = torch.autograd.grad(y, [x, z], allow_unused=True)  
print(grad)  # 输出:(4.0, None) (dy/dx = 4, z 未被使用,梯度为 None)

保留计算图

如果需要多次计算梯度,可以设置 retain_graph=True:

x = torch.tensor(2.0, requires_grad=True)  
y = x ** 3  # y = x^3  # 第一次计算梯度  
grad1 = torch.autograd.grad(y, x, retain_graph=True)  
print(grad1)  # 输出:(12.0,)  # 第二次计算梯度  
grad2 = torch.autograd.grad(y, x)  
print(grad2)  # 输出:(12.0,)

与 backward() 的区别

  • 梯度存储
    • torch.autograd.grad() 不会修改张量的 .grad 属性,而是直接返回梯度值。
    • backward() 会将计算的梯度累积到 .grad 属性中。
  • 灵活性:
    • torch.autograd.grad() 可以对多个输入和输出同时计算梯度,并支持未使用的输入张量。
    • backward() 只能对单个输出张量进行反向传播。
  • 高阶梯度:
    • torch.autograd.grad() 支持通过 create_graph=True 计算高阶梯度。
    • backward() 也支持高阶梯度,但需要手动设置 create_graph=True。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com