欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 旅游 > 写一个简单的SSD算法

写一个简单的SSD算法

2024/10/24 16:24:08 来源:https://blog.csdn.net/m0_73426548/article/details/142282025  浏览:    关键词:写一个简单的SSD算法

一:创建数据集

1.导入相关包

import torch
import numpy as np
from PIL import Image

2.继承torch.utils.data.Dataset读入数据

class PikaDataset(torch.utils.data.Dataset):# part:接收train或者text,选择读取训练集还是测试集def __init__(self, part):self.part=part# 读取数据self.data=np.loadtxt(r'./data/pika/%s.csv' % part,delimiter=',')def __getitem__(self,index):# 根据index读取图片,返回图片targetx=Image.open(r'./data/pika/%s/%s.jpg'%(self.part,index)).convert('RGB')# 把图片变成ndarry格式[256,256,3]x=np.array(x)# 因为我们要使用的是torch所以要将图片变成[3,225,225]格式x=x.transpose((2,0,1))  # 变成了[3,225,225]# 将数据变成tensor类型x=torch.tensor(x)x.float()  # 改变成float类型y=torch.FloatTensor(self.data[index])return x,ydef __len__(self):return len(self.data)

 3.读取训练集和测试集

loader_train = torch.utils.data.DataLoader(dataset=PikaDataset(part='train'), batch_size=32,shuffle=True,drop_last=True)
# bath_size:表示一组读取多少张图片 shuffle:是否打乱数
# drop_last:如果数据集大小不能被批大小整除,则设置为“true”以除去最后一个未完成的批。如果“false”那么最后一批将更小。(默认:falseloader_test = torch.utils.data.DataLoader(dataset=PikaDataset(part='test'), batch_size=32,shuffle=True,drop_last=True)

二:生成候选框

# 获取候选框方法:将图片分成N*N的形式,这里是32*32,那么就生成了32*32个候选框
def get_anchors():# 需要生成32*32的image_size=32# 候选框大小anchor_size=0.15# 生成(0,32)大小的等差数列step=(np.arange(image_size)+0.5)/image_size# 生成中心点point=[]# 记录中心点坐标for i in range(image_size):for j in range(image_size):point.append([step[j],step[i]])# 根据中心点,生成所有anchors的坐标anchors=torch.empty(len(point),4)for i in range(len(point)):# 左上角x=中心点x坐标-候选框大小的一半anchors[i,0]=point[i][0]-anchor_size/2# 右下角y=中心点y坐标-候选框一半anchors[i,1]=point[i][1]-anchor_size/2# 右下角x=中心点x坐标+候选框大小的一半anchors[i,2]=point[i][0]+anchor_size/2# 左上角y=中心点y坐标+候选框一半anchors[i,3]=point[i][1]+anchor_size/2# 返回的候选框格式为[左上角x,右下角y,右下角x,左上角y]return anchors 
anchors=git_anchors()       

三:计算IOU(交并比)

def get_iou(y):# 计算出候选框的面积anchors_w=anchors[:,2]-anchors[:,0]anchors_h=anchors[:,3]-anchors[:,1]anchors_s=anchors_w*anchors_h# 计算真实框的面积y_w=y[2]-y[0]y_h=y[3]-y[1]y_s=y_w*y_h# 计算候选框和真是框的交集的小方框,用cross保存整个小方框的左上角和右下角坐标cross=torch.empty(anchors.shape)# 小方框的左上角坐标cross[:,0]=torch.max(anchors[:,0],y[0])cross[:,1]=torch.max(anchors[:,1],y[1])# 小方框的右下角坐标cross[:,2]=torch.min(anchors[:,2],y[2])cross[:,3]=torch.min(anchors[:,3],y[3])# 计算小方框的面积,没有交集的可能为负数,所有加上clamp函数cross_w=(cross[:,2]-cross[:,0]).clamp(min=0)cross_h=(cross[:,3]-cross[:,1]).clamp(min=0)cross_s=cross_w*cross_h# 并集面积union_s=anchors_s+y_s-cross_sreturn cross_s/union_s

四:计算target

def get_target(y):# anchors: [1024, 4]# y -> [32, 4]# target -> [32, 1024]target = torch.zeros(32, 1024)for i in range(32):target[i] = get_iou(y[i])return target
target=get_target(y)

五:定义模型

class Model(torch.nn.Module):def __init__(self):super().__init__()def get_block(in_channels, out_channels):block = torch.nn.Sequential(torch.nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=1),torch.nn.BatchNorm2d(num_features=out_channels),torch.nn.ReLU(),torch.nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=1),torch.nn.BatchNorm2d(num_features=out_channels),torch.nn.ReLU(),torch.nn.MaxPool2d(kernel_size=2)    )return block
#         self.block = get_blockself.cnn = torch.nn.Sequential(# [32, 3, 256, 256] -> [32, 16, 128, 128]get_block(in_channels=3,  out_channels=16), # [32, 16, 128, 128] -> [32, 32, 64, 64]get_block(in_channels=16,  out_channels=32), # [32, 32, 64, 64] -> [32, 64, 32, 32]get_block(in_channels=32,  out_channels=64), # [32, 64, 32, 32] -> [32, 128, 16, 16]get_block(in_channels=64,  out_channels=128), # [32, 128, 16, 16] -> [32, 256, 8, 8]get_block(in_channels=128,  out_channels=256), # [32, 256, 8, 8] -> [32, 512, 4, 4]get_block(in_channels=256,  out_channels=512), # [32, 512, 4, 4] -> [32, 1024, 2, 2]get_block(in_channels=512,  out_channels=1024))# [32, 1024, 2, 2]  -> [32, 1024, 1, 1]self.predictor = torch.nn.Conv2d(in_channels=1024, out_channels=1024,kernel_size=2,padding=0)def forward(self, x):# [32, 3, 256, 256] -> [32, 1024, 2, 2]x = self.cnn(x)# [32, 1024, 2, 2] - > [32, 1024, 1, 1]pred = self.predictor(x)pred = pred.squeeze()return pred
model=Model()

六:训练

optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
get_loss = torch.nn.MSELoss()# 开始训练
for epoch in range(10):model.train()for i, (x, y) in enumerate(loader_train):# x [32, 3, 256, 256]# y [32, 4]# 预测 [32, 1024]pred = model(x)# 获取targettarget = get_target(y)# 计算损失  loss = get_loss(pred, target)optimizer.zero_grad()loss.backward()optimizer.step()if i % 10 == 0:print(epoch, i, loss.item())torch.save(model, './super_simple_ssd.pth')

七:预测

def predict(x):# x [32, 3, 256, 256]model.eval()# [32, 1024]pred = model(x)# [32, 1024] -> [32]# 得到的是最大值的索引, 这个索引就代表对应的anchorspred = pred.argmax(dim=1)return pred

八:保存模型数据

model = torch.load('./super_simple_ssd.pth')
pred = predict(x)

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com