Skip to content

Commit

Permalink
Extended testing and added front_keypoint argument to `compute_head…
Browse files Browse the repository at this point in the history
…_direction_vector()`
  • Loading branch information
b-peri committed Sep 10, 2024
1 parent d6d3e85 commit e48a2ec
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 11 deletions.
35 changes: 33 additions & 2 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import xarray as xr

from movement.utils.logging import log_error
from movement.utils.vector import convert_to_unit


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Expand Down Expand Up @@ -122,13 +123,19 @@ def _compute_approximate_time_derivative(


def compute_head_direction_vector(
data: xr.DataArray, left_keypoint: str, right_keypoint: str
data: xr.DataArray,
left_keypoint: str,
right_keypoint: str,
front_keypoint: str | None = None,
):
"""Compute the 2D head direction vector given two keypoints on the head.
The head direction vector is computed as a vector perpendicular to the
line connecting two keypoints on either side of the head, pointing
forwards (in the rostral direction).
forwards (in the rostral direction). As the forward direction may
differ between coordinate systems, the front keypoint is used ...,
when present. Otherwise, we assume that coordinates are given in the
image coordinate system (where the origin is located in the top-left).
Parameters
----------
Expand All @@ -140,6 +147,8 @@ def compute_head_direction_vector(
Name of the left keypoint, e.g., "left_ear"
right_keypoint : str
Name of the right keypoint, e.g., "right_ear"
front_keypoint : str | None
(Optional) Name of the front keypoint, e.g., "nose".
Returns
-------
Expand All @@ -150,6 +159,9 @@ def compute_head_direction_vector(
"""
# Validate input dataset
_validate_type_data_array(data)
_validate_time_keypoints_space_dimensions(data)

if left_keypoint == right_keypoint:
raise log_error(
ValueError, "The left and right keypoints may not be identical."
Expand All @@ -176,6 +188,25 @@ def compute_head_direction_vector(
:, :, :-1
]

# Check computed head_vector is pointing in the same direction as vector
# from head midpoint to snout
if front_keypoint:
head_front = data.sel(keypoints=front_keypoint, drop=True)
head_midpoint = (head_right + head_left) / 2
mid_to_front_vector = head_front - head_midpoint
dot_product_array = (
convert_to_unit(head_vector.sel(individuals=data.individuals[0]))
* convert_to_unit(mid_to_front_vector).sel(
individuals=data.individuals[0]
)
).sum(dim="space")
median_dot_product = float(dot_product_array.median(dim="time").values)
if median_dot_product < 0:
perpendicular_vector = np.array([0, 0, 1])
head_vector.values = np.cross(
right_to_left_vector, perpendicular_vector
)[:, :, :-1]

return head_vector


Expand Down
33 changes: 24 additions & 9 deletions tests/test_unit/test_kinematics.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from contextlib import nullcontext as does_not_raise

import numpy as np
Expand Down Expand Up @@ -116,7 +117,7 @@ class TestNavigation:
"""Test suite for navigation-related functions in the kinematics module."""

@pytest.fixture
def mock_data_array(self):
def mock_dataarray(self):
"""Return a mock DataArray containing four known head orientations."""
time = [0, 1, 2, 3]
individuals = ["individual_0"]
Expand All @@ -141,7 +142,7 @@ def mock_data_array(self):
return ds

@pytest.fixture
def mock_data_array_3D(self):
def mock_dataarray_3D(self):
"""Return a 3D DataArray containing a known head orientation."""
time = [0]
individuals = ["individual_0"]
Expand All @@ -163,7 +164,7 @@ def mock_data_array_3D(self):
return ds

def test_compute_head_direction_vector(
self, mock_data_array, mock_data_array_3D
self, mock_dataarray, mock_dataarray_3D
):
"""Test that the correct head direction vectors
are computed from a basic mock dataset.
Expand All @@ -173,14 +174,14 @@ def test_compute_head_direction_vector(
# Catch incorrect datatype
with pytest.raises(TypeError, match="must be an xarray.DataArray"):
kinematics.compute_head_direction_vector(
mock_data_array.values, "left_ear", "right_ear"
mock_dataarray.values, "left_ear", "right_ear"
)

# Catch incorrect dimensions
with pytest.raises(
AttributeError, match="'time', 'space', and 'keypoints'"
):
mock_data_keypoint = mock_data_array.sel(
mock_data_keypoint = mock_dataarray.sel(
keypoints="nose", drop=True
)
kinematics.compute_head_direction_vector(
Expand All @@ -190,20 +191,21 @@ def test_compute_head_direction_vector(
# Catch identical left and right keypoints
with pytest.raises(ValueError, match="keypoints may not be identical"):
kinematics.compute_head_direction_vector(
mock_data_array, "left_ear", "left_ear"
mock_dataarray, "left_ear", "left_ear"
)

# Catch incorrect spatial dimensions
with pytest.raises(
ValueError, match="must have 2 (and only 2) spatial dimensions"
ValueError,
match=re.escape("must have 2 (and only 2) spatial dimensions"),
):
kinematics.compute_head_direction_vector(
mock_data_array_3D, "left", "right"
mock_dataarray_3D, "left", "right"
)

# Test that output contains correct datatype, dimensions, and values
head_vector = kinematics.compute_head_direction_vector(
mock_data_array, "left_ear", "right_ear"
mock_dataarray, "left_ear", "right_ear"
)
known_vectors = np.array([[[0, 2]], [[-2, 0]], [[0, -2]], [[2, 0]]])

Expand All @@ -213,3 +215,16 @@ def test_compute_head_direction_vector(
and ("keypoints" not in head_vector.dims)
)
assert np.equal(head_vector.values, known_vectors).all()

# Test behaviour with NaNs
nan_dataarray = mock_dataarray.where(
(mock_dataarray.time != 1)
| (mock_dataarray.keypoints != "left_ear")
)
head_vector = kinematics.compute_head_direction_vector(
nan_dataarray, "left_ear", "right_ear"
)
assert (
np.isnan(head_vector.values[1, 0, :]).all()
and not np.isnan(head_vector.values[[0, 2, 3], 0, :]).any()
)

0 comments on commit e48a2ec

Please sign in to comment.