Functions and transforms to help gather text data in a `Datasets`

Backwards

Reversing the text can provide higher accuracy with an ensemble with a forward model. All that is needed is a type_tfm that will reverse the text as it is brought in:

reverse_text[source]

reverse_text(x)

t = tensor([0,1,2])
r = reverse_text(t)
test_eq(r, tensor([2,1,0]))

Numericalizing

Numericalization is the step in which we convert tokens to integers. The first step is to build a correspondence token to index that is called a vocab.

make_vocab[source]

make_vocab(count, min_freq=3, max_vocab=60000, special_toks=None)

Create a vocab of max_vocab size from Counter count with items present more than min_freq

If there are more than max_vocab tokens, the ones kept are the most frequent.

count = Counter(['a', 'a', 'a', 'a', 'b', 'b', 'c', 'c', 'd'])
test_eq(set([x for x in make_vocab(count) if not x.startswith('xxfake')]), 
        set(defaults.text_spec_tok + 'a'.split()))
test_eq(len(make_vocab(count))%8, 0)
test_eq(set([x for x in make_vocab(count, min_freq=1) if not x.startswith('xxfake')]), 
        set(defaults.text_spec_tok + 'a b c d'.split()))
test_eq(set([x for x in make_vocab(count,max_vocab=12, min_freq=1) if not x.startswith('xxfake')]), 
        set(defaults.text_spec_tok + 'a b c'.split()))

class TensorText[source]

TensorText(x, **kwargs) :: TensorBase

Semantic type for a tensor representing text

class LMTensorText[source]

LMTensorText(x, **kwargs) :: TensorText

Semantic type for a tensor representing text in language modeling

class Numericalize[source]

Numericalize(vocab=None, min_freq=3, max_vocab=60000, special_toks=None, pad_tok=None) :: Transform

Reversible transform of tokenized texts to numericalized ids

If no vocab is passed, one is created at setup from the data, using make_vocab with min_freq and max_vocab.

start = 'This is an example of text'
num = Numericalize(min_freq=1)
num.setup(L(start.split(), 'this is another text'.split()))
test_eq(set([x for x in num.vocab if not x.startswith('xxfake')]), 
        set(defaults.text_spec_tok + 'This is an example of text this another'.split()))
test_eq(len(num.vocab)%8, 0)
t = num(start.split())

test_eq(t, tensor([11, 9, 12, 13, 14, 10]))
test_eq(num.decode(t), start.split())
num = Numericalize(min_freq=2)
num.setup(L('This is an example of text'.split(), 'this is another text'.split()))
test_eq(set([x for x in num.vocab if not x.startswith('xxfake')]), 
        set(defaults.text_spec_tok + 'is text'.split()))
test_eq(len(num.vocab)%8, 0)
t = num(start.split())
test_eq(t, tensor([0, 9, 0, 0, 0, 10]))
test_eq(num.decode(t), f'{UNK} is {UNK} {UNK} {UNK} text'.split())

class LMDataLoader[source]

LMDataLoader(dataset, lens=None, cache=2, bs=64, seq_len=72, num_workers=0, shuffle=False, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, drop_last=False, indexed=None, n=None, device=None, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None) :: TfmdDL

A DataLoader suitable for language modeling

dataset should be a collection of numericalized texts for this to work. lens can be passed for optimizing the creation, otherwise, the LMDataLoader will do a full pass of the dataset to compute them. cache is used to avoid reloading items unnecessarily.

The LMDataLoader will concatenate all texts (maybe shuffled) in one big stream, split it in bs contiguous sentences, then go through those seq_len at a time.

bs,sl = 4,3
ints = L([0,1,2,3,4],[5,6,7,8,9,10],[11,12,13,14,15,16,17,18],[19,20],[21,22,23],[24]).map(tensor)
dl = LMDataLoader(ints, bs=bs, seq_len=sl)
test_eq(list(dl),
    [[tensor([[0, 1, 2], [6, 7, 8], [12, 13, 14], [18, 19, 20]]),
      tensor([[1, 2, 3], [7, 8, 9], [13, 14, 15], [19, 20, 21]])],
     [tensor([[3, 4, 5], [ 9, 10, 11], [15, 16, 17], [21, 22, 23]]),
      tensor([[4, 5, 6], [10, 11, 12], [16, 17, 18], [22, 23, 24]])]])
dl = LMDataLoader(ints, bs=bs, seq_len=sl, shuffle=True)
for x,y in dl: test_eq(x[:,1:], y[:,:-1])
((x0,y0), (x1,y1)) = tuple(dl)
#Second batch begins where first batch ended
test_eq(y0[:,-1], x1[:,0]) 
test_eq(type(x0), LMTensorText)

Classification

For classification, we deal with the fact that texts don't all have the same length by using padding.

pad_input[source]

pad_input(samples, pad_idx=1, pad_fields=0, pad_first=False, backwards=False)

Function that collect samples and adds padding

pad_idx is used for the padding, and the padding is applied to the pad_fields of the samples. The padding is applied at the beginning if pad_first is True, and if backwards is added, the tensors are flipped.

test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0), 
        [(tensor([1,2,3]),1), (tensor([4,5,0]),2), (tensor([6,0,0]), 3)])
test_eq(pad_input([(tensor([1,2,3]), (tensor([6]))), (tensor([4,5]), tensor([4,5])), (tensor([6]), (tensor([1,2,3])))], pad_idx=0, pad_fields=1), 
        [(tensor([1,2,3]),(tensor([6,0,0]))), (tensor([4,5]),tensor([4,5,0])), ((tensor([6]),tensor([1, 2, 3])))])
test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, pad_first=True), 
        [(tensor([1,2,3]),1), (tensor([0,4,5]),2), (tensor([0,0,6]), 3)])
test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, backwards=True), 
        [(tensor([3,2,1]),1), (tensor([5,4,0]),2), (tensor([6,0,0]), 3)])
x = test_eq(pad_input([(tensor([1,2,3]),1), (tensor([4,5]), 2), (tensor([6]), 3)], pad_idx=0, backwards=True), 
        [(tensor([3,2,1]),1), (tensor([5,4,0]),2), (tensor([6,0,0]), 3)])

pad_input_chunk[source]

pad_input_chunk(samples, pad_idx=1, pad_first=True, seq_len=72)

Pad samples by adding padding by chunks of size seq_len

The difference with the base pad_input is that most of the padding is applied first (if pad_first=True) or at the end (if pad_first=False) but only by a round multiple of seq_len. The rest of the padding is applied to the end (or the beginning if pad_first=False). This is to work with SequenceEncoder with recurrent models.

test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),1), (tensor([1,2,3]), 2), (tensor([1,2]), 3)], pad_idx=0, seq_len=2), 
        [(tensor([1,2,3,4,5,6]),1), (tensor([0,0,1,2,3,0]),2), (tensor([0,0,0,0,1,2]), 3)])
test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),), (tensor([1,2,3]),), (tensor([1,2]),)], pad_idx=0, seq_len=2), 
        [(tensor([1,2,3,4,5,6]),), (tensor([0,0,1,2,3,0]),), (tensor([0,0,0,0,1,2]),)])
test_eq(pad_input_chunk([(tensor([1,2,3,4,5,6]),), (tensor([1,2,3]),), (tensor([1,2]),)], pad_idx=0, seq_len=2, pad_first=False), 
        [(tensor([1,2,3,4,5,6]),), (tensor([1,2,3,0,0,0]),), (tensor([1,2,0,0,0,0]),)])

class SortedDL[source]

SortedDL(dataset, sort_func=None, res=None, bs=64, shuffle=False, num_workers=None, verbose=False, do_setup=True, pin_memory=False, timeout=0, batch_size=None, drop_last=False, indexed=None, n=None, device=None, wif=None, before_iter=None, after_item=None, before_batch=None, after_batch=None, after_iter=None, create_batches=None, create_item=None, create_batch=None, retain=None, get_idxs=None, sample=None, shuffle_fn=None, do_batch=None) :: TfmdDL

A DataLoader that goes throught the item in the order given by sort_func

res is the result of sort_func applied on all elements of the dataset. You can pass it if available to make the init much faster by avoiding an initial pass over the whole dataset. For example if sorting by text length (as in the default sort_func, called _default_sort) you should pass a list with the length of each element in dataset to res to take advantage of this speed-up.

To get the same init speed-up for the validation set, val_res (a list of text lengths for your validation set) can be passed to the kwargs argument of SortedDL. Below is an example to reduce the init time by passing a list of text lengths for both the training set and the validation set:

# Pass the training dataset text lengths to SortedDL
srtd_dl=partial(SortedDL, res = train_text_lens)

# Pass the validation dataset text lengths 
dl_kwargs = [{},{'val_res': val_text_lens}]

# init our Datasets 
dsets = Datasets(...)   

# init our Dataloaders
dls = dsets.dataloaders(...,dl_type = srtd_dl, dl_kwargs = dl_kwargs)

If shuffle is True, this will shuffle a bit the results of the sort to have items of roughly the same size in batches, but not in the exact sorted order.

ds = [(tensor([1,2]),1), (tensor([3,4,5,6]),2), (tensor([7]),3), (tensor([8,9,10]),4)]
dl = SortedDL(ds, bs=2, before_batch=partial(pad_input, pad_idx=0))
test_eq(list(dl), [(tensor([[ 3,  4,  5,  6], [ 8,  9, 10,  0]]), tensor([2, 4])), 
                   (tensor([[1, 2], [7, 0]]), tensor([1, 3]))])
ds = [(tensor(range(random.randint(1,10))),i) for i in range(101)]
dl = SortedDL(ds, bs=2, create_batch=partial(pad_input, pad_idx=-1), shuffle=True, num_workers=0)
batches = list(dl)
max_len = len(batches[0][0])
for b in batches: 
    assert(len(b[0])) <= max_len 
    test_ne(b[0][-1], -1)

TransformBlock for text

To use the data block API, you will need this build block for texts.

class TextBlock[source]

TextBlock(tok_tfm, vocab=None, is_lm=False, seq_len=72, backwards=False, min_freq=3, max_vocab=60000, special_toks=None, pad_tok=None) :: TransformBlock

A TransformBlock for texts

For efficient tokenization, you probably want to use one of the factory methods. Otherwise, you can pass your custom tok_tfm that will deal with tokenization (if your texts are already tokenized, you can pass noop), a vocab, or leave it to be inferred on the texts using min_freq and max_vocab.

is_lm indicates if we want to use texts for language modeling or another task, seq_len is only necessary to tune if is_lm=False, and is passed along to pad_input_chunk.

TextBlock.from_df[source]

TextBlock.from_df(text_cols, vocab=None, is_lm=False, seq_len=72, backwards=False, min_freq=3, max_vocab=60000, tok=None, rules=None, sep=' ', n_workers=64, mark_fields=None, res_col_name='text', **kwargs)

Build a TextBlock from a dataframe using text_cols

Here is an example using a sample of IMDB stored as a CSV file:

path = untar_data(URLs.IMDB_SAMPLE)
df = pd.read_csv(path/'texts.csv')

imdb_clas = DataBlock(
    blocks=(TextBlock.from_df('text', seq_len=72), CategoryBlock),
    get_x=ColReader('text'), get_y=ColReader('label'), splitter=ColSplitter())

dls = imdb_clas.dataloaders(df, bs=64)
dls.show_batch(max_n=2)
label text is_valid
0 negative Un-bleeping-believable! Meg Ryan doesn't even look her usual pert lovable self in this, which normally makes me forgive her shallow ticky acting schtick. Hard to believe she was the producer on this dog. Plus Kevin Kline: what kind of suicide trip has his career been on? Whoosh... Banzai!!! Finally this was directed by the guy who did Big Chill? Must be a replay of Jonestown - hollywood style. Wooofff! False
1 positive This is a extremely well-made film. The acting, script and camera-work are all first-rate. The music is good, too, though it is mostly early in the film, when things are still relatively cheery. There are no really superstars in the cast, though several faces will be familiar. The entire cast does an excellent job with the script.<br /><br />But it is hard to watch, because there is no good end to a situation like the one presented. It is now fashionable to blame the British for setting Hindus and Muslims against each other, and then cruelly separating them into two countries. There is som... False

vocab, is_lm, seq_len, min_freq and max_vocab are passed to the main init, the other argument to Tokenizer.from_df.

TextBlock.from_folder[source]

TextBlock.from_folder(path, vocab=None, is_lm=False, seq_len=72, backwards=False, min_freq=3, max_vocab=60000, tok=None, rules=None, extensions=None, folders=None, output_dir=None, skip_if_exists=True, output_names=None, n_workers=64, encoding='utf8', **kwargs)

Build a TextBlock from a path

vocab, is_lm, seq_len, min_freq and max_vocab are passed to the main init, the other argument to Tokenizer.from_folder.

class TextDataLoaders[source]

TextDataLoaders(*loaders, path='.', device=None) :: DataLoaders

Basic wrapper around several DataLoaders with factory methods for NLP problems

You should not use the init directly but one of the following factory methods. All those factory methods accept as arguments:

  • text_vocab: the vocabulary used for numericalizing texts (if not passed, it's inferred from the data)
  • tok_tfm: if passed, uses this tok_tfm instead of the default
  • seq_len: the sequence length used for batch
  • bs: the batch size
  • val_bs: the batch size for the validation DataLoader (defaults to bs)
  • shuffle_train: if we shuffle the training DataLoader or not
  • device: the PyTorch device to use (defaults to default_device())

TextDataLoaders.from_folder[source]

TextDataLoaders.from_folder(path, train='train', valid='valid', valid_pct=None, seed=None, vocab=None, text_vocab=None, is_lm=False, tok_tfm=None, seq_len=72, backwards=False, bs=64, val_bs=None, shuffle_train=True, device=None)

Create from imagenet style dataset in path with train and valid subfolders (or provide valid_pct)

If valid_pct is provided, a random split is performed (with an optional seed) by setting aside that percentage of the data for the validation set (instead of looking at the grandparents folder). If a vocab is passed, only the folders with names in vocab are kept.

Here is an example on a sample of the IMDB movie review dataset:

path = untar_data(URLs.IMDB)
dls = TextDataLoaders.from_folder(path)
dls.show_batch(max_n=3)
text category
0 ▁xxbos ▁xxmaj ▁match ▁1: ▁xxmaj ▁tag ▁xxmaj ▁team ▁xxmaj ▁table ▁xxmaj ▁match ▁xxmaj ▁ bub ba ▁xxmaj ▁ray ▁and ▁xxmaj ▁spike ▁xxmaj ▁dudley ▁vs ▁xxmaj ▁eddie ▁xxmaj ▁guerrero ▁and ▁xxmaj ▁chris ▁xxmaj ▁benoit ▁xxmaj ▁ bub ba ▁xxmaj ▁ray ▁and ▁xxmaj ▁spike ▁xxmaj ▁dudley ▁started ▁things ▁off ▁with ▁a ▁xxmaj ▁tag ▁xxmaj ▁team ▁xxmaj ▁table ▁xxmaj ▁match ▁against ▁xxmaj ▁eddie ▁xxmaj ▁guerrero ▁and ▁xxmaj ▁chris ▁xxmaj ▁benoit . ▁xxmaj ▁according ▁to ▁the ▁rules ▁of ▁the ▁match , ▁both ▁opponents ▁have ▁to ▁go ▁through ▁tables ▁in ▁order ▁to ▁get ▁the ▁win . ▁xxmaj ▁benoit ▁and ▁xxmaj ▁guerrero ▁heated ▁up ▁early ▁on ▁by ▁taking ▁turns ▁hammer ing ▁first ▁xxmaj ▁spike ▁and ▁then ▁xxmaj ▁ bub ba ▁xxmaj ▁ray . ▁a ▁xxmaj ▁german ▁su plex ▁by ▁xxmaj ▁benoit ▁to ▁xxmaj ▁ bub ba ▁took ▁the ▁wind ▁out ▁of ▁the ▁xxmaj ▁dudley ▁brother . ▁xxmaj ▁spike ▁tried ▁to ▁help ▁his ▁brother , ▁but ▁the pos
1 xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad neg
2 xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad xxpad pos

TextDataLoaders.from_df[source]

TextDataLoaders.from_df(df, path='.', valid_pct=0.2, seed=None, text_col=0, label_col=1, label_delim=None, y_block=None, text_vocab=None, is_lm=False, valid_col=None, tok_tfm=None, seq_len=72, backwards=False, bs=64, val_bs=None, shuffle_train=True, device=None)

Create from df in path with valid_pct

seed can optionally be passed for reproducibility. text_col, label_col and optionally valid_col are indices or names of columns for texts/labels and the validation flag. label_delim can be passed for a multi-label problem if your labels are in one column, separated by a particular char. y_block should be passed to indicate your type of targets, in case the library did no infer it properly.

Here are examples on subsets of IMDB:

dls = TextDataLoaders.from_df(df, path=path, text_col='text', label_col='label', valid_col='is_valid')
dls.show_batch(max_n=3)
dls = TextDataLoaders.from_df(df, path=path, text_col='text', is_lm=True, valid_col='is_valid')
dls.show_batch(max_n=3)

TextDataLoaders.from_csv[source]

TextDataLoaders.from_csv(path, csv_fname='labels.csv', header='infer', delimiter=None, valid_pct=0.2, seed=None, text_col=0, label_col=1, label_delim=None, y_block=None, text_vocab=None, is_lm=False, valid_col=None, tok_tfm=None, seq_len=72, backwards=False, bs=64, val_bs=None, shuffle_train=True, device=None)

Create from csv file in path/csv_fname

Opens the csv file with header and delimiter, then pass all the other arguments to TextDataLoaders.from_df.

dls = TextDataLoaders.from_csv(path=path, csv_fname='texts.csv', text_col='text', label_col='label', valid_col='is_valid')
dls.show_batch(max_n=3)