Minist手写体识别_完整项目_CodingPark编程公园
发布日期:2021-06-29 15:45:53 浏览次数:2 分类:技术文章

本文共 5229 字,大约阅读时间需要 17 分钟。

预备条件

  • 本项目利用Pytorch开发
    在这里插入图片描述

——————————utils.py—工具文件——————————

用于绘图、绘制loss曲线、one_hot

import torchfrom matplotlib import pyplot as plt        # 绘图# makecurve# to show loss picturedef plot_curve(data):    fig = plt.figure()  # 设置绘图区域的大小和像素    plt.plot(range(len(data)),data,color = 'blue')  # 将实际值的折线设置为蓝色    plt.legend(['value'],loc = 'upper right')   # 显示图例的位置,自适应方式    plt.xlabel('step')    plt.ylabel('value')    plt.show()# draw image# designed to show picture meterialdef plot_img(img, label, name):    fig = plt.figure()  # plt.figure()用来画图,create a figure;自定义画布大小,表示figure 的大小为宽、长(单位为inch)    for i in range(6):        plt.subplot(2,3,i + 1)  # 表示整个figure分为2行3列        plt.tight_layout()        plt.imshow(img[i][0] * 0.3081 + 0.1307, cmap='gray', interpolation='none')        plt.title("{}: {}".format(name, label[i].item()))        plt.xticks([])        plt.yticks([])    plt.show()#one-hotdef one_hot(label, depth=10):    out = torch.zeros(label.size(0), depth)    idx = torch.LongTensor(label).view(-1, 1)    out.scatter_(dim=1, index=idx, value=1)    return out

——————————minist.py—项目文件——————————

主项目文件

import torchfrom torch import nnfrom torch.nn import functional as Ffrom torch import optim                          # Optimization Toolbox 优化工具包import torchvision                              # vision 视觉from matplotlib import pyplot as plt            # 绘图from utils import plot_img, plot_curve, one_hot'''step1. load dataset   加载数据集'''# 'mnist_data':加载mnist数据集,路径# train=True:选择训练集还是测试# download=True:如果当前文件没有mnist文件就会自动从网上去下载# torchvision.transforms.ToTensor():下载好的数据一般是numpy格式,转换成Tensor# torchvision.transforms.Normalisze((0.1307,), (0.3081,)):正则化过程,为了让数据更好的在0的附近均匀的分布# 上面一行可注释掉:但是性能会差到百分之70,加上是百分之80,更加方便神经网络去优化# batch size if the total data number every batch.batch_size = 512# extract and transform# train训练集train_loader = torch.utils.data.DataLoader(                                                        #DataLoader批量处理  datasets加载一张      torchvision.datasets.MNIST('mnist_data',train=True, download=True,      # datasets加载 Mnist 数据集                               transform=torchvision.transforms.Compose([   # 转化                                   torchvision.transforms.ToTensor(),                                   torchvision.transforms.Normalize((0.1307,),(0.3081,))# 均值是0.1307,                                       # 标准差是0.3081,这些系数都是数据集提供方计算好的数据                               ])),batch_size = batch_size, shuffle = True)                   # batch_size=batch_size:表示一次加载多少张图片# shuffle = True 加载的时候做一个随机的打散# test训练集test_loader = torch.utils.data.DataLoader(    torchvision.datasets.MNIST('mnist_data/',train=False, download=True,                               transform=torchvision.transforms.Compose([                                   torchvision.transforms.ToTensor(),                                   torchvision.transforms.Normalize((0.1307,),(0.3081,)),                               ])),batch_size = batch_size, shuffle = False)# 尝试显示加载项x,y =next(iter(train_loader))print(x.shape, y.shape, x.min(), x.max())plot_img(x, y, 'TEAM-AG_MNISTTrain')'''step2. Net Creat   网络创建'''# 生成 网class Net(nn.Module):    def __init__(self):        super(Net, self).__init__()        # define three layers,        # fc = 全连接 fc stands for fully connected layer. conv is for convolution layer(nn.Con2d()        # xw+b         self.fc1 = nn.Linear(28*28, 256)        self.fc2 = nn.Linear(256, 64)        self.fc3 = nn.Linear(64, 10)# 生成 向前训练路径    def forward(self, x):        # x: [b, 1, 28, 28]        # h1 = relu(xw1+b1)        x = F.relu(self.fc1(x))        # h2 = relu(h1w2+b2)        x = F.relu(self.fc2(x))        # h3 = h2w3+b3        x = self.fc3(x)        return x'''step3.  Train  训练'''# 实例化网net = Net()                 # 初始化net# [w1, b1, w2, b2, w3, b3]# 来个优化器optimizer = optim.SGD(net.parameters(), lr = 0.01, momentum =0.9)     # SGD成优化器train_loss = []   # train_loss记录# 一次一次轮# 一轮一轮又一轮for epoch in range(3):    for batch_idx, (x, y) in enumerate(train_loader):        # x: [b, 1, 28, 28], y: [512]        x = x.view(x.size(0), 28*28)        # 打平 # [b, 1, 28, 28] => [b, 784]# 走网络        out = net(x)        # 有了 制造品  我们要 VS 一下 定标品        y_onehot = one_hot(y)               # 给 y 编码one-hot        # loss = mse(out, y_onehot)# 走Loss        loss = F.mse_loss(out, y_onehot)# 走素质三连        optimizer.zero_grad()  #首先使之梯度为0        loss.backward()        # w' = w - lr*grad        optimizer.step()        train_loss.append(loss.item())        if batch_idx % 10 == 0:            print(epoch, batch_idx, loss.item())plot_curve(train_loss)              # 画 loss 曲线# 这是已经完成了 一轮一轮训练 so we get optimal [w1, b1, w2, b2, w3, b3]'''step4.  accuracy test 准确度测试'''total_correct = 0for x, y in test_loader:    x = x.view(x.size(0), 28*28)    out = net(x)    # out: [b, 10] => pred: [b]    pred = out.argmax(dim=1)# pred的准确度    correct = pred.eq(y).sum().float().item()  # 当前batch中与y标签等,也就是预测对的总个数合计,item()取出它的数值     item()变numpy    total_correct += correcttotal_num = len(test_loader.dataset)acc = total_correct / total_numprint('test acc : ' ,acc)x, y = next(iter(test_loader))      # 取一个batch,查看预测结果out = net(x.view(x.size(0), 28*28))pred = out.argmax(dim=1)    # 取得[b, 10]的10个值的最大值所在位置的索引plot_img(x, pred,'TEAM-AG_MNISTTest')

在这里插入图片描述

转载地址:https://codingpark.blog.csdn.net/article/details/105394828 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:CNN-LeNet5_完整项目_CodingPark编程公园
下一篇:win10下pytorch-gpu安装以及CUDA详细安装过程

发表评论

最新留言

逛到本站,mark一下
[***.202.152.39]2024年04月19日 19时54分00秒