learn = synth_learner()
learn.fit(10, lr=100, cbs=TerminateOnNaNCallback())
assert len(learn.recorder.losses) < 10 * len(learn.dls.train)
for l in learn.recorder.losses:
assert not torch.isinf(l) and not torch.isnan(l)
When implementing a Callback
that has behavior that depends on the best value of a metric or loss, subclass this Callback
and use its best
(for best value so far) and new_best
(there was a new best value this epoch) attributes.
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount.
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount. patience
is the number of epochs you're willing to wait without improvement.
learn = synth_learner(n_trn=2, metrics=F.mse_loss)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='mse_loss', min_delta=0.1, patience=2))
learn.validate()
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=200, lr=1e-7, cbs=EarlyStoppingCallback(monitor='valid_loss', min_delta=0.1, patience=2))
comp
is the comparison operator used to determine if a value is best than another (defaults to np.less
if 'loss' is in the name passed in monitor
, np.greater
otherwise) and min_delta
is an optional float that requires a new value to go over the current best (depending on comp
) by at least that amount. Model will be saved in learn.path/learn.model_dir/name.pth
, maybe every_epoch
or at each improvement of the monitored quantity.
learn = synth_learner(n_trn=2, path=Path.cwd()/'tmp')
learn.fit(n_epoch=2, cbs=SaveModelCallback())
assert (Path.cwd()/'tmp/models/model.pth').exists()
learn.fit(n_epoch=2, cbs=SaveModelCallback(every_epoch=True))
for i in range(2): assert (Path.cwd()/f'tmp/models/model_{i}.pth').exists()
shutil.rmtree(Path.cwd()/'tmp')
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=4, lr=1e-7, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2))
learn = synth_learner(n_trn=2)
learn.fit(n_epoch=6, lr=5e-8, cbs=ReduceLROnPlateau(monitor='valid_loss', min_delta=0.1, patience=2, min_lr=1e-8))