深度学习(可视化卷积核)-牛翰网

深度学习(可视化卷积核)

     

可视化卷积核参数对理解卷积神经网络的工作原理、优化模型性能、提高模型泛化能力有一定帮助作用。

下面以resnet18为例,可视化了部分卷积核参数。

import torchvision
from matplotlib import pyplot as plt
import torch

model = torchvision.models.resnet18(pretrained=True)
#model = torchvision.models.efficientnet_b0(pretrained=True)

num = 1
# 遍历模型的每一层
for name, module in model.named_modules():
    # 判断是否为卷积层
    if isinstance(module, torch.nn.Conv2d):
        # 输出卷积层名称和权重
        print(f"layer {name} : {module.weight.data.shape}")
        _,_,H,W = module.weight.data.shape
        if H >=3 and W >=3:
            plt.subplot(5,4,num)
            data = module.weight.data.numpy()
            plt.imshow(data[0,0,:,:])  #太多了,只显示一个卷积核
            num+=1

plt.show()            

 结果如下:

请登录后发表评论

    没有回复内容