欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 房产 > 家装 > 批量归一化

批量归一化

2024/10/25 1:34:09 来源:https://blog.csdn.net/qq_58317297/article/details/140286985  浏览:    关键词:批量归一化

一、问题

1、在反向传播过程中,梯度通过链式法则从输出层传递到输入层,由于链式法则的乘积形式,如果每一层的梯度范数小于1,那么经过多层的乘积后,梯度会指数级减小,从而导致梯度消失。梯度在较低层(靠近输入层)较小,而在高层(靠近输出层)较大的现象通常被称为“梯度消失”问题

2、初始化一般是一个均值为0,方差为1的的分布,如果不是那么适合的话,我们可以去学习一个新的方差和均值,能更好的进行神经网络的学习,BN的作用是把我们的均值方差拉的比较好,

二、批量归一化

1、固定小批量的均值和方差

(1)小批量的均值和方差

(2)标准化:使用计算得到的小批量均值和方差,对该小批量的数据进行标准化

ϵ 是一个小的常数,用于防止分母为零

(3)缩放和平移(γ和β是可学习的参数)

        yi​=γx^i​+β

2、对于每个特征(每个列),如果是全连接,就会有一个对应的伽玛和贝塔

三、批量归一化层

1、如果作用在全连接层和卷积层的输出上,批量归一化层就作用在激活函数之前,如果作用在激活函数之后,假如我们用的是relu的激活函数,他直接就把我们的梯度变成了一个正数,那就相当于是一个没有用的工作了

2、我们可以把批量归一层,看作是一个线性变换,然后再作用在激活函数的非线性变换

3、对于卷积层,他的通道就相当于特征,例如1*1卷积层,可以把它的通道为看作是这一个像素的特征

4、对于每一个全连接的输出和输入的每一个特征作标量的均值和标量的方差,把特征变为均值为零方差为一,再用学到的伽马和贝塔。把均值和方差再做一次处理

四、作用

1、批量归一化通过减少内部协变量偏移,加速训练,缓解梯度消失和爆炸问题,以及提供正则化效果,极大地改善了深度神经网络的训练和性能。

2、因为每个batch的均值和方差都不太一样,就相当于加入了噪音控制了模型复杂度,他首先根据小批量的方差和均值,增加了随机的偏移和随机缩放,然后再通过一个学习的稳定均值和方差,使模型变化不那么剧烈,并具有随机性

3、批量归一化通过标准化每一层的小批量输入,使得输入的数据均值方差保持在相对稳定的范围内,这样有助于梯度的平稳,支持使用较高的学习率。在原来我们会考虑到梯度爆炸或梯度消失的情况,学习率过大的话,上层的梯度比较大,可能会直接爆炸,学习率过小下层梯度小可能训练不动,但是加入批量归一化之后,我们将每一层的输入都固定在一个比较稳定的范围内,这样就可以使用更大的学习率,不用担心出现之前的那样问题。

五、总结

1、在模型训练过程中,批量规范化利用小批量的均值和标准差,不断调整神经网络的中间输出,使整个神经网络各层的中间输出值更加稳定。

2、批量规范化在全连接层和卷积层的使用略有不同。

3、批量规范化层和暂退层一样,在训练模式和预测模式下计算不同。

4、批量规范化有许多有益的副作用,主要是正则化(加了“噪音”)。另一方面,”减少内部协变量偏移“的原始动机似乎不是一个有效的解释。

六、训练与推理

1、训练期间:BN会针对每个小批量的数据独立计算均值和方差,这些均值和方差只用于当前小批量数据的标准化。每一个新的小批量都会重新计算其均值和方差。

2、在推理期间(即模型训练完成后进行预测时),不能再使用每个小批量的均值和方差,因为推理阶段通常输入的是单个数据样本或者是非训练过程中使用的小批量数据。因此,推理期间BN使用在训练过程中累积的均值和方差的滑动平均值。这些滑动平均值是通过在整个训练过程中计算和更新的,用于稳定模型的输出。

七、代码实现

1、具有张量的批量规范化层

import torch
from torch import nn
from d2l import torch as d2l#moving_mean, moving_var全局均值和方差;gamma, betaγ和β;momentum步长
def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的全局平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差,0是批量。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化,eps防止分母为0X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data
class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()#全连接层if num_dims == 2:shape = (1, num_features)#卷积层else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)#为啥这里没写迭代呢我不知道# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

2、使用批量规范化层的 LeNet(批量规范化是在卷积层或全连接层之后、相应的激活函数之前应用的)

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))
#主要区别在于学习率大得多
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com