本文共 6989 字,大约阅读时间需要 23 分钟。
Hook 函数是在不改变主体的情况下,实现额外功能。由于 PyTorch 是基于动态图实现的,因此在一次迭代运算结束后,一些中间变量如非叶子节点的梯度和特征图,会被释放掉。在这种情况下想要提取和记录这些中间变量,就需要使用 Hook 函数。
PyTorch 提供了 4 种 Hook 函数。
1 Hook 函数
1-1 torch.Tensor.register_hook(hook)
功能: 注册一个反向传播 hook 函数,仅输入一个参数,为张量的梯度。
hook函数:hook(grad)
参数:
grad:张量的梯度import torch# x,y 为leaf节点,也就是说,在计算的时候,PyTorch只会保留此节点的梯度值x = torch.tensor([3.], requires_grad=True)y = torch.tensor([5.], requires_grad=True)# a,b均为中间值,在计算梯度时,此部分会被释放掉a = x + yb = x * yc = a * b# 新建列表,用于存储Hook函数保存的中间梯度值a_grad = []def hook_grad(grad): a_grad.append(grad)# register_hook的参数为一个函数handle = a.register_hook(hook_grad)c.backward()# 只有leaf节点才会有梯度值print('gradient:',x.grad, y.grad, a.grad, b.grad, c.grad)# Hook函数保留下来的中间节点a的梯度print('a_grad:', a_grad[0])# 移除Hook函数handle.remove()
1-2 torch.nn.Module.register_forward_hook(hook)
功能: 注册 module 的前向传播hook函数,可用于获取中间的 feature map。
hook函数:hook(module, input, output)
参数:
module:当前网络层 input:当前网络层输入数据 output:当前网络层输出数据import torchimport torch.nn as nn# 构建网网络,一个卷积层一个池化层class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x# 初始化网络net = Net()# detach将张量分离net.conv1.weight[0].detach().fill_(1)net.conv1.weight[1].detach().fill_(2)net.conv1.bias.detach().zero_()# 构建两个列表用于保存信息fmap_block = []input_block = []def forward_hook(module, data_input, data_output): fmap_block.append(data_output) input_block.append(data_input)# 注册Hooknet.conv1.register_forward_hook(forward_hook)# 输入数据fake_img = torch.ones((1, 1, 4, 4))output = net(fake_img)# 观察结果# 卷积神经网络输出维度和结果print("output share:{}\noutput value:{}\n".format(output.size(),output))# 卷积神经网络Hook函数返回的结果print("feature map share:{}\noutput value:{}\n".format(fmap_block[0].shape,fmap_block[0]))# 输入的信息print("input share:{}\ninput value:{}\n".format(input_block[0][0].size(),input_block[0][0]))
1-3 torch.Tensor.register_forward_pre_hook()
功能: 注册 module 的前向传播前的hook函数,可用于获取输入数据。
hook函数:hook(module, input)
参数:
module:当前网络层 input:当前网络层输入数据1-4 torch.Tensor.register_backward_hook()
功能: 注册 module 的反向传播的hook函数,可用于获取梯度。
hook函数:hook(module, grad_input, grad_output)
参数:
module:当前网络层 input:当前网络层输入的梯度数据 output:当前网络层输出的梯度数据class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x def forward_hook(module, data_input, data_output): fmap_block.append(data_output) input_block.append(data_input) def forward_pre_hook(module, data_input): print("forward_pre_hook input:{}".format(data_input)) def backward_hook(module, grad_input, grad_output): print("backward hook input:{}".format(grad_input)) print("backward hook output:{}".format(grad_output)) # 初始化网络 net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 注册hook fmap_block = list() input_block = list() net.conv1.register_forward_hook(forward_hook) net.conv1.register_forward_pre_hook(forward_pre_hook) net.conv1.register_backward_hook(backward_hook) # inference fake_img = torch.ones((1, 1, 4, 4)) # batch size * channel * H * W output = net(fake_img) loss_fnc = nn.L1Loss() target = torch.randn_like(output) loss = loss_fnc(target, output) loss.backward()
2 hook函数实现机制
hook函数实现的原理是在module的__call()__函数进行拦截,__call()__函数可以分为 4 个部分:
第 1 部分是实现 _forward_pre_hooks 第 2 部分是实现 forward 前向传播 第 3 部分是实现 _forward_hooks 第 4 部分是实现 _backward_hooks 由于卷积层也是一个module,因此可以记录_forward_hooks。def __call__(self, *input, **kwargs): # 第 1 部分是实现 _forward_pre_hooks for hook in self._forward_pre_hooks.values(): result = hook(self, input) if result is not None: if not isinstance(result, tuple): result = (result,) input = result # 第 2 部分是实现 forward 前向传播 if torch._C._get_tracing_state(): result = self._slow_forward(*input, **kwargs) else: result = self.forward(*input, **kwargs) # 第 3 部分是实现 _forward_hooks for hook in self._forward_hooks.values(): hook_result = hook(self, input, result) if hook_result is not None: result = hook_result # 第 4 部分是实现 _backward_hooks if len(self._backward_hooks) > 0: var = result while not isinstance(var, torch.Tensor): if isinstance(var, dict): var = next((v for v in var.values() if isinstance(v, torch.Tensor))) else: var = var[0] grad_fn = var.grad_fn if grad_fn is not None: for hook in self._backward_hooks.values(): wrapper = functools.partial(hook, self) functools.update_wrapper(wrapper, hook) grad_fn.register_hook(wrapper) return result
3 示例
3-1 Hook 函数提取网络的特征图
下面通过hook函数获取 AlexNet 每个卷积层的所有卷积核参数,以形状作为 key,value 对应该层多个卷积核的 list。然后取出每层的第一个卷积核,形状是 [1, in_channle, h, w],转换为 [in_channle, 1, h, w],使用 TensorBoard 进行可视化,代码如下:
writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix") # 数据 path_img = "imgs/lena.png" # your path to image normMean = [0.49139968, 0.48215827, 0.44653124] normStd = [0.24703233, 0.24348505, 0.26158768] norm_transform = transforms.Normalize(normMean, normStd) img_transforms = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), norm_transform ]) img_pil = Image.open(path_img).convert('RGB') if img_transforms is not None: img_tensor = img_transforms(img_pil) img_tensor.unsqueeze_(0) # chw --> bchw # 模型 alexnet = models.alexnet(pretrained=True) # 注册hook fmap_dict = dict() for name, sub_module in alexnet.named_modules(): if isinstance(sub_module, nn.Conv2d): key_name = str(sub_module.weight.shape) fmap_dict.setdefault(key_name, list()) # 由于AlexNet 使用 nn.Sequantial 包装,所以 name 的形式是:features.0 features.1 n1, n2 = name.split(".") def hook_func(m, i, o): key_name = str(m.weight.shape) fmap_dict[key_name].append(o) alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func) # forward output = alexnet(img_tensor) # add image for layer_name, fmap_list in fmap_dict.items(): fmap = fmap_list[0]# 取出第一个卷积核的参数 fmap.transpose_(0, 1) # 把 BCHW 转换为 CBHW nrow = int(np.sqrt(fmap.shape[0])) fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow) writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)
转载地址:https://liumin.blog.csdn.net/article/details/117290328 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!