本文为您介绍PAI-TF模型导出和部署相关说明。包括导出SaveModel通用模型、保存和恢复检查点以及如何将TF模型部署到EAS。

导出SaveModel通用模型

  • SavedModel格式

    SavedModel是目前官方推荐的模型保存的格式(SessionBundle自Tensorflow 1.0以后不再推荐使用),目录结构如下。

    assets/assets.extra/variables/ variables.data-?????-of-????? variables.indexsaved_model.pb

    目录中各个子目录和文件的含义请参见TensorFlow SavedModel官方文档介绍

  • 导出SavedModel
    代码片段:
    class Softmax(object):
        def __init__(self):
            self.weights_ = tf.Variable(tf.zeros([FLAGS.image_size, FLAGS.num_classes]),
                    name='weights')
            self.biases_ = tf.Variable(tf.zeros([FLAGS.num_classes]),
                    name='biases')
        # ...
        def signature_def(self):
            images = tf.placeholder(tf.uint8, [None, FLAGS.image_size],
                name='input')
            normalized_images = tf.scalar_mul(1.0 / FLAGS.image_depth,
                tf.to_float(images))
            scores = self.scores(normalized_images)
            tensor_info_x = tf.saved_model.utils.build_tensor_info(images)
            tensor_info_y = tf.saved_model.utils.build_tensor_info(scores)
            return tf.saved_model.signature_def_utils.build_signature_def(
                    inputs={'images': tensor_info_x},
                    outputs={'scores': tensor_info_y},
                    method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME)
        def savedmodel(self, sess, signature, path):
            export_dir = os.path.join(path, str(FLAGS.model_version))
            builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
            builder.add_meta_graph_and_variables(
                sess, [tf.saved_model.tag_constants.SERVING],
                signature_def_map={
                    'predict_images':
                        signature,
                },
                clear_devices=True)
            builder.save()
    #...
    model = Softmax()
    signature = model.signature_def()
    #...
    model.savedmodel(sess, signature, mnist.export_path())
    代码说明:
    • Softmax类封装了机器学习模型,其中weightsbiases是其最主要的模型参数。
    • signature_def方法描述了预测时,如何从一个placeholder经过数据标准化和前向计算得到输出的逻辑,并分别作为输入和输出构建出一个SignatureDef

    导出SavedModel至OSS:

    训练并导出模型的命令如下。
    PAI -name tensorflow
        -Dscript="file://path/to/mnist_savedmodel_oss.py"
        -Dbuckets="oss://mnistdataset/?host=oss-test.aliyun-inc.com&role_arn=acs:ram::127488******:role/odps"
        -DcheckpointDir="oss://mnistdataset/?host=oss-test.aliyun-inc.com&role_arn=acs:ram::127488*********:role/odps";

保存和恢复检查点

  • Checkpoint存储
    非交互式TensorFlow存储模型的示例程序如下。
    # -*- coding: utf-8 -*-
    # usage
    # pai -name tensorflow -DcheckpointDir="oss://tftest/examples/?host=oss-test.aliyun-inc.com&role_arn=acs:ram::****:role/odps" -Dscript="file:///path/to/save_model.py";
    import tensorflow as tf
    import json
    import os
    tf.app.flags.DEFINE_string("checkpointDir", "", "oss info")
    FLAGS = tf.app.flags.FLAGS
    print("checkpoint dir:" + FLAGS.checkpointDir)
    # 定义变量
    counter = tf.Variable(1, name="counter")
    one = tf.constant(2)
    sum = tf.add(counter, one)
    new_counter = tf.assign(counter, sum)
    saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        ret = sess.run(new_counter)
        print("Counter:%d" % ret)
        ckp_path = os.path.join(FLAGS.checkpointDir, "model.ckpt")
        save_path = saver.save(sess, ckp_path)
        print("Model saved in file: %s" % save_path)
        coord.request_stop()
        coord.join(threads)

    tf.app.flags.DEFINE_string()tf.app.flags.FLAGS可以获取PAI命令中的checkpointDir参数,checkpointDir指定了模型将要存储到OSS上。

    以下代码完成了new_counter的计算,并将名称为counter的变量存储到模型中(值为3),save_path = saver.save(sess, ckp_path)将模型写到OSS指定路径。
    ret = sess.run(new_counter)
    print("Counter:%d" % ret)
    ckp_path = os.path.join(FLAGS.checkpointDir, "model.ckpt")
    save_path = saver.save(sess, ckp_path)
    print("Model saved in file: %s" % save_path)
  • Checkpoint恢复
    TensorFlow的Saver类也可以用于模型的恢复,TensorFlow恢复模型的示例如下。
    # -*- coding: utf-8 -*-
    # usage
    # pai -name tensorflow -Dbuckets="oss://tftest/examples/?host=oss-test.aliyun-inc.com&role_arn=acs:ram::***:role/odps" -Dscript="file:///path/to/restore_model.py";
    import tensorflow as tf
    import json
    import os
    tf.app.flags.DEFINE_string("buckets", "", "oss info")
    FLAGS = tf.app.flags.FLAGS
    print("buckets:" + FLAGS.buckets)
    # 定义变量
    counter = tf.Variable(1, name="counter")
    saver = tf.train.Saver()
    init_op = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init_op)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        ret = sess.run(counter)
        print("Before restore counter:%d" % ret)
        print("Model restore from file")
        ckp_path = os.path.join(FLAGS.buckets, "model.ckpt")
        saver.restore(sess, ckp_path)
        ret = sess.run(counter)
        print("After restore counter:%d" % ret)
        coord.request_stop()
        coord.join(threads)

    tf.app.flags.DEFINE_string()tf.app.flags.FLAGS可以获取PAI命令中的buckets参数,buckets指定了模型将要从OSS上恢复模型。

    以下代码中,首先定义了名称counter的变量,初始值为1。调用saver.restore(sess, ckp_path),根据给定的OSS路径恢复已存储的模型,最后执行ret = sess.run(counter)得到恢复后的变量的值也是3。
    ret = sess.run(counter)
    print("Before restore counter:%d" % ret)
    print("Model restore from file")
    ckp_path = os.path.join(FLAGS.buckets, "model.ckpt")
    saver.restore(sess, ckp_path)
    ret = sess.run(counter)
    print("After restore counter:%d" % ret)

TF模型部署到EAS

EAS是PAI平台自研的模型部署工具,支持深度学习框架生成的模型,特别是部署TensorFlow SavedModel函数生成的模型。EAS有两种模型部署方式,一种是通过PAI-EAS的线上服务进行部署,另一种是通过EAS CMD进行部署:
  • 线上服务部署方式
    1. 将模型存储于OSS中。
    2. 登录PAI控制台
    3. 在左侧导航栏,单击EAS-模型在线服务
    4. 在顶部菜单栏处,选择地域。
    5. PAI EAS 模型在线服务页面,单击模型上传部署
    6. 在右侧的配置页面,设置Processor 种类TensorFlow1.12TensorFlow1.14,并选择您已上传的OSS中的模型文件。模型上传部署
    7. 单击下一步,配置模型部署信息并确认,然后单击部署

      系统会把SavedModel格式的TensorFlow模型打包上传,完成模型服务的部署。

  • EASCMD部署方式

    详情请参见命令使用说明