欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 金融 > 手动实现CrossEntropyLoss()函数

手动实现CrossEntropyLoss()函数

2024/10/24 23:30:46 来源:https://blog.csdn.net/weixin_45259560/article/details/139655830  浏览:    关键词:手动实现CrossEntropyLoss()函数

在这里插入图片描述

根据index计算loss

class Our_CrossEntropy(torch.nn.Module):def __init__(self):super(Our_CrossEntropy, self).__init__()def forward(self, pre, target, ignore_index=-100):# 考虑igore_indexmask = (pre != ignore_index)	filter_x = pre[mask]filter_y = target[mask]if filter_y.size(0) == 0:return torch.tensor(0.0)# softmaxlogits = torch.nn.functional.softmax(filter_x, dim=1)# 把index标签转化为one-hoty_onehot = torch.nn.functional.one_hot(filter_y, num_classes=logits.shape[1])loss = y_onehot * torch.log(logits)loss = -torch.mean(torch.sum(loss, dim=1), dim=0)return loss

模型中使用示例

	......outputs = self.model(input_ids,attention_mask=attention_mask,decoder_input_ids=decoder_input_ids,encoder_outputs=encoder_outputs,decoder_attention_mask=decoder_attention_mask,head_mask=head_mask,decoder_head_mask=decoder_head_mask,cross_attn_head_mask=cross_attn_head_mask,past_key_values=past_key_values,inputs_embeds=inputs_embeds,decoder_inputs_embeds=decoder_inputs_embeds,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)lm_logits = self.lm_head(outputs[0])lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)masked_lm_loss = None	# torch实现的交叉熵计算my_lm_loss = None		# 自定义实现的交叉熵计算if labels is not None:labels = labels.to(lm_logits.device)loss_fct = CrossEntropyLoss() # torch实现的交叉熵计算masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))loss_fmy = Our_CrossEntropy() # 自定义实现的交叉熵计算my_lm_loss = loss_fmy(lm_logits, labels)if not return_dict:output = (lm_logits,) + outputs[1:]return ((masked_lm_loss,) + output) if masked_lm_loss is not None else outputreturn Seq2SeqLMOutput(loss=masked_lm_loss,logits=lm_logits,past_key_values=outputs.past_key_values,decoder_hidden_states=outputs.decoder_hidden_states,decoder_attentions=outputs.decoder_attentions,cross_attentions=outputs.cross_attentions,encoder_last_hidden_state=outputs.encoder_last_hidden_state,encoder_hidden_states=outputs.encoder_hidden_states,encoder_attentions=outputs.encoder_attentions,).......

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com