博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
TensorFlow学习笔记(8)--网络模型的保存和读取【转】
阅读量:7020 次
发布时间:2019-06-28

本文共 2566 字,大约阅读时间需要 8 分钟。

转自:http://blog.csdn.net/lwplwf/article/details/62419087

 

之前的笔记里实现了softmax回归分类、简单的含有一个隐层的神经网络、卷积神经网络等等,但是这些代码在训练完成之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练好的神经网络模型持久化,这就是这篇笔记里要写的东西。

TensorFlow提供了一个非常简单的API,即tf.train.Saver类来保存和还原一个神经网络模型。


下面代码给出了保存TensorFlow模型的方法:

import tensorflow as tf# 声明两个变量v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer() # 初始化全部变量 saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型 with tf.Session() as sess: sess.run(init_op) print("v1:", sess.run(v1)) # 打印v1、v2的值一会读取之后对比 print("v2:", sess.run(v2)) saver_path = saver.save(sess, "save/model.ckpt") # 将模型保存到save/model.ckpt文件 print("Model saved in file:", saver_path)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

这段代码中,通过saver.save函数将TensorFlow模型保存到了save/model.ckpt文件中,这里代码中指定路径为"save/model.ckpt",也就是保存到了当前程序所在文件夹里面的save文件夹中。

TensorFlow模型会保存在后缀为.ckpt的文件中。保存后在save这个文件夹中实际会出现3个文件,因为TensorFlow会将计算图的结构和图上参数取值分开保存。

  • model.ckpt.meta文件保存了TensorFlow计算图的结构,可以理解为神经网络的网络结构
  • model.ckpt文件保存了TensorFlow程序中每一个变量的取值
  • checkpoint文件保存了一个目录下所有的模型文件列表

这里写图片描述


下面代码给出了加载TensorFlow模型的方法:

可以对比一下v1、v2的值是随机初始化的值还是和之前保存的值是一样的?

import tensorflow as tf# 使用和保存模型代码中一样的方式来声明变量v1 = tf.Variable(tf.random_normal([1, 2]), name="v1") v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") saver = tf.train.Saver() # 声明tf.train.Saver类用于保存模型 with tf.Session() as sess: saver.restore(sess, "save/model.ckpt") # 即将固化到硬盘中的Session从保存路径再读取出来 print("v1:", sess.run(v1)) # 打印v1、v2的值和之前的进行对比 print("v2:", sess.run(v2)) print("Model Restored")
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

运行结果:

v1: [[ 0.76705766  1.82217288]]v2: [[-0.98012197  1.2369734   0.5797025 ] [ 2.50458145  0.81897354  0.07858191]]Model Restored
  • 1
  • 2
  • 3
  • 4
  • 1
  • 2
  • 3
  • 4

这段加载模型的代码基本上和保存模型的代码是一样的。也是先定义了TensorFlow计算图上所有的运算,并声明了一个tf.train.Saver类。两段唯一的不同是,在加载模型的代码中没有运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。 

也就是说使用TensorFlow完成了一次模型的保存和读取的操作。



如果不希望重复定义图上的运算,也可以直接加载已经持久化的图:

import tensorflow as tf# 在下面的代码中,默认加载了TensorFlow计算图上定义的全部变量# 直接加载持久化的图saver = tf.train.import_meta_graph("save/model.ckpt.meta") with tf.Session() as sess: saver.restore(sess, "save/model.ckpt") # 通过张量的名称来获取张量 print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

运行程序,输出:

[[ 0.76705766  1.82217288]]
  • 1
  • 1

有时可能只需要保存或者加载部分变量。 

比如,可能有一个之前训练好的5层神经网络模型,但现在想写一个6层的神经网络,那么可以将之前5层神经网络中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。

为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存或者加载的变量。比如在加载模型的代码中使用saver = tf.train.Saver([v1])命令来构建tf.train.Saver类,那么只有变量v1会被加载进来。

…未完待续

你可能感兴趣的文章
symantec5220牛刀小试系列(4)
查看>>
Go笔记-flag参数解析
查看>>
二分查找例子 记录一下
查看>>
printf的返回值问题(转)
查看>>
常用选择器
查看>>
eclipse 调试多线程
查看>>
各种排序算法
查看>>
条码控件商IDAutomation极大改善了Barcode Image Generator性能
查看>>
AxureRP7.0基础教程系列 部件详解 Menu 菜单
查看>>
SpringMVC 搭建及详解
查看>>
PCC-S-02201
查看>>
solr的分词器
查看>>
源码包安装
查看>>
docker hello-world
查看>>
5个词 带你回忆2014 IT安全圈
查看>>
微信小程序从入门到项目实战
查看>>
CentOS7 安装执行 VmwareCore
查看>>
EL表达式详解
查看>>
docker容器
查看>>
系统性能指标查看方法-Linux
查看>>