基于Ray+LLaMA-Factory实现高效图片打标

更新时间:

本文介绍如何利用RayLLaMA Factory结合的技术方案实现高效的图像打标。

背景信息

某游戏社区场景,旨在为玩家和开发者提供游戏分发与互动服务。针对玩家在游戏中频繁遇到的攻略不匹配、上下文缺失等问题,探索引入AI构建“游戏陪玩助手”,旨在通过识别玩家当前游戏状态,结合站内内容提供相应建议。

现有方案基于LLaMA-Factory进行SFTCPT训练,并借助VLLM或阿里云百炼进行推理,但同时依赖大量人工标注的图像数据以支持视觉理解。

在这一背景下,以ADB Ray为中心,与Lance进行集成,利用RayData提升分布式图文数据处理效率和结构化能力;同时集成LLaMA-Factory,通过Ray实现对Qwen-VL多模态模型的分布式微调。

方案优势

  • 使用RayData实现任意格式源数据的高效加载与转换,并将其统一存储为Lance格式。

  • Lance支持图像二进制和结构化数据的集成存储,提供更优的数据一致性和版本控制,从而减少远程IO。

  • 结合Ray实现Lance分布式数据打标和增量更新(新增列),与Parquet数据相比,速度提升193%。

方案流程

image

准备工作

  1. AnalyticDB for MySQL集群的产品系列为企业版、基础版或湖仓版

  2. 托管Ray服务,并将Worker资源类型选择为GPU。

  3. 提交工单联系技术支持协助部署LLaMA-Factory框架。

  4. 提交工单联系技术支持协助挂载网卡,确保对应VPC可以访问公网。

操作步骤

步骤一:准备数据并将其写入Lance

根据实际环境,替换以下代码中的相关配置并执行,加载model scope数据集并下载图片二进制数据,然后使用RayData进行数据格式处理,处理完成后将数据写入Lance。

from modelscope.msdatasets import MsDataset
import lance
import pyarrow as pa
import pandas as pd
import os
import json
import uuid
import shutil
from tqdm import tqdm
import time
import requests
from PIL import Image
import io
import numpy as np
import ray
import pyarrow.compute as pc

# 配置参数
OUTPUT_DIR = "/home/ray/binary_data"
NUM_SAMPLES = 10000
FIXED_USER_PROMPT = "请描述下述图片中发生了什么<image>"
HEADERS = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}

def prepare_output_directory():
    """准备输出目录结构"""
    if os.path.exists(OUTPUT_DIR):
        shutil.rmtree(OUTPUT_DIR)
    
    os.makedirs(OUTPUT_DIR)
    print(f"输出目录已创建: {OUTPUT_DIR}")

def download_image_binary(image_url, max_retries=3):
    """从URL下载图像文件,返回图片二进制数据"""
    for attempt in range(max_retries):
        try:
            # 下载图像数据
            response = requests.get(image_url, headers=HEADERS, timeout=10)
            response.raise_for_status()
            
            # 检查内容是否为有效图像
            try:
                image = Image.open(io.BytesIO(response.content))
                image.verify()  # 验证是否为有效图像
                return response.content
            except (IOError, SyntaxError) as e:
                raise ValueError(f"下载的图像无效: {str(e)}")
        
        except (requests.RequestException, ValueError) as e:
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt
                print(f"下载失败 (尝试 {attempt+1}/{max_retries}), 将在 {wait_time} 秒后重试: {str(e)}")
                time.sleep(wait_time)
            else:
                # 创建替代图像二进制数据
                print(f"无法下载图像,创建替代图像: {image_url}")
                img = Image.new('RGB', (256, 256), color=(122, 122, 122))
                buffer = io.BytesIO()
                img.save(buffer, format='JPEG')
                return buffer.getvalue()

def parse_global_caption(sample):
    """从样本中提取全局描述文本"""
    # 尝试从不同位置获取global_caption
    global_caption = None

    # 首先尝试直接从顶级字段获取
    if 'global_caption' in sample:
        global_caption = sample['global_caption']

    # 然后尝试在cap_seg字段中查找
    elif 'cap_seg' in sample:
        cap_seg = sample['cap_seg']
        # 处理JSON字符串格式
        if isinstance(cap_seg, str):
            try:
                cap_seg = json.loads(cap_seg)
            except json.JSONDecodeError:
                # 可能是带单引号的字符串
                try:
                    cap_seg = json.loads(cap_seg.replace("'", '"'))
                except:
                    cap_seg = {}

        # 处理字典格式
        if isinstance(cap_seg, dict) and 'global_caption' in cap_seg:
            global_caption = cap_seg['global_caption']

    # 回退方案:使用默认描述
    if not global_caption or not isinstance(global_caption, str):
        print(f"未找到有效的global_caption,使用默认描述")
        global_caption = "这是一张图片"

    return global_caption.strip()


def convert_samples(samples_df):
    """转换单个批次的样本为目标格式"""
    results = []
    for _, row in samples_df.iterrows():
        sample = row.to_dict()
        index = sample.get("__index_level_0__", "")  # 如果有索引列则获取

        # 生成唯一ID
        image_id = str(uuid.uuid4())

        # 获取图像URL并下载
        image_url = sample.get("opensource_url", "")
        if not image_url:
            print(f"缺少opensource_url,使用默认图片")
            image_url = "https://modelscope.cn-beijing.oss.aliyuncs.com/open_data/sa-1b-cot-qwen/default.jpg"

        # 下载图像二进制数据
        image_binary = download_image_binary(image_url)

        # 构建对话结构
        result = {
            "id": image_id,
            "messages": [
                {"content": FIXED_USER_PROMPT, "role": "user"},
                {"content": parse_global_caption(sample), "role": "assistant"},
            ],
            "images": [image_binary],  # 直接存储二进制图片数据
            "image_url": image_url,  # 保留原始URL用于调试
        }
        results.append(result)

    return pd.DataFrame(results)


def save_with_ray_lance(processed_ray_dataset):
    """使用 Ray 的 write_lance 方法保存为 Lance 格式"""
    print("正在使用 Ray 的 write_lance 方法保存数据...")

    lance_output_path = os.path.join(OUTPUT_DIR, "original_data.lance")

    # 写入 Lance 格式
    processed_ray_dataset.write_lance(lance_output_path)

    print(f"数据已成功写入: {lance_output_path}")


def main():
    # 设置数据集信息
    dataset_name = "Tongyi-DataEngine/SA1B-Paired-Captions-Images"

    # 准备输出目录
    prepare_output_directory()

    # 下载数据集
    print(f"正在下载数据集: {dataset_name} (前 {NUM_SAMPLES} 个样本)")

    ms_dataset = MsDataset.load(dataset_name, split="train")

    ray_dataset = ray.data.from_huggingface(ms_dataset).limit(NUM_SAMPLES)

    ray_dataset = ray_dataset.repartition(100)

    # 处理样本
    print("开始处理样本...")
    processed_ray_dataset = ray_dataset.map_batches(
        convert_samples,
        batch_format="pandas",
        num_cpus=1,
        concurrency=100,
        batch_size=32,
    )

    # 保存转换后的样本
    save_with_ray_lance(processed_ray_dataset)

    # 添加元数据
    metadata = {
        "dataset": dataset_name,
        "num_samples": NUM_SAMPLES,
        "conversion_date": pd.Timestamp.now().isoformat(),
        "conversion_format": "SA1B转多模态对话(二进制图片存储)",
        "message_structure": [
            {"role": "user", "content": "固定提示词<image>"},
            {"role": "assistant", "content": "图像描述"},
        ],
    }

    with open(os.path.join(OUTPUT_DIR, "metadata.json"), "w", encoding="utf-8") as f:
        json.dump(metadata, f, indent=2, ensure_ascii=False)

    print("\n 转换完成!")
    return 0


if __name__ == '__main__':
    main()

步骤二:Lance增量打标

  1. 使用Ray Serve部署score模型,执行如下代码,部署打分服务。

    """
    ray serve,一个持续运行中的打分服务,供应数据预处理使用
    demoscore函数为random,生产中可以替换为LLM模型或者其他打分器
    """
    import ray
    from ray import serve
    from ray.data import Dataset
    import pandas as pd
    import pyarrow as pa
    import logging
    import random
    # 设置日志
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    # 初始化Ray
    ray.init(address="ray://127.0.0.1:10001")
    
    # 配置Ray Serve
    @serve.deployment()
    class ScoringModel:
        def __init__(self):
            # 在这里加载您的模型
            # demo中为空
            pass
        
        def score(self,data_batch) :
            results = []
            print(f"Scoring {len(data_batch)} items")
            print(data_batch[0])
            for item in data_batch:
                # 调用实际评分模型 (这里使用随机值模拟)
                score = random.randint(60,100)
                item["score"] = score
                results.append(item)
            return results
    
    # 部署模型服务
    model_deployment = ScoringModel.bind()
    serve.run(model_deployment, name="scoring_model")
    
  2. 执行如下代码,对Lance数据增量打标。

    import pyarrow as pa
    from pathlib import Path
    import lance
    import ray
    import pandas as pd
    from lance.ray.fragment_api import add_columns
    import random
    from ray import serve
    
    path = "/nas/lance/binary_data_10w/original_data.lance"
    
    
    # Define label generation logic
    def generate_labels(batch: pa.RecordBatch) -> pa.RecordBatch:
        """使用Ray Serve服务对数据批次进行评分,返回带分数的新RecordBatch"""
        # 将RecordBatch转为Pandas DataFrame
        batch_df = batch.to_pandas()
    
        handle = serve.get_app_handle("scoring_model")
        # 将DataFrame转换为字典列表(每行一个字典)
        dict_list = batch_df.to_dict('records')
    
        # 异步调用评分服务
        scored_data_ref = handle.score.remote(dict_list)
        scores = scored_data_ref.result()  # 获取评分结果
    
        return pa.RecordBatch.from_arrays([scores], names=["score1"])
    
    
    def main():
        # Add new columns in parallel
        lance_ds = lance.dataset(path)
    
        add_columns(
            lance_ds,
            generate_labels,
            source_columns=["images"],  # Input columns needed
        )
    
        print("数据打标完成")
    
    
    if __name__ == "__main__":
        main()
    

步骤三:LLaMA-Factory多模训练

  1. 安装登录LLaMA FactoryWeb UI界面

  2. dataset_info元数据插入dataset_info.json中,或新建一个dataset_info.json文件,用于声明数据目录和数据字段的映射关系。

    image.png

  3. 提交工单联系技术支持修改llama-factory源代码,以支持Lance数据读取,并确保仅读取需要的数据列。

  4. 使用llamafactory webui训练,使用设定的数据集名称指定数据集。

    image