欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 健康 > 美食 > PyTorch实现NMS算法

PyTorch实现NMS算法

2024/10/23 22:51:46 来源:https://blog.csdn.net/qq_36892712/article/details/139840971  浏览:    关键词:PyTorch实现NMS算法

PyTorch实现NMS算法

  • 介绍
    • 示例代码

介绍

参考链接1:NMS 算法源码实现
参考链接2: Python实现NMS(非极大值抑制)对边界框进行过滤。
目标检测算法(主流的有 RCNN 系、YOLO 系、SSD 等)在进行目标检测任务时,可能对同一目标有多次预测得到不同的检测框,非极大值抑制(NMS) 算法则可以确保对每个对象只得到一个检测,简单来说就是“消除冗余检测”。

示例代码

以下代码实现在 PyTorch 中实现非极大值抑制(NMS)。这个函数接受三个参数:boxes(边界框),scores(每个边界框的得分),和 iou_threshold(交并比阈值)。假设输入的边界框格式为 [x1, y1, x2, y2],其中 (x1, y1) 是左上角坐标,(x2, y2) 是右下角坐标。

import torchdef nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float):"""Perform Non-Maximum Suppression (NMS) on bounding boxes.Args:boxes (torch.Tensor): A tensor of shape (N, 4) containing the bounding boxesof shape [x1, y1, x2, y2], where N is the number of boxes.scores (torch.Tensor): A tensor of shape (N,) containing the scores of the boxes.iou_threshold (float): The IoU threshold for suppressing boxes.Returns:torch.Tensor: A tensor of indices of the boxes to keep."""# Get the areas of the boxesx1 = boxes[:, 0]y1 = boxes[:, 1]x2 = boxes[:, 2]y2 = boxes[:, 3]areas = (x2 - x1) * (y2 - y1)# Sort the scores in descending order and get the sorted indices_, order = scores.sort(0, descending=True)keep = []while order.numel() > 0:if order.numel() == 1:i = order.item()keep.append(i)breakelse:i = order[0].item()keep.append(i)# Compute the IoU of the kept box with the restxx1 = torch.max(x1[i], x1[order[1:]])yy1 = torch.max(y1[i], y1[order[1:]])xx2 = torch.min(x2[i], x2[order[1:]])yy2 = torch.min(y2[i], y2[order[1:]])w = torch.clamp(xx2 - xx1, min=0)h = torch.clamp(yy2 - yy1, min=0)inter = w * hiou = inter / (areas[i] + areas[order[1:]] - inter)# Keep the boxes with IoU less than the thresholdinds = torch.where(iou <= iou_threshold)[0]order = order[inds + 1]return torch.tensor(keep, dtype=torch.long)

代码工作原理:

  1. 计算每个边界框的面积。
  2. 根据得分对边界框进行降序排序。
  3. 依次选择得分最高的边界框,并计算它与其他边界框的 IoU。
  4. 保留 IoU 小于阈值的边界框,并继续处理剩余的边界框。
  5. 返回保留的边界框的索引。

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com