Skip to content

Commit

Permalink
fix n_samples, sample_rate and n_mels
Browse files Browse the repository at this point in the history
  • Loading branch information
matheusbach committed Dec 13, 2023
1 parent bd6311f commit 63ef88b
Showing 1 changed file with 15 additions and 21 deletions.
36 changes: 15 additions & 21 deletions whisperx_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import os
from pathlib import Path
import re

import pysrt
import whisperx
import whisper # only for detect language

import whisper_utils
import subtitle_utils
from utils import time_task

Expand Down Expand Up @@ -38,22 +37,17 @@ def transcribe_audio(model: whisperx.asr.WhisperModel, audio_path: Path, srt_pat


def detect_language(model: whisperx.asr.WhisperModel, audio_path: Path):
# load audio and pad/trim it to fit 30 seconds
# audio = whisperx.load_audio(audio_path.as_posix(), 16000)
# segment = whisperx.asr.log_mel_spectrogram(audio[: whisperx.asr.N_SAMPLES], padding=0 if audio.shape[0]
# >= whisperx.asr.N_SAMPLES else whisperx.asr.N_SAMPLES - audio.shape[0], device="cpu")
# encoder_output = model.model.encode(segment)
# results = model.model.model.detect_language(encoder_output).to("cpu")
# language_token, language_probability = results[0][0]
# return language_token[2:-2]

# ABOVE CODE IS BEST, BUT ITS NOT WORKING FOR NOW IN SOME SYSTEMS. SAVE FOR THE FUTURE

audio = whisper.load_audio(audio_path.as_posix(), 16000)
audio = whisper.pad_or_trim(audio, whisperx.asr.N_SAMPLES)
# make log-Mel spectrogram and move to the same device as the model
mel = whisper.log_mel_spectrogram(audio).to("cpu")
whisper_model = whisper.load_model("base", device="cpu", in_memory=True)
# detect the spoken language
_, probs = whisper_model.detect_language(mel)
return max(probs, key=probs.get)
try:
if os.getenv("COLAB_RELEASE_TAG"):
raise Exception("Method invalid for Google Colab")
audio = whisperx.load_audio(audio_path.as_posix(), model.model.feature_extractor.sampling_rate)
audio = whisper.pad_or_trim(audio, model.model.feature_extractor.n_samples)
mel = whisperx.asr.log_mel_spectrogram(audio, n_mels=model.model.model.n_mels)
encoder_output = model.model.encode(mel)
results = model.model.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
return language_token[2:-2]
except:
print("using whisper base model for detection: ", end='')
whisper_model = whisper.load_model("base", device="cpu", in_memory=True)
return whisper_utils.detect_language(model=whisper_model, audio_path=audio_path)

0 comments on commit 63ef88b

Please sign in to comment.