文档

物体检测训练和预测

更新时间:

简介

物体检测是识别图片中有多个主体、位置信息及数量。

前提条件

在对物体检测进行训练之前,要准备好如下数据:

  • 开通OSS授权。

  • 用于训练的图片。

  • 图片的标签和对应的标注信息。

操作步骤

下面将以JAVA SDK为例,详细描述如何训练自己的物体检测模型。操作步骤如下:

  1. 将图片上传到OSS。通过OSS控制台,或者通过OSS SDK将图片上传到OSS对应区域的Bucket中

  2. 调用CreateProject接口或者在控制台上创建项目。接口核心示例代码:

         CreateProjectRequest request = new CreateProjectRequest();
         request.setName("物体检测测试");
         request.setDescription("物体检测描述");
         request.setProType("ObjectDetection");
         CreateProjectResponse response = client.getAcsResponse(request)
         // 保存项目ID。
         String projectId = response.getProject().getProjectId();
  3. 调用CreateTag接口或者控制台上创建图片标签,每个项目最少需要两个以上的标签。例子中,将创建苹果和香蕉的标签。核心代码:

        CreateTagRequest request = new CreateTagRequest();
        // 创建项目时,返回的项目ID。
        request.setProjectId(projectId);
        request.setName("苹果");
        request.setDescription("苹果的描述");
        CreateTagResponse response = client.getAcsResponse(request);
        // 保存苹果的标签ID。
        String appleTagId = response.getTag().getTagId();
    
        request.setName("香蕉");
        request.setDescription("香蕉的描述");
        response = client.getAcsResponse(request);
        // 保存香蕉的标签ID。
        String bananaTagId = response.getTag().getTagId();
  4. 调用CreateTrainDatasFromUrls接口,将OSS的文件数据添加到训练集。然后调用CreateTrainDataRegionTag接口标注数据

     // 添加苹果训练数据。
     for(int i=0; i<15; i++) {
         CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest();
         request.setProjectId(projectId);
         // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。
         request.setUrls(String.Format("http://xxxxx.oss-cn-beijing.aliyuncs.com/apple/%d.jpg", i));
         CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);
         String dataId = response.getTrainDatas().get(0).getDataId();
         CreateTrainDataRegionTagRequest request = new CreateTrainDataRegionTagRequest();
         request.setProjectId(projectId);
         request.setDataId(dataId);
         JSONArray array = new JSONArray();
         JSONObject object = new JSONObject();
         object.put("TagId", appleTagId);
         JSONObject region = new JSONObject();
         // 改成物体的位置信息,Left, Top, Width, Height
         region.put("Left", 0);
         region.put("Top", 0);
         region.put("Width", 640);
         region.put("Height", 480);
    
         object.put("Region", region);
         array.add(object);
         request.setTagItems(array.toJSONString());
         CreateTrainDataRegionTagResponse response = client.getAcsResponse(request);
     }
    
     // 添加香蕉的训练数据。
     for(int i=0; i<15; i++) {
         CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest();
         request.setProjectId(projectId);
         // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。
         request.setUrls(String.Format("http://xxxxx.oss-cn-beijing.aliyuncs.com/banana/%d.jpg", i));
         CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request);
         String dataId = response.getTrainDatas().get(0).getDataId();
         CreateTrainDataRegionTagRequest request = new CreateTrainDataRegionTagRequest();
         request.setProjectId(projectId);
         request.setDataId(dataId);
         JSONArray array = new JSONArray();
         JSONObject object = new JSONObject();
         object.put("TagId", bananaTagId);
         JSONObject region = new JSONObject();
         // 改成物体的位置信息,Left, Top, Width, Height
         region.put("Left", 0);
         region.put("Top", 0);
         region.put("Width", 640);
         region.put("Height", 480);
    
         object.put("Region", region);
         array.add(object);
         request.setTagItems(array.toJSONString());
         CreateTrainDataRegionTagResponse response = client.getAcsResponse(request);
     }
  5. 调用TrainProject接口,或者控制台,进行项目的训练。

         TrainProjectRequest request = new TrainProjectRequest();
         request.setProjectId(projectId);
         TrainProjectResponse response = client.getAcsResponse(request);
         // 保存迭代ID。
         String iterationId = response.getIterationId();
    
         // 等待迭代完成
         while(true) {
             DescribeTrainResultRequest request = new DescribeTrainResultRequest();
             request.setProjectId(projectId);
             request.setIterationId(iterationId);
             TrainProjectResponse response = client.getAcsResponse(request);
             if ("TrainSuccess".equals(response.getTrainResult().getStatus())) {
                 break;
             }
             TimeUnit.SECONDS.sleep(5);
         }
  6. 预测图片训练结束之后,可以拿到训练时的迭代ID,进行预测。预测作业为异步接口,提交完成之后,通过查询接口进行查询预测结果。核心代码:

         // 提交图片预测。
         PredictImageRequest request = new PredictImageRequest();
         request.setProjectId(projectId);
         request.setIterationId(iterationId);
         // 替换成OSS图片URL
         request.setDataUrls("test-bucket.oss-cn-beijing.aliyuncs.com/predict/1.jpg");
         PredictImageResponse response = client.getAcsResponse(request);
    
         // 等待一会
         TimeUnit.SECONDS.sleep(10);
    
         // 查询结果
         List<PredictImageResponse.PredictData> datas = response.getPredictDatas();
         List<String> dataIds = new ArrayList<>(datas.size());
         for (PredictImageResponse.PredictData data : datas) {
             dataIds.add(data.getDataId());
         }
         DescribePredictDatasRequest request = new DescribePredictDatasRequest();
         request.setProjectId(projectId);
         request.setIterationId(iterationId);
         request.setDataIds(dataIds);
         PredictImageResponse response = client.getAcsResponse(request);
         // 输出预测结果
         System.out.println(JSON.toJSONString(response));
  • 本页导读 (0)
文档反馈