From 979b0e5523e324105e7fa3879e07e73a8f736199 Mon Sep 17 00:00:00 2001 From: spicysama Date: Tue, 23 Jul 2024 16:03:47 +0800 Subject: [PATCH] Optimize dp etc. (#407) * Remove unused asr models. * Fix ipynb * webui.py ok * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove unused code * Changed to faster whisper. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unused * Optimize sth. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Auto Labeling * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unused package * Advice for learning with a small number of samples * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Recommendations refined * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- fish_speech/webui/manage.py | 43 +++++++++++++++++------------- start.bat | 6 ++++- tools/auto_rerank.py | 45 ++++++++++++++++++------------- tools/webui.py | 53 ++++++++++++++++++++++++++++++++----- 4 files changed, 103 insertions(+), 44 deletions(-) diff --git a/fish_speech/webui/manage.py b/fish_speech/webui/manage.py index d66e72f7..9c183acd 100644 --- a/fish_speech/webui/manage.py +++ b/fish_speech/webui/manage.py @@ -510,6 +510,10 @@ def train_process( ) ) logger.info(project) + + if llama_check_interval > llama_maxsteps: + llama_check_interval = llama_maxsteps + train_cmd = [ PYTHON, "fish_speech/train.py", @@ -800,7 +804,7 @@ def llama_quantify(llama_weight, quantify_mode): "Use LoRA can save GPU memory, but may reduce the quality of the model" ), value=True, - interactive=False, + interactive=True, ) llama_ckpt = gr.Dropdown( label=i18n("Select LLAMA ckpt"), @@ -816,19 +820,25 @@ def llama_quantify(llama_weight, quantify_mode): with gr.Row(equal_height=False): llama_lr_slider = gr.Slider( label=i18n("Initial Learning Rate"), + info=i18n( + "lr smaller -> usually train slower but more stable" + ), interactive=True, minimum=1e-5, maximum=1e-4, step=1e-5, - value=init_llama_yml["model"]["optimizer"]["lr"], + value=5e-5, ) llama_maxsteps_slider = gr.Slider( label=i18n("Maximum Training Steps"), + info=i18n( + "recommend: max_steps = num_audios // batch_size * (2 to 5)" + ), interactive=True, - minimum=50, + minimum=1, maximum=10000, - step=50, - value=init_llama_yml["trainer"]["max_steps"], + step=1, + value=50, ) with gr.Row(equal_height=False): llama_base_config = gr.Dropdown( @@ -841,13 +851,9 @@ def llama_quantify(llama_weight, quantify_mode): llama_data_num_workers_slider = gr.Slider( label=i18n("Number of Workers"), minimum=1, - maximum=16, + maximum=32, step=1, - value=( - init_llama_yml["data"]["num_workers"] - if sys.platform == "linux" - else 1 - ), + value=4, ) with gr.Row(equal_height=False): llama_data_batch_size_slider = gr.Slider( @@ -856,7 +862,7 @@ def llama_quantify(llama_weight, quantify_mode): minimum=1, maximum=32, step=1, - value=init_llama_yml["data"]["batch_size"], + value=4, ) llama_data_max_length_slider = gr.Slider( label=i18n("Maximum Length per Sample"), @@ -864,7 +870,7 @@ def llama_quantify(llama_weight, quantify_mode): minimum=1024, maximum=4096, step=128, - value=init_llama_yml["max_length"], + value=1024, ) with gr.Row(equal_height=False): llama_precision_dropdown = gr.Dropdown( @@ -878,13 +884,14 @@ def llama_quantify(llama_weight, quantify_mode): ) llama_check_interval_slider = gr.Slider( label=i18n("Save model every n steps"), + info=i18n( + "make sure that it's not greater than max_steps" + ), interactive=True, - minimum=50, + minimum=1, maximum=1000, - step=50, - value=init_llama_yml["trainer"][ - "val_check_interval" - ], + step=1, + value=50, ) with gr.Row(equal_height=False): llama_grad_batches = gr.Slider( diff --git a/start.bat b/start.bat index f3b58a6a..05565642 100644 --- a/start.bat +++ b/start.bat @@ -3,7 +3,11 @@ chcp 65001 set USE_MIRROR=true set PYTHONPATH=%~dp0 -set PYTHON_CMD=%cd%\fishenv\env\python +set PYTHON_CMD=python +if exist "fishenv" ( + set PYTHON_CMD=%cd%\fishenv\env\python +) + set API_FLAG_PATH=%~dp0API_FLAGS.txt set KMP_DUPLICATE_LIB_OK=TRUE diff --git a/tools/auto_rerank.py b/tools/auto_rerank.py index 04346e90..0297d63d 100644 --- a/tools/auto_rerank.py +++ b/tools/auto_rerank.py @@ -40,13 +40,16 @@ def batch_asr_internal(model: WhisperModel, audios, sr): assert audio.dim() == 1 audio_np = audio.numpy() resampled_audio = librosa.resample(audio_np, orig_sr=sr, target_sr=16000) - resampled_audios.append(torch.from_numpy(resampled_audio)) + resampled_audios.append(resampled_audio) trans_results = [] for resampled_audio in resampled_audios: segments, info = model.transcribe( - resampled_audio.numpy(), language=None, beam_size=5 + resampled_audio, + language=None, + beam_size=5, + initial_prompt="Punctuation is needed in any language.", ) trans_results.append(list(segments)) @@ -71,6 +74,7 @@ def batch_asr_internal(model: WhisperModel, audios, sr): last_tr = tr if max_gap > 3.0: huge_gap = True + break sim_text = t2s_converter.convert(text) results.append( @@ -95,34 +99,37 @@ def is_chinese(text): return True -def calculate_wer(text1, text2): - # 将文本分割成字符列表 +def calculate_wer(text1, text2, debug=False): chars1 = remove_punctuation(text1) chars2 = remove_punctuation(text2) - # 计算编辑距离 m, n = len(chars1), len(chars2) - dp = [[0] * (n + 1) for _ in range(m + 1)] - for i in range(m + 1): - dp[i][0] = i - for j in range(n + 1): - dp[0][j] = j + if m > n: + chars1, chars2 = chars2, chars1 + m, n = n, m - for i in range(1, m + 1): - for j in range(1, n + 1): + prev = list(range(m + 1)) # row 0 distance: [0, 1, 2, ...] + curr = [0] * (m + 1) + + for j in range(1, n + 1): + curr[0] = j + for i in range(1, m + 1): if chars1[i - 1] == chars2[j - 1]: - dp[i][j] = dp[i - 1][j - 1] + curr[i] = prev[i - 1] else: - dp[i][j] = min(dp[i - 1][j], dp[i][j - 1], dp[i - 1][j - 1]) + 1 + curr[i] = min(prev[i], curr[i - 1], prev[i - 1]) + 1 + prev, curr = curr, prev - # WER - edits = dp[m][n] + edits = prev[m] tot = max(len(chars1), len(chars2)) wer = edits / tot - print(" gt: ", chars1) - print(" pred: ", chars2) - print(" edits/tot = wer: ", edits, "/", tot, "=", wer) + + if debug: + print(" gt: ", chars1) + print(" pred: ", chars2) + print(" edits/tot = wer: ", edits, "/", tot, "=", wer) + return wer diff --git a/tools/webui.py b/tools/webui.py index 6afb8633..9e01233a 100644 --- a/tools/webui.py +++ b/tools/webui.py @@ -9,6 +9,7 @@ from pathlib import Path import gradio as gr +import librosa import numpy as np import pyrootutils import torch @@ -323,6 +324,23 @@ def change_if_load_asr_model(if_load): return gr.Checkbox(label="Load faster whisper model", value=if_load) +def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text): + if if_load and asr_model is not None: + if ( + if_auto_label + and enable_ref + and ref_audio is not None + and ref_text.strip() == "" + ): + data, sample_rate = librosa.load(ref_audio) + res = batch_asr(asr_model, [data], sample_rate)[0] + ref_text = res["text"] + else: + gr.Warning("Whisper model not loaded!") + + return gr.Textbox(value=ref_text) + + def build_app(): with gr.Blocks(theme=gr.themes.Base()) as app: gr.Markdown(HEADER_MD) @@ -419,12 +437,19 @@ def build_app(): label=i18n("Reference Audio"), type="filepath", ) - reference_text = gr.Textbox( - label=i18n("Reference Text"), - placeholder=i18n("Reference Text"), - lines=1, - value="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", - ) + with gr.Row(): + if_auto_label = gr.Checkbox( + label=i18n("Auto Labeling"), + min_width=100, + scale=0, + value=False, + ) + reference_text = gr.Textbox( + label=i18n("Reference Text"), + lines=1, + placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。", + value="", + ) with gr.Tab(label=i18n("Batch Inference")): batch_infer_num = gr.Slider( label="Batch infer nums", @@ -479,6 +504,22 @@ def build_app(): outputs=[if_load_asr_model], ) + if_auto_label.change( + fn=lambda: gr.Textbox(value=""), + inputs=[], + outputs=[reference_text], + ).then( + fn=change_if_auto_label, + inputs=[ + if_load_asr_model, + if_auto_label, + enable_reference_audio, + reference_audio, + reference_text, + ], + outputs=[reference_text], + ) + # # Submit generate.click( inference_wrapper,