Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Get mediapipe info from upstream #10

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ dynamic = ["version"]

[project.optional-dependencies]
dev = ["black", "bumpver", "isort", "pip-tools", "pytest"]
standalone = ["skellytracker"]

[tool.bumpver] #bump the version by entering `bumpver update` in the terminal
current_version = "v2024.04.1022"
Expand Down
4 changes: 2 additions & 2 deletions skelly_viewer/config/folder_and_file_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@

OUTPUT_DATA_FOLDER_NAME = 'output_data'

MEDIAPIPE_3D_BODY_FILE_NAME = 'mediaPipeSkel_3d_origin_aligned.npy'
MEDIAPIPE_3D_BODY_ORIGIN_ALIGNED_FILE_NAME = 'mediaPipeSkel_3d_origin_aligned.npy'

#MEDIAPIPE_3D_BODY_FILE_NAME = 'mediapipe_body_3d_xyz.npy'
MEDIAPIPE_3D_BODY_FILE_NAME = 'mediapipe_body_3d_xyz.npy'

TOTAL_BODY_CENTER_OF_MASS_NPY_FILE_NAME = "totalBodyCOM_frame_XYZ.npy"

Expand Down
70 changes: 52 additions & 18 deletions skelly_viewer/gui/qt/skelly_viewer_widget.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from pathlib import Path
from typing import Union
from typing import List, Optional, Union

from PySide6.QtCore import Qt
from PySide6.QtWidgets import QWidget, QHBoxLayout, QVBoxLayout
from skellytracker.trackers.mediapipe_tracker.mediapipe_model_info import MediapipeModelInfo

from skelly_viewer.gui.qt.widgets.multi_video_display import MultiVideoDisplay
from skelly_viewer.gui.qt.widgets.skeleton_view_widget import SkeletonViewWidget
Expand All @@ -11,9 +12,15 @@

class SkellyViewer(QWidget):
# session_folder_loaded_signal = Signal()
def __init__(self, mediapipe_skeleton_npy_path=None, video_folder_path=None):
def __init__(
self,
skeleton_npy_path=None,
video_folder_path=None,
connections: List[tuple] = MediapipeModelInfo.body_connections):
super().__init__()

self.connections = connections

layout = QVBoxLayout()
self.setLayout(layout)
layout.setAlignment(Qt.AlignmentFlag.AlignBottom)
Expand All @@ -38,42 +45,69 @@ def __init__(self, mediapipe_skeleton_npy_path=None, video_folder_path=None):

self._is_video_display_enabled = True

if mediapipe_skeleton_npy_path is not None:
self.load_skeleton_data(mediapipe_skeleton_npy_path)
if skeleton_npy_path and self.connections:
self.load_skeleton_data(
skeleton_npy_path=skeleton_npy_path,
connections=self.connections,
)

if video_folder_path is not None:
self.generate_video_display(video_folder_path)

def load_skeleton_data(self, mediapipe_skeleton_npy_path: Union[str, Path]):
self._skeleton_view_widget.load_skeleton_data(mediapipe_skeleton_npy_path)
def load_skeleton_data(
self, skeleton_npy_path: Union[str, Path], connections: List[tuple]
):
self._skeleton_view_widget.load_skeleton_data(
skeleton_npy_path=skeleton_npy_path,
connections=connections,
)

def generate_video_display(self, video_folder_path: Union[str, Path]):
self.multi_video_display.generate_video_display(video_folder_path)
self.multi_video_display.update_display(self._frame_count_slider._slider.value())

def set_data_paths(self,
mediapipe_skeleton_npy_path: Union[str, Path],
video_folder_path: Union[str, Path]):

self.load_skeleton_data(mediapipe_skeleton_npy_path)
self.multi_video_display.update_display(
self._frame_count_slider._slider.value()
)

def set_data_paths(
self,
skeleton_npy_path: Union[str, Path],
video_folder_path: Union[str, Path],
connections: Optional[List[tuple]] = None,
) -> None:
"""
Load skeleton data and generate video display. Reset frame count slider to 0.
If connections is None, defaults to class connection, which is Mediapipe by default
"""
if connections is None:
connections = self.connections
self.load_skeleton_data(skeleton_npy_path, connections)
self.generate_video_display(video_folder_path)

self._frame_count_slider._slider.setValue(0)

def connect_signals_to_slots(self):
self._skeleton_view_widget.skeleton_data_loaded_signal.connect(
self._handle_data_loaded_signal)
self._handle_data_loaded_signal
)

self._frame_count_slider._slider.valueChanged.connect(self._handle_slider_value_changed)
self._frame_count_slider._slider.valueChanged.connect(
self._handle_slider_value_changed
)

def _handle_data_loaded_signal(self):
self._frame_count_slider.set_slider_range(self._skeleton_view_widget._number_of_frames)
self._frame_count_slider.set_slider_range(
self._skeleton_view_widget._number_of_frames
)
self._frame_count_slider.setEnabled(True)

def _handle_slider_value_changed(self):
self._skeleton_view_widget.update_skeleton_plot(self._frame_count_slider._slider.value())
self._skeleton_view_widget.update_skeleton_plot(
self._frame_count_slider._slider.value()
)
if self._is_video_display_enabled:
self.multi_video_display.update_display(self._frame_count_slider._slider.value())
self.multi_video_display.update_display(
self._frame_count_slider._slider.value()
)

def toggle_video_display(self):
self._is_video_display_enabled = not self._is_video_display_enabled
Expand Down
2 changes: 1 addition & 1 deletion skelly_viewer/gui/qt/skellyview_main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _load_data(self, path: Union[Path, str]):
data_loader = FreeMoCapDataLoader(path_to_session_folder=self._session_folder_path)

self._skelly_viewer.set_data_paths(
mediapipe_skeleton_npy_path=data_loader.find_skeleton_npy_file_name(),
skeleton_npy_path=data_loader.find_skeleton_npy_file_name(),
video_folder_path=data_loader.find_synchronized_videos_folder_path()
)

Expand Down
56 changes: 26 additions & 30 deletions skelly_viewer/gui/qt/widgets/skeleton_view_widget.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Union
from typing import List, Union

import matplotlib
from PySide6.QtCore import Signal
from PySide6.QtWidgets import QWidget, QVBoxLayout

from skelly_viewer.utilities.mediapipe_skeleton_builder import build_skeleton, mediapipe_indices, mediapipe_connections
from skelly_viewer.utilities.skeleton_builder import build_skeleton

matplotlib.use('Qt5Agg')
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg
matplotlib.use('QtAgg')
from matplotlib.backends.backend_qtagg import FigureCanvasQTAgg
from matplotlib.figure import Figure
from pathlib import Path
import numpy as np
Expand All @@ -25,11 +25,12 @@ def __init__(self):
self._layout.addWidget(self._figure_widget)
self._skel_bones = None

def load_skeleton_data(self, mediapipe_skeleton_npy_path: Union[str, Path]):
self._skeleton_3d_frame_marker_xyz = np.load(str(mediapipe_skeleton_npy_path))
self._mediapipe_skeleton = build_skeleton(skeleton_3d_frame_marker_xyz=self._skeleton_3d_frame_marker_xyz,
pose_estimation_markers_list=mediapipe_indices,
pose_estimation_connections_dict=mediapipe_connections)
def load_skeleton_data(self, skeleton_npy_path: Union[str, Path], connections: List[tuple]):
self._skeleton_3d_frame_marker_xyz = np.load(str(skeleton_npy_path))
self._skeleton = build_skeleton(
skeleton_3d_frame_marker_xyz=self._skeleton_3d_frame_marker_xyz,
pose_estimation_connections=connections
)

self._number_of_frames = self._skeleton_3d_frame_marker_xyz.shape[0]
self._initialize_3d_axes()
Expand All @@ -50,11 +51,6 @@ def _initialize_3d_axes(self):
self._skel_bones = None
self.skel_bones = self._plot_skeleton_bones(0)

def reset_slider(self):
self._slider_max = self._number_of_frames - 1
self.slider.setValue(0)
self.slider.setMaximum(self._slider_max)

def _calculate_axes_means(self, skeleton_3d_frame_marker_xyz: np.ndarray):
self._data_midpoint_x = np.nanmean(skeleton_3d_frame_marker_xyz[:, :, 0])
self._data_midpoint_y = np.nanmean(skeleton_3d_frame_marker_xyz[:, :, 1])
Expand All @@ -76,28 +72,28 @@ def _plot_skeleton(self, frame_number, skeleton_points_x, skeleton_points_y, ske

self._figure_widget.figure.canvas.draw_idle()

def _plot_skeleton_bones(self, frame_number):
def _plot_skeleton_bones(self, frame_number: int):
if self._skel_bones is None:
this_frame_skeleton_data = self._mediapipe_skeleton[frame_number]
this_frame_skeleton_data = self._skeleton[frame_number]
self._skel_bones = []
for connection in this_frame_skeleton_data.keys():
line_start_point = this_frame_skeleton_data[connection][0]
line_end_point = this_frame_skeleton_data[connection][1]
for connection in this_frame_skeleton_data:
line_start_point = connection[0]
line_end_point = connection[1]

bone_x, bone_y, bone_z = [line_start_point[0], line_end_point[0]], [line_start_point[1],
line_end_point[1]], [
line_start_point[2], line_end_point[2]]
bone_x = [line_start_point[0], line_end_point[0]]
bone_y = [line_start_point[1], line_end_point[1]]
bone_z = [line_start_point[2], line_end_point[2]]
bone = self._3d_axes.plot(bone_x, bone_y, bone_z)[0]
self._skel_bones.append(bone)
else:
this_frame_skeleton_data = self._mediapipe_skeleton[frame_number]
for i, connection in enumerate(this_frame_skeleton_data.keys()):
line_start_point = this_frame_skeleton_data[connection][0]
line_end_point = this_frame_skeleton_data[connection][1]

bone_x, bone_y, bone_z = [line_start_point[0], line_end_point[0]], [line_start_point[1],
line_end_point[1]], [
line_start_point[2], line_end_point[2]]
this_frame_skeleton_data = self._skeleton[frame_number]
for i, connection in enumerate(this_frame_skeleton_data):
line_start_point = connection[0]
line_end_point = connection[1]

bone_x = [line_start_point[0], line_end_point[0]]
bone_y = [line_start_point[1], line_end_point[1]]
bone_z = [line_start_point[2], line_end_point[2]]
bone = self._skel_bones[i]
bone.set_xdata(bone_x)
bone.set_ydata(bone_y)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@

from typing import List
from PySide6.QtCore import Signal
from PySide6.QtWidgets import QWidget, QVBoxLayout, QComboBox

from skelly_viewer.utilities.mediapipe_skeleton_builder import mediapipe_indices


class MarkerSelectorWidget(QWidget):
marker_to_plot_updated_signal = Signal()
def __init__(self):
def __init__(self, markers: List[str]):
super().__init__()

self._layout = QVBoxLayout()
self.setLayout(self._layout)

combo_box_items = mediapipe_indices
combo_box_items = markers
# combo_box_items.insert(0,'')
self.marker_combo_box = QComboBox()
self.marker_combo_box.addItems(combo_box_items)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,17 @@


class TimeSeriesViewer(QWidget):
def __init__(self, freemocap_data:np.ndarray):
def __init__(self, freemocap_data: np.ndarray, markers: list):
super().__init__()

self.layout = QVBoxLayout()
self.setLayout(self.layout)


self.freemocap_data = freemocap_data
self.connections = markers

self.marker_selector_widget = MarkerSelectorWidget()
self.marker_selector_widget = MarkerSelectorWidget(markers)
self.layout.addWidget(self.marker_selector_widget)

self.time_series_plotter_widget = TimeSeriesPlotterWidget()
Expand All @@ -25,19 +26,19 @@ def __init__(self, freemocap_data:np.ndarray):
self.connect_signals_to_slots()

def connect_signals_to_slots(self):
self.marker_selector_widget.marker_to_plot_updated_signal.connect(lambda: self.time_series_plotter_widget.update_plot(self.marker_selector_widget.current_marker,self.freemocap_data))
self.marker_selector_widget.marker_to_plot_updated_signal.connect(lambda: self.time_series_plotter_widget.update_plot(self.marker_selector_widget.current_marker,self.freemocap_data, self.markers))


if __name__ == "__main__":

class MainWindow(QMainWindow):
def __init__(self, freemocap_data:np.ndarray):
def __init__(self, freemocap_data: np.ndarray, connections: list):
super().__init__()

layout = QVBoxLayout()
widget = QWidget()

self.time_series_viewer = TimeSeriesViewer(freemocap_data)
self.time_series_viewer = TimeSeriesViewer(freemocap_data, connections)
layout.addWidget(self.time_series_viewer)

widget.setLayout(layout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg, NavigationToolbar2QT as NavigationToolbar
from matplotlib.figure import Figure

from skelly_viewer.utilities.mediapipe_skeleton_builder import mediapipe_indices


import numpy as np

Expand Down Expand Up @@ -46,20 +44,20 @@ def initialize_skeleton_plot(self):
self.axes_list = [self.x_ax,self.y_ax,self.z_ax]
return fig, self.axes_list

def get_mediapipe_indices(self,marker_to_plot):
mediapipe_index = mediapipe_indices.index(marker_to_plot)
return mediapipe_index
def get_marker_indices(self, marker_to_plot, markers: list):
index = markers.index(marker_to_plot)
return index


def update_plot(self,marker_to_plot:str, freemocap_data:np.ndarray):
mediapipe_index = self.get_mediapipe_indices(marker_to_plot)
def update_plot(self,marker_to_plot:str, freemocap_data: np.ndarray, markers: list):
index = self.get_marker_indices(marker_to_plot, markers)

axes_names = ['X Axis', 'Y Axis', 'Z Axis']

for dimension, (ax,ax_name) in enumerate(zip(self.axes_list,axes_names)):

ax.cla()
ax.plot(freemocap_data[:,mediapipe_index,dimension], label = 'FreeMoCap', alpha = .7)
ax.plot(freemocap_data[:,index,dimension], label = 'FreeMoCap', alpha = .7)

ax.set_ylabel(ax_name)

Expand Down
18 changes: 7 additions & 11 deletions skelly_viewer/utilities/freemocap_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,14 @@

import numpy as np

from skelly_viewer.config.folder_and_file_names import MEDIAPIPE_3D_BODY_FILE_NAME, OUTPUT_DATA_FOLDER_NAME, \
TOTAL_BODY_CENTER_OF_MASS_NPY_FILE_NAME
from skelly_viewer.config.folder_and_file_names import MEDIAPIPE_3D_BODY_FILE_NAME, MEDIAPIPE_3D_BODY_ORIGIN_ALIGNED_FILE_NAME,\
OUTPUT_DATA_FOLDER_NAME, TOTAL_BODY_CENTER_OF_MASS_NPY_FILE_NAME


class FreeMoCapDataLoader:
def __init__(self, path_to_session_folder: Path):
self._recording_folder_path = path_to_session_folder

def load_mediapipe_body_data(self):
path_to_mediapipe_body_data = self.find_output_data_folder_path()
mediapipe_body_data = np.load(str(path_to_mediapipe_body_data))
return mediapipe_body_data

def load_total_body_COM_data(self):
path_to_total_body_COM_data = self.find_output_data_folder_path() / TOTAL_BODY_CENTER_OF_MASS_NPY_FILE_NAME
total_body_COM_data = np.load(str(path_to_total_body_COM_data))
Expand All @@ -33,11 +28,12 @@ def find_skeleton_npy_file_name(self) -> Path:

npy_path_list = [path.name for path in self.find_output_data_folder_path().glob("*.npy")]

if 'mediaPipeSkel_3d_origin_aligned.npy' in npy_path_list:
return self.find_output_data_folder_path() / MEDIAPIPE_3D_BODY_FILE_NAME
# TODO: Find a mediapipe independent version of this
if MEDIAPIPE_3D_BODY_ORIGIN_ALIGNED_FILE_NAME in npy_path_list:
return self.find_output_data_folder_path() / MEDIAPIPE_3D_BODY_ORIGIN_ALIGNED_FILE_NAME

if 'mediapipe_body_3d_xyz.npy' in npy_path_list:
return self.find_output_data_folder_path() / 'mediapipe_body_3d_xyz.npy'
if MEDIAPIPE_3D_BODY_FILE_NAME in npy_path_list:
return self.find_output_data_folder_path() / MEDIAPIPE_3D_BODY_FILE_NAME

raise Exception(f"Could not find a skeleton NPY file in path {str(self.find_output_data_folder_path())}")

Expand Down
4 changes: 2 additions & 2 deletions skelly_viewer/utilities/get_video_paths.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pathlib import Path
from typing import Union
from typing import List, Union


def get_video_paths(path_to_video_folder: Union[str, Path]) -> list:
def get_video_paths(path_to_video_folder: Union[str, Path]) -> List[Path]:
"""Search the folder for 'mp4' files (case insensitive) and return them as a list"""

list_of_video_paths = list(Path(path_to_video_folder).glob("*.mp4")) + list(
Expand Down
Loading