TensorFlow 模型持久化
主要类 tf.train.Saver
. 方法<Saver>saver.save
, 保存计算图的结构
, 保存每个变量的取值
, 保存一个目录下所有模型文件列表
1 2 3 4 5 6 7 8 9 10 11 12 v1 = tf.Variable(tf.constant(1.0 , shape=[1 ]), name='v1' ) v2 = tf.Variable(tf.constant(2.0 , shape=[1 ]), name='v2' ) result = v1 + v2 init_op = tf.global_variables_initializer() saver1 = tf.train.Saver() with tf.Session() as sess: sess.run(init_op) print("All variables:" , tf.global_variables()) saver1.save(sess, "./SaveModels/model1.ckpt" )
1 2 3 4 5 6 7 8 9 10 11 12 v1 = tf.Variable(tf.constant(0.0 , shape=[1 ]), name='v1' ) v2 = tf.Variable(tf.constant(0.0 , shape=[1 ]), name='v2' ) result = v1 + v2 saver2 = tf.train.Saver() with tf.Session() as sess: saver2.restore(sess, './SaveModels/model1.ckpt' ) print(sess.run(result)) pass
1 2 3 4 5 6 saver = tf.train.import_meta_graph('./SaveModels/model1.ckpt.meta' ) with tf.Session() as sess: saver.restore(sess, './SaveModels/model1.ckpt' ) print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0" )))
tf.train.Saver(paramNeedLoad: list)
, paramNeedLoad 是要加载变量的 list.
同时也支持加载变量后重命名, 其目的之一是为了方便使用变量的滑动平均值. 因为, 滑动平均值是通过影子变量维护的, 如果加载模型时直接将影子变量映射到自身, 那么使用训练好的模型时, 就不需要再调用函数来获取变量的滑动平均值了.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 V1 = tf.Variable(tf.constant(0.0 , shape=[1 ]), name='other-v1' ) V2 = tf.Variable(tf.constant(0.0 , shape=[1 ]), name='other-v2' ) result = V1 + V2 saver = tf.train.Saver({"v1" : V1, "v2" : V2}) with tf.Session() as sess: saver.restore(sess, './SaveModels/model1.ckpt' ) print(V1) print(V2) print("Result=" , sess.run(result)) """ <tf.Variable 'other-v1:0' shape=(1,) dtype=float32_ref> <tf.Variable 'other-v2:0' shape=(1,) dtype=float32_ref> Result= [ 3.] """
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 v1 = tf.Variable(0 , dtype=tf.float32, name='v1' ) v2 = tf.Variable(0 , dtype=tf.float32, name='v2' ) print("Before: " ) for tempV in tf.global_variables(): print(tempV.name) ema = tf.train.ExponentialMovingAverage(0.99 ) maintain_avg_op = ema.apply(tf.global_variables()) print("After: " ) for tempV in tf.global_variables(): print(tempV.name) saver = tf.train.Saver() with tf.Session() as sess: init_op = tf.global_variables_initializer() sess.run(init_op) sess.run(tf.assign(v1, 10 )) sess.run(tf.assign(v2, 5 )) sess.run(maintain_avg_op) saver.save(sess, "./SaveModels/model1.ckpt" ) print(sess.run([v1, v2, ema.average(v1), ema.average(v2)])) """ Before: v1:0 v2:0 After: v1:0 v2:0 v1/ExponentialMovingAverage:0 v2/ExponentialMovingAverage:0 2018-08-17 11:03:37.258476: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA [10.0, 5.0, 0.099999905, 0.049999952] """
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 avgV1 = tf.Variable(0 , dtype=tf.float32, name='avg-v1' ) avgV2 = tf.Variable(0 , dtype=tf.float32, name='avg-v2' ) saver = tf.train.Saver({"v1/ExponentialMovingAverage" : avgV1, "v2/ExponentialMovingAverage" : avgV2}) with tf.Session() as sess: saver.restore(sess, SAVE_MODEL_PATH) print("avg v1" , sess.run(avgV1)) print("avg v2" , sess.run(avgV2)) """ avg v1 0.0999999 avg v2 0.05 """
从上面的代码可以发现, 如果Saver
中的字典很多的时候使用就不太方便. tf.train.ExponentialMovingAverage
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 avgV1 = tf.Variable(0 , dtype=tf.float32, name='v1' ) avgV2 = tf.Variable(0 , dtype=tf.float32, name='v2' ) ema = tf.train.ExponentialMovingAverage(0.99 ) renameDict = ema.variables_to_restore() print("renameDict=" , renameDict) saver = tf.train.Saver(renameDict) with tf.Session() as sess: saver.restore(sess, SAVE_MODEL_PATH) print("avg v1" , sess.run(avgV1)) print("avg v2" , sess.run(avgV2)) """ renameDict= {'v2/ExponentialMovingAverage': <tf.Variable 'v2:0' shape=() dtype=float32_ref>, 'v1/ExponentialMovingAverage': <tf.Variable 'v1:0' shape=() dtype=float32_ref>} avg v1 0.0999999 avg v2 0.05 """
convert_variables_to_constants 转成常量保存
之前的方法会保存运行 TensorFlow 程序所需要的全部信息,然而有时并不需要某些信息.
利用 convert_variables_to_constants
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 v3 = tf.Variable(tf.constant(3.0 , shape=[1 ]), name='v3' ) v4 = tf.Variable(tf.constant(4.0 , shape=[1 ]), name='v4' ) result = v3 + v4 init_op = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init_op) graph_def = tf.get_default_graph().as_graph_def() print("Result: " , result) output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add' ]) print('output_graph_def: ' , output_graph_def) """ 输出 output_graph_def: node { name: "v3" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } float_val: 3.0 } } } } node { name: "v3/read" op: "Identity" input: "v3" attr { key: "T" value { type: DT_FLOAT } } attr { key: "_class" value { list { s: "loc:@v3" } } } } node { name: "v4" op: "Const" attr { key: "dtype" value { type: DT_FLOAT } } attr { key: "value" value { tensor { dtype: DT_FLOAT tensor_shape { dim { size: 1 } } float_val: 4.0 } } } } node { name: "v4/read" op: "Identity" input: "v4" attr { key: "T" value { type: DT_FLOAT } } attr { key: "_class" value { list { s: "loc:@v4" } } } } node { name: "add" op: "Add" input: "v3/read" input: "v4/read" attr { key: "T" value { type: DT_FLOAT } } } library { } """ with tf.gfile.GFile("./SaveModels/model3.pb" , 'wb' ) as f: outputData = output_graph_def.SerializeToString() print('Serialize Data: ' , outputData) f.write(outputData)
以上是保存节点代码, 以下是加载节点并执行结果的代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 from tensorflow.python.platform import gfilewith tf.Session() as sess: filePath = './SaveModels/model3.pb' with gfile.FastGFile(filePath, 'rb' ) as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) result = tf.import_graph_def(graph_def, return_elements=["add:0" ]) print(sess.run(result))