俄罗斯贵宾会-俄罗斯贵宾会官网
做最好的网站

俄罗斯贵宾会[译]简明教程:Tensorflow模型的保存与恢复

当我们对模型进行了训练后,就需要把模型保存起来,便于在预测时直接用已经训练好的模型进行预测。

b) Checkpoint file:

它是一个二进制文件,包含所有的权重,偏置,导数和其他保存变量的值。文件后缀为: .ckpt。但自从0.11版本之后,Temsorflow作了改变,不再是一个单独的.ckpt文件,取而代之的是两个文件:

<<mymodel.data-00000-of-00001>>

<<mymodel.index>>

.data文件包含着训练好的变量的值,除此之外,Tensorflow还有一个名为checkpoint的文件,持续记录着最新的保存数据。

所以,总结下来,0.10之后的Tensorflow模型如下图所示:

而,0.11版本之前的Tensorflow模型,仅仅包含三个文件:

<<inception_v1.meta>>

<<inception_v1.ckpt>>

<<checkpoint>>

保存模型的权重和偏置值

假设我们已经训练好了模型,其中有关于weights和biases的值,例如:

import tensorflow as tf
# 保存到文件
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')

然后我们初始化这些变量的值,假装是训练后被设置上的值:

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

最后进行保存:

# 创建saver
saver = tf.train.Saver()
save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
print("保存的路径为:", save_path)

这样在打印出:

保存的路径为: D:/todel/python/saver/save_net.ckpt

在那个目录下,我们看到:
俄罗斯贵宾会 1

这样,这些训练后的参数就被保存起来了。

完整的保存参数的代码为:

import tensorflow as tf
# 保存到文件
W = tf.Variable([[1, 2, 3], [3, 4, 5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1, 2, 3]], dtype=tf.float32, name='biases')

init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)

# 创建saver
saver = tf.train.Saver()
save_path = saver.save(sess, "D:/todel/python/saver/save_net.ckpt")
print("保存的路径为:", save_path)

3. 导入预训练的模型

如果你想要使用别人训练好的模型做fine-tuning,有两件事需要做:

恢复模型的权重和偏置值

在我们训练好模型并把训练后的权重和偏置值保存了之后,当我们需要进行预测时,只要读取这个已经保存好的权重和偏置值就可以进行预测了。
当然,这里的模型结构还是需要进行创建的,因为我们保存的仅仅是权重值和偏置值。

首先定义要恢复的权重和偏置值的结构:

import tensorflow as tf
import numpy as np
# 定义权重和偏置值的结构,但其中的数值随便填
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

注意:其中的name要跟之前保存时一致。

然后进行加载:

saver = tf.train.Saver()
sess = tf.Session()
# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))

这样输出为:

weights: [[ 1.  2.  3.]
 [ 3.  4.  5.]]
biases: [[ 1.  2.  3.]]

就是前面我们保存的内容被恢复出来了。

完整的恢复代码为:

import tensorflow as tf
import numpy as np
# 定义权重和偏置值的结构,但其中的数值随便填
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

saver = tf.train.Saver()
sess = tf.Session()
# 不需要对变量进行初始化,因为这些变量的值我们会从saver中进行恢复
saver.restore(sess, "D:/todel/python/saver/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))

这篇tensorflow的教程,将解释:

这篇教程,假定读者对神经网络的训练有基本的了解。如果不是,请先阅读Tensorflow Tutorial 2: image classifier using convolutional neural network俄罗斯贵宾会,,然后阅读本文。

2. 保存一个Tensorflow模型:

假设,你正在训练一个卷积神经网络,用于图片分类。作为一个标准操作,你持续观测Loss function和Accuracy。一旦你看到网络收敛,你可以人为停止训练或者只训练固定数目的epochs。当训练完成之后,我们想要保存所有的变量和网络图(network graph)到一个文件,以便日后使用。因此,在Tensorflow中,为了保存graph和变量,我们应该新建一个tf.train.Saver()类。

谨记Tensorflow的变量只有在一个session中才是有效的。因此,你不得不在一个session中保存模型,使用刚刚新建的saver对象,调用save方法,如下:

这里,sess是一个session对象,“my-test-model”是你想要保存的模型的名字。完整的例子如下:

如果,我们想要在1000次迭代之后保存模型,可以传入表示步数的参数:

这行代码将添加‘-1000’至模型的名字,以下文件将被建立:

假设,训练时,我们每隔1000次迭代保存一次模型,因此,.meta文件第1000次迭代生成.meta文件后,我们不必要每次新建.meta文件(即在2000,3000次等迭代无须新建.meta文件)。我们仅仅保存最新的迭代模型。因为graph结构并没有改变,因此,也没必要写meta-graph,使用如下代码:

如果你想要只记录最新的4个模型,并每隔2个小时保存一个模型,可以使用这两个参数:max_to_keep和keep_checkpoint_every_n_hours,如下:

需要指出的是,如果我们在tf.train.Saver()中不指定任何事情,它将保存所有的变量。如果,我们不想保存所有的变量,仅仅是一部分。我们可以指定想要保存的变量或集合。当新建tf.train.Saver实例时,传递给它一个想要保存的变量的列表或者字典。看下面的例子:

可以保存Tensorflow Graph的指定的需要的部分。

4. 使用恢复模型

既然你已经理解如何保存并恢复Tensorflow模型,让我们养成一个规范去恢复任意预训练模型,并使用它做预测,fine-tuning或者进一步训练。不管什么时候使用Tensorflow,你将定义一个Graph,包含输入,一些超参数,如learning rate, global step等。一个标准的喂入数据和超参数的方式是使用placeholders。让我们构建一个小的使用placeholders的网络,并保存它。值得指出的是。当网络被保存。placeholders的值并未保存。

现在,当我们想要恢复模型时,不仅需要恢复graph和权重,也需要准备新的feed_dict去喂新的训练数据给网络。我们可以通过graph.get_tensor_by_name()等方法得到保存的ops和placeholder变量的引用。

如果我们仅仅想要在网络上跑不同的数据,可以通过feed_dict传递新的数据给网络。

如果想要增加更多的操作(增加更多的layers)到graph里,并训练它。当然,你也可以如下:

但是,可以只恢复一部分的graph然后增加一些操作进行fine-tuning么?当然可以。利用graph.get_tensor_by_name()方法得到相应操作的引用,在顶层构建网络。这里有个实际的例子。我们加载一个预训练的VGG网络,改变输出的单元数目为2,利用新的训练数据fine-tuning。

希望这篇文章能让你清晰地理解Tensorflow模型的保存和恢复。

转载请注明来源,谢谢。

本文由俄罗斯贵宾会发布于编程,转载请注明出处:俄罗斯贵宾会[译]简明教程:Tensorflow模型的保存与恢复

您可能还会对下面的文章感兴趣: