多模态数据管理和使用

1. 概述

多模态数据管理针对图片等多模态数据,可通过多模态大模型、Embedding模型等进行预处理(智能打标、语义索引),形成丰富的元数据。借助这些元数据,支持对多模态数据进行搜索、筛选等操作,便于快速挖掘特定场景的数据子集,用于进一步的数据标注、训练等流程。同时,PAI数据集还开放了全套OpenAPI,便于在自建平台中集成。产品架构如下图所示:

image

2. 使用限制

当前PAI多模态数据管理有如下使用限制:

  • 使用地域:当前支持杭州、上海、深圳、乌兰察布4个地域;

  • 存储类型:当前仅支持在OSS对象存储中使用PAI多模态数据管理;

  • 文件类型:当前仅支持图片类型文件,支持文件格式:jpg、 jpeg、png、gif、bmp、tiff、webp;

  • 文件数量:支持单个数据集版本最大1,000,000个文件,如有特殊需求可联系PAI PDSA扩容;

  • 使用模型:

    • 打标模型:支持使用百炼平台-Qwen VL Max/Plus模型;

    • 索引模型:支持使用PAI-Model Gallery上的GME模型作为索引模型,在PAI-EAS上部署使用;

  • 元信息存储:

    • 元数据:元数据安全存储于PAI内置的元数据库;

    • Embedding向量:支持存储于下列自定义向量数据库中:

      • ElasticSearch(向量增强版,8.17.0版本及以上);

      • OpenSearch(向量检索版);

      • Milvus(2.4及以上版本);

      • Lindorm(向量引擎版本);

  • 数据集处理模式:目前仅支持全量模式运行智能打标任务及语义索引任务,暂不支持增量模式;

3. 使用流程

PAI多模态数据管理使用说明

3.1 前置准备

3.1.1 开通PAI,创建默认工作空间并获取空间管理员权限

  1. 使用主账号开通PAI并创建工作空间。登录PAI控制台,左上角选择开通区域,然后一键授权和开通产品,详情见开通PAI并创建工作空间

  2. 操作账号授权。当使用主账号操作时,可跳过此步。当使用RAM账号时,需要具有空间管理员角色。操作账号授权请参见管理工作空间 > 成员角色配置。

3.1.2 开通百炼并创建API-KEY

开通阿里云百炼并创建API-KEY,请参考阿里云百炼服务开通

3.1.3 创建向量数据库

创建向量数据库实例

多模态数据集管理目前支持以下几种阿里云向量数据库:

  • ElasticSearch(向量增强版,8.17.0版本及以上);

  • OpenSearch(向量检索版);

  • Milvus(2.4及以上版本);

  • Lindorm(向量引擎版本);

各个云向量数据库实例创建请参考对应云产品文档。

网络配置以及白名单配置

  • 公网方式

    若向量库实例开通了公网地址,将下面的IP列表添加到实例的公网访问白名单地址列表后,多模态数据管理服务即可通过公网访问此实例。ElasticSearch白名单设置请参见配置实例公网或私网访问白名单

    地域

    IP列表

    杭州

    47.110.230.142,47.98.189.92

    上海

    47.117.86.159,106.14.192.90

    深圳

    47.106.88.217,39.108.12.110

    乌兰察布

    8.130.24.177,8.130.82.15

  • 私网方式

    请提交工单申请。

创建向量索引表

在某些向量数据库中,向量索引表也称为CollectionIndex。

索引表结构定义(表结构必须遵循如下定义):

表结构定义

{
    "id":"text",                    //主键id,在OpenSearch中需要定义,其他数据库中默认存在,无需定义
    "index_set_id": "keyword",      //索引集ID,需支持索引
    "file_meta_id": "text",         //文件元数据ID   
    "dataset_id": "text",           //数据集ID
    "dataset_version": "text",      //数据集版本
    "uri": "text",                  //OSS文件的uri
    "file_vector": {                //向量字段
        "type": "float",            //向量类型:浮点型
        "dims": 1536,               //向量维度,自定义
        "similarity": "DotProduct"  //向量距离算法,余弦距离或点积距离
    }
}

本文中以ElasticSearch为例通过Python创建语义索引表(其他类型向量数据库索引表的创建请参考对应云产品使用文档)。示例代码如下:

ElasticSearch创建语义索引表示例代码

from elasticsearch import Elasticsearch

# 1. 连接阿里云 Elasticsearch 实例,
# 注意:
# (1)需要安装3.9以上python版本:python3 -V
# (2)elasticsearch客户端版本需安装8.x版本:pip show elasticsearch  !!!!
# (3)如果使用vpc地址,调用方需跟es实例的vpc打通。否则使用公网连接地址并配置调用方公网IP到es白名单。
# 默认的userName为 elastic
es_client = Elasticsearch(
    hosts=["http://es-cn-l4p***5z.elasticsearch.aliyuncs.com:9200"],
    basic_auth=("{userName}", "{password}"),
)

# 2. 定义索引名称和结构,默认采用HNSW索引算法
index_name = "dataset_embed_test"
index_mapping = {
    "settings": {
        "number_of_shards": 1,          # 表分片数
        "number_of_replicas": 1         # 表副本数
    },
    "mappings": {
        "properties": {
            "index_set_id": {
                "type": "keyword"
            },
            "uri": {
                "type": "text"
            },
            "file_meta_id": {
                "type": "text"
            },
            "dataset_id": {
                "type": "text"
            },
            "dataset_version": {
                "type": "text"  
            },
            "file_vector": {
                "type": "dense_vector",  # 定义file_vector为密集向量类型
                "dims": 1536,  # 向量维度为 1536
                "similarity": "dot_product"  # 相似度计算方法为点积
            }
        }
    }
}

# 3. 创建索引
if not es_client.indices.exists(index=index_name):
    es_client.indices.create(index=index_name, body=index_mapping)
    print(f"Index {index_name} create success!")
else:
    print(f"Index {index_name} existed, do not create repeatedly.")

# 4. 查看创建的索引表结构(可选)
# indexes = es_client.indices.get(index=index_name)
# print(indexes)

3.2 创建数据集

  1. 进入PAI工作空间,在左侧菜单栏单击AI资产管理 > 数据集 > 新建数据集,进入数据集配置页面。

    image

  2. 配置数据集参数,关键参数如下,其他参数默认即可。

    1. 存储类型对象存储(OSS)

    2. 类型高级型

    3. 内容类型图片

    4. OSS路径:选择数据集的OSS存储路径。如果您没有准备数据集,可以下载示例数据集retrieval_demo_data,并上传至OSS,体验多模态数据管理功能。

    说明

    此处导入文件/文件夹,仅在系统记录中设置了路径,不会复制数据。

    image

    然后单击确定,创建数据集。

3.3 创建连接

3.3.1 创建智能打标模型连接

  1. 进入PAI工作空间,在左侧菜单栏单击AI资产管理 > 连接 > 模型服务 > 新建连接,打开新建连接页面。

    image

  2. 选择百炼大模型服务,并配置百炼api_key

    image

  3. 创建成功后,在列表页可以看到创建的百炼大模型服务。

    image

3.3.2 创建自定义语义索引模型连接

  1. 在左侧菜单栏单击Model Gallery,找到并部署GME多模态检索模型,得到一个EAS服务。部署大约需要5分钟,当处于运行中时,代表部署成功。

    重要

    当您不需要使用该索引模型时,可停止和删除该服务,以免继续产生费用。

    image

  2. 进入PAI工作空间,在左侧菜单栏单击AI资产管理 > 连接 > 模型服务 > 新建连接,打开新建连接页面。

  3. 选择通用多模态Embedding模型服务,单击EAS服务输入框,选择刚部署的GME多模态检索模型。

    image

    image

  4. 创建成功后,在列表页可以看到创建的模型连接服务。

    image

3.3.3 创建向量数据库连接

  1. 在左侧菜单栏单击AI资产管理 > 连接 > 数据库 > 新建连接,打开新建连接页面。

    image

  2. 多模态检索服务支持Milvus/Lindorm/OpenSearch/ElasticSearch向量数据库,这里以ElasticSearch为例创建连接。选择检索分析服务-Elasticsearch,配置uriusernamepassword等信息。

    image

  3. 创建成功后,在列表页可以看到创建的向量数据库连接。

    image

3.4 创建智能打标任务

3.4.1 创建智能标签定义

  1. 在左侧菜单栏单击AI资产管理 > 数据集 > 智能标签定义 > 新建智能标签,打开标签配置页面,配置示例如下:

    image

    • 引导提示词:作为一个拥有多年驾驶经验的老司机,你有着非常丰富的高速以及城市道路驾驶经验。

    • 标签定义

      自动驾驶示例标签定义

      {
          "反光贴条": " 通常为黄色,或者黄黑间隔,贴在墙角等永久的突出障碍物上,用于提示车主避让。为条带状,不是锥筒,不是地锁,不是水马!",
          "地锁": "也叫车位锁,升起时可以阻止车位被占用。存在地锁时,务必输出地锁为升起还是降下状态。有升起架子时为升起状态,否则为降下状态。",
          "亮灯的工程车": "有左右2个箭头状灯光的,并且亮起的,就是目标,否则不存在;",
          "侧翻的车辆": "车辆侧翻在地;",
          "倒地的水马": "水马是一种用于分割路面或形成阻挡的塑制壳体障碍物,一般为红色塑料墙形态。一般用于道路交通设施,在高速路、城市道路、及天桥街道路口常见。比锥筒显著大,且为片状结构。水马正常为直立,卧倒在地的话需要明确指出。",
          "倒地的锥桶": "又称锥形交通路标、雪糕筒,俗称路锥、三角锥,为锥形临时道路标示。杆状,片状的障碍物不是锥筒,因为其不是锥形。锥筒可能被车撞倒,若图片中存在锥筒,而又需要判断是否为倒地时,可以观察锥筒底部(圆锥的锥底)是否与地面接触,如接触则不为倒地,否则为倒地。",
          "充电车位": "靠墙且明显带着充电枪,充电桩设备的,或者写着新能源车位的为充电车位,只可能出现在停车场(室内室外均可能),注意地锁跟充电无关。",
          "减速带": "一般为黄黑相间,或者黄色,为横在路上垂直于路边缘的窄条突起,用于车辆减速。不可能出现在停车位内。",
          "减速车道线": " 车道两侧鱼骨状虚线,在实线内测,2侧均有才为减速车道线。",
          "匝道": "只有明确看到高速路上的大弯道,一般匝道都在高速路干道的右侧,进出收费站才可判定存在。",
          "地面阴影": "地面有明显的影子的情况。",
          "多云": "只有可以看到天空,且天空有明显的云彩的情况才可判定存在。",
          "炫光的车": "前方灯光发生炫光(灯光由单点变成了线条状光线)情况,通常在夜晚或者下雨天时导致。",
          "左转、右转、掉头箭头": "车道地面上所画的乳白色箭头标志(少数为黄色),不是指高速路的绿白色指示道路右弯的绿白箭头。判断是否存在这些箭头时,首先只有车道地面中间的明晰的箭头标志才是目标,其他的路边的等一律不是。若地面存在箭头,判断箭头朝向方法,右转箭头为从箭头根部到箭头尖部为顺时针旋转;左转箭头为从箭头根部到箭头尖部为类似逆时针旋转;U型的箭头为调头箭头。",
          "斑马线": "仅可能在路面(停车场内也可能),路口存在,一定为白色线条重复间隔分布平行于路边,用于行人通行。不可能出现在高速路,高速匝道,隧道中。",
          "曝光": " 白天,阳光直射导致镜头曝光(只可能发生在白天)。",
          "机动车": "视野中有其他机动车辆。",
          "汇入汇出": " 高速路多条并未一条,或者一条分为多条路的地方。",
          "路口": " 路口,且路口内没有车道线的情况(指路口段内无,路口之外有没有影响)。",
          "禁停牌": " 悬挂或者立在地上的牌子,写着禁止停车文字,或者绘有圆圈包围的P加斜线的标识。",
          "车道线": " 路上的车道线,尤其关注车道线模糊的情况。",
          "道路上掉落的石头、轮胎": " 路上的影响通车的障碍物。",
          "隧道": " 进隧道口,出隧道口尤其注意分辨。",
          "雨天地面潮湿": "下雨天地面湿滑情况。",
          "非机动车": "包括自行车,电动车,轮椅,单轮车,超市推车等非机动车物体,可能被停放在路边,车位,行驶在路上等。"
        }

3.4.2 创建智能打标离线任务

  1. 单击自定义数据集,单击数据集名称进入详情页面,然后再单击智能打标任务

    image

    image

  2. 进入任务页面,单击新建智能打标任务,并配置任务参数。

    • 数据集版本:选择需要打标的版本如v1;

    • 智能打标模型连接:选择创建的百炼模型连接;

    • 智能打标模型:支持通义千问VL-MAX和通义千问VL-Plus;

    • 智能标签定义:选择刚创建的智能标签定义;

    说明

    打标模式暂时仅支持对数据集版本中的全量文件打标。

    image

  3. 智能打标任务创建成功后,在任务列表可以看到创建的打标任务。观察启动的智能打标任务,可点击列表右侧链接查看日志或停止打标任务

    image

    说明

    首次启动智能打标任务,将进行元数据的构建,所需时间可能较长,请耐心等待。

3.5 创建语义索引任务

  1. 单击数据集名称进入详情页面,在索引库配置区域,单击编辑按钮。

    image

  2. 配置索引库。

    • 索引模型连接:选择3.3.2中创建的索引模型连接

    • 索引数据库连接:选择3.3.3中创建的索引库连接;

    • 索引数据库表:输入创建向量索引表中创建的索引表名称,即:dataset_embed_test;

    单击保存 > 立即刷新。然后会创建一个所选数据集版本的语义索引任务,对版本中全量文件更新语义索引。可单击数据集详情页面右上角语义索引任务查看任务详情。

    image

    如果没有单击立即刷新,而是取消,您可以通过在语义索引任务页面手动新建语义索引任务更新索引。
    说明

    首次启动语义索引任务,将进行元数据的构建,所需时间可能较长,请耐心等待。

3.6 数据预览

  1. 待智能打标和语义索引任务完成后,在数据集详情页面,单击查看数据可预览该数据集版本内的图片。

    image

    image

  2. 点击具体图片,可查看大图,并查看图片中包含的标签

    image

3.7 数据搜索

  1. 在“查看数据”界面的左侧工具栏内,可进行索引检索标签搜索,按下Enter或单击搜索即可开始搜索。

  2. 索引检索,文本关键词搜索:基于“语义索引”的结果,通过关键词与图片索引结果的向量匹配进行搜索。在“高级设置”中可以设置topk、Score 阈值等参数。

    image

  3. 索引检索,以图搜图:基于“语义索引”的结果,用户可以从本地上传图片或者选择oss中的图片,与数据集图片索引结果的向量匹配进行搜索。在“高级设置”中可以设置topk、Score阈值等参数。

    image

  4. 标签搜索:基于“智能打标”的结果,通过关键词与图片标签的匹配进行搜索。可同时按照包含以下任意标签同时包含以下标签排除以下任意标签的逻辑进行搜索。

    image

  5. 元数据搜索:可以按照文件名、存储路径、文件最后修改时间进行搜索。

    image

    以上所有搜索条件为AND关系。

3.8 搜索结果集的导出

说明

此步骤的目的,是将搜索结果导出为文件列表索引,用于后续的模型训练或数据分析。目前搜索结果集导出的能力以API的形式提供。

  1. 创建导出任务,通过调用API CreateDatasetJob,关键参数如下:

    1. JobAction: FileMetaExport,指定任务类型为文件导出类型;

    2. JobSpec: 格式为一个JSON格式的字符串,样例为:

      {
        "fileUri":"oss://mybucket.oss-cn-beijing.aliyuncs.com/mypath",
        "queryType":"TAG",
        "queryText":"汽车",
        "topK":100,
        "scoreThreshold":0.6
      }

      其中:

      1. fileUri: 设置导出结果的存储位置,例如 oss://mybucket.oss-cn-beijing.aliyuncs.com/mypath。最终导出文件路径为: {fileUri}/{DatasetId}/{DatasetId-DatasetVersionName-timestamp}.jsonl

      2. 除了fileUri之外,JSON中的其他入参,与获取数据集文件元数据列表ListDatasetFileMetas中除DatasetId、DatasetVersion、WorkspaceId之外的其他参数一样均可按需传入。

  2. 查看导出任务状态及结果,通过调用API GetDatasetJob。当任务状态为 Succeeded 时,表示导出完成。

    在指定的OSS路径下查看导出结果文件,文件格式为JSONL(每行一个JSON对象)。示例内容如下:

    {"file_name": "camera_1.jpg", "path": "oss://bucket1/camera_1.jpg"}
    {"file_name": "camera_2.jpg", "path": "oss://bucket1/camera_2.jpg"}
    ...
    说明

    GetDatasetJob中的入参DatasetJobId应当使用前面创建数据集导出任务时CreateDatasetJob返回的DatasetJobId

  3. 使用导出结果,将导出结果文件与原数据集挂载至对应的训练环境(如DLCDSW实例),通过代码实现读取导出结果文件索引,并从原数据集中加载目标文件进行模型训练或分析。

4. (可选)自定义语义索引模型

您可以通过微调自定义语义检索模型,在EAS部署成功后,可以按照3.3.2中的步骤创建模型连接,在后续的多模态数据管理中使用。

4.1 数据准备

本文提供了示例数据retrieval_demo_data,您可以单击下载。

4.1.1 数据格式要求

每个数据样本以一行JSON格式保存到dataset.jsonl文件中,必须包含以下字段:

  • image_id: 图像唯一标识符(如图片名称或唯一ID)。

  • tags: 与该图像关联的文本标签列表,标签为字符串数组。

示例格式:

{  
    "image_id": "c909f3df-ac4074ed",  
    "tags": ["银色的轿车", "白色的SUV", "城市街道", "下雪", "夜晚"], 
}

4.1.2 文件组织结构

将所有图像文件放入一个文件夹(images),并将dataset.jsonl文件放在与图像文件夹同级的目录中。

目录示例:

├── images
│   ├── image1.jpg
│   ├── image2.jpg
│   └── image3.jpg
└── dataset.jsonl  
重要

务必使用原始文件名dataset.jsonl,文件夹名images不可更改。

4.2 模型训练

  1. 在 Model Gallery 中找到检索相关的模型, 更具所需的模型大小和计算资源,选择合适的模型来进行微调和部署。

    image

    微调 VRAM bs=4

    微调(4*A800)train_samples/second

    部署 VRAM

    向量维度

    GME-2B

    14G

    16.331

    5G

    1536

    GME-7B

    35G

    13.868

    16G

    3584

  2. 以训练GME-2B模型为例,单击训练,填入数据地址 (默认地址即为示例数据地址),填写模型输出路径,即可开始训练模型。

    image

    image

4.3 模型部署

训练完的模型可以训练任务中,点击部署来部署微调后的模型

点击Model Gallery模型选项卡的部署按钮,即可部署原始的GME模型。

image

部署完成后,可在页面中获得对应的 EAS 访问地址Tokenimage

4.4 模型服务调用

输入参数

名称

类型

是否必填

示例值

描述

model

String

pai-multimodal-embedding-v1

模型类型,后续可以添加用户自定义模型的支持 / 进行基模型的版本迭代

contents.input

list(dict) or list(str)

input = [{'text': text}]

input=[xxx,xxx,xxx,...]

input = [{'text': text},{'image', f"data:image/{image_format};base64,{image64}"}]

embedding的内容。

当前只支持 text, image

输出参数

名称

类型

示例值

描述

status_code

Integer

200

http状态码。

200 请求成功

204 请求部分成功

400 请求失败

message

list(str)

['Invalid input data: must be a list of strings or dict']

报错信息

output

dict

见下表

embedding结果

dashscope 返回结果是一个 {'output', {'embeddings': list(dict), 'usage': xxx, 'request_id':xxx}}(暂时不用 'usage', 'request_id')

embeddings 的元素包含以下key (失败的index 会把错误原因加在message中)

名称

类型

示例值

描述

index

数据id

0

http状态码。

200、400、500

embedding

List[Float]

[0.0391846,0.0518188,.....,-0.0329895,

0.0251465]

1536

embedding后的向量

type

String

"Internal execute error."

错误信息

调用示例代码

import base64
import json
import os
import sys
from io import BytesIO

import requests
from PIL import Image, PngImagePlugin
import numpy as np

ENCODING = 'utf-8'

hosts = 'EAS URL'
head = {
    'Authorization': 'EAS TOKEN'
}

def encode_image_to_base64(image_path):
    """
    将图像文件编码为 Base64 字符串
    """
    with open(image_path, "rb") as image_file:
        # 读取图像文件的二进制数据
        image_data = image_file.read()
        # 编码为 Base64 字符串
        base64_encoded = base64.b64encode(image_data).decode('utf-8')
    
    return base64_encoded

if __name__=='__main__':
    iamege_path = "path_to_your_image"
    text = 'prompt'

    image_format = 'jpg'
    input_data = []
    
    image64 = encode_image_to_base64(image_path)
    input_data.append({'image': f"data:image/{image_format};base64,{image64}"})

    input_data.append({'text': text})

    datas = json.dumps({
        'input': {
            'contents': input_data
        }
    })
    r = requests.post(hosts, data=datas, headers=head)
    data = json.loads(r.content.decode('utf-8'))

    if data['status_code']==200:
        if len(data['message'])!=0:
            print('Part failed for the following reasons.')
            print(data['message'])

        for result_item in data['output']['embeddings']:
            print('The following succeed.')
            print('index', result_item['index'])
            print('type', result_item['type'])
            print('embedding', len(result_item['embedding']))
    else:
        print('Processed fail')
        print(data['message'])

输出示例:

{
    "status_code": 200,
    "message": "",
    "output": {
        "embeddings": [
            {
                "index": 0,
                "embedding": [
                    -0.020782470703125,
                    -0.01399993896484375,
                    -0.0229949951171875,
                    ...
                ],
                "type": "text"
            }
        ]
    }
}

4.5 模型评测

在我们的示例数据上的评测效果如下(所使用的评测文件):

原始模型Precision

微调1epoch的模型Precision

gme2b

Precision@1 0.3542

Precision@5 0.5280

Precision@10 0.5923

Precision@50 0.5800

Precision@100 0.5792

Precision@1 0.4271

Precision@5 0.6480

Precision@10 0.7308

Precision@50 0.7331

Precision@100 0.7404

gme7b

Precision@1 0.3958

Precision@5 0.5920

Precision@10 0.6667

Precision@50 0.6517

Precision@100 0.6415

Precision@1 0.4375

Precision@5 0.6680

Precision@10 0.7590

Precision@50 0.7683

Precision@100 0.7723

模型评测示例脚本

import base64
import json
import os
import requests
import numpy as np
import torch
from tqdm import tqdm
from collections import defaultdict


# Constants
ENCODING = 'utf-8'
HOST_URL = 'http://1xxxxxxxx4.cn-xxx.pai-eas.aliyuncs.com/api/xxx'
AUTH_HEADER = {'Authorization': 'ZTg*********Mw=='}

def encode_image_to_base64(image_path):
    """将图像文件编码为 Base64 字符串"""
    with open(image_path, "rb") as image_file:
        image_data = image_file.read()
        base64_encoded = base64.b64encode(image_data).decode(ENCODING)
    return base64_encoded


def load_image_features(feature_file):
    print("Begin to load image features...")
    image_ids, image_feats = [], []
    with open(feature_file, "r") as fin:
        for line in tqdm(fin):
            obj = json.loads(line.strip())
            image_ids.append(obj['image_id'])
            image_feats.append(obj['feature'])
    image_feats_array = np.array(image_feats, dtype=np.float32)
    print("Finished loading image features.")
    return image_ids, image_feats_array


def precision_at_k(predictions, gts, k):
    """
    计算前K个结果的精确率。
    
    :param predictions: [(image_id, similarity_score), ...]
    :param gts: set of ground truth image_ids
    :param k: int, 前K个结果
    :return: float, 精确率
    """
    if len(predictions) > k:
        predictions = predictions[:k]
    
    predicted_ids = {p[0] for p in predictions}
    relevant_and_retrieved = predicted_ids.intersection(gts)
    precision = len(relevant_and_retrieved) / k
    return precision


def main():
    root_dir = '/mnt/data/retrieval/data/'
    data_dir = os.path.join(root_dir, 'images')
    tag_file = os.path.join(root_dir, 'meta/test.jsonl')
    model_type = 'finetune_gme7b_final'
    save_feature_file = os.path.join(root_dir, 'features', f'features_{model_type}_eas.jsonl')
    final_result_log = os.path.join(root_dir, 'results', f'retrieval_{model_type}_log_eas.txt')
    final_result = os.path.join(root_dir, 'results', f'retrieval_{model_type}_log_eas.jsonl')

    os.makedirs(os.path.join(root_dir, 'features'), exist_ok=True)
    os.makedirs(os.path.join(root_dir, 'results'), exist_ok=True)

    tag_dict = defaultdict(list)
    gt_image_ids = []
    with open(tag_file, 'r') as f:
        lines = f.readlines()
        for line in lines:
            data = json.loads(line.strip())
            gt_image_ids.append(data['image_id'])
            img_id = data['image_id'].split('.')[0]
            for caption in data['tags']:
                tag_dict[caption.strip()].append(img_id)

    print('Total tags:', len(tag_dict.keys()))

    prefix = ''
    texts = [prefix + text for text in tag_dict.keys()]
    images = [os.path.join(data_dir, i+'.jpg') for i in gt_image_ids]
    print('Total images:', len(images))

    encode_images = True
    if encode_images:
        with open(save_feature_file, "w") as fout:
            for image_path in tqdm(images):
                image_id = os.path.basename(image_path).split('.')[0]
                image64 = encode_image_to_base64(image_path)
                input_data = [{'image': f"data:image/jpg;base64,{image64}"}]

                datas = json.dumps({'input': {'contents': input_data}})
                r = requests.post(HOST_URL, data=datas, headers=AUTH_HEADER)

                data = json.loads(r.content.decode(ENCODING))
                if data['status_code'] == 200:
                    if len(data['message']) != 0:
                        print('Part failed:', data['message'])
                    for result_item in data['output']['embeddings']:
                        fout.write(json.dumps({"image_id": image_id, "feature": result_item['embedding']}) + "\n")
                else:
                    print('Processed fail:', data['message'])

    image_ids, image_feats_array = load_image_features(save_feature_file)

    top_k_list = [1, 5, 10, 50, 100]
    top_k_list_precision  = [[] for _ in top_k_list]

    with open(final_result, 'w') as f_w, open(final_result_log, 'w') as f:
        for tag in tqdm(texts):
            datas = json.dumps({'input': {'contents': [{'text': tag}]}})
            r = requests.post(HOST_URL, data=datas, headers=AUTH_HEADER)
            data = json.loads(r.content.decode(ENCODING))

            if data['status_code'] == 200:
                if len(data['message']) != 0:
                    print('Part failed:', data['message'])

                for result_item in data['output']['embeddings']:
                    text_feat_tensor = result_item['embedding']
                    idx = 0
                    score_tuples = []
                    batch_size = 128
                    while idx < len(image_ids):
                        img_feats_tensor = torch.from_numpy(image_feats_array[idx:min(idx + batch_size, len(image_ids))]).cuda()
                        batch_scores = torch.from_numpy(np.array(text_feat_tensor)).cuda().float() @ img_feats_tensor.t()
                        for image_id, score in zip(image_ids[idx:min(idx + batch_size, len(image_ids))], batch_scores.squeeze(0).tolist()):
                            score_tuples.append((image_id, score))
                        idx += batch_size
                    
                    predictions = sorted(score_tuples, key=lambda x: x[1], reverse=True)
            else:
                print('Processed fail:', data['message'])

            gts = tag_dict[tag.replace(prefix, '')]

            # Write result
            predictions_tmp = predictions[:10]
            result_dict = {'tag': tag, 'gts': gts, 'preds': [pred[0] for pred in predictions_tmp]}
            f_w.write(json.dumps(result_dict, ensure_ascii=False, indent=4) + '\n')

            for top_k_id, k in enumerate(top_k_list):
                need_exit = False

                if k > len(gts):
                    k = len(gts)
                    need_exit = True

                prec = precision_at_k(predictions, gts, k)

                f.write(f'Tag {tag}, Len(GT) {len(gts)}, Precision@{k} {prec:.4f} \n')
                f.flush()

                if need_exit:
                    break
                else:
                    top_k_list_precision[top_k_id].append(prec)
                    
    for idx, k in enumerate(top_k_list):
        print(f'Precision@{k} {np.mean(top_k_list_precision[idx]):.4f}')


if __name__ == "__main__":
    main()