简介
物体检测是识别图片中有多个主体、位置信息及数量。
前提条件
在对物体检测进行训练之前,要准备好如下数据:
开通OSS授权。
用于训练的图片。
图片的标签和对应的标注信息。
操作步骤
下面将以JAVA SDK为例,详细描述如何训练自己的物体检测模型。操作步骤如下:
将图片上传到OSS。通过OSS控制台,或者通过OSS SDK将图片上传到OSS对应区域的Bucket中
调用CreateProject接口或者在控制台上创建项目。接口核心示例代码:
CreateProjectRequest request = new CreateProjectRequest(); request.setName("物体检测测试"); request.setDescription("物体检测描述"); request.setProType("ObjectDetection"); CreateProjectResponse response = client.getAcsResponse(request) // 保存项目ID。 String projectId = response.getProject().getProjectId();
调用CreateTag接口或者控制台上创建图片标签,每个项目最少需要两个以上的标签。例子中,将创建苹果和香蕉的标签。核心代码:
CreateTagRequest request = new CreateTagRequest(); // 创建项目时,返回的项目ID。 request.setProjectId(projectId); request.setName("苹果"); request.setDescription("苹果的描述"); CreateTagResponse response = client.getAcsResponse(request); // 保存苹果的标签ID。 String appleTagId = response.getTag().getTagId(); request.setName("香蕉"); request.setDescription("香蕉的描述"); response = client.getAcsResponse(request); // 保存香蕉的标签ID。 String bananaTagId = response.getTag().getTagId();
调用CreateTrainDatasFromUrls接口,将OSS的文件数据添加到训练集。然后调用CreateTrainDataRegionTag接口标注数据
// 添加苹果训练数据。 for(int i=0; i<15; i++) { CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest(); request.setProjectId(projectId); // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。 request.setUrls(String.Format("http://xxxxx.oss-cn-beijing.aliyuncs.com/apple/%d.jpg", i)); CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request); String dataId = response.getTrainDatas().get(0).getDataId(); CreateTrainDataRegionTagRequest request = new CreateTrainDataRegionTagRequest(); request.setProjectId(projectId); request.setDataId(dataId); JSONArray array = new JSONArray(); JSONObject object = new JSONObject(); object.put("TagId", appleTagId); JSONObject region = new JSONObject(); // 改成物体的位置信息,Left, Top, Width, Height region.put("Left", 0); region.put("Top", 0); region.put("Width", 640); region.put("Height", 480); object.put("Region", region); array.add(object); request.setTagItems(array.toJSONString()); CreateTrainDataRegionTagResponse response = client.getAcsResponse(request); } // 添加香蕉的训练数据。 for(int i=0; i<15; i++) { CreateTrainDatasFromUrlsRequest request = new CreateTrainDatasFromUrlsRequest(); request.setProjectId(projectId); // OSS地址URL列表,用","分隔。请将OSS地址替换成自己的地址。 request.setUrls(String.Format("http://xxxxx.oss-cn-beijing.aliyuncs.com/banana/%d.jpg", i)); CreateTrainDatasFromUrlsResponse response = client.getAcsResponse(request); String dataId = response.getTrainDatas().get(0).getDataId(); CreateTrainDataRegionTagRequest request = new CreateTrainDataRegionTagRequest(); request.setProjectId(projectId); request.setDataId(dataId); JSONArray array = new JSONArray(); JSONObject object = new JSONObject(); object.put("TagId", bananaTagId); JSONObject region = new JSONObject(); // 改成物体的位置信息,Left, Top, Width, Height region.put("Left", 0); region.put("Top", 0); region.put("Width", 640); region.put("Height", 480); object.put("Region", region); array.add(object); request.setTagItems(array.toJSONString()); CreateTrainDataRegionTagResponse response = client.getAcsResponse(request); }
调用TrainProject接口,或者控制台,进行项目的训练。
TrainProjectRequest request = new TrainProjectRequest(); request.setProjectId(projectId); TrainProjectResponse response = client.getAcsResponse(request); // 保存迭代ID。 String iterationId = response.getIterationId(); // 等待迭代完成 while(true) { DescribeTrainResultRequest request = new DescribeTrainResultRequest(); request.setProjectId(projectId); request.setIterationId(iterationId); TrainProjectResponse response = client.getAcsResponse(request); if ("TrainSuccess".equals(response.getTrainResult().getStatus())) { break; } TimeUnit.SECONDS.sleep(5); }
预测图片训练结束之后,可以拿到训练时的迭代ID,进行预测。预测作业为异步接口,提交完成之后,通过查询接口进行查询预测结果。核心代码:
// 提交图片预测。 PredictImageRequest request = new PredictImageRequest(); request.setProjectId(projectId); request.setIterationId(iterationId); // 替换成OSS图片URL request.setDataUrls("test-bucket.oss-cn-beijing.aliyuncs.com/predict/1.jpg"); PredictImageResponse response = client.getAcsResponse(request); // 等待一会 TimeUnit.SECONDS.sleep(10); // 查询结果 List<PredictImageResponse.PredictData> datas = response.getPredictDatas(); List<String> dataIds = new ArrayList<>(datas.size()); for (PredictImageResponse.PredictData data : datas) { dataIds.add(data.getDataId()); } DescribePredictDatasRequest request = new DescribePredictDatasRequest(); request.setProjectId(projectId); request.setIterationId(iterationId); request.setDataIds(dataIds); PredictImageResponse response = client.getAcsResponse(request); // 输出预测结果 System.out.println(JSON.toJSONString(response));
反馈
- 本页导读 (0)
文档反馈