-
Notifications
You must be signed in to change notification settings - Fork 384
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support WavLM-based speaker similarity metric (#97)
WavLM-based speaker similarity: code and docs
- Loading branch information
1 parent
4cb808b
commit 17dc95a
Showing
4 changed files
with
60 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2023 Amphion. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
import os | ||
import librosa | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
from transformers import Wav2Vec2FeatureExtractor, WavLMForXVector | ||
|
||
|
||
def extract_wavlm_similarity(target_path, reference_path): | ||
"""Extract cosine similarity based on WavLM for two given audio folders. | ||
target_path: path to the ground truth audio folder. | ||
reference_path: path to the predicted audio folder. | ||
""" | ||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained( | ||
"microsoft/wavlm-base-plus-sv" | ||
) | ||
gpu = False | ||
model = WavLMForXVector.from_pretrained("microsoft/wavlm-base-plus-sv") | ||
if torch.cuda.is_available(): | ||
print("Cuda available, conducting inference on GPU") | ||
model = model.to("cuda") | ||
gpu = True | ||
|
||
similarity_scores = [] | ||
|
||
for file in tqdm(os.listdir(reference_path)): | ||
ref_wav_path = os.path.join(reference_path, file) | ||
tgt_wav_path = os.path.join(target_path, file) | ||
|
||
ref_wav, _ = librosa.load(ref_wav_path, sr=16000) | ||
tgt_wav, _ = librosa.load(tgt_wav_path, sr=16000) | ||
|
||
inputs = feature_extractor( | ||
[tgt_wav, ref_wav], padding=True, return_tensors="pt" | ||
) | ||
|
||
if gpu: | ||
for key in inputs.keys(): | ||
inputs[key] = inputs[key].cuda("cuda") | ||
|
||
with torch.no_grad(): | ||
embeddings = model(**inputs).embeddings | ||
embeddings = embeddings.cpu() | ||
cos_sim_score = F.cosine_similarity(embeddings[0], embeddings[1], dim=-1) | ||
similarity_scores.append(cos_sim_score.item()) | ||
|
||
return np.mean(similarity_scores) |