Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sgmse implementation #177

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions bins/sgmse/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
from argparse import ArgumentParser
import os

from models.sgmse.dereverberation.dereverberation_inference import (
DereverberationInference,
)
from utils.util import save_config, load_model_config, load_config
import numpy as np
import torch


def build_inference(args, cfg):
supported_inference = {
"dereverberation": DereverberationInference,
}

inference_class = supported_inference[cfg.model_type]
inference = inference_class(args, cfg)
return inference


def build_parser():
parser = argparse.ArgumentParser()

parser.add_argument(
"--config",
type=str,
required=True,
help="JSON/YAML file for configurations.",
)
parser.add_argument(
"--checkpoint_path",
type=str,
)
parser.add_argument(
"--test_dir",
type=str,
required=True,
help="Directory containing the test data (must have subdirectory noisy/)",
)
parser.add_argument(
"--corrector_steps", type=int, default=1, help="Number of corrector steps"
)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="Output dir for saving generated results",
)
parser.add_argument(
"--snr",
type=float,
default=0.33,
help="SNR value for (annealed) Langevin dynmaics.",
)
parser.add_argument("--N", type=int, default=50, help="Number of reverse steps")
parser.add_argument("--local_rank", default=0, type=int)
return parser


def main():
# Parse arguments
args = build_parser().parse_args()
# args, infer_type = formulate_parser(args)

# Parse config
cfg = load_config(args.config)
if torch.cuda.is_available():
args.local_rank = torch.device("cuda")
else:
args.local_rank = torch.device("cpu")
print("args: ", args)

# Build inference
inferencer = build_inference(args, cfg)

# Run inference
inferencer.inference()


if __name__ == "__main__":
main()
53 changes: 53 additions & 0 deletions bins/sgmse/preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import faulthandler

faulthandler.enable()
import os
import argparse
import json
from multiprocessing import cpu_count
from utils.util import load_config
from preprocessors.processor import preprocess_dataset


def preprocess(cfg):
"""Proprocess raw data of single or multiple datasets (in cfg.dataset)

Args:
cfg (dict): dictionary that stores configurations
"""
# Specify the output root path to save the processed data
output_path = cfg.preprocess.processed_dir
os.makedirs(output_path, exist_ok=True)

## Split train and test sets
for dataset in cfg.dataset:
print("Preprocess {}...".format(dataset))

preprocess_dataset(
dataset,
cfg.dataset_path[dataset],
output_path,
cfg.preprocess,
cfg.task_type,
is_custom_dataset=dataset in cfg.use_custom_dataset,
)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config", default="config.json", help="json files for configurations."
)
parser.add_argument("--num_workers", type=int, default=int(cpu_count()))
args = parser.parse_args()
cfg = load_config(args.config)
preprocess(cfg)


if __name__ == "__main__":
main()
87 changes: 87 additions & 0 deletions bins/sgmse/train_sgmse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os
import torch
from models.sgmse.dereverberation.dereverberation_Trainer import DereverberationTrainer

from utils.util import load_config


def build_trainer(args, cfg):
supported_trainer = {
"dereverberation": DereverberationTrainer,
}

trainer_class = supported_trainer[cfg.model_type]
trainer = trainer_class(args, cfg)
return trainer


def cuda_relevant(deterministic=False):
torch.cuda.empty_cache()
# TF32 on Ampere and above
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.allow_tf32 = True
# Deterministic
torch.backends.cudnn.deterministic = deterministic
torch.backends.cudnn.benchmark = not deterministic
torch.use_deterministic_algorithms(deterministic)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config",
default="config.json",
help="json files for configurations.",
required=True,
)
parser.add_argument(
"--num_workers", type=int, default=4, help="Number of dataloader workers."
)
parser.add_argument(
"--exp_name",
type=str,
default="exp_name",
help="A specific name to note the experiment",
required=True,
)
parser.add_argument(
"--log_level", default="warning", help="logging level (debug, info, warning)"
)
parser.add_argument("--stdout_interval", default=5, type=int)
parser.add_argument("--local_rank", default=0, type=int)
args = parser.parse_args()
cfg = load_config(args.config)
cfg.exp_name = args.exp_name
args.log_dir = os.path.join(cfg.log_dir, args.exp_name)
os.makedirs(args.log_dir, exist_ok=True)
# Data Augmentation
if cfg.preprocess.data_augment:
new_datasets_list = []
for dataset in cfg.preprocess.data_augment:
new_datasets = [
# f"{dataset}_pitch_shift",
# f"{dataset}_formant_shift",
f"{dataset}_equalizer",
f"{dataset}_time_stretch",
]
new_datasets_list.extend(new_datasets)
cfg.dataset.extend(new_datasets_list)

# CUDA settings
cuda_relevant()

# Build trainer
trainer = build_trainer(args, cfg)

trainer.train()


if __name__ == "__main__":
main()
42 changes: 42 additions & 0 deletions config/sgmse.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
{
"base_config": "config/base.json",
"dataset": [
"wsj0reverb"
],
"task_type": "sgmse",
"preprocess": {
"dummy": false,
"num_frames":256,
"normalize": "noisy",
"hop_length": 128,
"n_fft": 510,
"spec_abs_exponent": 0.5,
"spec_factor": 0.15,
"use_spkid": false,
"use_uv": false,
"use_frame_pitch": false,
"use_phone_pitch": false,
"use_frame_energy": false,
"use_phone_energy": false,
"use_mel": false,
"use_audio": false,
"use_label": false,
"use_one_hot": false
},
"model": {
"sgmse": {
"backbone": "ncsnpp",
"sde": "ouve",

"gpus": 1
}
},
"train": {
"batch_size": 8,
"lr": 1e-4,
"ema_decay": 0.999,
"t_eps": 3e-2,
"num_eval_files": 20
}

}
98 changes: 98 additions & 0 deletions egs/sgmse/README.md
lithr1 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Amphion Speech Enhancement and Dereverberation with Diffusion-based Generative Models Recipe


<br>
<div align="center">
<img src="../../imgs/sgmse/diffusion_process.png" width="90%">
</div>
<br>
This repository contains the PyTorch implementations for the 2023 papers and also adapted from [sgmse](https://github.com/sp-uhh/sgmse):
- Julius Richter, Simon Welker, Jean-Marie Lemercier, Bunlong Lay, Timo Gerkmann. [*"Speech Enhancement and Dereverberation with Diffusion-Based Generative Models"*](https://ieeexplore.ieee.org/abstract/document/10149431), IEEE/ACM Transactions on Audio, Speech, and Language Processing, vol. 31, pp. 2351-2364, 2023.


You can use any sgmse architecture with any dataset you want. There are three steps in total:

1. Data preparation
2. Training
3. Inference


> **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
> ```bash
> cd Amphion
> ```

## 1. Data Preparation

You can train the vocoder with any datasets. Amphion's supported open-source datasets are detailed [here](../../../datasets/README.md).

### Configuration

Specify the dataset path in `exp_config_base.json`. Note that you can change the `dataset` list to use your preferred datasets.

```json
"dataset": [
"wsj0reverb"
],
"dataset_path": {
// TODO: Fill in your dataset path
"wsj0reverb": ""
},
"preprocess": {
"processed_dir": "",
"sample_rate": 16000
},
```

## 2. Training

### Configuration

We provide the default hyparameters in the `exp_config_base.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines.

```json
"train": {
// TODO: Fill in your checkpoint path
"checkpoint": "",
"adam": {
"lr": 1e-4
},
"ddp": false,
"batch_size": 8,
"epochs": 200000,
"save_checkpoints_steps": 800,
"save_summary_steps": 1000,
"max_steps": 1000000,
"ema_decay": 0.999,
"valid_interval": 800,
"t_eps": 3e-2,
"num_eval_files": 20

}
}
```

### Run

Run the `run.sh` as the training stage (set `--stage 2`).

```bash
sh egs/sgmse/dereverberation/run.sh --stage 2
```

> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.

## 3. Inference

### Run

Run the `run.sh` as the training stage (set `--stage 3`)

```bash
sh egs/sgmse/dereverberation/run.sh --stage 3
--checkpoint_path [your path]
--test_dir [your path]
--output_dir [your path]

```

Loading
Loading