欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 能源 > 循环神经网络(RNN)

循环神经网络(RNN)

2025/4/25 2:59:48 来源:https://blog.csdn.net/pljnb/article/details/147423448  浏览:    关键词:循环神经网络(RNN)

循环神经网络(RNN)基本原理

一、RNN核心思想

目标:处理序列数据(如文本、时间序列),通过循环连接传递隐藏状态,捕捉序列的动态依赖关系。
核心特性

  • 参数共享:所有时间步共享同一组权重。
  • 记忆能力:隐藏状态 h t h_t ht 编码历史信息。

二、网络结构与数学公式

1. RNN展开结构

在这里插入图片描述

  • 输入:时间步 t t t 的输入 x t x_t xt(如词向量)。
  • 隐藏状态 h t h_t ht 融合当前输入与历史信息。
  • 输出 y t y_t yt 基于 h t h_t ht 生成预测。

2. 数学公式

  • 隐藏状态更新
    h t = tanh ⁡ ( W h h h t − 1 + W x h x t + b h ) h_t = \tanh(W_{hh} h_{t-1} + W_{xh} x_t + b_h) ht=tanh(Whhht1+Wxhxt+bh)

    • W h h ∈ R d h × d h W_{hh} \in \mathbb{R}^{d_h \times d_h} WhhRdh×dh: 隐藏状态权重
    • W x h ∈ R d x × d h W_{xh} \in \mathbb{R}^{d_x \times d_h} WxhRdx×dh: 输入权重
    • tanh ⁡ \tanh tanh: 激活函数(压缩到[-1,1])
  • 输出计算
    y t = W h y h t + b y y_t = W_{hy} h_t + b_y yt=Whyht+by

    • W h y ∈ R d h × d y W_{hy} \in \mathbb{R}^{d_h \times d_y} WhyRdh×dy: 输出权重

三、PyTorch代码实现

1. RNN模型定义

import torch
import torch.nn as nnclass SimpleRNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_size# 定义权重参数self.W_hh = nn.Parameter(torch.randn(hidden_size, hidden_size))self.W_xh = nn.Parameter(torch.randn(input_size, hidden_size))self.W_hy = nn.Parameter(torch.randn(hidden_size, output_size))self.b_h = nn.Parameter(torch.zeros(hidden_size))self.b_y = nn.Parameter(torch.zeros(output_size))def forward(self, x_seq):# x_seq形状: (seq_length, batch_size, input_size)batch_size = x_seq.size(1)h = torch.zeros(batch_size, self.hidden_size)  # 初始隐藏状态outputs = []for x_t in x_seq:  # 按时间步迭代# 更新隐藏状态h = torch.tanh(torch.mm(h, self.W_hh) + torch.mm(x_t, self.W_xh) + self.b_h)# 计算输出y_t = torch.mm(h, self.W_hy) + self.b_youtputs.append(y_t)return torch.stack(outputs), h

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词