learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback())
learn = synth_learner()
learn.fit(1, cbs=ShortEpochCallback(short_valid=False))
learn = synth_learner()
learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=2*learn.dls.bs))
# ensure train_loss decreased
assert learn.recorder.values[-1][0] < learn.recorder.values[0][0]
learn.fit(2, lr=0.01, cbs=GradientAccumulation(n_acc=1e6))
# ensure valid_loss didn't change (same weights)
assert learn.recorder.values[-1][1] == learn.recorder.values[0][1]
BnFreeze
is useful when you'd like to train two separate models that have a common feature extractor / body. The only part of the model that's different is the head that you attach for transfer learning.
Learner.freeze()
) doesn't suffice here as the BatchNorm
layers are trainable by default, and running mean and std of batches are tracked. For feature extractors to fully match, you need to set train_bn=False
and these stats need to be frozen as well, which is precisely the function of BnFreeze
.
from fastai.vision.all import *
path = untar_data(URLs.MNIST_TINY)
dls = ImageDataLoaders.from_folder(path, valid_pct=0.2)
We first demonstrate the mismatch of the running stats when using only train_bn=False
, by creating a Learner
...:
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False)
...and grab the first BatchNorm
layer, and store its running mean:
m = learn1.model[0][1].running_mean.clone()
You can see that now that running mean has changed:
learn1.fit(1, lr=0.02)
test_ne(learn1.model[0][1].running_mean, m)
When we use the BnFreeze
callback, the running statistics will not be changed during training. This is often important for getting good results from transfer learning.
learn1 = cnn_learner(deepcopy(dls), resnet18, pretrained=True, train_bn=False, cbs=BnFreeze)
m = learn1.model[0][1].running_mean.clone()
learn1.fit(1, lr=0.02)
test_eq(learn1.model[0][1].running_mean, m)