首页 游戏 软件 资讯 排行榜 专题
首页
AI
基于Albumentations库的目标检测数据增强

基于Albumentations库的目标检测数据增强

热心网友
31
转载
2025-07-23
本文介绍利用Albumentations库进行数据增强,以减少标注工作量。先说明安装方式及注意事项,接着详细阐述对COCO、YOLO、VOC三种常见数据集格式的批量增强实现,包括定义增强类、设置增强选项及保存增强结果,还提及增强在应对比赛环境等方面的作用,为相关比赛提供便利。

基于albumentations库的目标检测数据增强 - 游乐网

基于Albumentations库的数据增强

 对于自己采集的数据集,最头疼的就是标注数据集,如何自己少标注一点数据集而又获得更多的数据集来训练出更好的模型,则可以使用Albumentations开源库(Github地址: https://github.com/albumentations-team/albumentations)进行数据增强,直接获取对应的标注文件,无需二次标注,获得更多的针对性的数据集,如:在移动的小车上进行检测,可能需要对图片进行模糊,缩放等来增强模型。当然首要任务还是(偷懒)。

免费影视、动漫、音乐、游戏、小说资源长期稳定更新! 👉 点此立即查看 👈

一、安装Albumentations库

  直接pip安装会导致opencv, numpy, scipy等安装到最新的版本, 如果出现版本错误需要重新安装;使用源码安装可以在set.py中设置依赖库的版本问题。

In [1]
# 源码安装# !unzip -d /home/aistudio/work /home/aistudio/work/albumentations-master.zip%cd /home/aistudio/work/albumentations-master!python3 setup.py install # pip安装# !pip install -U albumentations# !pip install opencv-python-headless==4.1.1.26# !pip install numpy==1.16.4# !pip install scipy
登录后复制        
^Cinterrupted
登录后复制        

二、说明

1、对于安装albumentations, 如果直接pip安装会安装最新的版本同时有些库也会改变,如numpy,opencv等;在AI Studio上进行源码安装本人会
卡在某一个下载依赖库上,所以还是建议在自己电脑上进行,所有均在自己电脑上PyCharm上跑通。

2、在AI Studio上使用opencv相关的展示图片的会报错,所以可以在本地运行时置is_show为True。同时在AI Studio下有些目录下会多出“.ipynb_checkpoints” 文件,则在遍历时可能会出错,所以根据文件后缀名进行过滤, 还是建议在本地运行。

3、本项目仅对三种格式做批量增强,对于albumentations库的更多用法请去Github主页或另行搜索。同时三种格式的规范如果不理解可以参考示例或自行搜索。

4、本项目按照作者习惯进行,其中图片命名均从0000.webp开始,0001.webp ....依次向后。

5、Albumentations的空间级增强后生成的标注框有时候会存在误差, 像仿射变换, 像素级的变化没有影响。作者也在Github上说在修复改进, 所以最好安装最新的版本, 同时在以下代码中将is_show参数 设置为True, 每次增强一张图片都会把标注框画出并展示, 可以通过按键选择是否保存, 检查一下还是好的。 但在AI Studio上, opencv的交互函数好像运行有问题, 像imshow, waitKey, 所以在本地会好一些, 每种格式增强后使用了matplotlib函数展示了增强后的图片, 没有加上框, 可在对应的目录上查看增强后的文件。

In [36]
# 首先进行图片增强小测试, 该测试只是选择一下增强的方式, 例如: 如果你的检测目标与颜色有关联, # 可能就不能选择改变颜色的增强方式, 如果采用镜像的增强方式, 左转路标可能就变成右转路标, 需要注意!!!import albumentations as Aimport cv2import numpy as npimport matplotlib.pyplot as plt# 读取原始图片original_image = cv2.imread('/home/aistudio/work/0000.webp')# 像素级变换transform_Pixel = A.Compose([    # A.CLAHE(p=1),  # 直方图均衡    # A.ChannelDropout(p=1),  # 随机丢弃通道    # A.ChannelShuffle(p=1),  # 随机排列通道    A.ColorJitter(p=1),  # 随机改变图像的亮度、对比度、饱和度、色调])# 空间级变换transform_Spatial = A.Compose([    # A.RandomCrop(width=256, height=256),    A.HorizontalFlip(p=1),    A.RandomBrightnessContrast(brightness_limit=0.5, contrast_limit=0.5, p=1), # 与像素级变换结合使用    # A.SafeRotate(limit=60, p=1),    # A.Rotate(limit=45, p=1),    # A.Affine(p=1),    # A.GridDistortion(p=1),])# 进行增强变化transformed = transform_Spatial(image=original_image)# 获得增强后的图片transformed_image = transformed["image"]transformed_image = cv2.cvtColor(transformed_image, cv2.COLOR_BGR2RGB)original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)plt.subplot(1, 2, 1), plt.title("original image"), plt.axis('off')plt.imshow(original_image) plt.subplot(1, 2, 2), plt.title("transformed image"), plt.axis('off')plt.imshow(transformed_image)plt.show()
登录后复制        
登录后复制                In [ ]
# 解压数据集格式目录例程# TestImage为各个数据集格式的根目录,里面有COCO、YOLO、VOC三种格式的例子。!unzip -d /home/aistudio/work /home/aistudio/work/DataProcess.zip
登录后复制    

三、COCO、YOLO、VOC格式的批量增强

1、COCO格式数据增强

  coco格式如下:

COCO
 |-- annotations
  |-- train.json
  |-- val.json

 |-- train
  |-- 0000.webp
  |-- 0001.webp
  |-- .....webp

 |-- val
  |-- 0000.webp
  |-- 0001.webp
  |-- .....webp
  本次只用少数val数据集示例。

In [ ]
# 进入COCO格式目录%cd /home/aistudio/work/TestImage/COCO
登录后复制    In [17]
# 定义增强类class COCOAug(object):    def __init__(self,                 anno_path=None,                 pre_image_path=None,                 save_image_path=None,                 anno_mode='train',                 is_show=True,                 start_filename_id=None,                 start_anno_id=None,                 ):        """        :param anno_path: json文件的路径        :param pre_image_path: 需要增强的图片路径        :param save_image_path: 保存的图片路径        :param anno_mode: 有train,val两种, 同时也对应两种路径, 两种json文件[train.json, val.json]        :param is_show: 是否实时展示: 每增强一张图片就把对应的标注框和标签画出并imshow        :param start_filename_id: 新的图片起始名称. 同时也对应图片的id, 后续在此基础上依次+1,                                  如果没有指定则按已有的图片长度继续+1        :param start_anno_id: 新的注释id起始号, 后续在此基础上依次+1, 如果没有指定则按已有的注释个数长度继续+1        """        self.anno_path = anno_path        self.aug_image_path = pre_image_path        self.save_image_path = save_image_path        self.anno_mode = anno_mode        self.is_show = is_show        self.start_filename_id = start_filename_id        self.start_anno_id = start_anno_id        # 数据增强选项        self.aug = A.Compose([            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),            A.GaussianBlur(p=0.7), # 高斯滤波            A.GaussNoise(p=0.7), # 高斯模糊            A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.5),  # 直方图均衡            A.Equalize(p=0.5),  # 均衡图像直方图            A.HorizontalFlip(p=1),             A.OneOf([                # A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),                # A.ChannelShuffle(p=0.3),  # 随机排列通道                # A.ColorJitter(p=0.3),  # 随机改变图像的亮度、对比度、饱和度、色调                # A.ChannelDropout(p=0.3),  # 随机丢弃通道            ], p=0.),            # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量            A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加        ],            # coco: [x_min, y_min, width, height]            # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.            # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox            A.BboxParams(format='coco', min_area=0., min_visibility=0., label_fields=['category_id'])        )        # 打开json文件        with open(os.path.join(self.anno_path, f"{self.anno_mode}.json"), 'r', encoding='utf-8') as load_f:            self.load_dict = json.load(load_f)  # ['images', 'annotations', 'categories']            self.labels = []  # 读取标签列表            for anno in self.load_dict['categories']:                self.labels.append(anno['name'])            print("--------- * ---------")            if self.start_filename_id is None:                self.start_filename_id = len(self.load_dict['images'])                print("the start_filename_id is not set, default: len(images)")            if self.start_anno_id is None:                self.start_anno_id = len(self.load_dict['annotations'])                print("the start_anno_id is not set, default: len(annotations)")            print("len(images)     : ", self.start_filename_id)            print("len(annotations): ", self.start_anno_id)            print("categories: ", self.load_dict['categories'])            print("labels: ", self.labels)            print("--------- * ---------")        def image_aug(self, max_len=4):        """        json格式        "images": [{"file_name": "013856.webp", "height": 1080, "width": 1920, "id": 13856},...]        "annotations": [{"image_id": 13856, "id": 0, "category_id": 2, "bbox": [541, 517, 79, 102],                         "area": 8058, "iscrowd": 0, "segmentation": []}, ...]        "categories": [{"id": 0, "name": "Motor Vehicle"}, ...]        :param start_filename_id: 起始图片id号        :param start_anno_id: 起始注释框id号        :param max_len: 默认数据集不超过9999, 即: 0000~9999 如果更多可以设置为5 即00000~99999        :return: None        """        # 保存原始数据        aug_data = self.load_dict        # 记录给定的开始序列        cnt_filename = self.start_filename_id        cnt_anno_id = self.start_anno_id        # 对每一张图片遍历        for index, item in enumerate(self.load_dict['images'][:]):            image_name = item['file_name']            image_suffix = image_name.split(".")[-1]  # 获取图片后缀 e.g. [.webp .webp]            image_id = item['id']            bboxes_list = []            category_id_list = []            # 对每一张图片找到所有的标注框, 并且bbox和label的id要对应上            for anno in self.load_dict['annotations']:                if anno['image_id'] == image_id:                    bboxes_list.append(anno['bbox'])                    category_id_list.append(anno['category_id'])            # 读取图片            image = cv2.imread(os.path.join(self.aug_image_path, image_name))            h, w = image.shape[:2]            # 生成需要增强的图片的anno字典            # augmented {'image':, 'height':,'width:', 'bboxes':[(),()], 'category_id':[,,]}            aug_anno = {'image': image, 'height': h, 'width': w, 'bboxes': bboxes_list, 'category_id': category_id_list}            # 得到增强后的数据 {"image", "height", "width", "bboxes", "category_id"}            augmented = self.aug(**aug_anno)            # print(augmented)            aug_image = augmented['image']            aug_bboxes = augmented['bboxes']            aug_category_id = augmented['category_id']            height = augmented['height']            width = augmented['width']            # 对增强后的bbox取整            for index, bbox in enumerate(aug_bboxes):                x, y, w, h = bbox                aug_bboxes[index] = [int(x + 0.5), int(y + 0.5), int(w + 0.5), int(h + 0.5)]            # 是否进行实时展示图片, 用于检测是否有误            if self.is_show:                tl = 2                # aug_image_copy = aug_image.copy()                aug_image_copy = aug_image                for bbox, category_id in zip(aug_bboxes, aug_category_id):                    text = f"{self.labels[category_id]}"                    t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tl)[0]                    cv2.rectangle(aug_image_copy, (bbox[0], bbox[1] - 3),                                  (bbox[0] + t_size[0], bbox[1] - t_size[1] - 3),                                  (0, 0, 255), -1, cv2.LINE_AA)  # filled                    cv2.putText(aug_image_copy, text, (bbox[0], bbox[1] - 2), 0, tl / 3, (255, 255, 255), tl,                                cv2.LINE_AA)                    aug_image_show = cv2.rectangle(aug_image_copy, (bbox[0], bbox[1]),                                                   (bbox[0] + bbox[2], bbox[1] + bbox[3]),                                                   (255, 255, 0), 2)                # cv2.imshow('aug_image_show', aug_image_show)                                # 实时检测增强后的标注框是否有较大偏差, 符合要求按下's'健保存, 其他键跳过                key = cv2.waitKey(0)                # 按下s键保存增强,否则取消保存此次增强                if key & 0xff == ord('s'):                    pass                else:                    cv2.destroyWindow(f'aug_image_show')                    continue                cv2.destroyWindow(f'aug_image_show')            # 获取新的图片名称 e.g.  cnt_filename=45   new_filename: 0045.image_suffix            name = '0' * max_len  # e.g. '0'*4 = '0000'            cnt_str = str(cnt_filename)            length = len(cnt_str)            new_filename = name[:-length] + cnt_str + f'.{image_suffix}'            # 保存增强后的图片            cv2.imwrite(os.path.join(self.save_image_path, new_filename), aug_image)            # 添加增强后的图片            dict_image = {                "file_name": new_filename,                "height": height,                "width": width,                "id": cnt_filename            }            aug_data['images'].append(dict_image)            # print("augmented['bboxes']: ", augmented['bboxes'])            for bbox, idx in zip(bboxes_list, category_id_list):                dict_anno = {'image_id': cnt_filename,                             'id': cnt_anno_id,                             'category_id': idx,                             'bbox': bbox,                             'area': int(bbox[2] * bbox[3]),                             'iscrowd': 0,                             "segmentation": []                             }                aug_data['annotations'].append(dict_anno)                # 每一个增加的anno_id+1                cnt_anno_id += 1            # 图片数+1            cnt_filename += 1        # 保存增强后的json文件        with open(os.path.join(self.anno_path, f'aug_{self.anno_mode}.json'), 'w') as ft:            json.dump(aug_data, ft)
登录后复制    In [19]
# 对示例数据集进行增强, 运行成功后会在相应目录下保存import osimport jsonimport matplotlib.pyplot as pltimport cv2# 图片路径PRE_IMAGE_PATH = '/home/aistudio/work/TestImage/COCO/val'SAVE_IMAGE_PATH = '/home/aistudio/work/TestImage/COCO/val'# anno路径ANNO_PATH = '/home/aistudio/work/TestImage/COCO/annotations'mode = 'val'  # ['train', 'val']aug = COCOAug(        anno_path=ANNO_PATH,        pre_image_path=PRE_IMAGE_PATH,        save_image_path=SAVE_IMAGE_PATH,        anno_mode=mode,        is_show=False,    )aug.image_aug()# cv2.destroyAllWindows()original_image1 = cv2.imread('/home/aistudio/work/TestImage/COCO/val/0000.webp')transformed_image1 = cv2.imread('/home/aistudio/work/TestImage/COCO/val/0002.webp')original_image2 = cv2.imread('/home/aistudio/work/TestImage/COCO/val/0001.webp')transformed_image2 = cv2.imread('/home/aistudio/work/TestImage/COCO/val/0003.webp')original_image1 = cv2.cvtColor(original_image1, cv2.COLOR_BGR2RGB)transformed_image1 = cv2.cvtColor(transformed_image1, cv2.COLOR_BGR2RGB)original_image2 = cv2.cvtColor(original_image2, cv2.COLOR_BGR2RGB)transformed_image2 = cv2.cvtColor(transformed_image2, cv2.COLOR_BGR2RGB)plt.subplot(2, 2, 1), plt.title("original image"), plt.axis('off')plt.imshow(original_image1) plt.subplot(2, 2, 2), plt.title("transformed image"), plt.axis('off')plt.imshow(transformed_image1)plt.subplot(2, 2, 3), plt.title("original image"), plt.axis('off')plt.imshow(original_image2) plt.subplot(2, 2, 4), plt.title("transformed image"), plt.axis('off')plt.imshow(transformed_image2)plt.show()
登录后复制        
--------- * ---------the start_filename_id is not set, default: len(images)the start_anno_id is not set, default: len(annotations)len(images)     :  2len(annotations):  2categories:  [{'id': 0, 'name': 'side-walk'}, {'id': 1, 'name': 'speed-limit'}, {'id': 2, 'name': 'turn-left'}, {'id': 3, 'name': 'slope'}, {'id': 4, 'name': 'speed'}]labels:  ['side-walk', 'speed-limit', 'turn-left', 'slope', 'speed']--------- * ---------
登录后复制        
登录后复制登录后复制登录后复制                

2、YOLO格式数据增强

  yolo格式如下:

YOLO
 |-- images
  |-- 0000.webp
  |-- 0001.webp
  |-- .....webp

 |-- labels
  |-- 0000.txt
  |-- 0001.txt
  |-- .....txt
  本次只用少数数据集示例。

In [ ]
# 进入YOLO格式目录%cd /home/aistudio/work/TestImage/YOLO/
登录后复制    In [29]
# 定义类class YOLOAug(object):    def __init__(self,                 pre_image_path=None,                 pre_label_path=None,                 aug_save_image_path=None,                 aug_save_label_path=None,                 labels=None,                 is_show=True,                 start_filename_id=None,                 max_len=4):        """                :param pre_image_path:         :param pre_label_path:         :param aug_save_image_path:         :param aug_save_label_path:         :param labels: 标签列表, 需要根据自己的设定, 用于展示图片        :param is_show:         :param start_filename_id:         :param max_len:         """        self.pre_image_path = pre_image_path        self.pre_label_path = pre_label_path        self.aug_save_image_path = aug_save_image_path        self.aug_save_label_path = aug_save_label_path        self.labels = labels        self.is_show = is_show        self.start_filename_id = start_filename_id        self.max_len = max_len        # 数据增强选项        self.aug = A.Compose([            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),            # A.GaussianBlur(p=0.7),            # A.GaussNoise(p=0.7),            # A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.5),  # 直方图均衡            # A.Equalize(p=0.5),  # 均衡图像直方图            A.HorizontalFlip(p=1),            A.OneOf([                # A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),                # A.ChannelShuffle(p=0.3),  # 随机排列通道                # A.ColorJitter(p=0.3),  # 随机改变图像的亮度、对比度、饱和度、色调                # A.ChannelDropout(p=0.3),  # 随机丢弃通道            ], p=0.),            # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量            A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加        ],            # yolo: [x_center, y_center, width, height]  # 经过归一化            # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.            # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox            A.BboxParams(format='yolo', min_area=0., min_visibility=0., label_fields=['category_id'])        )        print("--------*--------")        image_len = len(os.listdir(self.pre_image_path))        print("the length of images: ", image_len)        if self.start_filename_id is None:            print("the start_filename id is not set, default: len(image)", image_len)            self.start_filename_id = image_len        print("--------*--------")    def get_data(self, image_name):        """        获取图片和对应的label信息        :param image_name: 图片文件名, e.g. 0000.webp        :return:        """        image = cv2.imread(os.path.join(self.pre_image_path, image_name))        with open(os.path.join(self.pre_label_path, image_name.split('.')[0] + '.txt'), 'r',                  encoding='utf-8') as f:            label_txt = f.readlines()        label_list = []        cls_id_list = []        for label in label_txt:            label_info = label.strip().split(' ')            cls_id_list.append(int(label_info[0]))            label_list.append([float(x) for x in label_info[1:]])        anno_info = {'image': image, 'bboxes': label_list, 'category_id': cls_id_list}        return anno_info    def aug_image(self):        image_list = os.listdir(self.pre_image_path)        file_name_id = self.start_filename_id        for image_filename in image_list[:]:            image_suffix = image_filename.split('.')[-1]            # AI Studio下会存在.ipynb_checkpoints文件, 为了不报错, 根据文件后缀过滤            if image_suffix not in ['jpg', 'png']:                continue            aug_anno = self.get_data(image_filename)            # 获取增强后的信息            aug_info = self.aug(**aug_anno)  # {'image': , 'bboxes': , 'category_id': }            aug_image = aug_info['image']            aug_bboxes = aug_info['bboxes']            aug_category_id = aug_info['category_id']            name = '0' * self.max_len            cnt_str = str(file_name_id)            length = len(cnt_str)            new_image_filename = name[:-length] + cnt_str + f'.{image_suffix}'            new_label_filename = name[:-length] + cnt_str + '.txt'            print(f"aug_image_{new_image_filename}: ")            aug_image_copy = aug_image.copy()            for cls_id, bbox in zip(aug_category_id, aug_bboxes):                print(f" --- --- cls_id: ", cls_id)                if self.is_show:                    tl = 2                    h, w = aug_image_copy.shape[:2]                    x_center = int(bbox[0] * w)                    y_center = int(bbox[1] * h)                    width = int(bbox[2] * w)                    height = int(bbox[3] * h)                    xmin = int(x_center - width / 2)                    ymin = int(y_center - height / 2)                    xmax = int(x_center + width / 2)                    ymax = int(y_center + height / 2)                    text = f"{self.labels[cls_id]}"                    t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tl)[0]                    cv2.rectangle(aug_image_copy, (xmin, ymin - 3), (xmin + t_size[0], ymin - t_size[1] - 3), (0, 0, 255),                                  -1, cv2.LINE_AA)  # filled                    cv2.putText(aug_image_copy, text, (xmin, ymin - 2), 0, tl / 3, (255, 255, 255), tl, cv2.LINE_AA)                    aug_image_show = cv2.rectangle(aug_image_copy, (xmin, ymin), (xmax, ymax), (255, 255, 0), 2)            if self.is_show:                cv2.imshow(f'aug_image_{new_image_filename}', aug_image_show)                key = cv2.waitKey(0)                # 按下s键保存增强,否则取消保存此次增强                if key & 0xff == ord('s'):                    pass                else:                    cv2.destroyWindow(f'aug_image_{new_image_filename}')                    continue                cv2.destroyWindow(f'aug_image_{new_image_filename}')            # 保存增强后的信息            cv2.imwrite(os.path.join(self.aug_save_image_path, new_image_filename), aug_image)            with open(os.path.join(self.aug_save_label_path, new_label_filename), 'w', encoding='utf-8') as lf:                for cls_id, bbox in zip(aug_category_id, aug_bboxes):                    lf.write(str(cls_id) + ' ')                    for i in bbox:                        # 保存小数点后六位                        lf.write(str(i)[:8] + ' ')                    lf.write('\n')            file_name_id += 1
登录后复制    In [30]
# 对示例数据集进行增强, 运行成功后会在相应目录下保存 import osimport jsonimport cv2import numpy as npimport matplotlib.pyplot as plt# 原始图片和label路径PRE_IMAGE_PATH = '/home/aistudio/work/TestImage/YOLO/images'PRE_LABEL_PATH = '/home/aistudio/work/TestImage/YOLO/labels'# 增强后的图片和label保存的路径AUG_SAVE_IMAGE_PATH ='/home/aistudio/work/TestImage/YOLO/images'AUG_SAVE_LABEL_PATH = '/home/aistudio/work/TestImage/YOLO/labels'# 类别列表, 需要根据自己的修改labels = ['side-walk', 'speed-limit', 'turn-left', 'slope', 'speed']aug = YOLOAug(pre_image_path=PRE_IMAGE_PATH,                pre_label_path=PRE_LABEL_PATH,                aug_save_image_path=AUG_SAVE_IMAGE_PATH,                aug_save_label_path=AUG_SAVE_LABEL_PATH,                labels=labels,                is_show=False)aug.aug_image()original_image1 = cv2.imread('/home/aistudio/work/TestImage/YOLO/images/0000.webp')transformed_image1 = cv2.imread('/home/aistudio/work/TestImage/YOLO/images/0003.webp')original_image2 = cv2.imread('/home/aistudio/work/TestImage/YOLO/images/0001.webp')transformed_image2 = cv2.imread('/home/aistudio/work/TestImage/YOLO/images/0004.webp')original_image1 = cv2.cvtColor(original_image1, cv2.COLOR_BGR2RGB)transformed_image1 = cv2.cvtColor(transformed_image1, cv2.COLOR_BGR2RGB)original_image2 = cv2.cvtColor(original_image2, cv2.COLOR_BGR2RGB)transformed_image2 = cv2.cvtColor(transformed_image2, cv2.COLOR_BGR2RGB)plt.subplot(2, 2, 1), plt.title("original image"), plt.axis('off')plt.imshow(original_image1) plt.subplot(2, 2, 2), plt.title("transformed image"), plt.axis('off')plt.imshow(transformed_image1)plt.subplot(2, 2, 3), plt.title("original image"), plt.axis('off')plt.imshow(original_image2) plt.subplot(2, 2, 4), plt.title("transformed image"), plt.axis('off')plt.imshow(transformed_image2)plt.show()
登录后复制        
--------*--------the length of images:  3the start_filename id is not set, default: len(image) 3--------*--------aug_image_0003.webp:  --- --- cls_id:  0 --- --- cls_id:  3aug_image_0004.webp:  --- --- cls_id:  0 --- --- cls_id:  3
登录后复制        
登录后复制登录后复制登录后复制                

3、VOC格式数据增强

  voc格式如下:

VOC
 |-- images
  |-- 0000.webp
  |-- 0001.webp
  |-- .....webp

 |-- labels
  |-- 0000.xml
  |-- 0001.xml
  |-- .....xml
  本次只用少数数据集示例。

In [ ]
# 进入VOC格式目录%cd /home/aistudio/work/TestImage/VOC/
登录后复制    In [33]
# 定义类class VOCAug(object):    def __init__(self,                 pre_image_path=None,                 pre_xml_path=None,                 aug_image_save_path=None,                 aug_xml_save_path=None,                 start_aug_id=None,                 labels=None,                 max_len=4,                 is_show=False):        """                :param pre_image_path:         :param pre_xml_path:         :param aug_image_save_path:         :param aug_xml_save_path:         :param start_aug_id:         :param labels: 标签列表, 展示增强后的图片用        :param max_len:         :param is_show:         """        self.pre_image_path = pre_image_path        self.pre_xml_path = pre_xml_path        self.aug_image_save_path = aug_image_save_path        self.aug_xml_save_path = aug_xml_save_path        self.start_aug_id = start_aug_id        self.labels = labels        self.max_len = max_len        self.is_show = is_show        print(self.labels)        assert self.labels is not None, "labels is None!!!"        # 数据增强选项        self.aug = A.Compose([            A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=1),            A.GaussianBlur(p=0.7),            A.GaussNoise(p=0.7),            A.CLAHE(clip_limit=2.0, tile_grid_size=(4, 4), p=0.5),  # 直方图均衡            A.Equalize(p=0.5),  # 均衡图像直方图            A.OneOf([                # A.RGBShift(r_shift_limit=50, g_shift_limit=50, b_shift_limit=50, p=0.5),                # A.ChannelShuffle(p=0.3),  # 随机排列通道                # A.ColorJitter(p=0.3),  # 随机改变图像的亮度、对比度、饱和度、色调                # A.ChannelDropout(p=0.3),  # 随机丢弃通道            ], p=0.),            # A.Downscale(p=0.1),  # 随机缩小和放大来降低图像质量            A.Emboss(p=0.2),  # 压印输入图像并将结果与原始图像叠加        ],            # voc: [xmin, ymin, xmax, ymax]  # 经过归一化            # min_area: 表示bbox占据的像素总个数, 当数据增强后, 若bbox小于这个值则从返回的bbox列表删除该bbox.            # min_visibility: 值域为[0,1], 如果增强后的bbox面积和增强前的bbox面积比值小于该值, 则删除该bbox            A.BboxParams(format='pascal_voc', min_area=0., min_visibility=0., label_fields=['category_id'])        )        print('--------------*--------------')        print("labels: ", self.labels)        if self.start_aug_id is None:            self.start_aug_id = len(os.listdir(self.pre_xml_path))            print("the start_aug_id is not set, default: len(images)", self.start_aug_id)        print('--------------*--------------')    def get_xml_data(self, xml_filename):        with open(os.path.join(self.pre_xml_path, xml_filename), 'r') as f:            tree = ET.parse(f)            root = tree.getroot()            image_name = tree.find('filename').text            size = root.find('size')            w = int(size.find('width').text)            h = int(size.find('height').text)            bboxes = []            cls_id_list = []            for obj in root.iter('object'):                # difficult = obj.find('difficult').text                difficult = obj.find('difficult').text                cls_name = obj.find('name').text  # label                if cls_name not in LABELS or int(difficult) == 1:                    continue                xml_box = obj.find('bndbox')                xmin = int(xml_box.find('xmin').text)                ymin = int(xml_box.find('ymin').text)                xmax = int(xml_box.find('xmax').text)                ymax = int(xml_box.find('ymax').text)                # 标注越界修正                if xmax > w:                    xmax = w                if ymax > h:                    ymax = h                bbox = [xmin, ymin, xmax, ymax]                bboxes.append(bbox)                cls_id_list.append(self.labels.index(cls_name))            # 读取图片            image = cv2.imread(os.path.join(self.pre_image_path, image_name))        return bboxes, cls_id_list, image, image_name    def aug_image(self):        xml_list = os.listdir(self.pre_xml_path)        cnt = self.start_aug_id        for xml in xml_list:            # AI Studio下会存在.ipynb_checkpoints文件, 为了不报错, 根据文件后缀过滤            file_suffix = xml.split('.')[-1]            if file_suffix not in ['xml']:                continue            bboxes, cls_id_list, image, image_name = self.get_xml_data(xml)            anno_dict = {'image': image, 'bboxes': bboxes, 'category_id': cls_id_list}            # 获得增强后的数据 {"image", "bboxes", "category_id"}            augmented = self.aug(**anno_dict)            # 保存增强后的数据            flag = self.save_aug_data(augmented, image_name, cnt)            if flag:                cnt += 1            else:                continue    def save_aug_data(self, augmented, image_name, cnt):        aug_image = augmented['image']        aug_bboxes = augmented['bboxes']        aug_category_id = augmented['category_id']        # print(aug_bboxes)        # print(aug_category_id)        name = '0' * self.max_len        # 获取图片的后缀名        image_suffix = image_name.split(".")[-1]        # 未增强对应的xml文件名        pre_xml_name = image_name.replace(image_suffix, 'xml')        # 获取新的增强图像的文件名        cnt_str = str(cnt)        length = len(cnt_str)        new_image_name = name[:-length] + cnt_str + "." + image_suffix        # 获取新的增强xml文本的文件名        new_xml_name = new_image_name.replace(image_suffix, 'xml')        # 获取增强后的图片新的宽和高        new_image_height, new_image_width = aug_image.shape[:2]        # 深拷贝图片        aug_image_copy = aug_image.copy()        # 在对应的原始xml上进行修改, 获得增强后的xml文本        with open(os.path.join(self.pre_xml_path, pre_xml_name), 'r') as pre_xml:            aug_tree = ET.parse(pre_xml)        # 修改image_filename值        root = aug_tree.getroot()        aug_tree.find('filename').text = new_image_name        # 修改变换后的图片大小        size = root.find('size')        size.find('width').text = str(new_image_width)        size.find('height').text = str(new_image_height)        # 修改每一个标注框        for index, obj in enumerate(root.iter('object')):            obj.find('name').text = self.labels[aug_category_id[index]]            xmin, ymin, xmax, ymax = aug_bboxes[index]            xml_box = obj.find('bndbox')            xml_box.find('xmin').text = str(int(xmin))            xml_box.find('ymin').text = str(int(ymin))            xml_box.find('xmax').text = str(int(xmax))            xml_box.find('ymax').text = str(int(ymax))            if self.is_show:                tl = 2                text = f"{LABELS[aug_category_id[index]]}"                t_size = cv2.getTextSize(text, 0, fontScale=tl / 3, thickness=tl)[0]                cv2.rectangle(aug_image_copy, (int(xmin), int(ymin) - 3),                              (int(xmin) + t_size[0], int(ymin) - t_size[1] - 3),                              (0, 0, 255), -1, cv2.LINE_AA)  # filled                cv2.putText(aug_image_copy, text, (int(xmin), int(ymin) - 2), 0, tl / 3, (255, 255, 255), tl,                            cv2.LINE_AA)                cv2.rectangle(aug_image_copy, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 0), 2)        if self.is_show:            cv2.imshow('aug_image_show', aug_image_copy)            # 按下s键保存增强,否则取消保存此次增强            key = cv2.waitKey(0)            if key & 0xff == ord('s'):                pass            else:                return False        # 保存增强后的图片        cv2.imwrite(os.path.join(self.aug_image_save_path, new_image_name), aug_image)        # 保存增强后的xml文件        tree = ET.ElementTree(root)        tree.write(os.path.join(self.aug_xml_save_path, new_xml_name))                return True
登录后复制    In [34]
import osimport cv2import albumentations as Aimport xml.etree.ElementTree as ETimport matplotlib.pyplot as plt# 原始的xml路径和图片路径PRE_IMAGE_PATH = '/home/aistudio/work/TestImage/VOC/images'PRE_XML_PATH = '/home/aistudio/work/TestImage/VOC/labels'# 增强后保存的xml路径和图片路径AUG_SAVE_IMAGE_PATH ='/home/aistudio/work/TestImage/VOC/images'AUG_SAVE_XML_PATH = '/home/aistudio/work/TestImage/VOC/labels'# 标签列表LABELS = ['zu', 'pai', 'lan']aug = VOCAug(    pre_image_path=PRE_IMAGE_PATH,    pre_xml_path=PRE_XML_PATH,    aug_image_save_path=AUG_SAVE_IMAGE_PATH,    aug_xml_save_path=AUG_SAVE_XML_PATH,    start_aug_id=None,    labels=LABELS,    is_show=False,)aug.aug_image()# cv2.destroyAllWindows()original_image1 = cv2.imread('/home/aistudio/work/TestImage/VOC/images/0000.webp')transformed_image1 = cv2.imread('/home/aistudio/work/TestImage/VOC/images/0003.webp')original_image2 = cv2.imread('/home/aistudio/work/TestImage/VOC/images/0001.webp')transformed_image2 = cv2.imread('/home/aistudio/work/TestImage/VOC/images/0004.webp')original_image1 = cv2.cvtColor(original_image1, cv2.COLOR_BGR2RGB)transformed_image1 = cv2.cvtColor(transformed_image1, cv2.COLOR_BGR2RGB)original_image2 = cv2.cvtColor(original_image2, cv2.COLOR_BGR2RGB)transformed_image2 = cv2.cvtColor(transformed_image2, cv2.COLOR_BGR2RGB)plt.subplot(2, 2, 1), plt.title("original image"), plt.axis('off')plt.imshow(original_image1) plt.subplot(2, 2, 2), plt.title("transformed image"), plt.axis('off')plt.imshow(transformed_image1)plt.subplot(2, 2, 3), plt.title("original image"), plt.axis('off')plt.imshow(original_image2) plt.subplot(2, 2, 4), plt.title("transformed image"), plt.axis('off')plt.imshow(transformed_image2)
登录后复制        
['zu', 'pai', 'lan']--------------*--------------labels:  ['zu', 'pai', 'lan']the start_aug_id is not set, default: len(images) 3--------------*--------------
登录后复制        
登录后复制                
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/font_manager.py:1331: UserWarning: findfont: Font family ['sans-serif'] not found. Falling back to DejaVu Sans  (prop.get_family(), self.defaultFamily[fontext]))
登录后复制        
登录后复制登录后复制登录后复制                

四、总结

  本项目对三种常见的数据格式进行了批量增强,有像素级别的和空间级别的。更适合一些需要自己做数据集的比赛,作者在参加2024年中国机器人及人工智能大赛的深度学习赛项中(现在时间是20240721),需要根据最新提供的资料来自己制作数据集、标定数据集、进行训练、最后部署到jetson nano。在对小车的调试过程中发现需要对自己采集的数据集进行一定的增强来应对现实中的情况:
    ①光线环境问题,线下比赛中环境不可预测,光线环境问题就很重要,那么就可以使用albumentations库来改变数据集的明暗程度;同时环境多变那么数据集一定要充足。
    ②模糊,检测物体大小问题,比赛为竞速模式,小车的速度就要尽可能的提高,就可能会造成摄像头读取的图片模糊,清晰度不高,而且读取的检测物不一定是按由小变大的规律出现,可能按由部分到整体逐渐显示出全图,那么我们可以使用albumentations库对数据集进行模糊、添加噪声、进行缩放、随机裁剪等操作来适应比赛环境。
  本项目新颖在无需对数据集再次标定,直接生成对应的标注文件,之前看到大多数的albumentations的使用仅仅是对图片做简单的增强,而albumentations库也提供了多种标注格式的接口,那为什么不用起来呢???所以本项目整理总结了对最常见的三种标注格式的批量增强,特别是对于做比赛的童鞋帮助很大!

来源:https://www.php.cn/faq/1423481.html
免责声明: 游乐网为非赢利性网站,所展示的游戏/软件/文章内容均来自于互联网或第三方用户上传分享,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系youleyoucom@outlook.com。

相关攻略

Pywinrm,一个 Python 管理利器!
科技数码
Pywinrm,一个 Python 管理利器!

Pywinrm 通过Windows远程管理(WinRM)协议,让Python能够像操作本地一样执行远程Windows命令,真正打通了跨平台管理的最后一公里。 在混合IT环境中,Linux机器管理Wi

热心网友
04.07
全网炸了!5亿人用的Axios竟被投毒,你的密钥还保得住吗?
科技数码
全网炸了!5亿人用的Axios竟被投毒,你的密钥还保得住吗?

早些时候,聊过 Python 领域那场惊心动魄的供应链攻击。当时我就感叹,虽然我们 JavaScript 开发者对这类套路烂熟于心,但亲眼目睹这种规模的“投毒”还是头一次。 早些时候,聊过 Pyth

热心网友
04.07
Toga,一个超精简的 Python 项目!
科技数码
Toga,一个超精简的 Python 项目!

Toga 是 BeeWare 家族的核心成员,号称“写一次,跑遍所有平台”,而且用的是系统原生控件,不是那种一看就是网页套壳的界面 。 写了这么多年 Python,你是不是也想过:要是能一套代码跑

热心网友
04.07
Python 异常处理:别再用裸奔的 try 了
科技数码
Python 异常处理:别再用裸奔的 try 了

异常处理的核心:让错误在正确的地方被有效处理。正确的地方,就是别在底层就把异常吞了,也别在顶层还抛裸奔的 Exception。 异常处理写得好,半夜不用起来改 bug。1 你是不是也这么干过?tr

热心网友
04.07
OpenClaw如何自定义SKILL
AI
OpenClaw如何自定义SKILL

1 Skills机制概述 提起OpenClaw的Skills机制,不少人可能会把它想象成传统意义上的可执行插件。其实,它的内涵要更精妙一些。 简单说,Skills本质上是一套基于提示驱动的能力扩展机制。它并不是一个可以独立“跑”起来的程序模块,而是通过一份结构化描述文件(核心就是那个SKILL m

热心网友
04.07

最新APP

宝宝过生日
宝宝过生日
应用辅助 04-07
台球世界
台球世界
体育竞技 04-07
解绳子
解绳子
休闲益智 04-07
骑兵冲突
骑兵冲突
棋牌策略 04-07
三国真龙传
三国真龙传
角色扮演 04-07

热门推荐

美国SEC主席Paul Atkins证实:加密货币安全港提案已送交白宫审查
web3.0
美国SEC主席Paul Atkins证实:加密货币安全港提案已送交白宫审查

加密货币行业翘首以盼的监管里程碑,终于有了实质性进展。美国证券交易委员会(SEC)主席保罗·阿特金斯(Paul Atkins)近日证实,那份允许加密项目在早期获得注册豁免权的“安全港”框架提案,已经正式送抵白宫,进入了最终审查阶段。 在范德堡大学与区块链协会联合举办的数字资产峰会上,阿特金斯透露了这

热心网友
04.08
微策略Strategy报告:第一季录得144.6亿美元浮亏 再斥资约3.3亿美元买进4871枚比特币
web3.0
微策略Strategy报告:第一季录得144.6亿美元浮亏 再斥资约3.3亿美元买进4871枚比特币

微策略Strategy报告:第一季录得144 6亿美元浮亏 再斥资约3 3亿美元买进4871枚比特币 市场震荡的威力有多大?看看Strategy的最新季报就明白了。根据其最新向美国证管会(SEC)提交的8-K报告,受市场剧烈波动影响,这家公司所持的比特币在第一季度录得了一笔惊人的数字——144 6亿

热心网友
04.08
稳定币发行商Tether再扩Web3版图!Paolo Ardoino:正开发去中心化搜索引擎Hypersearch
web3.0
稳定币发行商Tether再扩Web3版图!Paolo Ardoino:正开发去中心化搜索引擎Hypersearch

稳定币巨头Tether的动向,向来是加密世界的风向标。这不,它向Web3基础设施的版图扩张,又迈出了关键一步。公司执行长Paolo Ardoino在社交平台X上透露,其工程团队正在全力“烹制”一个新项目——去中心化搜索引擎 “Hypersearch”。这个消息一出,立刻引发了行业的广泛猜想。 采用D

热心网友
04.08
Base链首个原生DeFi借贷协议Seamless Protocol倒闭 将于2026年6月30日下线
web3.0
Base链首个原生DeFi借贷协议Seamless Protocol倒闭 将于2026年6月30日下线

基地位于Coinbase旗下以太坊Layer2网络Base的Seamless Protocol,日前正式宣告了服务的终结。这个曾经吸引了超过20万用户的原生DeFi借贷协议,在运营不到三年后,终究没能跑赢时间。它主打的核心产品是Integrated Leverage Markets(ILMs)——一

热心网友
04.08
PAAL代币如何参与治理?社区投票能决定哪些事项?
web3.0
PAAL代币如何参与治理?社区投票能决定哪些事项?

PAAL代币揭秘:深度解析Web3社区治理的核心钥匙 在去中心化自治组织的浪潮中,谁真正掌握了项目的话语权?PAAL代币提供了一套系统化的答案。它不仅是生态内流转的价值媒介,更是开启链上治理大门的核心凭证。通过持有并质押PAAL代币,用户能够对协议升级、资金分配乃至战略方向等关键事务投出决定性的一票

热心网友
04.08