本文为您介绍如何使用OSS Python SDK及OSS Python API读写OSS数据。
背景信息
对象存储OSS(Object Storage Service)是阿里云提供的海量、安全、低成本及高可靠性的云存储服务。DSW不仅预置NAS文件系统,而且对接OSS存储。
如果您需要频繁访问和处理大规模数据,推荐您在创建DSW实例时将OSS注册为数据集,然后挂载该数据集;如果您只需要临时访问OSS数据,或者根据业务逻辑来决定是否访问OSS,可采用更灵活的方式,比如SDK和API方式。
OSS Python SDK
通常,您可以直接使用OSS的Python API读写OSS中的数据,详情请参见OSS2 Package。
DSW已预装OSS2 Python包,您可以参见如下方法读写OSS数据。
鉴权及初始化。
import oss2 auth = oss2.Auth('<your_AccessKey_ID>', '<your_AccessKey_Secret>') bucket = oss2.Bucket(auth, 'http://oss-cn-beijing-internal.aliyuncs.com', '<your_bucket_name>')
需要根据实际需要修改以下参数。
参数
描述
<your_AccessKey_ID>
阿里云的AccessKey ID。
<your_AccessKey_Secret>
阿里云的AccessKey Secret。
http://oss-cn-beijing-internal.aliyuncs.com
OSS域名。需要根据实例的地域选择对应的OSS域名:
华北2(北京)后付费实例:oss-cn-beijing.aliyuncs.com
华北2(北京)预付费实例:oss-cn-beijing-internal.aliyuncs.com
华东2(上海)GPU P100实例或CPU实例:oss-cn-shanghai.aliyuncs.com
华东2(上海)GPU M40实例:oss-cn-shanghai-internal.aliyuncs.com
<your_bucket_name>
Bucket名称,且开头不带oss://。
读写OSS数据。
#读取一个完整文件。 result = bucket.get_object('<your_file_path/your_file>') print(result.read()) #按Range读取数据。 result = bucket.get_object('<your_file_path/your_file>', byte_range=(0, 99)) #写数据至OSS。 bucket.put_object('<your_file_path/your_file>', '<your_object_content>') #对文件进行Append。 result = bucket.append_object('<your_file_path/your_file>', 0, '<your_object_content>') result = bucket.append_object('<your_file_path/your_file>', result.next_position, '<your_object_content>')
其中
<your_file_path/your_file>
表示待读写的文件路径,<your_object_content>
表示待Append的内容,需要根据实际情况修改。
OSS Python API
对于PyTorch用户,DSW提供OSS Python API,用于直接读写OSS数据。
您可以在OSS存储训练数据或模型:
加载训练数据
您可以将数据存放在一个OSS Bucket中,且将数据路径和对应的Label存储在同一个OSS Bucket的索引文件中。通过自定义DataSet,在PyTorch中使用
DataLoader
API多进程并行读取数据,示例如下。import io import oss2 import PIL import torch class OSSDataset(torch.utils.data.dataset.Dataset): def __init__(self, endpoint, bucket, auth, index_file): self._bucket = oss2.Bucket(auth, endpoint, bucket) self._indices = self._bucket.get_object(index_file).read().split(',') def __len__(self): return len(self._indices) def __getitem__(self, index): img_path, label = self._indices(index).strip().split(':') img_str = self._bucket.get_object(img_path) img_buf = io.BytesIO() img_buf.write(img_str.read()) img_buf.seek(0) img = Image.open(img_buf).convert('RGB') img_buf.close() return img, label dataset = OSSDataset(endpoint, bucket, auth, index_file) data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=num_loaders, pin_memory=True)
其中
endpoint
为OSS域名,bucket
为Bucket名称,auth
为鉴权对象,index_file
为索引文件的路径,都需要根据实际情况修改。说明示例中,索引文件格式为每条样本使用英文逗号(,)分隔,样本路径与Label之间使用英文冒号(:)分隔。
Save或Load模型
您可以使用OSS2 Python API Save或Load PyTorch模型(关于PyTorch如何Save或Load模型,详情请参见PyTorch),示例如下:
Save模型
from io import BytesIO import torch import oss2 # bucket_name bucket_name = "<your_bucket_name>" bucket = oss2.Bucket(auth, endpoint, bucket_name) buffer = BytesIO() torch.save(model.state_dict(), buffer) bucket.put_object("<your_model_path>", buffer.getvalue())
其中
endpoint
为OSS域名,<your_bucket_name>
为OSS Bucket名称,且开头不带oss://,auth
为鉴权对象,<your_model_path>
为模型路径,都需要根据实际情况修改。Load模型
from io import BytesIO import torch import oss2 bucket_name = "<your_bucket_name>" bucket = oss2.Bucket(auth, endpoint, bucket_name) buffer = BytesIO(bucket.get_object("<your_model_path>").read()) model.load_state_dict(torch.load(buffer))
其中
endpoint
为OSS域名,<your_bucket_name>
为OSS Bucket名称,且开头不带oss://,auth
为鉴权对象,<your_model_path>
为模型路径,都需要根据实际情况修改。