TensorFlow模型文件保存和读取
一、模型文件的保存
在训练一个TensorFlow模型之后,我们可以将训练好的模型保存成文件,这样可以方便下一次对新的数据进行预测的时候直接加载训练好的模型即可获得结果,下面通过TensorFlow提供的tf.train.Saver函数,将一个模型保存成文件,一般习惯性的将TensorFlow的模型文件命名为*.ckpt文件。
[python] view plain copy
- <span style="font-size:14px;">import tensorflow as tf
- if __name__ == "__main__":
- #定义两个变量
- a = tf.Variable(tf.constant(1.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(2.0,shape=[1],name="b"))
- c = a + b
- init = tf.initialize_all_variables()
- sess = tf.Session()
- sess.run(init)
- #声明一个保存
- saver = tf.train.Saver()
- saver.save(sess,"./model.ckpt")</span>
如果,在运行程序的时候报ValueError: Parent directory of model.ckpt doesn't exist, can't save.,只需要将保存文件的路径由model.ckpt改成./model.ckpt即可。运行完上面的代码之后,我们会发现在当前的程序目录下产生四个文件checkpoint、model.ckpt.data-00000-of-00001、model.ckpt.index、model.ckpt.meta。会产生四个文件的原因,之前有介绍过TensorFlow的程序是由计算图所组成的,所以在持久化的时候TensorFlow会将计算图的结果和图上的参数值分成不同的文件进行保存。二、模型文件的读取
TensorFlow对于模型文件的读取方式也提供了几种方法,根据读取不同的文件来获取不同的信息。
1、加载model.ckpt文件来初始化变量
[python] view plain copy
- <span style="font-size:14px;"> a = tf.Variable(tf.constant(3.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(4.0,shape=[1],name="b"))
- c = a + b
- saver = tf.train.Saver()
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(c))
- #[ 3.]</span>
在声明变量的时候,变量的名字,shape要与保存的模型文件一致,无论给变量a和b的初始值设置为什么,最后输出的结果总是3,因为在保存模型文件的时候,已经记录了变量的初始值。而且,在加载model.ckpt文件的时候也不需要对变量进行初始化操作。2、加载持久化图获取全部变量
[python] view plain copy
- <span style="font-size:14px;"> saver = tf.train.import_meta_graph("model.ckpt.meta")
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(tf.get_default_graph().get_tensor_by_name("a:0")))
- #[ 1.]
- print(sess.run(tf.get_default_graph().get_tensor_by_name("b:0")))
- #[ 2.]
- print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))
- #[ 3.]</span>
通过加载model.cpt.meta文件和model.ckpt文件来获取全部的变量,然后通过变量的名称来获取变量的值,在通过变量的名字来获取变量的时候需要注意的是,比如说变量a的name为"a",但是在使用名字的时候不能直接使用“a”来获取变量的值,如果直接使用“a”的话,会报ValueError: The name 'a' refers to an Operation, not a Tensor. Tensor names must be of the form "<op_name>:<output_index>".错误的原因就是需要以<op_name>:<output_index>来获取变量的值,意思就是名字和下标结合,如"a:0"。还需要注意的就是在获取变量c的值的时候,不是通过c,而是通过"add:0",因为变量a和b直接求和,生成的变量c的名字TensorFlow默认为"add"。3、加载指定列表变量
[python] view plain copy
- <span style="font-size:14px;"> a = tf.Variable(tf.constant(3.0,shape=[1],name="a"))
- b = tf.Variable(tf.constant(4.0,shape=[1],name="b"))
- c = a + b
- saver = tf.train.Saver([a,b])
- sess = tf.Session()
- saver.restore(sess,"model.ckpt")
- print(sess.run(a))
- #[ 1.]
- print(sess.run(b))
- #[ 2.]</span>
通过在第一种方式的基础上,初始化Saver的时候指定一个列表,在初始化模型文件中的变量时,只会加载指定列表的变量。如果在上面代码的基础上,在最后在加一句print(sess.run(c)),输出的结果为[3.],明明没有指定加载c,为什么还能输出3呢?其实,原因也很简单,因为我们已经初始化了变量a和b,所以通过计算a+b自然就可以计算出c了。如果,指定加载列表为[a]而输出b的话,就会报tensorflow.python.framework.errors_impl.FailedPreconditionError: Attempting to use uninitialized value Variable_1
[[Node: _retval_Variable_1_0_0 = _Retval[T=DT_FLOAT, index=0, _device="/job:localhost/replica:0/task:0/cpu:0"](Variable_1)]],使用一个没有初始化的变量。
4、加载变量名的重命名
tensorfow提供了一种方法可以修改加载模型中的变量名,通过tf.train.Saver(),带参的形式来修改变量名称。
[python] view plain copy
- <span style="font-size:14px;"> #重新定义两个变量v1和v2
- v1 = tf.Variable(tf.constant(3.,shape=[1]),name="v1")
- v2 = tf.Variable(tf.constant(4.,shape=[1]),name="v2")
- #将模型中的变量名a重命名为v1,将模型中的变量名b重命名为v2
- save = tf.train.Saver({"a":v1,"b":v2})
- sess = tf.Session()
- save.restore(sess,"model.ckpt")
- print(sess.run(v1))
- print(sess.run(v2))</span>
通过传入一个字典,来修改TensorFlow的变量名,a和b是模型中的变量名称,而v1和v2是将变量a和b重命名之后的名称。如果,你用我的第一个程序来保存一个模型文件,通过上面的方法来修改变量的名称的时候,你会得到一个错误NotFoundError (see above for traceback): Key b not found in checkpoint,难道上面的代码有问题?其实,这个坑是在保存模型文件留下来的。要想解决这个问题,首先还是看重命名变量名称的这个程序,TensorFlow提供的tf.train.Saver({"a":v1,"b":v2})方法,它会去checkpoint这个文件中找变量名为a和b的变量,然后再修改变量名,找不到这两个变量自然就报错了。第一个程序的坑,在这句代码中tf.Variable(tf.constant(1.0,shape=[1],name="a")),也许细心的朋友已经发现问题了,其实我们只是将常量命名为了"a",并没有将变量命名,这样就导致了问题的发现。所以,要想解决这个问题,我们只需要将tf.tran.Saver中的{“a”:v1,"b":v2}修改为{"Variable":v1,"Variable_1":v2}或者将tf.Variable(tf.constant(1.0,shape=[1],name="a"))修改为tf.Variable(tf.constant(1.0,shape=[1]),name="a"),其中的Variable和Variable_1是TensorFlow默认的变量名称,我们可以通过a.name的方式查看TensorFlow中的变量名称。