tensorflow给图片打标签_数字图片分类实例玩转R中的Tensorflow
发布日期:2021-10-30 18:55:19 浏览次数:1 分类:技术文章

本文共 2727 字,大约阅读时间需要 9 分钟。

01
开篇 Introduction       
  • Tensorflow in R 系列,将分享如何使用R语言在Tensorflow/Keras 框架中训练深度学习模型。

  • MNIST 全称为 Modified National Institute of Standards and Technology。这个名词一点也不重要。

  • MNIST 数据为 7万张(6万张训练+1万张测试 0-9的手写数字图片。建立模型预测图片中的数字是多少。

安装 R 和 R studio

此次省略300字,建议使用云计算平台如Kaggle Kernel/Google Codelab/Google Cloud 等

安装 keras package

ff6975ef664a2fa6e4ea11f12bffce92.png

查看 tensorflow 版本

17a3c522d794ede08db9cda47461b267.png

02
导入数据 Input data       

导入4个数据集,分别为:

  • x_train:  6万张训练数字图片

  • y_train   6万个训练数字0-9标签

  • x_test:1万张测试数字图片

  • y_test:1万个测试数字0-9标签

为什么有4个数据集 ?

  • 带x的通常为特征(feature)。带y的为标签(label)。

  • 训练数据是用来训练模型。测试数据不参加建模,而是模型建立后是用来测试模型的效果。

b5a63180c839ba4f119455d1b825b34f.png

这些图片长这个样

c89245243dbb839f7449cecd185fd37d.png

b3b0da2dad94fd48bb44ab01f4152ed8.png

03
数据处理 Data cleaning  

reshape:将每个2维的28 × 28 的图片变成1维数据 1× 784 的数据

rescale:将每个由0到255的像素(pixel)转为0到1:原来是0的,现在 0/255=0 原来是255的,现在255/255=1。原来为200,现在200/255=0.78

33669de3164f2cbb21451342b6fab967.png

embedding:

这里对标签作 0,1 embedding 处理。

处理后 y_train 变成了 6万行 ,每行10 个 0或1 的数据。

处理后 y_test 变成了 1万行 ,每行10 个 0或1 的数据

b02c63d0a924c15fc3370f22db6c4aea.png

数据处理前

  • x_train: 6万张训练数字图片 60000 * 28 * 28 形状的 0-255的数字

  • y_train:6万个训练数字0-9标签 60000 形状的 0-9的数字

  • x_test:1万个测试数字图片 10000 * 28 * 28 形状的 0-255的数字

  • y_test:1万个测试数字0-9标签 10000 形状的 0-9的数字

数据处理后

  • x_train: 6万张训练数字图片 60000 * 784 形状的 0到1的数字

  • y_train:6万个训练数字0-9标签 60000 * 10 形状的 0或1的数字

  • x_test:1万个测试数字图片 10000 * 784 形状的 0到1的数字

  • y_test:1万个测试数字0-9标签 10000 * 10 形状的 0或1的数字

04
建立模型 modeling       

建立深度神经网络模型(deep neural network)

网络结构介绍:

输入层:每个图片的形状为784位数字的输入层

第一层:使用 'relu' 的256个tensor 的隐藏层 (relu 是什么?后续文章再聊)

第二层:使用 'relu' 的128个tensor 的隐藏层

输出层:使用 'softmax' 的 10个 加总为1 的 0到1的概率 的 输出层 (softmax 是什么?后续文章再聊)

b333191a1e9c087119807e738db21ee8.png

神经网络型图:

41877b5f86c0fcac8a69a5d34d802add.png

神经网络公式:

公式是我们设计模型的时候定义的。比如图中的模型。W11-W33 9个weight 和 b1-b3 3个bias 经过训练得出。所以模型训练的Learnable Parameters=9+3=12

4de8473b98d6f52c3ea9349d1732c222.png

模型的架构:

Learnable_Parameters=input*output+bias

第一层:使用'relu' 的256个tensor 的隐藏层:

Learnable_Parameters:200960=784*256 + 256

第二层:使用'relu'的128个tensor 的隐藏层:

Learnable_Parameters:32896=256*128+128

输出层:使用 'softmax' 的 10个 0到1的概率 的 输出层: 

Learnable_Parameters :1290=128*10+10

总Learnable_Parameters :

235146=200960+32896+129

24f75ac037c8fd33e98c757365b786fd.png

9c9a55d422cdaf6d1ab03ffd20de342b.png

05
Complie模型     

loss function是categorial_crossentropy

(loss function 是什么?后续文章再聊)

optimizer是optimize_rmsprop

(optimizer 是什么?后续文章再聊)

metrics 为 accuracy,metrics是评估模型的指标。大多数情况都选accuracy。accuracy=正确预测的个数/总预测个数

5a89738b703041408f0f95d1556d315a.png

06
训练模型 trainning       

一堆数据处理转换。模型设计后 。终于可以开始训练模型了。

x_train为训练数据集特征

(6万张照片)

y_train 为训练数据集标签

(6万个数字)

每次读入128张图片。训练10次。

6万张照片80%用来训练。20%用来验证。

训练时间大概为5分钟。

e93545944d95e4744d9f6014a32862aa.png

07
模型效果 performance

可见 经过 10次训练后。最终在验证集的accuracy表现为97%。从图中可见其实经过6次的训练。在验证集的表现以达到97%

312e964f1253343999145972caf69f04.png

a3a05e74ba827015abe0429f1fc2f898.png

08
模型对比 benchmark       

Naive benchmark:

如果我们什么都不知道,瞎猜0-9的话。准确度是10%

决策树模型 Decision tree benchmark:

使用决策树模型。准确度是61%。训练时间大概为10分钟。

随机森林模型 random forest benchmark:

使用随机森林模型。准确度是92%。训练时间大概为15分钟

tensorflow神经网络模型的准确度是97%

dd98c80fda6bf73fc09086e0fe427072.png

327adf46bb2b089ae7a3d8fe61a72d4e.png

09
总结 summary       

 使用tensorflow 神经网络模型将准确率提高到97%。可以得到如此高的准确率,主要是图片比较简单。只有0-9的标准数字。对于更加困难的问题。比如在自动驾驶中需要精准的物体识别等问题。将需要更加复杂的神经网络模型。

代码:https://tduan.netlify.com/post/tensorflow-in-r-1-mnist-image-classification/

如果您喜欢本文。请分享出去。

后续分享:

Tensorflow in R 系列(2) :时装分类 Fashion-MNIST image classification with CNN

db8195c6e096344760bfb6bdfecc112f.png

转载地址:https://blog.csdn.net/weixin_39849239/article/details/111167380 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:3500个常用汉字表_【Excel教程】10个常用的Excel透视表技巧!一次送给你
下一篇:高速公路etc门架最新要求_交控信息公司助力集团完成高速公路ETC门架收费切换...

发表评论

最新留言

逛到本站,mark一下
[***.202.152.39]2024年02月28日 21时45分26秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章

mysql居左查询abcd_MySql速查手册 2019-04-21
loadrunner 错误: 无法找到 java.exe_LoadRunner错误及解决方法总结 2019-04-21
Java小魔女芭芭拉_沉迷蘑菇不可自拔,黏土人《小魔女学园》苏西·曼芭芭拉 图赏... 2019-04-21
php+mysql记事本_一个简单记事本php操作mysql辅助类创建 2019-04-21
300小时成为java程序员_直击面试现场: Java程序员3轮6小时面试, 成功拿到阿里offer!... 2019-04-21
中国网建java发送短信_短信验证登陆-中国网建提供的SMS短信平台 2019-04-21
隔行变色java代码_jquery入门—选择器实现隔行变色实例代码 2019-04-21
角标越界 Java_【新人求助】利用占位符操作数据库是总是提示数组角标越界是怎么回事 - Java论坛 - 51CTO技术论坛_中国领先的IT技术社区... 2019-04-21
java类中声明log对象_用于Android环境,java环境的log打印,可打印任何类型数据 2019-04-21
db2与mysql编目_DB2编目、联邦数据库 - Goopand's OS Space - OSCHINA - 中文开源技术交流社区... 2019-04-21
atomikosdatasourcebean mysql_SpringBoot2整合JTA组件实现多数据源事务管理 2019-04-21
webpack 入口文件 php,如何实现webpack多入口文件打包配置 2019-04-21
php tire树,Immutable.js源码之List 类型的详细解析(附示例) 2019-04-21
matlab转差频率控制,转差频率控制的异步电机调速系统的研究 2019-04-21
oracle错误1327,Oracle中的PGA监控报警分析(r11笔记第97天) 2019-04-21
php函数内的循环,PHP 循环列出目录内容的函数代码 2019-04-21
oracle树状排序,Oracle树状结构查询 2019-04-21
深度linux内核升级,深度操作系统 2020.11.11 更新发布:内核升级 2019-04-21
android 解压gzip,Response gzip 解压的问题 2019-04-21
html表格中的滚动字幕,html – 向表格主体添加滚动条 2019-04-21