图像增强用于生成更多的样本,提高模型的鲁棒性。主要包括:
- 镜像/上下翻转
- 随机剪裁/拉伸
- 亮度/对比度/饱和度/色调变换
导入用到的包
import torch
import torchvision
from torch import nn
import PIL
from matplotlib import pyplot as plt
下载并读取示例图片
!wget https://bastudypic.oss-cn-hongkong.aliyuncs.com/auto-upload/images/cat.jpg
img = PIL.Image.open('cat.jpg')
plt.imshow(img);
构建显示图片的函数
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# Tensor Image
ax.imshow(img.numpy())
else:
# PIL Image
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
Augmentation示例函数
def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
Y = [aug(img) for _ in range(num_rows * num_cols)]
show_images(Y, num_rows, num_cols, scale=scale)
50%概率左右翻转
apply(img, torchvision.transforms.RandomHorizontalFlip())
50%概率上下翻转
apply(img, torchvision.transforms.RandomVerticalFlip())
随机剪裁,随机剪裁原始面积10%~100%的区域,宽高比改变为随机0.5-2
shape_aug = torchvision.transforms.RandomResizedCrop(
(200, 200), scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)
改变颜色,包括:亮度、对比度、饱和度和色调。比如,随机改变亮度,为原图的50%-150%
- 参数brightness=0.5的意思,是(1-0.5,1+0.5)区间
apply(img,
torchvision.transforms.ColorJitter(brightness=0.5, contrast=0,
saturation=0, hue=0))
改变色调
apply(img,
torchvision.transforms.ColorJitter(brightness=0, contrast=0, saturation=0,
hue=0.5))
当然,也可以同时设置亮度、对比度、饱和度和色调
color_aug = torchvision.transforms.ColorJitter(brightness=0.5, contrast=0.5,
saturation=0.5, hue=0.5)
apply(img, color_aug)
也可以结合多个增强方法
augs = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)