Minist手写体识别_完整项目_CodingPark编程公园
发布日期:2021-06-29 15:45:53
浏览次数:2
分类:技术文章
本文共 5229 字,大约阅读时间需要 17 分钟。
预备条件
- 本项目利用Pytorch开发
——————————utils.py—工具文件——————————
用于绘图、绘制loss曲线、one_hotimport torchfrom matplotlib import pyplot as plt # 绘图# makecurve# to show loss picturedef plot_curve(data): fig = plt.figure() # 设置绘图区域的大小和像素 plt.plot(range(len(data)),data,color = 'blue') # 将实际值的折线设置为蓝色 plt.legend(['value'],loc = 'upper right') # 显示图例的位置,自适应方式 plt.xlabel('step') plt.ylabel('value') plt.show()# draw image# designed to show picture meterialdef plot_img(img, label, name): fig = plt.figure() # plt.figure()用来画图,create a figure;自定义画布大小,表示figure 的大小为宽、长(单位为inch) for i in range(6): plt.subplot(2,3,i + 1) # 表示整个figure分为2行3列 plt.tight_layout() plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none') plt.title("{}: {}".format(name, label[i].item())) plt.xticks([]) plt.yticks([]) plt.show()#one-hotdef one_hot(label, depth=10): out = torch.zeros(label.size(0), depth) idx = torch.LongTensor(label).view(-1, 1) out.scatter_(dim=1, index=idx, value=1) return out
——————————minist.py—项目文件——————————
主项目文件import torchfrom torch import nnfrom torch.nn import functional as Ffrom torch import optim # Optimization Toolbox 优化工具包import torchvision # vision 视觉from matplotlib import pyplot as plt # 绘图from utils import plot_img, plot_curve, one_hot'''step1. load dataset 加载数据集'''# 'mnist_data':加载mnist数据集,路径# train=True:选择训练集还是测试# download=True:如果当前文件没有mnist文件就会自动从网上去下载# torchvision.transforms.ToTensor():下载好的数据一般是numpy格式,转换成Tensor# torchvision.transforms.Normalisze((0.1307,), (0.3081,)):正则化过程,为了让数据更好的在0的附近均匀的分布# 上面一行可注释掉:但是性能会差到百分之70,加上是百分之80,更加方便神经网络去优化# batch size if the total data number every batch.batch_size = 512# extract and transform# train训练集train_loader = torch.utils.data.DataLoader( #DataLoader批量处理 datasets加载一张 torchvision.datasets.MNIST('mnist_data',train=True, download=True, # datasets加载 Mnist 数据集 transform=torchvision.transforms.Compose([ # 转化 torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,),(0.3081,))# 均值是0.1307, # 标准差是0.3081,这些系数都是数据集提供方计算好的数据 ])),batch_size = batch_size, shuffle = True) # batch_size=batch_size:表示一次加载多少张图片# shuffle = True 加载的时候做一个随机的打散# test训练集test_loader = torch.utils.data.DataLoader( torchvision.datasets.MNIST('mnist_data/',train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.1307,),(0.3081,)), ])),batch_size = batch_size, shuffle = False)# 尝试显示加载项x,y =next(iter(train_loader))print(x.shape, y.shape, x.min(), x.max())plot_img(x, y, 'TEAM-AG_MNISTTrain')'''step2. Net Creat 网络创建'''# 生成 网class Net(nn.Module): def __init__(self): super(Net, self).__init__() # define three layers, # fc = 全连接 fc stands for fully connected layer. conv is for convolution layer(nn.Con2d() # xw+b self.fc1 = nn.Linear(28*28, 256) self.fc2 = nn.Linear(256, 64) self.fc3 = nn.Linear(64, 10)# 生成 向前训练路径 def forward(self, x): # x: [b, 1, 28, 28] # h1 = relu(xw1+b1) x = F.relu(self.fc1(x)) # h2 = relu(h1w2+b2) x = F.relu(self.fc2(x)) # h3 = h2w3+b3 x = self.fc3(x) return x'''step3. Train 训练'''# 实例化网net = Net() # 初始化net# [w1, b1, w2, b2, w3, b3]# 来个优化器optimizer = optim.SGD(net.parameters(), lr = 0.01, momentum =0.9) # SGD成优化器train_loss = [] # train_loss记录# 一次一次轮# 一轮一轮又一轮for epoch in range(3): for batch_idx, (x, y) in enumerate(train_loader): # x: [b, 1, 28, 28], y: [512] x = x.view(x.size(0), 28*28) # 打平 # [b, 1, 28, 28] => [b, 784]# 走网络 out = net(x) # 有了 制造品 我们要 VS 一下 定标品 y_onehot = one_hot(y) # 给 y 编码one-hot # loss = mse(out, y_onehot)# 走Loss loss = F.mse_loss(out, y_onehot)# 走素质三连 optimizer.zero_grad() #首先使之梯度为0 loss.backward() # w' = w - lr*grad optimizer.step() train_loss.append(loss.item()) if batch_idx % 10 == 0: print(epoch, batch_idx, loss.item())plot_curve(train_loss) # 画 loss 曲线# 这是已经完成了 一轮一轮训练 so we get optimal [w1, b1, w2, b2, w3, b3]'''step4. accuracy test 准确度测试'''total_correct = 0for x, y in test_loader: x = x.view(x.size(0), 28*28) out = net(x) # out: [b, 10] => pred: [b] pred = out.argmax(dim=1)# pred的准确度 correct = pred.eq(y).sum().float().item() # 当前batch中与y标签等,也就是预测对的总个数合计,item()取出它的数值 item()变numpy total_correct += correcttotal_num = len(test_loader.dataset)acc = total_correct / total_numprint('test acc : ' ,acc)x, y = next(iter(test_loader)) # 取一个batch,查看预测结果out = net(x.view(x.size(0), 28*28))pred = out.argmax(dim=1) # 取得[b, 10]的10个值的最大值所在位置的索引plot_img(x, pred,'TEAM-AG_MNISTTest')
转载地址:https://codingpark.blog.csdn.net/article/details/105394828 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
逛到本站,mark一下
[***.202.152.39]2024年04月19日 19时54分00秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
硬货 | Redis 性能问题分析
2019-04-29
Kafka为什么这么快?
2019-04-29
灵魂四连问:API 接口应该如何设计?如何保证安全?如何签名?如何防重?
2019-04-29
一个依赖搞定 Spring Boot 反爬虫,防止接口盗刷!
2019-04-29
酸爽!IDEA 中这么玩 MyBatis,让编码速度飞起!
2019-04-29
已拿 Offer!字节跳动面试经验分享
2019-04-29
Windows路由表透析
2019-04-29
Java LockSupport 实战
2019-04-29
线程面试题实战与分析——各种锁的灵活运用
2019-04-29
Java 生产者和消费者面试题
2019-04-29
生产者消费者问题
2019-04-29
哲学家就餐问题
2019-04-29
本机电脑连接虚拟机redis失败解决方法
2019-04-29
JAVA学习:将字符串转成数字
2019-04-29
webrtc 中的 Android 端 jni
2019-04-29
webrtc Android 端 video 软解码创建
2019-04-29
如何构建私有的智能视觉系统
2019-04-29
OpenNCC智能视觉系统-基于Paddle的OCR模型迁移训练(一)
2019-04-29
dvsdk_3_10_00-19 编译
2019-04-29
DMAI GStreamer Plug-In 编译
2019-04-29