t = tensor([0,1,2])
r = reverse_text(t)
test_eq(r, tensor([2,1,0]))
Numericalization is the step in which we convert tokens to integers. The first step is to build a correspondence token to index that is called a vocab.
If there are more than max_vocab tokens, the ones kept are the most frequent.
xxfake tokens.count = Counter(['a', 'a', 'a', 'a', 'b', 'b', 'c', 'c', 'd'])
test_eq(set([x for x in make_vocab(count) if not x.startswith('xxfake')]),
set(defaults.text_spec_tok + 'a'.split()))
test_eq(len(make_vocab(count))%8, 0)
test_eq(set([x for x in make_vocab(count, min_freq=1) if not x.startswith('xxfake')]),
set(defaults.text_spec_tok + 'a b c d'.split()))
test_eq(set([x for x in make_vocab(count,max_vocab=12, min_freq=1) if not x.startswith('xxfake')]),
set(defaults.text_spec_tok + 'a b c'.split()))
If no vocab is passed, one is created at setup from the data, using make_vocab with min_freq and max_vocab.
start = 'This is an example of text'
num = Numericalize(min_freq=1)
num.setup(L(start.split(), 'this is another text'.split()))
test_eq(set([x for x in num.vocab if not x.startswith('xxfake')]),
set(defaults.text_spec_tok + 'This is an example of text this another'.split()))
test_eq(len(num.vocab)%8, 0)
t = num(start.split())
test_eq(t, tensor([11, 9, 12, 13, 14, 10]))
test_eq(num.decode(t), start.split())
num = Numericalize(min_freq=2)
num.setup(L('This is an example of text'.split(), 'this is another text'.split()))
test_eq(set([x for x in num.vocab if not x.startswith('xxfake')]),
set(defaults.text_spec_tok + 'is text'.split()))
test_eq(len(num.vocab)%8, 0)
t = num(start.split())
test_eq(t, tensor([0, 9, 0, 0, 0, 10]))
test_eq(num.decode(t), f'{UNK} is {UNK} {UNK} {UNK} text'.split())
dataset should be a collection of numericalized texts for this to work. lens can be passed for optimizing the creation, otherwise, the LMDataLoader will do a full pass of the dataset to compute them. cache is used to avoid reloading items unnecessarily.
The LMDataLoader will concatenate all texts (maybe shuffled) in one big stream, split it in bs contiguous sentences, then go through those seq_len at a time.
bs,sl = 4,3
ints = L([0,1,2,3,4],[5,6,7,8,9,10],[11,12,13,14,15,16,17,18],[19,20],[21,22,23],[24]).map(tensor)
dl = LMDataLoader(ints, bs=bs, seq_len=sl)
test_eq(list(dl),
[[tensor([[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]]),
tensor([[1, 2, 3], [7, 8, 9], [13, 14, 15], [19, 20, 21]])],
[tensor([[3, 4, 5], [ 9, 10, 11], [15, 16, 17], [21, 22, 23]]),
tensor([[4, 5, 6], [10, 11, 12], [16, 17, 18], [22, 23, 24]])]])
dl = LMDataLoader(ints, bs=bs, seq_len=sl, shuffle=True)
for x,y in dl: test_eq(x[:,1:], y[:,:-1])
((x0,y0), (x1,y1)) = tuple(dl)
#Second batch begins where first batch ended
test_eq(y0[:,-1], x1[:,0])
test_eq(type(x0), LMTensorText)
For classification, we deal with the fact that texts don't all have the same length by using padding.
pad_idx is used for the padding, and the padding is applied to the pad_fields of the samples. The padding is applied at the beginning if pad_first is True, and if backwards is added, the tensors are flipped.
test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0),
[(tensor([1,2,3]),1), (tensor([4,5,0]),2), (tensor([6,0,0]), 3)])
test_eq(pad_input([(tensor([1,2,3]), (tensor([6]))), (tensor([4,5]), tensor([4,5])), (tensor([6]), (tensor([1,2,3])))], pad_idx=0, pad_fields=1),
[(tensor([1,2,3]),(tensor([6,0,0]))), (tensor([4,5]),tensor([4,5,0])), ((tensor([6]),tensor([1, 2, 3])))])
test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, pad_first=True),
[(tensor([1,2,3]),1), (tensor([0,4,5]),2), (tensor([0,0,6]), 3)])
test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, backwards=True),
[(tensor([3,2,1]),1), (tensor([5,4,0]),2), (tensor([6,0,0]), 3)])
x = test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, backwards=True),
[(tensor([3,2,1]),1), (tensor([5,4,0]),2), (tensor([6,0,0]), 3)])
The difference with the base pad_input is that most of the padding is applied first (if pad_first=True) or at the end (if pad_first=False) but only by a round multiple of seq_len. The rest of the padding is applied to the end (or the beginning if pad_first=False). This is to work with SequenceEncoder with recurrent models.
test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),1), (tensor([1,2,3]), 2), (tensor([1,2]), 3)], pad_idx=0, seq_len=2),
[(tensor([1,2,3,4,5,6]),1), (tensor([0,0,1,2,3,0]),2), (tensor([0,0,0,0,1,2]), 3)])
test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),), (tensor([1,2,3]),), (tensor([1,2]),)], pad_idx=0, seq_len=2),
[(tensor([1,2,3,4,5,6]),), (tensor([0,0,1,2,3,0]),), (tensor([0,0,0,0,1,2]),)])
test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),), (tensor([1,2,3]),), (tensor([1,2]),)], pad_idx=0, seq_len=2, pad_first=False),
[(tensor([1,2,3,4,5,6]),), (tensor([1,2,3,0,0,0]),), (tensor([1,2,0,0,0,0]),)])
res is the result of sort_func applied on all elements of the dataset. You can pass it if available to make the init much faster by avoiding an initial pass over the whole dataset. For example if sorting by text length (as in the default sort_func, called _default_sort) you should pass a list with the length of each element in dataset to res to take advantage of this speed-up.
To get the same init speed-up for the validation set, val_res (a list of text lengths for your validation set) can be passed to the kwargs argument of SortedDL. Below is an example to reduce the init time by passing a list of text lengths for both the training set and the validation set:
# Pass the training dataset text lengths to SortedDL
srtd_dl=partial(SortedDL, res = train_text_lens)
# Pass the validation dataset text lengths
dl_kwargs = [{},{'val_res': val_text_lens}]
# init our Datasets
dsets = Datasets(...)
# init our Dataloaders
dls = dsets.dataloaders(...,dl_type = srtd_dl, dl_kwargs = dl_kwargs)
If shuffle is True, this will shuffle a bit the results of the sort to have items of roughly the same size in batches, but not in the exact sorted order.
ds = [(tensor([1,2]),1), (tensor([3,4,5,6]),2), (tensor([7]),3), (tensor([8,9,10]),4)]
dl = SortedDL(ds, bs=2, before_batch=partial(pad_input, pad_idx=0))
test_eq(list(dl), [(tensor([[ 3, 4, 5, 6], [ 8, 9, 10, 0]]), tensor([2, 4])),
(tensor([[1, 2], [7, 0]]), tensor([1, 3]))])
ds = [(tensor(range(random.randint(1,10))),i) for i in range(101)]
dl = SortedDL(ds, bs=2, create_batch=partial(pad_input, pad_idx=-1), shuffle=True, num_workers=0)
batches = list(dl)
max_len = len(batches[0][0])
for b in batches:
assert(len(b[0])) <= max_len
test_ne(b[0][-1], -1)
To use the data block API, you will need this build block for texts.
For efficient tokenization, you probably want to use one of the factory methods. Otherwise, you can pass your custom tok_tfm that will deal with tokenization (if your texts are already tokenized, you can pass noop), a vocab, or leave it to be inferred on the texts using min_freq and max_vocab.
is_lm indicates if we want to use texts for language modeling or another task, seq_len is only necessary to tune if is_lm=False, and is passed along to pad_input_chunk.
Here is an example using a sample of IMDB stored as a CSV file:
path = untar_data(URLs.IMDB_SAMPLE)
df = pd.read_csv(path/'texts.csv')
imdb_clas = DataBlock(
blocks=(TextBlock.from_df('text', seq_len=72), CategoryBlock),
get_x=ColReader('text'), get_y=ColReader('label'), splitter=ColSplitter())
dls = imdb_clas.dataloaders(df, bs=64)
dls.show_batch(max_n=2)
vocab, is_lm, seq_len, min_freq and max_vocab are passed to the main init, the other argument to Tokenizer.from_df.
vocab, is_lm, seq_len, min_freq and max_vocab are passed to the main init, the other argument to Tokenizer.from_folder.
You should not use the init directly but one of the following factory methods. All those factory methods accept as arguments:
text_vocab: the vocabulary used for numericalizing texts (if not passed, it's inferred from the data)tok_tfm: if passed, uses thistok_tfminstead of the defaultseq_len: the sequence length used for batchbs: the batch sizeval_bs: the batch size for the validationDataLoader(defaults tobs)shuffle_train: if we shuffle the trainingDataLoaderor notdevice: the PyTorch device to use (defaults todefault_device())
If valid_pct is provided, a random split is performed (with an optional seed) by setting aside that percentage of the data for the validation set (instead of looking at the grandparents folder). If a vocab is passed, only the folders with names in vocab are kept.
Here is an example on a sample of the IMDB movie review dataset:
path = untar_data(URLs.IMDB)
dls = TextDataLoaders.from_folder(path)
dls.show_batch(max_n=3)
seed can optionally be passed for reproducibility. text_col, label_col and optionally valid_col are indices or names of columns for texts/labels and the validation flag. label_delim can be passed for a multi-label problem if your labels are in one column, separated by a particular char. y_block should be passed to indicate your type of targets, in case the library did no infer it properly.
Here are examples on subsets of IMDB:
dls = TextDataLoaders.from_df(df, path=path, text_col='text', label_col='label', valid_col='is_valid')
dls.show_batch(max_n=3)
dls = TextDataLoaders.from_df(df, path=path, text_col='text', is_lm=True, valid_col='is_valid')
dls.show_batch(max_n=3)
Opens the csv file with header and delimiter, then pass all the other arguments to TextDataLoaders.from_df.
dls = TextDataLoaders.from_csv(path=path, csv_fname='texts.csv', text_col='text', label_col='label', valid_col='is_valid')
dls.show_batch(max_n=3)