欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 文旅 > 手游 > 深度学习 —— 个人学习笔记16(目标检测和边界框、目标检测数据集)

深度学习 —— 个人学习笔记16(目标检测和边界框、目标检测数据集)

2024/10/24 14:28:08 来源:https://blog.csdn.net/qq_41159013/article/details/141054303  浏览:    关键词:深度学习 —— 个人学习笔记16(目标检测和边界框、目标检测数据集)

声明

  本文章为个人学习使用,版面观感若有不适请谅解,文中知识仅代表个人观点,若出现错误,欢迎各位批评指正。

三十二、目标检测和边界框

import torch
import matplotlib.pyplot as plt
from matplotlib_inline import backend_inlinedef show_images(imgs, titles=None):plt.imshow(imgs)backend_inline.set_matplotlib_formats('svg')plt.rcParams['figure.figsize'] = (6.5, 3.5)plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']plt.title(titles)plt.show()img = plt.imread('E:\\cat\\catdog.jpg')
show_images(img, titles='原图')def box_corner_to_center(boxes):"""从(左上,右下)转换到(中间,宽度,高度)"""x1, y1, x2, y2 = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]cx = (x1 + x2) / 2cy = (y1 + y2) / 2w = x2 - x1h = y2 - y1boxes = torch.stack((cx, cy, w, h), axis=-1)return boxesdef box_center_to_corner(boxes):"""从(中间,宽度,高度)转换到(左上,右下)"""cx, cy, w, h = boxes[:, 0], boxes[:, 1], boxes[:, 2], boxes[:, 3]x1 = cx - 0.5 * wy1 = cy - 0.5 * hx2 = cx + 0.5 * wy2 = cy + 0.5 * hboxes = torch.stack((x1, y1, x2, y2), axis=-1)return boxes# bbox是边界框的英文缩写
dog1_bbox, dog2_bbox = [55.0, 265.0, 252.0, 590.0], [400.0, 18.0, 656.0, 590.0]
cat1_bbox, cat2_bbox = [231.0, 188.0, 443.0, 595.0], [650.0, 226.0, 905.0, 590.0]
boxes = torch.tensor((dog1_bbox, dog2_bbox, cat1_bbox, cat2_bbox))print('测试函数正确性 : ', box_center_to_corner(box_corner_to_center(boxes)) == boxes)def bbox_to_rect(bbox, color):# 将边界框(左上x,左上y,右下x,右下y)格式转换成 matplotlib 格式:# ((左上x,左上y),宽,高)return plt.Rectangle(xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],fill=False, edgecolor=color, linewidth=2)fig = plt.imshow(img)
fig.axes.add_patch(bbox_to_rect(cat1_bbox, 'red'))
fig.axes.add_patch(bbox_to_rect(cat2_bbox, 'red'))
fig.axes.add_patch(bbox_to_rect(dog2_bbox, 'blue'))
fig.axes.add_patch(bbox_to_rect(dog1_bbox, 'blue'))plt.axis('off')
plt.suptitle('标记后')
plt.show()


三十三、目标检测数据集

import os
import pandas as pd
import torch
import torchvision
from matplotlib import pyplot as pltdef show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):numpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)figsize = (num_cols * scale, num_rows * scale)_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):try:img = numpy(img)except:passax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axesdef bbox_to_rect(bbox, color):return plt.Rectangle(xy=(bbox[0], bbox[1]), width=bbox[2]-bbox[0], height=bbox[3]-bbox[1],fill=False, edgecolor=color, linewidth=2)def show_bboxes(axes, bboxes, labels=None, colors=None):def make_list(obj, default_values=None):if obj is None:obj = default_valueselif not isinstance(obj, (list, tuple)):obj = [obj]return objnumpy = lambda x, *args, **kwargs: x.detach().numpy(*args, **kwargs)labels = make_list(labels)colors = make_list(colors, ['b', 'g', 'r', 'm', 'c'])for i, bbox in enumerate(bboxes):color = colors[i % len(colors)]rect = bbox_to_rect(numpy(bbox), color)axes.add_patch(rect)if labels and len(labels) > i:text_color = 'k' if color == 'w' else 'w'axes.text(rect.xy[0], rect.xy[1], labels[i],va='center', ha='center', fontsize=9, color=text_color,bbox=dict(facecolor=color, lw=0))def read_data_bananas(is_train=True):"""读取香蕉检测数据集中的图像和标签"""data_dir = 'E:\\banana-detection'csv_fname = os.path.join(data_dir, 'bananas_train' if is_trainelse 'bananas_val', 'label.csv')csv_data = pd.read_csv(csv_fname)csv_data = csv_data.set_index('img_name')images, targets = [], []for img_name, target in csv_data.iterrows():images.append(torchvision.io.read_image(os.path.join(data_dir, 'bananas_train' if is_train else'bananas_val', 'images', f'{img_name}')))# 这里的 target 包含(类别,左上角 x,左上角 y,右下角 x,右下角 y),# 其中所有图像都具有相同的香蕉类(索引为0)targets.append(list(target))return images, torch.tensor(targets).unsqueeze(1) / 256class BananasDataset(torch.utils.data.Dataset):"""一个用于加载香蕉检测数据集的自定义数据集"""def __init__(self, is_train):self.features, self.labels = read_data_bananas(is_train)print('read ' + str(len(self.features)) + (f' training examples' ifis_train else f' validation examples'))def __getitem__(self, idx):return (self.features[idx].float(), self.labels[idx])def __len__(self):return len(self.features)def load_data_bananas(batch_size):""" 加载香蕉检测数据集 """train_iter = torch.utils.data.DataLoader(BananasDataset(is_train=True),batch_size, shuffle=True)val_iter = torch.utils.data.DataLoader(BananasDataset(is_train=False),batch_size)return train_iter, val_iterbatch_size, edge_size = 32, 256
train_iter, _ = load_data_bananas(batch_size)
batch = next(iter(train_iter))
print(f'(批量大小、通道数、高度、宽度) : {batch[0].shape}\n'f'(批量大小、数据集的任何图像中边界框可能出现的最大数量、5) : {batch[1].shape}')imgs = (batch[0][0:10].permute(0, 2, 3, 1)) / 255
axes = show_images(imgs, 2, 5, scale=2)
for ax, label in zip(axes, batch[1][0:10]):show_bboxes(ax, [label[0][1:5] * edge_size], colors=['r'])plt.rcParams['font.sans-serif'] = ['Microsoft YaHei']
plt.suptitle('数据集展示')
plt.show()



  文中部分知识参考:B 站 —— 跟李沐学AI;百度百科

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com