自建Milvus迁移至AnalyticDB PostgreSQL版

Milvus作为专门用于处理对输入向量查询的数据库,能够对万亿级向量进行索引,支持通过Python编程语言将本地Milvus集群数据迁移到AnalyticDB PostgreSQL版实例中。

前提条件

  • 已创建2.3.x及以上版本的Milvus集群。

  • 已安装3.8及以上版本的Python环境。

  • 已安装所需的Python库。

    pip install psycopg2
    pip install pymilvus==2.3.0
    pip install pyaml
    pip install tqdm

迁移操作

步骤一:导出Milvus数据

  1. 准备好导出脚本export.py及导出配置文件milvus2csv.yaml,并创建输出目录,本文以output为例。

    导出脚本export.py如下。

    import yaml
    import json
    from pymilvus import (
        connections,
        DataType,
        Collection,
    )
    import os
    from tqdm import tqdm
    
    with open("./milvus2csv.yaml", "r") as f:
        config = yaml.safe_load(f)
    
    print("configuration:")
    print(config)
    
    milvus_config = config["milvus"]
    
    milvus_type_to_adbpg_type = {
        DataType.BOOL: "bool",
        DataType.INT8: "smallint",
        DataType.INT16: "smallint",
        DataType.INT32: "integer",
        DataType.INT64: "bigint",
    
        DataType.FLOAT: "real",
        DataType.DOUBLE: "double precision",
    
        DataType.STRING: "text",
        DataType.VARCHAR: "varchar",
        DataType.JSON: "json",
    
        DataType.BINARY_VECTOR: "bit[]",
        DataType.FLOAT_VECTOR: "real[]",
    }
    
    
    def convert_to_binary(binary_data):
        decimal_value = int.from_bytes(binary_data, byteorder='big')
        binary_string = bin(decimal_value)[2:].zfill(len(binary_data) * 8)
        return ','.join(list(binary_string))
    
    
    def data_convert_to_str(data, dtype, delimeter):
        if dtype == DataType.BOOL:
            return "1" if data else "0"
        elif dtype in [DataType.INT8, DataType.INT16,
                       DataType.INT32, DataType.INT64,
                       DataType.FLOAT, DataType.DOUBLE]:
            return str(data)
        elif dtype in [DataType.STRING, DataType.VARCHAR]:
            return str(data).replace(delimeter, f"\\{delimeter}").replace("\"", "\\\"")
        elif dtype == DataType.JSON:
            return str(data).replace(delimeter, f"\\{delimeter}").replace("\"", "\\\"")
        elif dtype == DataType.BINARY_VECTOR:
            return "{" + ','.join([convert_to_binary(d) for d in data]) + "}"
        elif dtype == DataType.FLOAT_VECTOR:
            return data
    
        Exception(f"Unsupported DataType {dtype}")
    
    
    def csv_write_rows(datum, fd, fields_types, delimiter="|"):
        for data in datum:
            for i in range(len(data)):
                ftype = fields_types[i]
                data[i] = data_convert_to_str(data[i], ftype, delimiter)
            fd.write(delimiter.join(data) + "\n")
    
    
    def csv_write_header(headers, fd, delimiter="|"):
        fd.write(delimiter.join(headers) + "\n")
    
    
    def dump_collection(collection_name: str):
        results = []
        file_cnt = 0
        print("connecting to milvus...")
        connections.connect("default", **milvus_config)
    
        export_config = config["export"]
        collection = Collection(collection_name)
        collection.load()
        tmp_path = os.path.join(export_config["output_path"], collection_name)
        if not os.path.exists(tmp_path):
            os.mkdir(tmp_path)
    
        fields_meta_str = ""
        fields_types = []
        headers = []
        for schema in collection.schema.fields:
            print(schema)
            fields_types.append(schema.dtype)
            headers.append(schema.name)
            if len(fields_meta_str) != 0:
                fields_meta_str += ","
            fields_meta_str += f"{schema.name} {milvus_type_to_adbpg_type[schema.dtype]}"
            if schema.dtype == DataType.VARCHAR and "max_length" in schema.params.keys():
                fields_meta_str += f"({schema.params['max_length']})"
            if schema.is_primary:
                fields_meta_str += " PRIMARY KEY"
    
        create_table_sql = f"CREATE TABLE {collection.name} " \
                           f" ({fields_meta_str});"
    
        with open(os.path.join(export_config["output_path"], collection_name, "create_table.sql"), "w") as f:
            f.write(create_table_sql)
    
        print(create_table_sql)
    
        print(headers)
    
        total_num = collection.num_entities
        collection.load()
        query_iterator = collection.query_iterator(batch_size=1000, expr="", output_fields=headers)
    
        def write_to_csv_file(col_names, data):
            if len(results) == 0:
                return
            nonlocal file_cnt
            assert(file_cnt <= 1e9)
            output_file_name = os.path.join(export_config["output_path"], collection_name, f"{str(file_cnt).zfill(10)}.csv")
            with open(output_file_name, "w", newline="") as csv_file:
                # write header
                csv_write_header(col_names, csv_file)
                # write data
                csv_write_rows(data, csv_file, fields_types)
                file_cnt += 1
                results.clear()
    
        with tqdm(total=total_num, bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}") as pbar:
            while True:
                res = query_iterator.next()
                if len(res) == 0:
                    print("query iteration finished, close")
                    # close the iterator
                    query_iterator.close()
                    break
                for row in res:
                    row_list = []
                    for i in range(len(headers)):
                        field = row[headers[i]]
                        if isinstance(field, list) and fields_types[i] != DataType.BINARY_VECTOR:
                            row_list.append("{" + ", ".join(str(x) for x in field) + "}")
                        elif isinstance(field, dict):
                            row_list.append(json.dumps(field, ensure_ascii=False))
                        else:
                            row_list.append(field)
                    results.append(row_list)
                    if len(results) >= export_config["max_line_in_file"]:
                        write_to_csv_file(headers, data=results)
                    pbar.update(1)
    
        write_to_csv_file(headers, data=results)
    
    if __name__ == "__main__":
      for name in config["export"]["collections"]:
          dump_collection(name)
    

    导出配置文件milvus2csv.yaml如下。

    milvus:
       host: '<localhost>'        # Milvus服务主机地址
       port: 19530                # Milvus服务端口
       user: '<user_name>'        # 用户名
       password: '<password>'     # 密码
       db_name: '<database_name>' # 数据库名
       token: '<token_id>'        # 访问token
    
    export:
       collections:
        - 'test'
        - 'medium_articles_with_json'
        # - 'hello_milvus'
        # - 'car'
        # - 'medium_articles_with_dynamic'
        # 填写所有需要导出的collection名称
      max_line_in_file: 40000     # 导出文件切分行数
      output_path: './output'     # 导出目标目录,本文以./output为例
  2. 将导出脚本export.py、导出配置文件milvus2csv.yaml及输出目录output存放至同一个目录下。目录层级如下。

    ├── export.py
    ├── milvus2csv.yaml
    └── output
  3. 根据Milvus集群信息,修改milvus2csv.yaml中配置项。

  4. 运行Python脚本,并查看输出结果。

    python export.py

    输出结果如下。

    .
    ├── export.py
    ├── milvus2csv.yaml
    └── output
        ├── medium_articles_with_json
        │   ├── 0000000000.csv
        │   ├── 0000000001.csv
        │   ├── 0000000002.csv
        │   └── create_table.sql
        └── test
            ├── 0000000000.csv
            └── create_table.sql

步骤二:导入AnalyticDB PostgreSQL版向量数据库

  1. 准备好导入脚本import.py、导入配置文件csv2adbpg.yaml及需要导入的数据data(即在导出步骤中得到的output目录)。

    导入脚本import.py如下。

    import psycopg2
    import yaml
    import glob
    import os
    
    if __name__ == "__main__":
        with open('csv2adbpg.yaml', 'r') as config_file:
            config = yaml.safe_load(config_file)
    
        print("current config:" + str(config))
    
        db_host = config['database']['host']
        db_port = config['database']['port']
        db_name = config['database']['name']
        schema_name = config['database']['schema']
        db_user = config['database']['user']
        db_password = config['database']['password']
        data_path = config['data_path']
    
        conn = psycopg2.connect(
            host=db_host,
            port=db_port,
            database=db_name,
            user=db_user,
            password=db_password,
            options=f'-c search_path={schema_name},public'
        )
    
        cur = conn.cursor()
    
        # check schema
        cur.execute("SELECT schema_name FROM information_schema.schemata WHERE schema_name = %s", (schema_name,))
        existing_schema = cur.fetchone()
        if existing_schema:
            print(f"Schema {schema_name} already exists.")
        else:
            # create schema
            cur.execute(f"CREATE SCHEMA {schema_name}")
            print(f"Created schema: {schema_name}")
    
        for table_name in os.listdir(data_path):
            table_folder = os.path.join(data_path, table_name)
            print(f"Begin Process table: {table_name}")
            if os.path.isdir(table_folder):
                create_table_file = os.path.join(table_folder, 'create_table.sql')
                with open(create_table_file, 'r') as file:
                    create_table_sql = file.read()
                try:
                    cur.execute(create_table_sql)
                except psycopg2.errors.DuplicateTable as e:
                    print(e)
                    conn.rollback()
                    continue
                print(f"Created table: {table_name}")
    
                cnt = 0
                csv_files = glob.glob(os.path.join(table_folder, '*.csv'))
                for csv_file in csv_files:
                    with open(csv_file, 'r') as file:
                        copy_command = f"COPY {table_name} FROM STDIN DELIMITER '|' HEADER"
                        cur.copy_expert(copy_command, file)
                    cnt += 1
                    print(f"Imported data from: {csv_file} | {cnt}/{len(csv_files)} file(s) Done")
    
            conn.commit()
            print(f"Finished import table: {table_name}")
            print(' # '*60)
    
        cur.close()
        conn.close()
    

    导入配置文件csv2adbpg.yaml如下。

    database:
      host: "192.16.XX.XX"         # AnalyticDB PostgreSQL版实例的外网地址
      port: 5432                   # AnalyticDB PostgreSQL版实例端口号
      name: "vector_database"      # 导入目标数据库名 
      user: "username"             # AnalyticDB PostgreSQL版实例的数据库账号
      password: ""                 # 账号密码
      schema: "public"             # 导入Schama名称,若不存在则会自动创建
    
    data_path: "./data"            # 导入数据源
  2. 将导入脚本import.py和导入配置文件csv2adbpg.yaml与需要导入的数据data存放在同一目录下。目录层级如下。

    .
    ├── csv2adbpg.yaml
    ├── data
    │   ├── medium_articles_with_json
    │   │   ├── 0000000000.csv
    │   │   ├── 0000000001.csv
    │   │   ├── 0000000002.csv
    │   │   └── create_table.sql
    │   └── test
    │       ├── 0000000000.csv
    │       └── create_table.sql
    └── import.py
  3. 根据AnalyticDB PostgreSQL版实例信息,修改csv2adbpg.yaml文件中配置项。

  4. 运行Python脚本。

    python import.py
  5. AnalyticDB PostgreSQL版向量数据库中检查数据是否正常导入。

  6. 重建所需要的索引。具体操作,请参见创建向量索引

相关文档

更多关于Milvus的介绍,请参见Milvus产品文档