img = PILImage(PILImage.create(TEST_IMAGE).resize((600,400)))
As for all Transform
you can pass encodes
and decodes
at init or subclass and implement them. You can do the same for the before_call
method that is called at each __call__
. Note that to have a consistent state for inputs and targets, a RandTransform
must be applied at the tuple level.
By default the before_call behavior is to execute the transform with probability p
(if subclassing and wanting to tweak that behavior, the attribute self.do
, if it exists, is looked for to decide if the transform is executed or not).
RandTransform
is only applied to the training set by default, so you have to pass split_idx=0
if you are calling it directly and not through a Datasets
. That behavior can be changed by setting the attr split_idx
of the transform to None
.def _add1(x): return x+1
dumb_tfm = RandTransform(enc=_add1, p=0.5)
start,d1,d2 = 2,False,False
for _ in range(40):
t = dumb_tfm(start, split_idx=0)
if dumb_tfm.do: test_eq(t, start+1); d1=True
else: test_eq(t, start) ; d2=True
assert d1 and d2
dumb_tfm
_,axs = subplots(1,2)
show_image(img, ctx=axs[0], title='original')
show_image(img.flip_lr(), ctx=axs[1], title='flipped');
tflip = FlipItem(p=1.)
test_eq(tflip(bbox,split_idx=0), tensor([[1.,0., 0.,1]]) -1)
By default each of the 8 dihedral transformations (including noop) have the same probability of being picked when the transform is applied. You can customize this behavior by passing your own draw
function. To force a specific flip, you can also pass an integer between 0 and 7.
_,axs = subplots(2, 4)
for ax in axs.flatten():
show_image(DihedralItem(p=1.)(img, split_idx=0), ctx=ax)
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,sz in zip(axs.flatten(), [300, 500, 700]):
show_image(img.crop_pad(sz), ctx=ax, title=f'Size {sz}');
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,mode in zip(axs.flatten(), [PadMode.Zeros, PadMode.Border, PadMode.Reflection]):
show_image(img.crop_pad((600,700), pad_mode=mode), ctx=ax, title=mode);
_,axs = plt.subplots(1,3,figsize=(12,4))
f = RandomCrop(200)
for ax in axs: show_image(f(img), ctx=ax);
On the validation set, we take a center crop.
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax in axs: show_image(f(img, split_idx=1), ctx=ax);
test_eq(ResizeMethod.Squish, 'squish')
Resize(224)
size
can be an integer (in which case images will be resized to a square) or a tuple. Depending on the method
:
- we squish any rectangle to
size
- we resize so that the shorter dimension is a match an use padding with
pad_mode
- we resize so that the larger dimension is match and crop (randomly on the training set, center crop for the validation set)
When doing the resize, we use resamples[0]
for images and resamples[1]
for segmentation masks.
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
rsz = Resize(256, method=method)
show_image(rsz(img, split_idx=0), ctx=ax, title=method);
On the validation set, the crop is always a center crop (on the dimension that's cropped).
_,axs = plt.subplots(1,3,figsize=(12,4))
for ax,method in zip(axs.flatten(), [ResizeMethod.Squish, ResizeMethod.Pad, ResizeMethod.Crop]):
rsz = Resize(256, method=method)
show_image(rsz(img, split_idx=1), ctx=ax, title=method);
The crop picked as a random scale in range (min_scale,1)
and ratio
in the range passed, then the resize is done with resamples[0]
for images and resamples[1]
for segmentation masks. On the validation set, we center crop the image if it's ratio isn't in the range (to the minmum or maximum value) then resize.
crop = RandomResizedCrop(256)
_,axs = plt.subplots(3,3,figsize=(9,9))
for ax in axs.flatten():
cropped = crop(img)
show_image(cropped, ctx=ax);
Squish is used on the validation set, removing val_xtra
proportion of each side first.
_,axs = subplots(1,3)
for ax in axs.flatten(): show_image(crop(img, split_idx=1), ctx=ax);
test_eq(cropped.shape, [256,256])
RatioResize(256)(img)
test_eq(RatioResize(256)(img).size[0], 256)
test_eq(RatioResize(256)(img.dihedral(3)).size[1], 256)
timg = TensorImage(array(img)).permute(2,0,1).float()/255.
def _batch_ex(bs): return TensorImage(timg[None].expand(bs, *timg.shape).clone())
Multipliy all the matrices returned by aff_fs
before doing the corresponding affine transformation on a basic grid corresponding to size
, then applies all coord_fs
on the resulting flow of coordinates before finally doing an interpolation with mode
and pad_mode
.
t = _batch_ex(8)
rrc = RandomResizedCropGPU(224, p=1.)
y = rrc(t)
_,axs = plt.subplots(2,4, figsize=(12,6))
for ax in axs.flatten():
show_image(y[i], ctx=ax)
x = torch.zeros(5,2,3)
def_draw = lambda x: torch.randint(0,8, (x.size(0),))
t = _draw_mask(x, def_draw)
assert (0. <= t).all() and (t <= 7).all()
t = _draw_mask(x, def_draw, 1)
assert (0. <= t).all() and (t <= 1).all()
test_eq(_draw_mask(x, def_draw, 1, p=1), tensor([1.,1,1,1,1]))
test_eq(_draw_mask(x, def_draw, [0,1,2,3,4], p=1), tensor([0.,1,2,3,4]))
for i in range(5):
t = _draw_mask(x, def_draw, 1, batch=True)
assert (t==torch.zeros(5)).all() or (t==torch.ones(5)).all()
x = flip_mat(torch.randn(100,4,3))
test_eq(set(x[:,0,0].numpy()), {-1,1}) #might fail with probability 2*2**(-100) (picked only 1s or -1s)
t = _pnt2tensor([[1,0], [2,1]], (3,3))
y = TensorImage(t[None,None]).flip_batch(p=1.)
test_eq(y, _pnt2tensor([[1,0], [0,1]], (3,3))[None,None])
pnts = TensorPoint((tensor([[1.,0.], [2,1]]) -1)[None])
test_eq(pnts.flip_batch(p=1.), tensor([[[1.,0.], [0,1]]]) -1)
bbox = TensorBBox(((tensor([[1.,0., 2.,1]]) -1)[None]))
test_eq(bbox.flip_batch(p=1.), tensor([[[0.,0., 1.,1.]]]) -1)
Flip(0.3)
flip = Flip(p=1.)
t = _pnt2tensor([[1,0], [2,1]], (3,3))
y = flip(TensorImage(t[None,None]), split_idx=0)
test_eq(y, _pnt2tensor([[1,0], [0,1]], (3,3))[None,None])
pnts = TensorPoint((tensor([[1.,0.], [2,1]]) -1)[None])
test_eq(flip(pnts, split_idx=0), tensor([[[1.,0.], [0,1]]]) -1)
bbox = TensorBBox(((tensor([[1.,0., 2.,1]]) -1)[None]))
test_eq(flip(bbox, split_idx=0), tensor([[[0.,0., 1.,1.]]]) -1)
t = _batch_ex(8)
draw = DeterministicDraw(list(range(8)))
for i in range(15): test_eq(draw(t), torch.zeros(8)+(i%8))
dih = DeterministicFlip({'p':.3})
t = _batch_ex(8)
dih = DeterministicFlip()
_,axs = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
y = dih(t)
show_image(y[0], ctx=ax, title=f'Call {i}')
draw
can be specified if you want to customize which flip is picked when the transform is applied (default is a random number between 0 and 7). It can be an integer between 0 and 7, a list of such integers (which then should have a length equal to the size of the batch) or a callable that returns an integer between 0 and 7.
t = _batch_ex(8)
dih = Dihedral(p=1., draw=list(range(8)))
y = dih(t)
y = t.dihedral_batch(p=1., draw=list(range(8)))
_,axs = plt.subplots(2,4, figsize=(12,5))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'Flip {i}')
t = _batch_ex(8)
dih = DeterministicDihedral()
_,axs = plt.subplots(2,4, figsize=(12,6))
for i,ax in enumerate(axs.flatten()):
y = dih(t)
show_image(y[0], ctx=ax, title=f'Call {i}')
draw
can be specified if you want to customize which angle is picked when the transform is applied (default is a random flaot between -max_deg
and max_deg
). It can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callable that returns a float.
thetas = [-30,-15,0,15,30]
y = _batch_ex(5).rotate(draw=thetas, p=1.)
_,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'{thetas[i]} degrees')
draw
, draw_x
and draw_y
can be specified if you want to customize which scale and center are picked when the transform is applied (default is a random float between 1 and max_zoom
for the first, between 0 and 1 for the last two). Each can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callbale that returns a float.
draw_x
and draw_y
are expected to be the position of the center in pct, 0 meaning the most left/top possible and 1 meaning the most right/bottom possible.
scales = [0.8, 1., 1.1, 1.25, 1.5]
n = len(scales)
y = _batch_ex(n).zoom(draw=scales, p=1., draw_x=0.5, draw_y=0.5)
fig,axs = plt.subplots(1, n, figsize=(12,3))
fig.suptitle('Center zoom with different scales')
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
y = _batch_ex(4).zoom(p=1., draw=1.5)
fig,axs = plt.subplots(1,4, figsize=(12,3))
fig.suptitle('Constant scale and different random centers')
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax)
draw_x
and draw_y
can be specified if you want to customize the magnitudes that are picked when the transform is applied (default is a random float between -magnitude
and magnitude
. Each can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callable that returns a float.
scales = [-0.4, -0.2, 0., 0.2, 0.4]
warp = Warp(p=1., draw_y=scales, draw_x=0.)
y = warp(_batch_ex(5), split_idx=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
fig.suptitle('Vertical warping')
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'magnitude {scales[i]}')
scales = [-0.4, -0.2, 0., 0.2, 0.4]
warp = Warp(p=1., draw_x=scales, draw_y=0.)
y = warp(_batch_ex(5), split_idx=0)
fig,axs = plt.subplots(1,5, figsize=(15,3))
fig.suptitle('Horizontal warping')
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'magnitude {scales[i]}')
Lighting transforms are transforms that effect how light is represented in an image. These don't change the location of the object like previous transforms, but instead simulate how light could change in a scene. The simclr paper evaluates these transforms against other transforms for their use case of self-supurved image classification, note they use "color" and "color distortion" to refer to a combination of these transforms.
Most lighting transforms work better in "logit space", as we do not want to blowout the image by going over maximum or minimum brightness. Taking the sigmoid of the logit allows us to get back to "linear space."
x=TensorImage(torch.tensor([.01* i for i in range(0,101)]))
f_lin= lambda x:(2*(x-0.5)+0.5).clamp(0,1) #blue line
f_log= lambda x:2*x #red line
plt.plot(x,f_lin(x),'b',x,x.lighting(f_log),'r')
The above graph shows the results of doing a contrast transformation in both linear and logit space. Notice how the blue linear plot has to be clamped, and we have lost information on how large 0.0 is by comparision to 0.2. While in the red plot the values curve, so we keep this relative information.
Brightness refers to the amount of light on a scene. This can be zero in which the image is completely black or one where the image is completely white. This may be especially useful if you expect your dataset to have over or under exposed images.
Brightness(0.5, p=0.8)
draw
can be specified if you want to customize the magnitude that is picked when the transform is applied (default is a random float between -0.5*(1-max_lighting)
and 0.5*(1+max_lighting)
. Each can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callable that returns a float.
scales = [0.1, 0.3, 0.5, 0.7, 0.9]
y = _batch_ex(5).brightness(draw=scales, p=1.)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()):
show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
Contrast pushes pixels to either the maximum or minimum values. The minimum value for contrast is a solid gray image. As an example take a picture of a bright light source in a dark room. Your eyes should be able to see some detail in the room, but the photo taken should instead have much higher contrast, with all of the detail in the background missing to the darkness. This is one example of what this transform can help simulate.
draw
can be specified if you want to customize the magnitude that is picked when the transform is applied (default is a random float taken with the log uniform distribution between (1-max_lighting)
and 1/(1-max_lighting)
. Each can be a float, a list of floats (which then should have a length equal to the size of the batch) or a callable that returns a float.
scales = [0.65, 0.8, 1., 1.25, 1.55]
y = _batch_ex(5).contrast(p=1., draw=scales)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
'%.3f' % sum([0.2989,0.5870,0.1140])
The above is just one way to convert to grayscale. We chose this one because it was fast. Notice that the sum of the weight of each channel is 1.
scales = [0., 0.5, 1., 1.5, 2.0]
y = _batch_ex(5).saturation(p=1., draw=scales)
fig,axs = plt.subplots(1,5, figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
Saturation controls the amount of color in the image, but not the lightness or darkness of an image. If has no effect on neutral colors such as whites,grays and blacks. At zero saturation you actually get a grayscale image. Pushing saturation past one causes more neutral colors to take on any underlying chromatic color.
fig,axs=plt.subplots(figsize=(20, 4),ncols=5)
for ax in axs:
ax.set_ylabel('Hue')
ax.set_xlabel('Saturation')
ax.set_yticklabels([])
ax.set_xticklabels([])
hsv=torch.stack([torch.arange(0,2.1,0.01)[:,None].repeat(1,210),torch.arange(0,1.05,0.005)[None].repeat(210,1),torch.ones([210,210])])[None]
for ax,i in zip(axs,range(0,5)):
if i>0: hsv[:,2].mul_(0.80)
ax.set_title('V='+'%.1f' %0.8**i)
ax.imshow(hsv2rgb(hsv)[0].permute(1,2,0))
For the Hue transform we are using hsv space instead of logit space. HSV stands for hue,saturation and value. Hue in hsv space just cycles through colors of the rainbow. Notices how there is no maximum, because the colors just repeat.
Above are some examples of Hue(H) and Saturation(S) at various Values(V). One property of note in HSV space is that V controls the color you get at minimum saturation when in HSV space.
scales = [0.5, 0.75, 1., 1.5, 1.75]
y = _batch_ex(len(scales)).hue(p=1., draw=scales)
fig,axs = plt.subplots(1,len(scales), figsize=(15,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax, title=f'scale {scales[i]}')
Random Erasing Data Augmentation. This variant, designed by Ross Wightman, is applied to either a batch or single image tensor after it has been normalized.
Since this should be applied after normalization, we'll define a helper to apply a function inside normalization.
nrm = Normalize.from_stats(*imagenet_stats, cuda=False)
f = partial(cutout_gaussian, areas=[(100,200,100,200),(200,300,200,300)])
show_image(norm_apply_denorm(timg, f, nrm)[0]);
Args:
- p: The probability that the Random Erasing operation will be performed
- sl: Minimum proportion of erased area
- sh: Maximum proportion of erased area
- min_aspect: Minimum aspect ratio of erased area
- max_count: maximum number of erasing blocks per image, area per box is scaled by count
tfm = RandomErasing(p=1., max_count=6)
_,axs = subplots(2,3, figsize=(12,6))
f = partial(tfm, split_idx=0)
for i,ax in enumerate(axs.flatten()): show_image(norm_apply_denorm(timg, f, nrm)[0], ctx=ax)
y = _batch_ex(6)
_,axs = plt.subplots(2,3, figsize=(12,6))
y = norm_apply_denorm(y, f, nrm)
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax)
tfm = RandomErasing(p=1., max_count=6)
_,axs = subplots(2,3, figsize=(12,6))
f = partial(tfm, split_idx=1)
for i,ax in enumerate(axs.flatten()): show_image(norm_apply_denorm(timg, f, nrm)[0], ctx=ax)
tfms = [Rotate(draw=10., p=1), Zoom(draw=1.1, draw_x=0.5, draw_y=0.5, p=1.)]
comp = setup_aug_tfms([Rotate(draw=10., p=1), Zoom(draw=1.1, draw_x=0.5, draw_y=0.5, p=1.)])
test_eq(len(comp), 1)
x = torch.randn(4,3,5,5)
test_close(comp[0]._get_affine_mat(x)[...,:2],tfms[0]._get_affine_mat(x)[...,:2] @ tfms[1]._get_affine_mat(x)[...,:2])
#We can't test that the ouput of comp or the composition of tfms on x is the same cause it's not (1 interpol vs 2 sp)
tfms = [Rotate(), Zoom(), Warp(), Brightness(), Flip(), Contrast()]
comp = setup_aug_tfms(tfms)
aff_tfm,lig_tfm = comp
test_eq(len(aff_tfm.aff_fs+aff_tfm.coord_fs+comp[1].fs), 6)
test_eq(len(aff_tfm.aff_fs), 3)
test_eq(len(aff_tfm.coord_fs), 1)
test_eq(len(lig_tfm.fs), 2)
Random flip (or dihedral if flip_vert=True
) with p=0.5
is added when do_flip=True
. With p_affine
we apply a random rotation of max_rotate
degrees, a random zoom between min_zoom
and max_zoom
and a perspective warping of max_warp
. With p_lighting
we apply a change in brightness and contrast of max_lighting
. Custon xtra_tfms
can be added. size
, mode
and pad_mode
will be used for the interpolation. max_rotate,max_lighting,max_warp
are multiplied by mult
so you can more easily increase or decrease augmentation with a single parameter.
tfms = aug_transforms(pad_mode='zeros', mult=2, min_scale=0.5)
y = _batch_ex(9)
for t in tfms: y = t(y, split_idx=0)
_,axs = plt.subplots(1,3, figsize=(12,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax)
tfms = aug_transforms(pad_mode='zeros', mult=2, batch=True)
y = _batch_ex(9)
for t in tfms: y = t(y, split_idx=0)
_,axs = plt.subplots(1,3, figsize=(12,3))
for i,ax in enumerate(axs.flatten()): show_image(y[i], ctx=ax)
camvid = untar_data(URLs.CAMVID_TINY)
fns = get_image_files(camvid/'images')
cam_fn = fns[0]
mask_fn = camvid/'labels'/f'{cam_fn.stem}_P{cam_fn.suffix}'
def _cam_lbl(fn): return mask_fn
cam_dsrc = Datasets([cam_fn]*10, [PILImage.create, [_cam_lbl, PILMask.create]])
cam_tdl = TfmdDL(cam_dsrc.train, after_item=ToTensor(),
after_batch=[IntToFloatTensor(), *aug_transforms()], bs=9)
cam_tdl.show_batch(max_n=9, vmin=1, vmax=30)
mnist = untar_data(URLs.MNIST_TINY)
mnist_fn = 'images/mnist3.png'
pnts = np.array([[0,0], [0,35], [28,0], [28,35], [9, 17]])
def _pnt_lbl(fn)->None: return TensorPoint.create(pnts)
pnt_dsrc = Datasets([mnist_fn]*10, [[PILImage.create, Resize((35,28))], _pnt_lbl])
pnt_tdl = TfmdDL(pnt_dsrc.train, after_item=[PointScaler(), ToTensor()],
after_batch=[IntToFloatTensor(), *aug_transforms(max_warp=0)], bs=9)
pnt_tdl.show_batch(max_n=9)
coco = untar_data(URLs.COCO_TINY)
images, lbl_bbox = get_annotations(coco/'train.json')
idx=2
coco_fn,bbox = coco/'train'/images[idx],lbl_bbox[idx]
def _coco_bb(x): return TensorBBox.create(bbox[0])
def _coco_lbl(x): return bbox[1]
coco_dsrc = Datasets([coco_fn]*10, [PILImage.create, [_coco_bb], [_coco_lbl, MultiCategorize(add_na=True)]], n_inp=1)
coco_tdl = TfmdDL(coco_dsrc, bs=9, after_item=[BBoxLabeler(), PointScaler(), ToTensor()],
after_batch=[IntToFloatTensor(), *aug_transforms()])
coco_tdl.show_batch(max_n=9)
coco_tdl.after_batch