TensorFlow 模型持久化
主要类 tf.train.Saver
. 方法<Saver>saver.save
.
保存后会生成3个文件.
ckpt.meta
, 保存计算图的结构
ckpt
, 保存每个变量的取值
checkpoint
, 保存一个目录下所有模型文件列表
保存图
保存计算图
1 2 3 4 5 6 7 8 9 10 11 12 v1 = tf.Variable(tf.constant(1.0 , shape=[1 ]), name='v1' ) v2 = tf.Variable(tf.constant(2.0 , shape=[1 ]), name='v2' ) result = v1 + v2 init_op = tf.global_variables_initializer() saver1 = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) print("All variables:" , tf.global_variables()) saver1.save(sess, "./SaveModels/model1.ckpt" )
读取计算图中的变量
<Saver>saver.restore
函数
1 2 3 4 5 6 7 8 9 10 11 12 v1 = tf.Variable(tf.constant(0.0 , shape=[1 ]), name='v1' ) v2 = tf.Variable(tf.constant(0.0 , shape=[1 ]), name='v2' ) result = v1 + v2 saver2 = tf.train.Saver() with tf.Session() as sess: saver2.restore(sess, './SaveModels/model1.ckpt' ) print(sess.run(result)) pass
加载持久化的图(包含运算)
1 2 3 4 5 6 saver = tf.train.import_meta_graph('./SaveModels/model1.ckpt.meta' ) with tf.Session() as sess: saver.restore(sess, './SaveModels/model1.ckpt' ) print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0" )))
只加载部分指定变量
加载变量方法:
tf.train.Saver(paramNeedLoad: list)
, paramNeedLoad 是要加载变量的 list.
同时也支持加载变量后重命名, 其目的之一是为了方便使用变量的滑动平均值. 因为, 滑动平均值是通过影子变量维护的, 如果加载模型时直接将影子变量映射到自身, 那么使用训练好的模型时, 就不需要再调用函数来获取变量的滑动平均值了.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 V1 = tf.Variable(tf.constant(0.0 , shape=[1 ]), name='other-v1' ) V2 = tf.Variable(tf.constant(0.0 , shape=[1 ]), name='other-v2' ) result = V1 + V2 saver = tf.train.Saver({"v1" : V1, "v2" : V2}) with tf.Session() as sess: saver.restore(sess, './SaveModels/model1.ckpt' ) print(V1) print(V2) print("Result=" , sess.run(result)) """ <tf.Variable 'other-v1:0' shape=(1,) dtype=float32_ref> <tf.Variable 'other-v2:0' shape=(1,) dtype=float32_ref> Result= [ 3.] """
存取滑动平均例子:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 v1 = tf.Variable(0 , dtype=tf.float32, name='v1' ) v2 = tf.Variable(0 , dtype=tf.float32, name='v2' ) print("Before: " ) for tempV in tf.global_variables(): print(tempV.name) ema = tf.train.ExponentialMovingAverage(0.99 ) maintain_avg_op = ema.apply(tf.global_variables()) print("After: " ) for tempV in tf.global_variables(): print(tempV.name) saver = tf.train.Saver() with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) sess.run(tf.assign(v1, 10 )) sess.run(tf.assign(v2, 5 )) sess.run(maintain_avg_op) saver.save(sess, "./SaveModels/model1.ckpt" ) print(sess.run([v1, v2, ema.average(v1), ema.average(v2)])) """ Before: v1:0 v2:0 After: v1:0 v2:0 v1/ExponentialMovingAverage:0 v2/ExponentialMovingAverage:0 2018-08-17 11:03:37.258476: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA [10.0, 5.0, 0.099999905, 0.049999952] """
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 avgV1 = tf.Variable(0 , dtype=tf.float32, name='avg-v1' ) avgV2 = tf.Variable(0 , dtype=tf.float32, name='avg-v2' ) saver = tf.train.Saver({"v1/ExponentialMovingAverage" : avgV1, "v2/ExponentialMovingAverage" : avgV2}) with tf.Session() as sess: saver.restore(sess, SAVE_MODEL_PATH) print("avg v1" , sess.run(avgV1)) print("avg v2" , sess.run(avgV2)) """ avg v1 0.0999999 avg v2 0.05 """
从上面的代码可以发现, 如果Saver
中的字典很多的时候使用就不太方便. tf.train.ExponentialMovingAverage
类提供了自动生成类所需要的变量重命名字典.
修改之前的代码为:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 avgV1 = tf.Variable(0 , dtype=tf.float32, name='v1' ) avgV2 = tf.Variable(0 , dtype=tf.float32, name='v2' ) ema = tf.train.ExponentialMovingAverage(0.99 ) renameDict = ema.variables_to_restore() print("renameDict=" , renameDict) saver = tf.train.Saver(renameDict) with tf.Session() as sess: saver.restore(sess, SAVE_MODEL_PATH) print("avg v1" , sess.run(avgV1)) print("avg v2" , sess.run(avgV2)) """ renameDict= {'v2/ExponentialMovingAverage': <tf.Variable 'v2:0' shape=() dtype=float32_ref>, 'v1/ExponentialMovingAverage': <tf.Variable 'v1:0' shape=() dtype=float32_ref>} avg v1 0.0999999 avg v2 0.05 """
convert_variables_to_constants 转成常量保存
之前的方法会保存运行 TensorFlow 程序所需要的全部信息,然而有时并不需要某些信息.
比如在测试或者离线预测时,只需要知道如何从神经网络的输入层经过向前传播得到输出层即可,而不需要变量的初始化等辅助节点信息.
利用 convert_variables_to_constants
函数将变量及其取值通过常量的方式保存.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 v3 = tf.Variable(tf.constant(3.0 , shape=[1 ]), name='v3' ) v4 = tf.Variable(tf.constant(4.0 , shape=[1 ]), name='v4' ) result = v3 + v4 init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) graph_def = tf.get_default_graph().as_graph_def() print("Result: " , result) output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add' ]) print('output_graph_def: ' , output_graph_def) """ 输出 output_graph_def: node { name: "v3" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } float_val: 3.0 } } } } node { name: "v3/read" op: "Identity" input: "v3" attr { key: "T" value { type: DT_FLOAT } } attr { key: "_class" value { list { s: "loc:@v3" } } } } node { name: "v4" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } float_val: 4.0 } } } } node { name: "v4/read" op: "Identity" input: "v4" attr { key: "T" value { type: DT_FLOAT } } attr { key: "_class" value { list { s: "loc:@v4" } } } } node { name: "add" op: "Add" input: "v3/read" input: "v4/read" attr { key: "T" value { type: DT_FLOAT } } } library { } """ with tf.gfile.GFile("./SaveModels/model3.pb" , 'wb' ) as f: outputData = output_graph_def.SerializeToString() print('Serialize Data: ' , outputData) f.write(outputData)
以上是保存节点代码, 以下是加载节点并执行结果的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 from tensorflow.python.platform import gfilewith tf.Session() as sess: filePath = './SaveModels/model3.pb' with gfile.FastGFile(filePath, 'rb' ) as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) result = tf.import_graph_def(graph_def, return_elements=["add:0" ]) print(sess.run(result))