全部产品
云市场
云游戏

物体检测训练和预测

更新时间:2019-10-25 19:01:00

简介

物体检测是识别图片中有多个主体、位置信息及数量。

前提条件

在对物体检测进行训练之前,要准备好如下数据:

  • 开通OSS授权。

  • 用于训练的图片。

  • 图片的标签和对应的标注信息。

操作步骤

下面将以JAVA SDK为例,详细描述如何训练自己的物体检测模型。操作步骤如下:

  1. 将图片上传到OSS。通过OSS控制台,或者通过OSS SDK将图片上传到OSS对应区域的Bucket中

  2. 调用CreateProject接口或者在控制台上创建项目。接口核心示例代码:

    1. CreateProjectRequest request = new CreateProjectRequest();
    2. request.setName("物体检测测试");
    3. request.setDescription("物体检测描述");
    4. request.setProType("ObjectDetection");
    5. CreateProjectResponse response = client.getAcsResponse(request)
    6. // 保存项目ID。
    7. String projectId = response.getProject().getProjectId();
  3. 调用CreateTag接口或者控制台上创建图片标签,每个项目最少需要两个以上的标签。例子中,将创建苹果和香蕉的标签。核心代码:

    1. CreateTagRequest request = new CreateTagRequest();
    2. // 创建项目时,返回的项目ID。
    3. request.setProjectId(projectId);
    4. request.setName("苹果");
    5. request.setDescription("苹果的描述");
    6. CreateTagResponse response = client.getAcsResponse(request);
    7. // 保存苹果的标签ID。
    8. String appleTagId = response.getTag().getTagId();
    9. request.setName("香蕉");
    10. request.setDescription("香蕉的描述");
    11. response = client.getAcsResponse(request);
    12. // 保存香蕉的标签ID。
    13. String bananaTagId = response.getTag().getTagId();
  4. 调用CreateTrainDatasFromUrls接口,将OSS的文件数据添加到训练集。然后调用CreateTrainDataRegionTag接口标注数据

    1. // 添加苹果训练数据。
    2. for(int i=0; i<15; i++) {
    3. CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest();
    4. request.setProjectId(projectId);
    5. // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。
    6. request.setUrls(String.Format("http://xxxxx.oss-cn-beijing.aliyuncs.com/apple/%d.jpg", i));
    7. CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);
    8. String dataId = response.getTrainDatas().get(0).getDataId();
    9. CreateTrainDataRegionTagRequest request = new CreateTrainDataRegionTagRequest();
    10. request.setProjectId(projectId);
    11. request.setDataId(dataId);
    12. JSONArray array = new JSONArray();
    13. JSONObject object = new JSONObject();
    14. object.put("TagId", appleTagId);
    15. JSONObject region = new JSONObject();
    16. // 改成物体的位置信息,Left, Top, Width, Height
    17. region.put("Left", 0);
    18. region.put("Top", 0);
    19. region.put("Width", 640);
    20. region.put("Height", 480);
    21. object.put("Region", region);
    22. array.add(object);
    23. request.setTagItems(array.toJSONString());
    24. CreateTrainDataRegionTagResponse response = client.getAcsResponse(request);
    25. }
    26. // 添加香蕉的训练数据。
    27. for(int i=0; i<15; i++) {
    28. CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest();
    29. request.setProjectId(projectId);
    30. // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。
    31. request.setUrls(String.Format("http://xxxxx.oss-cn-beijing.aliyuncs.com/banana/%d.jpg", i));
    32. CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);
    33. String dataId = response.getTrainDatas().get(0).getDataId();
    34. CreateTrainDataRegionTagRequest request = new CreateTrainDataRegionTagRequest();
    35. request.setProjectId(projectId);
    36. request.setDataId(dataId);
    37. JSONArray array = new JSONArray();
    38. JSONObject object = new JSONObject();
    39. object.put("TagId", bananaTagId);
    40. JSONObject region = new JSONObject();
    41. // 改成物体的位置信息,Left, Top, Width, Height
    42. region.put("Left", 0);
    43. region.put("Top", 0);
    44. region.put("Width", 640);
    45. region.put("Height", 480);
    46. object.put("Region", region);
    47. array.add(object);
    48. request.setTagItems(array.toJSONString());
    49. CreateTrainDataRegionTagResponse response = client.getAcsResponse(request);
    50. }
  5. 调用TrainProject接口,或者控制台,进行项目的训练。

    1. TrainProjectRequest request = new TrainProjectRequest();
    2. request.setProjectId(projectId);
    3. TrainProjectResponse response = client.getAcsResponse(request);
    4. // 保存迭代ID。
    5. String iterationId = response.getIterationId();
    6. // 等待迭代完成
    7. while(true) {
    8. DescribeTrainResultRequest request = new DescribeTrainResultRequest();
    9. request.setProjectId(projectId);
    10. request.setIterationId(iterationId);
    11. TrainProjectResponse response = client.getAcsResponse(request);
    12. if ("TrainSuccess".equals(response.getTrainResult().getStatus())) {
    13. break;
    14. }
    15. TimeUnit.SECONDS.sleep(5);
    16. }
  6. 预测图片训练结束之后,可以拿到训练时的迭代ID,进行预测。预测作业为异步接口,提交完成之后,通过查询接口进行查询预测结果。核心代码:

    1. // 提交图片预测。
    2. PredictImageRequest request = new PredictImageRequest();
    3. request.setProjectId(projectId);
    4. request.setIterationId(iterationId);
    5. // 替换成OSS图片URL
    6. request.setDataUrls("test-bucket.oss-cn-beijing.aliyuncs.com/predict/1.jpg");
    7. PredictImageResponse response = client.getAcsResponse(request);
    8. // 等待一会
    9. TimeUnit.SECONDS.sleep(10);
    10. // 查询结果
    11. List<PredictImageResponse.PredictData> datas = response.getPredictDatas();
    12. List<String> dataIds = new ArrayList<>(datas.size());
    13. for (PredictImageResponse.PredictData data : datas) {
    14. dataIds.add(data.getDataId());
    15. }
    16. DescribePredictDatasRequest request = new DescribePredictDatasRequest();
    17. request.setProjectId(projectId);
    18. request.setIterationId(iterationId);
    19. request.setDataIds(dataIds);
    20. PredictImageResponse response = client.getAcsResponse(request);
    21. // 输出预测结果
    22. System.out.println(JSON.toJSONString(response));