diff --git a/.gitignore b/.gitignore index 007add1d..2acb6c2e 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ filelists /*.npy /*.wav /*.mp3 +/*.lab /results /data /.idea @@ -25,6 +26,6 @@ asr-label* /fishenv /.locale /demo-audios -ref_data* +/references /example /faster_whisper diff --git a/tools/api.py b/tools/api.py index 05b31338..bf27f25b 100644 --- a/tools/api.py +++ b/tools/api.py @@ -9,16 +9,20 @@ from argparse import ArgumentParser from http import HTTPStatus from pathlib import Path -from typing import Annotated, Literal, Optional +from typing import Annotated, Any, Literal, Optional import numpy as np +import ormsgpack import pyrootutils import soundfile as sf import torch import torchaudio +from baize.datastructures import ContentType from kui.asgi import ( Body, + FactoryClass, HTTPException, + HttpRequest, HttpView, JSONResponse, Kui, @@ -27,14 +31,16 @@ ) from kui.asgi.routing import MultimethodRoutes from loguru import logger -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, conint pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) # from fish_speech.models.vqgan.lit_module import VQGAN from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture +from fish_speech.text.chn_text_norm.text import Text as ChnNormedText from fish_speech.utils import autocast_exclude_mps from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model +from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text from tools.llama.generate import ( GenerateRequest, GenerateResponse, @@ -82,11 +88,8 @@ async def other_exception_handler(exc: "Exception"): def load_audio(reference_audio, sr): if len(reference_audio) > 255 or not Path(reference_audio).exists(): - try: - audio_data = base64.b64decode(reference_audio) - reference_audio = io.BytesIO(audio_data) - except base64.binascii.Error: - raise ValueError("Invalid path or base64 string") + audio_data = reference_audio + reference_audio = io.BytesIO(audio_data) waveform, original_sr = torchaudio.load( reference_audio, backend="sox" if sys.platform == "linux" else "soundfile" @@ -153,56 +156,36 @@ def decode_vq_tokens( routes = MultimethodRoutes(base_class=HttpView) -def get_random_paths(base_path, data, speaker, emotion): - if base_path and data and speaker and emotion and (Path(base_path).exists()): - if speaker in data and emotion in data[speaker]: - files = data[speaker][emotion] - lab_files = [f for f in files if f.endswith(".lab")] - wav_files = [f for f in files if f.endswith(".wav")] +class ServeReferenceAudio(BaseModel): + audio: bytes + text: str - if lab_files and wav_files: - selected_lab = random.choice(lab_files) - selected_wav = random.choice(wav_files) - lab_path = Path(base_path) / speaker / emotion / selected_lab - wav_path = Path(base_path) / speaker / emotion / selected_wav - if lab_path.exists() and wav_path.exists(): - return lab_path, wav_path - - return None, None - - -def load_json(json_file): - if not json_file: - logger.info("Not using a json file") - return None - try: - with open(json_file, "r", encoding="utf-8") as file: - data = json.load(file) - except FileNotFoundError: - logger.warning(f"ref json not found: {json_file}") - data = None - except Exception as e: - logger.warning(f"Loading json failed: {e}") - data = None - return data - - -class InvokeRequest(BaseModel): +class ServeTTSRequest(BaseModel): text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游." - reference_text: Optional[str] = None - reference_audio: Optional[str] = None + chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 + # Audio format + format: Literal["wav", "pcm", "mp3"] = "wav" + mp3_bitrate: Literal[64, 128, 192] = 128 + # References audios for in-context learning + references: list[ServeReferenceAudio] = [] + # Reference id + # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ + # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 + reference_id: str | None = None + # Normalize text for en & zh, this increase stability for numbers + normalize: bool = True + mp3_bitrate: Optional[int] = 64 + opus_bitrate: Optional[int] = -1000 + # Balance mode will reduce latency to 300ms, but may decrease stability + latency: Literal["normal", "balanced"] = "normal" + # not usually used below + streaming: bool = False + emotion: Optional[str] = None max_new_tokens: int = 1024 - chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100 top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 - emotion: Optional[str] = None - format: Literal["wav", "mp3", "flac"] = "wav" - streaming: bool = False - ref_json: Optional[str] = "ref_data.json" - ref_base: Optional[str] = "ref_data" - speaker: Optional[str] = None def get_content_type(audio_format): @@ -217,35 +200,52 @@ def get_content_type(audio_format): @torch.inference_mode() -def inference(req: InvokeRequest): - # Parse reference audio aka prompt - prompt_tokens = None - - ref_data = load_json(req.ref_json) - ref_base = req.ref_base - - lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion) - - if lab_path and wav_path: - with open(lab_path, "r", encoding="utf-8") as lab_file: - ref_text = lab_file.read() - req.reference_audio = wav_path - req.reference_text = ref_text - logger.info("ref_path: " + str(wav_path)) - logger.info("ref_text: " + ref_text) - - # Parse reference audio aka prompt - prompt_tokens = encode_reference( - decoder_model=decoder_model, - reference_audio=req.reference_audio, - enable_reference_audio=req.reference_audio is not None, - ) - logger.info(f"ref_text: {req.reference_text}") +def inference(req: ServeTTSRequest): + + idstr: str | None = req.reference_id + if idstr is not None: + ref_folder = Path("references") / idstr + ref_folder.mkdir(parents=True, exist_ok=True) + ref_audios = list_files( + ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False + ) + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=audio_to_bytes(str(ref_audio)), + enable_reference_audio=True, + ) + for ref_audio in ref_audios + ] + prompt_texts = [ + read_ref_text(str(ref_audio.with_suffix(".lab"))) + for ref_audio in ref_audios + ] + + else: + # Parse reference audio aka prompt + refs = req.references + if refs is None: + refs = [] + prompt_tokens = [ + encode_reference( + decoder_model=decoder_model, + reference_audio=ref.audio, + enable_reference_audio=True, + ) + for ref in refs + ] + prompt_texts = [ref.text for ref in refs] + # LLAMA Inference request = dict( device=decoder_model.device, max_new_tokens=req.max_new_tokens, - text=req.text, + text=( + req.text + if not req.normalize + else ChnNormedText(raw_text=req.text).normalize() + ), top_p=req.top_p, repetition_penalty=req.repetition_penalty, temperature=req.temperature, @@ -254,7 +254,7 @@ def inference(req: InvokeRequest): chunk_length=req.chunk_length, max_length=2048, prompt_tokens=prompt_tokens, - prompt_text=req.reference_text, + prompt_text=prompt_texts, ) response_queue = queue.Queue() @@ -307,40 +307,7 @@ def inference(req: InvokeRequest): yield fake_audios -def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True): - if not use_auto_rerank: - # 如果不使用 auto_rerank,直接调用原始的 inference 函数 - return inference(req) - - zh_model, en_model = load_model() - max_attempts = 5 - best_wer = float("inf") - best_audio = None - - for attempt in range(max_attempts): - # 调用原始的 inference 函数 - audio_generator = inference(req) - fake_audios = next(audio_generator) - - asr_result = batch_asr( - zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100 - )[0] - wer = calculate_wer(req.text, asr_result["text"]) - - if wer <= 0.1 and not asr_result["huge_gap"]: - return fake_audios - - if wer < best_wer: - best_wer = wer - best_audio = fake_audios - - if attempt == max_attempts - 1: - break - - return best_audio - - -async def inference_async(req: InvokeRequest): +async def inference_async(req: ServeTTSRequest): for chunk in inference(req): yield chunk @@ -349,9 +316,9 @@ async def buffer_to_async_generator(buffer): yield buffer -@routes.http.post("/v1/invoke") +@routes.http.post("/v1/tts") async def api_invoke_model( - req: Annotated[InvokeRequest, Body(exclusive=True)], + req: Annotated[ServeTTSRequest, Body(exclusive=True)], ): """ Invoke model and generate audio @@ -422,7 +389,7 @@ def parse_args(): parser.add_argument("--half", action="store_true") parser.add_argument("--compile", action="store_true") parser.add_argument("--max-text-length", type=int, default=0) - parser.add_argument("--listen", type=str, default="127.0.0.1:8000") + parser.add_argument("--listen", type=str, default="127.0.0.1:8080") parser.add_argument("--workers", type=int, default=1) parser.add_argument("--use-auto-rerank", type=bool, default=True) @@ -436,18 +403,30 @@ def parse_args(): }, ).routes + +class MsgPackRequest(HttpRequest): + async def data(self) -> Annotated[Any, ContentType("application/msgpack")]: + if self.content_type == "application/msgpack": + return ormsgpack.unpackb(await self.body) + + raise HTTPException( + HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + headers={"Accept": "application/msgpack"}, + ) + + app = Kui( routes=routes + openapi[1:], # Remove the default route exception_handlers={ HTTPException: http_execption_handler, Exception: other_exception_handler, }, + factory_class=FactoryClass(http=MsgPackRequest), cors_config={}, ) if __name__ == "__main__": - import threading import uvicorn @@ -474,18 +453,16 @@ def parse_args(): # Dry run to check if the model is loaded correctly and avoid the first-time latency list( inference( - InvokeRequest( + ServeTTSRequest( text="Hello world.", - reference_text=None, - reference_audio=None, + references=[], + reference_id=None, max_new_tokens=0, top_p=0.7, repetition_penalty=1.2, temperature=0.7, emotion=None, format="wav", - ref_base=None, - ref_json=None, ) ) ) diff --git a/tools/file.py b/tools/file.py index b4b8051d..f7a05973 100644 --- a/tools/file.py +++ b/tools/file.py @@ -1,3 +1,4 @@ +import base64 from pathlib import Path from typing import Union @@ -23,6 +24,22 @@ } +def audio_to_bytes(file_path): + if not file_path or not Path(file_path).exists(): + return None + with open(file_path, "rb") as wav_file: + wav = wav_file.read() + return wav + + +def read_ref_text(ref_text): + path = Path(ref_text) + if path.exists() and path.is_file(): + with path.open("r", encoding="utf-8") as file: + return file.read() + return ref_text + + def list_files( path: Union[Path, str], extensions: set[str] = None, diff --git a/tools/gen_ref.py b/tools/gen_ref.py deleted file mode 100644 index a771903b..00000000 --- a/tools/gen_ref.py +++ /dev/null @@ -1,36 +0,0 @@ -import json -from pathlib import Path - - -def scan_folder(base_path): - wav_lab_pairs = {} - - base = Path(base_path) - for suf in ["wav", "lab"]: - for f in base.rglob(f"*.{suf}"): - relative_path = f.relative_to(base) - parts = relative_path.parts - print(parts) - if len(parts) >= 3: - character = parts[0] - emotion = parts[1] - - if character not in wav_lab_pairs: - wav_lab_pairs[character] = {} - if emotion not in wav_lab_pairs[character]: - wav_lab_pairs[character][emotion] = [] - wav_lab_pairs[character][emotion].append(str(f.name)) - - return wav_lab_pairs - - -def save_to_json(data, output_file): - with open(output_file, "w", encoding="utf-8") as file: - json.dump(data, file, ensure_ascii=False, indent=2) - - -base_path = "ref_data" -out_ref_file = "ref_data.json" - -wav_lab_pairs = scan_folder(base_path) -save_to_json(wav_lab_pairs, out_ref_file) diff --git a/tools/merge_asr_files.py b/tools/merge_asr_files.py deleted file mode 100644 index cc120620..00000000 --- a/tools/merge_asr_files.py +++ /dev/null @@ -1,55 +0,0 @@ -import os -from pathlib import Path - -from pydub import AudioSegment -from tqdm import tqdm - -from tools.file import AUDIO_EXTENSIONS, list_files - - -def merge_and_delete_files(save_dir, original_files): - save_path = Path(save_dir) - audio_slice_files = list_files( - path=save_dir, extensions=AUDIO_EXTENSIONS.union([".lab"]), recursive=True - ) - audio_files = {} - label_files = {} - for file_path in tqdm(audio_slice_files, desc="Merging audio files"): - rel_path = Path(file_path).relative_to(save_path) - (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True) - if file_path.suffix == ".wav": - prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0] - if prefix == rel_path.parent / file_path.stem: - continue - audio = AudioSegment.from_wav(file_path) - if prefix in audio_files.keys(): - audio_files[prefix] = audio_files[prefix] + audio - else: - audio_files[prefix] = audio - - elif file_path.suffix == ".lab": - prefix = rel_path.parent / file_path.stem.rsplit("-", 1)[0] - if prefix == rel_path.parent / file_path.stem: - continue - with open(file_path, "r", encoding="utf-8") as f: - label = f.read() - if prefix in label_files.keys(): - label_files[prefix] = label_files[prefix] + ", " + label - else: - label_files[prefix] = label - - for prefix, audio in audio_files.items(): - output_audio_path = save_path / f"{prefix}.wav" - audio.export(output_audio_path, format="wav") - - for prefix, label in label_files.items(): - output_label_path = save_path / f"{prefix}.lab" - with open(output_label_path, "w", encoding="utf-8") as f: - f.write(label) - - for file_path in original_files: - os.remove(file_path) - - -if __name__ == "__main__": - merge_and_delete_files("/made/by/spicysama/laziman", [__file__]) diff --git a/tools/msgpack_api.py b/tools/msgpack_api.py new file mode 100644 index 00000000..52f2220d --- /dev/null +++ b/tools/msgpack_api.py @@ -0,0 +1,68 @@ +from typing import Annotated, AsyncGenerator, Literal, Optional + +import httpx +import ormsgpack +from pydantic import AfterValidator, BaseModel, Field, conint + + +class ServeReferenceAudio(BaseModel): + audio: bytes + text: str + + +class ServeTTSRequest(BaseModel): + text: str + chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200 + # Audio format + format: Literal["wav", "pcm", "mp3"] = "wav" + mp3_bitrate: Literal[64, 128, 192] = 128 + # References audios for in-context learning + references: list[ServeReferenceAudio] = [] + # Reference id + # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/ + # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1 + reference_id: str | None = None + # Normalize text for en & zh, this increase stability for numbers + normalize: bool = True + mp3_bitrate: Optional[int] = 64 + opus_bitrate: Optional[int] = -1000 + # Balance mode will reduce latency to 300ms, but may decrease stability + latency: Literal["normal", "balanced"] = "normal" + # not usually used below + streaming: bool = False + emotion: Optional[str] = None + max_new_tokens: int = 1024 + top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 + repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2 + temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7 + + +# priority: ref_id > references +request = ServeTTSRequest( + text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.", + # reference_id="114514", + references=[ + ServeReferenceAudio( + audio=open("lengyue.wav", "rb").read(), + text=open("lengyue.lab", "r", encoding="utf-8").read(), + ) + ], + streaming=True, +) + +with ( + httpx.Client() as client, + open("hello.wav", "wb") as f, +): + with client.stream( + "POST", + "http://127.0.0.1:8080/v1/tts", + content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC), + headers={ + "authorization": "Bearer YOUR_API_KEY", + "content-type": "application/msgpack", + }, + timeout=None, + ) as response: + for chunk in response.iter_bytes(): + f.write(chunk) diff --git a/tools/post_api.py b/tools/post_api.py index 15389307..79c03cb6 100644 --- a/tools/post_api.py +++ b/tools/post_api.py @@ -1,40 +1,18 @@ import argparse import base64 -import json import wave from pathlib import Path import pyaudio import requests +from pydub import AudioSegment +from pydub.playback import play +from tools.file import audio_to_bytes, read_ref_text -def wav_to_base64(file_path): - if not file_path or not Path(file_path).exists(): - return None - with open(file_path, "rb") as wav_file: - wav_content = wav_file.read() - base64_encoded = base64.b64encode(wav_content) - return base64_encoded.decode("utf-8") +def parse_args(): -def read_ref_text(ref_text): - path = Path(ref_text) - if path.exists() and path.is_file(): - with path.open("r", encoding="utf-8") as file: - return file.read() - return ref_text - - -def play_audio(audio_content, format, channels, rate): - p = pyaudio.PyAudio() - stream = p.open(format=format, channels=channels, rate=rate, output=True) - stream.write(audio_content) - stream.stop_stream() - stream.close() - p.terminate() - - -if __name__ == "__main__": parser = argparse.ArgumentParser( description="Send a WAV file and text to a server and receive synthesized audio." ) @@ -43,16 +21,24 @@ def play_audio(audio_content, format, channels, rate): "--url", "-u", type=str, - default="http://127.0.0.1:8080/v1/invoke", + default="http://127.0.0.1:8080/v1/tts", help="URL of the server", ) parser.add_argument( "--text", "-t", type=str, required=True, help="Text to be synthesized" ) + parser.add_argument( + "--reference_id", + "-id", + type=str, + default=None, + help="ID of the reference model o be used for the speech", + ) parser.add_argument( "--reference_audio", "-ra", type=str, + nargs="+", default=None, help="Path to the WAV file", ) @@ -60,9 +46,30 @@ def play_audio(audio_content, format, channels, rate): "--reference_text", "-rt", type=str, + nargs="+", default=None, help="Reference text for voice synthesis", ) + parser.add_argument( + "--output", + "-o", + type=str, + default="generated_audio", + help="Output audio file name", + ) + parser.add_argument( + "--play", + type=bool, + default=True, + help="Whether to play audio after receiving data", + ) + parser.add_argument("--normalize", type=bool, default=True) + parser.add_argument( + "--format", type=str, choices=["wav", "mp3", "flac"], default="wav" + ) + parser.add_argument("--mp3_bitrate", type=int, default=64) + parser.add_argument("--opus_bitrate", type=int, default=-1000) + parser.add_argument("--latency", type=str, default="normal", help="延迟选项") parser.add_argument( "--max_new_tokens", type=int, @@ -88,7 +95,6 @@ def play_audio(audio_content, format, channels, rate): "--speaker", type=str, default=None, help="Speaker ID for voice synthesis" ) parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion") - parser.add_argument("--format", type=str, default="wav", help="Audio format") parser.add_argument( "--streaming", type=bool, default=False, help="Enable streaming response" ) @@ -97,18 +103,36 @@ def play_audio(audio_content, format, channels, rate): ) parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio") - args = parser.parse_args() + return parser.parse_args() + + +if __name__ == "__main__": - base64_audio = wav_to_base64(args.reference_audio) + args = parse_args() - ref_text = args.reference_text - if ref_text: - ref_text = read_ref_text(ref_text) + idstr: str | None = args.reference_id + # priority: ref_id > [{text, audio},...] + if idstr is None: + base64_audios = [ + audio_to_bytes(ref_audio) for ref_audio in args.reference_audio + ] + ref_texts = [read_ref_text(ref_text) for ref_text in args.reference_text] + else: + base64_audios = [] + ref_texts = [] + pass # in api.py data = { "text": args.text, - "reference_text": ref_text, - "reference_audio": base64_audio, + "references": [ + dict(text=ref_text, audio=ref_audio) + for ref_text, ref_audio in zip(ref_texts, base64_audios) + ], + "reference_id": idstr, + "normalize": args.normalize, + "format": args.format, + "mp3_bitrate": args.mp3_bitrate, + "opus_bitrate": args.opus_bitrate, "max_new_tokens": args.max_new_tokens, "chunk_length": args.chunk_length, "top_p": args.top_p, @@ -116,22 +140,20 @@ def play_audio(audio_content, format, channels, rate): "temperature": args.temperature, "speaker": args.speaker, "emotion": args.emotion, - "format": args.format, "streaming": args.streaming, } response = requests.post(args.url, json=data, stream=args.streaming) - audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format - if response.status_code == 200: if args.streaming: p = pyaudio.PyAudio() + audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format stream = p.open( format=audio_format, channels=args.channels, rate=args.rate, output=True ) - wf = wave.open("generated_audio.wav", "wb") + wf = wave.open(f"{args.output}.wav", "wb") wf.setnchannels(args.channels) wf.setsampwidth(p.get_sample_size(audio_format)) wf.setframerate(args.rate) @@ -153,12 +175,14 @@ def play_audio(audio_content, format, channels, rate): wf.close() else: audio_content = response.content - - with open("generated_audio.wav", "wb") as audio_file: + audio_path = f"{args.output}.{args.format}" + with open(audio_path, "wb") as audio_file: audio_file.write(audio_content) - play_audio(audio_content, audio_format, args.channels, args.rate) - print("Audio has been saved to 'generated_audio.wav'.") + audio = AudioSegment.from_file(audio_path, format=args.format) + if args.play: + play(audio) + print(f"Audio has been saved to '{audio_path}'.") else: print(f"Request failed with status code {response.status_code}") print(response.json()) diff --git a/tools/sensevoice/fun_asr.py b/tools/sensevoice/fun_asr.py index 02c15a59..6789316d 100644 --- a/tools/sensevoice/fun_asr.py +++ b/tools/sensevoice/fun_asr.py @@ -26,7 +26,7 @@ def uvr5_cli( output_folder: Path, audio_files: list[Path] | None = None, output_format: str = "flac", - model: str = "BS-Roformer-Viperx-1296.ckpt", + model: str = "BS-Roformer-Viperx-1297.ckpt", ): # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"] sepr = Separator(