From 63ef88b6a77714b75af1e159fcfee89e85062d83 Mon Sep 17 00:00:00 2001 From: matheusbach <35426162+matheusbach@users.noreply.github.com> Date: Tue, 12 Dec 2023 19:36:37 -0300 Subject: [PATCH] fix n_samples, sample_rate and n_mels --- whisperx_utils.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/whisperx_utils.py b/whisperx_utils.py index 0e8aecc..ba41ee7 100644 --- a/whisperx_utils.py +++ b/whisperx_utils.py @@ -1,11 +1,10 @@ import os from pathlib import Path -import re -import pysrt import whisperx import whisper # only for detect language +import whisper_utils import subtitle_utils from utils import time_task @@ -38,22 +37,17 @@ def transcribe_audio(model: whisperx.asr.WhisperModel, audio_path: Path, srt_pat def detect_language(model: whisperx.asr.WhisperModel, audio_path: Path): - # load audio and pad/trim it to fit 30 seconds - # audio = whisperx.load_audio(audio_path.as_posix(), 16000) - # segment = whisperx.asr.log_mel_spectrogram(audio[: whisperx.asr.N_SAMPLES], padding=0 if audio.shape[0] - # >= whisperx.asr.N_SAMPLES else whisperx.asr.N_SAMPLES - audio.shape[0], device="cpu") - # encoder_output = model.model.encode(segment) - # results = model.model.model.detect_language(encoder_output).to("cpu") - # language_token, language_probability = results[0][0] - # return language_token[2:-2] - - # ABOVE CODE IS BEST, BUT ITS NOT WORKING FOR NOW IN SOME SYSTEMS. SAVE FOR THE FUTURE - - audio = whisper.load_audio(audio_path.as_posix(), 16000) - audio = whisper.pad_or_trim(audio, whisperx.asr.N_SAMPLES) - # make log-Mel spectrogram and move to the same device as the model - mel = whisper.log_mel_spectrogram(audio).to("cpu") - whisper_model = whisper.load_model("base", device="cpu", in_memory=True) - # detect the spoken language - _, probs = whisper_model.detect_language(mel) - return max(probs, key=probs.get) + try: + if os.getenv("COLAB_RELEASE_TAG"): + raise Exception("Method invalid for Google Colab") + audio = whisperx.load_audio(audio_path.as_posix(), model.model.feature_extractor.sampling_rate) + audio = whisper.pad_or_trim(audio, model.model.feature_extractor.n_samples) + mel = whisperx.asr.log_mel_spectrogram(audio, n_mels=model.model.model.n_mels) + encoder_output = model.model.encode(mel) + results = model.model.model.detect_language(encoder_output) + language_token, language_probability = results[0][0] + return language_token[2:-2] + except: + print("using whisper base model for detection: ", end='') + whisper_model = whisper.load_model("base", device="cpu", in_memory=True) + return whisper_utils.detect_language(model=whisper_model, audio_path=audio_path)