Skip to content

Commit

Permalink
Implemented PR feedback and bugfixes for `compute_2d_head_direction_v…
Browse files Browse the repository at this point in the history
…ector()`
  • Loading branch information
b-peri committed Sep 16, 2024
1 parent e3db1ad commit 9a156ed
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 160 deletions.
116 changes: 65 additions & 51 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Compute kinematic variables like velocity and acceleration."""

import numpy as np
import numpy.typing as npt
import xarray as xr

from movement.utils.logging import log_error
from movement.utils.vector import convert_to_unit
from movement.utils.vector import compute_norm


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Expand Down Expand Up @@ -163,20 +164,44 @@ def _compute_approximate_time_derivative(
return result


def compute_head_direction_vector(
def compute_2d_head_direction_vector(
data: xr.DataArray,
left_keypoint: str,
right_keypoint: str,
front_keypoint: str | None = None,
upward_vector: npt.ArrayLike = (0, 0, -1),
):
"""Compute the 2D head direction vector given two keypoints on the head.
The head direction vector is computed as a vector perpendicular to the
line connecting two keypoints on either side of the head, pointing
forwards (in the rostral direction). As the forward direction may
differ between coordinate systems, the front keypoint is used ...,
when present. Otherwise, we assume that coordinates are given in the
image coordinate system (where the origin is located in the top-left).
line connecting two symmetrical keypoints on either side of the head
(i.e., symmetrical relative to the sagittal plane), and pointing
forwards (in the rostral direction). A top-down or bottom-up view of the
animal is assumed.
To determine the forward direction of the animal, we need to specify
(1) the right-to-left direction of the animal and (2) its upward direction.
We determine the right-to-left direction via the input left and right
keypoints. For the forward direction, if no additional information is
provided, we assume the keypoints are expressed in the image coordinate
system (where the origin is located in the top-left corner of the screen),
and that the analysed image is a top-down view of the animal. In this case
the upward direction of the animal is the negative z direction of the image
coordinate system. Alternatively, users can specify the upward direction
of the animal directly.
If one of the required pieces of information is missing for a frame (e.g.,
the left keypoint is not visible), then the computed head direction vector
is set to NaN.
Notes
-----
If specified, the upward direction must be expressed in the same coordinate
system as the keypoint data.
Note that the assumed upward direction would be incorrect if the animal
is recorded from its belly (bottom-up view). The default upward direction
would be the negative z direction in the image coordinate system, but the
true upward direction of the animal is the positive z direction.
Parameters
----------
Expand All @@ -188,8 +213,10 @@ def compute_head_direction_vector(
Name of the left keypoint, e.g., "left_ear"
right_keypoint : str
Name of the right keypoint, e.g., "right_ear"
front_keypoint : str | None
(Optional) Name of the front keypoint, e.g., "nose".
upward_vector : array-like, optional
The upward vector in the coordinate system the keypoints are in.
By default, it is the negative z-axis in the image coordinate
system, i.e., [0, 0, -1].
Returns
-------
Expand All @@ -199,56 +226,43 @@ def compute_head_direction_vector(
``keypoints`` dimension.
"""
# Validate input dataset
# Validate input data
_validate_type_data_array(data)
_validate_time_keypoints_space_dimensions(data)

if left_keypoint == right_keypoint:
raise log_error(
ValueError, "The left and right keypoints may not be identical."
)
if len(data.space) != 2:
raise log_error(
ValueError,
"Input data must have 2 (and only 2) spatial dimensions, but "
f"currently has {len(data.space)}.",
)

# Select the right and left keypoints
head_left = data.sel(keypoints=left_keypoint, drop=True)
head_right = data.sel(keypoints=right_keypoint, drop=True)

# Initialize a vector from right to left ear, and another vector
# perpendicular to the X-Y plane
right_to_left_vector = head_left - head_right
perpendicular_vector = np.array([0, 0, -1])

# Compute cross product
head_vector = head_right.copy()
head_vector.values = np.cross(right_to_left_vector, perpendicular_vector)[
:, :, :-1
]

# Check computed head_vector is pointing in the same direction as vector
# from head midpoint to snout
if front_keypoint:
head_front = data.sel(keypoints=front_keypoint, drop=True)
head_midpoint = (head_right + head_left) / 2
mid_to_front_vector = head_front - head_midpoint
dot_product_array = (
convert_to_unit(head_vector.sel(individuals=data.individuals[0]))
* convert_to_unit(mid_to_front_vector).sel(
individuals=data.individuals[0]
)
).sum(dim="space")
median_dot_product = float(dot_product_array.median(dim="time").values)
if median_dot_product < 0:
perpendicular_vector = np.array([0, 0, 1])
head_vector.values = np.cross(
right_to_left_vector, perpendicular_vector
)[:, :, :-1]

return head_vector
# Validate input keypoints
if left_keypoint == right_keypoint:
raise log_error(
ValueError, "The left and right keypoints may not be identical."
)

# Define right-to-left vector
right_to_left_vector = data.sel(
keypoints=left_keypoint, drop=True
) - data.sel(keypoints=right_keypoint, drop=True)

# Define upward vector
# default: negative z direction in the image coordinate system
upward_vector = xr.DataArray(
np.tile(np.array(upward_vector).reshape(1, -1), [len(data.time), 1]),
dims=["time", "space"],
)

# Compute forward direction as the cross product
# (right-to-left) cross (forward) = up
forward_vector = xr.cross(
right_to_left_vector, upward_vector, dim="space"
)[:, :, :-1] # keep only the first 2 dimensions of the result

# Return unit vector

return forward_vector / compute_norm(forward_vector)


def _validate_time_dimension(data: xr.DataArray) -> None:
Expand Down
Loading

0 comments on commit 9a156ed

Please sign in to comment.