diff --git a/src/ezmsg/sigproc/activation.py b/src/ezmsg/sigproc/activation.py index fc674e9..7b8ab12 100644 --- a/src/ezmsg/sigproc/activation.py +++ b/src/ezmsg/sigproc/activation.py @@ -13,6 +13,7 @@ class ActivationFunction(OptionsEnum): """Activation (transformation) function.""" + NONE = "none" """None.""" @@ -48,8 +49,12 @@ def activation( # str type. There's probably an easier way to support either enum or str argument. Oh well this works. function: str = function.lower() if function not in ActivationFunction.options(): - raise ValueError(f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}") - function = list(ACTIVATIONS.keys())[ActivationFunction.options().index(function)] + raise ValueError( + f"Unrecognized activation function {function}. Must be one of {ACTIVATIONS.keys()}" + ) + function = list(ACTIVATIONS.keys())[ + ActivationFunction.options().index(function) + ] func = ACTIVATIONS[function] msg_out = AxisArray(np.array([]), dims=[""]) @@ -67,6 +72,4 @@ class Activation(GenAxisArray): SETTINGS = ActivationSettings def construct_generator(self): - self.STATE.gen = activation( - function=self.SETTINGS.function - ) + self.STATE.gen = activation(function=self.SETTINGS.function) diff --git a/src/ezmsg/sigproc/affinetransform.py b/src/ezmsg/sigproc/affinetransform.py index adf330c..85af22b 100644 --- a/src/ezmsg/sigproc/affinetransform.py +++ b/src/ezmsg/sigproc/affinetransform.py @@ -46,7 +46,8 @@ def affine_transform( weights = np.ascontiguousarray(weights) # State variables - new_axis: typing.Optional[AxisArray.Axis] = None # New axis with transformed labels, if required + # New axis with transformed labels, if required + new_axis: typing.Optional[AxisArray.Axis] = None # Reset if any of these change. check_input = {"key": None} @@ -69,9 +70,9 @@ def affine_transform( check_input["key"] = msg_in.key # Determine if we need to modify the transformed axis. if ( - axis in msg_in.axes - and hasattr(msg_in.axes[axis], "labels") - and weights.shape[0] != weights.shape[1] + axis in msg_in.axes + and hasattr(msg_in.axes[axis], "labels") + and weights.shape[0] != weights.shape[1] ): in_labels = msg_in.axes[axis].labels new_labels = [] @@ -79,7 +80,9 @@ def affine_transform( n_out = weights.shape[0 if right_multiply else 1] if len(in_labels) != n_in: # Something upstream did something it wasn't supposed to. We will drop the labels. - ez.logger.warning(f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels.") + ez.logger.warning( + f"Received {len(in_labels)} for {n_in} inputs. Check upstream labels." + ) else: b_used_inputs = np.any(weights, axis=0 if right_multiply else 1) b_filled_outputs = np.any(weights, axis=1 if right_multiply else 0) @@ -107,8 +110,10 @@ def affine_transform( if data.shape[axis_idx] == (weights.shape[0] - 1): # The weights are stacked A|B where A is the transform and B is a single row # in the equation y = Ax + B. This supports NeuroKey's weights matrices. - sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx+1:] - data = np.concatenate((data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx) + sample_shape = data.shape[:axis_idx] + (1,) + data.shape[axis_idx + 1 :] + data = np.concatenate( + (data, np.ones(sample_shape).astype(data.dtype)), axis=axis_idx + ) if axis_idx in [-1, len(msg_in.dims) - 1]: data = np.matmul(data, weights) @@ -128,6 +133,7 @@ class AffineTransformSettings(ez.Settings): Settings for :obj:`AffineTransform`. See :obj:`affine_transform` for argument details. """ + weights: typing.Union[np.ndarray, str, Path] axis: typing.Optional[str] = None right_multiply: bool = True @@ -135,6 +141,7 @@ class AffineTransformSettings(ez.Settings): class AffineTransform(GenAxisArray): """:obj:`Unit` for :obj:`affine_transform`""" + SETTINGS = AffineTransformSettings def construct_generator(self): @@ -206,6 +213,7 @@ class CommonRereferenceSettings(ez.Settings): Settings for :obj:`CommonRereference` See :obj:`common_rereference` for argument details. """ + mode: str = "mean" axis: typing.Optional[str] = None include_current: bool = True @@ -215,6 +223,7 @@ class CommonRereference(GenAxisArray): """ :obj:`Unit` for :obj:`common_rereference`. """ + SETTINGS = CommonRereferenceSettings def construct_generator(self): diff --git a/src/ezmsg/sigproc/aggregate.py b/src/ezmsg/sigproc/aggregate.py index 85d72f1..2d5cff3 100644 --- a/src/ezmsg/sigproc/aggregate.py +++ b/src/ezmsg/sigproc/aggregate.py @@ -13,6 +13,7 @@ class AggregationFunction(OptionsEnum): """Enum for aggregation functions available to be used in :obj:`ranged_aggregate` operation.""" + NONE = "None (all)" MAX = "max" MIN = "min" @@ -49,7 +50,7 @@ class AggregationFunction(OptionsEnum): def ranged_aggregate( axis: typing.Optional[str] = None, bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None, - operation: AggregationFunction = AggregationFunction.MEAN + operation: AggregationFunction = AggregationFunction.MEAN, ): """ Apply an aggregation operation over one or more bands. @@ -92,17 +93,20 @@ def ranged_aggregate( ax_idx = msg_in.get_axis_idx(axis) - ax_vec = target_axis.offset + np.arange(msg_in.data.shape[ax_idx]) * target_axis.gain + ax_vec = ( + target_axis.offset + + np.arange(msg_in.data.shape[ax_idx]) * target_axis.gain + ) slices = [] mids = [] - for (start, stop) in bands: + for start, stop in bands: inds = np.where(np.logical_and(ax_vec >= start, ax_vec <= stop))[0] mids.append(np.mean(inds) * target_axis.gain + target_axis.offset) - slices.append(np.s_[inds[0]:inds[-1] + 1]) + slices.append(np.s_[inds[0] : inds[-1] + 1]) out_ax_kwargs = { "unit": target_axis.unit, "offset": mids[0], - "gain": (mids[1] - mids[0]) if len(mids) > 1 else 1.0 + "gain": (mids[1] - mids[0]) if len(mids) > 1 else 1.0, } if hasattr(target_axis, "labels"): out_ax_kwargs["labels"] = [f"{_[0]} - {_[1]}" for _ in bands] @@ -117,7 +121,7 @@ def ranged_aggregate( msg_out = replace( msg_in, data=np.stack(out_data, axis=ax_idx), - axes={**msg_in.axes, axis: out_axis} + axes={**msg_in.axes, axis: out_axis}, ) if operation in [AggregationFunction.ARGMIN, AggregationFunction.ARGMAX]: # Convert indices returned by argmin/argmax into the value along the axis. @@ -133,6 +137,7 @@ class RangedAggregateSettings(ez.Settings): Settings for ``RangedAggregate``. See :obj:`ranged_aggregate` for details. """ + axis: typing.Optional[str] = None bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = None operation: AggregationFunction = AggregationFunction.MEAN @@ -142,11 +147,12 @@ class RangedAggregate(GenAxisArray): """ Unit for :obj:`ranged_aggregate` """ + SETTINGS = RangedAggregateSettings def construct_generator(self): self.STATE.gen = ranged_aggregate( axis=self.SETTINGS.axis, bands=self.SETTINGS.bands, - operation=self.SETTINGS.operation + operation=self.SETTINGS.operation, ) diff --git a/src/ezmsg/sigproc/bandpower.py b/src/ezmsg/sigproc/bandpower.py index b945b98..09e397c 100644 --- a/src/ezmsg/sigproc/bandpower.py +++ b/src/ezmsg/sigproc/bandpower.py @@ -14,7 +14,10 @@ @consumer def bandpower( spectrogram_settings: SpectrogramSettings, - bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [(17, 30), (70, 170)] + bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = [ + (17, 30), + (70, 170), + ], ) -> typing.Generator[AxisArray, AxisArray, None]: """ Calculate the average spectral power in each band. @@ -33,12 +36,10 @@ def bandpower( window_shift=spectrogram_settings.window_shift, window=spectrogram_settings.window, transform=spectrogram_settings.transform, - output=spectrogram_settings.output + output=spectrogram_settings.output, ) f_agg = ranged_aggregate( - axis="freq", - bands=bands, - operation=AggregationFunction.MEAN + axis="freq", bands=bands, operation=AggregationFunction.MEAN ) pipeline = compose(f_spec, f_agg) @@ -52,17 +53,22 @@ class BandPowerSettings(ez.Settings): Settings for ``BandPower``. See :obj:`bandpower` for details. """ - spectrogram_settings: SpectrogramSettings = field(default_factory=SpectrogramSettings) - bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = ( - field(default_factory=lambda: [(17, 30), (70, 170)])) + + spectrogram_settings: SpectrogramSettings = field( + default_factory=SpectrogramSettings + ) + bands: typing.Optional[typing.List[typing.Tuple[float, float]]] = field( + default_factory=lambda: [(17, 30), (70, 170)] + ) class BandPower(GenAxisArray): """:obj:`Unit` for :obj:`bandpower`.""" + SETTINGS = BandPowerSettings def construct_generator(self): self.STATE.gen = bandpower( spectrogram_settings=self.SETTINGS.spectrogram_settings, - bands=self.SETTINGS.bands + bands=self.SETTINGS.bands, ) diff --git a/src/ezmsg/sigproc/base.py b/src/ezmsg/sigproc/base.py index d1d1169..749025d 100644 --- a/src/ezmsg/sigproc/base.py +++ b/src/ezmsg/sigproc/base.py @@ -36,4 +36,3 @@ async def on_signal(self, message: AxisArray) -> typing.AsyncGenerator: ez.logger.debug(f"Generator closed in {self.address}") except Exception: ez.logger.info(traceback.format_exc()) - diff --git a/src/ezmsg/sigproc/butterworthfilter.py b/src/ezmsg/sigproc/butterworthfilter.py index 10faa6b..85e0d75 100644 --- a/src/ezmsg/sigproc/butterworthfilter.py +++ b/src/ezmsg/sigproc/butterworthfilter.py @@ -11,6 +11,7 @@ class ButterworthFilterSettings(FilterSettingsBase): """Settings for :obj:`ButterworthFilter`.""" + order: int = 0 cuton: typing.Optional[float] = None @@ -27,7 +28,11 @@ class ButterworthFilterSettings(FilterSettingsBase): or if it is less than cuton then it is the beginning of the bandstop. """ - def filter_specs(self) -> typing.Optional[typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]]]: + def filter_specs( + self, + ) -> typing.Optional[ + typing.Tuple[str, typing.Union[float, typing.Tuple[float, float]]] + ]: """ Determine the filter type given the corner frequencies. @@ -83,7 +88,8 @@ def butter( ).filter_specs() # State variables - filter_gen = filtergen(axis, None, coef_type) # Initialize filtergen as passthrough until we can calculate coefs. + # Initialize filtergen as passthrough until we can calculate coefs. + filter_gen = filtergen(axis, None, coef_type) # Reset if these change. check_input = {"gain": None} @@ -98,7 +104,11 @@ def butter( if b_reset: check_input["gain"] = msg_in.axes[axis].gain coefs = scipy.signal.butter( - order, Wn=cutoffs, btype=btype, fs=1 / msg_in.axes[axis].gain, output=coef_type + order, + Wn=cutoffs, + btype=btype, + fs=1 / msg_in.axes[axis].gain, + output=coef_type, ) filter_gen = filtergen(axis, coefs, coef_type) diff --git a/src/ezmsg/sigproc/decimate.py b/src/ezmsg/sigproc/decimate.py index 39cdfec..c5a0d6b 100644 --- a/src/ezmsg/sigproc/decimate.py +++ b/src/ezmsg/sigproc/decimate.py @@ -11,6 +11,7 @@ class Decimate(ez.Collection): A :obj:`Collection` chaining a :obj:`Filter` node configured as a lowpass Chebyshev filter and a :obj:`Downsample` node. """ + SETTINGS = DownsampleSettings INPUT_SIGNAL = ez.InputStream(AxisArray) diff --git a/src/ezmsg/sigproc/downsample.py b/src/ezmsg/sigproc/downsample.py index 6df2f04..247e1f7 100644 --- a/src/ezmsg/sigproc/downsample.py +++ b/src/ezmsg/sigproc/downsample.py @@ -11,8 +11,7 @@ @consumer def downsample( - axis: typing.Optional[str] = None, - factor: int = 1 + axis: typing.Optional[str] = None, factor: int = 1 ) -> typing.Generator[AxisArray, AxisArray, None]: """ Construct a generator that yields a downsampled version of the data .send() to it. @@ -50,7 +49,10 @@ def downsample( axis_info = msg_in.get_axis(axis) axis_idx = msg_in.get_axis_idx(axis) - b_reset = msg_in.axes[axis].gain != check_input["gain"] or msg_in.key != check_input["key"] + b_reset = ( + msg_in.axes[axis].gain != check_input["gain"] + or msg_in.key != check_input["key"] + ) if b_reset: check_input["gain"] = axis_info.gain check_input["key"] = msg_in.key @@ -78,9 +80,9 @@ def downsample( axis: replace( axis_info, gain=axis_info.gain * factor, - offset=axis_info.offset + axis_info.gain * n_step - ) - } + offset=axis_info.offset + axis_info.gain * n_step, + ), + }, ) @@ -89,16 +91,17 @@ class DownsampleSettings(ez.Settings): Settings for :obj:`Downsample` node. See :obj:`downsample` documentation for a description of the parameters. """ + axis: typing.Optional[str] = None factor: int = 1 class Downsample(GenAxisArray): """:obj:`Unit` for :obj:`bandpower`.""" + SETTINGS = DownsampleSettings def construct_generator(self): self.STATE.gen = downsample( - axis=self.SETTINGS.axis, - factor=self.SETTINGS.factor + axis=self.SETTINGS.axis, factor=self.SETTINGS.factor ) diff --git a/src/ezmsg/sigproc/ewmfilter.py b/src/ezmsg/sigproc/ewmfilter.py index 54bc1ad..643751d 100644 --- a/src/ezmsg/sigproc/ewmfilter.py +++ b/src/ezmsg/sigproc/ewmfilter.py @@ -9,8 +9,6 @@ from .window import Window, WindowSettings - - class EWMSettings(ez.Settings): axis: typing.Optional[str] = None """Name of the axis to accumulate.""" @@ -117,6 +115,7 @@ class EWMFilter(ez.Collection): Consider :obj:`scaler` for a more efficient alternative. """ + SETTINGS = EWMFilterSettings INPUT_SIGNAL = ez.InputStream(AxisArray) @@ -128,7 +127,7 @@ class EWMFilter(ez.Collection): def configure(self) -> None: self.EWM.apply_settings( EWMSettings( - axis=self.SETTINGS.axis, + axis=self.SETTINGS.axis, zero_offset=self.SETTINGS.zero_offset, ) ) diff --git a/src/ezmsg/sigproc/filter.py b/src/ezmsg/sigproc/filter.py index 1c894ac..05b1686 100644 --- a/src/ezmsg/sigproc/filter.py +++ b/src/ezmsg/sigproc/filter.py @@ -17,8 +17,10 @@ class FilterCoefficients: def _normalize_coefs( - coefs: typing.Union[FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray],npt.NDArray] -) -> typing.Tuple[str, typing.Tuple[npt.NDArray,...]]: + coefs: typing.Union[ + FilterCoefficients, typing.Tuple[npt.NDArray, npt.NDArray], npt.NDArray + ], +) -> typing.Tuple[str, typing.Tuple[npt.NDArray, ...]]: coef_type = "ba" if coefs is not None: # scipy.signal functions called with first arg `*coefs`. @@ -86,7 +88,9 @@ def filtergen( n_tail = msg_in.data.ndim - axis_idx - 1 zi = zi_func(*coefs) zi_expand = (None,) * axis_idx + (slice(None),) + (None,) * n_tail - n_tile = msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :] + n_tile = ( + msg_in.data.shape[:axis_idx] + (1,) + msg_in.data.shape[axis_idx + 1 :] + ) if coef_type == "sos": # sos zi must keep its leading dimension (`order / 2` for low|high; `order` for bpass|bstop) zi_expand = (slice(None),) + zi_expand @@ -218,4 +222,7 @@ async def apply_filter(self, msg: AxisArray) -> typing.AsyncGenerator: if one_dimensional: arr_out = np.squeeze(arr_out, axis=1) - yield self.OUTPUT_SIGNAL, replace(msg, data=arr_out), + yield ( + self.OUTPUT_SIGNAL, + replace(msg, data=arr_out), + ) diff --git a/src/ezmsg/sigproc/filterbank.py b/src/ezmsg/sigproc/filterbank.py index 51d00a6..c095b0d 100644 --- a/src/ezmsg/sigproc/filterbank.py +++ b/src/ezmsg/sigproc/filterbank.py @@ -19,6 +19,7 @@ class FilterbankMode(OptionsEnum): """The mode of operation for the filterbank.""" + CONV = "Direct Convolution" FFT = "FFT Convolution" AUTO = "Automatic" @@ -26,6 +27,7 @@ class FilterbankMode(OptionsEnum): class MinPhaseMode(OptionsEnum): """The mode of operation for the filterbank.""" + NONE = "No kernel modification" HILBERT = "Hilbert Method; designed to be used with equiripple filters (e.g., from remez) with unity or zero gain regions" HOMOMORPHIC = "Works best with filters with an odd number of taps, and the resulting minimum phase filter will have a magnitude response that approximates the square root of the original filter’s magnitude response using half the number of taps" @@ -80,10 +82,13 @@ def filterbank( axis = axis or msg_in.dims[0] gain = msg_in.axes[axis].gain if axis in msg_in.axes else 1.0 targ_ax_ix = msg_in.get_axis_idx(axis) - in_shape = msg_in.data.shape[:targ_ax_ix] + msg_in.data.shape[targ_ax_ix + 1:] + in_shape = msg_in.data.shape[:targ_ax_ix] + msg_in.data.shape[targ_ax_ix + 1 :] b_reset = msg_in.key != check_input["key"] - b_reset = b_reset or (gain != check_input["gain"] and mode in [FilterbankMode.FFT, FilterbankMode.AUTO]) + b_reset = b_reset or ( + gain != check_input["gain"] + and mode in [FilterbankMode.FFT, FilterbankMode.AUTO] + ) b_reset = b_reset or msg_in.data.dtype.kind != check_input["kind"] b_reset = b_reset or in_shape != check_input["shape"] if b_reset: @@ -99,12 +104,16 @@ def filterbank( # MinPhaseMode.HOMOMORPHICFULL: ("homomorphic", True), }[min_phase] kernels = [ - sps.minimum_phase(k, method=method) # , half=half) -- half requires later scipy >= 1.14 + sps.minimum_phase( + k, method=method + ) # , half=half) -- half requires later scipy >= 1.14 for k in kernels ] # Determine if this will be operating with complex data. - b_complex = msg_in.data.dtype.kind == "c" or any([_.dtype.kind == "c" for _ in kernels]) + b_complex = msg_in.data.dtype.kind == "c" or any( + [_.dtype.kind == "c" for _ in kernels] + ) # Calculate window_dur, window_shift, nfft max_kernel_len = max([_.size for _ in kernels]) @@ -120,8 +129,10 @@ def filterbank( dummy_shape = in_shape + (len(kernels), 0) template = AxisArray( data=np.zeros(dummy_shape, dtype="complex" if b_complex else "float"), - dims=msg_in.dims[:targ_ax_ix] + msg_in.dims[targ_ax_ix + 1:] + [new_axis, axis], - axes=msg_in.axes.copy() # We do not have info for kernel/filter axis :(. + dims=msg_in.dims[:targ_ax_ix] + + msg_in.dims[targ_ax_ix + 1 :] + + [new_axis, axis], + axes=msg_in.axes.copy(), # We do not have info for kernel/filter axis :(. ) # Determine optimal mode. Assumes 100 msec chunks. @@ -136,15 +147,21 @@ def filterbank( if mode == FilterbankMode.CONV: # Preallocate memory for convolution result and overlap-add - dest_shape = in_shape + (len(kernels), overlap + msg_in.data.shape[targ_ax_ix]) - dest_arr = np.zeros(dest_shape, dtype="complex" if b_complex else "float") + dest_shape = in_shape + ( + len(kernels), + overlap + msg_in.data.shape[targ_ax_ix], + ) + dest_arr = np.zeros( + dest_shape, dtype="complex" if b_complex else "float" + ) elif mode == FilterbankMode.FFT: # Calculate optimal nfft and windowing size. opt_size = -overlap * lambertw(-1 / (2 * math.e * overlap), k=-1).real nfft = sp_fft.next_fast_len(math.ceil(opt_size)) win_len = nfft - overlap - infft = win_len + overlap # Same as nfft. Keeping as separate variable because I might need it again. + # infft same as nfft. Keeping as separate variable because I might need it again. + infft = win_len + overlap # Create windowing node. # Note: We could do windowing manually to avoid the overhead of the message structure, @@ -182,7 +199,9 @@ def filterbank( in_dat = np.moveaxis(msg_in.data, targ_ax_ix, -1) if mode == FilterbankMode.FFT: # Fix msg_in .dims because we will pass it to wingen - move_dims = msg_in.dims[:targ_ax_ix] + msg_in.dims[targ_ax_ix + 1:] + [axis] + move_dims = ( + msg_in.dims[:targ_ax_ix] + msg_in.dims[targ_ax_ix + 1 :] + [axis] + ) msg_in = replace(msg_in, data=in_dat, dims=move_dims) else: in_dat = msg_in.data @@ -196,13 +215,15 @@ def filterbank( # TODO: Parallelize this loop. for k_ix, k in enumerate(kernels): n_out = in_dat.shape[-1] + k.shape[-1] - 1 - dest_arr[..., k_ix, :n_out] = np.apply_along_axis(np.convolve, -1, in_dat, k, mode="full") + dest_arr[..., k_ix, :n_out] = np.apply_along_axis( + np.convolve, -1, in_dat, k, mode="full" + ) dest_arr[..., :overlap] += tail # Add previous overlap - new_tail = dest_arr[..., in_dat.shape[-1]:n_dest] + new_tail = dest_arr[..., in_dat.shape[-1] : n_dest] if new_tail.size > 0: # COPY overlap for next iteration tail = new_tail.copy() - res = dest_arr[..., :in_dat.shape[-1]].copy() + res = dest_arr[..., : in_dat.shape[-1]].copy() elif mode == FilterbankMode.FFT: # Slice into non-overlapping windows win_msg = wingen.send(msg_in) @@ -217,9 +238,12 @@ def filterbank( overlapped = ifft(conv_spec, axis=-1) # Do the overlap-add on the `axis` axis - overlapped[..., :1, :overlap] += tail # Previous iteration's tail - overlapped[..., 1:, :overlap] += overlapped[..., :-1, -overlap:] # window-to-window - new_tail = overlapped[..., -1:, -overlap:] # Save tail + # Previous iteration's tail: + overlapped[..., :1, :overlap] += tail + # window-to-window: + overlapped[..., 1:, :overlap] += overlapped[..., :-1, -overlap:] + # Save tail: + new_tail = overlapped[..., -1:, -overlap:] if new_tail.size > 0: # All of the above code works if input is size-zero, but we don't want to save a zero-size tail. tail = new_tail # Save the tail for the next iteration. @@ -227,9 +251,7 @@ def filterbank( res = overlapped[..., :-overlap].reshape(overlapped.shape[:-2] + (-1,)) msg_out = replace( - template, - data=res, - axes={**template.axes, axis: msg_in.axes[axis]} + template, data=res, axes={**template.axes, axis: msg_in.axes[axis]} ) @@ -242,6 +264,7 @@ class FilterbankSettings(ez.Settings): class Filterbank(GenAxisArray): """Unit for :obj:`spectrum`""" + SETTINGS = FilterbankSettings INPUT_SETTINGS = ez.InputStream(FilterbankSettings) @@ -251,5 +274,5 @@ def construct_generator(self): kernels=self.SETTINGS.kernels, mode=self.SETTINGS.mode, min_phase=self.SETTINGS.min_phase, - axis=self.SETTINGS.axis + axis=self.SETTINGS.axis, ) diff --git a/src/ezmsg/sigproc/math/abs.py b/src/ezmsg/sigproc/math/abs.py index 937c2a6..8f52138 100644 --- a/src/ezmsg/sigproc/math/abs.py +++ b/src/ezmsg/sigproc/math/abs.py @@ -10,8 +10,7 @@ @consumer -def abs( -) -> typing.Generator[AxisArray, AxisArray, None]: +def abs() -> typing.Generator[AxisArray, AxisArray, None]: msg_out = AxisArray(np.array([]), dims=[""]) while True: msg_in = yield msg_out diff --git a/src/ezmsg/sigproc/math/clip.py b/src/ezmsg/sigproc/math/clip.py index 025dca6..09a04b5 100644 --- a/src/ezmsg/sigproc/math/clip.py +++ b/src/ezmsg/sigproc/math/clip.py @@ -10,10 +10,7 @@ @consumer -def clip( - a_min: float, - a_max: float -) -> typing.Generator[AxisArray, AxisArray, None]: +def clip(a_min: float, a_max: float) -> typing.Generator[AxisArray, AxisArray, None]: msg_in = AxisArray(np.array([]), dims=[""]) msg_out = AxisArray(np.array([]), dims=[""]) while True: @@ -30,7 +27,4 @@ class Clip(GenAxisArray): SETTINGS = ClipSettings def construct_generator(self): - self.STATE.gen = clip( - a_min=self.SETTINGS.a_min, - a_max=self.SETTINGS.a_max - ) + self.STATE.gen = clip(a_min=self.SETTINGS.a_min, a_max=self.SETTINGS.a_max) diff --git a/src/ezmsg/sigproc/math/difference.py b/src/ezmsg/sigproc/math/difference.py index e111399..55df6c8 100644 --- a/src/ezmsg/sigproc/math/difference.py +++ b/src/ezmsg/sigproc/math/difference.py @@ -11,8 +11,7 @@ @consumer def const_difference( - value: float = 0.0, - subtrahend: bool = True + value: float = 0.0, subtrahend: bool = True ) -> typing.Generator[AxisArray, AxisArray, None]: """ result = (in_data - value) if subtrahend else (value - in_data) @@ -21,7 +20,9 @@ def const_difference( msg_out = AxisArray(np.array([]), dims=[""]) while True: msg_in: AxisArray = yield msg_out - msg_out = replace(msg_in, data=(msg_in.data - value) if subtrahend else (value - msg_in.data)) + msg_out = replace( + msg_in, data=(msg_in.data - value) if subtrahend else (value - msg_in.data) + ) class ConstDifferenceSettings(ez.Settings): @@ -34,10 +35,10 @@ class ConstDifference(GenAxisArray): def construct_generator(self): self.STATE.gen = const_difference( - value=self.SETTINGS.value, - subtrahend=self.SETTINGS.subtrahend + value=self.SETTINGS.value, subtrahend=self.SETTINGS.subtrahend ) + # class DifferenceSettings(ez.Settings): # pass # diff --git a/src/ezmsg/sigproc/math/invert.py b/src/ezmsg/sigproc/math/invert.py index e5a327e..9432bda 100644 --- a/src/ezmsg/sigproc/math/invert.py +++ b/src/ezmsg/sigproc/math/invert.py @@ -10,8 +10,7 @@ @consumer -def invert( -) -> typing.Generator[AxisArray, AxisArray, None]: +def invert() -> typing.Generator[AxisArray, AxisArray, None]: msg_in = AxisArray(np.array([]), dims=[""]) msg_out = AxisArray(np.array([]), dims=[""]) while True: diff --git a/src/ezmsg/sigproc/math/log.py b/src/ezmsg/sigproc/math/log.py index d55078a..0cb368c 100644 --- a/src/ezmsg/sigproc/math/log.py +++ b/src/ezmsg/sigproc/math/log.py @@ -29,6 +29,4 @@ class Log(GenAxisArray): SETTINGS = LogSettings def construct_generator(self): - self.STATE.gen = log( - base=self.SETTINGS.base - ) + self.STATE.gen = log(base=self.SETTINGS.base) diff --git a/src/ezmsg/sigproc/math/scale.py b/src/ezmsg/sigproc/math/scale.py index c5aeeb7..1c6f40b 100644 --- a/src/ezmsg/sigproc/math/scale.py +++ b/src/ezmsg/sigproc/math/scale.py @@ -10,9 +10,7 @@ @consumer -def scale( - scale: float = 1.0 -) -> typing.Generator[AxisArray, AxisArray, None]: +def scale(scale: float = 1.0) -> typing.Generator[AxisArray, AxisArray, None]: msg_in = AxisArray(np.array([]), dims=[""]) msg_out = AxisArray(np.array([]), dims=[""]) while True: diff --git a/src/ezmsg/sigproc/sampler.py b/src/ezmsg/sigproc/sampler.py index 3279643..a75ff0f 100644 --- a/src/ezmsg/sigproc/sampler.py +++ b/src/ezmsg/sigproc/sampler.py @@ -25,7 +25,6 @@ class SampleTriggerMessage: @dataclass class SampleMessage: - trigger: SampleTriggerMessage """The time, window, and value (if any) associated with the trigger.""" @@ -35,12 +34,14 @@ class SampleMessage: @consumer def sampler( - buffer_dur: float, - axis: typing.Optional[str] = None, - period: typing.Optional[typing.Tuple[float, float]] = None, - value: typing.Any = None, - estimate_alignment: bool = True -) -> typing.Generator[typing.List[SampleMessage], typing.Union[AxisArray, SampleTriggerMessage], None]: + buffer_dur: float, + axis: typing.Optional[str] = None, + period: typing.Optional[typing.Tuple[float, float]] = None, + value: typing.Any = None, + estimate_alignment: bool = True, +) -> typing.Generator[ + typing.List[SampleMessage], typing.Union[AxisArray, SampleTriggerMessage], None +]: """ A generator function that samples data into a buffer, accepts triggers, and returns slices of sampled data around the trigger time. @@ -98,16 +99,16 @@ def sampler( # Check that period is valid if _period[0] >= _period[1]: - ez.logger.warning(f"Sampling failed: invalid period requested ({_period})") + ez.logger.warning( + f"Sampling failed: invalid period requested ({_period})" + ) continue # Check that period is compatible with buffer duration. max_buf_len = int(np.round(buffer_dur * check_inputs["fs"])) req_buf_len = int(np.round((_period[1] - _period[0]) * check_inputs["fs"])) if req_buf_len >= max_buf_len: - ez.logger.warning( - f"Sampling failed: {period=} >= {buffer_dur=}" - ) + ez.logger.warning(f"Sampling failed: {period=} >= {buffer_dur=}") continue trigger_ts: float = msg_in.timestamp @@ -115,7 +116,9 @@ def sampler( # Override the trigger timestamp with the next sample's likely timestamp. trigger_ts = offset + (n_samples + 1) / check_inputs["fs"] - new_trig_msg = replace(msg_in, timestamp=trigger_ts, period=_period, value=_value) + new_trig_msg = replace( + msg_in, timestamp=trigger_ts, period=_period, value=_value + ) triggers.append(new_trig_msg) elif isinstance(msg_in, AxisArray): @@ -124,7 +127,9 @@ def sampler( axis_idx = msg_in.get_axis_idx(axis) axis_info = msg_in.get_axis(axis) fs = 1.0 / axis_info.gain - sample_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1:] + sample_shape = ( + msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :] + ) # TODO: We could accommodate change in dim order. # if axis_idx != check_inputs["axis_idx"]: @@ -138,7 +143,8 @@ def sampler( # If the properties have changed in a breaking way then reset buffer and triggers. b_reset = fs != check_inputs["fs"] b_reset = b_reset or sample_shape != check_inputs["shape"] - b_reset = b_reset or axis_idx != check_inputs["axis_idx"] # TODO: Skip this if we do np.moveaxis above + # TODO: Skip next line if we do np.moveaxis above + b_reset = b_reset or axis_idx != check_inputs["axis_idx"] b_reset = b_reset or msg_in.key != check_inputs["key"] if b_reset: check_inputs["fs"] = fs @@ -155,7 +161,11 @@ def sampler( offset = axis_info.offset # Update buffer - buffer = msg_in.data if buffer is None else np.concatenate((buffer, msg_in.data), axis=axis_idx) + buffer = ( + msg_in.data + if buffer is None + else np.concatenate((buffer, msg_in.data), axis=axis_idx) + ) # Calculate timestamps associated with buffer. buffer_offset = np.arange(buffer.shape[axis_idx], dtype=float) @@ -190,9 +200,16 @@ def sampler( trigger=trig, sample=replace( msg_in, - data=slice_along_axis(buffer, slice(start, stop), axis_idx), - axes={**msg_in.axes, axis: replace(axis_info, offset=buffer_offset[start])} - ) + data=slice_along_axis( + buffer, slice(start, stop), axis_idx + ), + axes={ + **msg_in.axes, + axis: replace( + axis_info, offset=buffer_offset[start] + ), + }, + ), ) ) triggers.remove(trig) @@ -206,28 +223,34 @@ class SamplerSettings(ez.Settings): Settings for :obj:`Sampler`. See :obj:`sampler` for a description of the fields. """ + buffer_dur: float axis: typing.Optional[str] = None - period: typing.Optional[ - typing.Tuple[float, float] - ] = None # Optional default period if unspecified in SampleTriggerMessage - value: typing.Any = None # Optional default value if unspecified in SampleTriggerMessage + period: typing.Optional[typing.Tuple[float, float]] = None + """Optional default period if unspecified in SampleTriggerMessage""" + + value: typing.Any = None + """Optional default value if unspecified in SampleTriggerMessage""" estimate_alignment: bool = True - # If true, use message timestamp fields and reported sampling rate to estimate - # sample-accurate alignment for samples. - # If false, sampling will be limited to incoming message rate -- "Block timing" - # NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect - # "realtime" operation for estimate_alignment to operate correctly. + """ + If true, use message timestamp fields and reported sampling rate to estimate sample-accurate alignment for samples. + If false, sampling will be limited to incoming message rate -- "Block timing" + NOTE: For faster-than-realtime playback -- Incoming timestamps must reflect + "realtime" operation for estimate_alignment to operate correctly. + """ class SamplerState(ez.State): cur_settings: SamplerSettings - gen: typing.Generator[typing.Union[AxisArray, SampleTriggerMessage], typing.List[SampleMessage], None] + gen: typing.Generator[ + typing.Union[AxisArray, SampleTriggerMessage], typing.List[SampleMessage], None + ] class Sampler(ez.Unit): """An :obj:`Unit` for :obj:`sampler`.""" + SETTINGS = SamplerSettings STATE = SamplerState @@ -242,7 +265,7 @@ def construct_generator(self): axis=self.STATE.cur_settings.axis, period=self.STATE.cur_settings.period, value=self.STATE.cur_settings.value, - estimate_alignment=self.STATE.cur_settings.estimate_alignment + estimate_alignment=self.STATE.cur_settings.estimate_alignment, ) async def initialize(self) -> None: diff --git a/src/ezmsg/sigproc/scaler.py b/src/ezmsg/sigproc/scaler.py index f4f5ee3..bee583d 100644 --- a/src/ezmsg/sigproc/scaler.py +++ b/src/ezmsg/sigproc/scaler.py @@ -46,6 +46,7 @@ def scaler( standardized, or "Z-scored" version of the input. """ from river import preprocessing + msg_out = AxisArray(np.array([]), dims=[""]) _scaler = None while True: @@ -78,8 +79,7 @@ def scaler( @consumer def scaler_np( - time_constant: float = 1.0, - axis: typing.Optional[str] = None + time_constant: float = 1.0, axis: typing.Optional[str] = None ) -> typing.Generator[AxisArray, AxisArray, None]: """ Create a generator function that applies an adaptive standard scaler. @@ -143,8 +143,8 @@ def _ew_update(arr, prev, _alpha): vars_sq_means = _ew_update(sample**2, vars_sq_means, alpha) means = _ew_update(sample, means, alpha) # Get step - varis = vars_sq_means - vars_means ** 2 - y = ((sample - means) / (varis**0.5)) + varis = vars_sq_means - vars_means**2 + y = (sample - means) / (varis**0.5) result[sample_ix] = y result[np.isnan(result)] = 0.0 @@ -157,16 +157,17 @@ class AdaptiveStandardScalerSettings(ez.Settings): Settings for :obj:`AdaptiveStandardScaler`. See :obj:`scaler_np` for a description of the parameters. """ + time_constant: float = 1.0 axis: typing.Optional[str] = None class AdaptiveStandardScaler(GenAxisArray): """Unit for :obj:`scaler_np`""" + SETTINGS = AdaptiveStandardScalerSettings def construct_generator(self): self.STATE.gen = scaler_np( - time_constant=self.SETTINGS.time_constant, - axis=self.SETTINGS.axis + time_constant=self.SETTINGS.time_constant, axis=self.SETTINGS.axis ) diff --git a/src/ezmsg/sigproc/signalinjector.py b/src/ezmsg/sigproc/signalinjector.py index 561d527..7e14af6 100644 --- a/src/ezmsg/sigproc/signalinjector.py +++ b/src/ezmsg/sigproc/signalinjector.py @@ -8,8 +8,8 @@ class SignalInjectorSettings(ez.Settings): - time_dim: str = 'time' # Input signal needs a time dimension with units in sec. - frequency: typing.Optional[float] = None # Hz + time_dim: str = "time" # Input signal needs a time dimension with units in sec. + frequency: typing.Optional[float] = None # Hz amplitude: float = 1.0 mixing_seed: typing.Optional[int] = None @@ -46,7 +46,6 @@ async def on_amplitude(self, msg: float) -> None: @ez.subscriber(INPUT_SIGNAL) @ez.publisher(OUTPUT_SIGNAL) async def inject(self, msg: AxisArray) -> typing.AsyncGenerator: - if self.STATE.cur_shape != msg.shape: self.STATE.cur_shape = msg.shape rng = np.random.default_rng(self.SETTINGS.mixing_seed) @@ -56,10 +55,10 @@ async def inject(self, msg: AxisArray) -> typing.AsyncGenerator: if self.STATE.cur_frequency is None: yield self.OUTPUT_SIGNAL, msg else: - out_msg = replace(msg, data = msg.data.copy()) + out_msg = replace(msg, data=msg.data.copy()) t = out_msg.ax(self.SETTINGS.time_dim).values[..., np.newaxis] signal = np.sin(2 * np.pi * self.STATE.cur_frequency * t) mixed_signal = signal * self.STATE.mixing * self.STATE.cur_amplitude with out_msg.view2d(self.SETTINGS.time_dim) as view: view[...] = view + mixed_signal.astype(view.dtype) - yield self.OUTPUT_SIGNAL, out_msg \ No newline at end of file + yield self.OUTPUT_SIGNAL, out_msg diff --git a/src/ezmsg/sigproc/slicer.py b/src/ezmsg/sigproc/slicer.py index 7f6b724..37e551a 100644 --- a/src/ezmsg/sigproc/slicer.py +++ b/src/ezmsg/sigproc/slicer.py @@ -54,11 +54,11 @@ def slicer( _slice: typing.Optional[typing.Union[slice, npt.NDArray]] = None new_axis: typing.Optional[AxisArray.Axis] = None b_change_dims: bool = False # If number of dimensions changes when slicing - + # Reset if input changes check_input = { "key": None, # key change used as proxy for label change, which we don't check explicitly - "len": None + "len": None, } while True: @@ -70,7 +70,8 @@ def slicer( b_reset = _slice is None # or new_axis is None b_reset = b_reset or msg_in.key != check_input["key"] b_reset = b_reset or ( - (msg_in.data.shape[axis_idx] != check_input["len"]) and (type(_slice) is np.ndarray) + (msg_in.data.shape[axis_idx] != check_input["len"]) + and (type(_slice) is np.ndarray) ) if b_reset: check_input["key"] = msg_in.key @@ -82,7 +83,8 @@ def slicer( _slices = parse_slice(selection) if len(_slices) == 1: _slice = _slices[0] - b_change_dims = isinstance(_slice, int) # If we drop the sliced dimension + # Do we drop the sliced dimension? + b_change_dims = isinstance(_slice, int) else: # Multiple slices, but this cannot be done in a single step, so we convert the slices # to a discontinuous set of integer indexes. @@ -91,26 +93,29 @@ def slicer( _slice = np.s_[indices] # Integer scalar array # Create the output axis. - if (axis in msg_in.axes - and hasattr(msg_in.axes[axis], "labels") - and len(msg_in.axes[axis].labels) > 0): + if ( + axis in msg_in.axes + and hasattr(msg_in.axes[axis], "labels") + and len(msg_in.axes[axis].labels) > 0 + ): new_labels = msg_in.axes[axis].labels[_slice] - new_axis = replace( - msg_in.axes[axis], - labels=new_labels - ) + new_axis = replace(msg_in.axes[axis], labels=new_labels) replace_kwargs = {} if b_change_dims: # Dropping the target axis - replace_kwargs["dims"] = [_ for dim_ix, _ in enumerate(msg_in.dims) if dim_ix != axis_idx] + replace_kwargs["dims"] = [ + _ for dim_ix, _ in enumerate(msg_in.dims) if dim_ix != axis_idx + ] replace_kwargs["axes"] = {k: v for k, v in msg_in.axes.items() if k != axis} elif new_axis is not None: - replace_kwargs["axes"] = {k: (v if k != axis else new_axis) for k, v in msg_in.axes.items()} + replace_kwargs["axes"] = { + k: (v if k != axis else new_axis) for k, v in msg_in.axes.items() + } msg_out = replace( msg_in, data=slice_along_axis(msg_in.data, _slice, axis_idx), - **replace_kwargs + **replace_kwargs, ) diff --git a/src/ezmsg/sigproc/spectrogram.py b/src/ezmsg/sigproc/spectrogram.py index a547b2f..de41a2a 100644 --- a/src/ezmsg/sigproc/spectrogram.py +++ b/src/ezmsg/sigproc/spectrogram.py @@ -6,10 +6,7 @@ from ezmsg.util.messages.modify import modify_axis from .window import windowing -from .spectrum import ( - spectrum, - WindowFunction, SpectralTransform, SpectralOutput -) +from .spectrum import spectrum, WindowFunction, SpectralTransform, SpectralOutput from .base import GenAxisArray @@ -19,7 +16,7 @@ def spectrogram( window_shift: typing.Optional[float] = None, window: WindowFunction = WindowFunction.HANNING, transform: SpectralTransform = SpectralTransform.REL_DB, - output: SpectralOutput = SpectralOutput.POSITIVE + output: SpectralOutput = SpectralOutput.POSITIVE, ) -> typing.Generator[typing.Optional[AxisArray], AxisArray, None]: """ Calculate a spectrogram on streaming data. @@ -41,9 +38,11 @@ def spectrogram( """ pipeline = compose( - windowing(axis="time", newaxis="win", window_dur=window_dur, window_shift=window_shift), + windowing( + axis="time", newaxis="win", window_dur=window_dur, window_shift=window_shift + ), spectrum(axis="time", window=window, transform=transform, output=output), - modify_axis(name_map={"win": "time"}) + modify_axis(name_map={"win": "time"}), ) # State variables @@ -59,8 +58,11 @@ class SpectrogramSettings(ez.Settings): Settings for :obj:`Spectrogram`. See :obj:`spectrogram` for a description of the parameters. """ + window_dur: typing.Optional[float] = None # window duration in seconds - window_shift: typing.Optional[float] = None # window step in seconds. If None, window_shift == window_dur + window_shift: typing.Optional[float] = None + """"window step in seconds. If None, window_shift == window_dur""" + # See SpectrumSettings for details of following settings: window: WindowFunction = WindowFunction.HAMMING transform: SpectralTransform = SpectralTransform.REL_DB @@ -71,6 +73,7 @@ class Spectrogram(GenAxisArray): """ Unit for :obj:`spectrogram`. """ + SETTINGS = SpectrogramSettings def construct_generator(self): @@ -79,5 +82,5 @@ def construct_generator(self): window_shift=self.SETTINGS.window_shift, window=self.SETTINGS.window, transform=self.SETTINGS.transform, - output=self.SETTINGS.output + output=self.SETTINGS.output, ) diff --git a/src/ezmsg/sigproc/spectrum.py b/src/ezmsg/sigproc/spectrum.py index 05ae8f4..e64392e 100644 --- a/src/ezmsg/sigproc/spectrum.py +++ b/src/ezmsg/sigproc/spectrum.py @@ -18,7 +18,8 @@ def options(cls): class WindowFunction(OptionsEnum): - """Windowing function prior to calculating spectrum. """ + """Windowing function prior to calculating spectrum.""" + NONE = "None (Rectangular)" """None.""" @@ -46,6 +47,7 @@ class WindowFunction(OptionsEnum): class SpectralTransform(OptionsEnum): """Additional transformation functions to apply to the spectral result.""" + RAW_COMPLEX = "Complex FFT Output" REAL = "Real Component of FFT" IMAG = "Imaginary Component of FFT" @@ -55,6 +57,7 @@ class SpectralTransform(OptionsEnum): class SpectralOutput(OptionsEnum): """The expected spectral contents.""" + FULL = "Full Spectrum" POSITIVE = "Positive Frequencies" NEGATIVE = "Negative Frequencies" @@ -109,7 +112,7 @@ def spectrum( "ndim": None, # Input ndim changed: Need to recalc windows "kind": None, # Input dtype changed: Need to re-init fft funcs "ax_idx": None, # Axis index changed: Need to re-init fft funcs - "gain": None # Gain changed: Need to re-calc freqs + "gain": None, # Gain changed: Need to re-calc freqs # "key": None # There's no temporal continuity; we can ignore key changes } @@ -139,10 +142,18 @@ def spectrum( # Pre-calculate windowing window = WINDOWS[window](targ_len) - window = window.reshape([1] * ax_idx + [len(window),] + [1] * (msg_in.data.ndim - 1 - ax_idx)) - if (transform != SpectralTransform.RAW_COMPLEX and - not (transform == SpectralTransform.REAL or transform == SpectralTransform.IMAG)): - scale = np.sum(window ** 2.0) * ax_info.gain + window = window.reshape( + [1] * ax_idx + + [ + len(window), + ] + + [1] * (msg_in.data.ndim - 1 - ax_idx) + ) + if transform != SpectralTransform.RAW_COMPLEX and not ( + transform == SpectralTransform.REAL + or transform == SpectralTransform.IMAG + ): + scale = np.sum(window**2.0) * ax_info.gain # Pre-calculate frequencies and select our fft function. b_complex = msg_in.data.dtype.kind == "c" @@ -168,18 +179,35 @@ def spectrum( ) if out_axis is None: out_axis = axis - new_dims = msg_in.dims[:ax_idx] + [out_axis, ] + msg_in.dims[ax_idx + 1:] + new_dims = ( + msg_in.dims[:ax_idx] + + [ + out_axis, + ] + + msg_in.dims[ax_idx + 1 :] + ) + + def f_transform(x): + return x - def f_transform(x): return x if transform != SpectralTransform.RAW_COMPLEX: if transform == SpectralTransform.REAL: - def f_transform(x): return x.real + + def f_transform(x): + return x.real elif transform == SpectralTransform.IMAG: - def f_transform(x): return x.imag + + def f_transform(x): + return x.imag else: - def f1(x): return (np.abs(x) ** 2.0) / scale + + def f1(x): + return (np.abs(x) ** 2.0) / scale + if transform == SpectralTransform.REL_DB: - def f_transform(x): return 10 * np.log10(f1(x)) + + def f_transform(x): + return 10 * np.log10(f1(x)) else: f_transform = f1 @@ -190,7 +218,8 @@ def f_transform(x): return 10 * np.log10(f1(x)) win_dat = msg_in.data * window else: win_dat = msg_in.data - spec = fftfun(win_dat, n=nfft, axis=ax_idx, norm=norm) # norm="forward" equivalent to `/ nfft` + spec = fftfun(win_dat, n=nfft, axis=ax_idx, norm=norm) + # Note: norm="forward" equivalent to `/ nfft` if do_fftshift or output == SpectralOutput.NEGATIVE: spec = np.fft.fftshift(spec, axes=ax_idx) spec = f_transform(spec) @@ -204,6 +233,7 @@ class SpectrumSettings(ez.Settings): Settings for :obj:`Spectrum. See :obj:`spectrum` for a description of the parameters. """ + axis: typing.Optional[str] = None # n: typing.Optional[int] = None # n parameter for fft out_axis: typing.Optional[str] = "freq" # If none; don't change dim name @@ -214,6 +244,7 @@ class SpectrumSettings(ez.Settings): class Spectrum(GenAxisArray): """Unit for :obj:`spectrum`""" + SETTINGS = SpectrumSettings INPUT_SETTINGS = ez.InputStream(SpectrumSettings) @@ -224,5 +255,5 @@ def construct_generator(self): out_axis=self.SETTINGS.out_axis, window=self.SETTINGS.window, transform=self.SETTINGS.transform, - output=self.SETTINGS.output + output=self.SETTINGS.output, ) diff --git a/src/ezmsg/sigproc/synth.py b/src/ezmsg/sigproc/synth.py index 7255371..c163c91 100644 --- a/src/ezmsg/sigproc/synth.py +++ b/src/ezmsg/sigproc/synth.py @@ -12,9 +12,7 @@ from .base import GenAxisArray -def clock( - dispatch_rate: Optional[float] -) -> Generator[ez.Flag, None, None]: +def clock(dispatch_rate: Optional[float]) -> Generator[ez.Flag, None, None]: """ Construct a generator that yields events at a specified rate. @@ -34,9 +32,7 @@ def clock( yield ez.Flag() -async def aclock( - dispatch_rate: Optional[float] -) -> AsyncGenerator[ez.Flag, None]: +async def aclock(dispatch_rate: Optional[float]) -> AsyncGenerator[ez.Flag, None]: """ ``asyncio`` version of :obj:`clock`. @@ -55,6 +51,7 @@ async def aclock( class ClockSettings(ez.Settings): """Settings for :obj:`Clock`. See :obj:`clock` for parameter description.""" + # Message dispatch rate (Hz), or None (fast as possible) dispatch_rate: Optional[float] @@ -66,6 +63,7 @@ class ClockState(ez.State): class Clock(ez.Unit): """Unit for :obj:`clock`.""" + SETTINGS = ClockSettings STATE = ClockState @@ -238,27 +236,26 @@ def construct_generator(self): self.STATE.cur_settings.fs, n_ch=self.STATE.cur_settings.n_ch, dispatch_rate=self.STATE.cur_settings.dispatch_rate, - mod=self.STATE.cur_settings.mod + mod=self.STATE.cur_settings.mod, ) self.STATE.new_generator.set() - + @ez.subscriber(INPUT_CLOCK) @ez.publisher(OUTPUT_SIGNAL) async def on_clock(self, clock: ez.Flag): - if self.STATE.cur_settings.dispatch_rate == 'ext_clock': + if self.STATE.cur_settings.dispatch_rate == "ext_clock": out = await self.STATE.gen.__anext__() yield self.OUTPUT_SIGNAL, out @ez.publisher(OUTPUT_SIGNAL) async def run_generator(self) -> AsyncGenerator: while True: - await self.STATE.new_generator.wait() self.STATE.new_generator.clear() - - if self.STATE.cur_settings.dispatch_rate == 'ext_clock': + + if self.STATE.cur_settings.dispatch_rate == "ext_clock": continue - + while not self.STATE.new_generator.is_set(): out = await self.STATE.gen.__anext__() yield self.OUTPUT_SIGNAL, out @@ -306,6 +303,7 @@ class SinGeneratorSettings(ez.Settings): Settings for :obj:`SinGenerator`. See :obj:`sin` for parameter descriptions. """ + time_axis: Optional[str] = "time" freq: float = 1.0 # Oscillation frequency in Hz amp: float = 1.0 # Amplitude @@ -316,6 +314,7 @@ class SinGenerator(GenAxisArray): """ Unit for :obj:`sin`. """ + SETTINGS = SinGeneratorSettings def construct_generator(self): @@ -323,12 +322,13 @@ def construct_generator(self): axis=self.SETTINGS.time_axis, freq=self.SETTINGS.freq, amp=self.SETTINGS.amp, - phase=self.SETTINGS.phase + phase=self.SETTINGS.phase, ) class OscillatorSettings(ez.Settings): """Settings for :obj:`Oscillator`""" + n_time: int """Number of samples to output per block.""" @@ -358,6 +358,7 @@ class Oscillator(ez.Collection): """ :obj:`Collection that chains :obj:`Counter` and :obj:`SinGenerator`. """ + SETTINGS = OscillatorSettings INPUT_CLOCK = ez.InputStream(ez.Flag) @@ -411,6 +412,7 @@ class RandomGenerator(ez.Unit): """ Replaces input data with random data and yields the result. """ + SETTINGS = RandomGeneratorSettings INPUT_SIGNAL = ez.InputStream(AxisArray) @@ -430,12 +432,12 @@ class NoiseSettings(ez.Settings): """ See :obj:`CounterSettings` and :obj:`RandomGeneratorSettings`. """ + n_time: int # Number of samples to output per block fs: float # Sampling rate of signal output in Hz n_ch: int = 1 # Number of channels to output - dispatch_rate: Optional[ - Union[float, str] - ] = None # (Hz), 'realtime', or 'ext_clock' + dispatch_rate: Optional[Union[float, str]] = None + """(Hz), 'realtime', or 'ext_clock'""" loc: float = 0.0 # DC offset scale: float = 1.0 # Scale (in standard deviations) @@ -447,6 +449,7 @@ class WhiteNoise(ez.Collection): """ A :obj:`Collection` that chains a :obj:`Counter` and :obj:`RandomGenerator`. """ + SETTINGS = NoiseSettings INPUT_CLOCK = ez.InputStream(ez.Flag) @@ -485,6 +488,7 @@ class PinkNoise(ez.Collection): """ A :obj:`Collection` that chains :obj:`WhiteNoise` and :obj:`ButterworthFilter`. """ + SETTINGS = PinkNoiseSettings INPUT_CLOCK = ez.InputStream(ez.Flag) @@ -497,7 +501,9 @@ def configure(self) -> None: self.WHITE_NOISE.apply_settings(self.SETTINGS) self.FILTER.apply_settings( ButterworthFilterSettings( - axis="time", order=1, cutoff=self.SETTINGS.fs * 0.01 # Hz + axis="time", + order=1, + cutoff=self.SETTINGS.fs * 0.01, # Hz ) ) @@ -542,6 +548,7 @@ async def output(self) -> AsyncGenerator: class EEGSynthSettings(ez.Settings): """See :obj:`OscillatorSettings`.""" + fs: float = 500.0 # Hz n_time: int = 100 alpha_freq: float = 10.5 # Hz @@ -553,6 +560,7 @@ class EEGSynth(ez.Collection): A :obj:`Collection` that chains a :obj:`Clock` to both :obj:`PinkNoise` and :obj:`Oscillator`, then :obj:`Add` s the result. """ + SETTINGS = EEGSynthSettings OUTPUT_SIGNAL = ez.OutputStream(AxisArray) diff --git a/src/ezmsg/sigproc/wavelets.py b/src/ezmsg/sigproc/wavelets.py index 41d31f5..08925c0 100644 --- a/src/ezmsg/sigproc/wavelets.py +++ b/src/ezmsg/sigproc/wavelets.py @@ -56,13 +56,13 @@ def cwt( "kind": None, # Need to recalc kernels at same complexity as input "gain": None, # Need to recalc freqs "shape": None, # Need to recalc template and buffer - "key": None # Buffer obsolete + "key": None, # Buffer obsolete } while True: msg_in: AxisArray = yield msg_out ax_idx = msg_in.get_axis_idx(axis) - in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1:] + in_shape = msg_in.data.shape[:ax_idx] + msg_in.data.shape[ax_idx + 1 :] b_reset = msg_in.data.dtype.kind != check_input["kind"] b_reset = b_reset or msg_in.axes[axis].gain != check_input["gain"] @@ -78,7 +78,7 @@ def cwt( # convert int_psi, wave_xvec to the same precision as the data dt_data = msg_in.data.dtype # _check_dtype(msg_in.data) dt_cplx = np.result_type(dt_data, np.complex64) - dt_psi = dt_cplx if int_psi.dtype.kind == 'c' else dt_data + dt_psi = dt_cplx if int_psi.dtype.kind == "c" else dt_data int_psi = np.asarray(int_psi, dtype=dt_psi) # TODO: Currently int_psi cannot be made non-complex once it is complex. @@ -94,21 +94,30 @@ def cwt( int_psi_scales.append(int_psi[reix][::-1]) # CONV is probably best because we often get huge kernels. - fbgen = filterbank(int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis) + fbgen = filterbank( + int_psi_scales, mode=FilterbankMode.CONV, min_phase=min_phase, axis=axis + ) - freqs = pywt.scale2frequency(wavelet, scales, precision) / msg_in.axes[axis].gain + freqs = ( + pywt.scale2frequency(wavelet, scales, precision) + / msg_in.axes[axis].gain + ) fstep = (freqs[1] - freqs[0]) if len(freqs) > 1 else 1.0 # Create output template dummy_shape = in_shape + (len(scales), 0) template = AxisArray( - np.zeros(dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data), - dims=msg_in.dims[:ax_idx] + msg_in.dims[ax_idx + 1:] + ["freq", axis], + np.zeros( + dummy_shape, dtype=dt_cplx if wavelet.complex_cwt else dt_data + ), + dims=msg_in.dims[:ax_idx] + msg_in.dims[ax_idx + 1 :] + ["freq", axis], axes={ **msg_in.axes, - "freq": AxisArray.Axis("Hz", offset=freqs[0], gain=fstep) + "freq": AxisArray.Axis("Hz", offset=freqs[0], gain=fstep), }, ) - last_conv_samp = np.zeros(dummy_shape[:-1] + (1,), dtype=template.data.dtype) + last_conv_samp = np.zeros( + dummy_shape[:-1] + (1,), dtype=template.data.dtype + ) conv_msg = fbgen.send(msg_in) @@ -118,7 +127,7 @@ def cwt( # Store last_conv_samp for next iteration. last_conv_samp = conv_msg.data[..., -1:] - if template.data.dtype.kind != 'c': + if template.data.dtype.kind != "c": coef = coef.real # pywt.cwt slices off the beginning and end of the result where the convolution overran. We don't have @@ -126,9 +135,7 @@ def cwt( # d = (coef.shape[-1] - msg_in.data.shape[ax_idx]) / 2. # coef = coef[..., math.floor(d):-math.ceil(d)] msg_out = replace( - template, - data=coef, - axes={**template.axes, axis: msg_in.axes[axis]} + template, data=coef, axes={**template.axes, axis: msg_in.axes[axis]} ) @@ -137,6 +144,7 @@ class CWTSettings(ez.Settings): Settings for :obj:`CWT` See :obj:`cwt` for argument details. """ + scales: typing.Union[list, tuple, npt.NDArray] wavelet: typing.Union[str, pywt.ContinuousWavelet, pywt.Wavelet] min_phase: MinPhaseMode = MinPhaseMode.NONE @@ -147,6 +155,7 @@ class CWT(GenAxisArray): """ :obj:`Unit` for :obj:`common_rereference`. """ + SETTINGS = CWTSettings def construct_generator(self): @@ -154,5 +163,5 @@ def construct_generator(self): scales=self.SETTINGS.scales, wavelet=self.SETTINGS.wavelet, min_phase=self.SETTINGS.min_phase, - axis=self.SETTINGS.axis + axis=self.SETTINGS.axis, ) diff --git a/src/ezmsg/sigproc/window.py b/src/ezmsg/sigproc/window.py index 3c6f0d1..c51048c 100644 --- a/src/ezmsg/sigproc/window.py +++ b/src/ezmsg/sigproc/window.py @@ -5,7 +5,11 @@ import ezmsg.core as ez import numpy as np import numpy.typing as npt -from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis, sliding_win_oneaxis +from ezmsg.util.messages.axisarray import ( + AxisArray, + slice_along_axis, + sliding_win_oneaxis, +) from ezmsg.util.generator import consumer from .base import GenAxisArray @@ -17,7 +21,7 @@ def windowing( newaxis: str = "win", window_dur: typing.Optional[float] = None, window_shift: typing.Optional[float] = None, - zero_pad_until: str = "input" + zero_pad_until: str = "input", ) -> typing.Generator[AxisArray, AxisArray, None]: """ Construct a generator that yields windows of data from an input :obj:`AxisArray`. @@ -52,19 +56,24 @@ def windowing( ez.logger.warning("`newaxis` must not be None. Setting to 'win'.") newaxis = "win" if window_shift is None and zero_pad_until != "input": - ez.logger.warning("`zero_pad_until` must be 'input' if `window_shift` is None. " - f"Ignoring received argument value: {zero_pad_until}") + ez.logger.warning( + "`zero_pad_until` must be 'input' if `window_shift` is None. " + f"Ignoring received argument value: {zero_pad_until}" + ) zero_pad_until = "input" elif window_shift is not None and zero_pad_until == "input": - ez.logger.warning("windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size " - "of the first input. We recommend using 'shift' when `window_shift` is float-valued.") + ez.logger.warning( + "windowing is non-deterministic with `zero_pad_until='input'` as it depends on the size " + "of the first input. We recommend using 'shift' when `window_shift` is float-valued." + ) msg_out = AxisArray(np.array([]), dims=[""]) # State variables buffer: typing.Optional[npt.NDArray] = None window_samples: typing.Optional[int] = None window_shift_samples: typing.Optional[int] = None - shift_deficit: int = 0 # Number of incoming samples to ignore. Only relevant when shift > window. + # Number of incoming samples to ignore. Only relevant when shift > window.: + shift_deficit: int = 0 b_1to1 = window_shift is None newaxis_warned: bool = b_1to1 out_newaxis: typing.Optional[AxisArray.Axis] = None @@ -85,11 +94,13 @@ def windowing( fs = 1.0 / axis_info.gain if not newaxis_warned and newaxis in msg_in.dims: - ez.logger.warning(f"newaxis {newaxis} present in input dims. Using {newaxis}_win instead") + ez.logger.warning( + f"newaxis {newaxis} present in input dims. Using {newaxis}_win instead" + ) newaxis_warned = True newaxis = f"{newaxis}_win" - samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1:] + samp_shape = msg_in.data.shape[:axis_idx] + msg_in.data.shape[axis_idx + 1 :] # If buffer unset or input stats changed, create a new buffer b_reset = buffer is None @@ -112,7 +123,11 @@ def windowing( else: # i.e. zero_pad_until == "input" req_samples = msg_in.data.shape[axis_idx] n_zero = max(0, window_samples - req_samples) - buffer = np.zeros(msg_in.data.shape[:axis_idx] + (n_zero,) + msg_in.data.shape[axis_idx + 1:]) + buffer = np.zeros( + msg_in.data.shape[:axis_idx] + + (n_zero,) + + msg_in.data.shape[axis_idx + 1 :] + ) # Add new data to buffer. # Currently, we concatenate the new time samples and clip the output. @@ -144,7 +159,7 @@ def windowing( out_newaxis = replace( axis_info, gain=0.0 if b_1to1 else axis_info.gain * window_shift_samples, - offset=0.0 # offset modified per-msg below + offset=0.0, # offset modified per-msg below ) # Generate outputs. @@ -164,8 +179,12 @@ def windowing( elif buffer.shape[axis_idx] >= window_samples: # Deterministic window shifts. out_dat = sliding_win_oneaxis(buffer, window_samples, axis_idx) - out_dat = slice_along_axis(out_dat, slice(None, None, window_shift_samples), axis_idx) - offset_view = sliding_win_oneaxis(buffer_offset, window_samples, 0)[::window_shift_samples] + out_dat = slice_along_axis( + out_dat, slice(None, None, window_shift_samples), axis_idx + ) + offset_view = sliding_win_oneaxis(buffer_offset, window_samples, 0)[ + ::window_shift_samples + ] out_newaxis = replace(out_newaxis, offset=offset_view[0, 0]) # Drop expired beginning of buffer and update shift_deficit @@ -175,22 +194,16 @@ def windowing( else: # Not enough data to make a new window. Return empty data. empty_data_shape = ( - msg_in.data.shape[:axis_idx] - + (0, window_samples) - + msg_in.data.shape[axis_idx + 1:] + msg_in.data.shape[:axis_idx] + + (0, window_samples) + + msg_in.data.shape[axis_idx + 1 :] ) out_dat = np.zeros(empty_data_shape, dtype=msg_in.data.dtype) # out_newaxis will have first timestamp in input... but mostly meaningless because output is size-zero. out_newaxis = replace(out_newaxis, offset=axis_info.offset) msg_out = replace( - msg_in, - data=out_dat, - dims=out_dims, - axes={ - **out_axes, - newaxis: out_newaxis - } + msg_in, data=out_dat, dims=out_dims, axes={**out_axes, newaxis: out_newaxis} ) @@ -209,6 +222,7 @@ class WindowState(ez.State): class Window(GenAxisArray): """:obj:`Unit` for :obj:`bandpower`.""" + SETTINGS = WindowSettings INPUT_SIGNAL = ez.InputStream(AxisArray) @@ -220,7 +234,7 @@ def construct_generator(self): newaxis=self.SETTINGS.newaxis, window_dur=self.SETTINGS.window_dur, window_shift=self.SETTINGS.window_shift, - zero_pad_until=self.SETTINGS.zero_pad_until + zero_pad_until=self.SETTINGS.zero_pad_until, ) @ez.subscriber(INPUT_SIGNAL, zero_copy=True) @@ -229,28 +243,37 @@ async def on_signal(self, msg: AxisArray) -> typing.AsyncGenerator: try: out_msg = self.STATE.gen.send(msg) if out_msg.data.size > 0: - if self.SETTINGS.newaxis is not None or self.SETTINGS.window_dur is None: + if ( + self.SETTINGS.newaxis is not None + or self.SETTINGS.window_dur is None + ): # Multi-win mode or pass-through mode. yield self.OUTPUT_SIGNAL, out_msg else: # We need to split out_msg into multiple yields, dropping newaxis. axis_idx = out_msg.get_axis_idx("win") win_axis = out_msg.axes["win"] - offsets = np.arange(out_msg.data.shape[axis_idx]) * win_axis.gain + win_axis.offset + offsets = ( + np.arange(out_msg.data.shape[axis_idx]) * win_axis.gain + + win_axis.offset + ) for msg_ix in range(out_msg.data.shape[axis_idx]): # Need to drop 'win' and replace self.SETTINGS.axis from axes. _out_axes = { - **{k: v for k, v in out_msg.axes.items() if k not in ["win", self.SETTINGS.axis]}, + **{ + k: v + for k, v in out_msg.axes.items() + if k not in ["win", self.SETTINGS.axis] + }, self.SETTINGS.axis: replace( - out_msg.axes[self.SETTINGS.axis], - offset=offsets[msg_ix] - ) + out_msg.axes[self.SETTINGS.axis], offset=offsets[msg_ix] + ), } _out_msg = replace( out_msg, data=slice_along_axis(out_msg.data, msg_ix, axis_idx), - dims=out_msg.dims[:axis_idx] + out_msg.dims[axis_idx + 1:], - axes=_out_axes + dims=out_msg.dims[:axis_idx] + out_msg.dims[axis_idx + 1 :], + axes=_out_axes, ) yield self.OUTPUT_SIGNAL, _out_msg except (StopIteration, GeneratorExit): diff --git a/tests/conftest.py b/tests/conftest.py index 5e85fe6..0028bcf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,4 @@ import sys import os -sys.path.append(os.path.join(os.path.dirname(__file__), 'helpers')) + +sys.path.append(os.path.join(os.path.dirname(__file__), "helpers")) diff --git a/tests/helpers/util.py b/tests/helpers/util.py index 3ba26db..153650e 100644 --- a/tests/helpers/util.py +++ b/tests/helpers/util.py @@ -39,9 +39,9 @@ def create_messages_with_periodic_signal( {"f": 35.0, "dur": 5.0, "offset": 5.0}, {"f": 300.0, "dur": 5.0, "offset": 5.0}, ], - fs: float = 1000., + fs: float = 1000.0, msg_dur: float = 1.0, - win_step_dur: typing.Optional[float] = None + win_step_dur: typing.Optional[float] = None, ) -> typing.List[AxisArray]: """ Create a continuous signal with periodic components. The signal will be divided into n segments, @@ -57,12 +57,16 @@ def create_messages_with_periodic_signal( for s_p in sin_params: offs = s_p.get("offset", 0.0) b_t = np.logical_and(t_vec >= offs, t_vec <= offs + s_p["dur"]) - data[b_t] += s_p.get("a", 1.) * np.sin(2 * np.pi * s_p["f"] * t_vec[b_t] + s_p.get("p", 0)) + data[b_t] += s_p.get("a", 1.0) * np.sin( + 2 * np.pi * s_p["f"] * t_vec[b_t] + s_p.get("p", 0) + ) # How will we split the data into messages? With a rolling window or non-overlapping? if win_step_dur is not None: win_step = int(win_step_dur * fs) - data_splits = sliding_window_view(data, (int(msg_dur * fs),), axis=0)[::win_step] + data_splits = sliding_window_view(data, (int(msg_dur * fs),), axis=0)[ + ::win_step + ] else: n_msgs = int(t_end / msg_dur) data_splits = np.array_split(data, n_msgs, axis=0) @@ -76,7 +80,7 @@ def create_messages_with_periodic_signal( AxisArray( split_dat[..., None], dims=["time", "ch"], - axes=frozendict({"time": _time_axis}) + axes=frozendict({"time": _time_axis}), ) ) offset += split_dat.shape[0] / fs diff --git a/tests/test_activation.py b/tests/test_activation.py index 6b61679..5054bad 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -8,21 +8,25 @@ from ezmsg.sigproc.activation import activation, ActivationFunction, ACTIVATIONS -@pytest.mark.parametrize("function", [_ for _ in ActivationFunction] + ActivationFunction.options()) +@pytest.mark.parametrize( + "function", [_ for _ in ActivationFunction] + ActivationFunction.options() +) def test_activation(function: str): in_fs = 19.0 sig = np.arange(24, dtype=float).reshape(4, 3, 2) if function in [ActivationFunction.LOGIT, "logit"]: sig += 1e-9 - sig /= (np.max(sig) + 1e-3) + sig /= np.max(sig) + 1e-3 def msg_generator(): for msg_ix in range(sig.shape[0]): - msg_sig = sig[msg_ix:msg_ix+1] + msg_sig = sig[msg_ix : msg_ix + 1] msg = AxisArray( data=msg_sig, dims=["time", "ch", "feat"], - axes=frozendict({"time": AxisArray.Axis.TimeAxis(fs=in_fs, offset=msg_ix / in_fs)}) + axes=frozendict( + {"time": AxisArray.Axis.TimeAxis(fs=in_fs, offset=msg_ix / in_fs)} + ), ) yield msg diff --git a/tests/test_affine_transform.py b/tests/test_affine_transform.py index 2817a51..ab951f3 100644 --- a/tests/test_affine_transform.py +++ b/tests/test_affine_transform.py @@ -64,7 +64,8 @@ def test_affine_passthrough(): gen = affine_transform(weights="passthrough", axis="does not matter") msg_out = gen.send(msg_in) - assert msg_out.data is in_dat # This is not desirable in ezmsg pipeline but fine for the generator + # We wouldn't want out_data is in_dat ezmsg pipeline but it's fine for the generator + assert msg_out.data is in_dat assert_messages_equal([msg_out], backup) diff --git a/tests/test_aggregate.py b/tests/test_aggregate.py index 8874643..8cf9393 100644 --- a/tests/test_aggregate.py +++ b/tests/test_aggregate.py @@ -11,7 +11,6 @@ def get_msg_gen(n_chans=20, n_freqs=100, data_dur=30.0, fs=1024.0, key=""): - n_samples = int(data_dur * fs) data = np.arange(n_samples * n_chans * n_freqs).reshape(n_samples, n_chans, n_freqs) n_msgs = int(data_dur / 2) @@ -22,19 +21,23 @@ def msg_generator(): msg = AxisArray( data=arr, dims=["time", "ch", "freq"], - axes=frozendict({ - "time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset), - "freq": AxisArray.Axis(gain=1.0, offset=0.0, unit="Hz") - }), - key=key + axes=frozendict( + { + "time": AxisArray.Axis.TimeAxis(fs=fs, offset=offset), + "freq": AxisArray.Axis(gain=1.0, offset=0.0, unit="Hz"), + } + ), + key=key, ) offset += arr.shape[0] / fs yield msg + return msg_generator() @pytest.mark.parametrize( - "agg_func", [AggregationFunction.MEAN, AggregationFunction.MEDIAN, AggregationFunction.STD] + "agg_func", + [AggregationFunction.MEAN, AggregationFunction.MEDIAN, AggregationFunction.STD], ) def test_aggregate(agg_func: AggregationFunction): bands = [(5.0, 20.0), (30.0, 50.0)] @@ -45,6 +48,7 @@ def test_aggregate(agg_func: AggregationFunction): # Grab a deepcopy backup of the inputs so we can check the inputs didn't change # while being processed. import copy + backup = [copy.deepcopy(_) for _ in in_msgs] gen = ranged_aggregate(axis=targ_ax, bands=bands, operation=agg_func) @@ -69,12 +73,17 @@ def test_aggregate(agg_func: AggregationFunction): agg_func = { AggregationFunction.MEAN: partial(np.mean, axis=-1, keepdims=True), AggregationFunction.MEDIAN: partial(np.median, axis=-1, keepdims=True), - AggregationFunction.STD: partial(np.std, axis=-1, keepdims=True) + AggregationFunction.STD: partial(np.std, axis=-1, keepdims=True), }[agg_func] - expected_data = np.concatenate([ - agg_func(data[..., np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)]) - for (start, stop) in bands - ], axis=-1) + expected_data = np.concatenate( + [ + agg_func( + data[..., np.logical_and(targ_ax_vec >= start, targ_ax_vec <= stop)] + ) + for (start, stop) in bands + ], + axis=-1, + ) received_data = AxisArray.concatenate(*out_msgs, dim="time").data assert np.allclose(received_data, expected_data) @@ -104,10 +113,19 @@ def test_aggregate_handle_change(change_ax: str): change_ax being 'ch' should work while 'freq' should fail. """ in_msgs1 = [_ for _ in get_msg_gen(n_chans=20, n_freqs=100)] - in_msgs2 = [_ for _ in get_msg_gen(n_chans=17 if change_ax == "ch" else 20, - n_freqs=70 if change_ax == "freq" else 100)] - - gen = ranged_aggregate(axis="freq", bands=[(5.0, 20.0), (30.0, 50.0)], operation=AggregationFunction.MEAN) + in_msgs2 = [ + _ + for _ in get_msg_gen( + n_chans=17 if change_ax == "ch" else 20, + n_freqs=70 if change_ax == "freq" else 100, + ) + ] + + gen = ranged_aggregate( + axis="freq", + bands=[(5.0, 20.0), (30.0, 50.0)], + operation=AggregationFunction.MEAN, + ) out_msgs1 = [gen.send(_) for _ in in_msgs1] print(len(out_msgs1)) diff --git a/tests/test_bandpower.py b/tests/test_bandpower.py index 320f98b..79f7d63 100644 --- a/tests/test_bandpower.py +++ b/tests/test_bandpower.py @@ -10,13 +10,16 @@ def _debug_plot(result): import matplotlib.pyplot as plt - t_vec = result.axes["time"].offset + np.arange(result.data.shape[0]) * result.axes["time"].gain + t_vec = ( + result.axes["time"].offset + + np.arange(result.data.shape[0]) * result.axes["time"].gain + ) plt.plot(t_vec, result.data[..., 0]) def test_bandpower(): win_dur = 1.0 - fs = 1000. + fs = 1000.0 bands = [(9, 11), (70, 90), (134, 136)] sin_params = [ @@ -29,7 +32,7 @@ def test_bandpower(): sin_params=sin_params, fs=fs, msg_dur=0.4, - win_step_dur=None # The spectrogram will do the windowing + win_step_dur=None, # The spectrogram will do the windowing ) # Grab a deepcopy backup of the inputs, so we can check the inputs didn't change @@ -41,7 +44,7 @@ def test_bandpower(): window_dur=win_dur, window_shift=0.1, ), - bands=bands + bands=bands, ) results = [gen.send(_) for _ in messages] @@ -51,13 +54,16 @@ def test_bandpower(): # _debug_plot(result) # Check the amplitudes at the midpoints of each of our sinusoids. - t_vec = result.axes["time"].offset + np.arange(result.data.shape[0]) * result.axes["time"].gain + t_vec = ( + result.axes["time"].offset + + np.arange(result.data.shape[0]) * result.axes["time"].gain + ) mags = [] for s_p in sin_params[:2]: - ix = np.argmin(np.abs(t_vec - (s_p["offset"] + s_p["dur"]/2))) + ix = np.argmin(np.abs(t_vec - (s_p["offset"] + s_p["dur"] / 2))) mags.append(result.data[ix, 0, 0]) for s_p in sin_params[2:]: - ix = np.argmin(np.abs(t_vec - (s_p["offset"] + s_p["dur"]/2))) + ix = np.argmin(np.abs(t_vec - (s_p["offset"] + s_p["dur"] / 2))) mags.append(result.data[ix, 2, 0]) # The sorting of the measured magnitudes should match the sorting of the parameter magnitudes. assert np.array_equal(np.argsort(mags), np.argsort([_["a"] for _ in sin_params])) diff --git a/tests/test_butter.py b/tests/test_butter.py index 926e402..6926833 100644 --- a/tests/test_butter.py +++ b/tests/test_butter.py @@ -5,7 +5,9 @@ from ezmsg.util.messages.axisarray import AxisArray from ezmsg.sigproc.butterworthfilter import butter -from ezmsg.sigproc.butterworthfilter import ButterworthFilterSettings as LegacyButterSettings +from ezmsg.sigproc.butterworthfilter import ( + ButterworthFilterSettings as LegacyButterSettings, +) @pytest.mark.parametrize( @@ -116,7 +118,11 @@ def test_butterworth( for split_dat in np.array_split(in_dat, n_splits, axis=time_ax): _time_axis = AxisArray.Axis.TimeAxis(fs=fs, offset=n_seen / fs) messages.append( - AxisArray(split_dat, dims=dat_dims, axes=frozendict({**other_axes, "time": _time_axis})) + AxisArray( + split_dat, + dims=dat_dims, + axes=frozendict({**other_axes, "time": _time_axis}), + ) ) n_seen += split_dat.shape[time_ax] @@ -131,4 +137,4 @@ def test_butterworth( ) result = np.concatenate([gen.send(_).data for _ in messages], axis=time_ax) - assert np.allclose(result, out_dat) \ No newline at end of file + assert np.allclose(result, out_dat) diff --git a/tests/test_butterworth.py b/tests/test_butterworth.py index 42728b9..45107b8 100644 --- a/tests/test_butterworth.py +++ b/tests/test_butterworth.py @@ -87,7 +87,7 @@ def test_butterworth_system( system = ButterworthSystem(settings) - ez.run(SYSTEM = system) + ez.run(SYSTEM=system) messages: typing.List[AxisArray] = [] for msg in message_log(test_filename): diff --git a/tests/test_downsample.py b/tests/test_downsample.py index a452aa0..4c6c7aa 100644 --- a/tests/test_downsample.py +++ b/tests/test_downsample.py @@ -29,19 +29,23 @@ def test_downsample_core(block_size: int, factor: int): n_features = 3 num_samps = int(np.ceil(test_dur * in_fs)) num_msgs = int(np.ceil(num_samps / block_size)) - sig = np.arange(num_samps * n_channels * n_features).reshape(num_samps, n_channels, n_features) + sig = np.arange(num_samps * n_channels * n_features).reshape( + num_samps, n_channels, n_features + ) # tvec = np.arange(num_samps) / in_fs def msg_generator(): for msg_ix in range(num_msgs): - msg_sig = sig[msg_ix*block_size:(msg_ix+1)*block_size] + msg_sig = sig[msg_ix * block_size : (msg_ix + 1) * block_size] msg_idx: float = msg_sig[0, 0, 0] / (n_channels * n_features) msg_offs = msg_idx / in_fs msg = AxisArray( data=msg_sig, dims=["time", "ch", "feat"], - axes=frozendict({"time": AxisArray.Axis.TimeAxis(fs=in_fs, offset=msg_offs)}), - key="test_downsample_core" + axes=frozendict( + {"time": AxisArray.Axis.TimeAxis(fs=in_fs, offset=msg_offs)} + ), + key="test_downsample_core", ) yield msg @@ -61,7 +65,9 @@ def msg_generator(): assert all(msg.axes["time"].gain == factor / in_fs for msg in out_msgs) # Assert messages have the correct timestamps - expected_offsets = np.cumsum([0] + [_.data.shape[0] for _ in out_msgs]) * factor / in_fs + expected_offsets = ( + np.cumsum([0] + [_.data.shape[0] for _ in out_msgs]) * factor / in_fs + ) actual_offsets = np.array([_.axes["time"].offset for _ in out_msgs]) assert np.allclose(actual_offsets, expected_offsets[:-1]) @@ -129,7 +135,9 @@ def test_downsample_system( settings = DownsampleSystemSettings( num_msgs=num_msgs, counter_settings=CounterSettings( - n_time=block_size, fs=in_fs, dispatch_rate=20.0, + n_time=block_size, + fs=in_fs, + dispatch_rate=20.0, ), down_settings=DownsampleSettings(factor=factor), log_settings=MessageLoggerSettings(output=test_filename), @@ -138,7 +146,7 @@ def test_downsample_system( system = DownsampleSystem(settings) - ez.run(SYSTEM = system) + ez.run(SYSTEM=system) messages: List[AxisArray] = [_ for _ in message_log(test_filename)] os.remove(test_filename) @@ -148,7 +156,12 @@ def test_downsample_system( out_fs = in_fs / factor assert np.allclose( np.array([1 / msg.axes["time"].gain for msg in messages]), - np.ones(len(messages,)) * out_fs + np.ones( + len( + messages, + ) + ) + * out_fs, ) # Check data @@ -168,8 +181,7 @@ def test_downsample_system( first_samps = np.concatenate(first_samps, axis=time_ax_idx) expected_offsets = first_samps.squeeze() / out_fs / factor assert np.allclose( - np.array([msg.axes["time"].offset for msg in messages]), - expected_offsets + np.array([msg.axes["time"].offset for msg in messages]), expected_offsets ) ez.logger.info("Test Complete.") diff --git a/tests/test_filterbank.py b/tests/test_filterbank.py index fa79ffd..d5c311a 100644 --- a/tests/test_filterbank.py +++ b/tests/test_filterbank.py @@ -20,7 +20,7 @@ def gen_signal(fs, dur): # generate signal f_gains = [9, 5] # frequency will scale as square of this value * time. t_offsets = [0.3 * dur, 0.1 * dur] # When the chirp starts. - time = np.arange(int(dur*fs)) / fs + time = np.arange(int(dur * fs)) / fs # TODO: Replace with sps.chirp? chirp1, frequency1 = make_chirp(time, t_offsets[0], f_gains[0]) chirp2, frequency2 = make_chirp(time, t_offsets[1], f_gains[1]) @@ -32,12 +32,20 @@ def gen_signal(fs, dur): def bandpass_kaiser(ntaps, lowcut, highcut, fs, width): atten = sps.kaiser_atten(ntaps, width / (0.5 * fs)) beta = sps.kaiser_beta(atten) - taps = sps.firwin(ntaps, [lowcut, highcut], fs=fs, pass_zero="bandpass", - window=("kaiser", beta), scale=False) + taps = sps.firwin( + ntaps, + [lowcut, highcut], + fs=fs, + pass_zero="bandpass", + window=("kaiser", beta), + scale=False, + ) return taps -@pytest.mark.parametrize("mode", [FilterbankMode.CONV, FilterbankMode.FFT, FilterbankMode.AUTO]) +@pytest.mark.parametrize( + "mode", [FilterbankMode.CONV, FilterbankMode.FFT, FilterbankMode.AUTO] +) @pytest.mark.parametrize("kernel_type", ["kaiser", "brickwall"]) def test_filterbank(mode: str, kernel_type: str): # Generate test signal @@ -52,11 +60,9 @@ def test_filterbank(mode: str, kernel_type: str): for idx in range(0, len(tvec), step_size): in_messages.append( AxisArray( - data=chirp[:, idx:idx+step_size], + data=chirp[:, idx : idx + step_size], dims=["ch", "time"], - axes={ - "time": AxisArray.Axis.TimeAxis(offset=tvec[idx], fs=fs) - } + axes={"time": AxisArray.Axis.TimeAxis(offset=tvec[idx], fs=fs)}, ) ) @@ -90,13 +96,18 @@ def test_filterbank(mode: str, kernel_type: str): # - conv has transients at the beginning that we need to skip over # - oaconvolve assumes the data is finished so it returns the trailing windows, # but filterbank keeps the tail assuming more data is coming. - expected = np.stack([sps.oaconvolve(chirp, _[None, :], axes=1) for _ in kernels], axis=1) + expected = np.stack( + [sps.oaconvolve(chirp, _[None, :], axes=1) for _ in kernels], axis=1 + ) idx0 = ntaps if mode in [FilterbankMode.CONV, FilterbankMode.AUTO] else 0 - assert np.allclose(result.data[..., idx0:], expected[..., idx0:result.data.shape[-1]]) + assert np.allclose( + result.data[..., idx0:], expected[..., idx0 : result.data.shape[-1]] + ) if False: # Debug visualize result import matplotlib.pyplot as plt + # tmp = result.data tmp = expected nch = tmp.shape[0] diff --git a/tests/test_math.py b/tests/test_math.py index b26996c..1138381 100644 --- a/tests/test_math.py +++ b/tests/test_math.py @@ -46,7 +46,9 @@ def test_const_difference(value: float, subtrahend: bool): proc = const_difference(value, subtrahend) msg_out = proc.send(msg_in) - assert np.array_equal(msg_out.data, (in_dat - value) if subtrahend else (value - in_dat)) + assert np.array_equal( + msg_out.data, (in_dat - value) if subtrahend else (value - in_dat) + ) def test_invert(): diff --git a/tests/test_sampler.py b/tests/test_sampler.py index b2669ed..561d73a 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -10,10 +10,12 @@ from ezmsg.util.messages.axisarray import AxisArray from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings from ezmsg.sigproc.sampler import ( - Sampler, SamplerSettings, - TriggerGenerator, TriggerGeneratorSettings, + Sampler, + SamplerSettings, + TriggerGenerator, + TriggerGeneratorSettings, SampleTriggerMessage, - sampler + sampler, ) from ezmsg.sigproc.synth import Oscillator, OscillatorSettings from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings @@ -25,21 +27,27 @@ def test_sampler_gen(): data_dur = 10.0 chunk_period = 0.1 - fs = 500. + fs = 500.0 n_chans = 3 # The sampler is a bit complicated as it requires 2 different inputs: signal and triggers # Prepare signal data n_data = int(data_dur * fs) - data = np.arange(n_chans *n_data).reshape(n_chans, n_data) + data = np.arange(n_chans * n_data).reshape(n_chans, n_data) offsets = np.arange(n_data) / fs n_chunks = int(np.ceil(data_dur / chunk_period)) n_per_chunk = int(np.ceil(n_data / n_chunks)) signal_msgs = [ AxisArray( - data=data[:, ix * n_per_chunk:(ix + 1) * n_per_chunk], + data=data[:, ix * n_per_chunk : (ix + 1) * n_per_chunk], dims=["ch", "time"], - axes=frozendict({"time": AxisArray.Axis.TimeAxis(fs=fs, offset=offsets[ix * n_per_chunk])}) + axes=frozendict( + { + "time": AxisArray.Axis.TimeAxis( + fs=fs, offset=offsets[ix * n_per_chunk] + ) + } + ), ) for ix in range(n_chunks) ] @@ -51,23 +59,25 @@ def test_sampler_gen(): period = (-0.01, 0.74) trigger_msgs = [ SampleTriggerMessage( - timestamp=_ts, - period=period, - value=["Start", "Stop"][_ix % 2] + timestamp=_ts, period=period, value=["Start", "Stop"][_ix % 2] ) for _ix, _ts in enumerate(trig_ts) ] backup_trigger = [copy.deepcopy(_) for _ in trigger_msgs] # Mix the messages and sort by time - msg_ts = [_.axes["time"].offset for _ in signal_msgs] + [_.timestamp for _ in trigger_msgs] + msg_ts = [_.axes["time"].offset for _ in signal_msgs] + [ + _.timestamp for _ in trigger_msgs + ] mix_msgs = signal_msgs + trigger_msgs mix_msgs = [mix_msgs[_] for _ in np.argsort(msg_ts)] # Create the sample-generator period_dur = period[1] - period[0] buffer_dur = 2 * max(period_dur, period[1]) - gen = sampler(buffer_dur, axis="time", period=None, value=None, estimate_alignment=True) + gen = sampler( + buffer_dur, axis="time", period=None, value=None, estimate_alignment=True + ) # Run the messages through the generator and collect samples. samples = [] @@ -79,12 +89,19 @@ def test_sampler_gen(): assert len(samples) == n_trigs # Check sample data size - assert all([_.sample.data.shape == (n_chans, int(fs * period_dur)) for _ in samples]) + assert all( + [_.sample.data.shape == (n_chans, int(fs * period_dur)) for _ in samples] + ) # Compare the sample window slice against the trigger timestamps - latencies = [_.sample.axes["time"].offset - (_.trigger.timestamp + _.trigger.period[0]) for _ in samples] + latencies = [ + _.sample.axes["time"].offset - (_.trigger.timestamp + _.trigger.period[0]) + for _ in samples + ] assert all([0 <= _ < 1 / fs for _ in latencies]) # Check the sample trigger value matches the trigger input. - assert all([_.trigger.value == ["Start", "Stop"][ix % 2] for ix, _ in enumerate(samples)]) + assert all( + [_.trigger.value == ["Start", "Stop"][ix % 2] for ix, _ in enumerate(samples)] + ) class SamplerSystemSettings(ez.Settings): @@ -131,7 +148,7 @@ def test_sampler_system(test_name: Optional[str] = None): period = (0.5, 1.5) n_msgs = 4 - sample_dur = (period[1] - period[0]) + sample_dur = period[1] - period[0] publish_period = sample_dur * 2.0 test_filename = get_test_fn(test_name) @@ -148,9 +165,7 @@ def test_sampler_system(test_name: Optional[str] = None): sync=True, # Adjust `freq` to sync with sampling rate ), trigger_settings=TriggerGeneratorSettings( - period=period, - prewait=0.5, - publish_period=publish_period + period=period, prewait=0.5, publish_period=publish_period ), sampler_settings=SamplerSettings(buffer_dur=publish_period + 1.0), log_settings=MessageLoggerSettings(output=test_filename), @@ -166,6 +181,9 @@ def test_sampler_system(test_name: Optional[str] = None): assert len(messages) == n_msgs assert all([_.sample.data.shape == (int(freq * sample_dur), 1) for _ in messages]) # Test the sample window slice vs the trigger timestamps - latencies = [_.sample.axes["time"].offset - (_.trigger.timestamp + _.trigger.period[0]) for _ in messages] - assert all([0 <= _ < 1/freq for _ in latencies]) + latencies = [ + _.sample.axes["time"].offset - (_.trigger.timestamp + _.trigger.period[0]) + for _ in messages + ] + assert all([0 <= _ < 1 / freq for _ in latencies]) # Given that the input is a pure sinusoid, we could test that the signal has expected characteristics. diff --git a/tests/test_scaler.py b/tests/test_scaler.py index 713972e..2c84bf1 100644 --- a/tests/test_scaler.py +++ b/tests/test_scaler.py @@ -23,12 +23,14 @@ def test_adaptive_standard_scaler_river(): # Test data values taken from river: # https://github.com/online-ml/river/blob/main/river/preprocessing/scale.py#L511-L536C17 data = np.array([5.278, 5.050, 6.550, 7.446, 9.472, 10.353, 11.784, 11.173]) - expected_result = np.array([0.0, -0.816, 0.812, 0.695, 0.754, 0.598, 0.651, 0.124]) + expected_result = np.array( + [0.0, -0.816, 0.812, 0.695, 0.754, 0.598, 0.651, 0.124] + ) test_input = AxisArray( np.tile(data, (2, 1)), dims=["ch", "time"], - axes=frozendict({"time": AxisArray.Axis()}) + axes=frozendict({"time": AxisArray.Axis()}), ) backup = [copy.deepcopy(test_input)] @@ -53,7 +55,9 @@ class ScalerTestSystemSettings(ez.Settings): counter_settings: CounterSettings scaler_settings: AdaptiveStandardScalerSettings log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) + term_settings: TerminateOnTotalSettings = field( + default_factory=TerminateOnTotalSettings + ) class ScalerTestSystem(ez.Collection): @@ -74,15 +78,15 @@ def network(self) -> ez.NetworkDefinition: return ( (self.COUNTER.OUTPUT_SIGNAL, self.SCALER.INPUT_SIGNAL), (self.SCALER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE) + (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) def test_scaler_system( - tau: float = 1.0, - fs: float = 10.0, - duration: float = 2.0, - test_name: Optional[str] = None, + tau: float = 1.0, + fs: float = 10.0, + duration: float = 2.0, + test_name: Optional[str] = None, ): """ For this test, we assume that Counter and scaler_np are functioning properly. @@ -101,16 +105,13 @@ def test_scaler_system( dispatch_rate=duration, # Simulation duration in 1.0 seconds mod=None, ), - scaler_settings=AdaptiveStandardScalerSettings( - time_constant=tau, - axis="time" - ), + scaler_settings=AdaptiveStandardScalerSettings(time_constant=tau, axis="time"), log_settings=MessageLoggerSettings( output=test_filename, ), term_settings=TerminateOnTotalSettings( total=int(duration * fs / block_size), - ) + ), ) system = ScalerTestSystem(settings) ez.run(SYSTEM=system) @@ -124,7 +125,7 @@ def test_scaler_system( expected_input = AxisArray( np.arange(len(data))[None, :], dims=["ch", "time"], - axes=frozendict({"time": AxisArray.Axis(gain=1/fs, offset=0.0)}) + axes=frozendict({"time": AxisArray.Axis(gain=1 / fs, offset=0.0)}), ) _scaler = scaler_np(time_constant=tau, axis="time") expected_output = _scaler.send(expected_input) diff --git a/tests/test_slicer.py b/tests/test_slicer.py index 648c312..6c84c6e 100644 --- a/tests/test_slicer.py +++ b/tests/test_slicer.py @@ -67,7 +67,7 @@ def test_slicer_gen_drop_dim(): dims=["time", "ch"], axes={ "time": AxisArray.Axis.TimeAxis(fs=100.0, offset=0.1), - } + }, ) backup = [copy.deepcopy(msg_in)] diff --git a/tests/test_spectrogram.py b/tests/test_spectrogram.py index 151fdea..a9470af 100644 --- a/tests/test_spectrogram.py +++ b/tests/test_spectrogram.py @@ -11,29 +11,35 @@ def _debug_plot( - ax_arr: AxisArray, - sin_params: typing.List[typing.Dict[str, float]] = None + ax_arr: AxisArray, sin_params: typing.List[typing.Dict[str, float]] = None ): import matplotlib.pyplot as plt t_ix = ax_arr.get_axis_idx("time") - t_vec = ax_arr.axes["time"].offset + np.arange(ax_arr.data.shape[t_ix] * ax_arr.axes["time"].gain) + t_vec = ax_arr.axes["time"].offset + np.arange( + ax_arr.data.shape[t_ix] * ax_arr.axes["time"].gain + ) t_vec -= ax_arr.axes["time"].gain / 2 f_ix = ax_arr.get_axis_idx("freq") - f_vec = ax_arr.axes["freq"].offset + np.arange(ax_arr.data.shape[f_ix] * ax_arr.axes["freq"].gain) + f_vec = ax_arr.axes["freq"].offset + np.arange( + ax_arr.data.shape[f_ix] * ax_arr.axes["freq"].gain + ) f_vec -= ax_arr.axes["freq"].gain / 2 plt.imshow( ax_arr.data[..., 0].T, origin="lower", aspect="auto", - extent=(t_vec[0], t_vec[-1], f_vec[0], f_vec[-1]) + extent=(t_vec[0], t_vec[-1], f_vec[0], f_vec[-1]), ) plt.xlabel("Time") plt.ylabel("Frequency") if sin_params is not None: for s_p in sin_params: - xx = (s_p.get("offset", 0.0) + t_vec[0], s_p.get("offset", 0.0) + t_vec[0] + s_p["dur"]) + xx = ( + s_p.get("offset", 0.0) + t_vec[0], + s_p.get("offset", 0.0) + t_vec[0] + s_p["dur"], + ) yy = (s_p["f"], s_p["f"]) plt.plot(xx, yy, linestyle="--", color="r", linewidth=1.0) @@ -41,7 +47,7 @@ def _debug_plot( def test_spectrogram(): win_dur = 1.0 win_step_dur = 0.5 - fs = 1000. + fs = 1000.0 seg_dur = 5.0 sin_params = [ {"f": 10.0, "dur": seg_dur, "offset": 0.0}, @@ -55,7 +61,7 @@ def test_spectrogram(): sin_params=sin_params, fs=fs, msg_dur=0.4, - win_step_dur=None # The spectrogram will do the windowing + win_step_dur=None, # The spectrogram will do the windowing ) backup = [copy.deepcopy(_) for _ in messages] @@ -74,7 +80,9 @@ def test_spectrogram(): # Check that the windows span the expected times. expected_t_span = 2 * seg_dur - data_t_span = (results[-1].axes["time"].offset + win_step_dur) - results[0].axes["time"].offset + data_t_span = (results[-1].axes["time"].offset + win_step_dur) - results[0].axes[ + "time" + ].offset assert np.abs(expected_t_span - data_t_span) < 1e-9 all_deltas = np.diff([_.axes["time"].offset for _ in results]) assert np.allclose(all_deltas, win_step_dur + np.zeros((len(results) - 1))) diff --git a/tests/test_spectrum.py b/tests/test_spectrum.py index d48607b..9fec5df 100644 --- a/tests/test_spectrum.py +++ b/tests/test_spectrum.py @@ -10,14 +10,23 @@ import ezmsg.core as ez from ezmsg.util.messages.axisarray import AxisArray, slice_along_axis from ezmsg.sigproc.spectrum import ( - spectrum, SpectralTransform, SpectralOutput, WindowFunction, Spectrum, SpectrumSettings + spectrum, + SpectralTransform, + SpectralOutput, + WindowFunction, + Spectrum, + SpectrumSettings, ) from ezmsg.sigproc.window import Window, WindowSettings from ezmsg.sigproc.synth import EEGSynth, EEGSynthSettings from ezmsg.util.messagelogger import MessageLogger, MessageLoggerSettings from ezmsg.util.messagecodec import message_log from ezmsg.util.terminate import TerminateOnTotal, TerminateOnTotalSettings -from util import get_test_fn, create_messages_with_periodic_signal, assert_messages_equal +from util import ( + get_test_fn, + create_messages_with_periodic_signal, + assert_messages_equal, +) def _debug_plot_welch(raw: AxisArray, result: AxisArray, welch_db: bool = True): @@ -26,7 +35,9 @@ def _debug_plot_welch(raw: AxisArray, result: AxisArray, welch_db: bool = True): fig, ax = plt.subplots(2, 1) t_ax = raw.axes["time"] - t_vec = np.arange(raw.data.shape[raw.get_axis_idx("time")]) * t_ax.gain + t_ax.offset + t_vec = ( + np.arange(raw.data.shape[raw.get_axis_idx("time")]) * t_ax.gain + t_ax.offset + ) ch0_raw = raw.data[..., :, 0] if ch0_raw.ndim > 1: # For multi-win inputs @@ -35,14 +46,19 @@ def _debug_plot_welch(raw: AxisArray, result: AxisArray, welch_db: bool = True): ax[0].set_xlabel("Time (s)") f_ax = result.axes["freq"] - f_vec = np.arange(result.data.shape[result.get_axis_idx("freq")]) * f_ax.gain + f_ax.offset + f_vec = ( + np.arange(result.data.shape[result.get_axis_idx("freq")]) * f_ax.gain + + f_ax.offset + ) ch0_spec = result.data[..., :, 0] if ch0_spec.ndim > 1: ch0_spec = ch0_spec[0] ax[1].plot(f_vec, ch0_spec, label="calculated", linewidth=2.0) ax[1].set_xlabel("Frequency (Hz)") - f, Pxx = sps.welch(ch0_raw, fs=1 / raw.axes["time"].gain, window="hamming", nperseg=len(ch0_raw)) + f, Pxx = sps.welch( + ch0_raw, fs=1 / raw.axes["time"].gain, window="hamming", nperseg=len(ch0_raw) + ) if welch_db: Pxx = 10 * np.log10(Pxx) ax[1].plot(f, Pxx, label="welch", color="tab:orange", linestyle="--") @@ -54,12 +70,14 @@ def _debug_plot_welch(raw: AxisArray, result: AxisArray, welch_db: bool = True): @pytest.mark.parametrize("window", [WindowFunction.HANNING, WindowFunction.HAMMING]) -@pytest.mark.parametrize("transform", [SpectralTransform.REL_DB, SpectralTransform.REL_POWER]) -@pytest.mark.parametrize("output", [SpectralOutput.POSITIVE, SpectralOutput.NEGATIVE, SpectralOutput.FULL]) +@pytest.mark.parametrize( + "transform", [SpectralTransform.REL_DB, SpectralTransform.REL_POWER] +) +@pytest.mark.parametrize( + "output", [SpectralOutput.POSITIVE, SpectralOutput.NEGATIVE, SpectralOutput.FULL] +) def test_spectrum_gen_multiwin( - window: WindowFunction, - transform: SpectralTransform, - output: SpectralOutput + window: WindowFunction, transform: SpectralTransform, output: SpectralOutput ): win_dur = 1.0 win_step_dur = 0.5 @@ -72,13 +90,10 @@ def test_spectrum_gen_multiwin( win_len = int(win_dur * fs) messages = create_messages_with_periodic_signal( - sin_params=sin_params, - fs=fs, - msg_dur=win_dur, - win_step_dur=win_step_dur + sin_params=sin_params, fs=fs, msg_dur=win_dur, win_step_dur=win_step_dur ) input_multiwin = AxisArray.concatenate(*messages, dim="win") - input_multiwin.axes["win"] = AxisArray.Axis.TimeAxis(offset=0, fs=1/win_step_dur) + input_multiwin.axes["win"] = AxisArray.Axis.TimeAxis(offset=0, fs=1 / win_step_dur) gen = spectrum(axis="time", window=window, transform=transform, output=output) result = gen.send(input_multiwin) @@ -92,24 +107,31 @@ def test_spectrum_gen_multiwin( assert result.axes["freq"].gain == 1 / win_dur assert "freq" in result.dims fax_ix = result.get_axis_idx("freq") - f_len = win_len if output == SpectralOutput.FULL else (win_len // 2 + 1 - (win_len % 2)) + f_len = ( + win_len if output == SpectralOutput.FULL else (win_len // 2 + 1 - (win_len % 2)) + ) assert result.data.shape[fax_ix] == f_len f_vec = result.axes["freq"].gain * np.arange(f_len) + result.axes["freq"].offset if output == SpectralOutput.NEGATIVE: f_vec = np.abs(f_vec) for s_p in sin_params: f_ix = np.argmin(np.abs(f_vec - s_p["f"])) - peak_inds = np.argmax(slice_along_axis(result.data, slice(f_ix-3, f_ix+3), axis=fax_ix), axis=fax_ix) + peak_inds = np.argmax( + slice_along_axis(result.data, slice(f_ix - 3, f_ix + 3), axis=fax_ix), + axis=fax_ix, + ) assert np.all(peak_inds == 3) @pytest.mark.parametrize("window", [WindowFunction.HANNING, WindowFunction.HAMMING]) -@pytest.mark.parametrize("transform", [SpectralTransform.REL_DB, SpectralTransform.REL_POWER]) -@pytest.mark.parametrize("output", [SpectralOutput.POSITIVE, SpectralOutput.NEGATIVE, SpectralOutput.FULL]) +@pytest.mark.parametrize( + "transform", [SpectralTransform.REL_DB, SpectralTransform.REL_POWER] +) +@pytest.mark.parametrize( + "output", [SpectralOutput.POSITIVE, SpectralOutput.NEGATIVE, SpectralOutput.FULL] +) def test_spectrum_gen( - window: WindowFunction, - transform: SpectralTransform, - output: SpectralOutput + window: WindowFunction, transform: SpectralTransform, output: SpectralOutput ): win_dur = 1.0 win_step_dur = 0.5 @@ -120,10 +142,7 @@ def test_spectrum_gen( {"a": 0.2, "f": 200.0, "p": np.pi / 11, "dur": 20.0}, ] messages = create_messages_with_periodic_signal( - sin_params=sin_params, - fs=fs, - msg_dur=win_dur, - win_step_dur=win_step_dur + sin_params=sin_params, fs=fs, msg_dur=win_dur, win_step_dur=win_step_dur ) backup = [copy.deepcopy(_) for _ in messages] @@ -150,10 +169,7 @@ def test_spectrum_vs_sps_fft(complex: bool): {"a": 0.2, "f": 200.0, "p": np.pi / 11, "dur": 20.0}, ] messages = create_messages_with_periodic_signal( - sin_params=sin_params, - fs=fs, - msg_dur=win_dur, - win_step_dur=win_step_dur + sin_params=sin_params, fs=fs, msg_dur=win_dur, win_step_dur=win_step_dur ) nfft = 1 << (messages[0].data.shape[0] - 1).bit_length() # nextpow2 @@ -164,7 +180,7 @@ def test_spectrum_vs_sps_fft(complex: bool): output=SpectralOutput.FULL if complex else SpectralOutput.POSITIVE, norm="backward", do_fftshift=False, - nfft=nfft + nfft=nfft, ) results = [gen.send(msg) for msg in messages] test_spec = results[0].data @@ -180,7 +196,9 @@ class SpectrumSettingsTest(ez.Settings): window_settings: WindowSettings spectrum_settings: SpectrumSettings log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) + term_settings: TerminateOnTotalSettings = field( + default_factory=TerminateOnTotalSettings + ) class SpectrumIntegrationTest(ez.Collection): @@ -204,14 +222,14 @@ def network(self) -> ez.NetworkDefinition: (self.SOURCE.OUTPUT_SIGNAL, self.WIN.INPUT_SIGNAL), (self.WIN.OUTPUT_SIGNAL, self.SPEC.INPUT_SIGNAL), (self.SPEC.OUTPUT_SIGNAL, self.SINK.INPUT_MESSAGE), - (self.SINK.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE) + (self.SINK.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) def test_spectrum_system( test_name: typing.Optional[str] = None, ): - fs = 500. + fs = 500.0 n_time = 100 # samples per block. dispatch_rate = fs / n_time target_dur = 2.0 window_dur = 1.0 @@ -244,7 +262,7 @@ def test_spectrum_system( ), term_settings=TerminateOnTotalSettings( total=target_messages, - ) + ), ) system = SpectrumIntegrationTest(settings) ez.run(SYSTEM=system) diff --git a/tests/test_synth.py b/tests/test_synth.py index ac18a0e..22549e1 100644 --- a/tests/test_synth.py +++ b/tests/test_synth.py @@ -14,10 +14,16 @@ from ezmsg.util.terminate import TerminateOnTotalSettings, TerminateOnTotal from util import get_test_fn from ezmsg.sigproc.synth import ( - clock, aclock, Clock, ClockSettings, - acounter, Counter, CounterSettings, + clock, + aclock, + Clock, + ClockSettings, + acounter, + Counter, + CounterSettings, sin, - EEGSynth, EEGSynthSettings + EEGSynth, + EEGSynthSettings, ) @@ -36,7 +42,8 @@ def test_clock_gen(dispatch_rate: typing.Optional[float]): if dispatch_rate is not None: assert (run_time - 1 / dispatch_rate) < t_elapsed < (run_time + 0.1) else: - assert t_elapsed < (n_target * 1e-4) # 100 usec per iteration is pretty generous + # 100 usec per iteration is pretty generous + assert t_elapsed < (n_target * 1e-4) @pytest.mark.parametrize("dispatch_rate", [None, 2.0, 20.0]) @@ -55,13 +62,16 @@ async def test_aclock_agen(dispatch_rate: typing.Optional[float]): if dispatch_rate: assert (run_time - 1.1 / dispatch_rate) < t_elapsed < (run_time + 0.1) else: - assert t_elapsed < (n_target * 1e-4) # 100 usec per iteration is pretty generous + # 100 usec per iteration is pretty generous + assert t_elapsed < (n_target * 1e-4) class ClockTestSystemSettings(ez.Settings): clock_settings: ClockSettings log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) + term_settings: TerminateOnTotalSettings = field( + default_factory=TerminateOnTotalSettings + ) class ClockTestSystem(ez.Collection): @@ -79,14 +89,14 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( (self.CLOCK.OUTPUT_CLOCK, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE) + (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) @pytest.mark.parametrize("dispatch_rate", [None, 2.0, 20.0]) def test_clock_system( - dispatch_rate: typing.Optional[float], - test_name: typing.Optional[str] = None, + dispatch_rate: typing.Optional[float], + test_name: typing.Optional[str] = None, ): run_time = 1.0 n_target = int(np.ceil(dispatch_rate * run_time)) if dispatch_rate else 100 @@ -95,7 +105,7 @@ def test_clock_system( settings = ClockTestSystemSettings( clock_settings=ClockSettings(dispatch_rate=dispatch_rate), log_settings=MessageLoggerSettings(output=test_filename), - term_settings=TerminateOnTotalSettings(total=n_target) + term_settings=TerminateOnTotalSettings(total=n_target), ) system = ClockTestSystem(settings) ez.run(SYSTEM=system) @@ -108,19 +118,20 @@ def test_clock_system( assert len(messages) >= n_target - @pytest.mark.parametrize("block_size", [1, 20]) @pytest.mark.parametrize("fs", [10.0, 1000.0]) @pytest.mark.parametrize("n_ch", [3]) -@pytest.mark.parametrize("dispatch_rate", [None, "realtime", "ext_clock", 2.0, 20.0]) # "ext_clock" needs a separate test -@pytest.mark.parametrize("mod", [2 ** 3, None]) +@pytest.mark.parametrize( + "dispatch_rate", [None, "realtime", "ext_clock", 2.0, 20.0] +) # "ext_clock" needs a separate test +@pytest.mark.parametrize("mod", [2**3, None]) @pytest.mark.asyncio async def test_acounter( block_size: int, fs: float, n_ch: int, dispatch_rate: typing.Optional[typing.Union[float, str]], - mod: typing.Optional[int] + mod: typing.Optional[int], ): target_dur = 2.6 # 2.6 seconds per test if dispatch_rate is None: @@ -134,7 +145,7 @@ async def test_acounter( chunk_dur = 0.1 else: # Note: float dispatch_rate will yield different number of samples than expected by target_dur and fs - chunk_dur = 1. / dispatch_rate + chunk_dur = 1.0 / dispatch_rate target_messages = int(target_dur / chunk_dur) # Run generator @@ -163,14 +174,16 @@ async def test_acounter( atol = 0.002 else: # Offsets are synthetic. - atol = 1.e-8 + atol = 1.0e-8 assert np.allclose(offsets[2:], expected_offsets[2:], atol=atol) class CounterTestSystemSettings(ez.Settings): counter_settings: CounterSettings log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) + term_settings: TerminateOnTotalSettings = field( + default_factory=TerminateOnTotalSettings + ) class CounterTestSystem(ez.Collection): @@ -188,7 +201,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( (self.COUNTER.OUTPUT_SIGNAL, self.LOG.INPUT_MESSAGE), - (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE) + (self.LOG.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) @@ -203,14 +216,14 @@ def network(self) -> ez.NetworkDefinition: (10, 10.0, 20.0, 2**3), # No test for ext_clock because that requires a different system # (20, 10.0, "ext_clock", None), - ] + ], ) def test_counter_system( - block_size: int, - fs: float, - dispatch_rate: typing.Optional[typing.Union[float, str]], - mod: typing.Optional[int], - test_name: typing.Optional[str] = None, + block_size: int, + fs: float, + dispatch_rate: typing.Optional[typing.Union[float, str]], + mod: typing.Optional[int], + test_name: typing.Optional[str] = None, ): n_ch = 3 target_dur = 2.6 # 2.6 seconds per test @@ -222,7 +235,7 @@ def test_counter_system( chunk_dur = block_size / fs else: # Note: float dispatch_rate will yield different number of samples than expected by target_dur and fs - chunk_dur = 1. / dispatch_rate + chunk_dur = 1.0 / dispatch_rate target_messages = int(target_dur / chunk_dur) test_filename = get_test_fn(test_name) @@ -240,7 +253,7 @@ def test_counter_system( ), term_settings=TerminateOnTotalSettings( total=target_messages, - ) + ), ) system = CounterTestSystem(settings) ez.run(SYSTEM=system) @@ -269,11 +282,7 @@ def test_counter_system( # TEST SIN # -def test_sin_gen( - freq: float = 1.0, - amp: float = 1.0, - phase: float = 0.0 -): +def test_sin_gen(freq: float = 1.0, amp: float = 1.0, phase: float = 0.0): axis: typing.Optional[str] = "time" srate = max(4.0 * freq, 1000.0) sim_dur = 30.0 @@ -282,13 +291,16 @@ def test_sin_gen( axis_idx = 0 messages = [] - for split_dat in np.array_split(np.arange(n_samples)[:, None], n_msgs, axis=axis_idx): + for split_dat in np.array_split( + np.arange(n_samples)[:, None], n_msgs, axis=axis_idx + ): _time_axis = AxisArray.Axis.TimeAxis(fs=srate, offset=float(split_dat[0, 0])) messages.append( AxisArray(split_dat, dims=["time", "ch"], axes={"time": _time_axis}) ) - def f_test(t): return amp * np.sin(2 * np.pi * freq * t + phase) + def f_test(t): + return amp * np.sin(2 * np.pi * freq * t + phase) gen = sin(axis=axis, freq=freq, amp=amp, phase=phase) results = [] @@ -297,7 +309,9 @@ def f_test(t): return amp * np.sin(2 * np.pi * freq * t + phase) assert np.allclose(res.data, f_test(msg.data / srate)) results.append(res) concat_ax_arr = AxisArray.concatenate(*results, dim="time") - assert np.allclose(concat_ax_arr.data, f_test(np.arange(n_samples) / srate)[:, None]) + assert np.allclose( + concat_ax_arr.data, f_test(np.arange(n_samples) / srate)[:, None] + ) # TODO: test SinGenerator in a system. @@ -306,7 +320,9 @@ def f_test(t): return amp * np.sin(2 * np.pi * freq * t + phase) class EEGSynthSettingsTest(ez.Settings): synth_settings: EEGSynthSettings log_settings: MessageLoggerSettings - term_settings: TerminateOnTotalSettings = field(default_factory=TerminateOnTotalSettings) + term_settings: TerminateOnTotalSettings = field( + default_factory=TerminateOnTotalSettings + ) class EEGSynthIntegrationTest(ez.Collection): @@ -324,7 +340,7 @@ def configure(self) -> None: def network(self) -> ez.NetworkDefinition: return ( (self.SOURCE.OUTPUT_SIGNAL, self.SINK.INPUT_MESSAGE), - (self.SINK.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE) + (self.SINK.OUTPUT_MESSAGE, self.TERM.INPUT_MESSAGE), ) @@ -332,7 +348,7 @@ def test_eegsynth_system( test_name: typing.Optional[str] = None, ): # Just a quick test to make sure the system runs. We aren't checking validity of values or anything. - fs = 500. + fs = 500.0 n_time = 100 # samples per block. dispatch_rate = fs / n_time target_dur = 2.0 target_messages = int(target_dur * fs / n_time) @@ -352,7 +368,7 @@ def test_eegsynth_system( ), term_settings=TerminateOnTotalSettings( total=target_messages, - ) + ), ) system = EEGSynthIntegrationTest(settings) diff --git a/tests/test_wavelets.py b/tests/test_wavelets.py index 31cbfd0..7511cbf 100644 --- a/tests/test_wavelets.py +++ b/tests/test_wavelets.py @@ -17,7 +17,9 @@ def make_chirp(t, t0, a): def scratch(): scales = np.geomspace(4, 256, num=35) - wavelets = [f"cmor{x:.1f}-{y:.1f}" for x in [0.5, 1.5, 2.5] for y in [0.5, 1.0, 1.5]] + wavelets = [ + f"cmor{x:.1f}-{y:.1f}" for x in [0.5, 1.5, 2.5] for y in [0.5, 1.0, 1.5] + ] wavelet = wavelets[1] # Generate test signal @@ -51,6 +53,7 @@ def scratch(): conv = np.convolve(chirp[0], int_psi_scales[-1]) import matplotlib.pyplot as plt + plt.plot(chirp[0]) plt.plot(int_psi_scales[-1]) plt.plot(conv) @@ -69,7 +72,7 @@ def test_cwt(): # Generate test signal fs = 1000 dur = 2.0 - tvec = np.arange(int(dur*fs)) / fs + tvec = np.arange(int(dur * fs)) / fs chirp1, frequency1 = make_chirp(tvec, 0.2, 9) chirp2, frequency2 = make_chirp(tvec, 0.1, 5) chirp = chirp1 + 0.6 * chirp2 @@ -83,20 +86,21 @@ def test_cwt(): for idx in range(0, len(tvec), step_size): in_messages.append( AxisArray( - data=chirp[:, idx:idx+step_size], + data=chirp[:, idx : idx + step_size], dims=["ch", "time"], - axes={ - "time": AxisArray.Axis.TimeAxis(offset=tvec[idx], fs=fs) - } + axes={"time": AxisArray.Axis.TimeAxis(offset=tvec[idx], fs=fs)}, ) ) # Prepare expected output from pywt.cwt - expected, freqs = pywt.cwt(chirp, scales, wavelet, 1/fs, method="conv", axis=-1) - expected = np.swapaxes(expected, 0, 1) # Swap scales and channels -> ch, freqs, time + expected, freqs = pywt.cwt(chirp, scales, wavelet, 1 / fs, method="conv", axis=-1) + # Swap scales and channels -> ch, freqs, time + expected = np.swapaxes(expected, 0, 1) # Prep filterbank - gen = cwt(scales=scales, wavelet=wavelet, min_phase=MinPhaseMode.HOMOMORPHIC, axis="time") + gen = cwt( + scales=scales, wavelet=wavelet, min_phase=MinPhaseMode.HOMOMORPHIC, axis="time" + ) # Pass the messages out_messages = [gen.send(in_messages[0])] @@ -108,6 +112,7 @@ def test_cwt(): if False: # Debug visualize result import matplotlib.pyplot as plt + tmp = result.data title = "ezmsg minphase homomorphic" # tmp = expected @@ -118,7 +123,9 @@ def test_cwt(): for ch_ix in range(nch): axes[0, ch_ix].set_title(f"Channel {ch_ix}") axes[0, ch_ix].plot(tvec, chirp[ch_ix]) - _ = axes[1, ch_ix].pcolormesh(tvec[:tmp.shape[-1]], freqs, np.abs(tmp[ch_ix, :-1, :-1])) + _ = axes[1, ch_ix].pcolormesh( + tvec[: tmp.shape[-1]], freqs, np.abs(tmp[ch_ix, :-1, :-1]) + ) axes[1, ch_ix].set_yscale("log") axes[1, ch_ix].set_xlabel("Time (s)") axes[1, ch_ix].set_ylabel("Frequency (Hz)") diff --git a/tests/test_window.py b/tests/test_window.py index 1c04f7f..9d8b490 100644 --- a/tests/test_window.py +++ b/tests/test_window.py @@ -22,7 +22,19 @@ from util import get_test_fn, assert_messages_equal -def calculate_expected_results(orig, fs, win_shift, zero_pad, msg_block_size, shift_len, win_len, nchans, data_len, n_msgs, win_ax): +def calculate_expected_results( + orig, + fs, + win_shift, + zero_pad, + msg_block_size, + shift_len, + win_len, + nchans, + data_len, + n_msgs, + win_ax, +): # For the calculation, we assume time_ax is last then transpose if necessary at the end. expected = orig.copy() tvec = np.arange(orig.shape[1]) / fs @@ -35,7 +47,9 @@ def calculate_expected_results(orig, fs, win_shift, zero_pad, msg_block_size, sh n_cut = win_len n_keep = win_len - n_cut if n_keep > 0: - expected = np.concatenate((np.zeros((nchans, win_len))[..., -n_keep:], expected), axis=-1) + expected = np.concatenate( + (np.zeros((nchans, win_len))[..., -n_keep:], expected), axis=-1 + ) tvec = np.hstack(((np.arange(-win_len, 0) / fs)[-n_keep:], tvec)) # Moving window -- assumes step size of 1 expected = sliding_window_view(expected, win_len, axis=-1) @@ -46,7 +60,9 @@ def calculate_expected_results(orig, fs, win_shift, zero_pad, msg_block_size, sh # If the window length is smaller than the block size then we only the tail of each block. first = max(min(msg_block_size, data_len) - win_len, 0) if tvec[::msg_block_size].shape[0] < n_msgs: - expected = np.concatenate((expected[:, first::msg_block_size], expected[:, -1:]), axis=1) + expected = np.concatenate( + (expected[:, first::msg_block_size], expected[:, -1:]), axis=1 + ) tvec = np.hstack((tvec[first::msg_block_size, 0], tvec[-1:, 0])) else: expected = expected[:, first::msg_block_size] @@ -71,7 +87,7 @@ def test_window_gen_nodur(): test_msg = AxisArray( data=data, dims=["ch", "time"], - axes=frozendict({"time": AxisArray.Axis.TimeAxis(fs=500., offset=0.)}) + axes=frozendict({"time": AxisArray.Axis.TimeAxis(fs=500.0, offset=0.0)}), ) backup = [copy.deepcopy(test_msg)] gen = windowing(window_dur=None) @@ -89,13 +105,13 @@ def test_window_gen_nodur(): @pytest.mark.parametrize("fs", [10.0, 500.0]) @pytest.mark.parametrize("time_ax", [0, 1]) def test_window_generator( - msg_block_size: int, - newaxis: typing.Optional[str], - win_dur: float, - win_shift: typing.Optional[float], - zero_pad: str, - fs: float, - time_ax: int + msg_block_size: int, + newaxis: typing.Optional[str], + win_dur: float, + win_shift: typing.Optional[float], + zero_pad: str, + fs: float, + time_ax: int, ): nchans = 3 @@ -111,24 +127,36 @@ def test_window_generator( n_msgs = int(np.ceil(data_len / msg_block_size)) # Instantiate the generator function - gen = windowing(axis="time", newaxis=newaxis, window_dur=win_dur, window_shift=win_shift, zero_pad_until=zero_pad) + gen = windowing( + axis="time", + newaxis=newaxis, + window_dur=win_dur, + window_shift=win_shift, + zero_pad_until=zero_pad, + ) # Create inputs and send them to the generator, collecting the results along the way. test_msg = AxisArray( data[..., ()], dims=["ch", "time"] if time_ax == 1 else ["time", "ch"], - axes=frozendict({"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.)}) + axes=frozendict({"time": AxisArray.Axis.TimeAxis(fs=fs, offset=0.0)}), ) messages = [] backup = [] results = [] for msg_ix in range(n_msgs): - msg_data = data[..., msg_ix * msg_block_size:(msg_ix+1) * msg_block_size] + msg_data = data[..., msg_ix * msg_block_size : (msg_ix + 1) * msg_block_size] if time_ax == 0: msg_data = np.ascontiguousarray(msg_data.T) - test_msg = replace(test_msg, data=msg_data, axes={ - "time": AxisArray.Axis.TimeAxis(fs=fs, offset=tvec[msg_ix * msg_block_size]) - }) + test_msg = replace( + test_msg, + data=msg_data, + axes={ + "time": AxisArray.Axis.TimeAxis( + fs=fs, offset=tvec[msg_ix * msg_block_size] + ) + }, + ) messages.append(test_msg) backup.append(copy.deepcopy(test_msg)) win_msg = gen.send(test_msg) @@ -137,25 +165,43 @@ def test_window_generator( assert_messages_equal(messages, backup) # Check each return value's metadata (offsets checked at end) - expected_dims = test_msg.dims[:time_ax] + [newaxis or "win"] + test_msg.dims[time_ax:] + expected_dims = ( + test_msg.dims[:time_ax] + [newaxis or "win"] + test_msg.dims[time_ax:] + ) for msg in results: - assert msg.axes["time"].gain == 1/fs + assert msg.axes["time"].gain == 1 / fs assert msg.dims == expected_dims assert (newaxis or "win") in msg.axes - assert msg.axes[(newaxis or "win")].gain == (0.0 if win_shift is None else shift_len / fs) + assert msg.axes[(newaxis or "win")].gain == ( + 0.0 if win_shift is None else shift_len / fs + ) # Post-process the results to yield a single data array and a single vector of offsets. win_ax = time_ax time_ax = win_ax + 1 result = np.concatenate([_.data for _ in results], win_ax) - offsets = np.hstack([ - _.axes[newaxis or "win"].offset + _.axes[newaxis or "win"].gain * np.arange(_.data.shape[win_ax]) - for _ in results - ]) + offsets = np.hstack( + [ + _.axes[newaxis or "win"].offset + + _.axes[newaxis or "win"].gain * np.arange(_.data.shape[win_ax]) + for _ in results + ] + ) # Calculate the expected results for comparison. - expected, tvec = calculate_expected_results(data, fs, win_shift, zero_pad, msg_block_size, shift_len, win_len, - nchans, data_len, n_msgs, win_ax) + expected, tvec = calculate_expected_results( + data, + fs, + win_shift, + zero_pad, + msg_block_size, + shift_len, + win_len, + nchans, + data_len, + n_msgs, + win_ax, + ) # Compare results to expected assert np.array_equal(result, expected) @@ -209,15 +255,18 @@ def network(self) -> ez.NetworkDefinition: # It takes >15 minutes to go through the full set of combinations tested for the generator. # We need only test a subset to assert integration is correct. -@pytest.mark.parametrize("msg_block_size, newaxis, win_dur, win_shift, zero_pad, fs", [ - (1, None, 0.2, None, "input", 10.0), - (20, None, 0.2, None, "input", 10.0), - (1, "step", 0.2, None, "input", 10.0), - (10, "step", 0.2, 1.0, "shift", 500.0), - (20, "step", 1.0, 1.0, "shift", 500.0), - (10, "step", 1.0, 1.0, "none", 500.0), - (20, None, None, None, "input", 10.0), -]) +@pytest.mark.parametrize( + "msg_block_size, newaxis, win_dur, win_shift, zero_pad, fs", + [ + (1, None, 0.2, None, "input", 10.0), + (20, None, 0.2, None, "input", 10.0), + (1, "step", 0.2, None, "input", 10.0), + (10, "step", 0.2, 1.0, "shift", 500.0), + (20, "step", 1.0, 1.0, "shift", 500.0), + (10, "step", 1.0, 1.0, "none", 500.0), + (20, None, None, None, "input", 10.0), + ], +) def test_window_system( msg_block_size: int, newaxis: typing.Optional[str], @@ -249,7 +298,7 @@ def test_window_system( newaxis=newaxis, window_dur=win_dur, window_shift=win_shift, - zero_pad_until=zero_pad + zero_pad_until=zero_pad, ), log_settings=MessageLoggerSettings(output=test_filename), term_settings=TerminateTestSettings(time=1.0), # sec @@ -269,8 +318,11 @@ def test_window_system( # In this test, we should have consistent dimensions assert msg.dims == ([newaxis, "time", "ch"] if newaxis else ["time", "ch"]) # Window should always output the same shape data - assert msg.shape[msg.get_axis_idx("ch")] == 1 # Counter yields only one channel. - assert msg.shape[msg.get_axis_idx("time")] == (msg_block_size if win_dur is None else win_len) + assert msg.shape[msg.get_axis_idx("ch")] == 1 + # Counter yields only one channel. + assert msg.shape[msg.get_axis_idx("time")] == ( + msg_block_size if win_dur is None else win_len + ) ez.logger.info("Consistent metadata!") @@ -279,10 +331,13 @@ def test_window_system( if newaxis is None: offsets = np.array([_.axes["time"].offset for _ in messages]) else: - offsets = np.hstack([ - _.axes[newaxis].offset + _.axes[newaxis].gain * np.arange(_.data.shape[0]) - for _ in messages - ]) + offsets = np.hstack( + [ + _.axes[newaxis].offset + + _.axes[newaxis].gain * np.arange(_.data.shape[0]) + for _ in messages + ] + ) # If this test was performed in "one-to-one" mode, we should # have one window output per message pushed to Window @@ -297,8 +352,19 @@ def test_window_system( # Calculate the expected results for comparison. sent_data = np.arange(num_msgs * msg_block_size)[None, :] - expected, tvec = calculate_expected_results(sent_data, fs, win_shift, zero_pad, msg_block_size, shift_len, win_len, - 1, data_len, num_msgs, 0) + expected, tvec = calculate_expected_results( + sent_data, + fs, + win_shift, + zero_pad, + msg_block_size, + shift_len, + win_len, + 1, + data_len, + num_msgs, + 0, + ) # Compare results to expected if win_dur is None: