阿里云ES AI多模态搜索(百炼)

更新时间:
复制为 MD 格式

本文通过代码示例展示了如何结合阿里云Elasticsearch(ES)与千问VL大模型,以提取图片特征,并利用多模态Embedding模型实现高效的多模态搜索,涵盖了以文搜图、以文搜文、以图搜图以及以图搜文等多种检索方式。

效果展示

背景信息

在多模态搜索场景中,图片和文本的非结构化数据需要被转换为向量表示,然后通过向量检索技术快速找到相似的内容。本实践使用以下工具:

  • ES:高效的向量数据库,用于存储和检索向量。

  • 千问VL:提取图片描述和关键词。更多详情请参见图像与视频理解

  • DashScope Embedding API:将图片和文本转换为向量。更多详情请参见多模态向量

其功能包括:

  • 以文搜图:输入文本查询,搜索最相似的图片。

  • 以文搜文:输入文本查询,搜索最相似的图片描述。

  • 以图搜图:输入图片查询,搜索最相似的图片。

  • 以图搜文:输入图片查询,搜索最相似的图片描述。

系统架构

下图展示了本文中使用的多模态搜索系统的整体架构。

es.svg

前提条件

  • 已创建 8.17及以上版本的 Elasticsearch实例。具体操作,请参见创建阿里云Elasticsearch实例

  • 已开通百炼服务并获得API-Key。具体操作,请参见获取API Key

  • 已安装 Python 3.8 及以上版本。

环境准备

安装依赖

pip install elasticsearch dashscope requests streamlit

下载示例数据集

执行以下命令下载并解压示例数据集:

wget https://github.com/milvus-io/pymilvus-assets/releases/download/imagedata/reverse_image_search.zip
unzip -q -o reverse_image_search.zip

示例数据集包含一个CSV文件reverse_image_search.csv和若干图片文件。

目录结构

创建工作目录并按以下结构组织文件:

multi_modal_search/
├── reverse_image_search.csv    # 数据集CSV文件
├── train/                      # 图片目录(解压后生成)
│   └── *.jpg
├── scripts/                    # 脚本目录
│   ├── write.py               # 数据写入脚本
│   ├── read.py                # 查询脚本
│   └── demo.py                # 前端演示脚本

核心代码介绍

写入流程

在写入流程中,首先利用千问VL模型提取图片描述信息,并将其存储在text_input字段中。接着,通过多模态Embedding模型,将图片及其描述分别转换为对应的向量表示(image_embeddingtext_embedding),以便后续进行跨模态检索或分析。

为了简化演示,本示例仅从前200张图片中提取数据并完成上述流程。

import os
import csv
import base64
import time
import json
import logging
import requests
from http import HTTPStatus
from elasticsearch import Elasticsearch
import dashscope
from dashscope import MultiModalConversation

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ==================== 配置 ====================
# ES配置
ES_HOST = "<ES_HOST>"           # 替换为您的ES实例地址
ES_PORT = 9200
ES_USER = "<ES_USER>"           # 替换为您的ES用户名
ES_PASSWORD = "<ES_PASSWORD>"   # 替换为您的ES密码
ES_URL = f"http://{ES_HOST}:{ES_PORT}"

# 百炼API配置
DASHSCOPE_API_KEY = "<DASHSCOPE_API_KEY>"  # 替换为您的API Key

# 索引配置
INDEX_NAME = "multi_modal_test"
PIPELINE_NAME = "bailian_pipeline"
INFERENCE_ID = "bailian_mm"

# 数据配置
CSV_FILE = "../reverse_image_search.csv"
IMAGE_BASE_DIR = ".."  # 图片基础目录
MAX_IMAGES = 200  # 最多处理的图片数量


def create_es_client():
    """创建ES客户端"""
    client = Elasticsearch(
        ES_URL,
        basic_auth=(ES_USER, ES_PASSWORD),
        request_timeout=300
    )
    logger.info(f"ES连接状态: {client.info()['cluster_name']}")
    return client


def delete_inference(client, inference_id):
    """删除 inference 接口"""
    try:
        # 直接使用 requests 库调用 DELETE API
        url = f"{ES_URL}/_inference/multi_modal_embedding/{inference_id}"
        response = requests.delete(
            url,
            auth=(ES_USER, ES_PASSWORD),
            headers={"Content-Type": "application/json"}
        )
        if response.status_code == 200:
            logger.info(f"Inference 接口 {inference_id} 已删除")
        elif response.status_code == 404:
            logger.info(f"Inference 接口 {inference_id} 不存在,无需删除")
        else:
            logger.warning(f"删除 inference 接口返回状态: {response.status_code}")
    except Exception as e:
        logger.warning(f"删除 inference 接口失败: {e}")


def create_inference(client, inference_id, api_key):
    """创建 inference 接口"""
    inference_config = {
        "service": "alibaba-cloud-custom-model",
        "service_settings": {
            "secret_parameters": {
                "DASHSCOPE_API_KEY": api_key
            },
            "url": "https://dashscope.aliyuncs.com",
            "path": {
                "/api/v1/services/embeddings/multimodal-embedding/multimodal-embedding": {
                    "POST": {
                        "headers": {
                            "Authorization": "Bearer ${DASHSCOPE_API_KEY}",
                            "Content-Type": "application/json;charset=utf-8"
                        },
                        "request": {
                            "format": "string",
                            "content": '''
                            {
                              "model": "tongyi-embedding-vision-plus",
                              "input": {
                                "contents":[${input}]
                              },
                              "parameters": {
                                "dimension":${dimension}
                              }
                            }
                            '''
                        },
                        "response": {
                            "json_parser": {
                                "multi_modal_embedding": "$.output.embeddings[*].embedding"
                            }
                        }
                    }
                }
            }
        },
        "task_settings": {
            "single_input": True,
            "parameters": {
                "dimension": "1152"
            }
        }
    }
    
    try:
        # 直接使用 requests 库调用 PUT API
        url = f"{ES_URL}/_inference/multi_modal_embedding/{inference_id}"
        response = requests.put(
            url,
            auth=(ES_USER, ES_PASSWORD),
            headers={"Content-Type": "application/json"},
            json=inference_config,
            timeout=60
        )
        
        if response.status_code == 200:
            logger.info(f"Inference 接口 {inference_id} 创建成功")
        else:
            logger.error(f"Inference 接口创建失败: {response.status_code} - {response.text}")
            raise Exception(f"HTTP {response.status_code}: {response.text}")
    except Exception as e:
        logger.error(f"创建 inference 接口失败: {e}")
        raise


def create_pipeline(client, pipeline_name, inference_id):
    """创建 ingest pipeline"""
    pipeline_config = {
        "description": "This is an example of multi_modal_embedding",
        "processors": [
            {
                "text_embedding": {
                    "if": "ctx.containsKey('text_input')",
                    "model_id": inference_id,
                    "input_output": [
                        {
                            "input_field": "text_input",
                            "output_field": "text_embedding"
                        }
                    ]
                }
            },
            {
                "text_embedding": {
                    "if": "ctx.containsKey('image_input')",
                    "model_id": inference_id,
                    "input_output": [
                        {
                            "input_field": "image_input",
                            "output_field": "image_embedding"
                        }
                    ]
                }
            }
        ]
    }
    
    try:
        client.ingest.put_pipeline(id=pipeline_name, body=pipeline_config)
        logger.info(f"Pipeline {pipeline_name} 创建成功")
    except Exception as e:
        logger.error(f"创建 pipeline 失败: {e}")
        raise


def delete_index(client, index_name):
    """删除索引"""
    try:
        if client.indices.exists(index=index_name):
            client.indices.delete(index=index_name)
            logger.info(f"索引 {index_name} 已删除")
        else:
            logger.info(f"索引 {index_name} 不存在,无需删除")
    except Exception as e:
        logger.error(f"删除索引失败: {e}")


def create_index(client, index_name):
    """创建索引"""
    mapping = {
        "mappings": {
            "properties": {
                "image_input": {
                    "type": "text"
                },
                "text_input": {
                    "type": "text"
                },
                "text_embedding": {
                    "type": "dense_vector",
                    "dims": 1152
                },
                "image_embedding": {
                    "type": "dense_vector",
                    "dims": 1152
                }
            }
        }
    }
    
    try:
        client.indices.create(index=index_name, body=mapping)
        logger.info(f"索引 {index_name} 创建成功")
    except Exception as e:
        logger.error(f"创建索引失败: {e}")


def read_image_as_base64(image_path):
    """读取图片并转换为base64编码"""
    _, ext = os.path.splitext(image_path)
    image_format = ext.lstrip(".").lower()
    if image_format == "jpg":
        image_format = "jpeg"
    
    with open(image_path, "rb") as f:
        image_data = f.read()
    base64_data = base64.b64encode(image_data).decode("utf-8")
    return f"data:image/{image_format};base64,{base64_data}"


def retry_with_backoff(func, max_retries=3, initial_delay=1):
    """带指数退避的重试机制"""
    for attempt in range(max_retries):
        try:
            return func()
        except Exception as e:
            if attempt < max_retries - 1:
                delay = initial_delay * (2 ** attempt)
                logger.warning(f"第 {attempt + 1} 次尝试失败: {e},{delay}秒后重试...")
                time.sleep(delay)
            else:
                raise e


def extract_image_description(image_path, api_key, max_retries=3):
    """使用Qwen-VL提取图片语义描述,带重试机制"""
    def _call_api():
        # 构造消息
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": "You are a helpful assistant."}]
            },
            {
                "role": "user",
                "content": [
                    {"image": image_path},
                    {"text": "先用50字内的文字描述这张图片,然后再给出5个关键词"}
                ]
            }
        ]
        
        # 调用Qwen-VL-Plus
        response = MultiModalConversation.call(
            model="qwen-vl-plus",
            messages=messages,
            api_key=api_key
        )
        
        if response.status_code == HTTPStatus.OK:
            return response.output["choices"][0]["message"].content[0]["text"]
        else:
            raise RuntimeError(f"API调用失败,状态码: {response.status_code}, 错误信息: {response.message}")
    
    try:
        return retry_with_backoff(_call_api, max_retries=max_retries)
    except Exception as e:
        logger.error(f"提取图片描述失败: {e}")
        return "无法识别的图片内容"


def load_image_data(csv_path, base_dir, max_count):
    """从CSV文件加载图片数据"""

    images = [ ]

    script_dir = os.path.dirname(os.path.abspath(__file__))
    
    with open(os.path.join(script_dir, csv_path), 'r') as f:
        reader = csv.DictReader(f)
        for i, row in enumerate(reader):
            if i >= max_count:
                break
            # 处理图片路径
            image_path = row['path']
            # 将相对路径转换为绝对路径
            if image_path.startswith('./'):
                image_path = image_path[2:]
            full_path = os.path.join(script_dir, base_dir, image_path)
            
            if os.path.exists(full_path):
                images.append({
                    'id': row['id'],
                    'path': full_path,
                    'label': row['label']
                })
            else:
                logger.warning(f"图片不存在: {full_path}")
    
    logger.info(f"加载了 {len(images)} 张图片")
    return images


def write_to_es(es_client, index_name, pipeline_name, doc, max_retries=3):
    """写入单条文档到ES,带重试机制"""
    def _write():
        es_client.index(
            index=index_name,
            body=doc,
            pipeline=pipeline_name
        )
        return True
    
    try:
        return retry_with_backoff(_write, max_retries=max_retries)
    except Exception as e:
        logger.error(f"写入ES失败: {e}")
        return False


def main():
    logger.info("=" * 50)
    logger.info("多模态数据写入程序")
    logger.info("=" * 50)
    
    # 1. 创建 ES 客户端
    logger.info("\n[1/7] 创建 ES 客户端...")
    es_client = create_es_client()
    
    # 2. 删除同名 inference 接口
    logger.info("\n[2/7] 删除同名 inference 接口...")
    delete_inference(es_client, INFERENCE_ID)
    
    # 3. 创建 inference 接口
    logger.info("\n[3/7] 创建 inference 接口...")
    create_inference(es_client, INFERENCE_ID, DASHSCOPE_API_KEY)
    
    # 4. 创建 pipeline
    logger.info("\n[4/7] 创建 pipeline...")
    create_pipeline(es_client, PIPELINE_NAME, INFERENCE_ID)
    
    # 5. 清理同名索引
    logger.info("\n[5/7] 清理同名索引...")
    delete_index(es_client, INDEX_NAME)
    
    # 6. 创建新索引
    logger.info("\n[6/7] 创建新索引...")
    create_index(es_client, INDEX_NAME)
    
    # 7. 加载图片数据
    logger.info("\n[7/7] 加载图片数据...")
    images = load_image_data(CSV_FILE, IMAGE_BASE_DIR, MAX_IMAGES)
    
    # 8. 处理图片并写入 ES
    logger.info("\n开始处理图片并写入 ES...")
    success_count = 0
    fail_count = 0
    
    for i, img_info in enumerate(images):
        try:
            logger.info(f"处理第 {i+1}/{len(images)} 张图片: {img_info['path']}")
            
            # 提取图片描述(使用图片路径)
            description = extract_image_description(img_info['path'], DASHSCOPE_API_KEY)
            logger.info(f"  描述: {description}")
            
            # 读取图片并转 base64
            image_base64 = read_image_as_base64(img_info['path'])
            
            # 构造写入文档
            doc = {
                "text_input": json.dumps({"text": description}, ensure_ascii=False),
                "image_input": json.dumps({"image": image_base64}, ensure_ascii=False)
            }
            
            # 写入 ES(带重试机制)
            if write_to_es(es_client, INDEX_NAME, PIPELINE_NAME, doc):
                success_count += 1
            else:
                fail_count += 1
            
        except Exception as e:
            logger.error(f"处理失败: {e}")
            fail_count += 1
    
    logger.info("\n" + "=" * 50)
    logger.info(f"处理完成!成功: {success_count}, 失败: {fail_count}")
    logger.info("=" * 50)


if __name__ == "__main__":
    main()

涉及以下参数,请根据实际情况进行替换

参数名称

说明

ES_HOST

Elasticsearch 实例的访问地址

ES_PORT

Elasticsearch 实例访问端口,默认为9200

ES_USER

Elasticsearch 的用户名

ES_PASSWORD

Elasticsearch 的访问密钥

DASHSCOPE_API_KEY

百炼平台的API密钥,用于调用千问VL和多模态Embedding模型。

执行 python3 write.py ,可以看到每张图片生成的相应描述,以及相关处理进度

image.png

查询流程

在查询流程中,我们定义了四类查询,分别为文搜图,文搜文,图搜图以及图搜文。将输入的文本或者图片调用百炼多模态模型进行Embedding,将Embedding的结果根据查询类型,检索image_embeddingtext_embedding字段,匹配最相关的文本或图片。

import os
import base64
import json
import logging
from elasticsearch import Elasticsearch

# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ==================== 配置 ====================
# ES配置
ES_HOST = "<ES_HOST>"           # 替换为您的ES实例地址
ES_PORT = 9200
ES_USER = "<ES_USER>"           # 替换为您的ES用户名
ES_PASSWORD = "<ES_PASSWORD>"   # 替换为您的ES密码
ES_URL = f"http://{ES_HOST}:{ES_PORT}"

# 索引配置
INDEX_NAME = "multi_modal_test"
MODEL_ID = "bailian_mm"  # Inference ID


def create_es_client():
    """创建ES客户端"""
    client = Elasticsearch(
        ES_URL,
        basic_auth=(ES_USER, ES_PASSWORD),
        request_timeout=300
    )
    return client


def read_image_as_base64(image_path):
    """读取图片并转换为base64编码"""
    _, ext = os.path.splitext(image_path)
    image_format = ext.lstrip(".").lower()
    if image_format == "jpg":
        image_format = "jpeg"
    
    with open(image_path, "rb") as f:
        image_data = f.read()
    base64_data = base64.b64encode(image_data).decode("utf-8")
    return f"data:image/{image_format};base64,{base64_data}"


def bytes_to_base64(image_bytes, image_format="jpeg"):
    """将图片bytes转换为base64编码"""
    base64_data = base64.b64encode(image_bytes).decode("utf-8")
    return f"data:image/{image_format};base64,{base64_data}"


def search_by_text(es_client, query_text, k=10):
    """
    以文本搜索
    返回格式: {"text_input": ..., "image_input": ...}
    """
    query = {
        "_source": {
            "includes": ["text_input", "image_input"]
        },
        "knn": {
            "field": "text_embedding",
            "k": k,
            "num_candidates": 100,
            "query_vector_builder": {
                "multi_modal_embedding": {
                    "model_id": MODEL_ID,
                    "model_text": query_text
                }
            }
        },
        "size":k
    }
    
    try:
        response = es_client.search(index=INDEX_NAME, body=query)
        return parse_search_response(response)
    except Exception as e:
        logger.error(f"搜索失败: {e}")

        return [ ]



def search_by_image(es_client, image_base64, k=10):
    """
    以图片搜索
    image_base64: base64编码的图片,格式为 "data:image/jpeg;base64,xxx"
    返回格式: {"text_input": ..., "image_input": ...}
    """
    query = {
        "_source": {
            "includes": ["text_input", "image_input"]
        },
        "knn": {
            "field": "image_embedding",
            "k": k,
            "num_candidates": 100,
            "query_vector_builder": {
                "multi_modal_embedding": {
                    "model_id": MODEL_ID,
                    "model_image": image_base64
                }
            }
        },
        "size":k
    }
    
    try:
        response = es_client.search(index=INDEX_NAME, body=query)
        return parse_search_response(response)
    except Exception as e:
        logger.error(f"搜索失败: {e}")

        return [ ]



def parse_search_response(response):
    """解析ES搜索响应"""

    results = [ ]


    for hit in response.get("hits", {}).get("hits", [ ]):

        source = hit.get("_source", {})
        score = hit.get("_score", 0)
        
        # 解析text_input
        text_input_str = source.get("text_input", "{}")
        try:
            text_data = json.loads(text_input_str)
            text_content = text_data.get("text", "")
        except:
            text_content = text_input_str
        
        # 解析image_input
        image_input_str = source.get("image_input", "{}")
        try:
            image_data = json.loads(image_input_str)
            image_content = image_data.get("image", "")
        except:
            image_content = image_input_str
        
        results.append({
            "score": score,
            "text": text_content,
            "image": image_content
        })
    
    return results


def text_to_image(es_client, query_text, k=10):
    """
    以文搜图:输入文本查询,搜索最相似的图片
    返回图片列表
    """
    results = search_by_text(es_client, query_text, k)
    return [{"score": r["score"], "image": r["image"], "description": r["text"]} for r in results]


def text_to_text(es_client, query_text, k=10):
    """
    以文搜文:输入文本查询,搜索最相似的图片描述
    返回文本列表
    """
    results = search_by_text(es_client, query_text, k)
    return [{"score": r["score"], "text": r["text"]} for r in results]


def image_to_image(es_client, image_base64, k=10):
    """
    以图搜图:输入图片查询,搜索最相似的图片
    返回图片列表
    """
    results = search_by_image(es_client, image_base64, k)
    return [{"score": r["score"], "image": r["image"], "description": r["text"]} for r in results]


def image_to_text(es_client, image_base64, k=10):
    """
    以图搜文:输入图片查询,搜索最相似的图片描述
    返回文本列表
    """
    results = search_by_image(es_client, image_base64, k)
    return [{"score": r["score"], "text": r["text"]} for r in results]


class MultiModalSearcher:
    """多模态搜索器类,封装所有搜索功能"""
    
    def __init__(self):
        self.es_client = create_es_client()
        logger.info("多模态搜索器初始化完成")
    
    def text_to_image(self, query_text, k=10):
        """以文搜图"""
        return text_to_image(self.es_client, query_text, k)
    
    def text_to_text(self, query_text, k=10):
        """以文搜文"""
        return text_to_text(self.es_client, query_text, k)
    
    def image_to_image(self, image_input, k=10):
        """
        以图搜图
        image_input: 可以是图片路径或base64编码
        """
        if os.path.exists(image_input):
            image_base64 = read_image_as_base64(image_input)
        else:
            image_base64 = image_input
        return image_to_image(self.es_client, image_base64, k)
    
    def image_to_text(self, image_input, k=10):
        """
        以图搜文
        image_input: 可以是图片路径或base64编码
        """
        if os.path.exists(image_input):
            image_base64 = read_image_as_base64(image_input)
        else:
            image_base64 = image_input
        return image_to_text(self.es_client, image_base64, k)
    
    def search(self, query_type, query_input, k=10):
        """
        统一搜索接口
        query_type: "text_to_image", "text_to_text", "image_to_image", "image_to_text"
        query_input: 文本或图片路径/base64
        """
        if query_type == "text_to_image":
            return self.text_to_image(query_input, k)
        elif query_type == "text_to_text":
            return self.text_to_text(query_input, k)
        elif query_type == "image_to_image":
            return self.image_to_image(query_input, k)
        elif query_type == "image_to_text":
            return self.image_to_text(query_input, k)
        else:
            raise ValueError(f"不支持的查询类型: {query_type}")


# 测试代码
if __name__ == "__main__":
    import csv
    
    # 创建搜索器
    searcher = MultiModalSearcher()
    
    # 测试以文搜图
    print("\n=== 测试以文搜图 ===")
    results = searcher.text_to_image("狮子", k=3)
    for i, r in enumerate(results):
        print(f"{i+1}. 得分: {r['score']:.4f}, 描述: {r['description'][:50]}...")
    
    # 测试以文搜文
    print("\n=== 测试以文搜文 ===")
    results = searcher.text_to_text("狮子", k=3)
    for i, r in enumerate(results):
        print(f"{i+1}. 得分: {r['score']:.4f}, 文本: {r['text'][:50]}...")
    
    # 获取一张测试图片路径
    script_dir = os.path.dirname(os.path.abspath(__file__))
    csv_path = os.path.join(script_dir, "../reverse_image_search.csv")
    test_image_path = None
    with open(csv_path, 'r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            image_path = row['path']
            if image_path.startswith('./'):
                image_path = image_path[2:]
            full_path = os.path.join(script_dir, "..", image_path)
            if os.path.exists(full_path):
                test_image_path = full_path
                break
    
    if test_image_path:
        print(f"\n使用测试图片: {test_image_path}")
        
        # 测试以图搜图
        print("\n=== 测试以图搜图 ===")
        results = searcher.image_to_image(test_image_path, k=3)
        for i, r in enumerate(results):
            desc = r.get('description', '')[:50] if r.get('description') else '无描述'
            print(f"{i+1}. 得分: {r['score']:.4f}, 描述: {desc}...")
        
        # 测试以图搜文
        print("\n=== 测试以图搜文 ===")
        results = searcher.image_to_text(test_image_path, k=3)
        for i, r in enumerate(results):
            text = r.get('text', '')[:50] if r.get('text') else '无文本'
            print(f"{i+1}. 得分: {r['score']:.4f}, 文本: {text}...")
    else:
        print("\n未找到测试图片,跳过图搜测试")
    
    print("\n测试完成!")

ES配置相关参数与写入流程一致。

前端demo

import streamlit as st
import base64
from read import MultiModalSearcher, bytes_to_base64

# 页面配置
st.set_page_config(
    page_title="多模态搜索系统",
    layout="wide"
)

# 初始化搜索器
@st.cache_resource
def get_searcher():
    return MultiModalSearcher()

def display_image_from_base64(base64_str, width=300):
    """从base64字符串显示图片"""
    if base64_str and base64_str.startswith("data:image"):
        st.image(base64_str, width=width)
    else:
        st.warning("图片无法显示")

def main():
    st.title("多模态搜索系统")
    st.markdown("---")
    
    # 初始化搜索器
    try:
        searcher = get_searcher()
    except Exception as e:
        st.error(f"初始化搜索器失败: {e}")
        return
    
    # 侧边栏 - 搜索模式选择
    with st.sidebar:
        st.header("搜索设置")
        
        search_mode = st.radio(
            "选择搜索模式",
            options=[
                "以文搜图",
                "以文搜文",
                "以图搜图",
                "以图搜文"
            ],
            index=0
        )
        
        k = st.slider("返回结果数量", min_value=1, max_value=20, value=5)
        
        st.markdown("---")
        st.markdown("""
        ### 使用说明
        - **以文搜图**: 输入文本,搜索相似图片
        - **以文搜文**: 输入文本,搜索相似描述
        - **以图搜图**: 上传图片,搜索相似图片
        - **以图搜文**: 上传图片,搜索相似描述
        """)
    
    # 主界面
    col1, col2 = st.columns([1, 2])
    
    with col1:
        st.header("输入区域")
        
        if search_mode in ["以文搜图", "以文搜文"]:
            # 文本输入
            query_text = st.text_area(
                "请输入搜索文本",
                placeholder="例如:狮子、棕色的狗...",
                height=100
            )
            
            if st.button("开始搜索", type="primary", use_container_width=True):
                if query_text.strip():
                    with st.spinner("搜索中..."):
                        if search_mode == "以文搜图":
                            results = searcher.text_to_image(query_text, k)
                            st.session_state["results"] = results
                            st.session_state["result_type"] = "image"
                        else:
                            results = searcher.text_to_text(query_text, k)
                            st.session_state["results"] = results
                            st.session_state["result_type"] = "text"
                else:
                    st.warning("请输入搜索文本")
        
        else:
            # 图片输入
            uploaded_file = st.file_uploader(
                "请上传图片",
                type=["jpg", "jpeg", "png", "gif", "bmp"],
                help="支持 JPG, JPEG, PNG, GIF, BMP 格式"
            )
            
            if uploaded_file is not None:
                # 显示上传的图片
                st.image(uploaded_file, caption="上传的图片", use_container_width=True)
                
                if st.button("开始搜索", type="primary", use_container_width=True):
                    with st.spinner("搜索中..."):
                        # 读取图片并转换为base64
                        image_bytes = uploaded_file.read()
                        image_format = uploaded_file.type.split("/")[-1]
                        if image_format == "jpg":
                            image_format = "jpeg"
                        image_base64 = bytes_to_base64(image_bytes, image_format)
                        
                        if search_mode == "以图搜图":
                            results = searcher.image_to_image(image_base64, k)
                            st.session_state["results"] = results
                            st.session_state["result_type"] = "image"
                        else:
                            results = searcher.image_to_text(image_base64, k)
                            st.session_state["results"] = results
                            st.session_state["result_type"] = "text"
    
    with col2:
        st.header("搜索结果")
        
        if "results" in st.session_state and st.session_state["results"]:
            results = st.session_state["results"]
            result_type = st.session_state.get("result_type", "text")
            
            st.success(f"找到 {len(results)} 条结果")
            
            if result_type == "image":
                # 显示图片结果
                # 每行显示3张图片
                cols_per_row = 3
                for i in range(0, len(results), cols_per_row):
                    cols = st.columns(cols_per_row)
                    for j, col in enumerate(cols):
                        idx = i + j
                        if idx < len(results):
                            r = results[idx]
                            with col:
                                st.markdown(f"**结果 {idx+1}** (得分: {r['score']:.4f})")
                                if r.get("image"):
                                    display_image_from_base64(r["image"], width=250)
                                if r.get("description"):
                                    st.caption(r["description"][:100] + "..." if len(r.get("description", "")) > 100 else r.get("description", ""))
            else:
                # 显示文本结果
                for i, r in enumerate(results):
                    with st.expander(f"结果 {i+1} (得分: {r['score']:.4f})", expanded=True):
                        st.write(r.get("text", "无描述"))
        else:
            st.info("请在左侧输入搜索内容并点击搜索按钮")


if __name__ == "__main__":
    main()

操作流程

步骤一:配置参数

在运行脚本之前,需要修改以下配置参数:

  1. 打开 write.pyread.py 文件,修改以下配置:

# ES配置
ES_HOST = "<ES_HOST>"           # 替换为您的ES实例地址
ES_PORT = 9200
ES_USER = "<ES_USER>"           # 替换为您的ES用户名
ES_PASSWORD = "<ES_PASSWORD>"   # 替换为您的ES密码

# 百炼API配置(仅write.py需要)
DASHSCOPE_API_KEY = "<DASHSCOPE_API_KEY>"  # 替换为百炼平台中可用的API Key

步骤二:加载数据集

进入 scripts 目录,执行数据写入脚本:

cd scripts
python3 write.py

执行成功后,您将看到类似以下输出:

INFO - [1/7] 创建 ES 客户端...
INFO - ES连接状态: xxx
...
INFO - 处理第 1/200 张图片: xxx
INFO -   描述: xxx
...
INFO - 处理完成!成功: 200, 失败: 0

步骤三:验证数据写入(可选)

可以运行查询脚本验证数据是否写入成功:

python3 read.py

以以文搜图为例,执行成功的返回结果:

以文搜图 - 搜索关键词"狮子"
✓ 得分: 0.8077 - 一只狮子坐在倒下的树干上,周围是茂密的灌木和树枝
✓ 得分: 0.7732 - 雄壮的狮子站在草地上,鬃毛在阳光下威武宁静
✓ 得分: 0.7566 - 雄狮特写,鬃毛浓密,眼神锐利

步骤四:启动前端演示

streamlit run demo.py

启动后,浏览器会自动打开 http://localhost:8501

步骤五:多模态向量检索

在搜索设置中选择搜索类型,在输入区域中输入搜索文本或上传图片,点击开始搜索,检索相关结果。

image.png