基于Ray+LLaMA-Factory实现高效图片打标
本文介绍如何利用Ray与LLaMA Factory结合的技术方案实现高效的图像打标。
背景信息
某游戏社区场景,旨在为玩家和开发者提供游戏分发与互动服务。针对玩家在游戏中频繁遇到的攻略不匹配、上下文缺失等问题,探索引入AI构建“游戏陪玩助手”,旨在通过识别玩家当前游戏状态,结合站内内容提供相应建议。
现有方案基于LLaMA-Factory进行SFT和CPT训练,并借助VLLM或阿里云百炼进行推理,但同时依赖大量人工标注的图像数据以支持视觉理解。
在这一背景下,以ADB Ray为中心,与Lance进行集成,利用RayData提升分布式图文数据处理效率和结构化能力;同时集成LLaMA-Factory,通过Ray实现对Qwen-VL多模态模型的分布式微调。
方案优势
使用RayData实现任意格式源数据的高效加载与转换,并将其统一存储为Lance格式。
Lance支持图像二进制和结构化数据的集成存储,提供更优的数据一致性和版本控制,从而减少远程IO。
结合Ray实现Lance分布式数据打标和增量更新(新增列),与Parquet数据相比,速度提升193%。
方案流程
准备工作
操作步骤
步骤一:准备数据并将其写入Lance
根据实际环境,替换以下代码中的相关配置并执行,加载model scope数据集并下载图片二进制数据,然后使用RayData进行数据格式处理,处理完成后将数据写入Lance。
from modelscope.msdatasets import MsDataset
import lance
import pyarrow as pa
import pandas as pd
import os
import json
import uuid
import shutil
from tqdm import tqdm
import time
import requests
from PIL import Image
import io
import numpy as np
import ray
import pyarrow.compute as pc
# 配置参数
OUTPUT_DIR = "/home/ray/binary_data"
NUM_SAMPLES = 10000
FIXED_USER_PROMPT = "请描述下述图片中发生了什么<image>"
HEADERS = {'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
def prepare_output_directory():
"""准备输出目录结构"""
if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR)
print(f"输出目录已创建: {OUTPUT_DIR}")
def download_image_binary(image_url, max_retries=3):
"""从URL下载图像文件,返回图片二进制数据"""
for attempt in range(max_retries):
try:
# 下载图像数据
response = requests.get(image_url, headers=HEADERS, timeout=10)
response.raise_for_status()
# 检查内容是否为有效图像
try:
image = Image.open(io.BytesIO(response.content))
image.verify() # 验证是否为有效图像
return response.content
except (IOError, SyntaxError) as e:
raise ValueError(f"下载的图像无效: {str(e)}")
except (requests.RequestException, ValueError) as e:
if attempt < max_retries - 1:
wait_time = 2 ** attempt
print(f"下载失败 (尝试 {attempt+1}/{max_retries}), 将在 {wait_time} 秒后重试: {str(e)}")
time.sleep(wait_time)
else:
# 创建替代图像二进制数据
print(f"无法下载图像,创建替代图像: {image_url}")
img = Image.new('RGB', (256, 256), color=(122, 122, 122))
buffer = io.BytesIO()
img.save(buffer, format='JPEG')
return buffer.getvalue()
def parse_global_caption(sample):
"""从样本中提取全局描述文本"""
# 尝试从不同位置获取global_caption
global_caption = None
# 首先尝试直接从顶级字段获取
if 'global_caption' in sample:
global_caption = sample['global_caption']
# 然后尝试在cap_seg字段中查找
elif 'cap_seg' in sample:
cap_seg = sample['cap_seg']
# 处理JSON字符串格式
if isinstance(cap_seg, str):
try:
cap_seg = json.loads(cap_seg)
except json.JSONDecodeError:
# 可能是带单引号的字符串
try:
cap_seg = json.loads(cap_seg.replace("'", '"'))
except:
cap_seg = {}
# 处理字典格式
if isinstance(cap_seg, dict) and 'global_caption' in cap_seg:
global_caption = cap_seg['global_caption']
# 回退方案:使用默认描述
if not global_caption or not isinstance(global_caption, str):
print(f"未找到有效的global_caption,使用默认描述")
global_caption = "这是一张图片"
return global_caption.strip()
def convert_samples(samples_df):
"""转换单个批次的样本为目标格式"""
results = []
for _, row in samples_df.iterrows():
sample = row.to_dict()
index = sample.get("__index_level_0__", "") # 如果有索引列则获取
# 生成唯一ID
image_id = str(uuid.uuid4())
# 获取图像URL并下载
image_url = sample.get("opensource_url", "")
if not image_url:
print(f"缺少opensource_url,使用默认图片")
image_url = "https://modelscope.cn-beijing.oss.aliyuncs.com/open_data/sa-1b-cot-qwen/default.jpg"
# 下载图像二进制数据
image_binary = download_image_binary(image_url)
# 构建对话结构
result = {
"id": image_id,
"messages": [
{"content": FIXED_USER_PROMPT, "role": "user"},
{"content": parse_global_caption(sample), "role": "assistant"},
],
"images": [image_binary], # 直接存储二进制图片数据
"image_url": image_url, # 保留原始URL用于调试
}
results.append(result)
return pd.DataFrame(results)
def save_with_ray_lance(processed_ray_dataset):
"""使用 Ray 的 write_lance 方法保存为 Lance 格式"""
print("正在使用 Ray 的 write_lance 方法保存数据...")
lance_output_path = os.path.join(OUTPUT_DIR, "original_data.lance")
# 写入 Lance 格式
processed_ray_dataset.write_lance(lance_output_path)
print(f"数据已成功写入: {lance_output_path}")
def main():
# 设置数据集信息
dataset_name = "Tongyi-DataEngine/SA1B-Paired-Captions-Images"
# 准备输出目录
prepare_output_directory()
# 下载数据集
print(f"正在下载数据集: {dataset_name} (前 {NUM_SAMPLES} 个样本)")
ms_dataset = MsDataset.load(dataset_name, split="train")
ray_dataset = ray.data.from_huggingface(ms_dataset).limit(NUM_SAMPLES)
ray_dataset = ray_dataset.repartition(100)
# 处理样本
print("开始处理样本...")
processed_ray_dataset = ray_dataset.map_batches(
convert_samples,
batch_format="pandas",
num_cpus=1,
concurrency=100,
batch_size=32,
)
# 保存转换后的样本
save_with_ray_lance(processed_ray_dataset)
# 添加元数据
metadata = {
"dataset": dataset_name,
"num_samples": NUM_SAMPLES,
"conversion_date": pd.Timestamp.now().isoformat(),
"conversion_format": "SA1B转多模态对话(二进制图片存储)",
"message_structure": [
{"role": "user", "content": "固定提示词<image>"},
{"role": "assistant", "content": "图像描述"},
],
}
with open(os.path.join(OUTPUT_DIR, "metadata.json"), "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
print("\n 转换完成!")
return 0
if __name__ == '__main__':
main()
步骤二:Lance增量打标
使用Ray Serve部署score模型,执行如下代码,部署打分服务。
""" ray serve,一个持续运行中的打分服务,供应数据预处理使用 demo中score函数为random,生产中可以替换为LLM模型或者其他打分器 """ import ray from ray import serve from ray.data import Dataset import pandas as pd import pyarrow as pa import logging import random # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # 初始化Ray ray.init(address="ray://127.0.0.1:10001") # 配置Ray Serve @serve.deployment() class ScoringModel: def __init__(self): # 在这里加载您的模型 # demo中为空 pass def score(self,data_batch) : results = [] print(f"Scoring {len(data_batch)} items") print(data_batch[0]) for item in data_batch: # 调用实际评分模型 (这里使用随机值模拟) score = random.randint(60,100) item["score"] = score results.append(item) return results # 部署模型服务 model_deployment = ScoringModel.bind() serve.run(model_deployment, name="scoring_model")
执行如下代码,对Lance数据增量打标。
import pyarrow as pa from pathlib import Path import lance import ray import pandas as pd from lance.ray.fragment_api import add_columns import random from ray import serve path = "/nas/lance/binary_data_10w/original_data.lance" # Define label generation logic def generate_labels(batch: pa.RecordBatch) -> pa.RecordBatch: """使用Ray Serve服务对数据批次进行评分,返回带分数的新RecordBatch""" # 将RecordBatch转为Pandas DataFrame batch_df = batch.to_pandas() handle = serve.get_app_handle("scoring_model") # 将DataFrame转换为字典列表(每行一个字典) dict_list = batch_df.to_dict('records') # 异步调用评分服务 scored_data_ref = handle.score.remote(dict_list) scores = scored_data_ref.result() # 获取评分结果 return pa.RecordBatch.from_arrays([scores], names=["score1"]) def main(): # Add new columns in parallel lance_ds = lance.dataset(path) add_columns( lance_ds, generate_labels, source_columns=["images"], # Input columns needed ) print("数据打标完成") if __name__ == "__main__": main()
步骤三:LLaMA-Factory多模训练
将
dataset_info
元数据插入dataset_info.json
中,或新建一个dataset_info.json
文件,用于声明数据目录和数据字段的映射关系。请提交工单联系技术支持修改
llama-factory
源代码,以支持Lance数据读取,并确保仅读取需要的数据列。使用
llamafactory webui
训练,使用设定的数据集名称指定数据集。