Skip to content

Commit

Permalink
Merge pull request #27 from UM-ARM-Lab/optimize_transform3d
Browse files Browse the repository at this point in the history
Remove laziness to speed up computations by about 25%
  • Loading branch information
LemonPi authored Nov 27, 2023
2 parents b8b07db + 0ebabfc commit 31f1422
Showing 1 changed file with 12 additions and 64 deletions.
76 changes: 12 additions & 64 deletions src/pytorch_kinematics/transforms/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def __init__(
rot_h = torch.cat((rot, zeros), dim=-2).reshape(-1, 4, 3)
self._matrix = torch.cat((rot_h, self._matrix[:, :, 3].reshape(-1, 4, 1)), dim=-1)

self._transforms = [] # store transforms to compose
self._lu = None
self.device = device
self.dtype = self._matrix.dtype
Expand Down Expand Up @@ -240,37 +239,19 @@ def compose(self, *others):
Returns:
A new Transform3d with the stored transforms
"""
out = Transform3d(device=self.device, dtype=self.dtype)
out._matrix = self._matrix.clone()

mat = self._matrix
for other in others:
if not isinstance(other, Transform3d):
msg = "Only possible to compose Transform3d objects; got %s"
raise ValueError(msg % type(other))
out._transforms = self._transforms + list(others)
mat = _broadcast_bmm(mat, other.get_matrix())

out = Transform3d(device=self.device, dtype=self.dtype, matrix=mat)
return out

def get_matrix(self):
"""
Return a matrix which is the result of composing this transform
with others stored in self.transforms. Where necessary transforms
are broadcast against each other.
For example, if self.transforms contains transforms t1, t2, and t3, and
given a set of points x, the following should be true:
.. code-block:: python
y1 = t1.compose(t2, t3).transform(x)
y2 = t3.transform(t2.transform(t1.transform(x)))
y1.get_matrix() == y2.get_matrix()
Returns:
A transformation matrix representing the composed inputs.
Return the Nx4x4 homogeneous transformation matrix represented by this object.
"""
composed_matrix = self._matrix
for other in self._transforms:
other_matrix = other.get_matrix()
composed_matrix = _broadcast_bmm(composed_matrix, other_matrix)
return composed_matrix
return self._matrix

def _get_matrix_inverse(self):
"""
Expand All @@ -282,7 +263,7 @@ def _get_matrix_inverse(self):
@staticmethod
def _invert_transformation_matrix(T):
"""
Inverts homogeneous transformation matrix
Invert homogeneous transformation matrix.
"""
Tinv = T.clone()
R = T[:, :3, :3]
Expand All @@ -297,54 +278,23 @@ def inverse(self, invert_composed: bool = False):
current transformation.
Args:
invert_composed:
- True: First compose the list of stored transformations
and then apply inverse to the result. This is
potentially slower for classes of transformations
with inverses that can be computed efficiently
(e.g. rotations and translations).
- False: Invert the individual stored transformations
independently without composing them.
invert_composed: ignored, included for backwards compatibility
Returns:
A new Transform3D object containing the inverse of the original
transformation.
"""

tinv = Transform3d(device=self.device)
i_matrix = self._get_matrix_inverse()

if invert_composed:
# first compose then invert
tinv._matrix = self._invert_transformation_matrix(self.get_matrix())
else:
# self._get_matrix_inverse() implements efficient inverse
# of self._matrix
i_matrix = self._get_matrix_inverse()

# 2 cases:
if len(self._transforms) > 0:
# a) Either we have a non-empty list of transforms:
# Here we take self._matrix and append its inverse at the
# end of the reverted _transforms list. After composing
# the transformations with get_matrix(), this correctly
# right-multiplies by the inverse of self._matrix
# at the end of the composition.
tinv._transforms = [t.inverse() for t in reversed(self._transforms)]
last = Transform3d(device=self.device)
last._matrix = i_matrix
tinv._transforms.append(last)
else:
# b) Or there are no stored transformations
# we just set inverted matrix
tinv._matrix = i_matrix
tinv = Transform3d(matrix=i_matrix, device=self.device)

return tinv

def stack(self, *others):
transforms = [self] + list(others)
matrix = torch.cat([t._matrix for t in transforms], dim=0)
out = Transform3d()
out._matrix = matrix
out = Transform3d(matrix=matrix, device=self.device, dtype=self.dtype)
return out

def transform_points(self, points, eps: Optional[float] = None):
Expand Down Expand Up @@ -478,7 +428,6 @@ def clone(self):
if self._lu is not None:
other._lu = [elem.clone() for elem in self._lu]
other._matrix = self._matrix.clone()
other._transforms = [t.clone() for t in self._transforms]
return other

def to(self, device, copy: bool = False, dtype=None):
Expand All @@ -504,7 +453,6 @@ def to(self, device, copy: bool = False, dtype=None):
other.device = device
other.dtype = dtype if dtype is not None else other.dtype
other._matrix = self._matrix.to(device=device, dtype=dtype)
other._transforms = [t.to(device, copy=copy, dtype=dtype) for t in other._transforms]
return other

def cpu(self):
Expand Down

0 comments on commit 31f1422

Please sign in to comment.