文档

TensorFlow模型如何导出为SavedModel

更新时间:

本文为您介绍如何将TensorFlow模型导出为SavedModel格式。

SavedModel格式

使用EAS预置官方Processor将TensorFlow模型部署为在线服务,必须先将模型导出为官方定义的SavedModel格式(TensorFlow官方推荐的导出模型格式)。SavedModel模型格式的目录结构如下。

assets/
variables/
    variables.data-00000-of-00001
    variables.index
saved_model.pb|saved_model.pbtxt

其中:

  • assets表示一个可选目录,用于存储预测时的辅助文档信息。

  • variables存储tf.train.Saver保存的变量信息。

  • saved_model.pbsaved_model.pbtxt存储MetaGraphDef(存储训练预测模型的程序逻辑)和SignatureDef(用于标记预测时的输入和输出)。

导出SavedModel

使用TensorFlow导出SavedModel格式的模型请参见Saving and Restoring。如果模型比较简单,则可以使用如下方式快速导出SavedModel。

tf.saved_model.simple_save(
  session,
  "./savedmodel/",
  inputs={"image": x},   ## x表示模型的输入变量。
  outputs={"scores": y}  ## y表示模型的输出。
)

请求在线预测服务时,请求中需要指定模型signature_name,使用simple_save()方法导出的模型中,signature_name默认为serving_default

如果模型比较复杂,则可以使用手工方式导出SavedModel,代码示例如下。

print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
    tf.saved_model.signature_def_utils.build_signature_def(
        inputs={'images': tensor_info_x},
        outputs={'scores': tensor_info_y},
        method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
)

legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')

builder.add_meta_graph_and_variables(
    sess, [tf.saved_model.tag_constants.SERVING],
    signature_def_map={
        'predict_images': prediction_signature,
    },
    legacy_init_op=legacy_init_op
)

builder.save()
print('Done exporting!')

其中:

  • export_path表示导出模型的路径。

  • prediction_signature表示模型为输入和输出构建的SignatureDef,详情请参见SignatureDef。示例中的signature_name为predict_images

  • builder.add_meta_graph_and_variables方法表示导出模型的参数。

说明
  • 导出预测所需的模型时,必须指定导出模型的Tag为tf.saved_model.tag_constants.SERVING。

  • 有关TensorFlow模型的更多信息,请参见TensorFlow SavedModel

Keras模型转换为SavedModel

使用Keras的model.save()方法会将Keras模型导出为H5格式,需要将其转换为SavedModel才能进行在线预测。您可以先调用load_model()方法加载H5模型,再将其导出为SavedModel格式,代码示例如下。

import tensorflow as tf
with tf.device("/cpu:0"):
    model = tf.keras.models.load_model('./mnist.h5')
    tf.saved_model.simple_save(
      tf.keras.backend.get_session(),
      "./h5_savedmodel/",
      inputs={"image": model.input},
      outputs={"scores": model.output}
    )

Checkpoint转换为Savedmodel

训练过程中使用tf.train.Saver()方法保存的模型格式为checkpoint,需要将其转换为SavedModel才能进行在线预测。您可以先调用saver.restore()方法将Checkpoint加载为tf.Session,再将其导出为SavedModel格式,代码示例如下。

import tensorflow as tf
# variable define ...
saver = tf.train.Saver()
with tf.Session() as sess:
  # Initialize v1 since the saver will not.
    saver.restore(sess, "./lr_model/model.ckpt")
    tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
    tensor_info_y = tf.saved_model.utils.build_tensor_info(y)
    tf.saved_model.simple_save(
      sess,
      "./savedmodel/",
      inputs={"image": tensor_info_x},
      outputs={"scores": tensor_info_y}
    )

  • 本页导读 (1)
文档反馈