tf.contrib.learn快速入门
发布日期:2021-06-30 22:47:22 浏览次数:2 分类:技术文章

本文共 7899 字,大约阅读时间需要 26 分钟。

tf.contrib.learn快速入门

TensorFlow的高级机器学习API(tf.contrib.learn)可以轻松配置,训练和评估各种机器学习模型。在本教程中,您将使用tf.contrib.learn构建  分类器并在进行训练, 以基于萼片/花瓣几何来预测花种。您将编写代码以执行以下五个步骤:

  1. 将包含Iris训练/测试数据的CSV加载到TensorFlow中 Dataset
  2. 构建
  3. 使用训练数据拟合模型
  4. 评估模型的准确性
  5. 分类新样本

注意: 在开始使用本教程之前,请记住

完整的神经网络源代码

这是神经网络分类器的完整代码:

from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import urllib import numpy as np import tensorflow as tf # Data sets IRIS_TRAINING = "iris_training.csv" IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv" IRIS_TEST = "iris_test.csv" IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv" def main():   # If the training and test sets aren't stored locally, download them.   if not os.path.exists(IRIS_TRAINING):     raw = urllib.urlopen(IRIS_TRAINING_URL).read()     with open(IRIS_TRAINING, "w") as f:       f.write(raw)   if not os.path.exists(IRIS_TEST):     raw = urllib.urlopen(IRIS_TEST_URL).read()     with open(IRIS_TEST, "w") as f:       f.write(raw)   # Load datasets.   training_set = tf.contrib.learn.datasets.base.load_csv_with_header(       filename=IRIS_TRAINING,       target_dtype=np.int,       features_dtype=np.float32)   test_set = tf.contrib.learn.datasets.base.load_csv_with_header(       filename=IRIS_TEST,       target_dtype=np.int,       features_dtype=np.float32)   # Specify that all features have real-value data   feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]   # Build 3 layer DNN with 10, 20, 10 units respectively.   classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,                                               hidden_units=[10, 20, 10],                                               n_classes=3,                                               model_dir="/tmp/iris_model")   # Define the training inputs   def get_train_inputs():     x = tf.constant(training_set.data)     y = tf.constant(training_set.target)     return x, y   # Fit model.   classifier.fit(input_fn=get_train_inputs, steps=2000)   # Define the test inputs   def get_test_inputs():     x = tf.constant(test_set.data)     y = tf.constant(test_set.target)     return x, y   # Evaluate accuracy.   accuracy_score = classifier.evaluate(input_fn=get_test_inputs,                                        steps=1)["accuracy"]   print("\nTest Accuracy: {0:f}\n".format(accuracy_score))   # Classify two new flower samples.   def new_samples():     return np.array(       [[6.4, 3.2, 4.5, 1.5],        [5.8, 3.1, 5.0, 1.7]], dtype=np.float32)   predictions = list(classifier.predict(input_fn=new_samples))   print(       "New Samples, Class Predictions:    {}\n"       .format(predictions)) if __name__ == "__main__":     main()

以下部分详细介绍了代码。

将Iris CSV数据加载到TensorFlow中

包含150行数据,包括来自每个的三个相关鸢尾种类50个样品: 山鸢尾虹膜锦葵,和变色鸢尾

花瓣几何比较了三种虹膜物种:鸢尾花,鸢尾花和鸢尾花从左到右, (由 ,CC BY-SA 3.0), (由 ,CC BY-SA 3.0)和(由,CC BY-SA 2.0))。

每行包含以下每个花样品的数据: 长度,萼片宽度, 长度,花瓣宽度和花种。花种以整数表示,0表示Iris setosa,1表示Iris versicolor,2表示Iris virginica

萼片长度 萼片宽度 花瓣长度 花瓣宽度 种类
5.1 3.5 1.4 0.2 0
4.9 3.0 1.4 0.2 0
4.7 3.2 1.3 0.2 0
... ... ... ... ...
7 3.2 4.7 1.4 1
6.4 3.2 4.5 1.5 1
6.9 3.1 4.9 1.5 1
... ... ... ... ...
6.5 3.0 5.2 2.0 2
6.2 3.4 5.4 2.3 2
5.9 3.0 5.1 1.8 2

对于本教程,Iris数据已被随机分为两个独立的CSV:

  • 120个样本的训练集(
  • 一个30个样本的测试集()。

要开始,首先导入所有必要的模块,并定义下载和存储数据集的位置:

from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import urllib import tensorflow as tf import numpy as np IRIS_TRAINING = "iris_training.csv" IRIS_TRAINING_URL = "http://download.tensorflow.org/data/iris_training.csv" IRIS_TEST = "iris_test.csv" IRIS_TEST_URL = "http://download.tensorflow.org/data/iris_test.csv"

然后,如果培训和测试集尚未存储在本地,请下载。

if not os.path.exists(IRIS_TRAINING):   raw = urllib.urlopen(IRIS_TRAINING_URL).read()   with open(IRIS_TRAINING,'w') as f:     f.write(raw) if not os.path.exists(IRIS_TEST):   raw = urllib.urlopen(IRIS_TEST_URL).read()   with open(IRIS_TEST,'w') as f:     f.write(raw)

接下来,Dataset使用 方法将训练和测试集加载到learn.datasets.baseload_csv_with_header()方法需要三个必需的参数:

  • filename,它将文件路径作为CSV文件
  • target_dtype,它采用 集的目标值的 
  • features_dtype,它采用 集的特征值的 

在这里,目标(你正在训练模型预测的值)是花种,它是0-2的整数,所以适当的numpy数据类型是np.int

# Load datasets. training_set = tf.contrib.learn.datasets.base.load_csv_with_header(     filename=IRIS_TRAINING,     target_dtype=np.int,     features_dtype=np.float32) test_set = tf.contrib.learn.datasets.base.load_csv_with_header(     filename=IRIS_TEST,     target_dtype=np.int,     features_dtype=np.float32)

Datasettf.contrib.learn中的s被  ; 您可以通过datatarget 字段访问要素数据和目标值在这里,training_set.datatraining_set.target包含用于训练集,分别特征数据和目标值,并test_set.datatest_set.target含有特征数据和目标值的测试集。

稍后,在  您将使用training_set.data和 training_set.target训练您的模型,并在 您将使用test_set.data和 test_set.target但首先,您将在下一节中构建您的模型。

构建深层神经网络分类器

tf.contrib.learn提供了各种预定义的模型,称为 ,您可以使用“开箱即用”来对数据运行培训和评估操作。在这里,您将配置深层神经网络分类器模型以适应Iris数据。使用tf.contrib.learn,您可以使用几行代码实例化 

# Specify that all features have real-value data feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)] # Build 3 layer DNN with 10, 20, 10 units respectively. classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,                                             hidden_units=[10, 20, 10],                                             n_classes=3,                                             model_dir="/tmp/iris_model")

上面的代码首先定义了模型的特征列,它们指定数据集中的要素的数据类型。所有的特征数据是连续的,所以tf.contrib.layers.real_valued_column使用相应的功能来构造特征列。数据集中有四个特征(萼片宽度,萼片高度,花瓣宽度和花瓣高度),因此dimension 必须设置4为保存所有数据。

然后,代码DNNClassifier使用以下参数创建一个模型:

  • feature_columns=feature_columns上面定义的特征列集合。
  • hidden_units=[10, 20, 10]三个 ,分别含有10,20和10个神经元。
  • n_classes=3三个目标课程,代表三种虹膜物种。
  • model_dir=/tmp/iris_modelTensorFlow将在模型训练期间保存检查点数据的目录。有关使用TensorFlow进行日志记录和监视的更多信息,请参阅

描述培训输入管道

所述tf.contrib.learnAPI使用输入功能,其创建用于生成模型数据中TensorFlow操作。在这种情况下,数据足够小,可以将其存储在以下代码生成最简单的输入管道:

# Define the training inputs def get_train_inputs():   x = tf.constant(training_set.data)   y = tf.constant(training_set.target)   return x, y

将DNNC分类器安装到Iris训练数据

现在您已经配置了DNN classifier模型,您可以使用该方法将其适用于Iris训练数据通过get_train_inputs作为input_fn训练的步骤和数目(这里,2000年):

# Fit model. classifier.fit(input_fn=get_train_inputs, steps=2000)

模型的状态保留在这里classifier,这意味着你可以反复训练,如果你喜欢。例如,以上相当于以下内容:

classifier.fit(x=training_set.data, y=training_set.target, steps=1000) classifier.fit(x=training_set.data, y=training_set.target, steps=1000)

但是,如果您希望在列车时跟踪模型,则可能需要使用TensorFlow  来执行日志记录操作。有关 此主题的更多信息,请参阅教程 

评估模型精度

您的DNNClassifier模型适合Iris训练数据; 现在,您可以使用该方法检查其对Iris测试数据的准确性 喜欢fit, evaluate需要一个构建其输入管道的输入函数。evaluate 返回一个dict与评估结果。下面的代码经过光圈测试DATA- test_set.datatest_set.target-to evaluate并打印accuracy从结果:

# Define the test inputs def get_test_inputs():   x = tf.constant(test_set.data)   y = tf.constant(test_set.target)   return x, y # Evaluate accuracy. accuracy_score = classifier.evaluate(input_fn=get_test_inputs,                                      steps=1)["accuracy"] print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
注意:
这里的steps论据evaluate很重要。 通常运行直到它到达输入的末尾。这对于评估一组文件是完美的,但是这里使用的常量将永远不会抛出OutOfRangeError或 StopIteration正在期待。

当您运行完整的脚本时,它会打印一些接近:

Test Accuracy: 0.966667

您的准确性结果可能有所不同,但应高于90%。对于相对较小的数据集来说不错!

分类新样本

使用估计器的predict()方法对新样本进行分类。例如,说你有这两个新花样:

萼片长度 萼片宽度 花瓣长度 花瓣宽度
6.4 3.2 4.5 1.5
5.8 3.1 5 1.7

您可以使用该predict()方法预测其物种predict返回一个生成器,可以很容易地转换成一个列表。以下代码检索并打印类预测:

# Classify two new flower samples. def new_samples():   return np.array(     [[6.4, 3.2, 4.5, 1.5],      [5.8, 3.1, 5.0, 1.7]], dtype=np.float32) predictions = list(classifier.predict(input_fn=new_samples)) print(     "New Samples, Class Predictions:    {}\n"     .format(predictions))

您的结果应如下所示:

New Samples, Class Predictions:    [1 2]

因此,该模型预测第一个样品是Iris versicolor,第二个样品是Iris virginica

其他资源

  • 有关tf.contrib.learn的更多参考资料,请参阅官方 

  • 要了解有关使用tf.contrib.learn创建线性模型的更多信息,请参阅 

  • 要使用tf.contrib.learn API构建自己的Estimator,请查看 

  • 要在浏览器中实验神经网络建模和可视化,请查看

  • 有关神经网络的更多高级教程,请参阅 

转载地址:https://lqfarmer.blog.csdn.net/article/details/77239217 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:使用tf.contrib.learn构建输入函数
下一篇:tensorflow的作用机制

发表评论

最新留言

感谢大佬
[***.8.128.20]2024年05月01日 03时01分38秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章