Skip to content

Commit

Permalink
Fix around asr bugs (#401)
Browse files Browse the repository at this point in the history
* Remove unused asr models.

* Fix ipynb

* webui.py ok

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused code

* Changed to faster whisper.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unused

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] committed Jul 20, 2024
1 parent dc4b107 commit eb35b0b
Show file tree
Hide file tree
Showing 5 changed files with 129 additions and 84 deletions.
10 changes: 7 additions & 3 deletions inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"!python tools/webui.py \\\n",
Expand Down Expand Up @@ -114,7 +118,7 @@
"outputs": [],
"source": [
"## Enter the path to the audio file here\n",
"src_audio = r\"D:\\PythonProject\\\\vo_hutao_draw_appear.wav\"\n",
"src_audio = r\"D:\\PythonProject\\vo_hutao_draw_appear.wav\"\n",
"\n",
"!python tools/vqgan/inference.py \\\n",
" -i {src_audio} \\\n",
Expand Down Expand Up @@ -163,7 +167,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Generate speecj from semantic tokens: / 从语义 token 生成人声:"
"### 3. Generate speech from semantic tokens: / 从语义 token 生成人声:"
]
},
{
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ dependencies = [
"pydub",
"faster_whisper",
"modelscope==1.16.1",
"funasr==1.1.2"
"funasr==1.1.2",
"opencc-python-reimplemented==0.1.7"
]

[project.optional-dependencies]
Expand Down
122 changes: 74 additions & 48 deletions tools/auto_rerank.py
Original file line number Diff line number Diff line change
@@ -1,71 +1,81 @@
import time
import os

os.environ["MODELSCOPE_CACHE"] = ".cache/"

import string
import time
from threading import Lock

import librosa
import numpy as np
import opencc
import torch
import torchaudio
from funasr import AutoModel
from funasr.models.seaco_paraformer.model import SeacoParaformer
from faster_whisper import WhisperModel

# Monkey patching to disable hotwords
SeacoParaformer.generate_hotwords_list = lambda self, *args, **kwargs: None
t2s_converter = opencc.OpenCC("t2s")


def load_model(*, device="cuda"):
zh_model = AutoModel(
model="paraformer-zh",
device=device,
disable_pbar=True,
)
en_model = AutoModel(
model="paraformer-en",
model = WhisperModel(
"medium",
device=device,
disable_pbar=True,
compute_type="float16",
download_root="faster_whisper",
)

return zh_model, en_model
print("faster_whisper loaded!")
return model


@torch.no_grad()
def batch_asr_internal(model, audios, sr):
def batch_asr_internal(model: WhisperModel, audios, sr):
resampled_audios = []
for audio in audios:
# 将 NumPy 数组转换为 PyTorch 张量

if isinstance(audio, np.ndarray):
audio = torch.from_numpy(audio).float()

# 确保音频是一维的
if audio.dim() > 1:
audio = audio.squeeze()

audio = torchaudio.functional.resample(audio, sr, 16000)
assert audio.dim() == 1
resampled_audios.append(audio)
audio_np = audio.numpy()
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
resampled_audios.append(torch.from_numpy(resampled_audio))

res = model.generate(input=resampled_audios, batch_size=len(resampled_audios))
trans_results = []

for resampled_audio in resampled_audios:
segments, info = model.transcribe(
resampled_audio.numpy(), language=None, beam_size=5
)
trans_results.append(list(segments))

results = []
for r, audio in zip(res, audios):
text = r["text"]
for trans_res, audio in zip(trans_results, audios):

duration = len(audio) / sr * 1000
huge_gap = False
max_gap = 0.0

text = None
last_tr = None

for tr in trans_res:
delta = tr.text.strip()
if tr.id > 1:
max_gap = max(tr.start - last_tr.end, max_gap)
text += delta
else:
text = delta

if "timestamp" in r and len(r["timestamp"]) > 2:
for timestamp_a, timestamp_b in zip(
r["timestamp"][:-1], r["timestamp"][1:]
):
# If there is a gap of more than 5 seconds, we consider it as a huge gap
if timestamp_b[0] - timestamp_a[1] > 5000:
huge_gap = True
break

# Doesn't make sense to have a huge gap at the end
if duration - r["timestamp"][-1][1] > 3000:
last_tr = tr
if max_gap > 3.0:
huge_gap = True

sim_text = t2s_converter.convert(text)
results.append(
{
"text": text,
"text": sim_text,
"duration": duration,
"huge_gap": huge_gap,
}
Expand All @@ -86,11 +96,12 @@ def is_chinese(text):


def calculate_wer(text1, text2):
words1 = text1.split()
words2 = text2.split()
# 将文本分割成字符列表
chars1 = remove_punctuation(text1)
chars2 = remove_punctuation(text2)

# 计算编辑距离
m, n = len(words1), len(words2)
m, n = len(chars1), len(chars2)
dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(m + 1):
Expand All @@ -100,27 +111,42 @@ def calculate_wer(text1, text2):

for i in range(1, m + 1):
for j in range(1, n + 1):
if words1[i - 1] == words2[j - 1]:
if chars1[i - 1] == chars2[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
else:
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1

# 计算WER
# WER
edits = dp[m][n]
wer = edits / len(words1)

tot = max(len(chars1), len(chars2))
wer = edits / tot
print(" gt: ", chars1)
print(" pred: ", chars2)
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)
return wer


def remove_punctuation(text):
chinese_punctuation = (
" \n\t”“!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—"
'‛""„‟…‧﹏'
)
all_punctuation = string.punctuation + chinese_punctuation
translator = str.maketrans("", "", all_punctuation)
text_without_punctuation = text.translate(translator)
return text_without_punctuation


if __name__ == "__main__":
zh_model, en_model = load_model()
model = load_model()
audios = [
torchaudio.load("lengyue.wav")[0][0],
torchaudio.load("lengyue.wav")[0][0, : 44100 * 5],
librosa.load("44100.wav", sr=44100)[0],
librosa.load("lengyue.wav", sr=44100)[0],
]
print(batch_asr(zh_model, audios, 44100))
print(np.array(audios[0]))
print(batch_asr(model, audios, 44100))

start_time = time.time()
for _ in range(10):
batch_asr(zh_model, audios, 44100)
print(batch_asr(model, audios, 44100))
print("Time taken:", time.time() - start_time)
10 changes: 1 addition & 9 deletions tools/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,6 @@ def check_and_download_files(repo_id, file_list, local_dir):
"firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
]

# 2nd
repo_id_2 = "SpicyqSama007/fish-speech-packed"
local_dir_2 = ".cache/whisper"
files_2 = [
"medium.pt",
"small.pt",
]

# 3rd
repo_id_3 = "fishaudio/fish-speech-1"
local_dir_3 = "./"
Expand All @@ -58,6 +50,6 @@ def check_and_download_files(repo_id, file_list, local_dir):
]

check_and_download_files(repo_id_1, files_1, local_dir_1)
check_and_download_files(repo_id_2, files_2, local_dir_2)

check_and_download_files(repo_id_3, files_3, local_dir_3)
check_and_download_files(repo_id_4, files_4, local_dir_4)
68 changes: 45 additions & 23 deletions tools/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,25 +173,11 @@ def inference_with_auto_rerank(
top_p,
repetition_penalty,
temperature,
use_auto_rerank,
streaming=False,
use_auto_rerank=True,
):
if not use_auto_rerank:
return inference(
text,
enable_reference_audio,
reference_audio,
reference_text,
max_new_tokens,
chunk_length,
top_p,
repetition_penalty,
temperature,
streaming,
)

zh_model, en_model = load_model()
max_attempts = 2
max_attempts = 2 if use_auto_rerank else 1
best_wer = float("inf")
best_audio = None
best_sample_rate = None
Expand All @@ -218,11 +204,11 @@ def inference_with_auto_rerank(
if audio is None:
return None, None, message

asr_result = batch_asr(
zh_model if is_chinese(text) else en_model, [audio], sample_rate
)[0]
wer = calculate_wer(text, asr_result["text"])
if not use_auto_rerank:
return None, (sample_rate, audio), None

asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
wer = calculate_wer(text, asr_result["text"])
if wer <= 0.3 and not asr_result["huge_gap"]:
return None, (sample_rate, audio), None

Expand All @@ -237,7 +223,7 @@ def inference_with_auto_rerank(
return None, (best_sample_rate, best_audio), None


inference_stream = partial(inference_with_auto_rerank, streaming=True)
inference_stream = partial(inference, streaming=True)

n_audios = 4

Expand All @@ -256,6 +242,7 @@ def inference_wrapper(
repetition_penalty,
temperature,
batch_infer_num,
if_load_asr_model,
):
audios = []
errors = []
Expand All @@ -271,6 +258,7 @@ def inference_wrapper(
top_p,
repetition_penalty,
temperature,
if_load_asr_model,
)

_, audio_data, error_message = result
Expand Down Expand Up @@ -313,6 +301,28 @@ def normalize_text(user_input, use_normalization):
return user_input


asr_model = None


def change_if_load_asr_model(if_load):
global asr_model

if if_load:
gr.Warning("Loading faster whisper model...")
if asr_model is None:
asr_model = load_model()
return gr.Checkbox(label="Unload faster whisper model", value=if_load)

if if_load is False:
gr.Warning("Unloading faster whisper model...")
del asr_model
asr_model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return gr.Checkbox(label="Load faster whisper model", value=if_load)


def build_app():
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
Expand Down Expand Up @@ -344,8 +354,13 @@ def build_app():
if_refine_text = gr.Checkbox(
label=i18n("Text Normalization"),
value=True,
scale=0,
min_width=150,
scale=1,
)

if_load_asr_model = gr.Checkbox(
label=i18n("Load / Unload ASR model for auto-reranking"),
value=False,
scale=3,
)

with gr.Row():
Expand Down Expand Up @@ -458,6 +473,12 @@ def build_app():
fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
)

if_load_asr_model.change(
fn=change_if_load_asr_model,
inputs=[if_load_asr_model],
outputs=[if_load_asr_model],
)

# # Submit
generate.click(
inference_wrapper,
Expand All @@ -472,6 +493,7 @@ def build_app():
repetition_penalty,
temperature,
batch_infer_num,
if_load_asr_model,
],
[stream_audio, *global_audio_list, *global_error_list],
concurrency_limit=1,
Expand Down

0 comments on commit eb35b0b

Please sign in to comment.