使用OssCheckpoint在OSS中存储和访问检查点

本文为您介绍如何使用OssCheckpoint直接从OSS中读写检查点(模型训练过程中保存的特定时间点的模型状态)。

前提条件

已安装并配置OSS Connector for AI/ML。具体操作,请参见安装OSS Connector for AI/ML配置OSS Connector for AI/ML

OssCheckpoint

OssCheckpoint适用于数据训练过程中对训练结果进行读写需求的场景。

以下示例展示了如何使用OssCheckpoint来进行Checkpoint的读取和写入。

import torch
from osstorchconnector import OssCheckpoint

ENDPOINT = "endpoint"
CRED_PATH = "/root/.alibabacloud/credentials"
CONFIG_PATH = "/etc/oss-connector/config.json"

#  使用OssCheckpoint创建checkpoint
checkpoint = OssCheckpoint(endpoint=ENDPOINT, cred_path=CRED_PATH, config_path=CONFIG_PATH)

# 读 checkpoint
CHECKPOINT_READ_URI = "oss://checkpoint/epoch.0"
with checkpoint.reader(CHECKPOINT_READ_URI) as reader:
   state_dict = torch.load(reader)

# 写 checkpoint
CHECKPOINT_WRITE_URI = "oss://checkpoint/epoch.1"
with checkpoint.writer(CHECKPOINT_WRITE_URI) as writer:
   torch.save(state_dict, writer)

数据类型

通过OssCheckpoint创建的checkpoint对象实现了常用的IO接口。更多信息,请参见OSS Connector for AI/ML中的数据类型

参数配置

使用OssCheckpoint时需要进行相应配置,具体配置项说明请参见下表。

参数名

参数类型

是否必选

说明

endpoint

string

OSS对外服务的访问域名。更多信息,请参见访问域名和数据中心

cred_path

string

鉴权文件默认路径为/root/.alibabacloud/credentials,更多信息请参见配置访问凭证

config_path

string

OSS Connector配置文件默认路径为/etc/oss-connector/config.json,更多信息请参见配置OSS Connector