前两篇 * 使用TensorFlow C++ API构建线上预测服务 - 篇1 * 使用TensorFlow C++ API构建线上预测服务 - 篇2

在线下训练时,为了效率考虑,我们经常把数据转成TFRecord格式,然后直接调用TensorFlow提供的Reader来读入TFRecord数据。这样在生成的graph.pb中,Reader会对应多个节点,如果在c++中直接导入这个graph.pb我们就不能使用std::vector<std::pair<std::string, tensorflow::Tensor>>作为session.Run(...)的输入了。这篇文章讲解一下怎样处理这种情况。

Freeze Graph

回顾一下上篇讲到的怎样使用freeze graph。

python ../../python/freeze_graph.py \
--checkpoint_dir='./checkpoint' \
--output_node_names='predict/add' \
--output_dir='./model'
其实,这里还有一个可选输入,即--graph_pb,如果设定这个,相当于不用meta文件里的graph,而是用这个网络去freeze。 这个参数不一定非要用训练时保存的网络,可以指定任何网络。讲到这里你可能就明白我们的方案是什么了。
python ../../python/freeze_graph.py \
--checkpoint_dir='./checkpoint' \
--graph_pb='./model/predict_graph.pb' \
--output_node_names='predict/add' \
--output_dir='./model'

具体方案

由于session.Run(...)只能接受std::vector<std::pair<std::string, tensorflow::Tensor>>作为网络输入,那么我们可以构造一个新网络,这个网络和训练时的网络几乎一样,只不过输入部分不使用Reader,而是用tf.Placeholder代替。我们把新网络保存成predict_graph.pb,把它和训练产出的checkpoint进行freeze,即可得到可以用c++导入的一个新网络pb,用这个pb上线就可以了。

例子

训练网络

# input
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(file_name), num_epochs=max_epoch)
serialized_example = self.Decode(filename_queue)
capacity = thread_num * batch_size + min_after_dequeue
batch_serialized_example = tf.train.shuffle_batch(
[serialized_example],
batch_size=batch_size,
num_threads=thread_num,
capacity=capacity,
min_after_dequeue=min_after_dequeue)
features = {}
features['label'] = tf.FixedLenFeature([], tf.float32)
features['sparse_id'] = tf.VarLenFeature(tf.int64)
features['sparse_val'] = tf.VarLenFeature(tf.float32)
instance = tf.parse_example(batch_serialized_example, features)
label = instance['label']
sparse_id = instance['sparse_id']
sparse_val = instance['sparse_val']

# network
with tf.variable_scope("emb_layer"):
embedding_variable = tf.Variable(tf.truncated_normal([100000, 50], stddev=0.05), name='emb_var')
embedding = tf.nn.embedding_lookup_sparse(embedding_variable, sparse_id, sparse_val], "mod", combiner="sum")
...

预测网络

# input
with tf.name_scope('input'):
with tf.variable_scope('sparse_field'):
with tf.variable_scope('index'):
sparse_index = tf.placeholder(tf.int64)
with tf.variable_scope('id'):
sparse_ids = tf.placeholder(tf.int64)
with tf.variable_scope('value'):
sparse_vals = tf.placeholder(tf.float32)
with tf.variable_scope('shape'):
sparse_shape = tf.placeholder(tf.int64)
sparse_id = tf.SparseTensor(sparse_index, sparse_ids, self.sparse_shape)
sparse_val = tf.SparseTensor(sparse_index, sparse_vals, sparse_shape)
with tf.variable_scope('label'):
label = tf.placeholder(tf.float32)

# network
with tf.variable_scope("emb_layer"):
embedding_variable = tf.Variable(tf.truncated_normal([100000, 50], stddev=0.05), name='emb_var')
embedding = tf.nn.embedding_lookup_sparse(embedding_variable, sparse_id, sparse_val], "mod", combiner="sum")
...

使用训练网络训练后保存checkpoint,然后保存预测网络的graph.pb,直接调用freeze把两者生成一个新的graph.pb即可,c++线上预测时只需为预测网络的输入部分构造所需几个Tensor作为输入即可。