diff --git a/src/torch_wae/network.py b/src/torch_wae/network.py index 984fcc0..b951389 100644 --- a/src/torch_wae/network.py +++ b/src/torch_wae/network.py @@ -41,18 +41,18 @@ def __init__(self, s: int) -> None: self.layers = nn.Sequential( # -------------------- - # shape: (1, 64, 64) -> (16, 32, 32) + # shape: (1, 64, 64) -> (8, 32, 32) # -------------------- - InvertedBottleneck(k=3, c_in=1, c_out=16 * s, stride=2), + InvertedBottleneck(k=3, c_in=1, c_out=8 * s, stride=2), # -------------------- - # shape: (16, 32, 32) -> (8, 32, 32) + # shape: (8, 32, 32) -> (12, 32, 32) # -------------------- - InvertedBottleneck(k=3, c_in=16 * s, c_out=8 * s, stride=1), - InvertedBottleneck(k=3, c_in=8 * s, c_out=8 * s, stride=1), + InvertedBottleneck(k=3, c_in=8 * s, c_out=12 * s, stride=1), + InvertedBottleneck(k=3, c_in=12 * s, c_out=12 * s, stride=1), # -------------------- - # shape: (8, 32, 32) -> (12, 16, 16) + # shape: (12, 32, 32) -> (12, 16, 16) # -------------------- - InvertedBottleneck(k=3, c_in=8 * s, c_out=12 * s, stride=2), + InvertedBottleneck(k=3, c_in=12 * s, c_out=12 * s, stride=2), InvertedBottleneck(k=3, c_in=12 * s, c_out=12 * s, stride=1), InvertedBottleneck(k=3, c_in=12 * s, c_out=12 * s, stride=1), # --------------------