diff --git a/models/codec/ns3_codec/facodec.py b/models/codec/ns3_codec/facodec.py index 87f661bd..971b131e 100644 --- a/models/codec/ns3_codec/facodec.py +++ b/models/codec/ns3_codec/facodec.py @@ -811,15 +811,28 @@ def __init__( self.reset_parameters() + def _pad_data(self, x): + """Pads the input audio data 'x'.""" + + remainder = x.size(-1) % self.hop_length + if remainder != 0: + pad_size = self.hop_length - remainder + else: + pad_size = 0 + x_padded = F.pad(x, (0, pad_size)) + return x_padded + def forward(self, x): - out = self.block(x) - return out + x_padded = self._pad_data(x) + return self.block(x_padded) def inference(self, x): - return self.block(x) + x_padded = self._pad_data(x) + return self.block(x_padded) def get_prosody_feature(self, x): - return self.mel_transform(x.squeeze(1))[:, :20, :] + x_padded = self._pad_data(x) + return self.mel_transform(x_padded.squeeze(1))[:, :20, :] def remove_weight_norm(self): """Remove weight normalization module from all of the layers."""