目录
1 简介
2 结果展示
3 前提条件
使用 pip 安装
4 代码详解
指定标注文件路径
获取并显示类别信息
统计每个类别的图像和标注数量
5 完整代码
1 简介
COCO(Common Objects in Context)是一个广泛用于计算机视觉任务(如物体检测、实例分割)的数据集格式。pycocotools 是一个 Python 库,专门用于加载和操作 COCO 数据集的标注文件(通常是 JSON 格式)。本教程将通过一个示例代码,展示如何使用 pycocotools 加载 COCO 标注文件,提取类别信息,并统计每个类别的图像数量和标注框数量。
2 结果展示
3 前提条件
在运行代码之前,您需要:
Python 环境:确保已安装 Python(推荐版本 3.6 或更高)。
pycocotools 库:需要安装此库。
COCO 格式的标注文件:一个符合 COCO 数据集格式的 JSON 文件(例如 train.json)
使用 pip 安装
pip install pycocotools
pip install pycocotools-windows
4 代码详解
指定标注文件路径
annFile 是一个字符串,表示 COCO 标注文件的路径。
注意:路径中的 r 表示原始字符串,避免反斜杠 \ 被转义。
确保此文件存在且符合 COCO 格式(包含 images、annotations 和 categories 等字段)。
annFile = r"E:\AIrailway\CPSSFO_DATASET\ann\train.json"
获取并显示类别信息
coco.getCatIds():返回所有类别的 ID 列表(整数)。
coco.loadCats(coco.getCatIds()):根据类别 ID 加载详细信息,返回一个字典列表,每个字典包含类别名称(name)、ID(id)等。
cat_nms = [cat['name'] for cat in cats]:提取类别名称,生成一个列表。
print:输出类别总数和名称列表。
cats = coco.loadCats(coco.getCatIds())
cat_nms = [cat['name'] for cat in cats]
print('number of categories: ', len(cat_nms))
print('COCO categories: \n', cat_nms)
统计每个类别的图像和标注数量
for cat_name in cat_nms:遍历所有类别名称。
coco.getCatIds(catNms=cat_name):根据类别名称获取对应的 ID。
coco.getImgIds(catIds=catId):获取包含该类别的所有图像 ID。
coco.getAnnIds(catIds=catId):获取该类别的所有标注框 ID
print("{:<15} {:<6d} {:<10d}".format(...)):格式化输出:
{:<15}:类别名称,左对齐,占 15 个字符。
{:<6d}:图像数量,左对齐,占 6 个字符。
{:<10d}:标注框数量,左对齐,占 10 个字符。
for cat_name in cat_nms:catId = coco.getCatIds(catNms=cat_name) # 获取类别 IDimgId = coco.getImgIds(catIds=catId) # 获取图像 IDannId = coco.getAnnIds(catIds=catId) # 获取标注框 IDprint("{:<15} {:<6d} {:<10d}".format(cat_name, len(imgId), len(annId)))
5 完整代码
from pycocotools.coco import COCOannFile = r"E:\AIrailway\CPSSFO_DATASET\ann\train.json"#json# initialize COCO api for instance annotations
coco = COCO(annFile)# display COCO categories and supercategories
cats = coco.loadCats(coco.getCatIds())
cat_nms = [cat['name'] for cat in cats]
print('number of categories: ', len(cat_nms))
print('COCO categories: \n', cat_nms)# 统计各类的图片数量和标注框数量
for cat_name in cat_nms:catId = coco.getCatIds(catNms=cat_name) # 1~90imgId = coco.getImgIds(catIds=catId) # 图片的idannId = coco.getAnnIds(catIds=catId) # 标注框的idprint("{:<15} {:<6d} {:<10d}".format(cat_name, len(imgId), len(annId)))