Skip to content

Commit

Permalink
Fix Preprocess Bugs (#154)
Browse files Browse the repository at this point in the history
* Fix button height

* Streaming support

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Convert to 1 channel

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix Conversion bug

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix target path

* Add checkpoint selection

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix gpup decorator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
AnyaCoder and pre-commit-ci[bot] committed May 5, 2024
1 parent f473a75 commit 2711f5d
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 18 deletions.
101 changes: 84 additions & 17 deletions fish_speech/webui/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,11 +255,11 @@ def show_selected(options):
from pydub import AudioSegment


def convert_to_mono_in_place(audio_path):
def convert_to_mono_in_place(audio_path: Path):
audio = AudioSegment.from_file(audio_path)
if audio.channels > 1:
mono_audio = audio.set_channels(1)
mono_audio.export(audio_path, format="mp3")
mono_audio.export(audio_path, format=audio_path.suffix[1:])
logger.info(f"Convert {audio_path} successfully")


Expand All @@ -277,12 +277,11 @@ def list_copy(list_file_path, method):
if target_wav_path.is_file():
continue
target_wav_path.parent.mkdir(parents=True, exist_ok=True)
convert_to_mono_in_place(original_wav_path)
if method == i18n("Copy"):
shutil.copy(original_wav_path, target_wav_path)
else:
shutil.move(original_wav_path, target_wav_path.parent)

convert_to_mono_in_place(target_wav_path)
original_lab_path = original_wav_path.with_suffix(".lab")
target_lab_path = (
wav_root
Expand Down Expand Up @@ -312,8 +311,16 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
tar_path = data_path / item_path.name

if content["type"] == "folder" and item_path.is_dir():
if content["method"] == i18n("Copy"):
os.makedirs(tar_path, exist_ok=True)
shutil.copytree(
src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
)
elif not tar_path.is_dir():
shutil.move(src=str(item_path), dst=str(tar_path))

for suf in ["wav", "flac", "mp3"]:
for audio_path in item_path.glob(f"**/*.{suf}"):
for audio_path in tar_path.glob(f"**/*.{suf}"):
convert_to_mono_in_place(audio_path)

cur_lang = content["label_lang"]
Expand All @@ -328,9 +335,9 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
"--device",
label_device,
"--audio-dir",
item_path,
tar_path,
"--save-dir",
item_path,
tar_path,
"--language",
cur_lang,
],
Expand All @@ -339,14 +346,6 @@ def check_files(data_path: str, max_depth: int, label_model: str, label_device:
except Exception:
print("Transcription error occurred")

if content["method"] == i18n("Copy"):
os.makedirs(tar_path, exist_ok=True)
shutil.copytree(
src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
)
elif not tar_path.is_dir():
shutil.move(src=str(item_path), dst=str(tar_path))

elif content["type"] == "file" and item_path.is_file():
list_copy(item_path, content["method"])

Expand All @@ -359,6 +358,7 @@ def train_process(
data_path: str,
option: str,
# vq-gan config
vqgan_ckpt,
vqgan_lr,
vqgan_maxsteps,
vqgan_data_num_workers,
Expand All @@ -367,6 +367,7 @@ def train_process(
vqgan_precision,
vqgan_check_interval,
# llama config
llama_ckpt,
llama_base_config,
llama_lr,
llama_maxsteps,
Expand Down Expand Up @@ -400,12 +401,29 @@ def generate_folder_name():
str(data_pre_output.relative_to(cur_work_dir)),
]
)
latest = list(
sorted(
[
str(p.relative_to("results"))
for p in Path("results").glob("vqgan_*/")
],
reverse=True,
)
)[0]
project = (
("vqgan_" + new_project)
if vqgan_ckpt == "new"
else latest
if vqgan_ckpt == "latest"
else vqgan_ckpt
)
logger.info(project)
train_cmd = [
PYTHON,
"fish_speech/train.py",
"--config-name",
"vqgan_finetune",
f"project={'vqgan_' + new_project}",
f"project={project}",
f"trainer.strategy.process_group_backend={backend}",
f"model.optimizer.lr={vqgan_lr}",
f"trainer.max_steps={vqgan_maxsteps}",
Expand Down Expand Up @@ -454,12 +472,30 @@ def generate_folder_name():
if llama_base_config == "dual_ar_2_codebook_medium"
else "text2semantic-sft-large-v1-4k.pth"
)

latest = list(
sorted(
[
str(p.relative_to("results"))
for p in Path("results").glob("text2sem*/")
],
reverse=True,
)
)[0]
project = (
("text2semantic_" + new_project)
if llama_ckpt == "new"
else latest
if llama_ckpt == "latest"
else llama_ckpt
)
logger.info(project)
train_cmd = [
PYTHON,
"fish_speech/train.py",
"--config-name",
"text2semantic_finetune",
f"project={'text2semantic_' + new_project}",
f"project={project}",
f"ckpt_path=checkpoints/{ckpt_path}",
f"trainer.strategy.process_group_backend={backend}",
f"[email protected]={llama_base_config}",
Expand Down Expand Up @@ -530,6 +566,18 @@ def fresh_vqgan_model():
)


def fresh_vqgan_ckpt():
return gr.Dropdown(
choices=["latest", "new"] + [str(p) for p in Path("results").glob("vqgan_*/")]
)


def fresh_llama_ckpt():
return gr.Dropdown(
choices=["latest", "new"] + [str(p) for p in Path("results").glob("text2sem*/")]
)


def fresh_llama_model():
return gr.Dropdown(
choices=[init_llama_yml["ckpt_path"]]
Expand Down Expand Up @@ -655,6 +703,14 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
)
with gr.Row():
with gr.Tab(label=i18n("VQGAN Configuration")):
with gr.Row(equal_height=False):
vqgan_ckpt = gr.Dropdown(
label="Select VQGAN ckpt",
choices=["latest", "new"]
+ [str(p) for p in Path("results").glob("vqgan_*/")],
value="latest",
interactive=True,
)
with gr.Row(equal_height=False):
vqgan_lr_slider = gr.Slider(
label=i18n("Initial Learning Rate"),
Expand Down Expand Up @@ -728,6 +784,13 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
),
value=True,
)
llama_ckpt = gr.Dropdown(
label="Select LLAMA ckpt",
choices=["latest", "new"]
+ [str(p) for p in Path("results").glob("text2sem*/")],
value="latest",
interactive=True,
)
with gr.Row(equal_height=False):
llama_lr_slider = gr.Slider(
label=i18n("Initial Learning Rate"),
Expand Down Expand Up @@ -1022,6 +1085,7 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
train_box,
model_type_radio,
# vq-gan config
vqgan_ckpt,
vqgan_lr_slider,
vqgan_maxsteps_slider,
vqgan_data_num_workers_slider,
Expand All @@ -1030,6 +1094,7 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
vqgan_precision_dropdown,
vqgan_check_interval_slider,
# llama config
llama_ckpt,
llama_base_config,
llama_lr_slider,
llama_maxsteps_slider,
Expand Down Expand Up @@ -1065,6 +1130,8 @@ def llama_lora_merge(llama_weight, lora_weight, llama_lora_output):
fresh_btn.click(
fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
)
vqgan_ckpt.change(fn=fresh_vqgan_ckpt, inputs=[], outputs=[vqgan_ckpt])
llama_ckpt.change(fn=fresh_llama_ckpt, inputs=[], outputs=[llama_ckpt])
llama_lora_merge_btn.click(
fn=llama_lora_merge,
inputs=[llama_weight, lora_weight, llama_lora_output],
Expand Down
11 changes: 10 additions & 1 deletion tools/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import queue
import wave
from argparse import ArgumentParser
from functools import partial
from functools import partial, wraps
from pathlib import Path

import gradio as gr
Expand Down Expand Up @@ -38,17 +38,21 @@
"""

TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
SPACE_IMPORTED = False

try:
import spaces

GPU_DECORATOR = spaces.GPU
SPACE_IMPORTED = True
except ImportError:

def GPU_DECORATOR(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

wrapper.original = func # ref
return wrapper


Expand Down Expand Up @@ -169,6 +173,11 @@ def inference(

inference_stream = partial(inference, streaming=True)

if not SPACE_IMPORTED:
logger.info("‘spaces’ not imported, use original")
inference = inference.original
inference_stream = partial(inference, streaming=True)


def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
Expand Down

0 comments on commit 2711f5d

Please sign in to comment.