Skip to content

Commit

Permalink
Merge pull request #28 from 0X0StradSong/main
Browse files Browse the repository at this point in the history
New dataset: p239
  • Loading branch information
ctlllll authored Jun 25, 2024
2 parents 6f6d333 + 31bb84b commit b6c4454
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 57 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
---
*News: We are now on the mainnet with uid 3! Please join the [Bittensor Discord](https://discord.gg/RXST8svz) and see us at Channel γ·gamma·3! Also, please check our [X (Twitter)](https://twitter.com/myshell_ai/status/1772792027148894557) for our vision of creating a collaborative environment where everyone can contribute, benefit, and engage with open-source models, ultimately empowering millions. 03/24*

*Update: We are now in Phase 2 of the subnet. Our goal is to provide a more diversified and exciting voice dataset for miners to train and develop state-of-the-art efficient TTS models. We have observed that miners are making significant improvements in current metrics, which is encouraging. However, we need to be cautious about potential overfitting to these metrics. Our development team is working diligently on an adversarial and highly complex research study to develop an automatic system to address this issue. Have fun! 06/23*

## Introduction

> **Note:** The following documentation assumes you are familiar with basic Bittensor concepts: Miners, Validators, and incentives. If you need a primer, please check out https://docs.bittensor.com/learn/bittensor-building-blocks.
Expand All @@ -23,10 +25,12 @@ As building a TTS model is a complex task, we will divide the development into s
- **Phase 3**: More generally, we can have fast-clone models that can be adapted to new speakers with a small amount of data, e.g., [OpenVoice](https://github.com/myshell-ai/OpenVoice). We will move to fast-clone models in this phase.

## Current Status
We are currently in Phase 1. To start, we utilize the [VCTK](https://huggingface.co/datasets/vctk) dataset as the source of our speaker data. We randomly select 1 speaker from the dataset and the goal is to build a TTS model that can mimic this speaker's voice.
We are currently in Phase 2. To start, we utilize the [AniSpeech](https://huggingface.co/datasets/ShoukanLabs/AniSpeech) dataset as the source of our speaker data. We randomly select 1 speaker from the dataset and the goal is to build a TTS model that can perfectly mimic this speaker's voice.

Please refer to `tts_rater` folder for audio samples from the speaker and the text used for evaluation.

Please refer to `preprocess` folder for options to download and preprocess the dataset for training.

## Overview
![architecture](docs/tts_subnet.png)
Our subnet operates as follows:
Expand Down
4 changes: 2 additions & 2 deletions constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ class CompetitionParameters:
COMPETITION_SCHEDULE: List[CompetitionParameters] = [
CompetitionParameters(
reward_percentage=1.0,
competition_id="p334",
competition_id="p239",
),
]
ORIGINAL_COMPETITION_ID = "p334"
ORIGINAL_COMPETITION_ID = "p239"
CONSTANT_ALPHA = 0.1 # prev: 0.2
timestamp_epsilon = 0.005

Expand Down
39 changes: 39 additions & 0 deletions preprocess/clean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pandas as pd
import wave
import os
from glob import glob
from tqdm import tqdm
df_paths = glob('data/*.parquet')

def save_audio(row, file_name):
audio_bytes = row['audio']['bytes']

# Convert bytes to wave format and save
with wave.open(file_name, 'wb') as wave_file:
wave_file.setnchannels(1) # Mono
wave_file.setsampwidth(2) # Assuming 16-bit PCM
wave_file.setframerate(88200) # Assuming a sample rate of 16000 Hz
wave_file.writeframes(audio_bytes)

voice_counters = {}

for df_path in df_paths:
df = pd.read_parquet(df_path)
for index, row in tqdm(df.iterrows()):
voice = int(row['voice'])
voice_str = f"p_{voice:03d}"
if voice_str not in voice_counters:
voice_counters[voice_str] = 0
caption = row['caption']
# mkdir if not exists
txt_dir = f"txt/{voice_str}"
os.makedirs(txt_dir, exist_ok=True)
# save caption
with open(f"{txt_dir}/{voice_str}_{voice_counters[voice_str]:03d}.txt", 'w') as f:
f.write(caption)
# mkdir if not exists
audio_dir = f"wave/{voice_str}"
os.makedirs(audio_dir, exist_ok=True)
# save audio
save_audio(row, f"{audio_dir}/{voice_str}_{voice_counters[voice_str]:03d}_mic1.wav")
voice_counters[voice_str] += 1
25 changes: 25 additions & 0 deletions preprocess/download.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/bin/bash

# Base URL for the files
base_url="https://huggingface.co/datasets/ShoukanLabs/AniSpeech/resolve/main/data/ENGLISH-000"

# Directory to save the downloaded files
output_dir="./data"
mkdir -p "$output_dir"

# Download files from 00000 to 00037
for i in $(seq -w 0 37); do
file_url="${base_url}${i}-of-00038.parquet?download=true"
output_file="${output_dir}/ENGLISH-000${i}-of-00038.parquet"

echo "Downloading ${output_file}..."
wget -O "$output_file" "$file_url"

if [ $? -ne 0 ]; then
echo "Failed to download ${output_file}"
else
echo "Successfully downloaded ${output_file}"
fi
done

echo "Download completed."
43 changes: 42 additions & 1 deletion tts_rater/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,45 @@ def forward(self, inputs, mask=None):
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
for i in range(n_convs):
L = (L - kernel_size + 2 * pad) // stride + 1
return L
return L

class StackingSubsampling(nn.Module):
def __init__(self, stride, feat_in, feat_out):
super().__init__()
self.stride = stride
self.out = nn.Linear(stride * feat_in, feat_out)

def forward(
self, features: torch.Tensor, features_length: torch.Tensor
) -> torch.Tensor:
b, t, d = features.size()
pad_size = (self.stride - (t % self.stride)) % self.stride
features = nn.functional.pad(features, (0, 0, 0, pad_size))
_, t, _ = features.size()
features = torch.reshape(features, (b, t // self.stride, d * self.stride))
out_features = self.out(features)
out_length = torch.div(
features_length + pad_size, self.stride, rounding_mode="floor"
)
return out_features, out_length

class RaterJudger(nn.Module):
def __init__(self):
super().__init__()
self.subsampling = StackingSubsampling(3, 128, 128)
encoder_layer = nn.TransformerEncoderLayer(d_model=128, nhead=8)
self.transformer = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=8
)
self.linear = nn.Linear(128, 1)
print("ssl: ", sum(p.numel() for p in self.subsampling.parameters()))
print("conf",sum(p.numel() for p in self.transformer.parameters()))
print("lin", sum(p.numel() for p in self.linear.parameters()))

def forward(self, x):
bsz, _, lens = x.size()
leng = torch.tensor([lens for _ in range(bsz)]).to(x.device)
x, leng = self.subsampling(x.transpose(1, 2), leng)
x = self.transformer(x)
return self.linear(x.mean(dim=1))
166 changes: 113 additions & 53 deletions tts_rater/rater.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from whisper.normalizers import EnglishTextNormalizer

from tts_rater.models import ReferenceEncoder
from tts_rater.models import ReferenceEncoder, RaterJudger
import torch
import torch.nn.functional as F
from tts_rater.mel_processing import spectrogram_torch
from tts_rater.mel_processing import spectrogram_torch, mel_spectrogram_torch
import librosa
import os
import eng_to_ipa as ipa
Expand All @@ -20,8 +20,9 @@
from tqdm import tqdm
import onnxruntime as ort
import tempfile
from tts_rater.pann import PANNModel
# from tts_rater.pann import PANNModel
from tts_rater.rawnet.inference import AntiSpoofingInference
from huggingface_hub import hf_hub_download

script_dir = os.path.dirname(os.path.abspath(__file__))

Expand Down Expand Up @@ -215,49 +216,49 @@ def compute_dns_mos_loss(audio_paths, batch_size):


# =================== PANN MMD loss ===================
speaker_pann_embeds = torch.load(os.path.join(script_dir, "pann/pann_embeds.pth"), map_location="cuda")
pann_model = PANNModel()
# speaker_pann_embeds = torch.load(os.path.join(script_dir, "pann/pann_embeds.pth"), map_location="cuda")
# pann_model = PANNModel()

def compute_mmd(a_x: torch.Tensor, b_y: torch.Tensor):
_SIGMA = 10
_SCALE = 1000
# def compute_mmd(a_x: torch.Tensor, b_y: torch.Tensor):
# _SIGMA = 10
# _SCALE = 1000

a_x = a_x.double()
b_y = b_y.double()
# a_x = a_x.double()
# b_y = b_y.double()

a_x_sqnorms = torch.sum(a_x**2, dim=1)
b_y_sqnorms = torch.sum(b_y**2, dim=1)
# a_x_sqnorms = torch.sum(a_x**2, dim=1)
# b_y_sqnorms = torch.sum(b_y**2, dim=1)

gamma = 1 / (2 * _SIGMA**2)
# gamma = 1 / (2 * _SIGMA**2)

k_xx = torch.mean(torch.exp(-gamma * (-2 * (a_x @ a_x.T) + a_x_sqnorms[:, None] + a_x_sqnorms[None, :])))
k_xy = torch.mean(torch.exp(-gamma * (-2 * (a_x @ b_y.T) + a_x_sqnorms[:, None] + b_y_sqnorms[None, :])))
k_yy = torch.mean(torch.exp(-gamma * (-2 * (b_y @ b_y.T) + b_y_sqnorms[:, None] + b_y_sqnorms[None, :])))
# k_xx = torch.mean(torch.exp(-gamma * (-2 * (a_x @ a_x.T) + a_x_sqnorms[:, None] + a_x_sqnorms[None, :])))
# k_xy = torch.mean(torch.exp(-gamma * (-2 * (a_x @ b_y.T) + a_x_sqnorms[:, None] + b_y_sqnorms[None, :])))
# k_yy = torch.mean(torch.exp(-gamma * (-2 * (b_y @ b_y.T) + b_y_sqnorms[:, None] + b_y_sqnorms[None, :])))

return _SCALE * (k_xx + k_yy - 2 * k_xy)
# return _SCALE * (k_xx + k_yy - 2 * k_xy)


def compute_pann_mmd_loss(audio_paths: list[str], speaker: str = "p374"):
# def compute_pann_mmd_loss(audio_paths: list[str], speaker: str = "p374"):

n_samples = len(audio_paths)
waveforms = [load_wav_file(fname, 32000) for fname in audio_paths]
# n_samples = len(audio_paths)
# waveforms = [load_wav_file(fname, 32000) for fname in audio_paths]

embeddings = []
for audio in tqdm(waveforms):
audio = torch.Tensor(audio).cuda()
embedding = pann_model.get_embedding(audio[None])[0]
embeddings.append(embedding)
embeddings = torch.stack(embeddings, dim=0)
embeddings_reverse = embeddings.flip(0)
embeddings_cat = torch.cat([embeddings, embeddings_reverse], dim=1).reshape(-1, embeddings.shape[-1])

mmd_losses = []
for idx in range(n_samples):
sampled_embeddings = embeddings_cat[idx: idx + 16]
mmd = compute_mmd(speaker_pann_embeds[speaker], sampled_embeddings)
mmd_losses.append(mmd.item())

return mmd_losses
# embeddings = []
# for audio in tqdm(waveforms):
# audio = torch.Tensor(audio).cuda()
# embedding = pann_model.get_embedding(audio[None])[0]
# embeddings.append(embedding)
# embeddings = torch.stack(embeddings, dim=0)
# embeddings_reverse = embeddings.flip(0)
# embeddings_cat = torch.cat([embeddings, embeddings_reverse], dim=1).reshape(-1, embeddings.shape[-1])

# mmd_losses = []
# for idx in range(n_samples):
# sampled_embeddings = embeddings_cat[idx: idx + 16]
# mmd = compute_mmd(speaker_pann_embeds[speaker], sampled_embeddings)
# mmd_losses.append(mmd.item())

# return mmd_losses

# =================== Anti Spoofing loss ===================
speaker_antispoofing_embeds = torch.load(os.path.join(script_dir, "rawnet/antispoofing_embeds.pth"), map_location="cuda")
Expand Down Expand Up @@ -292,6 +293,49 @@ def compute_antispoofing_loss(audio_paths: list[str], batch_size: int = 16, spea
)
return distance.min(dim=1).values.cpu().tolist()

# =================== RaterJudger ==============
rater_judger= RaterJudger().cuda()
model_name = 'myshell-test/judge_239'
model_path = model_name.replace('/', '_')
temp_location = hf_hub_download(repo_id=model_name, repo_type='model', filename='checkpoint.pth', local_dir=model_path)
rd_checkpoint = torch.load(
os.path.join(script_dir, "judge_239.pth"), map_location="cuda"
)
rater_judger.load_state_dict(rd_checkpoint["model"], strict=True)
rater_judger.eval()
def compute_judger_loss(audio_paths):
# waveforms = [load_wav_file(fname, hps.data.sampling_rate) for fname in audio_paths]
waveforms = []
for f in audio_paths:
audio = load_wav_file(f, 44100)
audio = quick_pad(audio, 16384)
audio = random_crop(audio, 16384)
waveforms.append(audio)
batch_size = 16

gs = []

for y in tqdm(
batched(waveforms, batch_size), total=math.ceil(len(waveforms) / batch_size)
):
with torch.inference_mode():
y = torch.stack(y)
y = y.to(next(rater_judger.parameters()).device)
y_hat_mel = mel_spectrogram_torch(
y.squeeze(1),
2048,
128,
44100,
512,
2048,
0,
None,
)
g = rater_judger(y_hat_mel)
gs.append(g.detach())
gs = torch.cat(gs)
return gs.squeeze_(1).tolist()

# =================== Rate function ===================

# TODO: Read texts from Internet or use a larger dataset
Expand All @@ -300,7 +344,13 @@ def compute_antispoofing_loss(audio_paths: list[str], batch_size: int = 16, spea

def get_normalized_scores(raw_errs: dict[str, float]):
# we adjust the normalization range to encourage miner improve the antispoofing score
score_ranges = {"pann_mmd": (50.0, 200.0), "word_error_rate": (0.0, 0.08), "tone_color": (0.15, 0.4), "antispoofing": (0.6, 1.5)}
score_ranges = {
# "pann_mmd": (50.0, 200.0),
"word_error_rate": (0.04, 0.12), # a more challenging dataset
"tone_color": (0.15, 0.4),
"antispoofing": (0.6, 1.5),
"judge_scores": (-0.00236, -0.00244)
}
normalized_scores = {}
for key, value in raw_errs.items():
min_val, max_val = score_ranges[key]
Expand All @@ -309,22 +359,22 @@ def get_normalized_scores(raw_errs: dict[str, float]):
return normalized_scores


def compute_sharpe_ratios(scores: list[float]) -> list[float]:
# Jackknife estimate of the Sharpe ratio.
n = len(scores)
sharpe_ratios = []
for ii in range(n):
scores_jack = scores[:ii] + scores[ii + 1 :]
mean_jack = np.mean(scores_jack)
std_jack = np.std(scores_jack, ddof=1)
sharpe_jack = mean_jack / std_jack
# def compute_sharpe_ratios(scores: list[float]) -> list[float]:
# # Jackknife estimate of the Sharpe ratio.
# n = len(scores)
# sharpe_ratios = []
# for ii in range(n):
# scores_jack = scores[:ii] + scores[ii + 1 :]
# mean_jack = np.mean(scores_jack)
# std_jack = np.std(scores_jack, ddof=1)
# sharpe_jack = mean_jack / std_jack

if mean_jack < 1e-6 and std_jack == 0.0:
sharpe_jack = 0.0
# if mean_jack < 1e-6 and std_jack == 0.0:
# sharpe_jack = 0.0

sharpe_ratios.append(sharpe_jack)
# sharpe_ratios.append(sharpe_jack)

return sharpe_ratios
# return sharpe_ratios


def rate(
Expand Down Expand Up @@ -372,8 +422,11 @@ def rate_(
model.tts_to_file(text, spkr, save_path, speed=1.0, quiet=True)

audio_paths = sorted(glob.glob(os.path.join(tmpdir, "*.wav")))

judge_scores = compute_judger_loss(audio_paths)


pann_mmds = compute_pann_mmd_loss(audio_paths, speaker)
# pann_mmds = compute_pann_mmd_loss(audio_paths, speaker)
total_errs, total_words = compute_wer(text_test, audio_paths, batch_size)
word_error_rates = []
for idxs in range(samples):
Expand All @@ -384,8 +437,15 @@ def rate_(

antispoofing_losses = compute_antispoofing_loss(audio_paths, batch_size, speaker)

assert len(pann_mmds) == len(word_error_rates) == len(tcs) == len(antispoofing_losses) == samples
raw_errs = {"pann_mmd": pann_mmds, "word_error_rate": word_error_rates, "tone_color": tcs, "antispoofing": antispoofing_losses}
# assert len(pann_mmds) == len(word_error_rates) == len(tcs) == len(antispoofing_losses) == samples == len(disc_scores)
assert len(word_error_rates) == len(tcs) == len(antispoofing_losses) == samples == len(judge_scores)
raw_errs = {
# "pann_mmd": pann_mmds,
"word_error_rate": word_error_rates,
"tone_color": tcs,
"antispoofing": antispoofing_losses,
"judge_scores": judge_scores,
}
norm_dict = get_normalized_scores(raw_errs)

keys = list(norm_dict.keys())
Expand Down
Binary file modified tts_rater/rawnet/antispoofing_embeds.pth
Binary file not shown.
Binary file modified tts_rater/vec_gt.pth
Binary file not shown.

0 comments on commit b6c4454

Please sign in to comment.