diff --git a/monai/networks/blocks/mednext_block.py b/monai/networks/blocks/mednext_block.py index 8d0f0433bd..d5420f92ed 100644 --- a/monai/networks/blocks/mednext_block.py +++ b/monai/networks/blocks/mednext_block.py @@ -21,6 +21,13 @@ all = ["MedNeXtBlock", "MedNeXtDownBlock", "MedNeXtUpBlock", "MedNeXtOutBlock"] +def get_conv_layer(spatial_dim: int = 3, transpose: bool = False): + if spatial_dim == 2: + return nn.ConvTranspose2d if transpose else nn.Conv2d + else: # spatial_dim == 3 + return nn.ConvTranspose3d if transpose else nn.Conv3d + + class MedNeXtBlock(nn.Module): def __init__( @@ -39,18 +46,9 @@ def __init__( self.do_res = use_residual_connection - assert dim in ["2d", "3d"] self.dim = dim - if self.dim == "2d": - conv = nn.Conv2d - normalized_shape = [in_channels, kernel_size, kernel_size] - grn_parameter_shape = (1, 1) - elif self.dim == "3d": - conv = nn.Conv3d - normalized_shape = [in_channels, kernel_size, kernel_size, kernel_size] - grn_parameter_shape = (1, 1, 1) - else: - raise ValueError("dim must be either '2d' or '3d'") + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) + grn_parameter_shape = (1,) * (2 if dim == "2d" else 3) # First convolution layer with DepthWise Convolutions self.conv1 = conv( in_channels=in_channels, @@ -63,9 +61,11 @@ def __init__( # Normalization Layer. GroupNorm is used by default. if norm_type == "group": - self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) + self.norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels) # type: ignore elif norm_type == "layer": - self.norm = nn.LayerNorm(normalized_shape=normalized_shape) + self.norm = nn.LayerNorm( + normalized_shape=[in_channels] + [kernel_size] * (2 if dim == "2d" else 3) # type: ignore + ) # Second convolution (Expansion) layer with Conv3D 1x1x1 self.conv2 = conv( in_channels=in_channels, out_channels=expansion_ratio * in_channels, kernel_size=1, stride=1, padding=0 @@ -131,10 +131,7 @@ def __init__( grn=grn, ) - if dim == "2d": - conv = nn.Conv2d - else: - conv = nn.Conv3d + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3) self.resample_do_res = use_residual_connection if use_residual_connection: self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) @@ -186,10 +183,7 @@ def __init__( self.resample_do_res = use_residual_connection self.dim = dim - if dim == "2d": - conv = nn.ConvTranspose2d - else: - conv = nn.ConvTranspose3d + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True) if use_residual_connection: self.res_conv = conv(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=2) @@ -228,10 +222,7 @@ class MedNeXtOutBlock(nn.Module): def __init__(self, in_channels, n_classes, dim): super().__init__() - if dim == "2d": - conv = nn.ConvTranspose2d - else: - conv = nn.ConvTranspose3d + conv = get_conv_layer(spatial_dim=2 if dim == "2d" else 3, transpose=True) self.conv_out = conv(in_channels, n_classes, kernel_size=1) def forward(self, x): diff --git a/monai/networks/nets/mednext.py b/monai/networks/nets/mednext.py index 4de4b7b442..bb1d6b534c 100644 --- a/monai/networks/nets/mednext.py +++ b/monai/networks/nets/mednext.py @@ -72,8 +72,8 @@ def __init__( init_filters: int = 32, in_channels: int = 1, out_channels: int = 2, - encoder_expansion_ratio: int = 2, - decoder_expansion_ratio: int = 2, + encoder_expansion_ratio: Sequence[int] | int = 2, + decoder_expansion_ratio: Sequence[int] | int = 2, bottleneck_expansion_ratio: int = 2, kernel_size: int = 7, deep_supervision: bool = False, @@ -212,7 +212,7 @@ def __init__( out_blocks.reverse() self.out_blocks = nn.ModuleList(out_blocks) - def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: + def forward(self, x: torch.Tensor) -> torch.Tensor | Sequence[torch.Tensor]: """ Forward pass of the MedNeXt model. @@ -227,7 +227,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | list[torch.Tensor]: x (torch.Tensor): Input tensor. Returns: - torch.Tensor or list[torch.Tensor]: Output tensor(s). + torch.Tensor or Sequence[torch.Tensor]: Output tensor(s). """ # Apply stem convolution x = self.stem(x) @@ -311,7 +311,7 @@ def create_mednext( blocks_down=(2, 2, 2, 2), blocks_bottleneck=2, blocks_up=(2, 2, 2, 2), - **common_args, + **common_args, # type: ignore ) elif variant.upper() == "B": return MedNeXt( @@ -321,7 +321,7 @@ def create_mednext( blocks_down=(2, 2, 2, 2), blocks_bottleneck=2, blocks_up=(2, 2, 2, 2), - **common_args, + **common_args, # type: ignore ) elif variant.upper() == "M": return MedNeXt( @@ -331,7 +331,7 @@ def create_mednext( blocks_down=(3, 4, 4, 4), blocks_bottleneck=4, blocks_up=(4, 4, 4, 3), - **common_args, + **common_args, # type: ignore ) elif variant.upper() == "L": return MedNeXt( @@ -341,7 +341,7 @@ def create_mednext( blocks_down=(3, 4, 8, 8), blocks_bottleneck=8, blocks_up=(8, 8, 4, 3), - **common_args, + **common_args, # type: ignore ) else: raise ValueError(f"Invalid MedNeXt variant: {variant}") diff --git a/tests/test_mednext.py b/tests/test_mednext.py index 4dca898f4a..b4ba4f9939 100644 --- a/tests/test_mednext.py +++ b/tests/test_mednext.py @@ -59,7 +59,7 @@ for spatial_dims in range(2, 4): for out_channels in [1, 2]: test_case = [ - model, + model, # type: ignore {"spatial_dims": spatial_dims, "in_channels": 1, "out_channels": out_channels}, (2, 1, *([16] * spatial_dims)), (2, out_channels, *([16] * spatial_dims)),