From d9dda8a8e40e225b80d5b92f26d27a81fc42c8f4 Mon Sep 17 00:00:00 2001 From: Tomoya Kose Date: Sun, 15 Sep 2024 14:16:20 +0900 Subject: [PATCH] Move WithResample layer to torch_wae.network. --- src/torch_wae/cli/export_to_onnx.py | 16 +--------------- src/torch_wae/network.py | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/torch_wae/cli/export_to_onnx.py b/src/torch_wae/cli/export_to_onnx.py index 13526b7..54818c1 100644 --- a/src/torch_wae/cli/export_to_onnx.py +++ b/src/torch_wae/cli/export_to_onnx.py @@ -4,9 +4,8 @@ import torch import typer -from torchaudio import functional as F -from torch_wae.network import WAENet +from torch_wae.network import WAENet, WithResample app = typer.Typer() @@ -51,18 +50,5 @@ def main( ) -class WithResample(torch.nn.Module): - def __init__(self, f: WAENet, sample_rate: int) -> None: - super().__init__() - - self.f = f - self.sample_rate = sample_rate - - def forward(self, waveform: torch.Tensor) -> torch.Tensor: - h = F.resample(waveform, self.sample_rate, self.f.SAMPLE_RATE) - z = self.f(h) - return z - - if __name__ == "__main__": app() diff --git a/src/torch_wae/network.py b/src/torch_wae/network.py index b951389..6402416 100644 --- a/src/torch_wae/network.py +++ b/src/torch_wae/network.py @@ -4,6 +4,7 @@ from convmelspec.stft import ConvertibleSpectrogram as Spectrogram from torch import nn from torch.nn import functional as F +from torchaudio import functional as FA # Wowrd Audio Encoder - A network for audio similar to MobileNet V2 for images. @@ -137,3 +138,16 @@ def __init__(self, dim: int = 1, eps: float = 1e-12): def forward(self, x: torch.Tensor) -> torch.Tensor: return F.normalize(x, p=2, dim=self.dim, eps=self.eps) + + +class WithResample(torch.nn.Module): + def __init__(self, f: WAENet, sample_rate: int) -> None: + super().__init__() + + self.f = f + self.sample_rate = sample_rate + + def forward(self, waveform: torch.Tensor) -> torch.Tensor: + h = FA.resample(waveform, self.sample_rate, self.f.SAMPLE_RATE) + z = self.f(h) + return z