python调用训练好的模型_TensorFlow 调用预训练好的模型—— Python 实现
发布日期:2021-10-31 18:34:20 浏览次数:22 分类:技术文章

本文共 1880 字,大约阅读时间需要 6 分钟。

1. 准备预训练好的模型

TensorFlow 预训练好的模型被保存为以下四个文件

1

data 文件是训练好的参数值,meta 文件是定义的神经网络图,checkpoint 文件是所有模型的保存路径,如下所示,为简单起见只保留了一个模型。

model_checkpoint_path: "/home/senius/python/c_python/test/model-40"

all_model_checkpoint_paths: "/home/senius/python/c_python/test/model-40"

复制代码

2. 导入模型图、参数值和相关变量

import tensorflow as tf

import numpy as np

sess = tf.Session()

X = None # input

yhat = None # output

def load_model():

"""

Loading the pre-trained model and parameters.

"""

global X, yhat

modelpath = r'/home/senius/python/c_python/test/'

saver = tf.train.import_meta_graph(modelpath + 'model-40.meta')

saver.restore(sess, tf.train.latest_checkpoint(modelpath))

graph = tf.get_default_graph()

X = graph.get_tensor_by_name("X:0")

yhat = graph.get_tensor_by_name("tanh:0")

print('Successfully load the pre-trained model!')

复制代码通过 saver.restore 我们可以得到预训练的所有参数值,然后再通过 graph.get_tensor_by_name 得到模型的输入张量和我们想要的输出张量。

3. 运行前向传播过程得到预测值

def predict(txtdata):

"""

Convert data to Numpy array which has a shape of (-1, 41, 41, 41 3).

Test a single example.

Arg:

txtdata: Array in C.

Returns:

Three coordinates of a face normal.

"""

global X, yhat

data = np.array(txtdata)

data = data.reshape(-1, 41, 41, 41, 3)

output = sess.run(yhat, feed_dict={X: data}) # (-1, 3)

output = output.reshape(-1, 1)

ret = output.tolist()

return ret

复制代码通过 feed_dict 喂入测试数据,然后 run 输出的张量我们就可以得到预测值。

4. 测试

load_model()

testdata = np.fromfile('/home/senius/python/c_python/test/04t30t00.npy', dtype=np.float32)

testdata = testdata.reshape(-1, 41, 41, 41, 3) # (150, 41, 41, 41, 3)

testdata = testdata[0:2, ...] # the first two examples

txtdata = testdata.tolist()

output = predict(txtdata)

print(output)

# [[-0.13345889747142792], [0.5858198404312134], [-0.7211828231811523],

# [-0.03778800368309021], [0.9978875517845154], [0.06522832065820694]]

复制代码本例输入是一个三维网格模型处理后的 [41, 41, 41, 3] 的数据,输出一个表面法向量坐标 (x, y, z)。

获取更多精彩,请关注「seniusen」!1

转载地址:https://blog.csdn.net/weixin_40006779/article/details/109984848 如侵犯您的版权,请留言回复原文章的地址,我们会给您删除此文章,给您带来不便请您谅解!

上一篇:angular高级编程第3版下载_UNIX系统编程宝典,每一本都值得程序员珍藏
下一篇:safari快捷图标不见了_桌面图标不见了怎么办?这里有妙招

发表评论

最新留言

感谢大佬
[***.8.128.20]2024年03月17日 09时00分54秒

关于作者

    喝酒易醉,品茶养心,人生如梦,品茶悟道,何以解忧?唯有杜康!
-- 愿君每日到此一游!

推荐文章

mysql service5.7_Mysql5.7服务下载安装 2019-04-21
mysql查看线程完整执行命令_MySQL-查看运行的线程-SHOW PROCESSLIST 2019-04-21
mysql 更新数据 字符串_批量替换 MySQL 指定字段中的字符串 2019-04-21
web开发 mysql安装_mysqlinstallerwebcommunity5.7.21.0.msi安装图文教程 2019-04-21
mysql concat 整数型_MySQL 数字类型转换函数(concat/cast) 2019-04-21
mysql单元格函数是_MySQL常用内置函数 2019-04-21
mysql 怎么字段分裂_你可以分裂/爆炸MySQL查询中的字段吗? 2019-04-21
mysql server卸载出错_Mysql卸载问题Start Server卡住报错解决方法 2019-04-21
全国省市区 mysql_2017全国省市区数据库【含三款数据库】 2019-04-21
druid加载MySQL驱动原理_你好,想知道mybatis+druid+jdbc 原理介绍? 2019-04-21
mysql 怎样链接jdbc_jdbc怎么链接mysql数据库 2019-04-21
mysql学生课程表试题_Mysql练习之 学生表、课程表 、教师表、成绩表 50道练习题... 2019-04-21
java exec封装_Java 执行系统命令工具类(commons-exec) 2019-04-21
php sha512解密,PHP加密函数 sha256 sha512 sha256_file() sha512_file() 2019-04-21
mysql里可以用cube吗_sql server的cube操作符使用详解_mysql 2019-04-21
php mysql 图书_使用PHP+MySQL来对图书管理系统进行构建 2019-04-21
单片机c语言 int1,51单片机into、int1中断计数c语言源程序.doc 2019-04-21
c语言课程设计工资管理建库,C语言课程设计工资管理系统参考.doc 2019-04-21
c语言case中途跳出,break语句在switch结构语句中的作用是终止某个case,并跳出switch结构语句。... 2019-04-21
c51写c语言外部ram头文件,C51中访问外部RAM的方法 2019-04-21