PyTorch实现NMS算法
- 介绍
- 示例代码
介绍
参考链接1:NMS 算法源码实现
参考链接2: Python实现NMS(非极大值抑制)对边界框进行过滤。
目标检测算法(主流的有 RCNN 系、YOLO 系、SSD 等)在进行目标检测任务时,可能对同一目标有多次预测得到不同的检测框,非极大值抑制(NMS) 算法则可以确保对每个对象只得到一个检测,简单来说就是“消除冗余检测”。
示例代码
以下代码实现在 PyTorch 中实现非极大值抑制(NMS)。这个函数接受三个参数:boxes
(边界框),scores
(每个边界框的得分),和 iou_threshold
(交并比阈值)。假设输入的边界框格式为 [x1, y1, x2, y2]
,其中 (x1, y1)
是左上角坐标,(x2, y2)
是右下角坐标。
import torchdef nms(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float):"""Perform Non-Maximum Suppression (NMS) on bounding boxes.Args:boxes (torch.Tensor): A tensor of shape (N, 4) containing the bounding boxesof shape [x1, y1, x2, y2], where N is the number of boxes.scores (torch.Tensor): A tensor of shape (N,) containing the scores of the boxes.iou_threshold (float): The IoU threshold for suppressing boxes.Returns:torch.Tensor: A tensor of indices of the boxes to keep."""# Get the areas of the boxesx1 = boxes[:, 0]y1 = boxes[:, 1]x2 = boxes[:, 2]y2 = boxes[:, 3]areas = (x2 - x1) * (y2 - y1)# Sort the scores in descending order and get the sorted indices_, order = scores.sort(0, descending=True)keep = []while order.numel() > 0:if order.numel() == 1:i = order.item()keep.append(i)breakelse:i = order[0].item()keep.append(i)# Compute the IoU of the kept box with the restxx1 = torch.max(x1[i], x1[order[1:]])yy1 = torch.max(y1[i], y1[order[1:]])xx2 = torch.min(x2[i], x2[order[1:]])yy2 = torch.min(y2[i], y2[order[1:]])w = torch.clamp(xx2 - xx1, min=0)h = torch.clamp(yy2 - yy1, min=0)inter = w * hiou = inter / (areas[i] + areas[order[1:]] - inter)# Keep the boxes with IoU less than the thresholdinds = torch.where(iou <= iou_threshold)[0]order = order[inds + 1]return torch.tensor(keep, dtype=torch.long)
代码工作原理:
- 计算每个边界框的面积。
- 根据得分对边界框进行降序排序。
- 依次选择得分最高的边界框,并计算它与其他边界框的 IoU。
- 保留 IoU 小于阈值的边界框,并继续处理剩余的边界框。
- 返回保留的边界框的索引。