Skip to content

Commit

Permalink
Merge branch 'move-resampling-to-network'
Browse files Browse the repository at this point in the history
  • Loading branch information
mitsuse committed Sep 15, 2024
2 parents 42ea029 + d9dda8a commit d69b0d1
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
16 changes: 1 addition & 15 deletions src/torch_wae/cli/export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import torch
import typer
from torchaudio import functional as F

from torch_wae.network import WAENet
from torch_wae.network import WAENet, WithResample

app = typer.Typer()

Expand Down Expand Up @@ -51,18 +50,5 @@ def main(
)


class WithResample(torch.nn.Module):
def __init__(self, f: WAENet, sample_rate: int) -> None:
super().__init__()

self.f = f
self.sample_rate = sample_rate

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
h = F.resample(waveform, self.sample_rate, self.f.SAMPLE_RATE)
z = self.f(h)
return z


if __name__ == "__main__":
app()
14 changes: 14 additions & 0 deletions src/torch_wae/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from convmelspec.stft import ConvertibleSpectrogram as Spectrogram
from torch import nn
from torch.nn import functional as F
from torchaudio import functional as FA


# Wowrd Audio Encoder - A network for audio similar to MobileNet V2 for images.
Expand Down Expand Up @@ -137,3 +138,16 @@ def __init__(self, dim: int = 1, eps: float = 1e-12):

def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.normalize(x, p=2, dim=self.dim, eps=self.eps)


class WithResample(torch.nn.Module):
def __init__(self, f: WAENet, sample_rate: int) -> None:
super().__init__()

self.f = f
self.sample_rate = sample_rate

def forward(self, waveform: torch.Tensor) -> torch.Tensor:
h = FA.resample(waveform, self.sample_rate, self.f.SAMPLE_RATE)
z = self.f(h)
return z

0 comments on commit d69b0d1

Please sign in to comment.