文档

PAI-TF超参支持

更新时间:
重要

本文中含有需要您注意的重要提示信息,忽略该信息可能对您的业务造成影响,请务必仔细阅读。

PAI-TF支持通过超参TXT文件或Command传入相应的超参配置,从而在模型试验时可以尝试不同的Learning Rate及Batch Size等。

警告

公共云GPU服务器即将过保下线,您可以继续提交CPU版本的TensorFlow任务。如需使用GPU进行模型训练,请前往DLC提交任务,具体操作请参见创建训练任务

超参文件

您可以通过一个本地文件配置相应的超参信息,格式如下。

batch_size=10
learning_rate=0.01

TensorFlow Python SDK提供了相应的参数以便获取相应的超参,您可以通过tf.app.flags.FLAGS读取所需的超参,再将其传入运行脚本中,即可在模型训练文件中读取到相应的超参定义。具体方法如下:

  1. 假设上面定义的超参文件存储在oss://xxx.oss-cn-beijing.aliyuncs.com/tf/hyper_para.txt,参考如下Python代码读取超参。

    import tensorflow as tf
    tf.app.flags.DEFINE_string("learning_rate", "", "learning_rate")
    tf.app.flags.DEFINE_string("batch_size", "", "batch size")
    FAGS = tf.app.flags.FLAGS
    print("learning rate:" + FAGS.learning_rate)
    print("batch size:" + FAGS.batch_size)
  2. 通过-DhyperParameters将超参传入到运行脚本中,示例如下。

    pai -name tensorflow1120_ext
        -Dscript='oss://xxx.oss-cn-beijing.aliyuncs.com/tf/hello_hyperpara.py'
        -Dbuckets='oss://xxx.oss-cn-beijing.aliyuncs.com/'
        -DhyperParameters='oss://xxx.oss-cn-beijing.aliyuncs.com/tf/hyper_para.txt'
        -Darn='acs:ram::111***:role/***role';

字符串形式参数

PAI-TF也支持以字符串形式传入参数,您可以直接将字符串通过userDefinedParameters传入,示例如下。

pai -name tensorflow1120_ext
    -Dscript='oss://xxx.oss-cn-beijing.aliyuncs.com/tf/hello_hyperpara.py'
    -Dbuckets='oss://xxx.oss-cn-beijing.aliyuncs.com/'
    -DuserDefinedParameters="--batch_size=10 --learning_rate=0.01"
    -Darn='acs:ram::111***:role/***role';
说明

以字符串传入的参数,使用KV格式,每一个KV前面需要以“--”作为前缀。