Skip to content

Commit

Permalink
Support WavLM-based speaker similarity metric (#97)
Browse files Browse the repository at this point in the history
WavLM-based speaker similarity: code and docs
  • Loading branch information
HeCheng0625 committed Jan 11, 2024
1 parent 4cb808b commit 17dc95a
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Amphion provides a comprehensive objective evaluation of the generated audio. Th
- **Energy Modeling**: Energy Root Mean Square Error, Energy Pearson Coefficients, etc.
- **Intelligibility**: Character/Word Error Rate, which can be calculated based on [Whisper](https://github.com/openai/whisper) and more.
- **Spectrogram Distortion**: Frechet Audio Distance (FAD), Mel Cepstral Distortion (MCD), Multi-Resolution STFT Distance (MSTFT), Perceptual Evaluation of Speech Quality (PESQ), Short Time Objective Intelligibility (STOI), etc.
- **Speaker Similarity**: Cosine similarity, which can be calculated based on [RawNet3](https://github.com/Jungjee/RawNet), [Resemblyzer](https://github.com/resemble-ai/Resemblyzer), [WeSpeaker](https://github.com/wenet-e2e/wespeaker), and more.
- **Speaker Similarity**: Cosine similarity, which can be calculated based on [RawNet3](https://github.com/Jungjee/RawNet), [Resemblyzer](https://github.com/resemble-ai/Resemblyzer), [WeSpeaker](https://github.com/wenet-e2e/wespeaker), [WavLM](https://github.com/microsoft/unilm/tree/master/wavlm) and more.

### Datasets

Expand Down
4 changes: 3 additions & 1 deletion bins/calc_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from evaluation.metrics.similarity.resemblyzer_similarity import (
extract_resemblyzer_similarity,
)
from evaluation.metrics.similarity.wavlm_similarity import extract_wavlm_similarity
from evaluation.metrics.spectrogram.frechet_distance import extract_fad
from evaluation.metrics.spectrogram.mel_cepstral_distortion import extract_mcd
from evaluation.metrics.spectrogram.multi_resolution_stft_distance import extract_mstft
Expand All @@ -52,6 +53,7 @@
"wer": extract_wer,
"rawnet3_similarity": extract_speaker_similarity,
"resemblyzer_similarity": extract_resemblyzer_similarity,
"wavlm_similarity": extract_wavlm_similarity,
"fad": extract_fad,
"mcd": extract_mcd,
"mstft": extract_mstft,
Expand All @@ -66,7 +68,7 @@ def calc_metric(ref_dir, deg_dir, dump_dir, metrics, fs=None):
result = defaultdict()

for metric in tqdm(metrics):
if metric in ["fad", "rawnet3_similarity"]:
if metric in ["fad", "rawnet3_similarity", "wavlm_similarity"]:
result[metric] = str(METRIC_FUNC[metric](ref_dir, deg_dir))
continue
elif metric in ["resemblyzer_similarity"]:
Expand Down
2 changes: 2 additions & 0 deletions egs/metrics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Until now, Amphion Evaluation has supported the following objective metrics:
- **Speaker Similarity**:
- Cosine similarity based on [Rawnet3](https://github.com/Jungjee/RawNet)
- Cosine similarity based on [Resemblyzer](https://github.com/resemble-ai/Resemblyzer)
- Cosine similarity based on [WavLM](https://github.com/microsoft/unilm/tree/master/wavlm)
- Cosine similarity based on [WeSpeaker](https://github.com/wenet-e2e/wespeaker) (👨‍💻 developing)

We provide a recipe to demonstrate how to objectively evaluate your generated audios. There are three steps in total:
Expand Down Expand Up @@ -86,6 +87,7 @@ All currently available metrics keywords are listed below:
| `cer` | Character Error Rate |
| `wer` | Word Error Rate |
| `rawnet3_similarity` | Cos Similarity based on RawNet3 |
| `wavlm_similarity` | Cos Similarity based on WavLM |
| `resemblyzer_similarity` | Cos Similarity based on Resemblyzer |
| `fad` | Frechet Audio Distance |
| `mcd` | Mel Cepstral Distortion |
Expand Down
54 changes: 54 additions & 0 deletions evaluation/metrics/similarity/wavlm_similarity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn.functional as F
import os
import librosa
import numpy as np
from tqdm import tqdm

from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector


def extract_wavlm_similarity(target_path, reference_path):
"""Extract cosine similarity based on WavLM for two given audio folders.
target_path: path to the ground truth audio folder.
reference_path: path to the predicted audio folder.
"""
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
"microsoft/wavlm-base-plus-sv"
)
gpu = False
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv")
if torch.cuda.is_available():
print("Cuda available, conducting inference on GPU")
model = model.to("cuda")
gpu = True

similarity_scores = []

for file in tqdm(os.listdir(reference_path)):
ref_wav_path = os.path.join(reference_path, file)
tgt_wav_path = os.path.join(target_path, file)

ref_wav, _ = librosa.load(ref_wav_path, sr=16000)
tgt_wav, _ = librosa.load(tgt_wav_path, sr=16000)

inputs = feature_extractor(
[tgt_wav, ref_wav], padding=True, return_tensors="pt"
)

if gpu:
for key in inputs.keys():
inputs[key] = inputs[key].cuda("cuda")

with torch.no_grad():
embeddings = model(**inputs).embeddings
embeddings = embeddings.cpu()
cos_sim_score = F.cosine_similarity(embeddings[0], embeddings[1], dim=-1)
similarity_scores.append(cos_sim_score.item())

return np.mean(similarity_scores)

0 comments on commit 17dc95a

Please sign in to comment.