tensorflow保存模型与恢复

tensorflow 模型保存

tensorflow模型保存有两种方式。
一种tensorflow本身的保存方式,使用tf.train.Saver进行保存,即tf格式
一种是当使用tf.Keras.Model作为模型的时候还可以使用h5方式保存,即h5格式

tf格式

tf模式有两种保存文件的方式。
下面是如何使用tf格式保存模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import tensorflow as tf
from tensorflow.core.protobuf import saver_pb2
# Create some variables. v1 = tf.get_variable("v1", shape=[3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer = tf.zeros_initializer)

inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)

init_op = tf.global_variables_initializer()

saver = tf.train.Saver(write_version=saver_pb2.SaverDef.V1)
# saver = tf.train.Saver(write_version=saver_pb2.SaverDef.V2)
with tf.Session() as sess:
sess.run(init_op)
inc_v1.op.run()
dec_v2.op.run()
save_path = saver.save(sess, "tf_format1/model.ckpt")
print("Model saved in path: %s" % save_path)

如上writer_version分别可以为V1或者V2
V1版本是要被tensorflow抛弃的一个版本。默认和推荐的版本都是V2
V2会产生4个文件

  • checkpoint,记录最新模型的checkpoint
  • model.ckpt.data-00000-of-00001 , 网络权重信息
  • model.ckpt.index # 存储Variables的索引
  • model.ckpt.meta 一个协议缓冲,保存tensorflow中完整的graph、variables、operation、collection。当saver.save传入write_meta_graph=False时,将不会产生这个文件。keras模型保存时当使用tf格式存储时,这个参数设置的就是false。

具体使用方式可参考

h5格式

该种格式在使用Keras模型的时候才可以使用,使用model.save_weights方法然后传参save_format=h5, 后面使用的是hdf5来存储训练的模型。Keras还可以使用tf格式保存,其本质就是调用tf.train.Saver

tensorflow 模型恢复

具体可参考