全部产品
云市场

DSW读写OSS

更新时间:2019-08-13 22:58:54

阿里云对象存储服务(Object Storage Service,简称 OSS[https://www.aliyun.com/product/oss]),是阿里云提供的海量、安全、低成本、高可靠的云存储服务。PAI DSW产品除了内置的存储NAS,与OSS也是连通的。本文旨在介绍在PAI DSW上如何快速读写存储在OSS中的文件。

OSS python sdk

一般场景下,用户可以直接使用OSS的python api读写OSS中的数据。
完整的文档请参考:https://aliyun-oss-python-sdk.readthedocs.io/en/stable/oss2.html

DSW上已经预装oss2 python包,使用之前参考如下示例进行鉴权以及初始化:

  1. import oss2
  2. auth = oss2.Auth('your-access-key-id', 'your-access-key-secret')
  3. # OSS域名说明
  4. # 北京CPU实例使用: oss-cn-beijing.aliyuncs.com
  5. # 北京GPU实例使用: oss-cn-beijing-internal.aliyuncs.com
  6. # 上海GPU P100实例和CPU实例使用: oss-cn-shanghai.aliyuncs.com
  7. # 上海GPU M40实例使用: oss-cn-shanghai-internal.aliyuncs.com
  8. # bucket_name没有oss://开头
  9. bucket = oss2.Bucket(auth, 'http://oss-cn-beijing-internal.aliyuncs.com', '<your_bucket_name>')

然后可以读写OSS上的数据,示例代码如下:

  1. # 读一个完整文件中的数据
  2. result = bucket.get_object('path/to/your_file')
  3. print(result.read())
  4. #也可以支持按range读数据
  5. result = bucket.get_object('path/to/your_file', byte_range=(0, 99))
  6. # 写数据到OSS
  7. bucket.put_object('path/to/your_file', 'content of the object')
  8. # 也可以对文件进行append
  9. result = bucket.append_object('path/to/your_file', 0, 'content of the object')
  10. result = bucket.append_object('path/to/your_file', result.next_position, 'content of the object')

Tensorflow OSS IO

对于tensorflow用户,DSW上可以使用 tensorflow_io.oss [https://github.com/tensorflow/io/blob/master/tensorflow_io/oss/README.md] 模块直接读取OSS中的数据,解决了执行tensorflow训练任务时需要频繁拷贝数据文件、模型文件的问题。

DSW上已经预装了tensorflow_io.oss包,使用之前首先需要import tensorflow.oss,按以下格式拼接好oss_bucket的url。

  1. import tensorflow as tf
  2. import tensorflow_io.oss
  3. access_id="<your_ak_id>"
  4. access_key="<your_ak_key>"
  5. # OSS域名说明
  6. # 北京CPU实例使用: oss-cn-beijing.aliyuncs.com
  7. # 北京GPU实例使用: oss-cn-beijing-internal.aliyuncs.com
  8. # 上海GPU P100实例和CPU实例使用: oss-cn-shanghai.aliyuncs.com
  9. # 上海GPU M40实例使用: oss-cn-shanghai-internal.aliyuncs.com
  10. host = "oss-cn-beijing-internal.aliyuncs.com"
  11. bucket="oss://<your_bucket_name>"
  12. oss_bucket_root="{}\x01id={}\x02key={}\x02host={}/".format(bucket, access_id, access_key, host)

然后使用GFile读写OSS文本文件,示例代码如下:

  1. oss_file = oss_bucket_root + "test.txt"
  2. with tf.gfile.GFile(oss_file, "w") as f:
  3. f.write("xxxxxxxxx")
  4. with tf.gfile.GFile(oss_file, "r") as f:
  5. print(f.read())

也可以直接使用TextLineDataset读取OSS中的数据,示例代码如下:

  1. # Test textline reader op
  2. oss_file = oss_bucket_root + "test.txt"
  3. dataset = tf.data.TextLineDataset([oss_file])
  4. iterator = dataset.make_initializable_iterator()
  5. a = iterator.get_next()
  6. with tf.Session() as sess:
  7. tf.global_variables_initializer().run()
  8. sess.run(iterator.initializer)
  9. print(sess.run(a))

说明:以上示例代码均使用tensorflow1.0的API,Tensorflow2.0仍然处于beta阶段,API迁移请参考: https://www.tensorflow.org/beta/guide/migration_guide

PyTorch使用OSS python api

对于PyTorch用户,我们可以使用OSS存储训练数据、日志、模型等,PyTorch可以直接用OSS python api读写OSS中的数据。

训练数据加载

我们可以将数据放在一个 OSS bucket 上,并将数据路径和对应的 label 放在同一个 OSS bucket 上的一个索引文件中。通过自定义一个Dataset,在 Pytorch 中使用 DataLoader API 进行多进程并行数据读取。

下面代码是一个简单的示例,其中 endpoint 为 OSS 的 host,bucket 为 bucket 名称,auth 为鉴权对象,index_file 为索引文件的路径。

示例索引文件格式为每条样本用 , 分割,样本路径和label之间用 : 分割

  1. import io
  2. import oss2
  3. import PIL
  4. import torch
  5. class OSSDataset(torch.utils.data.dataset.Dataset):
  6. def __init__(self, endpoint, bucket, auth, index_file):
  7. self._bucket = oss2.Bucket(auth, endpoint, bucket)
  8. self._indices = self._bucket.get_object(index_file).read().split(',')
  9. def __len__(self):
  10. return len(self._indices)
  11. def __getitem__(self, index):
  12. img_path, label = self._indices(index).strip().split(':')
  13. img_str = self._bucket.get_object(img_path)
  14. img_buf = io.BytesIO()
  15. img_buf.write(img_str.read())
  16. img_buf.seek(0)
  17. img = Image.open(img_buf).convert('RGB')
  18. img_buf.close()
  19. return img, label
  20. dataset = OSSDataset(endpoint, bucket, index_file)
  21. data_loader = torch.utils.data.DataLoader(
  22. dataset,
  23. batch_size=batch_size,
  24. num_workers=num_loaders,
  25. pin_memory=True)

写日志到OSS

可以编写一个 StreamHandler 来封装 logging 日志的输出。其中 endpoint 是 OSS 的 host,bucket 为 OSS bucket 的名称,auth 为鉴权对象,log_file 为日志文件的路径。

需要注意,不允许多个进程写同一个日志文件。

  1. import oss2
  2. import logging
  3. class OSSLoggingHandler(logging.StreamHandler):
  4. def __init__(self, endpoint, bucket, auth, log_file):
  5. OSSLoggingHandler.__init__(self)
  6. self._bucket = oss2.Bucket(auth, endpoint, bucket)
  7. self._log_file = log_file
  8. self._pos = self._bucket.append_object(self._log_file, 0, '')
  9. def emit(self, record):
  10. msg = self.format(record)
  11. self._pos = self._bucket.append_object(self._log_file, self._pos.next_position, msg)
  12. oss_handler = OSSLoggingHandler(endpoint, bucket, log_file)
  13. logging.basicConfig(
  14. stream=oss_handler,
  15. format='[%(asctime)s] [%(levelname)s] [%(process)d#%(threadName)s] ' +
  16. '[%(filename)s:%(lineno)d] %(message)s',
  17. level=logging.INFO)

Save & Load模型

可以采用OSS2 python api来save/load pytorch的模型(关于pytorch中如何save/load模型,请参考:https://pytorch.org/tutorials/beginner/saving_loading_models.html)。

pytorch save模型示例:

  1. from io import BytesIO
  2. import torch
  3. import oss2
  4. # bucket_name没有oss://开头
  5. bucket_name = "your_bucket_name"
  6. bucket = oss2.Bucket(auth, endpoint, bucket_name)
  7. buffer = BytesIO()
  8. torch.save(model.state_dict(), buffer)
  9. bucket.put_object("your_model_path", buffer.getvalue())

pytorch load模型示例:

  1. from io import BytesIO
  2. import torch
  3. import oss2
  4. # bucket_name没有oss://开头
  5. bucket_name = "your_bucket_name"
  6. bucket = oss2.Bucket(auth, endpoint, bucket_name)
  7. buffer = BytesIO(bucket.get_object("your_model_path").read())
  8. model.load_state_dict(torch.load(buffer))