欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 艺术 > 基于PyTorch的DETR(Detection Transformer)目标检测模型

基于PyTorch的DETR(Detection Transformer)目标检测模型

2025/4/19 10:04:27 来源:https://blog.csdn.net/hanfeng5268/article/details/147238014  浏览:    关键词:基于PyTorch的DETR(Detection Transformer)目标检测模型

以下是一个基于PyTorch的DETR(Detection Transformer)目标检测模型的实现代码。

文章目录

  • 1. 安装必要的依赖
  • 2. 完整代码实现
  • 3. 代码说明
  • 4. 使用说明
  • 5. 注意事项

1. 安装必要的依赖

在运行代码之前,请确保安装了以下库:

pip install torch torchvision

2. 完整代码实现

以下是DETR的完整实现代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import torchvision
from torch.utils.data import Dataset, DataLoader
import numpy as np# 定义DETR模型
class DETR(nn.Module):def __init__(self, backbone, transformer, num_classes, num_queries, aux_loss=False):super().__init__()self.num_queries = num_queriesself.transformer = transformerhidden_dim = transformer.d_modelself.class_embed = nn.Linear(hidden_dim, num_classes + 1)  # +1 for backgroundself.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3)  # 4 for bbox coordinatesself.query_embed = nn.Embedding(num_queries, hidden_dim)self.input_proj = nn.Conv2d(backbone.num_channels, hidden_dim, kernel_size=1)self.backbone = backboneself.aux_loss = aux_lossdef forward(self, samples: Tensor):# 提取特征features, pos = self.backbone(samples)src, mask = features[-1].decompose()assert mask is not Nonesrc = self.input_proj(src)hs = self.transformer(src, mask, self.query_embed.weight, pos[-1])[0]outputs_class = self.class_embed(hs)outputs_coord = self.bbox_embed(hs).sigmoid()out = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}if self.aux_loss:out['aux_outputs'] = self._set_aux_loss(outputs_class, outputs_coord)return out@torch.jit.unuseddef _set_aux_loss(self, outputs_class, outputs_coord):# this is a workaround to make torchscript happy, as torchscript# doesn't support dictionary with non-homogeneous values, such# as a dict having both a Tensor and a list.return [{'pred_logits': a, 'pred_boxes': b}for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]# 定义Backbone
class Backbone(nn.Module):def __init__(self, name: str, train_backbone: bool, return_layers, dilation):super().__init__()backbone = getattr(torchvision.models, name)(replace_stride_with_dilation=[False, False, dilation])for name, parameter in backbone.named_parameters():if not train_backbone or 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:parameter.requires_grad_(False)self.body = torch.nn.ModuleDict(dict([(k, v) for k, v in backbone.named_children()]))self.return_layers = return_layersself.num_channels = 2048def forward(self, tensor_list: Tensor):xs = self.body['conv1'](tensor_list)xs = self.body['bn1'](xs)xs = self.body['relu'](xs)xs = self.body['maxpool'](xs)xs = self.body['layer1'](xs)xs = self.body['layer2'](xs)xs = self.body['layer3'](xs)xs = self.body['layer4'](xs)out = []for name in self.return_layers:out.append(xs)return out# 定义Transformer
class Transformer(nn.Module):def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False,return_intermediate_dec=False):super().__init__()encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)encoder_norm = nn.LayerNorm(d_model) if normalize_before else Noneself.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,dropout, activation, normalize_before)decoder_norm = nn.LayerNorm(d_model)self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,return_intermediate=return_intermediate_dec)self._reset_parameters()self.d_model = d_modelself.nhead = nheaddef _reset_parameters(self):for p in self.parameters():if p.dim() > 1:nn.init.xavier_uniform_(p)def forward(self, src, mask, query_embed, pos_embed):# flatten NxCxHxW to HWxNxCbs, c, h, w = src.shapesrc = src.flatten(2).permute(2, 0, 1)pos_embed = pos_embed.flatten(2).permute(2, 0, 1)query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)mask = mask.flatten(1)tgt = torch.zeros_like(query_embed)memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)hs = self.decoder(tgt, memory, memory_key_padding_mask=mask,pos=pos_embed, query_pos=query_embed)return hs.transpose(1, 2), memory.permute(1, 2, 0).view(bs, c, h, w)# 定义Transformer的Encoder层
class TransformerEncoder(nn.Module):def __init__(self, encoder_layer, num_layers, norm=None):super().__init__()self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(num_layers)])self.norm = normdef forward(self, src, mask: Tensor = None, src_key_padding_mask: Tensor = None, pos: Tensor = None):output = srcfor layer in self.layers:output = layer(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask, pos=pos)if self.norm is not None:output = self.norm(output)return output# 定义Transformer的Decoder层
class TransformerDecoder(nn.Module):def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):super().__init__()self.layers = nn.ModuleList([copy.deepcopy(decoder_layer) for _ in range(num_layers)])self.norm = normself.return_intermediate = return_intermediatedef forward(self, tgt, memory, tgt_mask: Tensor = None, memory_mask: Tensor = None,tgt_key_padding_mask: Tensor = None, memory_key_padding_mask: Tensor = None,pos: Tensor = None, query_pos: Tensor = None):output = tgtintermediate = []for layer in self.layers:if self.return_intermediate:intermediate.append(self.norm(output))output = layer(output, query_pos, memory, pos, memory_key_padding_mask)if self.norm is not None:output = self.norm(output)if self.return_intermediate:intermediate.pop()intermediate.append(output)if self.return_intermediate:return torch.stack(intermediate)return output.unsqueeze(0)# 定义Transformer的Encoder层
class TransformerEncoderLayer(nn.Module):def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False):super().__init__()self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)self.linear1 = nn.Linear(d_model, dim_feedforward)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(dim_feedforward, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.activation = _get_activation_fn(activation)self.normalize_before = normalize_beforedef with_pos_embed(self, tensor, pos: Tensor):return tensor if pos is None else tensor + posdef forward_post(self, src, src_mask: Tensor = None, src_key_padding_mask: Tensor = None, pos: Tensor = None):q = k = self.with_pos_embed(src, pos)src2 = self.self_attn(q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)src = self.norm1(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))src = src + self.dropout2(src2)src = self.norm2(src)return srcdef forward_pre(self, src, src_mask: Tensor = None, src_key_padding_mask: Tensor = None, pos: Tensor = None):src2 = self.norm1(src)q = k = self.with_pos_embed(src2, pos)src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]src = src + self.dropout1(src2)src2 = self.norm2(src)src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))src = src + self.dropout2(src2)return srcdef forward(self, src, src_mask: Tensor = None, src_key_padding_mask: Tensor = None, pos: Tensor = None):if self.normalize_before:return self.forward_pre(src, src_mask, src_key_padding_mask, pos)return self.forward_post(src, src_mask, src_key_padding_mask, pos)# 定义Transformer的Decoder层
class TransformerDecoderLayer(nn.Module):def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,activation="relu", normalize_before=False):super().__init__()self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)self.linear1 = nn.Linear(d_model, dim_feedforward)self.dropout = nn.Dropout(dropout)self.linear2 = nn.Linear(dim_feedforward, d_model)self.norm1 = nn.LayerNorm(d_model)self.norm2 = nn.LayerNorm(d_model)self.norm3 = nn.LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)self.dropout3 = nn.Dropout(dropout)self.activation = _get_activation_fn(activation)self.normalize_before = normalize_beforedef with_pos_embed(self, tensor, pos: Tensor):return tensor if pos is None else tensor + posdef forward_post(self, tgt, query_pos, memory, memory_key_padding_mask: Tensor = None, pos: Tensor = None):q = k = self.with_pos_embed(tgt, query_pos)tgt2 = self.self_attn(q, k, value=tgt)[0]tgt = tgt + self.dropout1(tgt2)tgt = self.norm1(tgt)tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),key=self.with_pos_embed(memory, pos),value=memory, key_padding_mask=memory_key_padding_mask)[0]tgt = tgt + self.dropout2(tgt2)tgt = self.norm2(tgt)tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))tgt = tgt + self.dropout3(tgt2)tgt = self.norm3(tgt)return tgtdef forward_pre(self, tgt, query_pos, memory, memory_key_padding_mask: Tensor = None, pos: Tensor = None):tgt2 = self.norm1(tgt)q = k = self.with_pos_embed(tgt2, query_pos)tgt2 = self.self_attn(q, k, value=tgt2)[0]tgt = tgt + self.dropout1(tgt2)tgt2 = self.norm2(tgt)tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),key=self.with_pos_embed(memory, pos),value=memory, key_padding_mask=memory_key_padding_mask)[0]tgt = tgt + self.dropout2(tgt2)tgt2 = self.norm3(tgt)tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))tgt = tgt + self.dropout3(tgt2)return tgtdef forward(self, tgt, query_pos, memory, memory_key_padding_mask: Tensor = None, pos: Tensor = None):if self.normalize_before:return self.forward_pre(tgt, query_pos, memory, memory_key_padding_mask, pos)return self.forward_post(tgt, query_pos, memory, memory_key_padding_mask, pos)# 定义多层感知机(MLP)
class MLP(nn.Module):def __init__(self, input_dim, hidden_dim, output_dim, num_layers):super().__init__()self.num_layers = num_layersh = [hidden_dim] * (num_layers - 1)self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))def forward(self, x):for i, layer in enumerate(self.layers):x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)return x# 定义辅助函数
def _get_clones(module, N):return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])def _get_activation_fn(activation):if activation == "relu":return F.reluelif activation == "gelu":return F.geluraise RuntimeError(f"activation should be relu/gelu, not {activation}.")# 示例:训练和推理
if __name__ == "__main__":# 模型参数num_classes = 91  # COCO数据集有91个类别num_queries = 100  # 预测的目标数量hidden_dim = 256nhead = 8num_encoder_layers = 6num_decoder_layers = 6dim_feedforward = 2048dropout = 0.1activation = "relu"# 构建模型backbone = Backbone('resnet50', train_backbone=True, return_layers={'layer2': 'feat1', 'layer3': 'feat2', 'layer4': 'feat3'}, dilation=False)transformer = Transformer(d_model=hidden_dim, nhead=nhead, num_encoder_layers=num_encoder_layers,num_decoder_layers=num_decoder_layers, dim_feedforward=dim_feedforward,dropout=dropout, activation=activation, normalize_before=False,return_intermediate_dec=True)model = DETR(backbone, transformer, num_classes=num_classes, num_queries=num_queries, aux_loss=True)# 示例输入dummy_input = torch.randn(2, 3, 800, 1200)  # 2张图片,3通道,800x1200分辨率outputs = model(dummy_input)# 打印输出print("Predicted logits shape:", outputs['pred_logits'].shape)  # [batch_size, num_queries, num_classes+1]print("Predicted boxes shape:", outputs['pred_boxes'].shape)    # [batch_size, num_queries, 4]

3. 代码说明

  1. Backbone:使用ResNet-50作为特征提取器,提取多尺度特征。
  2. Transformer:核心部分,包括Encoder和Decoder,用于处理特征和生成目标预测。
  3. Detection Head:将Transformer的输出映射到类别和边界框。
  4. 辅助函数:包括位置编码、多层感知机(MLP)等。

4. 使用说明

  • 训练:需要准备一个目标检测数据集(如COCO),并实现数据加载器。
  • 推理:将输入图像传递给模型,输出预测的类别和边界框。

5. 注意事项

  • 代码中使用了COCO数据集的类别数(91),如果使用其他数据集,请修改num_classes
  • num_queries是DETR中预测目标的数量,可以根据需要调整。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com

热搜词