tensorflow保存读取-【老鱼学tensorflow】

2018-03-01 07:50:47来源:cnblogs.com作者:dreampursuer人点击

分享

当我们对模型进行了训练后,就需要把模型保存起来,便于在预测时直接用已经训练好的模型进行预测。

保存模型的权重和偏置值

假设我们已经训练好了模型,其中有关于weights和biases的值,例如:

import tensorflow as tf# 保存到文件W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')

然后我们初始化这些变量的值,假装是训练后被设置上的值:

init = tf.global_variables_initializer()sess = tf.Session()sess.run(init)

最后进行保存:

# 创建saversaver = tf.train.Saver()save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")print("保存的路径为:", save_path)

这样在打印出:

保存的路径为: D:/todel/python/saver/save_net.ckpt

在那个目录下,我们看到:

这样,这些训练后的参数就被保存起来了。

完整的保存参数的代码为:

import tensorflow as tf# 保存到文件W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')init = tf.global_variables_initializer()sess = tf.Session()sess.run(init)# 创建saversaver = tf.train.Saver()save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")print("保存的路径为:", save_path)

恢复模型的权重和偏置值

在我们训练好模型并把训练后的权重和偏置值保存了之后,当我们需要进行预测时,只要读取这个已经保存好的权重和偏置值就可以进行预测了。
当然,这里的模型结构还是需要进行创建的,因为我们保存的仅仅是权重值和偏置值。

首先定义要恢复的权重和偏置值的结构:

import tensorflow as tfimport numpy as np# 定义权重和偏置值的结构,但其中的数值随便填W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

注意:其中的name要跟之前保存时一致。

然后进行加载:

saver = tf.train.Saver()sess = tf.Session()# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")print("weights:", sess.run(W))print("biases:", sess.run(b))

这样输出为:

weights: [[ 1.  2.  3.] [ 3.  4.  5.]]biases: [[ 1.  2.  3.]]

就是前面我们保存的内容被恢复出来了。

完整的恢复代码为:

import tensorflow as tfimport numpy as np# 定义权重和偏置值的结构,但其中的数值随便填W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")saver = tf.train.Saver()sess = tf.Session()# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")print("weights:", sess.run(W))print("biases:", sess.run(b))

最新文章

123

最新摄影

微信扫一扫

第七城市微信公众平台