全部产品
机器学习PAI

深度学习

更新时间:2017-07-21 14:45:07   分享:   


深度学习


目录


深度学习框架说明

阿里云机器学习平台上支持深度学习框架,同时后端提供了功能强大的GPU(型号M40)计算集群,用户可以使用这些框架及硬件资源来运行深度学习算法。 目前支持的框架包括 TensorFlow(支持1.0、1.1、1.2版本),MXNet 0.9.5, Caffe rc3。TensorFlow 和MXNet 支持用户自己编写的Python 代码, Caffe 支持用户自定义网络文件。

在使用深度学习框架训练数据之前,需要将训练的数据上传至阿里云对象存储OSS中,算法在运行时从指定的OSS目录中读取数据。需要注意的是阿里云机器学习目前只在华东2部署了GPU 集群,算法在执行时访问华东2 OSS中数据时不产生流量费用,访问其它地域的OSS会产生流量费用。

深度学习开通

目前机器学习平台深度学习相关功能处于公测阶段,深度学习组件包含TensorFlow、Caffe、MXNet三个框架,开通方式如下,进入机器学习控制台,在相应项目下勾选GPU资源即可使用。

开通GPU资源的项目会被分配到公共的资源池,可以动态的调用底层的GPU计算资源。

OSS上传数据说明

使用深度学习处理数据时,数据先存储到OSS的bucket中。第一步要创建OSS Bucket。 由于深度学习的GPU集群在华东2,建议您创建 OSS Bucket 时选择华东2地区。这样在数据传输时就可以使用阿里云经典网络,算法运行时不需要收取流量费用。Bucket 创建好之后,可以在OSS管理控制台 来创建文件夹,组织数据目录,上传数据了。

OSS支持多种方式上传数据, API或SDK详细见:https://help.aliyun.com/document_detail/31848.html?spm=5176.doc31848.6.580.a6es2a

OSS还提供了大量的常用工具用来帮助用户更加高效的使用OSS。工具列表请参见: https://help.aliyun.com/document_detail/44075.html?spm=5176.doc32184.6.1012.XlMMUx

建议您使用 ossutil 或 osscmd ,这是两个命令行工具,通过命令的方式来上传、下载文件,还支持断点续传。

注:在使用工具时需要配置 AccessKey 和ID,登录后,可以在Access Key 管理控制台创建或查看。

读OSSBucket

用户在机器学习平台中使用“读OSS Bucket”组件时,需要授予一个名称为“AliyunODPSPAIDefaultRole” 的系统默认角色给数加的服务账号,当且仅当该角色被正确授权后,机器学习平台的算法才能正确地读、写OSS bucket。

  • 注:由于机器学习平台运行在MaxCompute框架之上,与MaxCompute共用服务账号。在授权时,默认的角色授予给MaxCompute服务账号。

需要在“设置”菜单完成对OSS读写权限的授权,详情见RAM授权。

RAM授权

1.RAM授权,可以使机器学习平台获得OSS的访问权限,点击“这里”进入RAM入口,如图。

2.点击“前往RAM进行授权”,进入如下界面,点击同意即可。

注:如果您想查看AliyunODPSPAIDefaultRole的相关详细策略信息,可以登录RAM控制台来查看。 默认角色AliyunODPSPAIDefaultRole包含的权限信息如下:

权限名称(Action) 权限说明
oss:PutObject 上传文件或文件夹对象
oss:GetObject 获取文件或文件夹对象
oss:ListObjects 查询文件列表信息
oss:DeleteObjects 删除对象

3.回到阿里云机器学习界面,点击刷新,RAM信息会自动录入组件中。如图

4.在使用深度学习框架过程中,需要组件“读OSSBucket”与相应的深度学习组件相连,用来获得OSS的读写权限。

TensorFlow

背景

TensorFlow(以下简称TF)是Google开源的一套机器学习框架,算法开发者通过简单的学习就能快速上手。阿里云机器学习平台将TF框架集成到产品中。用户可以自由的利用TF进行代码编写,TF的计算引擎为GPU集群,用户可以灵活的对计算资源进行调整。

参数说明

(1)参数设置

  • Python代码文件:程序执行文件,多个文件可通过tar.gz打包上传。
  • Python主文件:指定代码文件压缩包中的主文件,可选。
  • 数据源目录:选择OSS上的数据源。
  • 配置文件超参及用户自定义参数:PAI Tensorflow支持用户通过Command传入相应的超参配置,这样用户可以在做模型试验的时候可以尝试不同的learning rate, batch size等。
  • 输出目录:输出的模型路径。

(2)执行调优

用户可以根据自身任务的复杂程度指定GPU卡数

PAI命令

  1. 实际使用中,并不需要指定所有参数(不要直接复制下面的命令....),各个参数的含义可以参考后面的表格
  2. PAI -name tensorflow_ext -Dbuckets="oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/smoke_tensorflow/mnist/" -DgpuRequired="100" -Darn="acs:ram::166408185518****:role/aliyunodpspaidefaultrole" -Dscript="oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/smoke_tensorflow/mnist_ext.py";

各个参数的具体含义如下表:

参数名称 参数描述 参数值格式 默认值
script 必选,TF算法文件,可以是单个文件或者tar.gz压缩包 oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/smoke_tensorflow/mnist_ext.py -
entryFile 可选,算法入口文件名,当script为tar.gz压缩包时,该参数必填 train.py
buckets 必选,输入OSS bucket,可指定多个,以逗号分割, 每个bucket须以”/“结尾 oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/smoke_tensorflow/mnist/
arn 必选, OSS role_arn
gpuRequired 必选,标识使用GPU资源量 200 100
checkpointDir 可选,TF checkpoint目录 oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/smoke_tensorflow/mnist/
hyperParameters 可选,命令行超参数路径 oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/smoke_tensorflow/mnist/hyper_parameters.txt
  • 参数scriptentryFile用于指定要执行的TF算法脚本,如果算法比较复杂,分成了多个文件,可以将多个算法文件打包成tar.gz格式,并利用entryFile参数指定该算法的入口文件。
  • 参数checkpointDir用于指定算法将要写入的OSS路径,在Tensorflow保存模型时需要指定

  • 参数buckets用于指定算法将要读取的OSS路径,使用OSS需要指定arn参数。

案例

Mnist手写字识别是TF官方的案例,通过训练手写体1~9的数字生成模型,通过模型进行预测。

1.首先在OSS端上传数据的执行python文件以及训练数据集。本案例在OSS华东2 region创建bucket,bucket名为tfmnist,上传python脚本以及训练数据。

2.拖拽读OSS Bucket和TensorFlow组件,拼接成如下实验,需要设置好OSS Bucket的地区,并且完成RAM授权。如下图:

3.配置TensorFlow组件参数,将python执行文件以及数据源文件路径配置正确,如图:

4.点击运行,直到两个组件运行完成,如下图:

5.可以右键TensorFlow组件,查看运行日志。

MXNet

背景

MXNet是一个深度学习框架,支持命令和符号编程,可以运行在CPU和GPU集群,MXNet是cxxnet的下一代,cxxnet借鉴了minerva的思想。

参数说明

(1)参数设置

  • Python代码文件:程序执行文件,多个文件可通过tar.gz打包上传。
  • Python主文件:指定代码文件压缩包中的主文件,可选。
  • 数据源目录:选择OSS上的数据源。
  • 配置文件超参及用户自定义参数:PAI MXNet支持用户通过Command传入相应的超参配置,这样用户可以在做模型试验的时候可以尝试不同的learning rate, batch size等。
  • 输出目录:输出的模型路径。

(2)执行调优

用户可以根据自身任务的复杂程度指定GPU卡数

PAI命令

  1. 实际使用中,并不需要指定所有参数(不要直接复制下面的命令....),各个参数的含义可以参考后面的表格
  2. pai -name mxnet_ext -Dscript="oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/mxnet-ext-code/mxnet_cifar10_demo.tar.gz" -DentryFile="train_cifar10.py" -Dbuckets="oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com" -DcheckpointDir="oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/mxnet-ext-model/" -DhyperParameters="oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/mxnet-ext-code/hyperparam.txt.single" -Darn="acs:ram::1664081855183111:role/role-for-pai";

各个参数的具体含义如下表:

参数名称 参数描述 参数值格式 默认值
script 必选,TF算法文件,可以是单个文件或者tar.gz压缩包 oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/smoke_mxnet/mnist_ext.py oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/smoke_mxnet/mnist_ext.py -
entryFile 可选,算法入口文件名,当script为tar.gz压缩包时,该参数必填 train.py
buckets 必选,输入bucket,可多个,以逗号隔开, 每个bucket须以”/“结尾 oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com
hyperParameters 可选,命令行超参数路径 oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/mxnet-ext-code/
gpuRequired 可选,标识使用GPU资源量 200 100
checkpointDir 可选, checkpoint目录 oss://imagenet.oss-cn-shanghai-internal.aliyuncs.com/mxnet-ext-code/

案例

CIFAR-10是MXNet官方提供的基于图片的10分类场景的案例,通过对于6万张32*32的图片进行训练生成模型,可以对飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车进行自动分类。详细内容见:https://www.cs.toronto.edu/~kriz/cifar.html

1.首先在OSS端上传数据的执行python文件以及训练数据集。本案例在OSS华东2 region创建bucket,bucket名为tfmnist,上传python脚本以及训练数据。

2.拖拽读OSS Bucket和MXNet组件,拼接成如下实验,需要设置好OSS Bucket的地区,并且完成RAM授权。如下图:

3.配置MXNet组件参数,将python执行文件以及数据源文件路径配置正确,如图:

  • Python代码文件选择.tar.gz文件
  • Python主文件选择tar包中的执行入口文件
  • 超惨自定义参数文件选择.txt.single文件
  • checkpoint为模型输出目录

4.点击运行,直到两个组件运行完成,如下图:

5.可以右键MXNet组件,查看运行日志。

6.最终在checkpoint地址下生成模型如图:


格式转换

目前PAI Caffe不支持自定义格式训练数据,数据需要通过格式转换组件进行转换方可使用。

  • 输入桩接读oss组件。

  • 参数

    • 输入oss路径,oss的训练数据的file_list(如bucket.hz.aliyun.com/train_img/train_file_list.txt )file_list格式如下:
      1. bucket/ilsvrc12_val/ILSVRC2012_val_00029021.JPEG 817
      2. bucket/ilsvrc12_val/ILSVRC2012_val_00021046.JPEG 913
      3. bucket/ilsvrc12_val/ILSVRC2012_val_00041166.JPEG 486
      4. bucket/ilsvrc12_val/ILSVRC2012_val_00029527.JPEG 327
      5. bucket/ilsvrc12_val/ILSVRC2012_val_00042825.JPEG 138
    • 输出oss目录,如bucket_name.oss-cn-hangzhou-zmf.aliyuncs.com/ilsvrc12_val_convert ,会输出转换后的data_file_list.txt和对应的数据文件。data_file_list格式如:
      1. bucket/ilsvrc12_val_convert/train_data_00_01
      2. bucket/ilsvrc12_val_convert/train_data_00_02
    • 编码类型,选项,可选jpg,png,raw等。
    • 是否shuffle,勾选
    • 文件前缀,默认为data
    • resize_height,默认为256
    • resize_width,默认为256
    • 是否灰度,默认为否
    • 是否需要产生图片mean文件,默认否
转换组件PAI命令示例
  1. pai -name convert_image_oss2oss
  2. -Darn=acs:ram::1607128916545079:role/test-1
  3. -DossImageList=bucket_name.oss-cn-hangzhou-zmf.aliyuncs.com/image_list.txt
  4. -DossOutputDir=bucket_name.oss-cn-hangzhou-zmf.aliyuncs.com/your/dir
  5. -DencodeType=jpg
  6. -Dshuffle=true
  7. -DdataFilePrefix=train
  8. -DresizeHeight=256
  9. -DresizeWidth=256
  10. -DisGray=false
  11. -DimageMeanFile=false
PAI命令参数
参数名称 参数描述 取值范围 是否必选,默认值/行为
ossHost 对应的oss host地址 形式如“oss-test.aliyun-inc.com” 可选,默认值为“oss-cn-hangzhou-zmf.aliyuncs.com”,即为对内oss使用的host
arn OSS Bucket默认Role对应的ARN 形式如“acs:ram::XXXXXXXXXXXXXXXX:role/ossaccessroleforodps”,中间xxx代表生成的rolearn的16位数字 必选
ossImageList 图片文件列表 形式如“bucket_name/image_list.txt” 必选
ossOutputDir 输出oss目录 形式如“bucket_name/your/dir” 必选
encodeType 编码类型 如jpg,png,raw 可选,默认值为jpg
shuffle 是否shuffle数据 bool值 可选,默认值为true
dataFilePrefix 数据文件前缀 string类型,如train或val 必选
resizeHeight 图像resize的height int类型,用户自定义 可选,默认值为256
resizeWidth 图像resize的width int类型,用户自定义 可选,默认值为256
isGray 图像是否为灰度图 bool值 可选,默认值为false
imageMeanFile 是否需要生成imagemean文件 bool值 可选,默认值为false

Caffe

背景

caffe是一个清晰,可读性高,快速的深度学习框架。作者是贾扬清,加州大学伯克利的ph.D,现就职于Facebook。caffe的官网是http://caffe.berkeleyvision.org/。

参数说明

  • 首先配置OSS访问权限

  • 唯一的一个参数就是solver.prototxt文件的oss路径。其中solver由于并行化的修改,同开源caffe略有不同,需要注意一下几点

    1. net: “bucket.hz.aliyun.com/alexnet/train_val.prototxt” net的文件位置是oss路径
    2. type: “ParallelSGD” type类型是ParallelSGD,注意这是一个字符串
    3. model_average_iter_interval: 1 多卡下表示同步的频率,1表示每轮都同步一次
    4. snapshot_prefix: “bucket/snapshot/alexnet_train” 模型输出到oss的目录
  1. net: "bucket/alexnet/train_val.prototxt"
  2. test_iter: 1000
  3. test_interval: 1000
  4. base_lr: 0.01
  5. lr_policy: "step"
  6. gamma: 0.1
  7. stepsize: 100000
  8. display: 20
  9. max_iter: 450000
  10. momentum: 0.9
  11. weight_decay: 0.0005
  12. snapshot: 10000
  13. snapshot_prefix: "bucket/snapshot/alexnet_train"
  14. solver_mode: GPU
  15. type: "ParallelSGD"
  16. model_average_iter_interval: 1
  • train_val中的datalayer需使用BinaryDataLayer,请参考如下示例。
  1. layer {
  2. name: "data"
  3. type: "BinaryData"
  4. top: "data"
  5. top: "label"
  6. include {
  7. phase: TRAIN
  8. }
  9. transform_param {
  10. mirror: true
  11. crop_size: 227
  12. mean_file: "bucket/imagenet_mean.binaryproto"
  13. }
  14. binary_data_param {
  15. source: "bucket/ilsvrc12_train_binary/data_file_list.txt"
  16. batch_size: 256
  17. num_threads: 10
  18. }
  19. }
  20. layer {
  21. name: "data"
  22. type: "BinaryData"
  23. top: "data"
  24. top: "label"
  25. include {
  26. phase: TEST
  27. }
  28. transform_param {
  29. mirror: false
  30. crop_size: 227
  31. mean_file: "bucket/imagenet_mean.binaryproto"
  32. }
  33. binary_data_param {
  34. source: "bucket/ilsvrc12_val_binary/data_file_list.txt"
  35. batch_size: 50
  36. num_threads: 10
  37. }
  38. }

新的data Layer的名称为“ BinaryData”,其中也支持transform param对输入图像数据进行变换,参数和caffe原生参数保持一致;

其中binary_data_param为数据层本身的参数配置。

binary_data_param中包括以下特殊的参数:

  1. source:数据来源,其中路径为oss中filelist的路径,从bucket名称开始,不包含oss://

  2. num_threads:读取oss数据时并发的线程数目,默认值为10,用户可以根据自己的需求进行调整

PAI命令

  1. pai -name pluto_train_oss
  2. -DossHost=oss-cn-hangzhou-zmf.aliyuncs.com
  3. -Darn=acs:ram::1607128916545079:role/test-1
  4. -DsolverPrototxtFile=bucket_name.oss-cn-hangzhou-zmf.aliyuncs.com/solver.prototxt
  5. -DgpuRequired=1
PAI命令参数
参数名称 参数描述 取值范围 是否必选,默认值/行为
ossHost 对应的oss host地址 形式如“oss-test.aliyun-inc.com” 可选,默认值为“oss-cn-hangzhou-zmf.aliyuncs.com”,即为对内oss使用的host
arn OSS Bucket默认Role对应的ARN 形式如“acs:ram::XXXXXXXXXXXXXXXX:role/ossaccessroleforodps”,中间xxx代表生成的rolearn的16位数字 必选
solverPrototxtFile solver文件 solver文件在oss中的路径,从bucket name开始 必选
gpuRequired GPU卡个数 整型值 可选,默认值为1

案例

利用Caffe实现mnist的数据训练。1.准备数据源在本页的“深度学习案例代码及数据下载”页找到Caffe数据下载并解压。将数据导入OSS中,本案例路径如下图,请配合代码中的路径理解:

2.实现实验拖拉Caffe组件拼接成如下实验:

将solver oss路径指向mnist_solver_dnn_binary.prototxt文件。点击运行。3.查看日志右键Caffe组件,查看日志:点击logview链接-》ODPS Tasks-》VlinuxTask-》StdErr,即可查看训练过程产生的日志:

深度学习案例代码及数据下载

TensorFlow相关下载

MXNet相关下载

Caffe相关下载

TensorFlow支持的第三方库

  1. aorun (0.1)
  2. appdirs (1.4.3)
  3. backports-abc (0.5)
  4. backports.shutil-get-terminal-size (1.0.0)
  5. backports.ssl-match-hostname (3.5.0.1)
  6. bleach (2.0.0)
  7. boto (2.46.1)
  8. brocas-lm (1.0)
  9. bz2file (0.98)
  10. certifi (2017.1.23)
  11. chardet (2.3.0)
  12. configparser (3.5.0)
  13. cycler (0.10.0)
  14. dask (0.14.0)
  15. decorator (4.0.11)
  16. Distance (0.1.3)
  17. docutils (0.13.1)
  18. easygui (0.98.1)
  19. entrypoints (0.2.2)
  20. enum34 (1.1.6)
  21. funcsigs (1.0.2)
  22. functools32 (3.2.3.post2)
  23. future (0.16.0)
  24. futures (3.0.5)
  25. gensim (1.0.1)
  26. h5py (2.6.0)
  27. html5lib (0.999999999)
  28. iniparse (0.4)
  29. ipykernel (4.5.2)
  30. ipyparallel (6.0.2)
  31. ipython (5.3.0)
  32. ipython-genutils (0.1.0)
  33. ipywidgets (6.0.0)
  34. javapackages (1.0.0)
  35. jieba (0.38)
  36. Jinja2 (2.9.5)
  37. jsonschema (2.6.0)
  38. jupyter (1.0.0)
  39. jupyter-client (5.0.0)
  40. jupyter-console (5.1.0)
  41. jupyter-core (4.3.0)
  42. Keras (1.2.2)
  43. kitchen (1.1.1)
  44. langid (1.1.6)
  45. langtable (0.0.31)
  46. lxml (3.7.3)
  47. MarkupSafe (1.0)
  48. matplotlib (2.0.0)
  49. mistune (0.7.3)
  50. mlpy (0.1.0)
  51. mock (2.0.0)
  52. nbconvert (5.1.1)
  53. nbformat (4.3.0)
  54. networkx (1.11)
  55. nltk (3.2.2)
  56. nltk-tgrep (1.0.6)
  57. nose (1.3.7)
  58. notebook (4.4.1)
  59. numpy (1.12.1)
  60. olefile (0.44)
  61. opencv-helpers (1.1)
  62. opencv-python (3.2.0.6)
  63. opencvutils (0.5.8)
  64. packaging (16.8)
  65. pandas (0.19.2)
  66. pandocfilters (1.4.1)
  67. pathlib2 (2.2.1)
  68. pbr (2.0.0)
  69. pexpect (4.2.1)
  70. pickleshare (0.7.4)
  71. Pillow (4.0.0)
  72. pip (9.0.1)
  73. prompt-toolkit (1.0.13)
  74. protobuf (3.1.0)
  75. ptyprocess (0.5.1)
  76. PyBrain (0.3)
  77. pycurl (7.19.0)
  78. pygame (1.9.3)
  79. Pygments (2.2.0)
  80. pygobject (3.14.0)
  81. pygpgme (0.3)
  82. pyliblzma (0.5.3)
  83. pyparsing (2.2.0)
  84. pysummarize (0.6.0)
  85. python-dateutil (2.6.0)
  86. pytorch-extras (0.1.3)
  87. pytz (2016.10)
  88. pyxattr (0.5.1)
  89. PyYAML (3.12)
  90. pyzmq (16.0.2)
  91. qtconsole (4.2.1)
  92. requests (2.13.0)
  93. scandir (1.5)
  94. scikit-image (0.12.3)
  95. scikit-learn (0.18.1)
  96. scikit-sound (0.1.4)
  97. scikit-stack (3.0)
  98. scikit-surprise (1.0.2)
  99. scikit-tensor (0.1)
  100. scikit-video (0.1.2)
  101. scipy (0.19.0)
  102. setuptools (34.3.3)
  103. simplegeneric (0.8.1)
  104. singledispatch (3.4.0.3)
  105. six (1.10.0)
  106. slip (0.4.0)
  107. slip.dbus (0.4.0)
  108. smart-open (1.4.0)
  109. subprocess32 (3.2.7)
  110. tensorflow (1.0.0)
  111. tensorlayer (1.3.11)
  112. terminado (0.6)
  113. testpath (0.3)
  114. textblob (0.12.0)
  115. textblob-aptagger (0.2.0)
  116. tflearn (0.3)
  117. Theano (0.8.2)
  118. toolz (0.8.2)
  119. torch (0.1.11.post4)
  120. torchfcn (1.3)
  121. torchfile (0.1.0)
  122. torchvision (0.1.8)
  123. tornado (4.4.2)
  124. tqdm (4.11.2)
  125. traitlets (4.3.2)
  126. urlgrabber (3.10.2)
  127. visdom (0.1.2)
  128. wcwidth (0.1.7)
  129. webencodings (0.5)
  130. wheel (0.29.0)
  131. widgetsnbextension (2.0.0)
  132. word2veckeras (0.0.5.2)
  133. yum-langpacks (0.4.2)
  134. yum-metadata-parser (1.1.4)

TensorFlow读取数据方法说明

低效的IO方式

最近通过观察PAI平台上TensoFlow用户的运行情况,发现大家在数据IO这方面还是有比较大的困惑,主要是因为很多同学没有很好的理解本地执行TensorFlow代码和分布式云端执行TensorFlow的区别。本地读取数据是server端直接从client端获得graph进行计算,而云端服务server在获得graph之后还需要将计算下发到各个worker处理(具体原理可以参考视频教程-Tensorflow高级篇:https://tianchi.aliyun.com/competition/new_articleDetail.html)。

本文通过读取一个简单的CSV文件为例,帮助大家快速了解如何使用TensorFlow高效的读取数据。CSV文件如下:

  1. 1,1,1,1,1
  2. 2,2,2,2,2
  3. 3,3,3,3,3

首先我们来看下大家容易产生问题的几个地方。

1.不建议用python本地读取文件的方式

PAI支持python的自带IO方式,但是需要将数据源和代码打包上传的方式使用,这种读取方式是将数据写入内存之后再计算,效率比较低,不建议使用。范例代码如下:

  1. import csv
  2. csv_reader=csv.reader(open('csvtest.csv'))
  3. for row in csv_reader:
  4. print(row)

2.尽量不要用第三方库的读取文件方法

很多同学使用第三方库的一些数据IO的方式进行数据读取,比如TFLearn、Panda的数据IO方式,这些方法很多都是通过封装PYTHON的读取方式实现的,所以在PAI平台使用的时候也会造成效率低下问题。

3.尽量不要用preload的方式读取文件

很多人在用PAI的服务的时候表示GPU并没有比本地的CPU速度快的明显,主要问题可能就出在数据IO这块。preload的方式是先把数据全部都读到内存中,然后再通过session计算,比如feed的读取方式。这样要先进行数据读取,再计算,不同步造成性能浪费,同时因为内存限制也无法支持大数据量的计算。举个例子:假设我们的硬盘中有一个图片数据集0001.jpg,0002.jpg,0003.jpg……我们只需要把它们读取到内存中,然后提供给GPU或是CPU进行计算就可以了。这听起来很容易,但事实远没有那么简单。事实上,我们必须要把数据先读入后才能进行计算,假设读入用时0.1s,计算用时0.9s,那么就意味着每过1s,GPU都会有0.1s无事可做,这就大大降低了运算的效率。

下面我们看下高效的读取方式。

高效的IO方式

高效的TensorFlow读取方式是将数据读取转换成OP,通过session run的方式拉去数据。另外,读取线程源源不断地将文件系统中的图片读入到一个内存的队列中,而负责计算的是另一个线程,计算需要数据时,直接从内存队列中取就可以了。这样就可以解决GPU因为IO而空闲的问题!

下面我们看下代码,如何在PAI平台通过OP的方式读取数据:

  1. import argparse
  2. import tensorflow as tf
  3. import os
  4. FLAGS=None
  5. def main(_):
  6. dirname = os.path.join(FLAGS.buckets, "csvtest.csv")
  7. reader=tf.TextLineReader()
  8. filename_queue=tf.train.string_input_producer([dirname])
  9. key,value=reader.read(filename_queue)
  10. record_defaults=[[''],[''],[''],[''],['']]
  11. d1, d2, d3, d4, d5= tf.decode_csv(value, record_defaults, ',')
  12. init=tf.initialize_all_variables()
  13. with tf.Session() as sess:
  14. sess.run(init)
  15. coord = tf.train.Coordinator()
  16. threads = tf.train.start_queue_runners(sess=sess,coord=coord)
  17. for i in range(4):
  18. print(sess.run(d2))
  19. coord.request_stop()
  20. coord.join(threads)
  21. if __name__ == '__main__':
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument('--buckets', type=str, default='',
  24. help='input data path')
  25. parser.add_argument('--checkpointDir', type=str, default='',
  26. help='output model path')
  27. FLAGS, _ = parser.parse_known_args()
  28. tf.app.run(main=main)
  • dirname:OSS文件路径,可以是数组,方便下一阶段shuffle
  • reader:TF内置各种reader API,可以根据需求选用
  • tf.train.string_input_producer:将文件生成队列
  • tf.decode_csv:是一个splite功能的OP,可以拿到每一行的特定参数
  • 通过OP获取数据,在session中需要tf.train.Coordinator()和tf.train.start_queue_runners(sess=sess,coord=coord)

在代码中,我们的输入是3行5个字段:

  1. 1,1,1,1,1
  2. 2,2,2,2,2
  3. 3,3,3,3,3

我们循环输出4次,打印出第2个字段。结果如图:

输出结果也证明了数据结构是成队列。

其它

本文导读目录
本文导读目录
以上内容是否对您有帮助?