简介
图像分类是指识别图片中主体或者状态单一的场景。
前提
在对图像分类进行训练之前,要准备好如下数据:
开通OSS授权
用于训练的图片集。
图片集对应的标签。
操作步骤
下面将以JAVA SDK为例,详细描述如何训练自己的图像分类模型。操作步骤如下:
1.创建项目。核心示例代码:
CreateProjectRequest request = new CreateProjectRequest();
request.setName("图像分类测试");
request.setDescription("图像分类描述");
request.setProType("Classification");
CreateProjectResponse response = client.getAcsResponse(request)
// 保存项目ID。
String projectId = response.getProject().getProjectId();
2.创建图片标签,每个项目最少需要两个以上的标签。例子中,将创建苹果和香蕉的标签。核心代码:
CreateTagRequest request = new CreateTagRequest();
// 创建项目时,返回的项目ID。
request.setProjectId(projectId);
request.setName("苹果");
request.setDescription("苹果的描述");
CreateTagResponse response = client.getAcsResponse(request);
// 保存苹果的标签ID。
String appleTagId = response.getTag().getTagId();
request.setName("香蕉");
request.setDescription("香蕉的描述");
response = client.getAcsResponse(request);
// 保存香蕉的标签ID。
String bananaTagId = response.getTag().getTagId();
3.创建训练数据,并做标注。
将苹果和香蕉的图片分别上传至OSS。
通过创建训练数据接口将OSS文件添加到训练集,同时标注。例子里,核心代码:
CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest(); request.setProjectId(projectId); // 添加苹果训练数据。 // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。 request.setUrls("http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/1.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/2.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/3.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/4.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/5.jpg,"); request.setTagId(appleTagId); CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request); // 添加香蕉训练数据。 // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。 request.setUrls("http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/1.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/2.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/3.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/4.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/5.jpg,"); request.setTagId(bananaTagId); CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);
4.开始训练数据准备完毕之后,调用开始训练接口,并等待训练完成。模型训练的时间较长,请耐心等待。核心代码:
TrainProjectRequest request = new TrainProjectRequest();
request.setProjectId(projectId);
TrainProjectResponse response = client.getAcsResponse(request);
// 保存迭代ID。
String iterationId = response.getIterationId();
// 等待迭代完成
while(true) {
DescribeTrainResultRequest request = new DescribeTrainResultRequest();
request.setProjectId(projectId);
request.setIterationId(iterationId);
TrainProjectResponse response = client.getAcsResponse(request);
if ("TrainSuccess".equals(response.getTrainResult().getStatus())) {
break;
}
TimeUnit.SECONDS.sleep(5);
}
5.预测图片训练结束之后,可以拿到训练时的迭代ID,进行预测。预测作业为异步接口,提交完成之后,通过查询接口进行查询预测结果。核心代码:
// 提交图片预测。
PredictImageRequest request = new PredictImageRequest();
request.setProjectId(projectId);
request.setIterationId(iterationId);
// 替换成OSS图片URL
request.setDataUrls("test-bucket.oss-cn-beijing.aliyuncs.com/predict/1.jpg");
PredictImageResponse response = client.getAcsResponse(request);
// 等待一会
TimeUnit.SECONDS.sleep(10);
// 查询结果
List<PredictImageResponse.PredictData> datas = response.getPredictDatas();
List<String> dataIds = new ArrayList<>(datas.size());
for (PredictImageResponse.PredictData data : datas) {
dataIds.add(data.getDataId());
}
DescribePredictDatasRequest request = new DescribePredictDatasRequest();
request.setProjectId(projectId);
request.setIterationId(iterationId);
request.setDataIds(dataIds);
PredictImageResponse response = client.getAcsResponse(request);
// 输出预测结果
System.out.println(JSON.toJSONString(response));
反馈
- 本页导读 (0)
文档反馈