本文为您介绍如何使用OSS Python SDK、TensorFlow OSS IO及PyTorch OSS API读写OSS数据。

背景信息

对象存储OSS(Object Storage Service)是阿里云提供的海量、安全、低成本及高可靠性的云存储服务。PAI-DSW不仅预置NAS文件系统,而且对接OSS存储。

OSS Python SDK

通常,您可以直接使用OSS的Python API读写OSS中的数据,详情请参见OSS2 Package。PAI-DSW已预装OSS2 Python包,您可以参见如下方法读写OSS数据。

  1. 鉴权及初始化。
    import oss2
    auth = oss2.Auth('<your_AccessKey_ID>', '<your_AccessKey_Secret>')
    bucket = oss2.Bucket(auth, 'http://oss-cn-beijing-internal.aliyuncs.com', '<your_bucket_name>')
    需要根据实际需要修改以下参数。
    参数 描述
    <your_AccessKey_ID> 阿里云的AccessKey ID。
    <your_AccessKey_Secret> 阿里云的AccessKey Secret。
    http://oss-cn-beijing-internal.aliyuncs.com OSS域名。需要根据实例的地域选择对应的OSS域名:
    • 华北2(北京)后付费实例:oss-cn-beijing.aliyuncs.com
    • 华北2(北京)预付费实例:oss-cn-beijing-internal.aliyuncs.com
    • 华东2(上海)GPU P100实例或CPU实例:oss-cn-shanghai.aliyuncs.com
    • 华东2(上海)GPU M40实例:oss-cn-shanghai-internal.aliyuncs.com
    <your_bucket_name> Bucket名称,且开头不带oss://
  2. 读写OSS数据。
    #读取一个完整文件。
    result = bucket.get_object('<your_file_path/your_file>')
    print(result.read())
    #按Range读取数据。
    result = bucket.get_object('<your_file_path/your_file>', byte_range=(0, 99))
    #写数据至OSS。
    bucket.put_object('<your_file_path/your_file>', '<your_object_content>')
    #对文件进行Append。
    result = bucket.append_object('<your_file_path/your_file>', 0, '<your_object_content>')
    result = bucket.append_object('<your_file_path/your_file>', result.next_position, '<your_object_content>')
    其中<your_file_path/your_file>表示待读写的文件路径,<your_object_content>表示待Append的内容,需要根据实际情况修改。

TensorFlow OSS IO

PAI-DSW提供tensorflow_io.oss模块,TensorFlow用户可以使用其直接读取OSS数据。在执行TensorFlow训练任务的过程中,无需频繁拷贝数据文件或模型文件。

  1. 导入tensorflow_io.oss包,并拼接OSS Bucket URL。
    import tensorflow as tf
    import tensorflow_io.oss
    access_id="<your_Access_Key_ID>"
    access_key="<your_Access_Key_Secret>"
    host = "oss-cn-beijing-internal.aliyuncs.com"
    bucket="oss://<your_bucket_name>"
    oss_bucket_root="{}\x01id={}\x02key={}\x02host={}/".format(bucket, access_id, access_key, host)
    需要根据实际需要修改以下参数。
    参数 描述
    <your_AccessKey_ID> 阿里云的AccessKey ID。
    <your_AccessKey_Secret> 阿里云的AccessKey Secret。
    oss-cn-beijing-internal.aliyuncs.com OSS域名。需要根据实例的地域选择对应的OSS域名:
    • 华北2(北京)后付费实例:oss-cn-beijing.aliyuncs.com
    • 华北2(北京)预付费实例:oss-cn-beijing-internal.aliyuncs.com
    • 华东2(上海)GPU P100实例或CPU实例:oss-cn-shanghai.aliyuncs.com
    • 华东2(上海)GPU M40实例:oss-cn-shanghai-internal.aliyuncs.com
    <your_bucket_name> Bucket名称,且开头不带oss://
  2. 您可以使用以下任何一种方式读取OSS数据(以TensorFlow 1.0 API为例):
    • 使用GFile读写OSS文本文件。
      oss_file = oss_bucket_root + "test.txt"
      with tf.gfile.GFile(oss_file, "w") as f:
        f.write("<your_context>")
      with tf.gfile.GFile(oss_file, "r") as f:
        print(f.read())
      其中<your_context>表示写入的内容,需要根据实际情况修改。
    • 使用TextLineDataset读取OSS数据。
      #Test textline reader op.
      oss_file = oss_bucket_root + "test.txt"
      dataset = tf.data.TextLineDataset([oss_file])
      iterator = dataset.make_initializable_iterator()
      a = iterator.get_next()
      with tf.Session() as sess:
          tf.global_variables_initializer().run()
          sess.run(iterator.initializer)
          print(sess.run(a))

OSS Python API

对于PyTorch用户,PAI-DSW提供OSS Python API,用于直接读写OSS数据。

您可以在OSS存储训练数据、日志或模型:
  • 加载训练数据
    您可以将数据存放在一个OSS Bucket中,且将数据路径和对应的Label存储在同一个OSS Bucket的索引文件中。通过自定义DataSet,在PyTorch中使用DataLoader API多进程并行读取数据,示例如下。
    import io
    import oss2
    import PIL
    import torch
    class OSSDataset(torch.utils.data.dataset.Dataset):
        def __init__(self, endpoint, bucket, auth, index_file):
            self._bucket = oss2.Bucket(auth, endpoint, bucket)
            self._indices = self._bucket.get_object(index_file).read().split(',')
        def __len__(self):
            return len(self._indices)
        def __getitem__(self, index):
            img_path, label = self._indices(index).strip().split(':')
            img_str = self._bucket.get_object(img_path)
            img_buf = io.BytesIO()
            img_buf.write(img_str.read())
            img_buf.seek(0)
            img = Image.open(img_buf).convert('RGB')
            img_buf.close()
            return img, label
    dataset = OSSDataset(endpoint, bucket, index_file)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_loaders,
        pin_memory=True)
    其中endpoint为OSS域名,bucket为Bucket名称,auth为鉴权对象,index_file为索引文件的路径,都需要根据实际情况修改。
    说明 示例中,索引文件格式为每条样本使用英文逗号(,)分隔,样本路径与Label之间使用英文冒号(:)分隔。
  • 写日志
    您可以编写一个StreamHandler,用于封装Logging日志的输出。
    说明 多个进程不能同时写一个日志文件。
    import oss2
    import logging
    class OSSLoggingHandler(logging.StreamHandler):
        def __init__(self, endpoint, bucket, auth, log_file):
            OSSLoggingHandler.__init__(self)
            self._bucket = oss2.Bucket(auth, endpoint, bucket)
            self._log_file = log_file
            self._pos = self._bucket.append_object(self._log_file, 0, '')
        def emit(self, record):
            msg = self.format(record)
            self._pos = self._bucket.append_object(self._log_file, self._pos.next_position, msg)
    oss_handler = OSSLoggingHandler(endpoint, bucket, log_file)
    logging.basicConfig(
        stream=oss_handler,
        format='[%(asctime)s] [%(levelname)s] [%(process)d#%(threadName)s] ' +
               '[%(filename)s:%(lineno)d] %(message)s',
        level=logging.INFO)
    其中endpoint为OSS域名,bucket为OSS Bucket名称,auth为鉴权对象,log_file为日志文件路径,都需要根据实际情况修改。
  • Save或Load模型
    您可以使用OSS2 Python API Save或Load PyTorch模型(关于PyTorch如何Save或Load模型,详情请参见PyTorch),示例如下:
    • Save模型
      from io import BytesIO
      import torch
      import oss2
      # bucket_name
      bucket_name = "<your_bucket_name>"
      bucket = oss2.Bucket(auth, endpoint, bucket_name)
      buffer = BytesIO()
      torch.save(model.state_dict(), buffer)
      bucket.put_object("<your_model_path>", buffer.getvalue())
      其中endpoint为OSS域名,<your_bucket_name>为OSS Bucket名称,且开头不带oss://auth为鉴权对象,<your_model_path>为模型路径,都需要根据实际情况修改。
    • Load模型
      from io import BytesIO
      import torch
      import oss2
      bucket_name = "<your_bucket_name>"
      bucket = oss2.Bucket(auth, endpoint, bucket_name)
      buffer = BytesIO(bucket.get_object("<your_model_path>").read())
      model.load_state_dict(torch.load(buffer))
      其中endpoint为OSS域名,<your_bucket_name>为OSS Bucket名称,且开头不带oss://auth为鉴权对象,<your_model_path>为模型路径,都需要根据实际情况修改。