Skip to content

Commit

Permalink
Use SFT medium as base model
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue committed May 5, 2024
1 parent 2711f5d commit e787d01
Show file tree
Hide file tree
Showing 10 changed files with 19 additions and 41 deletions.
4 changes: 2 additions & 2 deletions docs/en/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ After the command finishes executing, you should see the `quantized-dataset-ft.p
Similarly, make sure you have downloaded the `LLAMA` weights. If not, run the following command:

```bash
huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
```

Finally, you can start the fine-tuning by running the following command:
Expand Down Expand Up @@ -182,7 +182,7 @@ After training, you need to convert the LoRA weights to regular weights before p
python tools/llama/merge_lora.py \
--llama-config dual_ar_2_codebook_large \
--lora-config r_8_alpha_16 \
--llama-weight checkpoints/text2semantic-sft-large-v1-4k.pth \
--llama-weight checkpoints/text2semantic-sft-medium-v1-4k.pth \
--lora-weight results/text2semantic-finetune-medium-lora/checkpoints/step_000000200.ckpt \
--output checkpoints/merged.ckpt
```
8 changes: 4 additions & 4 deletions docs/en/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Download the required `vqgan` and `text2semantic` models from our Hugging Face r

```bash
huggingface-cli download fishaudio/fish-speech-1 vq-gan-group-fsq-2x1024.pth --local-dir checkpoints
huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
```

### 1. Generate prompt from voice:
Expand All @@ -38,7 +38,7 @@ python tools/llama/generate.py \
--prompt-text "Your reference text" \
--prompt-tokens "fake.npy" \
--config-name dual_ar_2_codebook_large \
--checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
--checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
--num-samples 2 \
--compile
```
Expand Down Expand Up @@ -69,7 +69,7 @@ We provide a HTTP API for inference. You can use the following command to start
```bash
python -m tools.api \
--listen 0.0.0.0:8000 \
--llama-checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
--llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
--llama-config-name dual_ar_2_codebook_large \
--vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
```
Expand All @@ -82,7 +82,7 @@ You can start the WebUI using the following command:

```bash
python -m tools.webui \
--llama-checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
--llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
--llama-config-name dual_ar_2_codebook_large \
--vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
```
Expand Down
6 changes: 3 additions & 3 deletions docs/zh/finetune.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,13 @@ python tools/llama/build_dataset.py \
同样的, 请确保你已经下载了 `LLAMA` 权重, 如果没有, 请运行以下命令:

```bash
huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
```

对于中国大陆用户, 可使用 mirror 下载.

```bash
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
```

最后, 你可以运行以下命令来启动微调:
Expand Down Expand Up @@ -192,7 +192,7 @@ python fish_speech/train.py --config-name text2semantic_finetune \
python tools/llama/merge_lora.py \
--llama-config dual_ar_2_codebook_large \
--lora-config r_8_alpha_16 \
--llama-weight checkpoints/text2semantic-sft-large-v1-4k.pth \
--llama-weight checkpoints/text2semantic-sft-medium-v1-4k.pth \
--lora-weight results/text2semantic-finetune-medium-lora/checkpoints/step_000000200.ckpt \
--output checkpoints/merged.ckpt
```
10 changes: 5 additions & 5 deletions docs/zh/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@

```bash
huggingface-cli download fishaudio/fish-speech-1 vq-gan-group-fsq-2x1024.pth --local-dir checkpoints
huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
```
对于中国大陆用户,可使用mirror下载。
```bash
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 vq-gan-group-fsq-2x1024.pth --local-dir checkpoints
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-large-v1-4k.pth --local-dir checkpoints
HF_ENDPOINT=https://hf-mirror.com huggingface-cli download fishaudio/fish-speech-1 text2semantic-sft-medium-v1-4k.pth --local-dir checkpoints
```

### 1. 从语音生成 prompt:
Expand All @@ -43,7 +43,7 @@ python tools/llama/generate.py \
--prompt-text "你的参考文本" \
--prompt-tokens "fake.npy" \
--config-name dual_ar_2_codebook_large \
--checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
--checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
--num-samples 2 \
--compile
```
Expand Down Expand Up @@ -74,7 +74,7 @@ python tools/vqgan/inference.py \
```bash
python -m tools.api \
--listen 0.0.0.0:8000 \
--llama-checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
--llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
--llama-config-name dual_ar_2_codebook_large \
--vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"

Expand All @@ -90,7 +90,7 @@ HF_ENDPOINT=https://hf-mirror.com python -m ...

```bash
python -m tools.webui \
--llama-checkpoint-path "checkpoints/text2semantic-sft-large-v1-4k.pth" \
--llama-checkpoint-path "checkpoints/text2semantic-sft-medium-v1-4k.pth" \
--llama-config-name dual_ar_2_codebook_large \
--vqgan-checkpoint-path "checkpoints/vq-gan-group-fsq-2x1024.pth"
```
Expand Down
2 changes: 1 addition & 1 deletion fish_speech/configs/text2semantic_finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:

project: text2semantic_finetune_dual_ar
max_length: 2048
ckpt_path: checkpoints/text2semantic-sft-large-v1-4k.pth
ckpt_path: checkpoints/text2semantic-sft-medium-v1-4k.pth
resume_weights_only: true

# Lightning Trainer
Expand Down
2 changes: 1 addition & 1 deletion fish_speech/webui/manage.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ def generate_folder_name():
ckpt_path = (
"text2semantic-pretrain-medium-2k-v1.pth"
if llama_base_config == "dual_ar_2_codebook_medium"
else "text2semantic-sft-large-v1-4k.pth"
else "text2semantic-sft-medium-v1-4k.pth"
)

latest = list(
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies = [
"vector_quantize_pytorch>=1.14.7",
"samplerate>=0.2.1",
"resampy>=0.4.3",
"spaces>=0.26.1",
"einx[torch]==0.2.2"
]

Expand Down
2 changes: 1 addition & 1 deletion tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def parse_args():
parser.add_argument(
"--llama-checkpoint-path",
type=str,
default="checkpoints/text2semantic-sft-large-v1-4k.pth",
default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
)
parser.add_argument(
"--llama-config-name", type=str, default="dual_ar_2_codebook_large"
Expand Down
2 changes: 1 addition & 1 deletion tools/llama/merge_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@click.option("--llama-config", type=str, default="dual_ar_2_codebook_large")
@click.option("--lora-config", type=str, default="r_8_alpha_16")
@click.option(
"--llama-weight", type=str, default="checkpoints/text2semantic-sft-large-v1-4k.pth"
"--llama-weight", type=str, default="checkpoints/text2semantic-sft-medium-v1-4k.pth"
)
@click.option("--lora-weight", type=str, required=True)
@click.option("--output", type=str, required=True)
Expand Down
23 changes: 1 addition & 22 deletions tools/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,6 @@
TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
SPACE_IMPORTED = False

try:
import spaces

GPU_DECORATOR = spaces.GPU
SPACE_IMPORTED = True
except ImportError:

def GPU_DECORATOR(func):
@wraps(func)
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

wrapper.original = func # ref
return wrapper


def build_html_error_message(error):
return f"""
Expand All @@ -65,7 +50,6 @@ def build_html_error_message(error):
"""


@GPU_DECORATOR
@torch.inference_mode()
def inference(
text,
Expand Down Expand Up @@ -173,11 +157,6 @@ def inference(

inference_stream = partial(inference, streaming=True)

if not SPACE_IMPORTED:
logger.info("‘spaces’ not imported, use original")
inference = inference.original
inference_stream = partial(inference, streaming=True)


def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
buffer = io.BytesIO()
Expand Down Expand Up @@ -343,7 +322,7 @@ def parse_args():
parser.add_argument(
"--llama-checkpoint-path",
type=Path,
default="checkpoints/text2semantic-sft-large-v1-4k.pth",
default="checkpoints/text2semantic-sft-medium-v1-4k.pth",
)
parser.add_argument(
"--llama-config-name", type=str, default="dual_ar_2_codebook_large"
Expand Down

0 comments on commit e787d01

Please sign in to comment.