Callbacks that make decisions depending how a monitored metric/loss behaves

class TerminateOnNaNCallback[source]

TerminateOnNaNCallback(before_fit=None, before_epoch=None, before_train=None, before_batch=None, after_pred=None, after_loss=None, before_backward=None, after_backward=None, after_step=None, after_cancel_batch=None, after_batch=None, after_cancel_train=None, after_train=None, before_validate=None, after_cancel_validate=None, after_validate=None, after_cancel_epoch=None, after_epoch=None, after_cancel_fit=None, after_fit=None) :: Callback

A Callback that terminates training if loss is NaN.

learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())
epoch train_loss valid_loss time
0 733056513731005412711487968341131264.000000 00:00
assert len(learn.recorder.losses) < 10 * len(learn.dls.train)
for l in learn.recorder.losses:
    assert not torch.isinf(l) and not torch.isnan(l) 

class TrackerCallback[source]

TrackerCallback(monitor='valid_loss', comp=None, min_delta=0.0) :: Callback

A Callback that keeps track of the best value in monitor.

When implementing a Callback that has behavior that depends on the best value of a metric or loss, subclass this Callback and use its best (for best value so far) and new_best (there was a new best value this epoch) attributes.

comp is the comparison operator used to determine if a value is best than another (defaults to np.less if 'loss' is in the name passed in monitor, np.greater otherwise) and min_delta is an optional float that requires a new value to go over the current best (depending on comp) by at least that amount.

class EarlyStoppingCallback[source]

EarlyStoppingCallback(monitor='valid_loss', comp=None, min_delta=0.0, patience=1) :: TrackerCallback

A TrackerCallback that terminates training when monitored quantity stops improving.

comp is the comparison operator used to determine if a value is best than another (defaults to np.less if 'loss' is in the name passed in monitor, np.greater otherwise) and min_delta is an optional float that requires a new value to go over the current best (depending on comp) by at least that amount. patience is the number of epochs you're willing to wait without improvement.

learn = synth_learner(n_trn=2, metrics=F.mse_loss)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='mse_loss', min_delta=0.1, patience=2))
epoch train_loss valid_loss mse_loss time
0 19.993200 24.202908 24.202908 00:00
1 20.007574 24.202845 24.202845 00:00
2 20.021687 24.202751 24.202751 00:00
No improvement since epoch 0: early stopping
learn.validate()
(#2) [24.20275115966797,24.20275115966797]
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))
epoch train_loss valid_loss time
0 12.963860 10.800257 00:00
1 12.936502 10.800226 00:00
2 12.926699 10.800186 00:00
No improvement since epoch 0: early stopping

class SaveModelCallback[source]

SaveModelCallback(monitor='valid_loss', comp=None, min_delta=0.0, fname='model', every_epoch=False, with_opt=False) :: TrackerCallback

A TrackerCallback that saves the model's best during training and loads it at the end.

comp is the comparison operator used to determine if a value is best than another (defaults to np.less if 'loss' is in the name passed in monitor, np.greater otherwise) and min_delta is an optional float that requires a new value to go over the current best (depending on comp) by at least that amount. Model will be saved in learn.path/learn.model_dir/name.pth, maybe every_epoch or at each improvement of the monitored quantity.

learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback())
assert (Path.cwd()/'tmp/models/model.pth').exists()
learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))
for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()
shutil.rmtree(Path.cwd()/'tmp')
epoch train_loss valid_loss time
0 22.050220 19.015476 00:00
1 21.870991 18.548334 00:00
Better model found at epoch 0 with valid_loss value: 19.01547622680664.
Better model found at epoch 1 with valid_loss value: 18.5483341217041.
epoch train_loss valid_loss time
0 21.132572 17.913414 00:00
1 20.745022 17.148394 00:00

ReduceLROnPlateau

class ReduceLROnPlateau[source]

ReduceLROnPlateau(monitor='valid_loss', comp=None, min_delta=0.0, patience=1, factor=10.0, min_lr=0) :: TrackerCallback

A TrackerCallback that reduces learning rate when a metric has stopped improving.

learn = synth_learner(n_trn=2)
learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))
epoch train_loss valid_loss time
0 19.359529 19.566990 00:00
1 19.375505 19.566935 00:00
2 19.362509 19.566856 00:00
3 19.386513 19.566845 00:00
Epoch 2: reducing lr to 1e-08
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=6, lr=5e-8, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2, min_lr=1e-8))
epoch train_loss valid_loss time
0 11.854585 8.217508 00:00
1 11.875463 8.217495 00:00
2 11.885604 8.217478 00:00
3 11.876034 8.217473 00:00
4 11.872295 8.217467 00:00
5 11.874965 8.217461 00:00
Epoch 2: reducing lr to 1e-08