欢迎来到尧图网

客户服务 关于我们

您的位置:首页 > 财经 > 创投人物 > 计算DOTA文件的IOU

计算DOTA文件的IOU

2024/10/23 23:22:06 来源:https://blog.csdn.net/m0_67947599/article/details/143027325  浏览:    关键词:计算DOTA文件的IOU

背景

在目标检测任务中,评估不同对象之间的重叠情况是至关重要的,而IOU(Intersection Over Union)是衡量这种重叠程度的重要指标。本文将介绍如何编写一个Python脚本,通过并行化处理DOTA格式的标注文件,统计同类别对象之间的IOU超过某个阈值的对数。

代码功能

本文代码的核心功能包括:

  1. 解析DOTA格式标注文件,提取对象类别和多边形坐标。
  2. 计算同类对象之间的IOU,并统计超过设定阈值的情况。
  3. 使用多进程并行化处理,提高对大规模数据的处理效率。
  4. 将最终的结果保存为CSV格式,方便后续分析。

完整代码

import os
import logging
from shapely.geometry import Polygon
import numpy as np
from itertools import combinations
from concurrent.futures import ProcessPoolExecutor, as_completed
import argparse
import pandas as pd# 配置日志
logging.basicConfig(level=logging.INFO,format='%(asctime)s [%(levelname)s] %(message)s',handlers=[logging.StreamHandler()]
)def parse_args():"""解析命令行参数。"""parser = argparse.ArgumentParser(description='进行标注文件的IOU分析,统计同类别对象之间的重叠数量。')parser.add_argument('--anno_folder', type=str, required=True, help='标注文件夹路径')parser.add_argument('--output_csv', type=str, default='iou_overlap_results.csv', help='输出CSV文件路径')parser.add_argument('--iou_threshold', type=float, default=0.01, help='IOU阈值,超过此值视为重叠')parser.add_argument('--num_workers', type=int, default=None, help='并行处理的进程数,默认为CPU核心数')parser.add_argument('--by_class', action='store_true', help='是否按类别统计重叠数量')return parser.parse_args()def parse_annotation_file(file_path, class_map=None):"""解析标注文件,提取对象的类别和多边形坐标。参数:file_path (str): 标注文件的路径。class_map (set, optional): 需要筛选的类别集合。默认为None,表示不筛选。返回:list of tuples: 每个元组包含类别和对应的Shapely多边形。"""objects = []try:with open(file_path, 'r') as file:for line_num, line in enumerate(file, 1):parts = line.strip().split()if len(parts) < 9:logging.warning(f"{file_path} 第{line_num}行格式不正确,跳过。")continuetry:# 假设坐标为前8个元素,类别为第9个元素coords = list(map(float, parts[:8]))dota_type = parts[8]if class_map and dota_type not in class_map:continue# 将坐标转换为Shapely多边形polygon = Polygon(np.array(coords).reshape(-1, 2))if not polygon.is_valid:logging.warning(f"{file_path} 第{line_num}行的多边形无效,跳过。")continueobjects.append((dota_type, polygon))except ValueError as ve:logging.error(f"{file_path} 第{line_num}行坐标转换错误: {ve}")except Exception as e:logging.error(f"读取文件 {file_path} 时发生错误: {e}")return objectsdef compute_iou(poly1, poly2):"""计算两个多边形的IOU。参数:poly1 (Polygon): 第一个多边形。poly2 (Polygon): 第二个多边形。返回:float: 两个多边形的IOU值。"""intersection = poly1.intersection(poly2).areaunion = poly1.union(poly2).areaif union == 0:return 0return intersection / uniondef analyze_file(file_path, by_class=False, class_map=None, iou_threshold=0.01):"""分析单个标注文件,统计同类别对象之间的IOU超过阈值的对数。参数:file_path (str): 标注文件的路径。by_class (bool, optional): 是否按类别统计。默认为False。class_map (set, optional): 需要筛选的类别集合。默认为None,表示所有类别。iou_threshold (float, optional): IOU阈值。默认为0.01。返回:list of dicts: 每个字典包含文件名、类别(如果按类别统计)和重叠对数。"""filename = os.path.basename(file_path)objects = parse_annotation_file(file_path, class_map)results = []if by_class:# 按类别分组class_dict = {}for dota_type, polygon in objects:class_dict.setdefault(dota_type, []).append(polygon)for dota_type, polygons in class_dict.items():overlap_count = 0num_objects = len(polygons)if num_objects < 2:# 少于两个对象,无需比较results.append({'filename': filename,'class': dota_type,'overlap_count': 0})continue# 使用组合生成所有可能的对象对for poly1, poly2 in combinations(polygons, 2):iou = compute_iou(poly1, poly2)if iou > iou_threshold:overlap_count += 1results.append({'filename': filename,'class': dota_type,'overlap_count': overlap_count})else:# 不按类别,统计所有对象之间的重叠polygons = [polygon for _, polygon in objects]overlap_count = 0num_objects = len(polygons)if num_objects >= 2:for poly1, poly2 in combinations(polygons, 2):iou = compute_iou(poly1, poly2)if iou > iou_threshold:overlap_count += 1results.append({'filename': filename,'overlap_count': overlap_count})return resultsdef main():args = parse_args()anno_folder = args.anno_folderoutput_csv = args.output_csviou_threshold = args.iou_thresholdnum_workers = args.num_workersby_class = args.by_class# 定义类别映射,如果需要筛选特定类别,可以在这里修改# 例如:class_map = {'embankment_dota', 'gravity_dota'}class_map = None  # 设置为None表示分析所有类别# class_map = {'embankment_dota'}  # 只分析 'embankment_dota' 类别# 获取所有标注文件all_files = [os.path.join(anno_folder, f) for f in os.listdir(anno_folder) if os.path.isfile(os.path.join(anno_folder, f))]logging.info(f"找到 {len(all_files)} 个标注文件。")# 准备并行处理results = []with ProcessPoolExecutor(max_workers=num_workers) as executor:future_to_file = {executor.submit(analyze_file, file_path, by_class, class_map, iou_threshold): file_path for file_path in all_files}for future in as_completed(future_to_file):file_path = future_to_file[future]try:file_results = future.result()results.extend(file_results)logging.info(f"完成分析文件: {os.path.basename(file_path)}")except Exception as exc:logging.error(f"分析文件 {os.path.basename(file_path)} 时发生异常: {exc}")# 将结果写入CSVif results:df = pd.DataFrame(results)df.to_csv(output_csv, index=False)logging.info(f"分析结果已保存到 {output_csv}")else:logging.warning("没有生成任何分析结果。")if __name__ == '__main__':main()

代码详解

接下来我们将详细解释该脚本的每个部分。

1. 配置日志和命令行参数解析

我们首先配置了日志系统,以便记录运行时的相关信息。日志系统可以帮助我们实时跟踪程序的执行状态和潜在问题。

import logginglogging.basicConfig(level=logging.INFO,format='%(asctime)s [%(levelname)s] %(message)s',handlers=[logging.StreamHandler()]
)

然后,定义了命令行参数解析函数 parse_args(),用于接收用户输入的文件夹路径、IOU阈值、输出文件路径、并行进程数等参数:

def parse_args():parser = argparse.ArgumentParser(description='进行标注文件的IOU分析,统计同类别对象之间的重叠数量。')parser.add_argument('--anno_folder', type=str, required=True, help='标注文件夹路径')parser.add_argument('--output_csv', type=str, default='iou_overlap_results.csv', help='输出CSV文件路径')parser.add_argument('--iou_threshold', type=float, default=0.01, help='IOU阈值,超过此值视为重叠')parser.add_argument('--num_workers', type=int, default=None, help='并行处理的进程数,默认为CPU核心数')parser.add_argument('--by_class', action='store_true', help='是否按类别统计重叠数量')return parser.parse_args()

2. 解析DOTA格式文件

DOTA标注文件包含多个对象的坐标和类别,通常以文本行的形式存储。我们通过 parse_annotation_file() 函数读取文件内容,并提取每个对象的类别和多边形坐标。

def parse_annotation_file(file_path, class_map=None):"""解析标注文件,提取对象的类别和多边形坐标。参数:file_path (str): 标注文件的路径。class_map (set, optional): 需要筛选的类别集合。默认为None,表示不筛选。返回:list of tuples: 每个元组包含类别和对应的Shapely多边形。"""objects = []try:with open(file_path, 'r') as file:for line_num, line in enumerate(file, 1):parts = line.strip().split()if len(parts) < 9:logging.warning(f"{file_path} 第{line_num}行格式不正确,跳过。")continuetry:# 假设坐标为前8个元素,类别为第9个元素coords = list(map(float, parts[:8]))dota_type = parts[8]if class_map and dota_type not in class_map:continue# 将坐标转换为Shapely多边形polygon = Polygon(np.array(coords).reshape(-1, 2))if not polygon.is_valid:logging.warning(f"{file_path} 第{line_num}行的多边形无效,跳过。")continueobjects.append((dota_type, polygon))except ValueError as ve:logging.error(f"{file_path} 第{line_num}行坐标转换错误: {ve}")except Exception as e:logging.error(f"读取文件 {file_path} 时发生错误: {e}")return objects

通过这个函数,我们可以将文件中的每个对象转化为一个Shapely库支持的多边形对象,方便后续计算IOU。

3. 计算IOU

IOU(交并比)的计算公式如下:

IOU = \frac{\text{Intersection Area}}{\text{Union Area}}

我们利用Shapely库中的 intersection()union() 方法来计算两个多边形的交集和并集面积。

def compute_iou(poly1, poly2):intersection = poly1.intersection(poly2).areaunion = poly1.union(poly2).areaif union == 0:return 0return intersection / union

4. 文件分析与并行化处理

analyze_file() 函数用于分析单个标注文件,统计同类别对象之间的IOU超过设定阈值的对数。支持按类别统计或整体统计。

def analyze_file(file_path, by_class=False, class_map=None, iou_threshold=0.01):"""分析单个标注文件,统计同类别对象之间的IOU超过阈值的对数。参数:file_path (str): 标注文件的路径。by_class (bool, optional): 是否按类别统计。默认为False。class_map (set, optional): 需要筛选的类别集合。默认为None,表示所有类别。iou_threshold (float, optional): IOU阈值。默认为0.01。返回:list of dicts: 每个字典包含文件名、类别(如果按类别统计)和重叠对数。"""filename = os.path.basename(file_path)objects = parse_annotation_file(file_path, class_map)results = []if by_class:# 按类别分组class_dict = {}for dota_type, polygon in objects:class_dict.setdefault(dota_type, []).append(polygon)for dota_type, polygons in class_dict.items():overlap_count = 0num_objects = len(polygons)if num_objects < 2:# 少于两个对象,无需比较results.append({'filename': filename,'class': dota_type,'overlap_count': 0})continue# 使用组合生成所有可能的对象对for poly1, poly2 in combinations(polygons, 2):iou = compute_iou(poly1, poly2)if iou > iou_threshold:overlap_count += 1results.append({'filename': filename,'class': dota_type,'overlap_count': overlap_count})else:# 不按类别,统计所有对象之间的重叠polygons = [polygon for _, polygon in objects]overlap_count = 0num_objects = len(polygons)if num_objects >= 2:for poly1, poly2 in combinations(polygons, 2):iou = compute_iou(poly1, poly2)if iou > iou_threshold:overlap_count += 1results.append({'filename': filename,'overlap_count': overlap_count})return results

推荐工具

在本文代码中,我们使用了以下Python库,它们在处理几何计算、多进程处理、文件解析等方面发挥了重要作用。如果你对这些库不太熟悉,可以通过以下链接获取更多信息和文档。

  1. Shapely - 进行几何对象的构造和操作,比如多边形的交集、并集等计算。

    • 官方文档:Shapely Documentation
    • 安装方法:pip install shapely
  2. NumPy - 科学计算库,用于处理数值数组。在这里,我们用它来将多边形的坐标转换为二维数组。

    • 官方文档:NumPy Documentation
    • 安装方法:pip install numpy
  3. itertools - Python标准库中的组合工具,用于生成多边形配对,计算它们之间的IOU。

    • 官方文档:itertools Documentation
  4. concurrent.futures - Python标准库中的并发工具,用于多进程并行处理标注文件。

    • 官方文档:concurrent.futures Documentation

结论

通过这篇博客,我们详细介绍了如何使用Python并行化处理DOTA格式的标注文件,并统计对象之间的IOU重叠情况。该脚本不仅具有较强的灵活性(支持按类别或整体统计),还充分利用多进程加速大数据量的处理。希望这篇博客能够帮助你在实际项目中更高效地处理和分析目标检测任务中的标注文件。

---

希望这篇博客对您有所帮助,如果您喜欢这篇文章,请点赞或关注,我会持续分享更多实用的 Python 技术内容!

---

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com