Skip to content

Commit

Permalink
Support Multi-speaker VITS (#131)
Browse files Browse the repository at this point in the history
Support Multi-speaker VITS & Hi-Fi TTS dataset preprocessing
  • Loading branch information
zyingt committed Feb 23, 2024
1 parent d37d8f1 commit 6e9d34f
Show file tree
Hide file tree
Showing 14 changed files with 191 additions and 48 deletions.
4 changes: 2 additions & 2 deletions bins/tts/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,11 @@ def extract_phonme_sequences(dataset, output_path, cfg, dataset_types):
dataset_file = os.path.join(dataset_output, "{}.json".format(dataset_type))
with open(dataset_file, "r") as f:
metadata.extend(json.load(f))
phone_extractor.extract_utt_phone_sequence(cfg, metadata)
phone_extractor.extract_utt_phone_sequence(dataset, cfg, metadata)


def preprocess(cfg, args):
"""Proprocess raw data of single or multiple datasets (in cfg.dataset)
"""Preprocess raw data of single or multiple datasets (in cfg.dataset)
Args:
cfg (dict): dictionary that stores configurations
Expand Down
1 change: 1 addition & 0 deletions config/tts.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// Directory names of processed data or extracted features
"phone_dir": "phones",
"use_phone": true,
"add_blank": true
},
"model": {
"text_token_num": 512,
Expand Down
31 changes: 31 additions & 0 deletions egs/datasets/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Amphion support the following academic datasets (sort alphabetically):
- [AudioCaps](#audiocaps)
- [CSD](#csd)
- [CustomSVCDataset](#customsvcdataset)
- [Hi-Fi TTS](#hifitts)
- [KiSing](#kising)
- [LibriLight](#librilight)
- [LibriTTS](#libritts)
Expand Down Expand Up @@ -75,6 +76,36 @@ We support custom dataset for Singing Voice Conversion. Organize your data in th
┣ ...
```


## Hi-Fi TTS

Download the official Hi-Fi TTS dataset [here](https://www.openslr.org/109/). The file structure looks like below:

```plaintext
[Hi-Fi TTS dataset path]
┣ audio
┃ ┣ 11614_other {Speaker_ID}_{SNR_subset}
┃ ┃ ┣ 10547 {Book_ID}
┃ ┃ ┃ ┣ thousandnights8_04_anonymous_0001.flac
┃ ┃ ┃ ┣ thousandnights8_04_anonymous_0003.flac
┃ ┃ ┃ ┣ thousandnights8_04_anonymous_0004.flac
┃ ┃ ┃ ┣ ...
┃ ┃ ┣ ...
┃ ┣ ...
┣ 92_manifest_clean_dev.json
┣ 92_manifest_clean_test.json
┣ 92_manifest_clean_train.json
┣ ...
┣ {Speaker_ID}_manifest_{SNR_subset}_{dataset_split}.json
┣ ...
┣ books_bandwidth.tsv
┣ LICENSE.txt
┣ readers_books_clean.txt
┣ readers_books_other.txt
┣ README.txt
```

## KiSing

Download the official KiSing dataset [here](http://shijt.site/index.php/2021/05/16/kising-the-first-open-source-mandarin-singing-voice-synthesis-corpus/). The file structure looks like below:
Expand Down
79 changes: 58 additions & 21 deletions egs/tts/VITS/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![hf](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-Spaces-yellow)](https://huggingface.co/spaces/amphion/Text-to-Speech)
[![openxlab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Amphion/Text-to-Speech)

In this recipe, we will show how to train [VITS](https://arxiv.org/abs/2106.06103) using Amphion's infrastructure. VITS is an end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning.
In this recipe, we will show how to train VITS using Amphion's infrastructure. [VITS](https://arxiv.org/abs/2106.06103) is an end-to-end TTS architecture that utilizes conditional variational autoencoder with adversarial learning.

There are four stages in total:

Expand All @@ -20,7 +20,7 @@ There are four stages in total:
## 1. Data Preparation
### Dataset Download
You can use the commonly used TTS dataset to train TTS model, e.g., LJSpeech, VCTK, LibriTTS, etc. We strongly recommend you use LJSpeech to train TTS model for the first time. How to download dataset is detailed [here](../../datasets/README.md).
You can use the commonly used TTS dataset to train TTS model, e.g., LJSpeech, VCTK, Hi-Fi TTS, LibriTTS, etc. We strongly recommend using LJSpeech to train single-speaker TTS model for the first time. While for training multi-speaker TTS model for the first time, we would recommend using Hi-Fi TTS. The process of downloading dataset has been detailed [here](../../datasets/README.md).
### Configuration
Expand All @@ -29,32 +29,41 @@ After downloading the dataset, you can set the dataset paths in `exp_config.jso
```json
"dataset": [
"LJSpeech",
//"hifitts"
],
"dataset_path": {
// TODO: Fill in your dataset path
"LJSpeech": "[LJSpeech dataset path]",
//"hifitts": "[Hi-Fi TTS dataset path]
},
```
## 2. Features Extraction
### Configuration
Specify the `processed_dir` and the `log_dir` and for saving the processed data and the checkpoints in `exp_config.json`:
In `exp_config.json`, specify the `log_dir` for saving the checkpoints and logs, and specify the `processed_dir` for saving processed data. For preprocessing the multi-speaker TTS dataset, set `extract_audio` and `use_spkid` to `true`:
```json
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
"log_dir": "ckpts/tts",
"preprocess": {
//"extract_audio": true,
"use_phone": true,
// linguistic features
"extract_phone": true,
"phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
// TODO: Fill in the output data path. The default value is "Amphion/data"
"processed_dir": "data",
...
"sample_rate": 22050, //target sampling rate
"valid_file": "valid.json", //validation set
//"use_spkid": true, //use speaker ID to train multi-speaker TTS model
},
```
### Run
Run the `run.sh` as the preproces stage (set `--stage 1`):
Run the `run.sh` as the preprocess stage (set `--stage 1`):
```bash
sh egs/tts/VITS/run.sh --stage 1
Expand All @@ -66,17 +75,22 @@ sh egs/tts/VITS/run.sh --stage 1
### Configuration
We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
We provide the default hyparameters in the `exp_config.json`. They can work on a single NVIDIA-24g GPU. You can adjust them based on your GPU machines.
For training the multi-speaker TTS model, specify the `n_speakers` value to be greater (used for new speaker fine-tuning) than or equal to the number of speakers in your dataset(s) and set `multi_speaker_training` to `true`.
```
"train": {
"batch_size": 16,
}
```json
"model": {
//"n_speakers": 10 //Number of speakers in the dataset(s) used. The default value is 0 if not specified.
},
"train": {
"batch_size": 16,
//"multi_speaker_training": true,
}
```
### Train From Scratch
Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.
Run the `run.sh` as the training stage (set `--stage 2`). Specify an experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/tts/[YourExptName]`.
```bash
sh egs/tts/VITS/run.sh --stage 2 --name [YourExptName]
Expand Down Expand Up @@ -139,12 +153,35 @@ For inference, you need to specify the following configurations when running `ru
| `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `Amphion/ckpts/tts/[YourExptName]` |
| `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/tts/[YourExptName]/result` |
| `--infer_mode` | The inference mode, e.g., "`single`", "`batch`". | "`single`" to generate a clip of speech, "`batch`" to generate a batch of speech at a time. |
| `--infer_dataset` | The dataset used for inference. | For LJSpeech dataset, the inference dataset would be `LJSpeech`. |
| `--infer_testing_set` | The subset of the inference dataset used for inference, e.g., train, test, golden_test | For LJSpeech dataset, the testing set would be  "`test`" split from LJSpeech at the feature extraction, or "`golden_test`" cherry-picked from test set as template testing set. |
| `--infer_dataset` | The dataset used for inference. | For LJSpeech dataset, the inference dataset would be `LJSpeech`.<br> For Hi-Fi TTS dataset, the inference dataset would be `hifitts`. |
| `--infer_testing_set` | The subset of the inference dataset used for inference, e.g., train, test, golden_test | For LJSpeech dataset, the testing set would be  "`test`" split from LJSpeech at the feature extraction, or "`golden_test`" cherry-picked from the test set as template testing set.<br>For Hi-Fi TTS dataset, the testing set would be "`test`" split from Hi-Fi TTS during the feature extraction process. |
| `--infer_text` | The text to be synthesized. | "`This is a clip of generated speech with the given text from a TTS model.`" |
| `--infer_speaker_name` | The target speaker's voice is to be synthesized.<br> (***Note: only applicable to multi-speaker TTS model***) | For Hi-Fi TTS dataset, the list of available speakers includes: "`hifitts_11614`", "`hifitts_11697`", "`hifitts_12787`", "`hifitts_6097`", "`hifitts_6670`", "`hifitts_6671`", "`hifitts_8051`", "`hifitts_9017`", "`hifitts_9136`", "`hifitts_92`". <br> You may find the list of available speakers from `spk2id.json` file generated in ```log_dir/[YourExptName]``` that you have specified in `exp_config.json`. |
### Run
For example, if you want to generate speech of all testing set split from LJSpeech, just run:
#### Single text inference:
For the single-speaker TTS model, if you want to generate a single clip of speech from a given text, just run:
```bash
sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
--infer_mode "single" \
--infer_text "This is a clip of generated speech with the given text from a TTS model."
```
For the multi-speaker TTS model, in addition to the above-mentioned arguments, you need to add ```infer_speaker_name``` argument, and run:
```bash
sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
--infer_mode "single" \
--infer_text "This is a clip of generated speech with the given text from a TTS model." \
--infer_speaker_name "hifitts_92"
```
#### Batch inference:
For the single-speaker TTS model, if you want to generate speech of all testing sets split from LJSpeech, just run:
```bash
sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
Expand All @@ -154,18 +191,18 @@ sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
--infer_dataset "LJSpeech" \
--infer_testing_set "test"
```

Or, if you want to generate a single clip of speech from a given text, just run:

For the multi-speaker TTS model, if you want to generate speech of all testing sets split from Hi-Fi TTS, the same procedure follows from above, with ```LJSpeech``` replaced by ```hifitts```.
```bash
sh egs/tts/VITS/run.sh --stage 3 --gpu "0" \
--infer_expt_dir Amphion/ckpts/tts/[YourExptName] \
--infer_output_dir Amphion/ckpts/tts/[YourExptName]/result \
--infer_mode "single" \
--infer_text "This is a clip of generated speech with the given text from a TTS model."
--infer_mode "batch" \
--infer_dataset "hifitts" \
--infer_testing_set "test"
```
We released a pre-trained Amphion VITS model trained on LJSpeech. So you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech following the above inference instruction.
We released a pre-trained Amphion VITS model trained on LJSpeech. So you can download the pre-trained model [here](https://huggingface.co/amphion/vits-ljspeech) and generate speech following the above inference instruction. Meanwhile, the pre-trained multi-speaker VITS model trained on Hi-Fi TTS will be released soon. Stay tuned.
```bibtex
Expand All @@ -176,4 +213,4 @@ We released a pre-trained Amphion VITS model trained on LJSpeech. So you can dow
pages={5530--5540},
year={2021},
}
```
```
21 changes: 14 additions & 7 deletions egs/tts/VITS/exp_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,33 @@
"base_config": "config/vits.json",
"model_type": "VITS",
"dataset": [
"LJSpeech"
"LJSpeech",
//"hifitts"
],
"dataset_path": {
// TODO: Fill in your dataset path
"LJSpeech": "[LJSpeech dataset path]"
"LJSpeech": "[LJSpeech dataset path]",
//"hifitts": "[Hi-Fi TTS dataset path]
},
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/tts"
"log_dir": "ckpts/tts",
"preprocess": {
//"extract_audio":true,
"use_phone": true,
// linguistic features
"extract_phone": true,
"phone_extractor": "lexicon", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
"phone_extractor": "espeak", // "espeak, pypinyin, pypinyin_initials_finals, lexicon (only for language=en-us right now)"
// TODO: Fill in the output data path. The default value is "Amphion/data"
"processed_dir": "data",

"sample_rate": 22050,
"valid_file": "test.json", // validattion set
"sample_rate": 22050, // target sampling rate
"valid_file": "valid.json", // validation set
//"use_spkid": true // use speaker ID to train multi-speaker TTS model
},
"model":{
//"n_speakers": 10 // number of speakers, greater than or equal to the number of speakers in the dataset(s) used. The default value is 0 if not specified.
},
"train": {
"batch_size": 16,
//"multi_speaker_training": true
}
}
}
23 changes: 16 additions & 7 deletions egs/tts/VITS/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ cd $work_dir

######## Parse the Given Parameters from the Commond ###########
# options=$(getopt -o c:n:s --long gpu:,config:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir:,name:,stage: -- "$@")
options=$(getopt -o c:n:s --long gpu:,config:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_mode:,infer_dataset:,infer_testing_set:,infer_text:,name:,stage: -- "$@")
options=$(getopt -o c:n:s --long gpu:,config:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_mode:,infer_dataset:,infer_testing_set:,infer_text:,infer_speaker_name:,name:,stage: -- "$@")
eval set -- "$options"

while true; do
Expand All @@ -43,14 +43,16 @@ while true; do
--infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
# [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
--infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
# [Only for Inference] The inference mode. It can be "batch" to generate speech by batch, or "single" to generage a single clip of speech.
# [Only for Inference] The inference mode. It can be "batch" to generate speech by batch, or "single" to generate a single clip of speech.
--infer_mode) shift; infer_mode=$1 ; shift ;;
# [Only for Inference] The inference dataset. It is only used when the inference model is "batch".
# [Only for Inference] The inference dataset. It is only used when the inference mode is "batch".
--infer_dataset) shift; infer_dataset=$1 ; shift ;;
# [Only for Inference] The inference testing set. It is only used when the inference model is "batch". It can be "test" set split from the dataset, or "golden_test" carefully selected from the testing set.
# [Only for Inference] The inference testing set. It is only used when the inference mode is "batch". It can be "test" set split from the dataset, or "golden_test" carefully selected from the testing set.
--infer_testing_set) shift; infer_testing_set=$1 ; shift ;;
# [Only for Inference] The text to be synthesized from. It is only used when the inference model is "single".
# [Only for Inference] The text to be synthesized from. It is only used when the inference mode is "single".
--infer_text) shift; infer_text=$1 ; shift ;;
# [Only for Inference] The chosen speaker's voice to be synthesized. It is only used when the inference mode is "single" for multi-speaker VITS.
--infer_speaker_name) shift; infer_speaker_name=$1 ; shift ;;

--) shift ; break ;;
*) echo "Invalid option: $1" exit 1 ;;
Expand All @@ -67,7 +69,7 @@ fi
if [ -z "$exp_config" ]; then
exp_config="${exp_dir}"/exp_config.json
fi
echo "Exprimental Configuration File: $exp_config"
echo "Experimental Configuration File: $exp_config"

if [ -z "$gpu" ]; then
gpu="0"
Expand All @@ -86,7 +88,7 @@ if [ $running_stage -eq 2 ]; then
echo "[Error] Please specify the experiments name"
exit 1
fi
echo "Exprimental Name: $exp_name"
echo "Experimental Name: $exp_name"

# add default value
if [ -z "$resume_from_ckpt_path" ]; then
Expand Down Expand Up @@ -153,6 +155,12 @@ if [ $running_stage -eq 3 ]; then
elif [ "$infer_mode" = "batch" ]; then
infer_text=''
fi

if [ -z "$infer_speaker_name" ]; then
infer_speaker_name=None
fi




CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/tts/inference.py \
Expand All @@ -163,6 +171,7 @@ if [ $running_stage -eq 3 ]; then
--dataset $infer_dataset \
--testing_set $infer_testing_set \
--text "$infer_text" \
--speaker_name $infer_speaker_name \
--log_level debug


Expand Down
3 changes: 3 additions & 0 deletions models/tts/base/tts_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,9 @@ def __init__(self, cfg, dataset, is_valid=False):
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)

if cfg.preprocess.add_blank:
sequence = intersperse(sequence, 0)

self.utt2seq[utt] = sequence

def __getitem__(self, index):
Expand Down
12 changes: 11 additions & 1 deletion models/tts/base/tts_inferece.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tqdm import tqdm
from accelerate.logging import get_logger
from torch.utils.data import DataLoader
from safetensors.torch import load_file


from abc import abstractmethod
Expand Down Expand Up @@ -162,7 +163,16 @@ def _load_model(
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
checkpoint_path = ls[0]

self.accelerator.load_state(str(checkpoint_path))
if (
Path(os.path.join(checkpoint_path, "model.safetensors")).exists()
and accelerate.__version__ < "0.25"
):
self.model.load_state_dict(
load_file(os.path.join(checkpoint_path, "model.safetensors")),
strict=False,
)
else:
self.accelerator.load_state(str(checkpoint_path))
return str(checkpoint_path)

def inference(self):
Expand Down
Loading

0 comments on commit 6e9d34f

Please sign in to comment.