from fastai.test_utils import *
Hooks are functions you can attach to a particular layer in your model and that will be executed in the forward pass (for forward hooks) or backward pass (for backward hooks). Here we begin with an introduction around hooks, but you should jump to HookCallback
if you quickly want to implement one (and read the following example ActivationStats
).
Forward hooks are functions that take three arguments: the layer it's applied to, the input of that layer and the output of that layer.
tst_model = nn.Linear(5,3)
def example_forward_hook(m,i,o): print(m,i,o)
x = torch.randn(4,5)
hook = tst_model.register_forward_hook(example_forward_hook)
y = tst_model(x)
hook.remove()
Backward hooks are functions that take three arguments: the layer it's applied to, the gradients of the loss with respect to the input, and the gradients with respect to the output.
def example_backward_hook(m,gi,go): print(m,gi,go)
hook = tst_model.register_backward_hook(example_backward_hook)
x = torch.randn(4,5)
y = tst_model(x)
loss = y.pow(2).mean()
loss.backward()
hook.remove()
Hooks can change the input/output of a layer, or the gradients, print values or shapes. If you want to store something related to theses inputs/outputs, it's best to have your hook associated to a class so that it can put it in the state of an instance of that class.
This will be called during the forward pass if is_forward=True
, the backward pass otherwise, and will optionally detach
, gather
and put on the cpu
the (gradient of the) input/output of the model before passing them to hook_func
. The result of hook_func
will be stored in the stored
attribute of the Hook
.
tst_model = nn.Linear(5,3)
hook = Hook(tst_model, lambda m,i,o: o)
y = tst_model(x)
test_eq(hook.stored, y)
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
y = tst_model(x)
hook = Hook(tst_model, example_forward_hook)
test_stdout(lambda: tst_model(x), f"{tst_model} ({x},) {y.detach()}")
hook.remove()
test_stdout(lambda: tst_model(x), "")
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
y = tst_model(x)
with Hook(tst_model, example_forward_hook) as h:
test_stdout(lambda: tst_model(x), f"{tst_model} ({x},) {y.detach()}")
test_stdout(lambda: tst_model(x), "")
The activations stored are the gradients if grad=True
, otherwise the output of module
. If detach=True
they are detached from their history, and if cpu=True
, they're put on the CPU.
tst_model = nn.Linear(5,10)
x = torch.randn(4,5)
with hook_output(tst_model) as h:
y = tst_model(x)
test_eq(y, h.stored)
assert not h.stored.requires_grad
with hook_output(tst_model, grad=True) as h:
y = tst_model(x)
loss = y.pow(2).mean()
loss.backward()
test_close(2*y / y.numel(), h.stored[0])
with hook_output(tst_model, cpu=True) as h:
y = tst_model.cuda()(x.cuda())
test_eq(h.stored.device, torch.device('cpu'))
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
hooks = Hooks(tst_model, lambda m,i,o: o)
y = tst_model(x)
test_eq(hooks.stored[0], layers[0](x))
test_eq(hooks.stored[1], F.relu(layers[0](x)))
test_eq(hooks.stored[2], y)
hooks.remove()
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
with Hooks(layers, lambda m,i,o: o) as h:
y = tst_model(x)
test_eq(h.stored[0], layers[0](x))
test_eq(h.stored[1], F.relu(layers[0](x)))
test_eq(h.stored[2], y)
The activations stored are the gradients if grad=True
, otherwise the output of modules
. If detach=True
they are detached from their history, and if cpu=True
, they're put on the CPU.
layers = [nn.Linear(5,10), nn.ReLU(), nn.Linear(10,3)]
tst_model = nn.Sequential(*layers)
x = torch.randn(4,5)
with hook_outputs(layers) as h:
y = tst_model(x)
test_eq(h.stored[0], layers[0](x))
test_eq(h.stored[1], F.relu(layers[0](x)))
test_eq(h.stored[2], y)
for s in h.stored: assert not s.requires_grad
with hook_outputs(layers, grad=True) as h:
y = tst_model(x)
loss = y.pow(2).mean()
loss.backward()
g = 2*y / y.numel()
test_close(g, h.stored[2][0])
g = g @ layers[2].weight.data
test_close(g, h.stored[1][0])
g = g * (layers[0](x) > 0).float()
test_close(g, h.stored[0][0])
with hook_outputs(tst_model, cpu=True) as h:
y = tst_model.cuda()(x.cuda())
for s in h.stored: test_eq(s.device, torch.device('cpu'))
m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))
test_eq(model_sizes(m), [[1, 16, 64, 64], [1, 32, 32, 32], [1, 32, 32, 32]])
m = nn.Sequential(nn.Conv2d(5,4,3), nn.Conv2d(4,3,3))
test_eq(num_features_model(m), 3)
m = nn.Sequential(ConvLayer(3, 16), ConvLayer(16, 32, stride=2), ConvLayer(32, 32))
test_eq(num_features_model(m), 32)
To make hooks easy to use, we wrapped a version in a Callback where you just have to implement a hook
function (plus any element you might need).
assert has_params(nn.Linear(3,4))
assert has_params(nn.LSTM(4,5,2))
assert not has_params(nn.ReLU())
You can either subclass and implement a hook
function (along with any event you want) or pass that a hook
function when initializing. Such a function needs to take three argument: a layer, input and output (for a backward hook, input means gradient with respect to the inputs, output, gradient with respect to the output) and can either modify them or update the state according to them.
If not provided, modules
will default to the layers of self.model
that have a weight
attribute. Depending on do_remove
, the hooks will be properly removed at the end of training (or in case of error). is_forward
, detach
and cpu
are passed to Hooks
.
The function called at each forward (or backward) pass is self.hook
and must be implemented when subclassing this callback.
class TstCallback(HookCallback):
def hook(self, m, i, o): return o
def after_batch(self): test_eq(self.hooks.stored[0], self.pred)
learn = synth_learner(n_trn=5, cbs = TstCallback())
learn.fit(1)
class TstCallback(HookCallback):
def __init__(self, modules=None, remove_end=True, detach=True, cpu=False):
super().__init__(modules, None, remove_end, False, detach, cpu)
def hook(self, m, i, o): return o
def after_batch(self):
if self.training:
test_eq(self.hooks.stored[0][0], 2*(self.pred-self.y)/self.pred.shape[0])
learn = synth_learner(n_trn=5, cbs = TstCallback())
learn.fit(1)
test_eq(total_params(nn.Linear(10,32)), (32*10+32,True))
test_eq(total_params(nn.Linear(10,32, bias=False)), (32*10,True))
test_eq(total_params(nn.BatchNorm2d(20)), (20*2, True))
test_eq(total_params(nn.BatchNorm2d(20, affine=False)), (0,False))
test_eq(total_params(nn.Conv2d(16, 32, 3)), (16*32*3*3 + 32, True))
test_eq(total_params(nn.Conv2d(16, 32, 3, bias=False)), (16*32*3*3, True))
#First ih layer 20--10, all else 10--10. *4 for the four gates
test_eq(total_params(nn.LSTM(20, 10, 2)), (4 * (20*10 + 10) + 3 * 4 * (10*10 + 10), True))
def _m(): return nn.Sequential(nn.Linear(1,50), nn.ReLU(), nn.BatchNorm1d(50), nn.Linear(50, 1))
sample_input = torch.randn((16, 1))
test_eq(layer_info(synth_learner(model=_m()), sample_input), [
('Linear', 100, True, [1, 50]),
('ReLU', 0, False, [1, 50]),
('BatchNorm1d', 100, True, [1, 50]),
('Linear', 51, True, [1, 1])
])
learn = synth_learner(model=_m())
learn.summary()
This is an example of a HookCallback
, that stores the mean, stds and histograms of activations that go through the network.
@delegates()
class ActivationStats(HookCallback):
"Callback that record the mean and std of activations."
run_before=TrainEvalCallback
def __init__(self, with_hist=False, **kwargs):
super().__init__(**kwargs)
self.with_hist = with_hist
def before_fit(self):
"Initialize stats."
super().before_fit()
self.stats = L()
def hook(self, m, i, o):
o = o.float()
res = {'mean': o.mean().item(), 'std': o.std().item(),
'near_zero': (o<=0.05).long().sum().item()/o.numel()}
if self.with_hist: res['hist'] = o.histc(40,0,10)
return res
def after_batch(self):
"Take the stored results and puts it in `self.stats`"
if self.training and (self.every is None or self.train_iter%self.every == 0):
self.stats.append(self.hooks.stored)
super().after_batch()
def layer_stats(self, idx):
lstats = self.stats.itemgot(idx)
return L(lstats.itemgot(o) for o in ('mean','std','near_zero'))
def hist(self, idx):
res = self.stats.itemgot(idx).itemgot('hist')
return torch.stack(tuple(res)).t().float().log1p()
def color_dim(self, idx, figsize=(10,5), ax=None):
"The 'colorful dimension' plot"
res = self.hist(idx)
if ax is None: ax = subplots(figsize=figsize)[1][0]
ax.imshow(res, origin='lower')
ax.axis('off')
def plot_layer_stats(self, idx):
_,axs = subplots(1, 3, figsize=(12,3))
for o,ax,title in zip(self.layer_stats(idx),axs,('mean','std','% near zero')):
ax.plot(o)
ax.set_title(title)
learn = synth_learner(n_trn=5, cbs = ActivationStats(every=4))
learn.fit(1)
learn.activation_stats.stats
The first line contains the means of the outputs of the model for each batch in the training set, the second line their standard deviations.
import math
def test_every(n_tr, every):
"create a learner, fit, then check number of stats collected"
learn = synth_learner(n_trn=n_tr, cbs=ActivationStats(every=every))
learn.fit(1)
expected_stats_len = math.ceil(n_tr / every)
test_eq(expected_stats_len, len(learn.activation_stats.stats))
for n_tr in [11, 12, 13]:
test_every(n_tr, 4)
test_every(n_tr, 1)