欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 科技 > 名人名企 > 【PyTorch】循环神经网络

【PyTorch】循环神经网络

2024/10/25 18:35:15 来源:https://blog.csdn.net/Glass_Gun/article/details/142635961  浏览:    关键词:【PyTorch】循环神经网络

循环神经网络是什么

Recurrent Neural Networks
RNN:循环神经网络

  • 处理不定长输入的模型
  • 常用于NLP及时间序列任务(输入数据具有前后关系

RNN网络结构

参考资料
Recurrent Neural Networks Tutorial, Part 1 – Introduction to RNNs
Understanding LSTM Networks
在这里插入图片描述

RNN实现人名分类

问题定义:输入任意长度姓名(字符串),输出姓名来自哪一个国家(18类分类任务)
数据: https://download.pytorch.org/tutorial/data.zip
Jackie Chan —— 成龙
Jay Chou —— 周杰伦
Tingsong Yue —— 余霆嵩

RNN如何处理不定长输入

思考:计算机如何实现不定长字符串分类向量的映射?
Chou(字符串)→ RNN →Chinese(分类类别)

  1. 单词字符 → 数字
  2. 数字 → model
  3. 下一个字符 → 数字 → model
  4. 最后一个字符 → 数字 → model → 分类向量
# 伪代码
# Chou(字符串)→ RNN →Chinese(分类类别)
for string in [C, h, o, u]:1. one-hot:string → [0,0, ...., 1, ..., 0]	# 首先把每个字母转换成编码2. y, h = model([0,0, ...., 1, ..., 0], h)		# h就是隐藏层的状态信息

xt:时刻t的输入,shape = (1, 57)
st:时刻t的状态值,shape=(1, 128)
ot:时刻t的输出值,shape=(1, 18)
U:linear层的权重参数, shape = (128, 57)
W:linear层的权重参数, shape = (128, 128)
V:linear层的权重参数, shape = (18, 128)

代码如下:

# -*- coding: utf-8 -*-
"""
# @file name  : rnn_demo.py
# @author     : TingsongYu https://github.com/TingsongYu
# @date       : 2019-12-09
# @brief      : rnn人名分类
"""
from io import open
import glob
import unicodedata
import string
import math
import os
import time
import torch.nn as nn
import torch
import random
import matplotlib.pyplot as plt
import torch.utils.data
import sys
# 获取路径
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR)from tools.common_tools import set_seedset_seed(1)  # 设置随机种子
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 选择运行设备
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")# Read a file and split into lines
def readLines(filename):lines = open(filename, encoding='utf-8').read().strip().split('\n')return [unicodeToAscii(line) for line in lines]def unicodeToAscii(s):return ''.join(c for c in unicodedata.normalize('NFD', s)if unicodedata.category(c) != 'Mn'and c in all_letters)# Find letter index from all_letters, e.g. "a" = 0
def letterToIndex(letter):return all_letters.find(letter)# Just for demonstration, turn a letter into a <1 x n_letters> Tensor
def letterToTensor(letter):tensor = torch.zeros(1, n_letters)tensor[0][letterToIndex(letter)] = 1return tensor# Turn a line into a <line_length x 1 x n_letters>,
# or an array of one-hot letter vectors
def lineToTensor(line):tensor = torch.zeros(len(line), 1, n_letters)for li, letter in enumerate(line):tensor[li][0][letterToIndex(letter)] = 1return tensordef categoryFromOutput(output):top_n, top_i = output.topk(1)category_i = top_i[0].item()return all_categories[category_i], category_idef randomChoice(l):return l[random.randint(0, len(l) - 1)]def randomTrainingExample():category = randomChoice(all_categories)                 # 选类别line = randomChoice(category_lines[category])           # 选一个样本category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)line_tensor = lineToTensor(line)    # str to one-hotreturn category, line, category_tensor, line_tensordef timeSince(since):now = time.time()s = now - sincem = math.floor(s / 60)s -= m * 60return '%dm %ds' % (m, s)# Just return an output given a line
def evaluate(line_tensor):hidden = rnn.initHidden()for i in range(line_tensor.size()[0]):output, hidden = rnn(line_tensor[i], hidden)return outputdef predict(input_line, n_predictions=3):print('\n> %s' % input_line)with torch.no_grad():output = evaluate(lineToTensor(input_line))# Get top N categoriestopv, topi = output.topk(n_predictions, 1, True)for i in range(n_predictions):value = topv[0][i].item()category_index = topi[0][i].item()print('(%.2f) %s' % (value, all_categories[category_index]))def get_lr(iter, learning_rate):lr_iter = learning_rate if iter < n_iters else learning_rate*0.1return lr_iter# 定义网络结构
class RNN(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.u = nn.Linear(input_size, hidden_size)self.w = nn.Linear(hidden_size, hidden_size)self.v = nn.Linear(hidden_size, output_size)self.tanh = nn.Tanh()self.softmax = nn.LogSoftmax(dim=1)def forward(self, inputs, hidden):u_x = self.u(inputs)hidden = self.w(hidden)hidden = self.tanh(hidden + u_x)output = self.softmax(self.v(hidden))return output, hiddendef initHidden(self):return torch.zeros(1, self.hidden_size)def train(category_tensor, line_tensor):hidden = rnn.initHidden()rnn.zero_grad()line_tensor = line_tensor.to(device)hidden = hidden.to(device)category_tensor = category_tensor.to(device)for i in range(line_tensor.size()[0]):output, hidden = rnn(line_tensor[i], hidden)loss = criterion(output, category_tensor)loss.backward()# Add parameters' gradients to their values, multiplied by learning ratefor p in rnn.parameters():# p.data.add_(-learning_rate, p.grad.data) # 该方法已经被弃用p.data.add_(p.grad.data, alpha=-learning_rate)return output, loss.item()if __name__ == "__main__":print(device)# configdata_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rnn_data", "names"))if not os.path.exists(data_dir):raise Exception("\n{} 不存在,请下载 08-05-数据-20200724.zip  放到\n{}  下,并解压即可".format(data_dir, os.path.dirname(data_dir)))path_txt = os.path.join(data_dir, "*.txt")all_letters = string.ascii_letters + " .,;'"n_letters = len(all_letters)    # 52 + 5 字符总数print_every = 5000plot_every = 5000learning_rate = 0.005n_iters = 200000# step 1 data# Build the category_lines dictionary, a list of names per languagecategory_lines = {}all_categories = []for filename in glob.glob(path_txt):category = os.path.splitext(os.path.basename(filename))[0]all_categories.append(category)lines = readLines(filename)category_lines[category] = linesn_categories = len(all_categories)# step 2 modeln_hidden = 128# rnn = RNN(n_letters, n_hidden, n_categories)rnn = RNN(n_letters, n_hidden, n_categories)rnn.to(device)# step 3 losscriterion = nn.NLLLoss()# step 4 optimize by hand# step 5 iterationcurrent_loss = 0all_losses = []start = time.time()for iter in range(1, n_iters + 1):# samplecategory, line, category_tensor, line_tensor = randomTrainingExample()# trainingoutput, loss = train(category_tensor, line_tensor)current_loss += loss# Print iter number, loss, name and guessif iter % print_every == 0:guess, guess_i = categoryFromOutput(output)correct = '✓' if guess == category else '✗ (%s)' % categoryprint('Iter: {:<7} time: {:>8s} loss: {:.4f} name: {:>10s}  pred: {:>8s} label: {:>8s}'.format(iter, timeSince(start), loss, line, guess, correct))# Add current loss avg to list of lossesif iter % plot_every == 0:all_losses.append(current_loss / plot_every)current_loss = 0path_model = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rnn_state_dict.pkl"))
if not os.path.exists(path_model):raise Exception("\n{} 不存在,请下载 08-05-数据-20200724.zip  放到\n{}  下,并解压即可".format(path_model, os.path.dirname(path_model)))
torch.save(rnn.state_dict(), path_model)
plt.plot(all_losses)
plt.show()predict('Yue Tingsong')
predict('Yue tingsong')
predict('yutingsong')predict('test your name')

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com