全部产品
云市场

TensorFlow模型导出示例

更新时间:2019-05-27 14:57:27

SavedModel格式简介

Tensorflow模型部署成在线服务需首先将模型导出成官方定义的Savedmodel格式,SavedModel格式的模型是目前Tensorflow官方推荐的导出模型格式。SavedModel格式模型目录结构如下所示:

  1. assets/
  2. variables/
  3. variables.data-00000-of-00001
  4. variables.index
  5. saved_model.pb|saved_model.pbtxt

其中:

  • assets是一个可选目录,用于存放预测时的辅助文档信息;
  • variables存放tf.train.Saver时保存的变量信息;
  • saved_model.pb或saved_model.pbtxt存放MetaGraphDef,存储训练预测模型的程序逻辑和SignatureDef用于标记预测时的输入和输出。

导出SavedModel

使用Tensorflow导出SavedModel格式的模型也非常简单,可参考Saving and Restoring,若模型比较简单,则用户可使用简单的方式快速导出savedmodel,代码片段示例如下所示:

  1. tf.saved_model.simple_save(
  2. session,
  3. "./savedmodel/",
  4. inputs={"image": x}, ## x是模型的输入变量
  5. outputs={"scores": y} ## y是模型的输出
  6. )

在请求在线预测服务时,请求中需指定模型的signature_name,使用simple_save()方法导出的模型中,signature_name默认为:serving_default

或模型较为复杂,也可使用手工的方式来导出saved_model,代码片段示例如下所示:

  1. print 'Exporting trained model to', export_path
  2. builder = tf.saved_model.builder.SavedModelBuilder(export_path)
  3. tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
  4. tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
  5. prediction_signature = (
  6. tf.saved_model.signature_def_utils.build_signature_def(
  7. inputs={'images': tensor_info_x},
  8. outputs={'scores': tensor_info_y},
  9. method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
  10. legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
  11. builder.add_meta_graph_and_variables(
  12. sess, [tf.saved_model.tag_constants.SERVING],
  13. signature_def_map={
  14. 'predict_images':
  15. prediction_signature,
  16. },
  17. legacy_init_op=legacy_init_op)
  18. builder.save()
  19. print 'Done exporting!'

其中:

  • export_path为导出模型的路径;
  • prediction_signature是模型为输入和输出构建出的SignatureDef,具体可参考SignatureDef,在上例中signature_name为predict_images
  • builder.add_meta_graph_and_variables方法描述了导出模型的参数,特别注意tf.saved_model.tag_constants.SERVING**。
  • Tensorflow导出SavedModel格式模型的完整代码下载saved_model.tar.gz
  • More detail:TensorFlow SavedModel

Keras模型转换成Savedmodel

用户通常会使用keras的model.save()方法来将keras模型导出成h5格式,将h5格式的模型转换成Savedmodel同样简单,只需要调用load_model()方法将h5模型加载,继而再导出成Savedmodel格式即可,代码片段示例如下所示:

  1. import tensorflow as tf
  2. with tf.device("/cpu:0"):
  3. model = tf.keras.models.load_model('./mnist.h5')
  4. tf.saved_model.simple_save(
  5. tf.keras.backend.get_session(),
  6. "./h5_savedmodel/",
  7. inputs={"image": model.input},
  8. outputs={"scores": model.output}
  9. )

Checkpoint转换成Savedmodel

训练过程中使用tf.train.Saver()保存的模型格式为checkpoint格式,同样需要转换成Savedmodel才可进行在线预测,转换的方式也非常简单,可以saver.restore()方法将checkpoint加载成tf session,即而用上述方法转换成saved_model即可,代码片段示例如下所示:

  1. import tensorflow as tf
  2. # variable define ...
  3. saver = tf.train.Saver()
  4. with tf.Session() as sess:
  5. # Initialize v1 since the saver will not.
  6. saver.restore(sess, "./lr_model/model.ckpt")
  7. tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
  8. tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
  9. tf.saved_model.simple_save(
  10. sess,
  11. "./savedmodel/",
  12. inputs={"image": tensor_info_x},
  13. outputs={"scores": tensor_info_y}
  14. )