tensorflow-使用传统神经网络mnist手写数字图像识别的最佳实践
发布日期:2022-02-14 16:09:23
浏览次数:19
分类:技术文章
本文共 4804 字,大约阅读时间需要 16 分钟。
文件列表如下:
文件说明:mnist_inference 用于定义前向传播算法及其相关参数,mnist_train模型训练与持久化,mnist_eval模型加载与验证
mnist_inference.py
# 该文件定义了前向传播过程和神经网络参数
# _*_ coding:utf-8 _*_ import tensorflow as tf #定义参数 INPUT_NODE = 784 OUTPUT_NODE = 10 LAYER1_NODE = 500 def get_weight_variable(shape,regularizer): #tf.get_variable用于变量的创建和加载,变量加载时可通过重命名直接使用滑动平均后的变量 weights = tf.get_variable("weights",shape=shape,initializer=tf.truncated_normal_initializer(stddev=0.1)) if regularizer != None: #losses为自定义集合名称 tf.add_to_collection('losses',regularizer(weights)) return weights #定义神经网络前向传播 def inference(input_tensor,regularizer): with tf.variable_scope('layer1'): weights = get_weight_variable([INPUT_NODE,LAYER1_NODE],regularizer) biases = tf.get_variable("biases",[LAYER1_NODE],initializer=tf.constant_initializer(0.0)) layer1 = tf.nn.relu(tf.matmul(input_tensor,weights)+biases)with tf.variable_scope('layer2'):
weights = get_weight_variable([LAYER1_NODE,OUTPUT_NODE], regularizer) biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0)) layer2 = tf.nn.relu(tf.matmul(layer1, weights) + biases) return layer2
mnist_train.py
# _*_ coding:utf-8 _*_
import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_inference#配置神经网络参数
#使用梯度下降算法优化损失函数 BATCH_SIZE = 1000 #学习率指数衰减 LEARNING_RATE_BASE = 0.8 LEARNING_RATE_DECAY = 0.99 TRAINING_STEPS = 300000 #正则化避免过拟合 REGULARAZTION_RATE = 0.001 MOVING_AVERAGE_DECAY = 0.99MODEL_SAVE_PATH = "save/"
MODEL_NAME="mnist.ckpt"def train(mnist):
x = tf.placeholder(tf.float32,[None,mnist_inference.INPUT_NODE],name='x-input') y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') regularizer = tf.contrib.layers.l2_regularizer(REGULARAZTION_RATE) y = mnist_inference.inference(x,regularizer) global_step = tf.Variable(0,trainable=False) variables_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step) variables_averages_op = variables_averages.apply(tf.trainable_variables()) #y为长度为10的1维数组,labels为正确答案的类别,所以用到了argmax cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,labels=tf.argmax(y_,1)) cross_entropy_mean = tf.reduce_mean(cross_entropy) loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) print(mnist.train.num_examples) learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,mnist.train.num_examples/BATCH_SIZE,LEARNING_RATE_DECAY) train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step) with tf.control_dependencies([train_step,variables_averages_op]): train_op = tf.no_op(name='train') saver = tf.train.Saver() with tf.compat.v1.Session() as sess: tf.global_variables_initializer().run() for i in range(TRAINING_STEPS): xs,ys = mnist.train.next_batch(BATCH_SIZE) _,loss_value,step = sess.run([train_op,loss,global_step],feed_dict={x:xs,y_:ys}) if i % 1000 == 0: print("{0}steps,loss:{1}".format(step,loss_value)) saver.save(sess,os.path.join(MODEL_SAVE_PATH,MODEL_NAME),global_step=global_step) def main(argv=None): mnist = input_data.read_data_sets("data",one_hot=True) train(mnist) if __name__ == '__main__': tf.app.run()
mnist_eval.py
# _*_ coding:utf-8_*_import timeimport tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_dataimport mnist_inference,mnist_trainEVAL_INTERVAL_SECONDS = 10def evaluate(mnist): with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x:mnist.validation.images,y_:mnist.validation.labels} #训练时不关心正则损失 y = mnist_inference.inference(x,None) correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) variables_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY) variables_to_restore = variables_averages.variables_to_restore() saver = tf.train.Saver(variables_to_restore) while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess,ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('\\')[-1].split("-")[-1] accuracy_score = sess.run(accuracy,feed_dict=validate_feed) print("{0}steps,loss:{1}".format(global_step, accuracy_score)) time.sleep(EVAL_INTERVAL_SECONDS)def main(argv=None): mnist = input_data.read_data_sets("data",one_hot=True) evaluate(mnist)if __name__ == '__main__': tf.app.run()
转载地址:https://blog.csdn.net/qq_29590285/article/details/106132917 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!
发表评论
最新留言
表示我来过!
[***.240.166.169]2024年03月20日 22时33分21秒
关于作者
喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!
推荐文章
安卓——WIFI连接
2019-04-26
安卓——关于一些ui界面设置(直续更新ing)
2019-04-26
刷门禁——判断卡号是否一样(String==String)出现False
2019-04-26
好久没刷题了(阿里测试题)
2019-04-26
安卓界面——最开始界面的加载
2019-04-26
安卓——屏蔽陌生来电
2019-04-26
安卓——小笔记
2019-04-26
客户端面试万金油
2019-04-26
【u3d泰斗破坏神】05 --- 角色移动 velocity 的相关问题
2019-04-26
【u3d泰斗破坏神】06 --- Loading界面进度条Slider的使用
2019-04-26
【u3d泰斗破坏神】07 --- 角色攻击动画拆分、状态机设计
2019-04-26
【u3d泰斗破坏神】08 --- UGUI 制作艺术字体
2019-04-26
【u3d泰斗破坏神】09 --- 角色血条的制作、掉血特效
2019-04-26
Unity Shader 入门精要(01) -- 渲染流水线
2019-04-26
Unity Shader 入门精要(02) -- shader的编码基础
2019-04-26
Unity Shader 入门精要(03) -- Unity的基础光照
2019-04-26
Unity Shader 入门精要(04) -- 基础纹理
2019-04-26
Unity3D 移动平台的资源路径问题
2019-04-26
二分查找(折半查找)
2019-04-26
线段树
2019-04-26