Patch the parallel models so they work with RNNs
Patch the parallel models so they work with RNNs
Convenience functions to set up/tear down torch distributed data parallel mode.
We need to change the dataloaders so that they only get one part of the batch each (otherwise there is no point in using distributed training).
dl = TfmdDL(list(range(50)), bs=16, num_workers=2)
for i in range(4):
dl1 = DistributedDL.from_dl(dl, i, 4)
test_eq(list(dl1)[0], torch.arange(i, 52, 4)%50)
dl = TfmdDL(list(range(50)), bs=16, num_workers=2, shuffle=True)
res = []
for i in range(4):
dl1 = DistributedDL.from_dl(dl, i, 4)
dl1.set_epoch(0)
res += list(dl1)[0].tolist()
#All items should only be accessed once (except 0 and 1 for final cycle) with seeded shuffle
test_eq(sorted(res), [0,0,1,1] + list(range(2, 50)))
Attach, remove a callback which adapts the model to use DistributedDL to train in distributed data parallel mode.
distrib_ctx
context manager
distrib_ctx(cuda_id)
prepares a learner to train in distributed data parallel mode. It assumes these environment variables have all been setup properly, such as those launched by python -m fastai.launch
.
Typical usage:
with learn.distrib_ctx(): learn.fit(.....)
It attaches a DistributedTrainer
callback and DistributedDL
data loader to the learner, then executes learn.fit(.....)
. Upon exiting the context, it removes the DistributedTrainer
and DistributedDL
, and destroys any locally created distributed process group. The process is still attached to the GPU though.
rank0_first(f)
calls f()
in rank-0 process first, then in parallel on the rest, in distributed training mode. In single process, non-distributed training mode, f()
is called only once as expected.
One application of rank0_first()
is to make fresh downloads via untar_data()
safe in distributed training scripts launched by python -m fastai.launch <script>
:
path = untar_data(URLs.IMDB)
becomes:> path = rank0_first(lambda: untar_data(URLs.IMDB))
Some learner factory methods may use untar_data()
to download pretrained models by default:
learn = text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy)
becomes:> learn = rank0_first(lambda: text_classifier_learner(dls, AWD_LSTM, drop_mult=0.5, metrics=accuracy))
Otherwise, multiple processes will download at the same time and corrupt the data.