Callbacks which work with a learner's data
You don't normally need to use this Callback, because fastai's DataLoader
will handle passing data to a device for you. However, if you already have a plain PyTorch DataLoader and can't change it for some reason, you can use this transform.
learn = synth_learner(cbs=CudaCallback)
learn.model
learn.fit(1)
test_eq(next(learn.model.parameters()).device.type, 'cuda')
n = 160
dsets = Datasets(torch.arange(n).float())
dls = dsets.weighted_dataloaders(wgts=range(n), bs=16)
learn = synth_learner(data=dls, cbs=CollectDataCallback)
learn.fit(1)
t = concat(*learn.collect_data.data.itemgot(0,0))
plt.hist(t);
dls = dsets.partial_dataloaders(partial_n=32, bs=16)
assert len(dls[0])==2
for batch in dls[0]:
assert len(batch[0])==16