Skip to content

Commit

Permalink
Remove over-use of static methods, just one was useful
Browse files Browse the repository at this point in the history
  • Loading branch information
lauraporta committed Oct 25, 2023
1 parent 36f37b5 commit 779f75e
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 142 deletions.
126 changes: 36 additions & 90 deletions derotation/analysis/derotation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,35 +154,19 @@ def process_analog_signals(self):
self.full_rotation, self.k
)
self.rot_blocks_idx = self.correct_start_and_end_rotation_signal(
self.inter_rotation_interval_min_len, start, end
)
self.rotation_on = self.create_signed_rotation_array(
len(self.full_rotation),
self.rot_blocks_idx["start"],
self.rot_blocks_idx["end"],
self.direction,
start, end
)
self.rotation_on = self.create_signed_rotation_array()

self.rotation_ticks_peaks = self.drop_ticks_outside_of_rotation(
self.rotation_ticks_peaks,
self.rot_blocks_idx["start"],
self.rot_blocks_idx["end"],
self.total_clock_time,
self.number_of_rotations,
)
self.drop_ticks_outside_of_rotation()

self.check_number_of_rotations()
if not self.is_number_of_ticks_correct() and self.adjust_increment:
if self.assume_full_rotation:
(
self.corrected_increments,
self.ticks_per_rotation,
) = self.adjust_rotation_increment(
self.rotation_ticks_peaks,
self.rot_blocks_idx["start"],
self.rot_blocks_idx["end"],
self.rot_deg,
)
) = self.adjust_rotation_increment()
else:
self.corrected_increments = (
self.adjust_rotation_increment_for_incremental_changes()
Expand Down Expand Up @@ -276,9 +260,8 @@ def get_start_end_times_with_threshold(

return start, end

@staticmethod
def correct_start_and_end_rotation_signal(
inter_rotation_interval_min_len: int,
self,
start: np.ndarray,
end: np.ndarray,
) -> dict:
Expand All @@ -287,11 +270,12 @@ def correct_start_and_end_rotation_signal(
periods that are not plausible given the experimental setup.
The two surrounding on periods are merged.
Used the inter_rotation_interval_min_len parameter from the config
file: the minimum length of the time in between two rotations.
It is important to remove artifacts.
Parameters
----------
inter_rotation_interval_min_len : int
Minimum length of the time in between two rotations.
It is important to remove artifacts.
start : np.ndarray
The start times of the on periods of rotation signal.
end : np.ndarray
Expand All @@ -306,33 +290,21 @@ def correct_start_and_end_rotation_signal(
logging.info("Cleaning start and end rotation signal...")

shifted_end = np.roll(end, 1)
mask = start - shifted_end > inter_rotation_interval_min_len
mask = start - shifted_end > self.inter_rotation_interval_min_len
mask[0] = True # first rotation is always a full rotation
shifted_mask = np.roll(mask, -1)
new_start = start[mask]
new_end = end[shifted_mask]

return {"start": new_start, "end": new_end}

@staticmethod
def create_signed_rotation_array(
len_full_rotation: int, starts: np.ndarray, ends: np.ndarray, direction
) -> np.ndarray:
def create_signed_rotation_array(self) -> np.ndarray:
"""Reconstructs an array that has the same length as the full rotation
signal. It is 0 when the motor is off, and it is 1 or -1 when the motor
is on, depending on the direction of rotation. 1 is clockwise, -1 is
counter clockwise.
Parameters
----------
len_full_rotation : int
Length of the full rotation signal.
starts : np.ndarray
The start times of the on periods of rotation signal.
ends : np.ndarray
The end times of the on periods of rotation signal.
direction : _type_
The direction of rotation of the motor.
Uses the start and end times of the on periods of rotation signal, and
the direction of rotation to reconstruct the array.
Returns
-------
Expand All @@ -341,39 +313,19 @@ def create_signed_rotation_array(
"""

logging.info("Creating signed rotation array...")
rotation_on = np.zeros(len_full_rotation)
rotation_on = np.zeros(self.total_clock_time)
for i, (start, end) in enumerate(
zip(
starts,
ends,
self.rot_blocks_idx["start"],
self.rot_blocks_idx["end"],
)
):
rotation_on[start:end] = direction[i]
rotation_on[start:end] = self.direction[i]

return rotation_on

@staticmethod
def drop_ticks_outside_of_rotation(
rotation_ticks_peaks: np.ndarray,
starts: np.ndarray,
ends: np.ndarray,
full_length: int,
number_of_rotations: int,
) -> np.ndarray:
"""_summary_
Parameters
----------
rotation_ticks_peaks : np.ndarray
The clock times of the rotation ticks peaks.
starts : np.ndarray
The start times of the on periods of rotation signal.
ends : np.ndarray
The end times of the on periods of rotation signal.
full_length : int
The length of the analog signals, in clock time.
number_of_rotations : int
The number of rotations.
def drop_ticks_outside_of_rotation(self) -> np.ndarray:
"""Drops the rotation ticks that are outside of the rotation periods.
Returns
-------
Expand All @@ -384,33 +336,33 @@ def drop_ticks_outside_of_rotation(

logging.info("Dropping ticks outside of the rotation period...")

len_before = len(rotation_ticks_peaks)
len_before = len(self.rotation_ticks_peaks)

rolled_starts = np.roll(starts, -1)
rolled_starts[-1] = full_length
rolled_starts = np.roll(self.rot_blocks_idx["start"], -1)
rolled_starts[-1] = self.total_clock_time

inter_roatation_interval = [
idx
for i in range(number_of_rotations)
for i in range(self.number_of_rotations)
for idx in range(
ends[i],
self.rot_blocks_idx["end"][i],
rolled_starts[i],
)
]

rotation_ticks_peaks = np.delete(
rotation_ticks_peaks,
np.where(np.isin(rotation_ticks_peaks, inter_roatation_interval)),
self.rotation_ticks_peaks = np.delete(
self.rotation_ticks_peaks,
np.where(
np.isin(self.rotation_ticks_peaks, inter_roatation_interval)
),
)

len_after = len(rotation_ticks_peaks)
len_after = len(self.rotation_ticks_peaks)
logging.info(
f"Ticks dropped: {len_before - len_after}.\n"
+ f"Ticks remaining: {len_after}"
)

return rotation_ticks_peaks

def check_number_of_rotations(self):
"""Checks that the number of rotations is as expected.
Expand Down Expand Up @@ -460,13 +412,7 @@ def is_number_of_ticks_correct(self) -> bool:
)
return False

@staticmethod
def adjust_rotation_increment(
rotation_ticks_peaks: np.ndarray,
starts: np.ndarray,
ends: np.ndarray,
rot_deg: int,
) -> Tuple[np.ndarray, np.ndarray]:
def adjust_rotation_increment(self) -> Tuple[np.ndarray, np.ndarray]:
"""It calculates the new rotation increment for each rotation, given
the number of ticks in each rotation. It also outputs the number of
ticks in each rotation.
Expand All @@ -492,19 +438,19 @@ def adjust_rotation_increment(
def get_peaks_in_rotation(start, end):
return np.where(
np.logical_and(
rotation_ticks_peaks > start,
rotation_ticks_peaks < end,
self.rotation_ticks_peaks > start,
self.rotation_ticks_peaks < end,
)
)[0].shape[0]

ticks_per_rotation = [
get_peaks_in_rotation(start, end)
for start, end in zip(
starts,
ends,
self.rot_blocks_idx["start"],
self.rot_blocks_idx["end"],
)
]
new_increments = [rot_deg / t for t in ticks_per_rotation]
new_increments = [self.rot_deg / t for t in ticks_per_rotation]

logging.info(f"New increment example: {new_increments[0]:.3f}")

Expand Down
28 changes: 28 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import numpy as np
import pytest

from derotation.analysis.derotation_pipeline import DerotationPipeline


@pytest.fixture(autouse=True)
def random():
Expand Down Expand Up @@ -115,3 +117,29 @@ def rotation_ticks(
)
ticks = np.sort(ticks)
return ticks


@pytest.fixture
def derotation_pipeline(
rotation_ticks,
start_end_times,
full_length,
number_of_rotations,
full_rotation,
direction,
):
pipeline = DerotationPipeline.__new__(DerotationPipeline)

pipeline.inter_rotation_interval_min_len = 50
pipeline.rotation_ticks_peaks = rotation_ticks
pipeline.rot_blocks_idx = {
"start": start_end_times[0],
"end": start_end_times[1],
}
pipeline.number_of_rotations = number_of_rotations
pipeline.direction = direction
pipeline.total_clock_time = full_length
pipeline.full_rotation = full_rotation
pipeline.rot_deg = 360

return pipeline
28 changes: 6 additions & 22 deletions tests/test_unit/test_adjust_rotation_increment.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,12 @@


def test_adjust_rotation_increment_360(
rotation_ticks,
start_end_times,
derotation_pipeline: DerotationPipeline,
):
start, end = start_end_times
rot_deg = 360

(
new_increments,
ticks_per_rotation,
) = DerotationPipeline.adjust_rotation_increment(
rotation_ticks,
start,
end,
rot_deg,
)
) = derotation_pipeline.adjust_rotation_increment()

new_increments = np.round(new_increments, 0)

Expand All @@ -32,21 +23,14 @@ def test_adjust_rotation_increment_360(


def test_adjust_rotation_increment_5(
rotation_ticks,
start_end_times,
derotation_pipeline: DerotationPipeline,
):
start, end = start_end_times
rot_deg = 5
derotation_pipeline.rot_deg = 5

(
new_increments,
ticks_per_rotation,
) = DerotationPipeline.adjust_rotation_increment(
rotation_ticks,
start,
end,
rot_deg,
)
_,
) = derotation_pipeline.adjust_rotation_increment()

new_increments = np.round(new_increments, 3)

Expand Down
22 changes: 8 additions & 14 deletions tests/test_unit/test_create_signed_rotation_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,25 @@


def test_create_signed_rotation_array_interleaved(
full_length, start_end_times, direction_interleaved
derotation_pipeline: DerotationPipeline,
start_end_times: tuple,
):
start, end = start_end_times
rotation_on = DerotationPipeline.create_signed_rotation_array(
full_length,
start,
end,
direction_interleaved,
)
rotation_on = derotation_pipeline.create_signed_rotation_array()

for idx in range(0, len(start), 2):
assert np.all(rotation_on[start[idx] : end[idx]] == 1)
assert np.all(rotation_on[start[idx + 1] : end[idx + 1]] == -1)


def test_create_signed_rotation_array_incremental(
full_length, start_end_times, direction_incremental
derotation_pipeline: DerotationPipeline,
start_end_times: tuple,
direction_incremental: np.ndarray,
):
derotation_pipeline.direction = direction_incremental
start, end = start_end_times
rotation_on = DerotationPipeline.create_signed_rotation_array(
full_length,
start,
end,
direction_incremental,
)
rotation_on = derotation_pipeline.create_signed_rotation_array()

for idx in range(0, 5):
assert np.all(rotation_on[start[idx] : end[idx]] == 1)
Expand Down
9 changes: 3 additions & 6 deletions tests/test_unit/test_drop_ticks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@


def test_drop_ticks_generated_randomly(
rotation_ticks, start_end_times, full_length, number_of_rotations
derotation_pipeline: DerotationPipeline,
):
start, end = start_end_times
cleaned_ticks = DerotationPipeline.drop_ticks_outside_of_rotation(
rotation_ticks, start, end, full_length, number_of_rotations
)
derotation_pipeline.drop_ticks_outside_of_rotation()

assert len(cleaned_ticks) == 362
assert len(derotation_pipeline.rotation_ticks_peaks) == 362
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import numpy as np

from derotation.analysis.derotation_pipeline import DerotationPipeline


def test_finding_correct_start_end_times_with_threshold(
full_rotation, k, rotation_len, number_of_rotations
derotation_pipeline: DerotationPipeline,
full_rotation: np.ndarray,
k: int,
number_of_rotations: int,
rotation_len: int,
):
start, end = DerotationPipeline.get_start_end_times_with_threshold(
start, end = derotation_pipeline.get_start_end_times_with_threshold(
full_rotation, k
)

Expand Down
Loading

0 comments on commit 779f75e

Please sign in to comment.