TensorFlow经典入门示例MNIST(识别手写的数字图片)
发布日期:2021-07-01 05:38:15 浏览次数:2 分类:技术文章

本文共 3720 字,大约阅读时间需要 12 分钟。

文章目录

TensorFlow是流行的AI框架之一, 那么使用TensorFlow究竟能达成哪些人工智能的工作呢? 官方曾经提供了一个经典的演示示例,就是使用该框架实现对手写的阿拉伯数字图片的识别,该项目称为MNIST。这里的“经典”有两层意思:

  1. MNIST是最典型也是最早的演示示例, 通俗易懂,具有代表性
  2. 经典也代表了过时,目前官方推荐的示例是识别衣服、裤子和包包的图片,项目名是Fasion MNIST(时尚版的MNIST)
    本篇基于TensorFlow 2版本介绍,相关内容是在安装了TensorFlow基础之上进行,关于TensorFlow的安装,参考:

MNIST是什么?

MNIST的全写是: Mixed National Institue of Standards and Technology database,翻译的意思是:美国美国国家标准与技术研究所数据库。

MNIST是Yan Lecun等于1998年在论文(手写数字MNIST数据库)提出的概念, 另外一篇重要的论文是Gradient-based learning applied to document recognition(梯度学习在文档识别中的应用)。

MNIST的组成

MNIST的数据包括训练集(training set)和测试集两部分(test set),总共70000张图片数据,具体如下:

  • 训练集, 包括60000张手写图片及对应的数字标签。
  • 测试集, 包括10000张手写图片及对应的数字标签。

这些图片数据由来自 250 个不同人手写的数字构成, 其中 50% 是高中学生, 50% 来自人口普查局 (the Census Bureau) 的工作人员。这些图片的满足:

  • 0-9 的10个数字之一
  • 像素是28*28的灰度图
  • 每个像素值在0-255之间, 0 是黑色,255是白色。

示例图片的样貌如下:

在这里插入图片描述

MNIST存储方式

这70000张图片及对应的数字标签分为四个文件进行存储,分别是:

  • train-images-idx3-ubyte.gz ,60000张训练集图片
  • train-labels-idx1-ubyte.gz ,60000张训练集图片对应的标签
  • t10k-images-idx3-ubyte.gz,10000张测试集图片
  • t10k-labels-idx1-ubyte.gz ,10000张测试集图片对应的标签

图片文件的存储格式是:

  • 第1-4个byte(字节,1byte=8bit),即前32bit存的是文件的magic number,对应的十进制大小是2051;(magic number,魔幻数,用来标识文件的特性)
  • 第5-8个byte存的是图片的数量,即60000;
  • 第9-12个byte存的是每张图片行数/高度,即28;
  • 第13-16个byte存的是每张图片的列数/宽度,即28。
  • 从第17个byte开始,每个byte存储一张图片中的一个像素点的值。

标签文件测存储格式是:

  • 第1-4个byte存的是文件的magic number,对应的十进制大小是2049;
  • 第5-8个byte存的是number of items,即label数量60000;
  • 从第9个byte开始,每个byte存一个图片的label信息,即数字0-9中的一个。

使用Keras编码

Keras是一个由Python编写的开源人工神经网络库,可以作为Tensorflow、Microsoft-CNTK和Theano的高阶应用程序接口,进行深度学习模型的设计、调试、评估、应用和可视化。

自2017年起,Keras得到了Tensorflow团队的支持,其大部分组件被整合至Tensorflow的Python API中。在2018年Tensorflow 2.0.0公开后,Keras被正式确立为Tensorflow高阶API,即tf.keras。
所以,在TensorFlow中使用Keras, 不需要单独学习Keras,只需要调用tf.keras模块相关的接口即可。

新建mnist.py文件,内容如下:

import osos.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' #不显示AVX,CUDA等警告, 没有GPU的机器可以加上#1. 导入TensorFlowimport tensorflow as tfmnist = tf.keras.datasets.mnist #2. 通过keras使用数据集(x_train, y_train), (x_test, y_test) = mnist.load_data() # 载入数据x_train, x_test = x_train / 255.0, x_test / 255.0  # 3. 数据归一化, 范围0到1之间,因为像素值的范围是0~255# 4. 搭建模型,选择优化器和损失函数model = tf.keras.models.Sequential([  tf.keras.layers.Flatten(input_shape=(28, 28)),  tf.keras.layers.Dense(128, activation='relu'),  tf.keras.layers.Dropout(0.2),  tf.keras.layers.Dense(10, activation='softmax')])#5. 模型编译model.compile(optimizer='adam',              loss='sparse_categorical_crossentropy',              metrics=['accuracy']) #6. 训练并验证模型model.fit(x_train, y_train, epochs=5)#7. 评估模型model.evaluate(x_test,  y_test, verbose=2)

以上源码在GitHub中的地址是:

这里将该文件放置在 D:\demoworkspace\tensorflow\tutorial 目录中, 通过 py mnist.py命令执行的效果是:

在这里插入图片描述

以上输出的大致意思是: 使用60000笔数据进行训练,训练5轮之后,模型的准确率达到了 97.32%。
Epoch的中文意思是时期、时代。在深度学习中的意思是训练集的全部数据对模型进行一次完整的训练。一般训练的轮次越多,准确率越高,花费的资源和时间也就越多,该实例中,如果训练10轮次,准确率超过98%。

运行说明与问题

以上代码会从 下载MNIST的数据,文件会下载到 C:\Users\电脑用户名.keras\datasets 路径下, 类似:

在这里插入图片描述

mnist.npz大概11M,包括上面介绍的四部分内容,分别是训练集和测试集的图片和标签。
以上是Google的下载地址,国内网络下载速度会比较慢, 如果失败可以多跑两次,如果还下不来,可以到以下百度网盘下载。

链接:

提取码:6q4w

图形化显示和理解

仅仅以上的代码和执行的结果, 对理解AI以及模型、训练和测试还是比较晦涩的, 只是知道使用数据训练了, 训练后使用测试数据进行验证的准确度时多少,但实际模型干什么了呢?

结合Python的可视化库matplotlib,可以更好的理解以上代码和深度学习的过程。matplotlib类似MATLAB,MATLAB就比较知名了,其一种高级的计算语言和交互式环境,可以用来算法开发、数据分析和可视化等。由美国MathWorks公司出品。MTALAB是 Matrix和Laboratory的组合,矩阵实验室。matplotlib的功能类似MATLAB,而且开源。

安装matplotlib

首选使用 pip list查看是否已经安装了该模块, 如果没有安装,使用如下安装命令:

pip install matplotlib

折线图简单示例

以绘制一个简单的折线图为例, 代码如下:

import matplotlib.pyplot as pltx_data = [1,2,3,4,5,6,7,8] # X轴坐标点y_data = [1,2,3,4,5,6,7,8]  # Y轴坐标点plt.plot(x_data,y_data)plt.show()

运行效果如下:

在这里插入图片描述

显示MNIST中的图

#1. 导入TensorFlowimport tensorflow as tf#1.2 matplotlibimport matplotlib.pyplot as pltmnist = tf.keras.datasets.mnist #2. 通过keras使用数据集(x_train, y_train), (x_test, y_test) = mnist.load_data() # 载入数据#3. 显示第一张图plt.figure()plt.imshow(x_train[0])plt.show()

运行效果:

在这里插入图片描述

待续。。。。。

转载地址:https://oscar.blog.csdn.net/article/details/105372881 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:Spring Boot整合H2内存数据库配置及常见问题处理
下一篇:跨站脚本攻击(XSS)及防范措施

发表评论

最新留言

很好
[***.229.124.182]2024年04月10日 13时16分57秒