Skip to content

Commit

Permalink
change Transform3d.inverse to more efficient inverse. Add Transform3d…
Browse files Browse the repository at this point in the history
….transform_shape_operator function
  • Loading branch information
powertj committed Nov 2, 2023
1 parent 6d1fd2d commit 652642f
Showing 1 changed file with 42 additions and 7 deletions.
49 changes: 42 additions & 7 deletions src/pytorch_kinematics/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,20 @@ def _get_matrix_inverse(self):
"""
Return the inverse of self._matrix.
"""
return torch.inverse(self._matrix)

return self._invert_transformation_matrix(self._matrix)

@staticmethod
def _invert_transformation_matrix(T):
"""
Inverts homogeneous transformation matrix
"""
Tinv = T.clone()
R = T[:, :3, :3]
t = T[:, :3, 3]
Tinv[:, :3, :3] = R.transpose(1, 2)
Tinv[:, :3, 3:] = -Tinv[:, :3, :3] @ t.unsqueeze(-1)
return Tinv

def inverse(self, invert_composed: bool = False):
"""
Expand All @@ -293,15 +306,15 @@ def inverse(self, invert_composed: bool = False):
independently without composing them.
Returns:
A new Transform3D object contaning the inverse of the original
A new Transform3D object containing the inverse of the original
transformation.
"""

tinv = Transform3d(device=self.device)

if invert_composed:
# first compose then invert
tinv._matrix = torch.inverse(self.get_matrix())
tinv._matrix = self._invert_transformation_matrix(self.get_matrix())
else:
# self._get_matrix_inverse() implements efficient inverse
# of self._matrix
Expand Down Expand Up @@ -392,10 +405,7 @@ def transform_normals(self, normals):
if normals.dim() not in [2, 3]:
msg = "Expected normals to have dim = 2 or dim = 3: got shape %r"
raise ValueError(msg % (normals.shape,))
composed_matrix = self.get_matrix()

# TODO: inverse is bad! Solve a linear system instead
mat = composed_matrix[:, :3, :3]
mat = self._get_matrix_inverse()[:, :3, :3]
normals_out = _broadcast_bmm(normals, mat.inverse())

# This doesn't pass unit tests. TODO investigate further
Expand All @@ -410,6 +420,31 @@ def transform_normals(self, normals):

return normals_out

def transform_shape_operator(self, shape_operators):
"""
Use this transform to transform a set of shape_operator (or Weingarten map).
This is the hessian of a signed-distance, i.e. gradient of a normal vector.
Args:
shape_operators: Tensor of shape (P, 3, 3) or (N, P, 3, 3)
Returns:
shape_operators_out: Tensor of shape (P, 3, 3) or (N, P, 3, 3) depending
on the dimensions of the transform
"""
if shape_operators.dim() not in [3, 4]:
msg = "Expected shape_operators to have dim = 3 or dim = 4: got shape %r"
raise ValueError(msg % (shape_operators.shape,))
mat = self._get_matrix_inverse()[:, :3, :3]
shape_operators_out = _broadcast_bmm(mat.permute(0, 2, 1), _broadcast_bmm(shape_operators, mat))

# When transform is (1, 4, 4) and shape_operator is (P, 3, 3) return
# shape_operators_out of shape (P, 3, 3)
if shape_operators_out.shape[0] == 1 and shape_operators.dim() == 3:
shape_operators_out = shape_operators_out.reshape(shape_operators.shape)

return shape_operators_out

def translate(self, *args, **kwargs):
return self.compose(Translate(device=self.device, *args, **kwargs))

Expand Down

0 comments on commit 652642f

Please sign in to comment.