图像增强用于生成更多的样本,提高模型的鲁棒性。主要包括:

  • 镜像/上下翻转
  • 随机剪裁/拉伸
  • 亮度/对比度/饱和度/色调变换

导入用到的包

import torch
import torchvision
from torch import nn

import PIL
from matplotlib import pyplot as plt

下载并读取示例图片

!wget https://bastudypic.oss-cn-hongkong.aliyuncs.com/auto-upload/images/cat.jpg
img = PIL.Image.open('cat.jpg')
plt.imshow(img);

image-20210620205205628

构建显示图片的函数

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
    figsize = (num_cols * scale, num_rows * scale)
    _, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
    axes = axes.flatten()
    for i, (ax, img) in enumerate(zip(axes, imgs)):
        if torch.is_tensor(img):
            # Tensor Image
            ax.imshow(img.numpy())
        else:
            # PIL Image
            ax.imshow(img)
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        if titles:
            ax.set_title(titles[i])
    return axes

Augmentation示例函数

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    show_images(Y, num_rows, num_cols, scale=scale)

50%概率左右翻转

apply(img, torchvision.transforms.RandomHorizontalFlip())

image-20210620205604894

50%概率上下翻转

apply(img, torchvision.transforms.RandomVerticalFlip())

image-20210620205636797

随机剪裁,随机剪裁原始面积10%~100%的区域,宽高比改变为随机0.5-2

shape_aug = torchvision.transforms.RandomResizedCrop(
    (200, 200), scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

image-20210620205938062

改变颜色,包括:亮度、对比度、饱和度和色调。比如,随机改变亮度,为原图的50%-150%

  • 参数brightness=0.5的意思,是(1-0.5,1+0.5)区间
apply(img,
    torchvision.transforms.ColorJitter(brightness=0.5, contrast=0,
                                       saturation=0, hue=0))

image-20210620210208942

改变色调

apply(img,
    torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0,
                                       hue=0.5))

image-20210620210631290

当然,也可以同时设置亮度、对比度、饱和度和色调

color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5,
                                               saturation=0.5, hue=0.5)
apply(img, color_aug)

image-20210620210729367

也可以结合多个增强方法

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

image-20210620210839683

最后修改:2021 年 07 月 15 日 03 : 14 PM
如果觉得我的文章对你有用,请随意赞赏