全部产品
云市场
云游戏

图像分类训练和预测

更新时间:2019-10-25 19:01:00

简介

图像分类是指识别图片中主体或者状态单一的场景。

前提

在对图像分类进行训练之前,要准备好如下数据:

  • 开通OSS授权
  • 用于训练的图片集。
  • 图片集对应的标签。

操作步骤

下面将以JAVA SDK为例,详细描述如何训练自己的图像分类模型。操作步骤如下:

1.创建项目。核心示例代码:

  1. CreateProjectRequest request = new CreateProjectRequest();
  2. request.setName("图像分类测试");
  3. request.setDescription("图像分类描述");
  4. request.setProType("Classification");
  5. CreateProjectResponse response = client.getAcsResponse(request)
  6. // 保存项目ID。
  7. String projectId = response.getProject().getProjectId();

2.创建图片标签,每个项目最少需要两个以上的标签。例子中,将创建苹果和香蕉的标签。核心代码:

  1. CreateTagRequest request = new CreateTagRequest();
  2. // 创建项目时,返回的项目ID。
  3. request.setProjectId(projectId);
  4. request.setName("苹果");
  5. request.setDescription("苹果的描述");
  6. CreateTagResponse response = client.getAcsResponse(request);
  7. // 保存苹果的标签ID。
  8. String appleTagId = response.getTag().getTagId();
  9. request.setName("香蕉");
  10. request.setDescription("香蕉的描述");
  11. response = client.getAcsResponse(request);
  12. // 保存香蕉的标签ID。
  13. String bananaTagId = response.getTag().getTagId();

3.创建训练数据,并做标注。

  • 将苹果和香蕉的图片分别上传至OSS。
  • 通过创建训练数据接口将OSS文件添加到训练集,同时标注。例子里,核心代码:

    1. CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest();
    2. request.setProjectId(projectId);
    3. // 添加苹果训练数据。
    4. // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。
    5. request.setUrls("http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/1.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/2.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/3.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/4.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/apple/5.jpg,");
    6. request.setTagId(appleTagId);
    7. CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);
    8. // 添加香蕉训练数据。
    9. // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。
    10. request.setUrls("http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/1.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/2.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/3.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/4.jpg,http://test-bucket.oss-cn-beijing.aliyuncs.com/banana/5.jpg,");
    11. request.setTagId(bananaTagId);
    12. CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);

4.开始训练训练数据准备完毕之后,调用开始训练接口,并等待训练完成。模型训练的时间较长,请耐心等待。核心代码:

  1. TrainProjectRequest request = new TrainProjectRequest();
  2. request.setProjectId(projectId);
  3. TrainProjectResponse response = client.getAcsResponse(request);
  4. // 保存迭代ID。
  5. String iterationId = response.getIterationId();
  6. // 等待迭代完成
  7. while(true) {
  8. DescribeTrainResultRequest request = new DescribeTrainResultRequest();
  9. request.setProjectId(projectId);
  10. request.setIterationId(iterationId);
  11. TrainProjectResponse response = client.getAcsResponse(request);
  12. if ("TrainSuccess".equals(response.getTrainResult().getStatus())) {
  13. break;
  14. }
  15. TimeUnit.SECONDS.sleep(5);
  16. }

5.预测图片训练结束之后,可以拿到训练时的迭代ID,进行预测。预测作业为异步接口,提交完成之后,通过查询接口进行查询预测结果。核心代码:

  1. // 提交图片预测。
  2. PredictImageRequest request = new PredictImageRequest();
  3. request.setProjectId(projectId);
  4. request.setIterationId(iterationId);
  5. // 替换成OSS图片URL
  6. request.setDataUrls("test-bucket.oss-cn-beijing.aliyuncs.com/predict/1.jpg");
  7. PredictImageResponse response = client.getAcsResponse(request);
  8. // 等待一会
  9. TimeUnit.SECONDS.sleep(10);
  10. // 查询结果
  11. List<PredictImageResponse.PredictData> datas = response.getPredictDatas();
  12. List<String> dataIds = new ArrayList<>(datas.size());
  13. for (PredictImageResponse.PredictData data : datas) {
  14. dataIds.add(data.getDataId());
  15. }
  16. DescribePredictDatasRequest request = new DescribePredictDatasRequest();
  17. request.setProjectId(projectId);
  18. request.setIterationId(iterationId);
  19. request.setDataIds(dataIds);
  20. PredictImageResponse response = client.getAcsResponse(request);
  21. // 输出预测结果
  22. System.out.println(JSON.toJSONString(response));