Skip to content

Commit

Permalink
Correct name of tree print function
Browse files Browse the repository at this point in the history
  • Loading branch information
LemonPi committed Aug 28, 2024
1 parent 3ca731a commit 5288100
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ __pycache__
temp*
build
dist
*.pyc
# These are cloned/generated when testing with mujoco
tests/MUJOCO_LOG.TXT
tests/mujoco_menagerie/
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_kinematics/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def get_link_names(self):
def get_frame_indices(self, *frame_names):
return torch.tensor([self.frame_to_idx[n] for n in frame_names], dtype=torch.long, device=self.device)

def print_link_tree(self, do_print=True):
def print_tree(self, do_print=True):
tree = str(self._root)
if do_print:
print(tree)
Expand Down
8 changes: 4 additions & 4 deletions tests/test_serial_chain_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_extract_serial_chain_from_tree():
├── right_finger_link
└── ee_gripper_link
"""
full_frame = chain.print_link_tree()
full_frame = chain.print_tree()
assert full_frame_expected.strip() == full_frame.strip()

serial_chain = pk.SerialChain(chain, "ee_gripper_link", "base_link")
Expand All @@ -46,7 +46,7 @@ def test_extract_serial_chain_from_tree():
└── fingers_link
└── ee_gripper_link
"""
serial_frame = serial_chain.print_link_tree()
serial_frame = serial_chain.print_tree()
assert serial_frame_expected.strip() == serial_frame.strip()

# full chain should have DOF = 8, however since we are creating just a serial chain to ee_gripper_link, should be 6
Expand All @@ -65,7 +65,7 @@ def test_extract_serial_chain_from_tree():
└── ee_arm_link
└── gripper_prop_link
"""
serial_frame = serial_chain.print_link_tree()
serial_frame = serial_chain.print_tree()
assert serial_frame_expected.strip() == serial_frame.strip()

serial_chain = pk.SerialChain(chain, "ee_gripper_link", "gripper_link")
Expand All @@ -76,7 +76,7 @@ def test_extract_serial_chain_from_tree():
└── fingers_link
└── ee_gripper_link
"""
serial_frame = serial_chain.print_link_tree()
serial_frame = serial_chain.print_tree()
assert serial_frame_expected.strip() == serial_frame.strip()
# only gripper_link is the parent frame of a joint in this serial chain
assert serial_chain.n_joints == 1
Expand Down

0 comments on commit 5288100

Please sign in to comment.