diff --git a/padertorch/contrib/mk/modules/features/timefreq.py b/padertorch/contrib/mk/modules/features/timefreq.py index d77305b2..6b2c6539 100644 --- a/padertorch/contrib/mk/modules/features/timefreq.py +++ b/padertorch/contrib/mk/modules/features/timefreq.py @@ -19,7 +19,7 @@ from padertorch.contrib.mk.typing import TSeqLen, TSeqReturn __all__ = [ - 'Sequential', + 'Identity', 'Logarithm', 'STFT', 'MelTransform', @@ -27,7 +27,7 @@ ] -class Sequential(pt.Module): +class Identity(pt.Module): def forward( self, x: Tensor, sequence_lengths: TSeqLen = None ) -> TSeqReturn: @@ -71,6 +71,7 @@ def forward(self, x: Tensor) -> Tensor: x = self.log_fn(torch.maximum( torch.tensor(self.eps).to(x.device), x )) + return x def inverse(self, x: torch.Tensor) -> torch.Tensor: return self.power_fn(x) @@ -88,6 +89,8 @@ class STFT(_STFT): pad (bool): See paderbox.transform.module_stft.stft. symmetric_window (bool): See paderbox.transform.module_stft.stft. complex_representation (str): See padertorch.ops._stft.STFT. + preemphasis (float, optional): If not None, apply pre-emphasis with + this value to the input signals. Defaults to None. spectrogram (bool): If True, return the magnitude spectrogram. Defaults to False. power (float): If `spectrogram` is True, raise magnitude to `power`. @@ -97,6 +100,8 @@ class STFT(_STFT): paderbox.transform.module_fbank.fbank. Defaults to False. log_base (str, int, float, bool, optional): See Logarithm. Defaults to False. + sequence_last (bool): If True, move the sequence axis to the last + position. Defaults to True. normalization (InputNormalization, optional): InputNormalization instance to perform z-normalization. Defaults to None. """ @@ -111,10 +116,12 @@ def __init__( pad: bool = True, symmetric_window: bool = False, complex_representation: str = 'complex', + preemphasis: tp.Optional[float] = None, spectrogram: bool = False, power: float = 1., scale_spec: bool = False, log_base: tp.Union[None, str, int, float, bool] = False, + sequence_last: bool = True, normalization: tp.Union[InputNormalization, None] = None, ): if not spectrogram and log_base: @@ -134,12 +141,48 @@ def __init__( # Keep references to window and symmetric_window self.window = window self.symmetric_window = symmetric_window + + if preemphasis is not None: + try: + from torchaudio.transforms import Preemphasis + except ImportError as e: + try: + import torchaudio + raise ImportError( + f"Your torchaudio version ({torchaudio.__version__}) " + "does not support pre-emphasis. If you want to use " + "pre-emphasis, install torchaudio>=2.0.1." + ) from e + except ImportError as e2: + raise ImportError( + "You need to install torchaudio>=2.0.1 to use " + "pre-emphasis." + ) from e2 + self.preemphasis = Preemphasis(preemphasis) + else: + self.preemphasis = None self.spectrogram = spectrogram self.power = power self.scale_spec = scale_spec self.log = Logarithm(log_base=log_base) + self.sequence_last = sequence_last self.normalization = normalization + def to_spectrogram(self, stft_signal: Tensor) -> Tensor: + if self.complex_representation == 'complex': + spect = torch.abs(stft_signal) + elif self.complex_representation == 'stacked': + spect = stft_signal.pow(2).sum(-1).sqrt() + else: + real, imag = torch.split( + stft_signal, stft_signal.shape[-1] // 2, dim=-1 + ) + spect = (real.pow(2) + imag.pow(2)).sqrt() + spect = spect ** self.power + if self.scale_spec: + spect /= self.size + return spect + def __call__( self, inputs: Tensor, @@ -154,34 +197,35 @@ def __call__( time signals in `inputs`. Defaults to None. Returns: - encoded (Tensor): Spectrogram of shape (batch, time, bins) if + encoded (Tensor): Spectrogram of shape (batch, bins, time) if `spectrogram` is True else STFT of shape - - (batch, time, bins) if `complex_representation` is 'complex', - - (batch, time, bins, 2) if `complex_representation` is 'stacked', or - - (batch, channels, time, 2*bins) if `complex_representation` is 'concat'. + - (batch, bins, time) if `complex_representation` is 'complex', + - (batch, bins, time, 2) if `complex_representation` is 'stacked', or + - (batch, channels, 2*bins, time) if `complex_representation` is 'concat'. + If `sequence_last` is False, the time and bins axis are swapped. sequence_lengths (list, optional): List of number of frames of spectrograms in `encoded` if input `sequence_lengths` is not None. """ + if self.preemphasis is not None: + inputs = self.preemphasis(inputs) + encoded = super().__call__(inputs) if self.spectrogram: - if self.complex_representation == 'complex': - encoded = torch.abs(encoded) - elif self.complex_representation == 'stacked': - encoded = encoded.pow(2).sum(-1).sqrt() - else: - real, imag = torch.split( - encoded, encoded.shape[-1] // 2, dim=-1 - ) - encoded = (real.pow(2) + imag.pow(2)).sqrt() - encoded = encoded ** self.power - if self.scale_spec: - encoded /= self.size + encoded = self.to_spectrogram(encoded) encoded = self.log(encoded) if sequence_lengths is not None: sequence_lengths = self.samples_to_frames( np.asarray(sequence_lengths) ) + if self.sequence_last: + if ( + self.complex_representation == 'stacked' + and not self.spectrogram + ): + encoded = encoded.transpose(-2, -3) + else: + encoded = encoded.transpose(-2, -1) # (..., bins, time) if self.normalization is not None: encoded = self.normalization( encoded, sequence_lengths=sequence_lengths @@ -235,6 +279,8 @@ class MelTransform(pt.Module): filter banks are used. squeeze_channel_axis (bool): If True, squeeze the channel axis and always return a 3D tensor. Defaults to False. + sequence_last (boo): If True, move the sequence axis to the last + position. Defaults to True. normalization (InputNormalization, optional): InputNormalization instance to perform z-normalization. Defaults to None. """ @@ -254,6 +300,7 @@ def __init__( warping_fn=None, independent_axis: tp.Union[int, tp.Sequence[int]] = 0, squeeze_channel_axis: bool = False, + sequence_last: bool = True, normalization: tp.Union[InputNormalization, None] = None, ): super().__init__() @@ -261,10 +308,6 @@ def __init__( self.stft_size = stft_size self.stft = stft - if self.stft is not None and not self.stft.spectrogram: - raise ValueError( - f'stft.spectrogram must be True but is {stft.spectrogram}' - ) self.number_of_filters = number_of_filters self.lowest_frequency = lowest_frequency @@ -313,6 +356,7 @@ def __init__( ) self.squeeze_channel_axis = squeeze_channel_axis + self.sequence_last = sequence_last self.normalization = normalization @classmethod @@ -322,6 +366,7 @@ def finalize_dogmatic_config(cls, config): 'window': 'hann', 'spectrogram': True, 'size': config['stft_size'], + 'sequence_last': False, } def _normalize(self, mel_basis): @@ -365,7 +410,8 @@ def forward( spectrograms or number of samples in `x`. Returns: x (Tensor): Mel spectrogram of shape - (batch, ..., time, number_of_filters). + (batch, ..., number_of_filters, time). If `sequence_last` is + False, the time and number_of_filters axis are swapped. sequence_lengths (list, optional): List of number of frames of mel spectrograms in `x` if input `sequence_lengths` is not None. """ @@ -373,6 +419,10 @@ def forward( if self.stft is not None: x, sequence_lengths = self.stft(x, sequence_lengths) + if not self.stft.spectrogram: + x = self.stft.to_spectrogram(x) + if self.stft.sequence_last: + x = x.transpose(-2, -1) if not self.training or self.warping_fn is None: x = torch.matmul(x, self.mel_basis.to(x.device)) @@ -410,6 +460,9 @@ def forward( if x.ndim == 4 and self.squeeze_channel_axis: x = x.squeeze(1) + if self.sequence_last: + x = x.transpose(-2, -1) # (..., bins, time) + if self.normalization is not None: x = self.normalization(x, sequence_lengths=sequence_lengths) @@ -427,7 +480,7 @@ def inverse( class MFCC(pt.Module): def __init__( self, - number_of_filters: int, + number_of_bins: int, transform: tp.Optional[tp.Union[MelTransform, STFT]] = None, axis: int = -1, channel_axis: int = 1, @@ -439,24 +492,28 @@ def __init__( """Extract mel-cepstral coefficients from audio. Args: - number_of_filters: Number of filters in the filterbank. - mel_transform: Optional `MelTransform` instance. If not None, - expect time signal as input and compute the log (mel) + number_of_bins (int): Number of frequency bins in the time-frequency + representation. + transform: Optional `MelTransform` or `STFT` instance. If not + None, expect time signal as input and compute the log (mel) spectrogram before extracting the cepstral coefficients. - axis: Position of the frequency axis. - channel_axis: Position of the channel axis. Can be set to None if - the input has no channel axis. - num_cep: Number of cepstral coefficients to keep. If None, all - coefficients are kept. - low_pass: If True and `num_cep` is not None, keep the lowest + axis (int): Position of the frequency axis. Defaults to -1. + channel_axis (int): Position of the channel axis. Can be set to + None if the input has no channel axis. Defaults to 1. + num_cep (int, optional): Number of cepstral coefficients to keep. + If None, all coefficients are kept. Defaults to None. + low_pass (bool): If True and `num_cep` is not None, keep the lowest `num_cep` coefficients and discard the rest (default behavior). If False, keep the highest `number_of_filters-num_cep` - coefficients (high-pass behavior). - lifter_coeff: Liftering in the cepstral domain. See + coefficients (high-pass behavior). Defaults to True. + lifter_coeff (int): Liftering in the cepstral domain. See `paderbox.transform.module_mfcc`. If 0, no liftering is applied. + Defaults to 0. + normalization (InputNormalization, optional): InputNormalization + instance to perform z-normalization. Defaults to None. """ super().__init__() - self.number_of_filters = number_of_filters + self.number_of_bins = number_of_bins self.transform = transform self.axis = axis self.channel_axis = channel_axis @@ -543,7 +600,7 @@ def inverse(self, x_mfcc: Tensor) -> Tensor: if self.num_cep is not None: shape = list(x_mfcc.shape) if self.low_pass: - shape[self.axis] = self.number_of_filters - self.num_cep + shape[self.axis] = self.number_of_bins - self.num_cep x_mfcc = torch.cat( (x_mfcc, torch.zeros(shape).to(x_mfcc.device)), dim=self.axis @@ -556,7 +613,7 @@ def inverse(self, x_mfcc: Tensor) -> Tensor: ) spect = torch.index_select( torch.fft.irfft(x_mfcc, axis=self.axis, norm='ortho'), - self.axis, torch.arange(self.number_of_filters).to(x_mfcc.device) + self.axis, torch.arange(self.number_of_bins).to(x_mfcc.device) ) return spect