Skip to content

Commit

Permalink
Optimize dp etc. (#407)
Browse files Browse the repository at this point in the history
* Remove unused asr models.

* Fix ipynb

* webui.py ok

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove unused code

* Changed to faster whisper.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* unused

* Optimize sth.

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Auto Labeling

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove unused package

* Advice for learning with a small number of samples

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Recommendations refined

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] committed Jul 23, 2024
1 parent 7248b44 commit 979b0e5
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 44 deletions.
43 changes: 25 additions & 18 deletions fish_speech/webui/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,10 @@ def train_process(
)
)
logger.info(project)

if llama_check_interval > llama_maxsteps:
llama_check_interval = llama_maxsteps

train_cmd = [
PYTHON,
"fish_speech/train.py",
Expand Down Expand Up @@ -800,7 +804,7 @@ def llama_quantify(llama_weight, quantify_mode):
"Use LoRA can save GPU memory, but may reduce the quality of the model"
),
value=True,
interactive=False,
interactive=True,
)
llama_ckpt = gr.Dropdown(
label=i18n("Select LLAMA ckpt"),
Expand All @@ -816,19 +820,25 @@ def llama_quantify(llama_weight, quantify_mode):
with gr.Row(equal_height=False):
llama_lr_slider = gr.Slider(
label=i18n("Initial Learning Rate"),
info=i18n(
"lr smaller -> usually train slower but more stable"
),
interactive=True,
minimum=1e-5,
maximum=1e-4,
step=1e-5,
value=init_llama_yml["model"]["optimizer"]["lr"],
value=5e-5,
)
llama_maxsteps_slider = gr.Slider(
label=i18n("Maximum Training Steps"),
info=i18n(
"recommend: max_steps = num_audios // batch_size * (2 to 5)"
),
interactive=True,
minimum=50,
minimum=1,
maximum=10000,
step=50,
value=init_llama_yml["trainer"]["max_steps"],
step=1,
value=50,
)
with gr.Row(equal_height=False):
llama_base_config = gr.Dropdown(
Expand All @@ -841,13 +851,9 @@ def llama_quantify(llama_weight, quantify_mode):
llama_data_num_workers_slider = gr.Slider(
label=i18n("Number of Workers"),
minimum=1,
maximum=16,
maximum=32,
step=1,
value=(
init_llama_yml["data"]["num_workers"]
if sys.platform == "linux"
else 1
),
value=4,
)
with gr.Row(equal_height=False):
llama_data_batch_size_slider = gr.Slider(
Expand All @@ -856,15 +862,15 @@ def llama_quantify(llama_weight, quantify_mode):
minimum=1,
maximum=32,
step=1,
value=init_llama_yml["data"]["batch_size"],
value=4,
)
llama_data_max_length_slider = gr.Slider(
label=i18n("Maximum Length per Sample"),
interactive=True,
minimum=1024,
maximum=4096,
step=128,
value=init_llama_yml["max_length"],
value=1024,
)
with gr.Row(equal_height=False):
llama_precision_dropdown = gr.Dropdown(
Expand All @@ -878,13 +884,14 @@ def llama_quantify(llama_weight, quantify_mode):
)
llama_check_interval_slider = gr.Slider(
label=i18n("Save model every n steps"),
info=i18n(
"make sure that it's not greater than max_steps"
),
interactive=True,
minimum=50,
minimum=1,
maximum=1000,
step=50,
value=init_llama_yml["trainer"][
"val_check_interval"
],
step=1,
value=50,
)
with gr.Row(equal_height=False):
llama_grad_batches = gr.Slider(
Expand Down
6 changes: 5 additions & 1 deletion start.bat
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ chcp 65001

set USE_MIRROR=true
set PYTHONPATH=%~dp0
set PYTHON_CMD=%cd%\fishenv\env\python
set PYTHON_CMD=python
if exist "fishenv" (
set PYTHON_CMD=%cd%\fishenv\env\python
)

set API_FLAG_PATH=%~dp0API_FLAGS.txt
set KMP_DUPLICATE_LIB_OK=TRUE

Expand Down
45 changes: 26 additions & 19 deletions tools/auto_rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,16 @@ def batch_asr_internal(model: WhisperModel, audios, sr):
assert audio.dim() == 1
audio_np = audio.numpy()
resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000)
resampled_audios.append(torch.from_numpy(resampled_audio))
resampled_audios.append(resampled_audio)

trans_results = []

for resampled_audio in resampled_audios:
segments, info = model.transcribe(
resampled_audio.numpy(), language=None, beam_size=5
resampled_audio,
language=None,
beam_size=5,
initial_prompt="Punctuation is needed in any language.",
)
trans_results.append(list(segments))

Expand All @@ -71,6 +74,7 @@ def batch_asr_internal(model: WhisperModel, audios, sr):
last_tr = tr
if max_gap > 3.0:
huge_gap = True
break

sim_text = t2s_converter.convert(text)
results.append(
Expand All @@ -95,34 +99,37 @@ def is_chinese(text):
return True


def calculate_wer(text1, text2):
# 将文本分割成字符列表
def calculate_wer(text1, text2, debug=False):
chars1 = remove_punctuation(text1)
chars2 = remove_punctuation(text2)

# 计算编辑距离
m, n = len(chars1), len(chars2)
dp = [[0] * (n + 1) for _ in range(m + 1)]

for i in range(m + 1):
dp[i][0] = i
for j in range(n + 1):
dp[0][j] = j
if m > n:
chars1, chars2 = chars2, chars1
m, n = n, m

for i in range(1, m + 1):
for j in range(1, n + 1):
prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...]
curr = [0] * (m + 1)

for j in range(1, n + 1):
curr[0] = j
for i in range(1, m + 1):
if chars1[i - 1] == chars2[j - 1]:
dp[i][j] = dp[i - 1][j - 1]
curr[i] = prev[i - 1]
else:
dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1
curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1
prev, curr = curr, prev

# WER
edits = dp[m][n]
edits = prev[m]
tot = max(len(chars1), len(chars2))
wer = edits / tot
print(" gt: ", chars1)
print(" pred: ", chars2)
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)

if debug:
print(" gt: ", chars1)
print(" pred: ", chars2)
print(" edits/tot = wer: ", edits, "/", tot, "=", wer)

return wer


Expand Down
53 changes: 47 additions & 6 deletions tools/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path

import gradio as gr
import librosa
import numpy as np
import pyrootutils
import torch
Expand Down Expand Up @@ -323,6 +324,23 @@ def change_if_load_asr_model(if_load):
return gr.Checkbox(label="Load faster whisper model", value=if_load)


def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
if if_load and asr_model is not None:
if (
if_auto_label
and enable_ref
and ref_audio is not None
and ref_text.strip() == ""
):
data, sample_rate = librosa.load(ref_audio)
res = batch_asr(asr_model, [data], sample_rate)[0]
ref_text = res["text"]
else:
gr.Warning("Whisper model not loaded!")

return gr.Textbox(value=ref_text)


def build_app():
with gr.Blocks(theme=gr.themes.Base()) as app:
gr.Markdown(HEADER_MD)
Expand Down Expand Up @@ -419,12 +437,19 @@ def build_app():
label=i18n("Reference Audio"),
type="filepath",
)
reference_text = gr.Textbox(
label=i18n("Reference Text"),
placeholder=i18n("Reference Text"),
lines=1,
value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
)
with gr.Row():
if_auto_label = gr.Checkbox(
label=i18n("Auto Labeling"),
min_width=100,
scale=0,
value=False,
)
reference_text = gr.Textbox(
label=i18n("Reference Text"),
lines=1,
placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
value="",
)
with gr.Tab(label=i18n("Batch Inference")):
batch_infer_num = gr.Slider(
label="Batch infer nums",
Expand Down Expand Up @@ -479,6 +504,22 @@ def build_app():
outputs=[if_load_asr_model],
)

if_auto_label.change(
fn=lambda: gr.Textbox(value=""),
inputs=[],
outputs=[reference_text],
).then(
fn=change_if_auto_label,
inputs=[
if_load_asr_model,
if_auto_label,
enable_reference_audio,
reference_audio,
reference_text,
],
outputs=[reference_text],
)

# # Submit
generate.click(
inference_wrapper,
Expand Down

0 comments on commit 979b0e5

Please sign in to comment.