本节课程地址:63 束搜索【动手学深度学习v2】_哔哩哔哩_bilibili
本节教材地址:9.8. 束搜索 — 动手学深度学习 2.0.0 documentation
本节开源代码:...>d2l-zh>pytorch>chapter_multilayer-perceptrons>beam-search.ipynb
束搜索
在 9.7节 中,我们逐个预测输出序列, 直到预测序列中出现特定的序列结束词元“<eos>”。 本节将首先介绍贪心搜索(greedy search)策略, 并探讨其存在的问题,然后对比其他替代策略: 穷举搜索(exhaustive search)和束搜索(beam search)。
在正式介绍贪心搜索之前,我们使用与 9.7节 中 相同的数学符号定义搜索问题。 在任意时间步 ,解码器输出 的概率取决于 时间步 之前的输出子序列 和对输入序列的信息进行编码得到的上下文变量 。 为了量化计算代价,用 表示输出词表, 其中包含“<eos>”, 所以这个词汇集合的基数 就是词表的大小。 我们还将输出序列的最大词元数指定为 。 因此,我们的目标是从所有 个 可能的输出序列中寻找理想的输出。 当然,对于所有输出序列,在“<eos>”之后的部分(非本句) 将在实际输出中丢弃。
贪心搜索
首先,让我们看看一个简单的策略:贪心搜索, 该策略已用于 9.7节 的序列预测。 对于输出序列的每一时间步 , 我们都将基于贪心搜索从 中找到具有最高条件概率的词元,即:
(9.8.1)
一旦输出序列包含了“<eos>”或者达到其最大长度 ,则输出完成。
如 图9.8.1 中, 假设输出中有四个词元“A”“B”“C”和“<eos>”。 每个时间步下的四个数字分别表示在该时间步 生成“A”“B”“C”和“<eos>”的条件概率。 在每个时间步,贪心搜索选择具有最高条件概率的词元。 因此,将在 图9.8.1 中预测输出序列“A”“B”“C”和“<eos>”。 这个输出序列的条件概率是 0.5×0.4×0.4×0.6=0.048 。
那么贪心搜索存在的问题是什么呢? 现实中,最优序列(optimal sequence)应该是最大化 值的输出序列,这是基于输入序列生成输出序列的条件概率。 然而,贪心搜索无法保证得到最优序列。
图9.8.2 中的另一个例子阐述了这个问题。 与 图9.8.1 不同,在时间步2中, 我们选择 图9.8.2 中的词元“C”, 它具有第二高的条件概率。 由于时间步3所基于的时间步1和2处的输出子序列已从 图9.8.1 中的“A”和“B”改变为 图9.8.2 中的“A”和“C”, 因此时间步3处的每个词元的条件概率也在 图9.8.2 中改变。 假设我们在时间步3选择词元“B”, 于是当前的时间步4基于前三个时间步的输出子序列“A”“C”和“B”为条件, 这与 图9.8.1 中的“A”“B”和“C”不同。 因此,在 图9.8.2 中的时间步4生成 每个词元的条件概率也不同于 图9.8.1 中的条件概率。 结果, 图9.8.2 中的输出序列 “A”“C”“B”和“<eos>”的条件概率为 0.5×0.3×0.6×0.6=0.054 , 这大于 图9.8.1 中的贪心搜索的条件概率。 这个例子说明:贪心搜索获得的输出序列 “A”“B”“C”和“<eos>” 不一定是最佳序列。
穷举搜索
如果目标是获得最优序列, 我们可以考虑使用穷举搜索(exhaustive search): 穷举地列举所有可能的输出序列及其条件概率, 然后计算输出条件概率最高的一个。
虽然我们可以使用穷举搜索来获得最优序列, 但其计算量 可能高的惊人。 例如,当 和 时, 我们需要评估 序列, 这是一个极大的数,现有的计算机几乎不可能计算它。 然而,贪心搜索的计算量 通常要显著地小于穷举搜索。 例如,当 和 时, 我们只需要评估 个序列。
束搜索
那么该选取哪种序列搜索策略呢? 如果精度最重要,则显然是穷举搜索。 如果计算成本最重要,则显然是贪心搜索。 而束搜索的实际应用则介于这两个极端之间。
束搜索(beam search)是贪心搜索的一个改进版本。 它有一个超参数,名为束宽(beam size) 。 在时间步1,我们选择具有最高条件概率的 个词元。 这 个词元将分别是 个候选输出序列的第一个词元。 在随后的每个时间步,基于上一时间步的 个候选输出序列, 我们将继续从 个可能的选择中 挑出具有最高条件概率的 个候选输出序列。
图9.8.3 演示了束搜索的过程。 假设输出的词表只包含五个元素: Y=A,B,C,D,E , 其中有一个是“<eos>”。 设置束宽为2,输出序列的最大长度为3。 在时间步1,假设具有最高条件概率 的词元是 A 和 C 。 在时间步2,我们计算所有 为:
从这十个值中选择最大的两个, 比如 和 。 然后在时间步3,我们计算所有 为:
从这十个值中选择最大的两个, 即 和 , 我们会得到六个候选输出序列: (1) A ;(2) C ;(3) A,B ;(4) C,E ;(5) A,B,D ;(6) C,E,D 。
最后,基于这六个序列(例如,丢弃包括“<eos>”和之后的部分), 我们获得最终候选输出序列集合。 然后我们选择其中条件概率乘积最高的序列作为输出序列:
(9.8.4)
其中 是最终候选序列的长度, 通常设置为 0.75 。 因为一个较长的序列在 (9.8.4) 的求和中会有更多的对数项, 因此分母中的 用于惩罚长序列。
束搜索的计算量为 , 这个结果介于贪心搜索和穷举搜索之间。 实际上,贪心搜索可以看作一种束宽为1的特殊类型的束搜索。 通过灵活地选择束宽,束搜索可以在正确率和计算代价之间进行权衡。
小结
- 序列搜索策略包括贪心搜索、穷举搜索和束搜索。
- 贪心搜索所选取序列的计算量最小,但精度相对较低。
- 穷举搜索所选取序列的精度最高,但计算量最大。
- 束搜索通过灵活选择束宽,在正确率和计算代价之间进行权衡。
练习
- 我们可以把穷举搜索看作一种特殊的束搜索吗?为什么?
解:
束搜索本来是一种启发式的搜索算法,通过限制搜索空间中的候选路径来降低计算复杂度;而穷举搜索则是探索所有可能的路径,以确保找到最优解。因此,穷举搜索可以看作是束宽无限大的束搜索,这种情况下,束搜索的启发式约束不起作用,因为所有的路径都被探索,最终等价于穷举搜索。 - 在 9.7节 的机器翻译问题中应用束搜索。 束宽是如何影响预测的速度和结果的?
解:
预测过程应用束搜索的代码如下。 随着束宽不断增加,预测用时增加,速度降低;预测结果上,随束宽增加(从1到4),bleu整体先升后降。
import torch
from torch import nn
from d2l import torch as d2l
import torch.nn.functional as F
class Seq2SeqDecoder(d2l.Decoder):def __init__(self, vocab_size, embed_size, num_hiddens, num_layers,dropout=0, **kwargs):super(Seq2SeqDecoder, self).__init__(**kwargs)self.embedding = nn.Embedding(vocab_size, embed_size)self.rnn = nn.GRU(embed_size + num_hiddens, num_hiddens, num_layers,dropout=dropout)self.dense = nn.Linear(num_hiddens, vocab_size)def init_state(self, enc_outputs, *args):return enc_outputs[1]def forward(self, X, state):X = self.embedding(X).permute(1, 0, 2)context = state[-1].repeat(X.shape[0], 1, 1)X_and_context = torch.cat((X, context), 2)output, state = self.rnn(X_and_context, state)output = self.dense(output).permute(1, 0, 2)return output, state
embed_size, num_hiddens, num_layers, dropout = 32, 32, 2, 0.1
batch_size, num_steps = 64, 10
lr, num_epochs, device = 0.005, 300, d2l.try_gpu()train_iter, src_vocab, tgt_vocab = d2l.load_data_nmt(batch_size, num_steps)
encoder = d2l.Seq2SeqEncoder(len(src_vocab), embed_size, num_hiddens, num_layers,dropout)
decoder = Seq2SeqDecoder(len(tgt_vocab), embed_size, num_hiddens, num_layers,dropout)
net = d2l.EncoderDecoder(encoder, decoder)
d2l.train_seq2seq(net, train_iter, lr, num_epochs, tgt_vocab, device)
输出结果:
loss 0.019, 8768.7 tokens/sec on cpu
def predict_seq2seq_beam(net, src_sentence, src_vocab, tgt_vocab, num_steps,device, beam_size, alpha=0.75, save_attention_weights=False):"""序列到序列模型的束搜索预测"""net.eval()src_tokens = src_vocab[src_sentence.lower().split(' ')] + [src_vocab['<eos>']]enc_valid_len = torch.tensor([len(src_tokens)], device=device)src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])enc_X = torch.unsqueeze(torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)enc_outputs = net.encoder(enc_X, enc_valid_len)dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)dec_X = torch.unsqueeze(torch.tensor([tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)# 保存束搜索的候选序列、累积概率以及注意力权重# (当前输入, 输出序列, 对数概率)beam = [[dec_X, [], 0]] attention_weight_seq = []for _ in range(num_steps):# 存储每个候选序列的下一个时间步的所有可能候选all_candidates = []for beam_input, beam_output, beam_log_prob in beam:Y, dec_state = net.decoder(beam_input, dec_state)# 对结果Y进行softmax并取对数概率Y = F.log_softmax(Y, dim=-1)# 取前 beam_size 个topk_probs, topk_indices = Y.topk(beam_size, dim=-1)# 对于每个候选词元,扩展当前的候选序列for i in range(beam_size):candidate_input = topk_indices[:, :, i] # 第 i 个候选词元candidate_prob = topk_probs[:, :, i] # 对应的对数概率candidate_output = beam_output + [candidate_input.squeeze(dim=0).item()]candidate_log_prob = beam_log_prob + candidate_prob.item()all_candidates.append([candidate_input, candidate_output, candidate_log_prob])# 从所有候选中选择对数概率最大的 beam_size 个候选序列ordered = sorted(all_candidates, key=lambda x: x[2], reverse=True)beam = ordered[:beam_size]# 保存注意力权重(稍后讨论)if save_attention_weights:attention_weight_seq.append(net.decoder.attention_weights)# 对于候选去除<eos>和之后的部分,并且根据序列长度调整最终的对数概率processed_candidates = []for candidate in all_candidates:candidate_input, candidate_output, candidate_log_prob = candidateif tgt_vocab['<eos>'] in candidate_output:eos_index = candidate_output.index(tgt_vocab['<eos>'])candidate_output = candidate_output[:eos_index]L = len(candidate_output) # 去除后的序列长度adjusted_log_prob = candidate_log_prob * (1 / (L ** alpha))processed_candidates.append([candidate_output, adjusted_log_prob])# 根据调整后的对数概率选择得分最高的序列best_candidate = max(processed_candidates, key=lambda x: x[1])output_seq = best_candidate[0]return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
import timeengs = ['go .', "i lost .", 'he\'s calm .', 'i\'m home .']
fras = ['va !', 'j\'ai perdu .', 'il est calme .', 'je suis chez moi .']
start = time.time()
for eng, fra in zip(engs, fras):translation, attention_weight_seq = predict_seq2seq_beam(net, eng, src_vocab, tgt_vocab, num_steps, device, beam_size=1)print(f'{eng} => {translation}, bleu {d2l.bleu(translation, fra, k=2):.3f}')
end = time.time()
print(f'{end-start:.3f} s')
输出结果:
go . => va !, bleu 1.000
i lost . => j'ai <unk> <unk> gagné ?, bleu 0.000
he's calm . => il est paresseux ., bleu 0.658
i'm home . => je suis chez chez suis malade ., bleu 0.574
0.030 s
start = time.time()
for eng, fra in zip(engs, fras):translation, attention_weight_seq = predict_seq2seq_beam(net, eng, src_vocab, tgt_vocab, num_steps, device, beam_size=2)print(f'{eng} => {translation}, bleu {d2l.bleu(translation, fra, k=2):.3f}')
end = time.time()
print(f'{end-start:.3f} s')
输出结果:
go . => va !, bleu 1.000
i lost . => j'ai <unk> emporté ., bleu 0.000
he's calm . => il est paresseux ., bleu 0.658
i'm home . => je suis chez suis chez suis bras !, bleu 0.448
0.049 s
start = time.time()
for eng, fra in zip(engs, fras):translation, attention_weight_seq = predict_seq2seq_beam(net, eng, src_vocab, tgt_vocab, num_steps, device, beam_size=3)print(f'{eng} => {translation}, bleu {d2l.bleu(translation, fra, k=2):.3f}')
end = time.time()
print(f'{end-start:.3f} s')
输出结果:
go . => va !, bleu 1.000
i lost . => j'ai perdu ., bleu 1.000
he's calm . => il est emporté ., bleu 0.658
i'm home . => je suis suis paresseux de confiance aboient aboient question !, bleu 0.258
0.068 s
start = time.time()
for eng, fra in zip(engs, fras):translation, attention_weight_seq = predict_seq2seq_beam(net, eng, src_vocab, tgt_vocab, num_steps, device, beam_size=4)print(f'{eng} => {translation}, bleu {d2l.bleu(translation, fra, k=2):.3f}')
end = time.time()
print(f'{end-start:.3f} s')
输出结果:
go . => va !, bleu 1.000
i lost . => j'ai <unk> emporté ., bleu 0.000
he's calm . => il est retard, bleu 0.492
i'm home . => je suis paresseux capté capté tomber !, bleu 0.342
0.086 s
start = time.time()
for eng, fra in zip(engs, fras):translation, attention_weight_seq = predict_seq2seq_beam(net, eng, src_vocab, tgt_vocab, num_steps, device, beam_size=5)print(f'{eng} => {translation}, bleu {d2l.bleu(translation, fra, k=2):.3f}')
end = time.time()
print(f'{end-start:.3f} s')
输出结果:
go . => va !, bleu 1.000
i lost . => j'ai <unk> emporté ., bleu 0.000
he's calm . => il est ai, bleu 0.492
i'm home . => je suis capté ai paresseux ., bleu 0.473
0.111 s
start = time.time()
for eng, fra in zip(engs, fras):translation, attention_weight_seq = predict_seq2seq_beam(net, eng, src_vocab, tgt_vocab, num_steps, device, beam_size=6)print(f'{eng} => {translation}, bleu {d2l.bleu(translation, fra, k=2):.3f}')
end = time.time()
print(f'{end-start:.3f} s')
输出结果:
go . => va !, bleu 1.000
i lost . => j'ai <unk> aboient ai ., bleu 0.000
he's calm . => il est ai, bleu 0.492
i'm home . => je suis paresseux !, bleu 0.418
0.137 s
3. 在 8.5节 中,我们基于用户提供的前缀, 通过使用语言模型来生成文本。这个例子中使用了哪种搜索策略?可以改进吗?
解:
8.5节 中用的也是贪心搜索,改成束搜索的代码如下。
输出结果的困惑度都是最低1.0,但是预测结果上看似乎没有改进。
%matplotlib inline
import math
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 隐藏层参数W_xh = normal((num_inputs, num_hiddens))W_hh = normal((num_hiddens, num_hiddens))b_h = torch.zeros(num_hiddens, device=device)# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params
def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )
def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH, = stateoutputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h)Y = torch.mm(H, W_hq) + b_qoutputs.append(Y)return torch.cat(outputs, dim=0), (H,)
num_hiddens = 512
net = d2l.RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
def predict_ch8_beam(prefix, num_preds, net, vocab, device, beam_size):"""在prefix后面应用束搜索生成新字符"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))for y in prefix[1:]: # 预热期_, state = net(get_input(), state)outputs.append(vocab[y])# 保存束搜索的候选序列、累积概率# (输出序列,对数概率)beam = [[outputs.copy(), 0]]for _ in range(num_preds): # 预测num_preds步all_candidates =[]for beam_output, beam_log_prob in beam:y, state = net(get_input(), state)y = F.log_softmax(y, dim=-1)# 取前 beam_size 个候选topk_probs, topk_indices = y.topk(beam_size)for i in range(beam_size):candidate_output = beam_output + [topk_indices[0, i].item()]candidate_log_prob = beam_log_prob + topk_probs[0, i].item()all_candidates.append([candidate_output, candidate_log_prob])# 从所有候选中选择对数概率最大的 beam_size 个候选序列ordered = sorted(all_candidates, key=lambda x: x[1], reverse=True)beam = ordered[:beam_size]# 根据调整后的对数概率选择得分最高的序列best_candidate = max(beam, key=lambda x: x[1])outputs = best_candidate[0]return ''.join([vocab.idx_to_token[i] for i in outputs])
def train_ch8_beam(net, train_iter, vocab, lr, num_epochs, device, beam_size,use_random_iter=False):loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', ylabel='perplexity',legend=['train'], xlim=[10, num_epochs])# 初始化if isinstance(net, nn.Module):updater = torch.optim.SGD(net.parameters(), lr)else:updater = lambda batch_size: d2l.sgd(net.params, lr, batch_size)predict = lambda prefix: predict_ch8_beam(prefix, 50, net, vocab, device, beam_size)# 训练和预测for epoch in range(num_epochs):ppl, speed = d2l.train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter)if (epoch + 1) % 10 == 0:print(predict('time traveller'))animator.add(epoch + 1, [ppl])print(f'困惑度 {ppl:.1f}, {speed:.1f} 词元/秒 {str(device)}')print(predict('time traveller'))print(predict('traveller'))
num_epochs, lr = 500, 1
train_ch8_beam(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(), beam_size=1)
输出结果:
困惑度 1.0, 18141.4 词元/秒 cpu
time traveller oioyaoaomtotaeaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
travellerywioefiootyataaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa
net = d2l.RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
train_ch8_beam(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(), beam_size=5)
输出结果:
困惑度 1.0, 18363.3 词元/秒 cpu
time traveller ooeeyaeayyeeeyeeeyyaeaeeaeeyeecyeyyeeeaaeeeeeeyee
traveller oae aaeeaecyaaeaeyeyyeeaeefeaeeyeeyeeaaeyeae seey
net = d2l.RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
train_ch8_beam(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(), beam_size=10)
输出结果:
困惑度 1.0, 19986.3 词元/秒 cpu
time traveller oeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeeee
travelleryoeeeeeeeeeeeeeeeseeeeeeeeeeeeeeeseeeeeeeeeeeeeeee
net = d2l.RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
train_ch8_beam(net, train_iter, vocab, lr, num_epochs, d2l.try_gpu(), beam_size=20)
输出结果:
困惑度 1.0, 20103.8 词元/秒 cpu
time travelleryoaaiaiaiiiaaaidaaiiaiiiiaaiaiidadaiiiaiiiiaiiaiid
travelleryoaiiiiaididaiiaiadaiiiiadaiiiiaiiiiaiidiiciiaiiia