In this tutorial, we will be looking at the training script of timm. There are various features that timm has to offer and some of them have been listed below:

  1. Auto Augmentation paper
  2. Augmix
  3. Distributed Training on multiple GPUs
  4. Mixed precision training
  5. Auxiliary Batch Norm for AdvProp paper
  6. Synchronized Batch Norm
  7. Mixup and Cutmix with an ability to switch between the two & also turn-off augmentation at a certain epoch

timm also supports multiple optimizers & schedulers. In this tutorial we will be only be looking at the above 7 features and look at how you could utilize timm to use these features for your own experiments on a custom dataset.

As part of this tutorial, we will first start out with a general introduction to the training script and look at the various key steps that occur inside this script at a high-level. Then, we will look at some of the details of the above 7 features to get a further understanding of train.py.

Training args

The training script in timm can accept ~100 arguments. You can find more about these by running python train.py --help. These arguments are to define Dataset/Model parameters, Optimizer parameters, Learnining Rate scheduler parameters, Augmentation and regularization, Batch Norm parameters, Model exponential moving average parameters, and some miscellaneaous parameters such as --seed, --tta etc.

As part of this tutorial, we will be looking at how the training script makes use of these parameters from a high-level view. This could be beneficial for you to able to run your own experiments on ImageNet or any other custom dataset using timm.

Required args

The only argument required by timm training script is the path to the training data such as ImageNet which is structured in the following way:

imagenette2-320
├── train
│   ├── n01440764
│   ├── n02102040
│   ├── n02979186
│   ├── n03000684
│   ├── n03028079
│   ├── n03394916
│   ├── n03417042
│   ├── n03425413
│   ├── n03445777
│   └── n03888257
└── val
    ├── n01440764
    ├── n02102040
    ├── n02979186
    ├── n03000684
    ├── n03028079
    ├── n03394916
    ├── n03417042
    ├── n03425413
    ├── n03445777
    └── n03888257

So to start training on this imagenette2-320 we could just do something like python train.py <path_to_imagenette2-320_dataset>.

Default args

The various default args, in the training script are setup for you and what get's passed to the training script looks something like this:

Namespace(aa=None, amp=False, apex_amp=False, aug_splits=0, batch_size=32, bn_eps=None, bn_momentum=None, bn_tf=False, channels_last=False, clip_grad=None, color_jitter=0.4, cooldown_epochs=10, crop_pct=None, cutmix=0.0, cutmix_minmax=None, data_dir='../imagenette2-320', dataset='', decay_epochs=30, decay_rate=0.1, dist_bn='', drop=0.0, drop_block=None, drop_connect=None, drop_path=None, epochs=200, eval_metric='top1', gp=None, hflip=0.5, img_size=None, initial_checkpoint='', input_size=None, interpolation='', jsd=False, local_rank=0, log_interval=50, lr=0.01, lr_cycle_limit=1, lr_cycle_mul=1.0, lr_noise=None, lr_noise_pct=0.67, lr_noise_std=1.0, mean=None, min_lr=1e-05, mixup=0.0, mixup_mode='batch', mixup_off_epoch=0, mixup_prob=1.0, mixup_switch_prob=0.5, model='resnet101', model_ema=False, model_ema_decay=0.9998, model_ema_force_cpu=False, momentum=0.9, native_amp=False, no_aug=False, no_prefetcher=False, no_resume_opt=False, num_classes=None, opt='sgd', opt_betas=None, opt_eps=None, output='', patience_epochs=10, pin_mem=False, pretrained=False, ratio=[0.75, 1.3333333333333333], recount=1, recovery_interval=0, remode='const', reprob=0.0, resplit=False, resume='', save_images=False, scale=[0.08, 1.0], sched='step', seed=42, smoothing=0.1, split_bn=False, start_epoch=None, std=None, sync_bn=False, torchscript=False, train_interpolation='random', train_split='train', tta=0, use_multi_epochs_loader=False, val_split='validation', validation_batch_size_multiplier=1, vflip=0.0, warmup_epochs=3, warmup_lr=0.0001, weight_decay=0.0001, workers=4)

Notice, that args is a Namespace which means we can set more along the way if needed by doing something like args.new_variable="some_value".

To get a one-line introduction of these various arguments, we can just do something like python train.py --help.

The training script in 20 steps

In this section we will look at the various steps from a high level perspective that occur inside the training script. These steps have been outlined below in the correct order:

  1. Setup up distributed training parameters if args.distributed is True.
  2. Setup manual seed for reproducible results.
  3. Create Model: Create the model to train using timm.create_model function.
  4. Setup data config based on model's default config. In general the default config of the model looks something like:
    {'url': '', 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': (0.485, 0.456, 0.406), 'std': (0.229, 0.224, 0.225), 'first_conv': 'conv1', 'classifier': 'fc'}
    
  5. Setup augmentation batch splits and if the number of augmentation batch splits is more than 1, and if so, convert all model BatchNormlayers to Split Batch Normalization layers.
  6. If we are using multiple GPUs for training, then setup either apex syncBN or PyTorch native SyncBatchNorm to set up Synchronized Batch Normalization. This means that rather than normalizing the data on each individual GPU, we normalize the whole batch at one spread across multiple GPUs.

  7. Make model exportable using torch.jit if requested.

  8. Initialize optimizer based on arguments passed to the training script.
  9. Setup mixed Precision - either using apex.amp or using native torch amp - torch.cuda.amp.autocast.
  10. Load model weights if resuming from a model checkpoint.
  11. Setup exponential moving average of model weights. This is similar to Stochastic Weight Averaging.
  12. Setup distributed training based on parameters from step-1.
  13. Setup learning rate scheduler.
  14. Create training and validation dataset.
  15. Setup Mixup/Cutmix data augmentation.
  16. Convert training dataset to `AugmixDataset` if number of augmentation batch splits from step-5 is greater than 1.
  17. Create training data loader and Validation dataloader. 18. Setup Loss function.
  18. Setup model checkpointing and evaluation metrics.
  19. Train and Validate the model and also store the eval metrics to an output file.

Some key timm features

Auto-Augment

To enable auto augmentation during training -

python train.py ./imagenette2-320 --aa 'v0'

Augmix

A brief introduction about augmix has been presented here. To enable augmix during training, simply do:

python train.py ./imagenette2-320 --aug-splits 3 --jsd

timm also supports augmix with RandAugment and AutoAugment like so:

python train.py ./imagenette2-320 --aug-splits 3 --jsd --aa rand-m9-mstd0.5-inc1

Distributed Training on multiple GPUs

To train models on multiple GPUs, simply replace python train.py with ./distributed_train.sh <num-gpus> like so:

./distributed_train.sh 4 ./imagenette2-320 --aug-splits 3 --jsd

This trains the model using AugMix data augmentation on 4 GPUs.

Mixed precision training

To enable mixed precision training, simply add the --amp flag. timm will automatically implement mixed precision training either using apex or PyTorch Native mixed precision training.

python train.py ../imagenette2-320 --aug-splits 3 --jsd --amp

Auxiliary Batch Norm/ SplitBatchNorm

From the paper,

markdown
Batch normalization serves as an essential component for many state-of-the-art computer vision models. Specifically, BN normalizes input features by the mean and variance computed within each mini-batch. **One intrinsic assumption of utilizing BN is that the input features should come from a single or similar distributions.** This normalization behavior could be problematic if the mini-batch contains data from different distributions, there- fore resulting in inaccurate statistics estimation.

To disentangle this mixture distribution into two simpler ones respectively for the clean and adversarial images, we hereby propose an auxiliary BN to guarantee its normalization statistics are exclusively preformed on the adversarial examples.

To enable split batch norm,

python train.py ./imagenette2-320 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --split-bn

Using the above command, timm now has separate batch normalization layer for each augmentation split.

Synchronized Batch Norm

Synchronized batch norm is only used when training on multiple GPUs. From papers with code:

markdown
Synchronized Batch Normalization (SyncBN) is a type of batch normalization used for multi-GPU training. Standard batch normalization only normalizes the data within each device (GPU). SyncBN normalizes the input within the whole mini-batch.

To enable, simply add --sync-bn flag like so:

./distributed_train.sh 4 ../imagenette2-320 --aug-splits 3 --jsd --sync-bn

Mixup and Cutmix

To enable either mixup or cutmix, simply add the --mixup or --cutmix flag with probability of applying the augmentation.

For example, to enable mixup,

train.py ../imagenette2-320 --mixup 0.5

Or for Cutmix,

train.py ../imagenette2-320 --cutmix 0.5

It is also possible to enable both,

python train.py ../imagenette2-320 --mixup 0.5 --cutmix 0.5 --mixup-switch-prob 0.3

The above command will use either Mixup or Cutmix as data augmentation techniques and apply it to the batch with 50% probability. It will also switch between the two with 30% probability.

There is also a parameter to turn off Mixup/Cutmix augmentation at a certail epoch:

python train.py ../imagenette2-320 --mixup 0.5 --cutmix 0.5 --mixup-switch-prob 0.3 --mixup-off-epoch 10

The above command only applies the Mixup/Cutmix data augmentation for the first 10 epochs.