from fastai.vision.core import *
from fastai.vision.data import *
To build a DataBlock
you need to give the library four things: the types of your input/labels, and at least two functions: get_items
and splitter
. You may also need to include get_x
and get_y
or a more generic list of getters
that are applied to the results of get_items
.
Once those are provided, you automatically get a Datasets
or a DataLoaders
:
You can create a DataBlock
by passing functions:
mnist = DataBlock(blocks = (ImageBlock(cls=PILImageBW),CategoryBlock),
get_items = get_image_files,
splitter = GrandparentSplitter(),
get_y = parent_label)
Each type comes with default transforms that will be applied
- at the base level to create items in a tuple (usually input,target) from the base elements (like filenames)
- at the item level of the datasets
- at the batch level
They are called respectively type transforms, item transforms, batch transforms. In the case of MNIST, the type transforms are the method to create a PILImageBW
(for the input) and the Categorize
transform (for the target), the item transform is ToTensor
and the batch transforms are Cuda
and IntToFloatTensor
. You can add any other transforms by passing them in DataBlock.datasets
or DataBlock.dataloaders
.
test_eq(mnist.type_tfms[0], [PILImageBW.create])
test_eq(mnist.type_tfms[1].map(type), [Categorize])
test_eq(mnist.default_item_tfms.map(type), [ToTensor])
test_eq(mnist.default_batch_tfms.map(type), [IntToFloatTensor])
dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))
test_eq(dsets.vocab, ['3', '7'])
x,y = dsets.train[0]
test_eq(x.size,(28,28))
show_at(dsets.train, 0, cmap='Greys', figsize=(2,2));
test_fail(lambda: DataBlock(wrong_kwarg=42, wrong_kwarg2='foo'))
We can pass any number of blocks to DataBlock
, we can then define what are the input and target blocks by changing n_inp
. For example, defining n_inp=2
will consider the first two blocks passed as inputs and the others as targets.
mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),
get_y=parent_label)
dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))
test_eq(mnist.n_inp, 2)
test_eq(len(dsets.train[0]), 3)
test_fail(lambda: DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),
get_y=[parent_label, noop],
n_inp=2), msg='get_y contains 2 functions, but must contain 1 (one for each output)')
mnist = DataBlock((ImageBlock, ImageBlock, CategoryBlock), get_items=get_image_files, splitter=GrandparentSplitter(),
n_inp=1,
get_y=[noop, Pipeline([noop, parent_label])])
dsets = mnist.datasets(untar_data(URLs.MNIST_TINY))
test_eq(len(dsets.train[0]), 3)
Besides stepping through the transformation, summary()
provides a shortcut dls.show_batch(...)
, to see the data. E.g.
pets.summary(path/"images", bs=8, show_batch=True, unique=True,...)
is a shortcut to:
pets.summary(path/"images", bs=8)
dls = pets.dataloaders(path/"images", bs=8)
dls.show_batch(unique=True,...) # See different tfms effect on the same image.