diff --git a/hippynn/layers/physics.py b/hippynn/layers/physics.py index 3debf31d..6c6282e4 100644 --- a/hippynn/layers/physics.py +++ b/hippynn/layers/physics.py @@ -247,7 +247,7 @@ def forward(self, features, species): class VecMag(torch.nn.Module): def forward(self, vector_feature): - return torch.norm(vector_feature, dim=1).unsqueeze(1) + return torch.norm(vector_feature, dim=1) class CombineEnergy(torch.nn.Module):