CosyVoice 性能压测指南

本文介绍如何对 PAI-EAS 上的 CozyVoice 服务进行压力测试。通过模拟高并发、多模式的语音合成请求并收集关键性能指标,可有效评估服务性能,从而保障其稳定性与高可用性。

1. 概述

1.1 压测场景

支持对 CosyVoice 的三种核心功能模式进行压测:

  • 3s复刻(fast_replication):测试基于3秒参考音频的快速音色克隆与语音合成性能。

  • 跨语种复刻(cross_lingual_replication):测试使用中文参考音频合成英文文本的性能。

  • 自然语言控制(natural_language_replication):测试通过自然语言指令(如改变方言、语气)控制语音合成的性能。

1.2 核心性能指标

压测脚本关注以下核心指标,以全面评估服务性能:

指标

含义

说明

首包耗时 (TTFP)

从客户端发出请求到收到第一个音频数据包的耗时。

反映服务端的响应速度和网络延迟,是衡量用户体验的关键指标。越低越好。

实时率 (RTF)

每生成1秒音频需要多少秒。

衡量音频生成速度。RTF 小于 1 表示音频生成速度快于其实时播放速度,是流式体验的核心。

成功率

成功完成的请求数占总请求数的百分比。

衡量服务的稳定性。

P99/P95分位值

99%或95%的请求的首包耗时都低于该值。

评估在高负载下,绝大多数用户的体验水平。

2. 快速开始

本章节将在5分钟内完成环境准备、配置,并执行一次简单的性能测试,以验证环境和配置的正确性。

2.1 准备环境与依赖

  1. 操作系统:建议使用 Linux 操作系统,以 Ubuntu 22.04为例。

  2. Python 环境:使用 Python 3.10 版本。

  3. 安装依赖

    • 将以下内容保存为 requirements.txt 文件。

      numpy==1.26.4
      sseclient-py==1.8.0
      websocket-client==1.8.0
      requests==2.32.3
      注意sseclient-py 不是sseclient
    • 执行以下命令安装所有依赖。推荐使用阿里云 PyPI 镜像以加快下载速度。

      pip install -r requirements.txt -i http://mirrors.cloud.aliyuncs.com/pypi/simple --trusted-host mirrors.cloud.aliyuncs.com

2.2 下载并配置压测文件

获取所有压测相关文件,并按如下目录结构组织。

.
├── benchmark.sh             # 主执行脚本
├── performance_test.py      # 性能测试核心逻辑
├── requirements.txt         # Python 依赖列表
├── official_client/
│   └── websocket_stream.py  # WebSocket 客户端实现
└── asset/
    ├── dataset.txt          # 中文测试文本
    ├── dataset-en.txt       # 英文测试文本
    └── zero_shot_prompt.wav # 用于创建参考音频的 WAV 文件
  • 测试数据。请置于asset文件夹下。

  • 压测代码与脚本

    benchmark.sh

    注意修改 benchmark.sh 文件中 的 URLTOKEN 变量值。

    #!/bin/bash
    
    # 请替换为CosyVoice服务的访问地址和Token。可在EAS服务详情页查看调用信息获取。
    URL=http://cosyvoice-frontend.aaaabbbbcccc.vpc.cn-hangzhou.pai-eas.aliyuncs.com/
    TOKEN=AAAABBBBCCCCDDDDEEEEFFFFGGGGHHHHIIIIJJJJKKKKLLLLMMMMM==
    
    PROTOCAL='WEBSOCKET_STREAM HTTP_STREAM HTTP_NON_STREAM'
    MODES='fast_replication cross_lingual_replication natural_language_replication'
    CONCURRENCY='1'
    REQUEST_NUM=140 # must > 40
    LOGLEVEL=INFO
    
    for ((i=0;i<10;i++)); do
        echo Round-$i
        sleep 1
        for p in ${PROTOCAL}; do
            for c in ${CONCURRENCY}; do
                for m in ${MODES}; do
                    echo Protocol: $p, Concurrency: $c, Mode: $m
                    sleep 1
                    python3 performance_test.py \
                    --random \
                    --concurrency $c \
                    --request-num ${REQUEST_NUM} \
                    --protocol $p \
                    --mode $m \
                    --url ${URL} \
                    --token ${TOKEN} \
                    --log-level ${LOGLEVEL} \
                    --title Protocol=${p}_Concurrency=${c}_Request=${REQUEST_NUM}
                    sleep 5
                done
            done
        done
    done

    performance_test.py

    import base64
    import os
    import io
    import json
    import re
    import sys
    import time
    import numpy as np
    import logging
    import requests
    import argparse
    import multiprocessing
    from sseclient import SSEClient
    from official_client.websocket_stream import TTSClient
    import itertools
    import random
    import importlib.metadata
    from packaging import version
    from concurrent.futures import ThreadPoolExecutor, as_completed
    
    required_version = "2.32.3"
    
    if version.parse(requests.__version__) < version.parse(required_version):
        raise RuntimeError(f"requests version must >= {required_version}")
    
    #output_format = ("mp3", "pcm")
    output_format = ("mp3",)
    
    def sseclient_version():
    
        try:
            importlib.metadata.version("sseclient")
            return "sseclient"
        except importlib.metadata.PackageNotFoundError:
            pass
    
        return "Unknown"
    
    if sseclient_version() == 'sseclient':
        raise "You should pip install sseclient-py instead of sseclient"
    
    MODE_MAP = {
        'fast_replication': ('3s复刻', os.path.join(os.path.dirname(__file__), 'asset/dataset.txt')),
        'cross_lingual_replication': ('跨语种复刻', os.path.join(os.path.dirname(__file__), 'asset/dataset-en.txt')),
        'natural_language_replication': ('自然语言控制', os.path.join(os.path.dirname(__file__), 'asset/dataset.txt')),
    }
    
    PROTOCOL=("HTTP_NON_STREAM", "HTTP_STREAM", "WEBSOCKET_STREAM")
                                          
    DEFAULT_INSTRUCT='用四川话说'
    
    log = logging.getLogger("performance_test")
    #log_formatter = logging.Formatter('%(asctime)s - Process(%(process)s) - %(levelname)s - Line: %(lineno)d - %(message)s')
    log_formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    stream_handler = logging.StreamHandler(stream=sys.stdout)
    stream_handler.setFormatter(log_formatter)
    log.addHandler(stream_handler)
    
    def argument_parser():
        parser = argparse.ArgumentParser()
        group = parser.add_argument_group()
        group.add_argument('--concurrency',
                           required=True,
                           type=int,
                           help='Specify the process num of multiprocess')
        group.add_argument('--request-num',
                           required=False,
                           default=-1,
                           type=int,
                           help='Specify the request num')
        group.add_argument('--request-timeout',
                           required=False,
                           default=30,
                           type=int,
                           help='Specify the request timeout')
        group.add_argument('--title',
                           required=True,
                           type=str,
                           help='title')
        group.add_argument('--url',
                           required=True,
                           type=str,
                           help='Specify the URL')
        group.add_argument('--token',
                           default='EMPTY',
                           type=str,
                           help='Specify the token')
        parser.add_argument("--mode",
                            default="fast_replication",
                            choices=MODE_MAP.keys(),
                            help="请求服务的模式")
        parser.add_argument("--protocol",
                            default="HTTP_STREAM",
                            choices=PROTOCOL,
                            help="请求服务的通信协议")
        parser.add_argument("--log-level",
                            default="INFO",
                            choices=["DEBUG", "INFO", "WARNING", "ERROR"],
                            help="日志级别")
        parser.add_argument("--random",
                            action="store_true",
                            help="add random prefix for text.")
    
    
        return parser.parse_args()
    
    
    def play_wav_audio(file_or_string: str):
        """播放WAV音频"""
        import soundfile as sf
        import sounddevice as sd
        try:
            data, sample_rate = sf.read(file_or_string, dtype='float32')
            if data.ndim == 1:
                data = data.reshape(-1, 1)
            else:
                data = data.T
            sd.play(data, sample_rate)
            sd.wait()
            log.info(f"已播放音频: {file_or_string}")
        except Exception as e:
            log.error(f"播放音频出错: {str(e)}")
    
    
    def create_reference_audio():
        cur_path = os.path.dirname(os.path.abspath(__file__))
        response = requests.post(
            url=f"{args.url}/api/v1/audio/reference_audio",
            headers={
                "Authorization": f"Bearer {args.token}",
            },
            files={"file": open(f"{cur_path}/asset/zero_shot_prompt.wav", "rb")},
            data={"text": "希望你以后能够做的得比我还好哟"}
        )
        if response.status_code != 200:
            log.error(response.text)
            exit()
        reference_audio_id = response.json()['id']
        return reference_audio_id
    
    
    def list_reference_audio():
        response = requests.get(
            url=f"{args.url}/api/v1/audio/reference_audio",
            headers={
                "Authorization": f"Bearer {args.token}",
            },
        )
        if response.status_code != 200:
            log.error(response.text)
            exit()
        reference_audio_list = response.json()
        return reference_audio_list
    
    
    def request_websocket_stream(mode, texts, reference_audio_id, instruct, format):
        match = re.search(r'([http|https]+)://([A-Za-z0-9\-\.\:]+).*', args.url)
        if match:
            endpoint = match.group(2)
        else:
            raise Exception('Invalid url:', args.url)
        params = {
            'mode': mode,
            'reference_audio_id': reference_audio_id,
            'instruct': instruct,
            'output_format': format,
            'sample_rate': '32000',
            'bit_rate': '192k',
            'volume': 1.0,
            'texts': texts
        }
        client = TTSClient(args.token, f"ws://{endpoint}/api-ws/v1/audio/speech", params=params, log_level=args.log_level)
        client.run()
        metrics = client.get_metrics()
        return metrics
    
    
    def performance_test_websocket(dataset, mode, reference_audio_id, instruct=None,
                                   stream=True, protocol="", enable_random=False):
        pid = os.getpid()
        with open(dataset, 'r', encoding='utf-8') as f:
            all_metrics, total_num = [], 0
            for i, text in enumerate(itertools.cycle(f)):
                if args.request_num > 0 and i >= args.request_num:
                    break
                total_num += 1
                if enable_random:
                    text = str(random.randint(0, 100000)) + ", " + text
                text = str(pid) + ", " + text
                format = random.choice(output_format)
                request_metrics = request_websocket_stream(mode=mode, texts=[text],
                                                           reference_audio_id=reference_audio_id,
                                                           instruct=instruct,
                                                           format=format)
                """
                request_metrics = {
                    "first_package_time": 0
                    "rtf": 0,
                    "cost_time": 0,
                    "speech_len": 0,
                }
                """
                if 'error' in request_metrics:
                    log.error(f'[Request-{i}]: mode: {mode}, error: {request_metrics["error"]}')
                    continue
                all_metrics.append(request_metrics)
                log.info(f'[Req-{i}] '
                         f'{protocol}/{MODE_MAP[mode][0]}/{format} '
                         f'cli_first_pkg_time={request_metrics["client_first_package_time"]:.3f} '
                         f'cli_rtf={request_metrics["client_rtf"]:.3f} '
                         f'cli_cost_time={request_metrics["client_cost_time"]:.3f} '
                         f'server_first_pkg_time={request_metrics["server_first_package_time"]:.3f} '
                         f'speech_len={request_metrics["speech_len"]:.3f} '
                         f'server_cost_time={request_metrics["server_cost_time"]:.3f} '
                         f'generate_time={request_metrics["generate_time"]:.3f}')
    
            all_metrics = all_metrics[20:-20]
            total_num -= 40
    
            metric_dict = {
                'first_package_time': np.mean([m["client_first_package_time"] for m in all_metrics]),
                'first_package_time_details': [m["client_first_package_time"] for m in all_metrics],
                'rtf': np.mean([m["client_rtf"] for m in all_metrics]),
                'cost_time': np.mean([m["client_cost_time"] for m in all_metrics]),
                'status': {
                    'success': len(all_metrics),
                    'fail': total_num - len(all_metrics),
                }
            }
            return metric_dict
    
    def request_http(session, mode, text, reference_audio_id, i, instruct=None, stream=True, format="mp3"):
        try:
            #assert mode in ('fast_replication', 'cross_lingual_replication', 'natural_language_replication')
            assert mode in MODE_MAP
            with session.post(
                f"{args.url}/api/v1/audio/speech/",
                headers={
                    "Authorization": f"Bearer {args.token}",
                    "Content-Type": "application/json",
                },
                json={
                    "model": "CosyVoice2-0.5B",
                    "input": {
                        "mode": mode,
                        "reference_audio_id": reference_audio_id,
                        "text": text,
                        "output_format": format,
                        "sample_rate": "32000",
                        "bit_rate": "192k",
                        "volume": 1.0,
                        "speed": 1.0,
                        "debug": True,
                        "instruct": instruct
                    },
                    "stream": stream
                },
                stream=stream,
                timeout=args.request_timeout
            ) as response:
                if response.status_code != 200:
                    yield '', 'error', {'error': response.text}
    
                if stream:
                    messages = SSEClient(response)
                    for j, msg in enumerate(messages.events()):
                        data = json.loads(msg.data)
                        encode_buffer = data['output']['audio']['data']
                        metrics = data['metrics']
    
                        # import torchaudio
                        # decode_buffer = base64.b64decode(encode_buffer)
                        # audio_stream = io.BytesIO(decode_buffer)
                        # play_wav_audio(audio_stream)
    
                        # tts_speech, sample_rate = torchaudio.load(io.BytesIO(decode_buffer))
                        # torchaudio.save(f'output_v2_{i}.wav', tts_speech, sample_rate)
                        yield encode_buffer, None, metrics
    
                else:
                    data = response.json()
                    encode_buffer = data['output']['audio']['data']
                    metrics = data['metrics']
                    yield encode_buffer, None, metrics
        except Exception as e:
            yield '', 'error', {'error': str(e)}
    
    
    def performance_test_http(dataset, mode, reference_audio_id, instruct=None, stream=True,
                              protocol="", enable_random=False):
    
        pid = os.getpid()
        with open(dataset, 'r', encoding='utf-8') as f, requests.Session() as session:
            all_metrics = []
            total_num = 0
            for i, text in enumerate(itertools.cycle(f)):
                if args.request_num > 0 and i >= args.request_num:
                    break
                total_num += 1
                text = text.strip()
                if enable_random:
                    text = str(random.randint(0, 100000)) + ", " + text
                text = str(pid) + ", " + text
                last_time = time.time()
                metrics = []
                format = random.choice(output_format)
                for j, (data, error, metric) in enumerate(request_http(session=session,
                                                                       mode=mode,
                                                                       text=text,
                                                                       reference_audio_id=reference_audio_id,
                                                                       i=i,
                                                                       instruct=instruct,
                                                                       stream=stream,
                                                                       format=format)):
                    """
                    metric = {
                        "first_package_time": 0,
                        "speech_len": 0,
                        "server_cost_time": 0,
                        "generate_time": 0,
                        # "text_normalize_time": 0,
                        # "frontend_zero_shot_time": 0,
                        # "frontend_cross_lingual_time": 0,
                        # "frontend_instruct2_time": 0,
                    }
                    """
                    if error:
                        log.error(f'[Request-{i}]: mode: {MODE_MAP[mode][0]}, error: {metric["error"]}')
                        continue
                    now = time.time()
                    cost_time = now - last_time
                    last_time = now
                    metric['client_cost_time'] = cost_time
                    metrics.append(metric)
                    if j == 0:
                        if mode == 'fast_replication':
                            log.debug(f'[Chunk-{i}_{j}]: '
                                      f'mode: {MODE_MAP[mode][0]}, '
                                      f'client_cost_time: {cost_time:.3f}, '
                                      f'client_rtf: {cost_time / metric["speech_len"]:.3f}, '
                                      f'speech_len: {metric["speech_len"]:3}, '
                                      f'server_cost_time: {metric["server_cost_time"]:.3f}, '
                                      f'generate_time: {metric["generate_time"]:.3f}, '
                                      f'text_normalize_time: {metric["text_normalize_time"]:.3f}, '
                                      f'frontend_zero_shot_time: {metric["frontend_zero_shot_time"]:.3f}')
                        elif mode == 'cross_lingual_replication':
                            log.debug(f'[Chunk-{i}_{j}]: '
                                      f'mode: {MODE_MAP[mode][0]}, '
                                      f'client_cost_time: {cost_time:.3f}, '
                                      f'client_rtf: {cost_time / metric["speech_len"]:.3f}, '
                                      f'speech_len: {metric["speech_len"]:3}, '
                                      f'server_cost_time: {metric["server_cost_time"]:.3f}, '
                                      f'generate_time: {metric["generate_time"]:.3f}, '
                                      f'frontend_cross_lingual_time: {metric["frontend_cross_lingual_time"]:.3f}')
                        elif mode == 'natural_language_replication':
                            log.debug(f'[Chunk-{i}_{j}]: '
                                      f'mode: {MODE_MAP[mode][0]}, '
                                      f'client_cost_time: {cost_time:.3f}, '
                                      f'client_rtf: {cost_time / metric["speech_len"]:.3f}, '
                                      f'speech_len: {metric["speech_len"]:3}, '
                                      f'server_cost_time: {metric["server_cost_time"]:.3f}, '
                                      f'generate_time: {metric["generate_time"]:.3f}, '
                                      f'frontend_instruct2_time: {metric["frontend_instruct2_time"]:.3f}')
                    else:
                        log.debug(f'[Chunk-{i}_{j}]: mode: {MODE_MAP[mode][0]}, client_cost_time: {cost_time:.3f}, client_rtf: {cost_time / metric["speech_len"]:.3f}, speech_len: {metric["speech_len"]:3}, server_cost_time: {metric["server_cost_time"]:.3f}, generate_time: {metric["generate_time"]:.3f}')
    
                if len(metrics) == 0:
                    continue
                request_metrics = {
                    "client_first_package_time": metrics[0]["client_cost_time"],
                    "client_rtf": sum([m["client_cost_time"] for m in metrics]) / sum([m["speech_len"] for m in metrics]),
                    "client_cost_time": sum([m["client_cost_time"] for m in metrics]),
                    "speech_len": sum([m["speech_len"] for m in metrics]),
                    "server_first_package_time": metrics[0]["first_package_time"],
                    "server_rtf": sum([m["server_cost_time"] for m in metrics]) / sum([m["speech_len"] for m in metrics]),
                    "server_cost_time": sum([m["server_cost_time"] for m in metrics]),
                    "generate_time": sum([m["generate_time"] for m in metrics])
                }
                all_metrics.append(request_metrics)
                log.info(f'[Req-{i}] '
                         f'{protocol}/{MODE_MAP[mode][0]}/{format} '
                         f'cli_first_pkg_time={request_metrics["client_first_package_time"]:.3f} '
                         f'cli_rtf={request_metrics["client_rtf"]:.3f} '
                         f'cli_cost_time={request_metrics["client_cost_time"]:.3f} '
                         f'server_first_pkg_time={request_metrics["server_first_package_time"]:.3f} '
                         f'speech_len={request_metrics["speech_len"]:.3f} '
                         f'server_cost_time={request_metrics["server_cost_time"]:.3f} '
                         f'generate_time={request_metrics["generate_time"]:.3f}')
    
            all_metrics = all_metrics[20:-20]
            total_num -= 40
    
            metric_dict = {
                'first_package_time': np.mean([m["client_first_package_time"] for m in all_metrics]),
                'first_package_time_details': [m["client_first_package_time"] for m in all_metrics],
                'rtf': np.mean([m["client_rtf"] for m in all_metrics]),
                'cost_time': np.mean([m["client_cost_time"] for m in all_metrics]),
                'status': {
                    'success': len(all_metrics),
                    'fail': total_num - len(all_metrics),
                }
            }
            return metric_dict
    
    
    def performance_task(params):
    
        reference_audio_id, args = params
        metric_dict = {}
        protocol = args.protocol
        mode = args.mode
        random = args.random
    
        assert mode in MODE_MAP
        assert protocol in PROTOCOL
    
        if protocol == 'HTTP_NON_STREAM':
            request_func = performance_test_http
            stream = False
        elif protocol == 'HTTP_STREAM':
            request_func = performance_test_http
            stream = True
        elif protocol == 'WEBSOCKET_STREAM':
            request_func = performance_test_websocket
            stream = True
    
        kwargs = {
            "dataset": os.path.join(MODE_MAP[mode][1]),
            "mode": mode,
            "reference_audio_id" : reference_audio_id,
            "stream" : stream,
            "protocol" : protocol,
            "enable_random": random
        }
    
        if mode == "natural_language_replication":
            kwargs['instruct'] = DEFAULT_INSTRUCT
    
        metric_dict[mode] = request_func(**kwargs)
        
        return metric_dict
    
    
    def main(args):
        reference_audio_list = list_reference_audio()
        if reference_audio_list:
            reference_audio_id = reference_audio_list[0]['id']
        else:
            reference_audio_id = create_reference_audio()
            log.info(f'create reference_audio_id: {reference_audio_id}')
    
        with ThreadPoolExecutor(max_workers=args.concurrency) as executor:
    
            start_time = time.time()
            results = []
            futures = [
                executor.submit(performance_task, (reference_audio_id, args))
                for _ in range(args.concurrency)
            ]
            for future in as_completed(futures):
                results.append(future.result())
            end_time = time.time()
    
            #overall_dict = {}
            relica_dict = {}
    
            for metric_dict in results:
                for _mode, tdict in metric_dict.items():
                    relica_dict[_mode] = relica_dict.get(_mode, {'status': {'success': 0, 'fail': 0}})
                    success, fail = tdict['status']['success'], tdict['status']['fail']
                    relica_dict[_mode]['status']['success'] += success
                    relica_dict[_mode]['status']['fail'] += fail
                    if success > 0:
                        relica_dict[_mode]['first_package_time'] = relica_dict[_mode].get('first_package_time', 0) + tdict['first_package_time'] * success
                        relica_dict[_mode]['first_package_time_details'] = relica_dict[_mode].get('first_package_time_details', []) + tdict['first_package_time_details']
                        relica_dict[_mode]['rtf'] = relica_dict[_mode].get('rtf', 0) + tdict['rtf'] * success
                        relica_dict[_mode]['cost_time'] = relica_dict[_mode].get('cost_time', 0) + tdict['cost_time'] * success
    
                    #overall_dict['status'] = overall_dict.get('status', {'success': 0, 'fail': 0})
                    #overall_dict['status']['success'] +=  success
                    #overall_dict['status']['fail'] +=  fail
                    #if success > 0:
                    #    overall_dict['first_package_time'] = overall_dict.get('first_package_time', 0) + tdict['first_package_time'] * success
                    #    overall_dict['rtf'] = overall_dict.get('rtf', 0) + tdict['rtf'] * success
                    #    overall_dict['cost_time'] = overall_dict.get('cost_time', 0) + tdict['cost_time'] * success
    
            log.info(f'====Performance Test Summary: {args.title}====')
            for _mode, tdict in relica_dict.items():
                success = tdict['status']['success']
                total = success + tdict['status']['fail']
                success_rate = success / total
                if success > 0:
                    tdict['first_package_time'] /= success
                    tdict['first_package_time_details'] = np.percentile(tdict['first_package_time_details'], [90, 95, 96, 97, 98, 99])
                    tdict['rtf'] /= success
                    tdict['cost_time'] /= success
                log.info(f'{MODE_MAP[_mode][0]}({int(end_time - start_time)}s)')
                log.info(f'  success rate = {success_rate*100:.1f}% ({success}/{total})')
                log.info(f'  avg_first_package_time = {tdict.get("first_package_time", 0):.3f}')
                log.info(f'  first_package_time_p90/p95/p96/p97/p98/p99 = {"/".join(f"{p:.3f}" for p in tdict.get("first_package_time_details"))}')
                log.info(f'  avg_rtf = {tdict.get("rtf", 0):.3f}')
                log.info(f'  avg_cost_time = {tdict.get("cost_time", 0):.3f}')
            log.info(f'================================================================================')
    
            #success = overall_dict['status']['success']
            #total = success + overall_dict['status']['fail']
            #success_rate = success / total
            #overall_dict['first_package_time'] /= success
            #overall_dict['rtf'] /= success
            #overall_dict['cost_time'] /= success
            #log.info("=" * 50)
            #log.info(f'Performance Test Summary, mode: OVERALL: success rate = {success_rate*100:.1f}% ({success} / {total}), first_package_time = {overall_dict["first_package_time"]:.3f}, rtf = {overall_dict["rtf"]:.3f}, cost_time = {overall_dict["cost_time"]:.3f}')
    
    
        # 5. 所有任务完成后,处理结果
        #log.info("All tasks finished!")
    
    
    if __name__ == '__main__':
        args = argument_parser()
        if args.request_num <= 40:
            raise "--request-num must > 40, because first 20 and last 20 request's metrics is ignored."
        log.setLevel(args.log_level)
        main(args)
    
    # /nasmnt/envs/cosyvoice_vllm/bin/python performance_test.py --concurrency 10 --request-num 50 --protocol HTTP_STREAM --url http://localhost:50000 --title test
    

    official_client/websocket_stream.py

    #!/usr/bin/python
    # -*- coding: utf-8 -*-
    
    import base64
    import json
    import logging
    import sys
    import time
    import uuid
    import traceback
    import websocket
    
    
    class TTSClient:
        def __init__(self, api_key, uri, params, log_level='INFO'):
            """
        初始化 TTSClient 实例
    
        参数:
            api_key (str): 鉴权用的 API Key
            uri (str): WebSocket 服务地址
        """
            self._api_key = api_key  # 替换为你的 API Key
            self._uri = uri  # 替换为你的 WebSocket 地址
            self._task_id = str(uuid.uuid4())  # 生成唯一任务 ID
            self._ws = None  # WebSocketApp 实例
            self._task_started = False  # 是否收到 task-started
            self._task_finished = False  # 是否收到 task-finished / task-failed
            self._check_params(params)
            self._params = params
            self._chunk_metrics = []
            self._metrics = {}
            self._first_package_time = None
            self._last_time = None
            self._init_log(log_level)
            self.audio_data = b''
    
        def _init_log(self, log_level):
            self._log = logging.getLogger("ws_client")
            log_formatter = logging.Formatter('%(asctime)s - Process(%(process)s) - %(levelname)s - %(message)s')
            stream_handler = logging.StreamHandler(stream=sys.stdout)
            stream_handler.setFormatter(log_formatter)
            self._log.addHandler(stream_handler)
            self._log.setLevel(log_level)
    
        def get_metrics(self):
            """获取合成结果性能指标"""
            return self._metrics
    
        def _check_params(self, params):
            assert 'mode' in params and params['mode'] in ['fast_replication', 'cross_lingual_replication', 'natural_language_replication']
            assert 'reference_audio_id' in params
            assert 'output_format' in params and params['output_format'] in ['wav', 'mp3', 'pcm']
            if params['mode'] == 'natural_language_replication':
                assert 'instruct' in params and params['instruct']
            else:
                if 'instruct' in params:
                    del params['instruct']
    
        def on_open(self, ws):
            """
        WebSocket 连接建立时回调函数
        发送 run-task 指令开启语音合成任务
        """
            self._log.debug("WebSocket 已连接")
    
            # 构造 run-task 指令
            run_task_cmd = {
                "header": {
                    "action": "run-task",
                    "task_id": self._task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "task_group": "audio",
                    "task": "tts",
                    "function": "SpeechSynthesizer",
                    "model": "cosyvoice-v2",
                    "parameters": {
                        "mode": self._params['mode'],
                        "reference_audio_id": self._params['reference_audio_id'],
                        "output_format": self._params.get('output_format', 'wav'),
                        "sample_rate": self._params.get('sample_rate', 24000),
                        "bit_rate": self._params.get('bit_rate', '192k'),
                        "volume": self._params.get('volume', 1.0),
                        "instruct": self._params.get('instruct', ''),
                        "speed": self._params.get('speed', 1.0),
                        "debug": True,
                    },
                    "input": {}
                }
            }
    
            # 发送 run-task 指令
            ws.send(json.dumps(run_task_cmd))
            self._log.debug("已发送 run-task 指令")
    
        def on_message(self, ws, message):
            """
        接收到消息时的回调函数
        区分文本和二进制消息处理
        """
            try:
                msg_json = json.loads(message)
                # self._log.debug(f"收到 JSON 消息: {msg_json}")
                self._log.debug(f"收到 JSON 消息: {msg_json['header']['event']}")
    
                if "header" in msg_json:
                    header = msg_json["header"]
    
                    if "event" in header:
                        event = header["event"]
    
                        if event == "task-started":
                            self._log.debug("任务已启动")
                            self._task_started = True
    
                            # 发送 continue-task 指令
                            for text in self._params['texts']:
                                self.send_continue_task(text)
    
                            # 所有 continue-task 发送完成后发送 finish-task
                            self.send_finish_task()
                            self._last_time = time.time()
                        elif event == "result-generated":
                            metrics = msg_json['payload']['metrics']
                            cur_time = time.time()
                            metrics['client_cost_time'] = cur_time - self._last_time
                            self._last_time = cur_time
    
                            encode_data = msg_json["payload"]["output"]["audio"]["data"]
                            decode_data = base64.b64decode(encode_data)
                            self._log.debug(f"收到音频数据,大小: {len(decode_data)} 字节")
                            self.audio_data += decode_data
    
                            metrics['client_rtf'] = metrics['client_cost_time'] / metrics['speech_len']
                            self._chunk_metrics.append(metrics)
    
                        elif event == "task-finished":
                            self._metrics = {
                                'client_first_package_time': self._chunk_metrics[0]['client_cost_time'],
                                "client_rtf": sum([m["client_cost_time"] for m in self._chunk_metrics]) / sum([m["speech_len"] for m in self._chunk_metrics]),
                                'client_cost_time': sum([m["client_cost_time"] for m in self._chunk_metrics]),
                                'speech_len': sum([m["speech_len"] for m in self._chunk_metrics]),
                                'server_first_package_time': self._chunk_metrics[0]['server_cost_time'],
                                'server_rtf': sum([m["server_cost_time"] for m in self._chunk_metrics]) / sum([m["speech_len"] for m in self._chunk_metrics]),
                                'server_cost_time': sum([m["server_cost_time"] for m in self._chunk_metrics]),
                                "generate_time": sum([m["generate_time"] for m in self._chunk_metrics])
                            }
    
                            self._log.debug(f"任务已完成, 请求性能指标: client_first_package_time: {self._metrics['client_first_package_time']:.3f}, client_rtf: {self._metrics['client_rtf']:.3f}, client_cost_time: {self._metrics['client_cost_time']:.3f}, speech_len: {self._metrics['speech_len']:.3f}, server_cost_time: {self._metrics['server_cost_time']:.3f}, generate_time: {self._metrics['generate_time']:.3f}")
                            self._task_finished = True
                            self.close(ws)
    
                        elif event == "task-failed":
                            self._log.error(f"任务失败: {msg_json}")
                            self._task_finished = True
                            self.close(ws)
    
            except json.JSONDecodeError as e:
                self._log.error(f"JSON 解析失败: {str(e)}\t{traceback.format_exc()}")
    
        def on_error(self, ws, error):
            """发生错误时的回调"""
            self._log.error(f"WebSocket 出错: {error}\t{traceback.format_exc()}")
            self._metrics = {'error': error}
    
        def on_close(self, ws, close_status_code, close_msg):
            """连接关闭时的回调"""
            self._log.debug(f"WebSocket 已关闭: {close_msg} ({close_status_code})")
    
        def send_continue_task(self, text):
            """发送 continue-task 指令,附带要合成的文本内容"""
            cmd = {
                "header": {
                    "action": "continue-task",
                    "task_id": self._task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "input": {
                        "text": text
                    }
                }
            }
    
            self._ws.send(json.dumps(cmd))
            self._log.debug(f"已发送 continue-task 指令,文本内容: {text}")
    
        def send_finish_task(self):
            """发送 finish-task 指令,结束语音合成任务"""
            cmd = {
                "header": {
                    "action": "finish-task",
                    "task_id": self._task_id,
                    "streaming": "duplex"
                },
                "payload": {
                    "input": {}
                }
            }
    
            self._ws.send(json.dumps(cmd))
            self._log.debug("已发送 finish-task 指令")
    
        def close(self, ws):
            """主动关闭连接"""
            if ws and ws.sock and ws.sock.connected:
                ws.close()
                self._log.debug("已主动关闭连接")
    
        def run(self):
            """启动 WebSocket 客户端"""
            # 设置请求头部(鉴权)
            header = {
                "Authorization": f"Bearer {self._api_key}",
            }
    
            # 创建 WebSocketApp 实例
            self._ws = websocket.WebSocketApp(
                self._uri,
                header=header,
                on_open=self.on_open,
                on_message=self.on_message,
                on_error=self.on_error,
                on_close=self.on_close
            )
    
            self._log.debug("正在监听 WebSocket 消息...")
            self._ws.run_forever()  # 启动长连接监听
    
    
    # 示例使用方式
    if __name__ == "__main__":
        API_KEY = 'your-api-key'
        SERVER_URI = "ws://localhost:50000/api-ws/v1/audio/speech"
    
        texts = [
            # "床前明月光,疑是地上霜。",
            # "举头望明月,低头思故乡。"
            "北京, 你好, 收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。"
        ]
        params = {
            'mode': 'natural_language_replication',
            'texts': texts,
            'reference_audio_id': '<reference_audio_id>', 
            'speed': 1.0,
            'output_format': 'wav',
            'sample_rate': 24000,
            "instruct": "用冷静的语气说",
        }
    
        client = TTSClient(API_KEY, SERVER_URI, params, log_level='DEBUG')
        client.run()
        with open('./websocket_stream.wav', 'wb') as wfile:
            wfile.write(client.audio_data)
    

2.3 开始压测

执行以下命令即开始压测:

bash benchmark.sh

3. 详细指南

3.1 主脚本配置 (benchmark.sh)

benchmark.sh 是一个外层封装脚本,通过循环调用 performance_test.py,实现对不同协议、并发数和模式的组合测试。

变量

作用

URL

PAI-EAS 模型在线服务终端节点地址。

TOKEN

访问服务的鉴权令牌

PROTOCOL

通信协议,可设置多个,用空格隔开。

MODES

服务模式,可设置多个,用空格隔开。

CONCURRENCY

并发数,即同时运行的测试进程数。

REQUEST_NUM

每个进程发起的总请求数

LOGLEVEL

日志级别

3.2 核心脚本参数 (performance_test.py)

performance_test.py 是执行性能测试的核心脚本,接收命令行参数以控制测试行为。

参数

是否必须

描述

--concurrency

指定并发进程数。

--request-num

每个进程的总请求数为确保统计有效,该值必须大于40。脚本会自动忽略前20个(预热)和后20个(收尾)请求的性能数据。

--title

本次测试的标题,会显示在最终的性能总结报告中。

--url

服务的终端节点地址。

--token

访问服务的鉴权令牌

--mode

服务模式。可选:fast_replication, cross_lingual_replication, natural_language_replication

--protocol

通信协议。可选:HTTP_NON_STREAM, HTTP_STREAM, WEBSOCKET_STREAM

--request-timeout

HTTP 请求的客户端请求超时时间(秒)。默认值为 30

--log-level

日志级别。可选:DEBUG, INFO, WARNING, ERROR

--random

在每个请求的文本前添加一个随机数前缀,用于防止服务端缓存。

3.3 压测模式说明

可以通过 --mode 参数指定不同的语音合成模式。

模式

说明

数据集

描述

fast_replication

3s复刻

asset/dataset.txt

使用 asset/zero_shot_prompt.wav 作为参考音频,合成 dataset.txt 中的中文文本。

cross_lingual_replication

跨语种复刻

asset/dataset-en.txt

使用中文参考音频,合成 dataset-en.txt 中的英文文本。目前仅支持中英跨语种复刻

natural_language_replication

自然语言控制

asset/dataset.txt

3s复刻的基础上,增加 instruct 参数,通过自然语言指令控制合成风格。例如:用四川话说用冷静的语气说

3.4 通信协议说明

可以通过 --protocol 参数选择不同的通信协议进行测试。

协议

描述

推荐与适用条件

HTTP_NON_STREAM

非流式 HTTP:发送完整请求,等待服务端处理完毕后一次性返回完整音频。

适用于短文本合成。延迟较高,不适合实时交互场景。

HTTP_STREAM

流式 HTTP:服务端以 Server-Sent Events (SSE) 的形式,边生成边推送音频数据流。

推荐。适用于 Web 前端等需要快速响应的场景,可以显著降低首包耗时

WEBSOCKET_STREAM

流式 WebSocket:通过建立持久的 WebSocket 连接进行双向通信和音频流传输。

推荐。适用于需要频繁交互或有严格实时性要求的后端服务。相比 HTTP,减少了每次请求建立连接的开销。

4. 结果解读

4.1 日志输出示例

在测试过程中,您会看到两种主要的日志输出:

  • 单次请求日志:每一行代表一次完整的语音合成请求的性能数据。

    [Req-21] WEBSOCKET_STREAM/3s复刻/mp3 cli_first_pkg_time=0.251 cli_rtf=0.280 cli_cost_time=0.871 server_first_pkg_time=0.150 speech_len=3.110 server_cost_time=0.750 generate_time=0.650
  • 性能总结报告:在所有并发进程结束后,会打印格式化的总结报告。

    ====Performance Test Summary: Protocol=WEBSOCKET_STREAM_Concurrency=1_Request=41====
    3s复刻(15s)
      success rate = 100.0% (1/1)
      avg_first_package_time = 0.251
      first_package_time_p90/p95/p96/p97/p98/p99 = 0.251/0.251/0.251/0.251/0.251/0.251
      avg_rtf = 0.280
      avg_cost_time = 0.871
    ================================================================================

4.2 性能指标定义与分析

指标 (日志中)

全称/计算方式

业务含义与分析建议

cli_first_pkg_time

Client First Package Time

客户端首包耗时:从客户端发起请求到收到第一个音频数据包的总耗时。该值包含网络延迟和服务端处理时间,是衡量用户体验最直接的指标。

cli_rtf

Client Real Time Factor

客户端实时率客户端总耗时 / 音频总时长。该值小于1表示生成速度快于播放速度,是衡量流式体验的关键。

cli_cost_time

Client Cost Time

客户端总耗时:客户端从发请求到接收完所有音频数据的总耗时。

server_first_pkg_time

Server First Package Time

服务端首包耗时:服务端内部从接收请求到生成第一个音频包的耗时。此指标排除了网络延迟,纯粹反映服务端处理性能。

speech_len

Speech Length

合成出的音频时长(秒)。

server_cost_time

Server Cost Time

服务端耗时:服务端处理该请求的总耗时。

generate_time

Generate Time

模型生成时间:模型实际用于生成音频的耗时,是 server_cost_time 的一部分。

success rate

成功率

成功请求数 / 总请求数。反映了服务的稳定性。

first_package_time_p90/p95...

首包耗时百分位

例如,p99=0.500 表示 99% 的请求其首包耗时都低于 0.5 秒。这是衡量服务在高并发下整体性能稳定性的重要参考。

5. 故障排查 (FAQ)

Q:报错 ERROR Authorization failed

请按以下步骤排查:

  1. 检查 benchmark.sh 中的 TOKEN 变量是否已替换为您从 PAI-EAS 服务详情页获取的真实令牌

  2. 确认 TOKEN 字符串的完整性,复制时不要遗漏任何字符。

Q:报错 FileNotFoundError: [Errno 2] No such file or directory: '/root/asset/dataset.txt'

请确保您已按照 2.2 下载并配置压测文件中的目录结构组织文件。asset 文件夹必须与 performance_test.py 脚本位于同一目录下。

6. 附录

6.1 official_client/websocket_stream.py 说明

此文件是为 WEBSOCKET_STREAM 协议实现的专用 Python 客户端,封装了与 CosyVoice 服务进行语音合成的完整交互逻辑。performance_test.py 在测试 WebSocket 协议时会调用此客户端。其核心流程如下:

  1. 建立连接:使用 Authorization: Bearer {TOKEN} 头进行鉴权

  2. 启动任务:发送 run-task 指令,并携带合成模式、参考音频ID等参数。

  3. 发送文本:通过 continue-task 指令发送待合成的文本。

  4. 结束任务:发送 finish-task 指令,通知服务端文本已全部发送。

  5. 接收音频:在 result-generated 事件中接收服务端返回的音频数据流。

  6. 处理结果:在 task-finishedtask-failed 事件中获取最终结果或错误信息。