Pytorch框架_经过CNN后的维度如何查看_闭关修炼系列_CodingPark编程公园
发布日期:2021-06-29 15:49:33
浏览次数:2
分类:技术文章
本文共 3545 字,大约阅读时间需要 11 分钟。
CNN模型
import torchfrom torch import nnfrom torch.nn import functional as F'''step2. Net Creat 网络创建'''class Lenet5(nn.Module): def __init__(self): super(Lenet5, self).__init__() self.conv_unit = nn.Sequential( # Sequential不需要给每个层编号,同样也就没有赋值 ,因为总的来说它时一个整体; # x: [b, 3, 32, 32]=> [b, 16, size] nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=0), # in_channels=3代表 3个通道in 其实就是RGB这三个,out_channels =16 nn.MaxPool2d(kernel_size=2, stride=2, padding=0), # kernel_size=2 kernel一次看一个长宽各2的窗口 nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=0), nn.MaxPool2d(kernel_size=2, stride=2, padding=0), ) # flatten打平 从32*5*5 -> 10 # 全连接层 self.fc_unit = nn.Sequential( nn.Linear(32*5*5, 120), nn.ReLU(), nn.Linear(120, 84), nn.ReLU(), nn.Linear(84, 10) ) # [b, 3, 32, 32] 测试1,看conv_unit后的out_channels数 # tmp = torch.randn(2, 3, 32, 32) # out = self.conv_unit(tmp) # # [b, 16, 5, 5] # print('conv out: ',out.shape) # use Cross Entropy Loss # self.criteon = nn.MSELoss() # 处理逼近问题 # self.criteon = nn.CrossEntropyLoss() # 处理分类问题 def forward(self, x): # 前向路径 #param(参数) x:[b, 3, 32, 32] batchsz = x.size(0) #param(参数) x:[b, 3, 32, 32] 其中的 b # [b, 3, 32, 32] => [b, 16, 5, 5] x = self.conv_unit(x) # flatten打平 从16*5*5 -> 10 x = x.view(batchsz, 32*5*5) # [b, 16*5*5] => [b, 10] logits = self.fc_unit(x) # logits: 网络最后一般送入softmax 那么 softmax前的变量 统称 logits return logits# def main():# net = Lenet5()# tmp = torch.randn(2, 3, 32, 32)# out = net(tmp)# # [b, 16, 5, 5]# print('lenet5 out: ', out.shape)### if __name__ == '__main__':# main()
查看经过CNN后的维度
[b, 3, 32, 32]
测试 : 为了看conv_unit后的out_channels数, 其中b为batch_size
结果展示
我自己从0又写了一遍
这次我写的是Minist手写体识别
方法1
仅通过__init__看结果
__ init __嘛~,你用那个类,它自动就会走 __ init __
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torchvision import datasets, transformsimport torch.utils.dataclass TeacherModel(nn.Module): def __init__(self): super(TeacherModel, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.3) self.dropout2 = nn.Dropout2d(0.5) tmp = torch.randn(2, 1, 28, 28) out = self.conv1(tmp) out = self.conv2(out) out = F.max_pool2d(out, 2) out = torch.flatten(out, 1) print('out->', out.shape)net = TeacherModel()
方法2
通过__init__ 与 def forward(self, x)看结果
import torchimport torch.nn as nnimport torch.nn.functional as Ffrom torchvision import datasets, transformsimport torch.utils.dataclass TeacherModel(nn.Module): def __init__(self): super(TeacherModel, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv2 = nn.Conv2d(32, 64, 3, 1) self.dropout1 = nn.Dropout2d(0.3) self.dropout2 = nn.Dropout2d(0.5) def forward(self, x): batchsz = x.size(0) x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) return xtmp = torch.randn(2, 1, 28, 28)net = TeacherModel()out = net(tmp)print('out:', out.shape)
转载地址:https://codingpark.blog.csdn.net/article/details/113425915 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
不错!
[***.144.177.141]2024年04月03日 08时12分22秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
GD32替换STM32,这些细节一定要知道。
2019-04-29
华为员工离职心声:菊厂15年退休,感恩,让我实现了财务自由!
2019-04-29
春晚上的“拓荒牛”
2019-04-29
嵌入式驱动自学者的亲身感受,有什么建议?
2019-04-29
华为被超越!这家公司成中国最大智能手机制造商,不是小米!
2019-04-29
腾讯机器狗,站起来了!
2019-04-29
我用自己创造的深度学习框架进入腾讯,爽!
2019-04-29
芯片为什么持续缺货?
2019-04-29
又涨了?2021 年 3 月程序员工资统计新出炉
2019-04-29
初入行的C++程序员,如何快速摆脱CRUD阶段?
2019-04-29
研究生跟了一个很棒的导师是种怎样的体验?
2019-04-29
学会扶墙的机器人:没有什么能让我倒下!
2019-04-29
美国无人机在火星首飞成功,创造历史,3米飞行高度悬停30秒
2019-04-29
单片机的几种数字滤波算法
2019-04-29
用单片机控制导弹?
2019-04-29
各种滤波器合集!
2019-04-29
国产CPU深度研究报告(干货,110页)
2019-04-29
在电路中,耦合是什么?有哪些方式?
2019-04-29
变局之际,聊聊物联网的过去、现在和未来
2019-04-29
缺货涨价很久的MCU的国产和国外厂家汇总!(80家)
2019-04-29