Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add shape check to Dataset initialization #106

Merged
merged 5 commits into from
Jun 12, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pymare/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd

from pymare.utils import _listify
from pymare.utils import _check_inputs_shape, _listify

from .estimators import (
DerSimonianLaird,
Expand Down Expand Up @@ -94,6 +94,10 @@ def __init__(
self.X = X
self.X_names = names

_check_inputs_shape(self.y, self.X, "y", "X", row=True)
_check_inputs_shape(self.y, self.v, "y", "v", row=True, column=True)
_check_inputs_shape(self.y, self.n, "y", "n", row=True, column=True)

def _get_predictors(self, X, names, add_intercept):
if X is None and not add_intercept:
raise ValueError(
Expand Down
27 changes: 27 additions & 0 deletions pymare/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,37 @@
"""Tests for pymare.utils."""
import os.path as op

import numpy as np
import pytest

from pymare import utils


def test_get_resource_path():
"""Test nimare.utils.get_resource_path."""
print(utils.get_resource_path())
assert op.isdir(utils.get_resource_path())


def test_check_inputs_shape():
"""Test nimare.utils._check_inputs_shape."""
n_rows = 5
n_columns = 4
n_pred = 3
y = np.random.randint(1, 100, size=(n_rows, n_columns))
v = np.random.randint(1, 100, size=(n_rows + 1, n_columns))
n = np.random.randint(1, 100, size=(n_rows, n_columns))
X = np.random.randint(1, 100, size=(n_rows, n_pred))
X_names = [f"X{x}" for x in range(n_pred)]

utils._check_inputs_shape(y, X, "y", "X", row=True)
utils._check_inputs_shape(y, n, "y", "n", row=True, column=True)
utils._check_inputs_shape(X, np.array(X_names)[None, :], "X", "X_names", column=True)

# Raise error if the number of rows and columns of v don't match y
with pytest.raises(ValueError):
utils._check_inputs_shape(y, v, "y", "v", row=True, column=True)

# Raise error if neither row or column is True
with pytest.raises(ValueError):
utils._check_inputs_shape(y, n, "y", "n")
Comment on lines +31 to +37
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I know I approved already, but I just realized that, while the function allows Nones, that behavior isn't tested here. Can you test Nones? Not every Dataset will have v or n.

36 changes: 36 additions & 0 deletions pymare/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,39 @@ def _listify(obj):
This provides a simple way to accept flexible arguments.
"""
return obj if isinstance(obj, (list, tuple, type(None), np.ndarray)) else [obj]


def _check_inputs_shape(param1, param2, param1_name, param2_name, row=False, column=False):
"""Check whether 'param1' and 'param2' have the same shape.

Parameters
----------
param1 : array
param2 : array
param1_name : str
param2_name : str
row : bool, default to False.
column : bool, default to False.
"""
if (param1 is not None) and (param2 is not None):
if row and not column:
shape1 = param1.shape[0]
shape2 = param2.shape[0]
message = "rows"
elif column and not row:
shape1 = param1.shape[1]
shape2 = param2.shape[1]
message = "columns"
elif row and column:
shape1 = param1.shape
shape2 = param2.shape
message = "rows and columns"
else:
raise ValueError("At least one of the two parameters (row or column) should be True.")

if shape1 != shape2:
raise ValueError(
f"{param1_name} and {param2_name} should have the same number of {message}. "
f"You provided {param1_name} with shape {param1.shape} and {param2_name} "
f"with shape {param2.shape}."
)