Skip to content

Commit

Permalink
Run ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
cboulay committed Sep 26, 2024
1 parent 74ef32a commit 1fb927c
Show file tree
Hide file tree
Showing 45 changed files with 818 additions and 462 deletions.
13 changes: 8 additions & 5 deletions src/ezmsg/sigproc/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class ActivationFunction(OptionsEnum):
"""Activation (transformation) function."""

NONE = "none"
"""None."""

Expand Down Expand Up @@ -48,8 +49,12 @@ def activation(
# str type. There's probably an easier way to support either enum or str argument. Oh well this works.
function: str = function.lower()
if function not in ActivationFunction.options():
raise ValueError(f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}")
function = list(ACTIVATIONS.keys())[ActivationFunction.options().index(function)]
raise ValueError(
f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}"
)
function = list(ACTIVATIONS.keys())[
ActivationFunction.options().index(function)
]
func = ACTIVATIONS[function]

msg_out = AxisArray(np.array([]), dims=[""])
Expand All @@ -67,6 +72,4 @@ class Activation(GenAxisArray):
SETTINGS = ActivationSettings

def construct_generator(self):
self.STATE.gen = activation(
function=self.SETTINGS.function
)
self.STATE.gen = activation(function=self.SETTINGS.function)
23 changes: 16 additions & 7 deletions src/ezmsg/sigproc/affinetransform.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def affine_transform(
weights = np.ascontiguousarray(weights)

# State variables
new_axis: typing.Optional[AxisArray.Axis] = None # New axis with transformed labels, if required
# New axis with transformed labels, if required
new_axis: typing.Optional[AxisArray.Axis] = None

# Reset if any of these change.
check_input = {"key": None}
Expand All @@ -69,17 +70,19 @@ def affine_transform(
check_input["key"] = msg_in.key
# Determine if we need to modify the transformed axis.
if (
axis in msg_in.axes
and hasattr(msg_in.axes[axis], "labels")
and weights.shape[0] != weights.shape[1]
axis in msg_in.axes
and hasattr(msg_in.axes[axis], "labels")
and weights.shape[0] != weights.shape[1]
):
in_labels = msg_in.axes[axis].labels
new_labels = []
n_in = weights.shape[1 if right_multiply else 0]
n_out = weights.shape[0 if right_multiply else 1]
if len(in_labels) != n_in:
# Something upstream did something it wasn't supposed to. We will drop the labels.
ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.")
ez.logger.warning(
f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels."
)
else:
b_used_inputs = np.any(weights, axis=0 if right_multiply else 1)
b_filled_outputs = np.any(weights, axis=1 if right_multiply else 0)
Expand Down Expand Up @@ -107,8 +110,10 @@ def affine_transform(
if data.shape[axis_idx] == (weights.shape[0] - 1):
# The weights are stacked A|B where A is the transform and B is a single row
# in the equation y = Ax + B. This supports NeuroKey's weights matrices.
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx+1:]
data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx)
sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :]
data = np.concatenate(
(data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx
)

if axis_idx in [-1, len(msg_in.dims) - 1]:
data = np.matmul(data, weights)
Expand All @@ -128,13 +133,15 @@ class AffineTransformSettings(ez.Settings):
Settings for :obj:`AffineTransform`.
See :obj:`affine_transform` for argument details.
"""

weights: typing.Union[np.ndarray, str, Path]
axis: typing.Optional[str] = None
right_multiply: bool = True


class AffineTransform(GenAxisArray):
""":obj:`Unit` for :obj:`affine_transform`"""

SETTINGS = AffineTransformSettings

def construct_generator(self):
Expand Down Expand Up @@ -206,6 +213,7 @@ class CommonRereferenceSettings(ez.Settings):
Settings for :obj:`CommonRereference`
See :obj:`common_rereference` for argument details.
"""

mode: str = "mean"
axis: typing.Optional[str] = None
include_current: bool = True
Expand All @@ -215,6 +223,7 @@ class CommonRereference(GenAxisArray):
"""
:obj:`Unit` for :obj:`common_rereference`.
"""

SETTINGS = CommonRereferenceSettings

def construct_generator(self):
Expand Down
20 changes: 13 additions & 7 deletions src/ezmsg/sigproc/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

class AggregationFunction(OptionsEnum):
"""Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation."""

NONE = "None (all)"
MAX = "max"
MIN = "min"
Expand Down Expand Up @@ -49,7 +50,7 @@ class AggregationFunction(OptionsEnum):
def ranged_aggregate(
axis: typing.Optional[str] = None,
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None,
operation: AggregationFunction = AggregationFunction.MEAN
operation: AggregationFunction = AggregationFunction.MEAN,
):
"""
Apply an aggregation operation over one or more bands.
Expand Down Expand Up @@ -92,17 +93,20 @@ def ranged_aggregate(

ax_idx = msg_in.get_axis_idx(axis)

ax_vec = target_axis.offset + np.arange(msg_in.data.shape[ax_idx]) * target_axis.gain
ax_vec = (
target_axis.offset
+ np.arange(msg_in.data.shape[ax_idx]) * target_axis.gain
)
slices = []
mids = []
for (start, stop) in bands:
for start, stop in bands:
inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0]
mids.append(np.mean(inds) * target_axis.gain + target_axis.offset)
slices.append(np.s_[inds[0]:inds[-1] + 1])
slices.append(np.s_[inds[0] : inds[-1] + 1])
out_ax_kwargs = {
"unit": target_axis.unit,
"offset": mids[0],
"gain": (mids[1] - mids[0]) if len(mids) > 1 else 1.0
"gain": (mids[1] - mids[0]) if len(mids) > 1 else 1.0,
}
if hasattr(target_axis, "labels"):
out_ax_kwargs["labels"] = [f"{_[0]} - {_[1]}" for _ in bands]
Expand All @@ -117,7 +121,7 @@ def ranged_aggregate(
msg_out = replace(
msg_in,
data=np.stack(out_data, axis=ax_idx),
axes={**msg_in.axes, axis: out_axis}
axes={**msg_in.axes, axis: out_axis},
)
if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]:
# Convert indices returned by argmin/argmax into the value along the axis.
Expand All @@ -133,6 +137,7 @@ class RangedAggregateSettings(ez.Settings):
Settings for ``RangedAggregate``.
See :obj:`ranged_aggregate` for details.
"""

axis: typing.Optional[str] = None
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None
operation: AggregationFunction = AggregationFunction.MEAN
Expand All @@ -142,11 +147,12 @@ class RangedAggregate(GenAxisArray):
"""
Unit for :obj:`ranged_aggregate`
"""

SETTINGS = RangedAggregateSettings

def construct_generator(self):
self.STATE.gen = ranged_aggregate(
axis=self.SETTINGS.axis,
bands=self.SETTINGS.bands,
operation=self.SETTINGS.operation
operation=self.SETTINGS.operation,
)
24 changes: 15 additions & 9 deletions src/ezmsg/sigproc/bandpower.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
@consumer
def bandpower(
spectrogram_settings: SpectrogramSettings,
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [(17, 30), (70, 170)]
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [
(17, 30),
(70, 170),
],
) -> typing.Generator[AxisArray, AxisArray, None]:
"""
Calculate the average spectral power in each band.
Expand All @@ -33,12 +36,10 @@ def bandpower(
window_shift=spectrogram_settings.window_shift,
window=spectrogram_settings.window,
transform=spectrogram_settings.transform,
output=spectrogram_settings.output
output=spectrogram_settings.output,
)
f_agg = ranged_aggregate(
axis="freq",
bands=bands,
operation=AggregationFunction.MEAN
axis="freq", bands=bands, operation=AggregationFunction.MEAN
)
pipeline = compose(f_spec, f_agg)

Expand All @@ -52,17 +53,22 @@ class BandPowerSettings(ez.Settings):
Settings for ``BandPower``.
See :obj:`bandpower` for details.
"""
spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings)
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = (
field(default_factory=lambda: [(17, 30), (70, 170)]))

spectrogram_settings: SpectrogramSettings = field(
default_factory=SpectrogramSettings
)
bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = field(
default_factory=lambda: [(17, 30), (70, 170)]
)


class BandPower(GenAxisArray):
""":obj:`Unit` for :obj:`bandpower`."""

SETTINGS = BandPowerSettings

def construct_generator(self):
self.STATE.gen = bandpower(
spectrogram_settings=self.SETTINGS.spectrogram_settings,
bands=self.SETTINGS.bands
bands=self.SETTINGS.bands,
)
1 change: 0 additions & 1 deletion src/ezmsg/sigproc/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,3 @@ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator:
ez.logger.debug(f"Generator closed in {self.address}")
except Exception:
ez.logger.info(traceback.format_exc())

16 changes: 13 additions & 3 deletions src/ezmsg/sigproc/butterworthfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

class ButterworthFilterSettings(FilterSettingsBase):
"""Settings for :obj:`ButterworthFilter`."""

order: int = 0

cuton: typing.Optional[float] = None
Expand All @@ -27,7 +28,11 @@ class ButterworthFilterSettings(FilterSettingsBase):
or if it is less than cuton then it is the beginning of the bandstop.
"""

def filter_specs(self) -> typing.Optional[typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]]:
def filter_specs(
self,
) -> typing.Optional[
typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]
]:
"""
Determine the filter type given the corner frequencies.
Expand Down Expand Up @@ -83,7 +88,8 @@ def butter(
).filter_specs()

# State variables
filter_gen = filtergen(axis, None, coef_type) # Initialize filtergen as passthrough until we can calculate coefs.
# Initialize filtergen as passthrough until we can calculate coefs.
filter_gen = filtergen(axis, None, coef_type)

# Reset if these change.
check_input = {"gain": None}
Expand All @@ -98,7 +104,11 @@ def butter(
if b_reset:
check_input["gain"] = msg_in.axes[axis].gain
coefs = scipy.signal.butter(
order, Wn=cutoffs, btype=btype, fs=1 / msg_in.axes[axis].gain, output=coef_type
order,
Wn=cutoffs,
btype=btype,
fs=1 / msg_in.axes[axis].gain,
output=coef_type,
)
filter_gen = filtergen(axis, coefs, coef_type)

Expand Down
1 change: 1 addition & 0 deletions src/ezmsg/sigproc/decimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class Decimate(ez.Collection):
A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter
and a :obj:`Downsample` node.
"""

SETTINGS = DownsampleSettings

INPUT_SIGNAL = ez.InputStream(AxisArray)
Expand Down
19 changes: 11 additions & 8 deletions src/ezmsg/sigproc/downsample.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

@consumer
def downsample(
axis: typing.Optional[str] = None,
factor: int = 1
axis: typing.Optional[str] = None, factor: int = 1
) -> typing.Generator[AxisArray, AxisArray, None]:
"""
Construct a generator that yields a downsampled version of the data .send() to it.
Expand Down Expand Up @@ -50,7 +49,10 @@ def downsample(
axis_info = msg_in.get_axis(axis)
axis_idx = msg_in.get_axis_idx(axis)

b_reset = msg_in.axes[axis].gain != check_input["gain"] or msg_in.key != check_input["key"]
b_reset = (
msg_in.axes[axis].gain != check_input["gain"]
or msg_in.key != check_input["key"]
)
if b_reset:
check_input["gain"] = axis_info.gain
check_input["key"] = msg_in.key
Expand Down Expand Up @@ -78,9 +80,9 @@ def downsample(
axis: replace(
axis_info,
gain=axis_info.gain * factor,
offset=axis_info.offset + axis_info.gain * n_step
)
}
offset=axis_info.offset + axis_info.gain * n_step,
),
},
)


Expand All @@ -89,16 +91,17 @@ class DownsampleSettings(ez.Settings):
Settings for :obj:`Downsample` node.
See :obj:`downsample` documentation for a description of the parameters.
"""

axis: typing.Optional[str] = None
factor: int = 1


class Downsample(GenAxisArray):
""":obj:`Unit` for :obj:`bandpower`."""

SETTINGS = DownsampleSettings

def construct_generator(self):
self.STATE.gen = downsample(
axis=self.SETTINGS.axis,
factor=self.SETTINGS.factor
axis=self.SETTINGS.axis, factor=self.SETTINGS.factor
)
5 changes: 2 additions & 3 deletions src/ezmsg/sigproc/ewmfilter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from .window import Window, WindowSettings




class EWMSettings(ez.Settings):
axis: typing.Optional[str] = None
"""Name of the axis to accumulate."""
Expand Down Expand Up @@ -117,6 +115,7 @@ class EWMFilter(ez.Collection):
Consider :obj:`scaler` for a more efficient alternative.
"""

SETTINGS = EWMFilterSettings

INPUT_SIGNAL = ez.InputStream(AxisArray)
Expand All @@ -128,7 +127,7 @@ class EWMFilter(ez.Collection):
def configure(self) -> None:
self.EWM.apply_settings(
EWMSettings(
axis=self.SETTINGS.axis,
axis=self.SETTINGS.axis,
zero_offset=self.SETTINGS.zero_offset,
)
)
Expand Down
Loading

0 comments on commit 1fb927c

Please sign in to comment.