Basic pytorch functions used in the fastai library
from PIL import Image

Arrays and show

subplots[source]

subplots(nrows=1, ncols=1, figsize=None, imsize=3, add_vert=0, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None, **kwargs)

show_image[source]

show_image(im, ax=None, figsize=None, title=None, ctx=None, cmap=None, norm=None, aspect=None, interpolation=None, alpha=None, vmin=None, vmax=None, origin=None, extent=None, filternorm=True, filterrad=4.0, resample=None, url=None, data=None, **kwargs)

Show a PIL or PyTorch image on ax.

show_image can show PIL images...

im = Image.open(TEST_IMAGE_BW)
ax = show_image(im, cmap="Greys")

...and color images with standard CHW dim order...

im2 = np.array(Image.open(TEST_IMAGE))
ax = show_image(im2, figsize=(2,2))

...and color images with HWC dim order...

im3 = torch.as_tensor(im2).permute(2,0,1)
ax = show_image(im3, figsize=(2,2))

show_titled_image[source]

show_titled_image(o, ax=None, figsize=None, title=None, ctx=None, cmap=None, norm=None, aspect=None, interpolation=None, alpha=None, vmin=None, vmax=None, origin=None, extent=None, filternorm=True, filterrad=4.0, resample=None, url=None, data=None, **kwargs)

Call show_image destructuring o to (img,title)

show_titled_image((im3,'A puppy'), figsize=(2,2))

show_images[source]

show_images(ims, nrows=1, ncols=None, titles=None, figsize=None, imsize=3, add_vert=0, sharex=False, sharey=False, squeeze=True, subplot_kw=None, gridspec_kw=None)

Show all images ims as subplots with rows using titles

show_images((im,im3), titles=('number','puppy'), imsize=2)

ArrayImage, ArrayImageBW and ArrayMask are subclasses of ndarray that know how to show themselves.

class ArrayBase[source]

ArrayBase() :: ndarray

An ndarray that can modify casting behavior

class ArrayImageBase[source]

ArrayImageBase() :: ArrayBase

Base class for arrays representing images

class ArrayImage[source]

ArrayImage() :: ArrayImageBase

An array representing an image

class ArrayImageBW[source]

ArrayImageBW() :: ArrayImage

An array representing an image

class ArrayMask[source]

ArrayMask() :: ArrayImageBase

An array representing an image mask

im = Image.open(TEST_IMAGE)
im_t = cast(im, ArrayImage)
test_eq(type(im_t), ArrayImage)
ax = im_t.show(figsize=(2,2))
test_fig_exists(ax)

Basics

Tensor.__array_eq__[source]

Tensor.__array_eq__(b)

tensor[source]

tensor(x, *rest, dtype=None, device=None, requires_grad=False, pin_memory=False)

Like torch.as_tensor, but handle lists too, and can pass multiple vector elements directly.

test_eq(tensor(torch.tensor([1,2,3])), torch.tensor([1,2,3]))
test_eq(tensor(array([1,2,3])), torch.tensor([1,2,3]))
test_eq(tensor(1,2,3), torch.tensor([1,2,3]))
test_eq_type(tensor(1.0), torch.tensor(1.0))

set_seed[source]

set_seed(s, reproducible=False)

Set random seed for random, torch, and numpy (where available)

set_seed(2*33)
a1 = np.random.random()
a2 = torch.rand(())
a3 = random.random()
set_seed(2*33)
b1 = np.random.random()
b2 = torch.rand(())
b3 = random.random()
test_eq(a1,b1)
test_eq(a2,b2)
test_eq(a3,b3)

unsqueeze[source]

unsqueeze(x, dim=-1, n=1)

Same as torch.unsqueeze but can add n dims

t = tensor([1])
t2 = unsqueeze(t, n=2)
test_eq(t2,t[:,None,None])

unsqueeze_[source]

unsqueeze_(x, dim=-1, n=1)

Same as torch.unsqueeze_ but can add n dims

t = tensor([1])
unsqueeze_(t, n=2)
test_eq(t, tensor([1]).view(1,1,1))

apply[source]

apply(func, x, *args, **kwargs)

Apply func recursively to x, passing on args

maybe_gather[source]

maybe_gather(x, axis=0)

Gather copies of x on axis (if training is distributed)

to_detach[source]

to_detach(b, cpu=True, gather=True)

Recursively detach lists of tensors in b; put them on the CPU if cpu=True.

gather only applies during distributed training and the result tensor will be the one gathered across processes if gather=True (as a result, the batch size will be multiplied by the number of processes).

to_half[source]

to_half(b)

Recursively map lists of tensors in b to FP16.

to_float[source]

to_float(b)

Recursively map lists of int tensors in b to float.

default_device[source]

default_device(use_cuda=-1)

Return or set default device; use_cuda: None - CUDA if available; True - error if not available; False - CPU

_td = torch.device(torch.cuda.current_device())
test_eq(default_device(None), _td)
test_eq(default_device(True), _td)
test_eq(default_device(False), torch.device('cpu'))
default_device(None);

to_device[source]

to_device(b, device=None)

Recursively put b on device.

t = to_device((3,(tensor(3),tensor(2))))
t1,(t2,t3) = t
test_eq_type(t,(3,(tensor(3).cuda(),tensor(2).cuda())))
test_eq(t2.type(), "torch.cuda.LongTensor")
test_eq(t3.type(), "torch.cuda.LongTensor")

to_cpu[source]

to_cpu(b)

Recursively map lists of tensors in b to the cpu.

t3 = to_cpu(t3)
test_eq(t3.type(), "torch.LongTensor")
test_eq(t3, 2)

to_np[source]

to_np(x)

Convert a tensor to a numpy array.

t3 = to_np(t3)
test_eq(type(t3), np.ndarray)
test_eq(t3, 2)

to_concat[source]

to_concat(xs, dim=0)

Concat the element in xs (recursively if they are tuples/lists of tensors)

test_eq(to_concat([tensor([1,2]), tensor([3,4])]), tensor([1,2,3,4]))
test_eq(to_concat([tensor([[1,2]]), tensor([[3,4]])], dim=1), tensor([[1,2,3,4]]))
test_eq_type(to_concat([(tensor([1,2]), tensor([3,4])), (tensor([3,4]), tensor([5,6]))]), (tensor([1,2,3,4]), tensor([3,4,5,6])))
test_eq_type(to_concat([[tensor([1,2]), tensor([3,4])], [tensor([3,4]), tensor([5,6])]]), [tensor([1,2,3,4]), tensor([3,4,5,6])])
test_eq_type(to_concat([(tensor([1,2]),), (tensor([3,4]),)]), (tensor([1,2,3,4]),))

test_eq(to_concat([tensor([[1,2]]), tensor([[3,4], [5,6]])], dim=1), [tensor([1]),tensor([3, 5]),tensor([4, 6])])
test_eq(type(to_concat([dict(foo=tensor([1,2]), bar=tensor(3,4))])), dict)

Tensor subtypes

Tensor.set_meta[source]

Tensor.set_meta(x, copy_meta=False)

Set all metadata in __dict__

Tensor.get_meta[source]

Tensor.get_meta(n, d=None)

Set n from self._meta if it exists and returns default d otherwise

Tensor.as_subclass[source]

Tensor.as_subclass(typ)

Cast to typ and include __dict__ and meta

Tensor.set_meta and Tensor.as_subclass work together to maintain _meta after casting.

class _T(Tensor): pass
t = tensor(1.).requires_grad_()
t._meta = {'img_size': 1}
t2 = t.as_subclass(_T)
test_eq(t._meta, t2._meta)
test_eq(t2.get_meta('img_size'), 1)
assert(t2.requires_grad_)

class TensorBase[source]

TensorBase(x, **kwargs) :: Tensor

class TensorCategory[source]

TensorCategory(x, **kwargs) :: TensorBase

class TensorMultiCategory[source]

TensorMultiCategory(x, **kwargs) :: TensorCategory

class _T(TensorBase): pass
t = _T(range(5))
test_eq(t[0], 0)
test_eq_type(t.gi(0), _T(0))
test_eq_type(t.gi(slice(2)), _T([0,1]))
test_eq_type(t+1, _T(range(1,6)))
test_eq(repr(t), '_T([0, 1, 2, 3, 4])')

test_eq(type(pickle.loads(pickle.dumps(t))), _T)
t = tensor([1,2,3])
m = TensorBase([False,True,True])
test_eq(t[m], tensor([2,3]))
t = tensor([[1,2,3],[1,2,3]])
m = cast(tensor([[False,True,True],
                 [False,True,True]]), TensorBase)
test_eq(t[m], tensor([2,3,2,3]))
t = tensor([[1,2,3],[1,2,3]])
t._meta = {'img_size': 1}
t2 = cast(t, TensorBase)
test_eq(t2._meta, t._meta)
x = retain_type(tensor([4,5,6]), t2)
test_eq(x._meta, t._meta)
t3 = TensorBase([[1,2,3],[1,2,3]], img_size=1)
test_eq(t3._meta, t._meta)
t4 = t2+1
t4._meta['img_size'] = 2
test_eq(t2._meta, {'img_size': 1})
test_eq(t4._meta, {'img_size': 2})

class TensorImageBase[source]

TensorImageBase(x, **kwargs) :: TensorBase

class TensorImage[source]

TensorImage(x, **kwargs) :: TensorImageBase

class TensorImageBW[source]

TensorImageBW(x, **kwargs) :: TensorImage

class TensorMask[source]

TensorMask(x, **kwargs) :: TensorImageBase

im = Image.open(TEST_IMAGE)
im_t = cast(array(im), TensorImage)
test_eq(type(im_t), TensorImage)
im_t2 = cast(tensor(1), TensorMask)
test_eq(type(im_t2), TensorMask)
test_eq(im_t2, tensor(1))
ax = im_t.show(figsize=(2,2))
test_fig_exists(ax)
test_eq_type(to_concat([TensorImage([1,2]), TensorImage([3,4])]), TensorImage([1,2,3,4]))

class TitledTensorScalar[source]

TitledTensorScalar(x, **kwargs) :: TensorBase

A tensor containing a scalar that has a show method

L.tensored[source]

L.tensored()

mapped(tensor)

There are shortcuts for torch.stack and torch.cat if your L contains tensors or something convertible. You can manually convert with tensored.

t = L(([1,2],[3,4]))
test_eq(t.tensored(), [tensor(1,2),tensor(3,4)])

L.stack[source]

L.stack(dim=0)

Same as torch.stack

test_eq(t.stack(), tensor([[1,2],[3,4]]))

L.cat[source]

L.cat(dim=0)

Same as torch.cat

test_eq(t.cat(), tensor([1,2,3,4]))

Chunks

concat[source]

concat(*ls)

Concatenate tensors, arrays, lists, or tuples

a,b,c = [1],[1,2],[1,1,2]
test_eq(concat(a,b), c)
test_eq_type(concat(tuple (a),tuple (b)), tuple (c))
test_eq_type(concat(array (a),array (b)), array (c))
test_eq_type(concat(tensor(a),tensor(b)), tensor(c))
test_eq_type(concat(TensorBase(a),TensorBase(b)), TensorBase(c))
test_eq_type(concat([1,1],1), [1,1,1])
test_eq_type(concat(1,1,1), L(1,1,1))
test_eq_type(concat(L(1,2),1), L(1,2,1))

class Chunks[source]

Chunks(chunks, lens=None)

Slice and int indexing into a list of lists

docs = L(list(string.ascii_lowercase[a:b]) for a,b in ((0,3),(3,7),(7,8),(8,16),(16,24),(24,26)))

b = Chunks(docs)
test_eq([b[ o] for o in range(0,5)], ['a','b','c','d','e'])
test_eq([b[-o] for o in range(1,6)], ['z','y','x','w','v'])
test_eq(b[6:13], 'g,h,i,j,k,l,m'.split(','))
test_eq(b[20:77], 'u,v,w,x,y,z'.split(','))
test_eq(b[:5], 'a,b,c,d,e'.split(','))
test_eq(b[:2], 'a,b'.split(','))
t = torch.arange(26)
docs = L(t[a:b] for a,b in ((0,3),(3,7),(7,8),(8,16),(16,24),(24,26)))
b = Chunks(docs)
test_eq([b[ o] for o in range(0,5)], range(0,5))
test_eq([b[-o] for o in range(1,6)], [25,24,23,22,21])
test_eq(b[6:13], torch.arange(6,13))
test_eq(b[20:77], torch.arange(20,26))
test_eq(b[:5], torch.arange(5))
test_eq(b[:2], torch.arange(2))
docs = L(TensorBase(t[a:b]) for a,b in ((0,3),(3,7),(7,8),(8,16),(16,24),(24,26)))
b = Chunks(docs)
test_eq_type(b[:2], TensorBase(range(2)))
test_eq_type(b[:5], TensorBase(range(5)))
test_eq_type(b[9:13], TensorBase(range(9,13)))

Simple types

show_title[source]

show_title(o, ax=None, ctx=None, label=None, color='black', **kwargs)

Set title of ax to o, or print o if ax is None

test_stdout(lambda: show_title("title"), "title")
# ensure that col names are unique when showing to a pandas series
assert show_title("title", ctx=pd.Series(dict(a=1)), label='a').equals(pd.Series(dict(a=1,a_='title')))

class ShowTitle[source]

ShowTitle()

Base class that adds a simple show

class TitledInt[source]

TitledInt() :: Int

An int with show

class TitledStr[source]

TitledStr() :: Str

An str with show

class TitledFloat[source]

TitledFloat(x=0) :: Float

A float with show

test_stdout(lambda: TitledStr('s').show(), 's')
test_stdout(lambda: TitledInt(1).show(), '1')

class TitledTuple[source]

TitledTuple(x=None, *rest) :: fastuple

A fastuple with show

TitledStr.truncate[source]

TitledStr.truncate(n)

Truncate self to n

Other functions

DataFrame.__init__[source]

DataFrame.__init__(data=None, index=None, columns=None, dtype=None, copy=False)

get_empty_df[source]

get_empty_df(n)

Return n empty rows of a dataframe

display_df[source]

display_df(df)

Display df in a notebook or defaults to print

get_first[source]

get_first(c)

Get the first element of c, even if c is a dataframe

one_param[source]

one_param(m)

First parameter in m

item_find[source]

item_find(x, idx=0)

Recursively takes the idx-th element of x

find_device[source]

find_device(b)

Recursively search the device of b.

t2 = to_device(tensor(0))
dev = default_device()
test_eq(find_device(t2), dev)
test_eq(find_device([t2,t2]), dev)
test_eq(find_device({'a':t2,'b':t2}), dev)
test_eq(find_device({'a':[[t2],[t2]],'b':t2}), dev)

find_bs[source]

find_bs(b)

Recursively search the batch size of b.

x = torch.randn(4,5)
test_eq(find_bs(x), 4)
test_eq(find_bs([x, x]), 4)
test_eq(find_bs({'a':x,'b':x}), 4)
test_eq(find_bs({'a':[[x],[x]],'b':x}), 4)

np_func[source]

np_func(f)

Convert a function taking and returning numpy arrays to one taking and returning tensors

This decorator is particularly useful for using numpy functions as fastai metrics, for instance:

from sklearn.metrics import f1_score
@np_func
def f1(inp,targ): return f1_score(targ, inp)

a1,a2 = array([0,1,1]),array([1,0,1])
t = f1(tensor(a1),tensor(a2))
test_eq(f1_score(a1,a2), t)
assert isinstance(t,Tensor)

class Module[source]

Module() :: Module

Same as nn.Module, but no need for subclasses to call super().__init__

class _T(Module):
    def __init__(self): self.f = nn.Linear(1,1)
    def forward(self,x): return self.f(x)

t = _T()
t(tensor([1.]))
tensor([-1.0893], grad_fn=<AddBackward0>)

get_model[source]

get_model(model)

Return the model maybe wrapped inside model.

one_hot[source]

one_hot(x, c)

One-hot encode x with c classes.

test_eq(one_hot([1,4], 5), tensor(0,1,0,0,1).byte())
test_eq(one_hot(torch.tensor([]), 5), tensor(0,0,0,0,0).byte())
test_eq(one_hot(2, 5), tensor(0,0,1,0,0).byte())

one_hot_decode[source]

one_hot_decode(x, vocab=None)

test_eq(one_hot_decode(tensor(0,1,0,0,1)), [1,4])
test_eq(one_hot_decode(tensor(0,0,0,0,0)), [   ])
test_eq(one_hot_decode(tensor(0,0,1,0,0)), [2  ])

params[source]

params(m)

Return all parameters of m

trainable_params[source]

trainable_params(m)

Return all trainable parameters of m

m = nn.Linear(4,5)
test_eq(trainable_params(m), [m.weight, m.bias])
m.weight.requires_grad_(False)
test_eq(trainable_params(m), [m.bias])

norm_bias_params[source]

norm_bias_params(m, with_bias=True)

Return all bias and BatchNorm parameters

for norm_func in [nn.BatchNorm1d, partial(nn.InstanceNorm1d, affine=True)]:
    model = nn.Sequential(nn.Linear(10,20), norm_func(20), nn.Conv1d(3,4, 3))
    test_eq(norm_bias_params(model), [model[0].bias, model[1].weight, model[1].bias, model[2].bias])
    model = nn.ModuleList([nn.Linear(10,20, bias=False), nn.Sequential(norm_func(20), nn.Conv1d(3,4,3))])
    test_eq(norm_bias_params(model), [model[1][0].weight, model[1][0].bias, model[1][1].bias])
    model = nn.ModuleList([nn.Linear(10,20), nn.Sequential(norm_func(20), nn.Conv1d(3,4,3))])
    test_eq(norm_bias_params(model, with_bias=False), [model[1][0].weight, model[1][0].bias])

batch_to_samples[source]

batch_to_samples(b, max_n=10)

'Transposes' a batch to (at most max_n) samples

t = tensor([1,2,3])
test_eq(batch_to_samples([t,t+1], max_n=2), ([1,2],[2,3]))
test_eq(batch_to_samples(tensor([1,2,3]), 10), [1, 2, 3])
test_eq(batch_to_samples([tensor([1,2,3]), tensor([4,5,6])], 10), [(1, 4), (2, 5), (3, 6)])
test_eq(batch_to_samples([tensor([1,2,3]), tensor([4,5,6])], 2), [(1, 4), (2, 5)])
test_eq(batch_to_samples([tensor([1,2,3]), [tensor([4,5,6]),tensor([7,8,9])]], 10), 
        [(1, (4, 7)), (2, (5, 8)), (3, (6, 9))])
test_eq(batch_to_samples([tensor([1,2,3]), [tensor([4,5,6]),tensor([7,8,9])]], 2), [(1, (4, 7)), (2, (5, 8))])

t = fastuple(tensor([1,2,3]),TensorBase([2,3,4]))
test_eq_type(batch_to_samples(t)[0][1], TensorBase(2))
test_eq(batch_to_samples(t).map(type), [fastuple]*3)

Tensor.interp_1d[source]

Tensor.interp_1d(x:Tensor, xp, fp)

Same as np.interp

brks = tensor(0,1,2,4,8,64).float()
ys = tensor(range_of(brks)).float()
ys /= ys[-1].item()
pts = tensor(0.2,0.5,0.8,3,5,63)

preds = pts.interp_1d(brks, ys)
test_close(preds.numpy(), np.interp(pts.numpy(), brks.numpy(), ys.numpy()))

plt.scatter(brks,ys)
plt.scatter(pts,preds)
plt.legend(['breaks','preds']);

Tensor.pca[source]

Tensor.pca(x:Tensor, k=2)

Compute PCA of x with k dimensions.

logit[source]

logit(x)

Logit of x, clamped to avoid inf.

num_distrib[source]

num_distrib()

Return the number of processes in distributed training (if applicable).

rank_distrib[source]

rank_distrib()

Return the distributed rank of this process (if applicable).

distrib_barrier[source]

distrib_barrier()

Place a synchronization barrier in distributed training so that ALL sub-processes in the pytorch process group must arrive here before proceeding.

Path.save_array[source]

Path.save_array(p:Path, o, complib='lz4', lvl=3)

Save numpy array to a compressed pytables file, using compression level lvl

Compression lib can be any of: blosclz, lz4, lz4hc, snappy, zlib or zstd.

Path.load_array[source]

Path.load_array(p:Path)

Save numpy array to a pytables file

inspect.getdoc(load_array)
'Save numpy array to a `pytables` file'
str(inspect.signature(load_array))
'(p: pathlib.Path)'

base_doc[source]

base_doc(elt)

Print a base documentation of elt

doc[source]

doc(elt)

Try to use doc form nbdev and fall back to base_doc

nested_reorder[source]

nested_reorder(t, idxs)

Reorder all tensors in t using idxs

x = tensor([0,1,2,3,4,5])
idxs = tensor([2,5,1,0,3,4])
test_eq_type(nested_reorder(([x], x), idxs), ([idxs], idxs))

y = L(0,1,2,3,4,5)
z = L(i.item() for i in idxs)
test_eq_type(nested_reorder((y, x), idxs), (z,idxs))

Image helpers

make_cross_image[source]

make_cross_image(bw=True)

Create a tensor containing a cross image, either bw (True) or color

plt.imshow(make_cross_image(), cmap="Greys");
plt.imshow(make_cross_image(False).permute(1,2,0));

show_image_batch[source]

show_image_batch(b, show=show_titled_image, items=9, cols=3, figsize=None, **kwargs)

Display batch b in a grid of size items with cols width

show_image_batch(([Image.open(TEST_IMAGE_BW),Image.open(TEST_IMAGE)],['bw','color']), items=2)

Model init

requires_grad[source]

requires_grad(m)

Check if the first parameter of m requires grad or not

tst = nn.Linear(4,5)
assert requires_grad(tst)
for p in tst.parameters(): p.requires_grad_(False)
assert not requires_grad(tst)

init_default[source]

init_default(m, func=kaiming_normal_)

Initialize m weights with func and set bias to 0.

tst = nn.Linear(4,5)
tst.weight.data.uniform_(-1,1)
tst.bias.data.uniform_(-1,1)
tst = init_default(tst, func = lambda x: x.data.fill_(1.))
test_eq(tst.weight, torch.ones(5,4))
test_eq(tst.bias, torch.zeros(5))

cond_init[source]

cond_init(m, func)

Apply init_default to m unless it's a batchnorm module

tst = nn.Linear(4,5)
tst.weight.data.uniform_(-1,1)
tst.bias.data.uniform_(-1,1)
cond_init(tst, func = lambda x: x.data.fill_(1.))
test_eq(tst.weight, torch.ones(5,4))
test_eq(tst.bias, torch.zeros(5))

tst = nn.BatchNorm2d(5)
init = [tst.weight.clone(), tst.bias.clone()]
cond_init(tst, func = lambda x: x.data.fill_(1.))
test_eq(tst.weight, init[0])
test_eq(tst.bias, init[1])

apply_leaf[source]

apply_leaf(m, f)

Apply f to children of m.

tst = nn.Sequential(nn.Linear(4,5), nn.Sequential(nn.Linear(4,5), nn.Linear(4,5)))
apply_leaf(tst, partial(init_default, func=lambda x: x.data.fill_(1.)))
for l in [tst[0], *tst[1]]: test_eq(l.weight, torch.ones(5,4))
for l in [tst[0], *tst[1]]: test_eq(l.bias,   torch.zeros(5))

apply_init[source]

apply_init(m, func=kaiming_normal_)

Initialize all non-batchnorm layers of m with func.

tst = nn.Sequential(nn.Linear(4,5), nn.Sequential(nn.Linear(4,5), nn.BatchNorm1d(5)))
init = [tst[1][1].weight.clone(), tst[1][1].bias.clone()]
apply_init(tst, func=lambda x: x.data.fill_(1.))
for l in [tst[0], tst[1][0]]: test_eq(l.weight, torch.ones(5,4))
for l in [tst[0], tst[1][0]]: test_eq(l.bias,   torch.zeros(5))
test_eq(tst[1][1].weight, init[0])
test_eq(tst[1][1].bias,   init[1])

autograd jit functions

script_use_ctx[source]

script_use_ctx(f)

Decorator: create jit script and pass everything in ctx.saved_variables tof, after*args`

script_save_ctx[source]

script_save_ctx(static, *argidx)

Decorator: create jit script and save args with indices argidx using ctx.save_for_backward

script_fwd[source]

script_fwd(*argidx)

Decorator: create static jit script and save args with indices argidx using ctx.save_for_backward

script_bwd[source]

script_bwd(f)

Decorator: create static jit script and pass everything in ctx.saved_variables tof, after*args`

grad_module[source]

grad_module()

Decorator: convert cls into an autograd function