Custom fastai layers and basic functions to grab them.

Basic manipulations and resize

module[source]

module(*flds, **defaults)

Decorator to create an nn.Module using f as forward method

class Identity[source]

Identity() :: Module

Do nothing at all

test_eq(Identity()(1), 1)

class Lambda[source]

Lambda(func) :: Module

An easy way to create a pytorch layer for a simple func

def _add2(x): return x+2
tst = Lambda(_add2)
x = torch.randn(10,20)
test_eq(tst(x), x+2)
tst2 = pickle.loads(pickle.dumps(tst))
test_eq(tst2(x), x+2)
tst
Lambda(func=_add2)

class PartialLambda[source]

PartialLambda(func) :: Lambda

Layer that applies partial(func, **kwargs)

def test_func(a,b=2): return a+b
tst = PartialLambda(test_func, b=5)
test_eq(tst(x), x+5)

class Flatten[source]

Flatten(full=False) :: Module

Flatten x to a single dimension, e.g. at end of a model. full for rank-1 tensor

tst = Flatten()
x = torch.randn(10,5,4)
test_eq(tst(x).shape, [10,20])
tst = Flatten(full=True)
test_eq(tst(x).shape, [200])

class View[source]

View(*size) :: Module

Reshape x to size

tst = View(10,5,4)
test_eq(tst(x).shape, [10,5,4])

class ResizeBatch[source]

ResizeBatch(*size) :: Module

Reshape x to size, keeping batch dim the same size

tst = ResizeBatch(5,4)
test_eq(tst(x).shape, [10,5,4])

class Debugger[source]

Debugger() :: Module

A module to debug inside a model.

sigmoid_range[source]

sigmoid_range(x, low, high)

Sigmoid function with range (low, high)

test = tensor([-10.,0.,10.])
assert torch.allclose(sigmoid_range(test, -1,  2), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test, -5, -1), tensor([-5.,-3.,-1.]), atol=1e-4, rtol=1e-4)
assert torch.allclose(sigmoid_range(test,  2,  4), tensor([2.,  3., 4.]), atol=1e-4, rtol=1e-4)

class SigmoidRange[source]

SigmoidRange(low, high) :: Module

Sigmoid module with range (low, high)

tst = SigmoidRange(-1, 2)
assert torch.allclose(tst(test), tensor([-1.,0.5, 2.]), atol=1e-4, rtol=1e-4)

Pooling layers

class AdaptiveConcatPool1d[source]

AdaptiveConcatPool1d(size=None) :: Module

Layer that concats AdaptiveAvgPool1d and AdaptiveMaxPool1d

class AdaptiveConcatPool2d[source]

AdaptiveConcatPool2d(size=None) :: Module

Layer that concats AdaptiveAvgPool2d and AdaptiveMaxPool2d

If the input is bs x nf x h x h, the output will be bs x 2*nf x 1 x 1 if no size is passed or bs x 2*nf x size x size

tst = AdaptiveConcatPool2d()
x = torch.randn(10,5,4,4)
test_eq(tst(x).shape, [10,10,1,1])
max1 = torch.max(x,    dim=2, keepdim=True)[0]
maxp = torch.max(max1, dim=3, keepdim=True)[0]
test_eq(tst(x)[:,:5], maxp)
test_eq(tst(x)[:,5:], x.mean(dim=[2,3], keepdim=True))
tst = AdaptiveConcatPool2d(2)
test_eq(tst(x).shape, [10,10,2,2])

class PoolType[source]

PoolType()

adaptive_pool[source]

adaptive_pool(pool_type)

class PoolFlatten[source]

PoolFlatten(pool_type='Avg') :: Sequential

Combine nn.AdaptiveAvgPool2d and Flatten.

tst = PoolFlatten()
test_eq(tst(x).shape, [10,5])
test_eq(tst(x), x.mean(dim=[2,3]))

BatchNorm layers

BatchNorm[source]

BatchNorm(nf, ndim=2, norm_type=<NormType.Batch: 1>, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

BatchNorm layer with nf features and ndim initialized depending on norm_type.

InstanceNorm[source]

InstanceNorm(nf, ndim=2, norm_type=<NormType.Instance: 5>, affine=True, eps:float=1e-05, momentum:float=0.1, track_running_stats:bool=False)

InstanceNorm layer with nf features and ndim initialized depending on norm_type.

kwargs are passed to nn.BatchNorm and can be eps, momentum, affine and track_running_stats.

tst = BatchNorm(15)
assert isinstance(tst, nn.BatchNorm2d)
test_eq(tst.weight, torch.ones(15))
tst = BatchNorm(15, norm_type=NormType.BatchZero)
test_eq(tst.weight, torch.zeros(15))
tst = BatchNorm(15, ndim=1)
assert isinstance(tst, nn.BatchNorm1d)
tst = BatchNorm(15, ndim=3)
assert isinstance(tst, nn.BatchNorm3d)
tst = InstanceNorm(15)
assert isinstance(tst, nn.InstanceNorm2d)
test_eq(tst.weight, torch.ones(15))
tst = InstanceNorm(15, norm_type=NormType.InstanceZero)
test_eq(tst.weight, torch.zeros(15))
tst = InstanceNorm(15, ndim=1)
assert isinstance(tst, nn.InstanceNorm1d)
tst = InstanceNorm(15, ndim=3)
assert isinstance(tst, nn.InstanceNorm3d)

If affine is false the weight should be None

test_eq(BatchNorm(15, affine=False).weight, None)
test_eq(InstanceNorm(15, affine=False).weight, None)

class BatchNorm1dFlat[source]

BatchNorm1dFlat(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) :: BatchNorm1d

nn.BatchNorm1d, but first flattens leading dimensions

tst = BatchNorm1dFlat(15)
x = torch.randn(32, 64, 15)
y = tst(x)
mean = x.mean(dim=[0,1])
test_close(tst.running_mean, 0*0.9 + mean*0.1)
var = (x-mean).pow(2).mean(dim=[0,1])
test_close(tst.running_var, 1*0.9 + var*0.1, eps=1e-4)
test_close(y, (x-mean)/torch.sqrt(var+1e-5) * tst.weight + tst.bias, eps=1e-4)

class LinBnDrop[source]

LinBnDrop(n_in, n_out, bn=True, p=0.0, act=None, lin_first=False) :: Sequential

Module grouping BatchNorm1d, Dropout and Linear layers

The BatchNorm layer is skipped if bn=False, as is the dropout if p=0.. Optionally, you can add an activation for after the linear layer with act.

tst = LinBnDrop(10, 20)
mods = list(tst.children())
test_eq(len(mods), 2)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Linear)

tst = LinBnDrop(10, 20, p=0.1)
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.BatchNorm1d)
assert isinstance(mods[1], nn.Dropout)
assert isinstance(mods[2], nn.Linear)

tst = LinBnDrop(10, 20, act=nn.ReLU(), lin_first=True)
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[0], nn.Linear)
assert isinstance(mods[1], nn.ReLU)
assert isinstance(mods[2], nn.BatchNorm1d)

tst = LinBnDrop(10, 20, bn=False)
mods = list(tst.children())
test_eq(len(mods), 1)
assert isinstance(mods[0], nn.Linear)

Inits

sigmoid[source]

sigmoid(input, eps=1e-07)

Same as torch.sigmoid, plus clamping to `(eps,1-eps)

sigmoid_[source]

sigmoid_(input, eps=1e-07)

Same as torch.sigmoid_, plus clamping to `(eps,1-eps)

vleaky_relu[source]

vleaky_relu(input, inplace=True)

F.leaky_relu with 0.3 slope

init_default[source]

init_default(m, func=kaiming_normal_)

Initialize m weights with func and set bias to 0.

init_linear[source]

init_linear(m, act_func=None, init='auto', bias_std=0.01)

Convolutions

class ConvLayer[source]

ConvLayer(ni, nf, ks=3, stride=1, padding=None, bias=None, ndim=2, norm_type=<NormType.Batch: 1>, bn_1st=True, act_cls=ReLU, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, groups:int=1, padding_mode:str='zeros') :: Sequential

Create a sequence of convolutional (ni to nf), ReLU (if use_activ) and norm_type layers.

The convolution uses ks (kernel size) stride, padding and bias. padding will default to the appropriate value ((ks-1)//2 if it's not a transposed conv) and bias will default to True the norm_type is Spectral or Weight, False if it's Batch or BatchZero. Note that if you don't want any normalization, you should pass norm_type=None.

This defines a conv layer with ndim (1,2 or 3) that will be a ConvTranspose if transpose=True. act_cls is the class of the activation function to use (instantiated inside). Pass act=None if you don't want an activation function. If you quickly want to change your default activation, you can change the value of defaults.activation.

init is used to initialize the weights (the bias are initialized to 0) and xtra is an optional layer to add at the end.

tst = ConvLayer(16, 32)
mods = list(tst.children())
test_eq(len(mods), 3)
test_eq(mods[1].weight, torch.ones(32))
test_eq(mods[0].padding, (1,1))
x = torch.randn(64, 16, 8, 8)#.cuda()
test_eq(tst(x).shape, [64,32,8,8])
tst = ConvLayer(16, 32, stride=2)
test_eq(tst(x).shape, [64,32,4,4])
tst = ConvLayer(16, 32, padding=0)
test_eq(tst(x).shape, [64,32,6,6])
assert mods[0].bias is None
#But can be overridden with `bias=True`
tst = ConvLayer(16, 32, bias=True)
assert first(tst.children()).bias is not None
#For no norm, or spectral/weight, bias is True by default
for t in [None, NormType.Spectral, NormType.Weight]:
    tst = ConvLayer(16, 32, norm_type=t)
    assert first(tst.children()).bias is not None
tst = ConvLayer(16, 32, ndim=3)
assert isinstance(list(tst.children())[0], nn.Conv3d)
tst = ConvLayer(16, 32, ndim=1, transpose=True)
assert isinstance(list(tst.children())[0], nn.ConvTranspose1d)
tst = ConvLayer(16, 32, ndim=3, act_cls=None)
mods = list(tst.children())
test_eq(len(mods), 2)
tst = ConvLayer(16, 32, ndim=3, act_cls=partial(nn.LeakyReLU, negative_slope=0.1))
mods = list(tst.children())
test_eq(len(mods), 3)
assert isinstance(mods[2], nn.LeakyReLU)
# def linear(in_features, out_features, bias=True, act_cls=None, init='auto'):
#     "Linear layer followed by optional activation, with optional auto-init"
#     res = nn.Linear(in_features, out_features, bias=bias)
#     if act_cls: act_cls = act_cls()
#     init_linear(res, act_cls, init=init)
#     if act_cls: res = nn.Sequential(res, act_cls)
#     return res
# @delegates(ConvLayer)
# def conv1d(ni, nf, ks, stride=1, ndim=1, norm_type=None, **kwargs):
#     "Convolutional layer followed by optional activation, with optional auto-init"
#     return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)
# @delegates(ConvLayer)
# def conv2d(ni, nf, ks, stride=1, ndim=2, norm_type=None, **kwargs):
#     "Convolutional layer followed by optional activation, with optional auto-init"
#     return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)
# @delegates(ConvLayer)
# def conv3d(ni, nf, ks, stride=1, ndim=3, norm_type=None, **kwargs):
#     "Convolutional layer followed by optional activation, with optional auto-init"
#     return ConvLayer(ni, nf, ks, stride=stride, ndim=ndim, norm_type=norm_type, **kwargs)

AdaptiveAvgPool[source]

AdaptiveAvgPool(sz=1, ndim=2)

nn.AdaptiveAvgPool layer for ndim

MaxPool[source]

MaxPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False)

nn.MaxPool layer for ndim

AvgPool[source]

AvgPool(ks=2, stride=None, padding=0, ndim=2, ceil_mode=False)

nn.AvgPool layer for ndim

fastai loss functions

The following class if the base class to warp a loss function it provides several added functionality:

  • it flattens the tensors before trying to take the losses since it's more convenient (with a potential tranpose to put axis at the end)
  • it has a potential activation method that tells the library if there is an activation fused in the loss (useful for inference and methods such as Learner.get_preds or Learner.predict)
  • it has a potential decodes method that is used on predictions in inference (for instance, an argmax in classification)
F.binary_cross_entropy_with_logits(torch.randn(4,5), torch.randint(0, 2, (4,5)).float(), reduction='none')
tensor([[0.4444, 1.1849, 1.1411, 2.2376, 0.4800],
        [3.0970, 0.2376, 0.2159, 2.0667, 0.5246],
        [0.7885, 0.7743, 0.5355, 0.6340, 1.5417],
        [0.5340, 0.4066, 0.9115, 0.5817, 0.2920]])
funcs_kwargs
<function fastcore.foundation.funcs_kwargs(cls)>

class BaseLoss[source]

BaseLoss(loss_cls, *args, axis=-1, flatten=True, floatify=False, is_2d=True, **kwargs)

Same as loss_cls, but flattens input and target.

The args and kwargs will be passed to loss_cls during the initialization to instantiate a loss function. axis is put at the end for losses like softmax that are often performed on the last axis. If floatify=True the targs will be converted to float (useful for losses that only accept float targets like BCEWithLogitsLoss) and is_2d determines if we flatten while keeping the first dimension (batch size) or completely flatten the input. We want the first for losses like Cross Entropy, and the second for pretty much anything else.

class CrossEntropyLossFlat[source]

CrossEntropyLossFlat(*args, axis=-1, weight=None, ignore_index=-100, reduction='mean', flatten=True, floatify=False, is_2d=True) :: BaseLoss

Same as nn.CrossEntropyLoss, but flattens input and target.

tst = CrossEntropyLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randint(0, 10, (32,5))
#nn.CrossEntropy would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.CrossEntropyLoss()(output,target))

#Associated activation is softmax
test_eq(tst.activation(output), F.softmax(output, dim=-1))
#This loss function has a decodes which is argmax
test_eq(tst.decodes(output), output.argmax(dim=-1))
tst = CrossEntropyLossFlat(axis=1)
output = torch.randn(32, 5, 128, 128)
target = torch.randint(0, 5, (32, 128, 128))
_ = tst(output, target)

test_eq(tst.activation(output), F.softmax(output, dim=1))
test_eq(tst.decodes(output), output.argmax(dim=1))

class BCEWithLogitsLossFlat[source]

BCEWithLogitsLossFlat(*args, axis=-1, floatify=True, thresh=0.5, weight=None, reduction='mean', pos_weight=None, flatten=True, is_2d=True) :: BaseLoss

Same as nn.CrossEntropyLoss, but flattens input and target.

tst = BCEWithLogitsLossFlat()
output = torch.randn(32, 5, 10)
target = torch.randn(32, 5, 10)
#nn.BCEWithLogitsLoss would fail with those two tensors, but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))
output = torch.randn(32, 5)
target = torch.randint(0,2,(32, 5))
#nn.BCEWithLogitsLoss would fail with int targets but not our flattened version.
_ = tst(output, target)
test_fail(lambda x: nn.BCEWithLogitsLoss()(output,target))

#Associated activation is sigmoid
test_eq(tst.activation(output), torch.sigmoid(output))

BCELossFlat[source]

BCELossFlat(*args, axis=-1, floatify=True, weight=None, reduction='mean')

Same as nn.BCELoss, but flattens input and target.

tst = BCELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.BCELoss()(output,target))

MSELossFlat[source]

MSELossFlat(*args, axis=-1, floatify=True, reduction='mean')

Same as nn.MSELoss, but flattens input and target.

tst = MSELossFlat()
output = torch.sigmoid(torch.randn(32, 5, 10))
target = torch.randint(0,2,(32, 5, 10))
_ = tst(output, target)
test_fail(lambda x: nn.MSELoss()(output,target))

L1LossFlat[source]

L1LossFlat(*args, axis=-1, floatify=True, reduction='mean')

Same as nn.L1Loss, but flattens input and target.

class LabelSmoothingCrossEntropy[source]

LabelSmoothingCrossEntropy(eps:float=0.1, reduction='mean') :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

On top of the formula we define:

  • a reduction attribute, that will be used when we call Learner.get_preds
  • an activation function that represents the activation fused in the loss (since we use cross entropy behind the scenes). It will be applied to the output of the model when calling Learner.get_preds or Learner.predict
  • a decodes function that converts the output of the model to a format similar to the target (here indices). This is used in Learner.predict and Learner.show_results to decode the predictions

class LabelSmoothingCrossEntropyFlat[source]

LabelSmoothingCrossEntropyFlat(*args, axis=-1, eps=0.1, reduction='mean', flatten=True, floatify=False, is_2d=True) :: BaseLoss

Same as LabelSmoothingCrossEntropy, but flattens input and target.

Embeddings

trunc_normal_[source]

trunc_normal_(x, mean=0.0, std=1.0)

Truncated normal initialization (approximation)

class Embedding[source]

Embedding(ni, nf) :: Embedding

Embedding layer with truncated normal initialization

Truncated normal initialization bounds the distribution to avoid large value. For a given standard deviation std, the bounds are roughly -std, std.

tst = Embedding(10, 30)
assert tst.weight.min() > -0.02
assert tst.weight.max() < 0.02
test_close(tst.weight.mean(), 0, 1e-2)
test_close(tst.weight.std(), 0.01, 0.1)

Self attention

class SelfAttention[source]

SelfAttention(n_channels) :: Module

Self attention layer for n_channels.

Self-attention layer as introduced in Self-Attention Generative Adversarial Networks.

Initially, no change is done to the input. This is controlled by a trainable parameter named gamma as we return x + gamma * out.

tst = SelfAttention(16)
x = torch.randn(32, 16, 8, 8)
test_eq(tst(x),x)

Then during training gamma will probably change since it's a trainable parameter. Let's see what's happening when it gets a nonzero value.

tst.gamma.data.fill_(1.)
y = tst(x)
test_eq(y.shape, [32,16,8,8])

The attention mechanism requires three matrix multiplications (here represented by 1x1 convs). The multiplications are done on the channel level (the second dimension in our tensor) and we flatten the feature map (which is 8x8 here). As in the paper, we note f, g and h the results of those multiplications.

q,k,v = tst.query[0].weight.data,tst.key[0].weight.data,tst.value[0].weight.data
test_eq([q.shape, k.shape, v.shape], [[2, 16, 1], [2, 16, 1], [16, 16, 1]])
f,g,h = map(lambda m: x.view(32, 16, 64).transpose(1,2) @ m.squeeze().t(), [q,k,v])
test_eq([f.shape, g.shape, h.shape], [[32,64,2], [32,64,2], [32,64,16]])

The key part of the attention layer is to compute attention weights for each of our location in the feature map (here 8x8 = 64). Those are positive numbers that sum to 1 and tell the model to pay attention to this or that part of the picture. We make the product of f and the transpose of g (to get something of size bs by 64 by 64) then apply a softmax on the first dimension (to get the positive numbers that sum up to 1). The result can then be multiplied with h transposed to get an output of size bs by channels by 64, which we can then be viewed as an output the same size as the original input.

The final result is then x + gamma * out as we saw before.

beta = F.softmax(torch.bmm(f, g.transpose(1,2)), dim=1)
test_eq(beta.shape, [32, 64, 64])
out = torch.bmm(h.transpose(1,2), beta)
test_eq(out.shape, [32, 16, 64])
test_close(y, x + out.view(32, 16, 8, 8), eps=1e-4)

class PooledSelfAttention2d[source]

PooledSelfAttention2d(n_channels) :: Module

Pooled self attention layer for 2d.

Self-attention layer used in the Big GAN paper.

It uses the same attention as in SelfAttention but adds a max pooling of stride 2 before computing the matrices g and h: the attention is ported on one of the 2x2 max-pooled window, not the whole feature map. There is also a final matrix product added at the end to the output, before retuning gamma * out + x.

class SimpleSelfAttention[source]

SimpleSelfAttention(n_in:int, ks=1, sym=False) :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

PixelShuffle

PixelShuffle introduced in this article to avoid checkerboard artifacts when upsampling images. If we want an output with ch_out filters, we use a convolution with ch_out * (r**2) filters, where r is the upsampling factor. Then we reorganize those filters like in the picture below:

Pixelshuffle

icnr_init[source]

icnr_init(x, scale=2, init=kaiming_normal_)

ICNR init of x, with scale and init function

ICNR init was introduced in this article. It suggests to initialize the convolution that will be used in PixelShuffle so that each of the r**2 channels get the same weight (so that in the picture above, the 9 colors in a 3 by 3 window are initially the same).

tst = torch.randn(16*4, 32, 1, 1)
tst = icnr_init(tst)
for i in range(0,16*4,4):
    test_eq(tst[i],tst[i+1])
    test_eq(tst[i],tst[i+2])
    test_eq(tst[i],tst[i+3])

class PixelShuffle_ICNR[source]

PixelShuffle_ICNR(ni, nf=None, scale=2, blur=False, norm_type=<NormType.Weight: 3>, act_cls=ReLU) :: Sequential

Upsample by scale from ni filters to nf (default ni), using nn.PixelShuffle.

The convolutional layer is initialized with icnr_init and passed act_cls and norm_type (the default of weight normalization seemed to be what's best for super-resolution problems, in our experiments).

The blur option comes from Super-Resolution using Convolutional Neural Networks without Any Checkerboard Artifacts where the authors add a little bit of blur to completely get rid of checkerboard artifacts.

psfl = PixelShuffle_ICNR(16, norm_type=None) #Deactivate weight norm as it changes the weight
x = torch.randn(64, 16, 8, 8)
y = psfl(x)
test_eq(y.shape, [64, 16, 16, 16])
#ICNR init makes every 2x2 window (stride 2) have the same elements
for i in range(0,16,2):
    for j in range(0,16,2):
        test_eq(y[:,:,i,j],y[:,:,i+1,j])
        test_eq(y[:,:,i,j],y[:,:,i  ,j+1])
        test_eq(y[:,:,i,j],y[:,:,i+1,j+1])

Sequential extensions

sequential[source]

sequential(*args)

Create an nn.Sequential, wrapping items with Lambda if needed

class SequentialEx[source]

SequentialEx(*layers) :: Module

Like nn.Sequential, but with ModuleList semantics, and can access module input

This is useful to write layers that require to remember the input (like a resnet block) in a sequential way.

class MergeLayer[source]

MergeLayer(dense:bool=False) :: Module

Merge a shortcut with the result of the module by adding them or concatenating them if dense=True.

res_block = SequentialEx(ConvLayer(16, 16), ConvLayer(16,16))
res_block.append(MergeLayer()) # just to test append - normally it would be in init params
x = torch.randn(32, 16, 8, 8)
y = res_block(x)
test_eq(y.shape, [32, 16, 8, 8])
test_eq(y, x + res_block[1](res_block[0](x)))

Concat

Equivalent to keras.layers.Concatenate, it will concat the outputs of a ModuleList over a given dimension (default the filter dimension)

class Cat[source]

Cat(layers, dim=1) :: ModuleList

Concatenate layers outputs over a given dim

layers = [ConvLayer(2,4), ConvLayer(2,4), ConvLayer(2,4)] 
x = torch.rand(1,2,8,8) 
cat = Cat(layers) 
test_eq(cat(x).shape, [1,12,8,8]) 
test_eq(cat(x), torch.cat([l(x) for l in layers], dim=1))

Ready-to-go models

class SimpleCNN[source]

SimpleCNN(filters, kernel_szs=None, strides=None, bn=True) :: Sequential

Create a simple CNN with filters.

The model is a succession of convolutional layers from (filters[0],filters[1]) to (filters[n-2],filters[n-1]) (if n is the length of the filters list) followed by a PoolFlatten. kernel_szs and strides defaults to a list of 3s and a list of 2s. If bn=True the convolutional layers are successions of conv-relu-batchnorm, otherwise conv-relu.

tst = SimpleCNN([8,16,32])
mods = list(tst.children())
test_eq(len(mods), 3)
test_eq([[m[0].in_channels, m[0].out_channels] for m in mods[:2]], [[8,16], [16,32]])

Test kernel sizes

tst = SimpleCNN([8,16,32], kernel_szs=[1,3])
mods = list(tst.children())
test_eq([m[0].kernel_size for m in mods[:2]], [(1,1), (3,3)])

Test strides

tst = SimpleCNN([8,16,32], strides=[1,2])
mods = list(tst.children())
test_eq([m[0].stride for m in mods[:2]], [(1,1),(2,2)])

class ProdLayer[source]

ProdLayer() :: Module

Merge a shortcut with the result of the module by multiplying them.

SEModule[source]

SEModule(ch, reduction, act_cls=ReLU)

class ResBlock[source]

ResBlock(expansion, ni, nf, stride=1, groups=1, reduction=None, nh1=None, nh2=None, dw=False, g2=1, sa=False, sym=False, norm_type=<NormType.Batch: 1>, act_cls=ReLU, ndim=2, ks=3, pool=AvgPool, pool_first=True, padding=None, bias=None, bn_1st=True, transpose=False, init='auto', xtra=None, bias_std=0.01, dilation:Union[int, Tuple[int, int]]=1, padding_mode:str='zeros') :: Module

Resnet block from ni to nh with stride

This is a resnet block (normal or bottleneck depending on expansion, 1 for the normal block and 4 for the traditional bottleneck) that implements the tweaks from Bag of Tricks for Image Classification with Convolutional Neural Networks. In particular, the last batchnorm layer (if that is the selected norm_type) is initialized with a weight (or gamma) of zero to facilitate the flow from the beginning to the end of the network. It also implements optional Squeeze and Excitation and grouped convs for ResNeXT and similar models (use dw=True for depthwise convs).

The kwargs are passed to ConvLayer along with norm_type.

SEBlock[source]

SEBlock(expansion, ni, nf, groups=1, reduction=16, stride=1, **kwargs)

SEResNeXtBlock[source]

SEResNeXtBlock(expansion, ni, nf, groups=32, reduction=16, stride=1, base_width=4, **kwargs)

SeparableBlock[source]

SeparableBlock(expansion, ni, nf, reduction=16, stride=1, base_width=4, **kwargs)

Swish and Mish

swish[source]

swish(x, inplace=False)

class Swish[source]

Swish() :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

class MishJitAutoFn[source]

MishJitAutoFn() :: Function

Records operation history and defines formulas for differentiating ops.

See the Note on extending the autograd engine for more details on how to use this class: https://pytorch.org/docs/stable/notes/extending.html#extending-torch-autograd

Every operation performed on :class:Tensor s creates a new function object, that performs the computation, and records that it happened. The history is retained in the form of a DAG of functions, with edges denoting data dependencies (input <- output). Then, when backward is called, the graph is processed in the topological ordering, by calling :func:backward methods of each :class:Function object, and passing returned gradients on to next :class:Function s.

Normally, the only way users interact with functions is by creating subclasses and defining new operations. This is a recommended way of extending torch.autograd.

Examples::

>>> class Exp(Function):
>>>
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>>
>>> #Use it by calling the apply method:
>>> output = Exp.apply(input)

mish[source]

mish(x)

class Mish[source]

Mish() :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

Helper functions for submodules

It's easy to get the list of all parameters of a given model. For when you want all submodules (like linear/conv layers) without forgetting lone parameters, the following class wraps those in fake modules.

class ParameterModule[source]

ParameterModule(p) :: Module

Register a lone parameter p in a module.

children_and_parameters[source]

children_and_parameters(m)

Return the children of m and its direct parameters not registered in modules.

class TstModule(Module):
    def __init__(self): self.a,self.lin = nn.Parameter(torch.randn(1)),nn.Linear(5,10)

tst = TstModule()
children = children_and_parameters(tst)
test_eq(len(children), 2)
test_eq(children[0], tst.lin)
assert isinstance(children[1], ParameterModule)
test_eq(children[1].val, tst.a)
class A(Module): pass
assert not A().has_children
assert TstModule().has_children

flatten_model[source]

flatten_model(m)

Return the list of all submodules and parameters of m

tst = nn.Sequential(TstModule(), TstModule())
children = flatten_model(tst)
test_eq(len(children), 4)
assert isinstance(children[1], ParameterModule)
assert isinstance(children[3], ParameterModule)

class NoneReduce[source]

NoneReduce(loss_func)

A context manager to evaluate loss_func with none reduce.

x,y = torch.randn(5),torch.randn(5)
loss_fn = nn.MSELoss()
with NoneReduce(loss_fn) as loss_func:
    loss = loss_func(x,y)
test_eq(loss.shape, [5])
test_eq(loss_fn.reduction, 'mean')

loss_fn = F.mse_loss
with NoneReduce(loss_fn) as loss_func:
    loss = loss_func(x,y)
test_eq(loss.shape, [5])
test_eq(loss_fn, F.mse_loss)

in_channels[source]

in_channels(m)

Return the shape of the first weight layer in m.

test_eq(in_channels(nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))), 5)
test_eq(in_channels(nn.Sequential(nn.AvgPool2d(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(BatchNorm(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(InstanceNorm(4), nn.Conv2d(4,3,3))), 4)
test_eq(in_channels(nn.Sequential(InstanceNorm(4, affine=False), nn.Conv2d(4,3,3))), 4)
test_fail(lambda : in_channels(nn.Sequential(nn.AvgPool2d(4))))