Helper functions to get data in a `DataLoaders` in the tabular application and higher class `TabularDataLoaders`
The main class to get your data ready for model training is TabularDataLoaders
and its factory methods. Checkout the tabular tutorial for examples of use.
This class should not be used directly, one of the factory methods should be preferred instead. All those factory methods accept as arguments:
cat_names
: the names of the categorical variablescont_names
: the names of the continuous variablesy_names
: the names of the dependent variablesy_block
: theTransformBlock
to use for the targetvalid_idx
: the indices to use for the validation set (defaults to a random split otherwise)bs
: the batch sizeval_bs
: the batch size for the validationDataLoader
(defaults tobs
)shuffle_train
: if we shuffle the trainingDataLoader
or notn
: overrides the numbers of elements in the datasetdevice
: the PyTorch device to use (defaults todefault_device()
)
Let's have a look on an example with the adult dataset:
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
df.head()
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_df(df, path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names="salary", valid_idx=list(range(800,1000)), bs=64)
dls.show_batch()
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_csv(path/'adult.csv', path=path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names="salary", valid_idx=list(range(800,1000)), bs=64)