读写OSS数据

本文为您介绍如何使用OSS Python SDKOSS Python API读写OSS数据。

背景信息

对象存储OSS(Object Storage Service)是阿里云提供的海量、安全、低成本及高可靠性的云存储服务。DSW不仅预置NAS文件系统,而且对接OSS存储。

如果您需要频繁访问和处理大规模数据,推荐您在创建DSW实例时将OSS注册为数据集,然后挂载该数据集;如果您只需要临时访问OSS数据,或者根据业务逻辑来决定是否访问OSS,可采用更灵活的方式,比如SDKAPI方式。

OSS Python SDK

通常,您可以直接使用OSSPython API读写OSS中的数据,详情请参见OSS2 Package

DSW已预装OSS2 Python包,您可以参见如下方法读写OSS数据。

  1. 鉴权及初始化。

    import oss2
    auth = oss2.Auth('<your_AccessKey_ID>', '<your_AccessKey_Secret>')
    bucket = oss2.Bucket(auth, 'http://oss-cn-beijing-internal.aliyuncs.com', '<your_bucket_name>')

    需要根据实际需要修改以下参数。

    参数

    描述

    <your_AccessKey_ID>

    阿里云的AccessKey ID。

    <your_AccessKey_Secret>

    阿里云的AccessKey Secret。

    http://oss-cn-beijing-internal.aliyuncs.com

    OSS域名。需要根据实例的地域选择对应的OSS域名:

    • 华北2(北京)后付费实例:oss-cn-beijing.aliyuncs.com

    • 华北2(北京)预付费实例:oss-cn-beijing-internal.aliyuncs.com

    • 华东2(上海)GPU P100实例或CPU实例:oss-cn-shanghai.aliyuncs.com

    • 华东2(上海)GPU M40实例:oss-cn-shanghai-internal.aliyuncs.com

    <your_bucket_name>

    Bucket名称,且开头不带oss://

  2. 读写OSS数据。

    #读取一个完整文件。
    result = bucket.get_object('<your_file_path/your_file>')
    print(result.read())
    #按Range读取数据。
    result = bucket.get_object('<your_file_path/your_file>', byte_range=(0, 99))
    #写数据至OSS。
    bucket.put_object('<your_file_path/your_file>', '<your_object_content>')
    #对文件进行Append。
    result = bucket.append_object('<your_file_path/your_file>', 0, '<your_object_content>')
    result = bucket.append_object('<your_file_path/your_file>', result.next_position, '<your_object_content>')

    其中<your_file_path/your_file>表示待读写的文件路径,<your_object_content>表示待Append的内容,需要根据实际情况修改。

OSS Python API

对于PyTorch用户,DSW提供OSS Python API,用于直接读写OSS数据。

您可以在OSS存储训练数据或模型:

  • 加载训练数据

    您可以将数据存放在一个OSS Bucket中,且将数据路径和对应的Label存储在同一个OSS Bucket的索引文件中。通过自定义DataSet,在PyTorch中使用DataLoader API多进程并行读取数据,示例如下。

    import io
    import oss2
    import PIL
    import torch
    class OSSDataset(torch.utils.data.dataset.Dataset):
        def __init__(self, endpoint, bucket, auth, index_file):
            self._bucket = oss2.Bucket(auth, endpoint, bucket)
            self._indices = self._bucket.get_object(index_file).read().split(',')
        def __len__(self):
            return len(self._indices)
        def __getitem__(self, index):
            img_path, label = self._indices(index).strip().split(':')
            img_str = self._bucket.get_object(img_path)
            img_buf = io.BytesIO()
            img_buf.write(img_str.read())
            img_buf.seek(0)
            img = Image.open(img_buf).convert('RGB')
            img_buf.close()
            return img, label
    dataset = OSSDataset(endpoint, bucket, auth, index_file)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=num_loaders,
        pin_memory=True)

    其中endpointOSS域名,bucketBucket名称,auth为鉴权对象,index_file为索引文件的路径,都需要根据实际情况修改。

    说明

    示例中,索引文件格式为每条样本使用英文逗号(,)分隔,样本路径与Label之间使用英文冒号(:)分隔。

  • SaveLoad模型

    您可以使用OSS2 Python API SaveLoad PyTorch模型(关于PyTorch如何SaveLoad模型,详情请参见PyTorch),示例如下:

    • Save模型

      from io import BytesIO
      import torch
      import oss2
      # bucket_name
      bucket_name = "<your_bucket_name>"
      bucket = oss2.Bucket(auth, endpoint, bucket_name)
      buffer = BytesIO()
      torch.save(model.state_dict(), buffer)
      bucket.put_object("<your_model_path>", buffer.getvalue())

      其中endpointOSS域名,<your_bucket_name>OSS Bucket名称,且开头不带oss://auth为鉴权对象,<your_model_path>为模型路径,都需要根据实际情况修改。

    • Load模型

      from io import BytesIO
      import torch
      import oss2
      bucket_name = "<your_bucket_name>"
      bucket = oss2.Bucket(auth, endpoint, bucket_name)
      buffer = BytesIO(bucket.get_object("<your_model_path>").read())
      model.load_state_dict(torch.load(buffer))

      其中endpointOSS域名,<your_bucket_name>OSS Bucket名称,且开头不带oss://auth为鉴权对象,<your_model_path>为模型路径,都需要根据实际情况修改。