Skip to content

Commit

Permalink
keep up with official close-source api (#513)
Browse files Browse the repository at this point in the history
* keep up with official close-source api

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* curl support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* avoid empty ref

* remove unused files

* api CHN normalize

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* ormsgpack support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] committed Sep 8, 2024
1 parent 26fa73a commit e9394c7
Show file tree
Hide file tree
Showing 8 changed files with 251 additions and 255 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ filelists
/*.npy
/*.wav
/*.mp3
/*.lab
/results
/data
/.idea
Expand All @@ -25,6 +26,6 @@ asr-label*
/fishenv
/.locale
/demo-audios
ref_data*
/references
/example
/faster_whisper
215 changes: 96 additions & 119 deletions tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,20 @@
from argparse import ArgumentParser
from http import HTTPStatus
from pathlib import Path
from typing import Annotated, Literal, Optional
from typing import Annotated, Any, Literal, Optional

import numpy as np
import ormsgpack
import pyrootutils
import soundfile as sf
import torch
import torchaudio
from baize.datastructures import ContentType
from kui.asgi import (
Body,
FactoryClass,
HTTPException,
HttpRequest,
HttpView,
JSONResponse,
Kui,
Expand All @@ -27,14 +31,16 @@
)
from kui.asgi.routing import MultimethodRoutes
from loguru import logger
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, conint

pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)

# from fish_speech.models.vqgan.lit_module import VQGAN
from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
from fish_speech.utils import autocast_exclude_mps
from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
from tools.llama.generate import (
GenerateRequest,
GenerateResponse,
Expand Down Expand Up @@ -82,11 +88,8 @@ async def other_exception_handler(exc: "Exception"):

def load_audio(reference_audio, sr):
if len(reference_audio) > 255 or not Path(reference_audio).exists():
try:
audio_data = base64.b64decode(reference_audio)
reference_audio = io.BytesIO(audio_data)
except base64.binascii.Error:
raise ValueError("Invalid path or base64 string")
audio_data = reference_audio
reference_audio = io.BytesIO(audio_data)

waveform, original_sr = torchaudio.load(
reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
Expand Down Expand Up @@ -153,56 +156,36 @@ def decode_vq_tokens(
routes = MultimethodRoutes(base_class=HttpView)


def get_random_paths(base_path, data, speaker, emotion):
if base_path and data and speaker and emotion and (Path(base_path).exists()):
if speaker in data and emotion in data[speaker]:
files = data[speaker][emotion]
lab_files = [f for f in files if f.endswith(".lab")]
wav_files = [f for f in files if f.endswith(".wav")]
class ServeReferenceAudio(BaseModel):
audio: bytes
text: str

if lab_files and wav_files:
selected_lab = random.choice(lab_files)
selected_wav = random.choice(wav_files)

lab_path = Path(base_path) / speaker / emotion / selected_lab
wav_path = Path(base_path) / speaker / emotion / selected_wav
if lab_path.exists() and wav_path.exists():
return lab_path, wav_path

return None, None


def load_json(json_file):
if not json_file:
logger.info("Not using a json file")
return None
try:
with open(json_file, "r", encoding="utf-8") as file:
data = json.load(file)
except FileNotFoundError:
logger.warning(f"ref json not found: {json_file}")
data = None
except Exception as e:
logger.warning(f"Loading json failed: {e}")
data = None
return data


class InvokeRequest(BaseModel):
class ServeTTSRequest(BaseModel):
text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
reference_text: Optional[str] = None
reference_audio: Optional[str] = None
chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
# Audio format
format: Literal["wav", "pcm", "mp3"] = "wav"
mp3_bitrate: Literal[64, 128, 192] = 128
# References audios for in-context learning
references: list[ServeReferenceAudio] = []
# Reference id
# For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
# Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
reference_id: str | None = None
# Normalize text for en & zh, this increase stability for numbers
normalize: bool = True
mp3_bitrate: Optional[int] = 64
opus_bitrate: Optional[int] = -1000
# Balance mode will reduce latency to 300ms, but may decrease stability
latency: Literal["normal", "balanced"] = "normal"
# not usually used below
streaming: bool = False
emotion: Optional[str] = None
max_new_tokens: int = 1024
chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
emotion: Optional[str] = None
format: Literal["wav", "mp3", "flac"] = "wav"
streaming: bool = False
ref_json: Optional[str] = "ref_data.json"
ref_base: Optional[str] = "ref_data"
speaker: Optional[str] = None


def get_content_type(audio_format):
Expand All @@ -217,35 +200,52 @@ def get_content_type(audio_format):


@torch.inference_mode()
def inference(req: InvokeRequest):
# Parse reference audio aka prompt
prompt_tokens = None

ref_data = load_json(req.ref_json)
ref_base = req.ref_base

lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)

if lab_path and wav_path:
with open(lab_path, "r", encoding="utf-8") as lab_file:
ref_text = lab_file.read()
req.reference_audio = wav_path
req.reference_text = ref_text
logger.info("ref_path: " + str(wav_path))
logger.info("ref_text: " + ref_text)

# Parse reference audio aka prompt
prompt_tokens = encode_reference(
decoder_model=decoder_model,
reference_audio=req.reference_audio,
enable_reference_audio=req.reference_audio is not None,
)
logger.info(f"ref_text: {req.reference_text}")
def inference(req: ServeTTSRequest):

idstr: str | None = req.reference_id
if idstr is not None:
ref_folder = Path("references") / idstr
ref_folder.mkdir(parents=True, exist_ok=True)
ref_audios = list_files(
ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
)
prompt_tokens = [
encode_reference(
decoder_model=decoder_model,
reference_audio=audio_to_bytes(str(ref_audio)),
enable_reference_audio=True,
)
for ref_audio in ref_audios
]
prompt_texts = [
read_ref_text(str(ref_audio.with_suffix(".lab")))
for ref_audio in ref_audios
]

else:
# Parse reference audio aka prompt
refs = req.references
if refs is None:
refs = []
prompt_tokens = [
encode_reference(
decoder_model=decoder_model,
reference_audio=ref.audio,
enable_reference_audio=True,
)
for ref in refs
]
prompt_texts = [ref.text for ref in refs]

# LLAMA Inference
request = dict(
device=decoder_model.device,
max_new_tokens=req.max_new_tokens,
text=req.text,
text=(
req.text
if not req.normalize
else ChnNormedText(raw_text=req.text).normalize()
),
top_p=req.top_p,
repetition_penalty=req.repetition_penalty,
temperature=req.temperature,
Expand All @@ -254,7 +254,7 @@ def inference(req: InvokeRequest):
chunk_length=req.chunk_length,
max_length=2048,
prompt_tokens=prompt_tokens,
prompt_text=req.reference_text,
prompt_text=prompt_texts,
)

response_queue = queue.Queue()
Expand Down Expand Up @@ -307,40 +307,7 @@ def inference(req: InvokeRequest):
yield fake_audios


def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
if not use_auto_rerank:
# 如果不使用 auto_rerank,直接调用原始的 inference 函数
return inference(req)

zh_model, en_model = load_model()
max_attempts = 5
best_wer = float("inf")
best_audio = None

for attempt in range(max_attempts):
# 调用原始的 inference 函数
audio_generator = inference(req)
fake_audios = next(audio_generator)

asr_result = batch_asr(
zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
)[0]
wer = calculate_wer(req.text, asr_result["text"])

if wer <= 0.1 and not asr_result["huge_gap"]:
return fake_audios

if wer < best_wer:
best_wer = wer
best_audio = fake_audios

if attempt == max_attempts - 1:
break

return best_audio


async def inference_async(req: InvokeRequest):
async def inference_async(req: ServeTTSRequest):
for chunk in inference(req):
yield chunk

Expand All @@ -349,9 +316,9 @@ async def buffer_to_async_generator(buffer):
yield buffer


@routes.http.post("/v1/invoke")
@routes.http.post("/v1/tts")
async def api_invoke_model(
req: Annotated[InvokeRequest, Body(exclusive=True)],
req: Annotated[ServeTTSRequest, Body(exclusive=True)],
):
"""
Invoke model and generate audio
Expand Down Expand Up @@ -422,7 +389,7 @@ def parse_args():
parser.add_argument("--half", action="store_true")
parser.add_argument("--compile", action="store_true")
parser.add_argument("--max-text-length", type=int, default=0)
parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
parser.add_argument("--workers", type=int, default=1)
parser.add_argument("--use-auto-rerank", type=bool, default=True)

Expand All @@ -436,18 +403,30 @@ def parse_args():
},
).routes


class MsgPackRequest(HttpRequest):
async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
if self.content_type == "application/msgpack":
return ormsgpack.unpackb(await self.body)

raise HTTPException(
HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
headers={"Accept": "application/msgpack"},
)


app = Kui(
routes=routes + openapi[1:], # Remove the default route
exception_handlers={
HTTPException: http_execption_handler,
Exception: other_exception_handler,
},
factory_class=FactoryClass(http=MsgPackRequest),
cors_config={},
)


if __name__ == "__main__":
import threading

import uvicorn

Expand All @@ -474,18 +453,16 @@ def parse_args():
# Dry run to check if the model is loaded correctly and avoid the first-time latency
list(
inference(
InvokeRequest(
ServeTTSRequest(
text="Hello world.",
reference_text=None,
reference_audio=None,
references=[],
reference_id=None,
max_new_tokens=0,
top_p=0.7,
repetition_penalty=1.2,
temperature=0.7,
emotion=None,
format="wav",
ref_base=None,
ref_json=None,
)
)
)
Expand Down
17 changes: 17 additions & 0 deletions tools/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import base64
from pathlib import Path
from typing import Union

Expand All @@ -23,6 +24,22 @@
}


def audio_to_bytes(file_path):
if not file_path or not Path(file_path).exists():
return None
with open(file_path, "rb") as wav_file:
wav = wav_file.read()
return wav


def read_ref_text(ref_text):
path = Path(ref_text)
if path.exists() and path.is_file():
with path.open("r", encoding="utf-8") as file:
return file.read()
return ref_text


def list_files(
path: Union[Path, str],
extensions: set[str] = None,
Expand Down
Loading

0 comments on commit e9394c7

Please sign in to comment.