添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接

本文以 PASCAL VOC2012 数据集为例子进行说明。(下载地址: PASCAL VOC2012 )

Pytorch 自定义数据集见文档: TorchVision Object Detection Finetuning Tutorial

本文将以PASCAL VOC为基础自定义一个数据集 VOCDataset ,并随机选取五张图片给将其对应的标注转化为矩形框画在图片上。

本文详细代码见: pytorch-tutorial/01-common/custom_dataset at main · simo-an/pytorch-tutorial (github.com)

定义一些工具类

定义类别数据,共有20中目标类别

class_dict = {
    "aeroplane": 1,
    "bicycle": 2,
    "bird": 3,
    "boat": 4,
    "bottle": 5,
    "bus": 6,
    "car": 7,
    "cat": 8,
    "chair": 9,
    "cow": 10,
    "diningtable": 11,
    "dog": 12,
    "horse": 13,
    "motorbike": 14,
    "person": 15,
    "pottedplant": 16,
    "sheep": 17,
    "sofa": 18,
    "train": 19,
    "tvmonitor": 20

将xml 转化为 json

def parse_xml_to_dict(xml):
    if len(xml) == 0:
        return {xml.tag: xml.text}
    result = {}
    for child in xml:
        child_result = parse_xml_to_dict(child)
        if child.tag != 'object':
            result[child.tag] = child_result[child.tag]
        else:
            if child.tag not in result:
                result[child.tag] = []
            result[child.tag].append(child_result[child.tag])
    return {xml.tag: result}

在图片上画出矩形框(参考代码: vision/utils.py at main · pytorch/vision (github.com))

def draw_bounding_boxes(
    image,
    boxes: torch.Tensor,
    labels: torch.Tensor,
) -> torch.Tensor:
    img_to_draw = Image.fromarray(image)
    img_boxes = boxes.to(torch.int64).tolist()
    draw = ImageDraw.Draw(img_to_draw)
    class_map = [k for k, v in class_dict.items()]
    for i, bbox in enumerate(img_boxes):
        draw.rectangle(bbox, width=2, outline='red')
        margin = 2
        draw.text((bbox[0] + margin, bbox[1] + margin),
                  class_map[labels[i] - 1], fill='red')
    return np.array(img_to_draw)

生成自定义数据集

一些需要导入的基本库

import torch
import utils
from torch.utils.data import Dataset
from PIL import Image
from os import path
from lxml import etree

按照文档要求,在VOCDataset中实现三个方法__len____getitem__、以及get_height_and_width

初始化 VOCDataset 类

构造函数定义如下

voc_root: voc 数据集的根目录 year: 哪一个年份的数据集 transforms: 数据预处理 text_name: train.txt or val.txt 该txt文件在数据集的 VOCdevkit\VOC2012\ImageSets\Main 文件夹下 def __init__(self, voc_root, year='2012', transforms=None, text_name='train.txt'):

在构造函数中,我们主要完成以下三个功能

  • 设置图片路径image_root和标注路径anno_root
  • 设置此次要训练的样本所有标注文件路径列表xml_list
  • 设置要检测的目标类别信息class_dict
  • 设置图片路径image_root和标注路径anno_root

            # 设置数据集、图片、标注的根目录
            self.root = path.join(voc_root, 'VOCdevkit', f'VOC{year}')
            self.image_root = path.join(self.root, 'JPEGImages')
            self.anno_root = path.join(self.root, 'Annotations')
    

    设置此次要训练的样本所有标注文件路径列表xml_list

            # 根据 text_name 拿到对应的标注xml文件路径
            text_path = path.join(self.root, 'ImageSets','Main', text_name)
            # 读取txt文件的每一行并生成xml标注文件路径存放在xml_list中
            with open(text_path) as file_reader:
                self.xml_list = [
                    path.join(self.anno_root, f'{line.strip()}.xml')
                    for line in file_reader.readlines() if len(line.strip()) > 0
    

    设置要检测的目标类别信息class_dict

            self.class_dict = utils.class_dict
    

    一般使用 0 来表示当前类别是背景

    获取所有样例条数

        def __len__(self):
            return len(self.xml_list)
    

    样本的条数即标注文件列表长度

    根据索引获取指定样本

    函数定义如下

        def __getitem__(self, idx):
    

    传入的即为样本的索引值,其取值范围为 0 ~ len(xml_list)

    获取指定样本需要分为如下两大步

  • 获取图片信息(标注信息、索引、区域面积等)
  • 首先我们需要根据索引拿到对应标注信息,并将其转化为json格式 定义一个获取json格式的annotation的方法

        def get_annotation(self, idx):
            xml_path = self.xml_list[idx]
            assert path.exists(xml_path), f'file {xml_path} not found'
            xml_reader = open(xml_path)
            xml_text = xml_reader.read()
            xml = etree.fromstring(xml_text)
            annotation = utils.parse_xml_to_dict(xml)['annotation']
    

    获取annotation

            annotation = self.get_annotation(idx)
    

    然后我们就可以从annotation中拿到文件名称并获取到文件

            image_path = path.join(self.image_root, annotation['filename'])
            image = Image.open(image_path)
    

    获取图片信息

    声明需要获取的所有信息

            # 生成 target
            target = {
                'boxes': [], # 标注的左上、右下坐标(xmin, ymin, xmax, ymax)
                'labels': [],# 标注类别
                'image_id': [], # 图片索引
                'area': [], # 含有目标区域的面积 (xmax-xmin) * (ymax-ymin)
                'iscrowd': [], # 是不是一堆密集的东西在一起
    

    便利所有的object

    for obj in annotation['object']: bndbox = obj['bndbox'] xmin = float(bndbox['xmin']) ymin = float(bndbox['ymin']) xmax = float(bndbox['xmax']) ymax = float(bndbox['ymax']) target['boxes'].append([xmin, ymin, xmax, ymax]) # 设置有目标的坐标信息 target['labels'].append(self.class_dict[obj['name']]) # 获取对应的标签 target['area'].append((xmax - xmin) * (ymax - ymin)) # 计算面积 # 使用 difficult(当前目标是否难以识别) 字段来设置 iscrowd if 'difficult' in obj: target['iscrowd'].append(int(obj['difficult'])) else: target['iscrowd'].append(0)

    将所有信息转化为Tensor

            # Convert to tensor
            target['boxes'] = torch.as_tensor(target['boxes'])
            target['labels'] = torch.as_tensor(target['labels'])
            target['iscrowd'] = torch.as_tensor(target['iscrowd'])
            target['area'] = torch.as_tensor(target['area'])
            target['image_id'] = torch.tensor([idx])
    

    如果有设置数据预处理器,则在返回数据前调用

            if self.transforms is not None:
                image = self.transforms(image)
    

    返回图片以及对应的信息

            return image, target
    

    根据索引获取当前图片的宽高

    在标注信息里面含有图片宽高信息,所以可以很容易获取到

        def get_height_and_width(self, idx):
            annotation = annotation = self.get_annotation(idx)
            # 从 annotation 中取出宽高并返回
            width = int(annotation['size']['width'])
            height = int(annotation['size']['height'])
            return height, width
    

    以上我们就完成了数据集的定义,下面我们将使用实例代码来使用这个数据集

    使用自定义数据集并画上标注框

    导入一些基本库

    import os
    import random
    import utils
    import main
    import matplotlib.pyplot as plt
    import numpy as np
    

    定义transformer,将数据转化为Tensor

    data_transform = ts.Compose([ts.ToTensor()])
    

    由于ToTensor会将数据标准化,为了代码简洁,这里不使用

    拿到数据集并将目标框以及类别画出来

    train_data_set = VOCDataset(os.getcwd(), '2012', None, 'train.txt')
    for index in random.sample(range(0, len(train_data_set)), k=5):
        image, target = train_data_set[index]
        image = utils.draw_bounding_boxes(
            np.array(image),
            target['boxes'],
            target['labels'],
        plt.imshow(image)
        plt.show()
    

    这样就完成了整个流程了!

    运行与测试

    可见运行结果正确!

    分类:
    人工智能
    标签: