Skip to content

Commit

Permalink
changed FullMarginalQuerySet/NapsuMQModel to raise ValueErrors if fea…
Browse files Browse the repository at this point in the history
…tures in provided query feature sets do not match features present in data

- also changed NapsuMQModel to no longer raise an error if forced_queries_in_automatic_selection is None instead of an empty iterable
  • Loading branch information
lumip committed Mar 8, 2024
1 parent fb1f65c commit c522a90
Show file tree
Hide file tree
Showing 6 changed files with 81 additions and 20 deletions.
2 changes: 2 additions & 0 deletions ChangeLog.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
- master:
- changed FullMarginalQuerySet to raise ValueErrors if features in provided query feature sets do not match features present in data
- changed NapsuMQModel to no longer raise an error if forced_queries_in_automatic_selection is None instead of an empty iterable
- changed InferenceModel.fit: added show_progress and return_diagnostics arguments, removed model specific kwargs (breaking)
NapsuMQModel.fit and DPVIModel.fit changed accordingly
- DPVIResult no longer contains the final ELBO from model fitting (breaking)
Expand Down
45 changes: 32 additions & 13 deletions tests/napsu_mq/marginal_query_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,21 +151,13 @@ def test_canonical_queries_large_nonbinary_domain(self):
naive_bayes_cross_queries = FullMarginalQuerySet([(0, 1), (0, 2), (0, 3), (1, 4), (2, 5)], domain.value_counts_by_col)
naive_bayes_3_way_cross_queries = FullMarginalQuerySet([(0, 1), (0, 2), (0, 3), (1, 4, 5), (2, 5)],
domain.value_counts_by_col)
naive_bayes_3_way_cross_queries_missing = FullMarginalQuerySet([(0, 1), (0, 2), (0, 3), (1, 4, 3)],
domain.value_counts_by_col)
one_way_marginals = FullMarginalQuerySet([(0,), (1,), (2,), (3,), (4,), (5,)], domain.value_counts_by_col)
one_way_marginals_missing = FullMarginalQuerySet([(0,), (1,), (2,), (3,), (4,)], domain.value_counts_by_col)

canon_queries = naive_bayes_3_way_cross_queries.get_canonical_queries().flatten()
n_canon_queries = len(canon_queries.queries)
rank = self.query_matrix_rank(domain, canon_queries)
self.assertEqual(rank, n_canon_queries + 1)

canon_queries = naive_bayes_3_way_cross_queries_missing.get_canonical_queries().flatten()
n_canon_queries = len(canon_queries.queries)
rank = self.query_matrix_rank(domain, canon_queries)
self.assertEqual(rank, n_canon_queries + 1)

canon_queries = naive_bayes_cross_queries.get_canonical_queries().flatten()
n_canon_queries = len(canon_queries.queries)
rank = self.query_matrix_rank(domain, canon_queries)
Expand All @@ -177,11 +169,38 @@ def test_canonical_queries_large_nonbinary_domain(self):
self.assertEqual(n_canon_queries, 1 + 1 + 2 + 3 + 4 + 5)
self.assertEqual(rank, n_canon_queries + 1)

canon_queries = one_way_marginals_missing.get_canonical_queries().flatten()
n_canon_queries = len(canon_queries.queries)
rank = self.query_matrix_rank(domain, canon_queries)
self.assertEqual(n_canon_queries, 1 + 1 + 2 + 3 + 4)
self.assertEqual(rank, n_canon_queries + 1)

class FullMarginalQuerySetTests(unittest.TestCase):

def setUp(self) -> None:
self.value_counts_by_feature = {
'A': 3,
'B': 3,
'C': 2,
}

def test_single_query(self) -> None:
feature_sets = [('A', 'B', 'C')]
query_set = FullMarginalQuerySet(
feature_sets=feature_sets,
value_counts_by_feature=self.value_counts_by_feature
)
self.assertEqual(set(feature_sets), query_set.queries.keys())
self.assertEqual(self.value_counts_by_feature, query_set.value_counts_by_feature)

def test_incomplete_queries(self) -> None:
with self.assertRaisesRegex(ValueError, r"not covered.*C.*"):
FullMarginalQuerySet(
feature_sets=[('A', 'B'), ('A',), ('D',)],
value_counts_by_feature=self.value_counts_by_feature
)

def test_nonexistent_features(self) -> None:
with self.assertRaisesRegex(ValueError, r"not present.*D.*"):
FullMarginalQuerySet(
feature_sets=[('A', 'B'), ('B', 'C'), ('D',)],
value_counts_by_feature=self.value_counts_by_feature
)


if __name__ == "__main__":
Expand Down
5 changes: 3 additions & 2 deletions tests/napsu_mq/mst_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

import unittest


# TODO: write tests


class MSTTest(unittest.TestCase):
def setUp(self):

def setUp(self) -> None:
pass

def test_MST_selection(self):
Expand Down
30 changes: 30 additions & 0 deletions tests/napsu_mq/napsu_mq_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,36 @@ def test_NAPSUMQ_model_fit_rejects_pure_integer_data(self) -> None:
with self.assertRaises(ValueError):
model.fit(data=data, rng=rng, epsilon=1, delta=(n ** (-2)))

def test_NAPSUMQ_incomplete_queries(self) -> None:
n = 4

data = pd.DataFrame({
'A': np.random.randint(500, 1000, size=n),
'B': np.random.randint(500, 1000, size=n),
'C': np.random.randint(500, 1000, size=n)
}, dtype='category')

rng = d3p.random.PRNGKey(42)

model = NapsuMQModel(queries=[('A', 'B'), ('A',), ('D',)])
with self.assertRaisesRegex(ValueError, r"not covered.*C.*"):
model.fit(data=data, rng=rng, epsilon=1, delta=(n ** (-2)))

def test_NAPSUMQ_nonexistent_features(self) -> None:
n = 4

data = pd.DataFrame({
'A': np.random.randint(500, 1000, size=n),
'B': np.random.randint(500, 1000, size=n),
'C': np.random.randint(500, 1000, size=n)
}, dtype='category')

rng = d3p.random.PRNGKey(42)

model = NapsuMQModel(queries=[('A', 'B'), ('B', 'C'), ('D',)])
with self.assertRaisesRegex(ValueError, r"not present.*D.*"):
model.fit(data=data, rng=rng, epsilon=1, delta=(n ** (-2)))


class TestNapsuMQResult(unittest.TestCase):

Expand Down
15 changes: 12 additions & 3 deletions twinify/napsu_mq/marginal_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def __init__(self, feature_sets: Iterable[Tuple], value_counts_by_feature: Dict[

self.feature_by_index = list(self.value_counts_by_feature.keys())

all_features_set = set(self.value_counts_by_feature.keys())
covered_features = set.union(*(set(fs) for fs in feature_sets)) if len(feature_sets) > 0 else set()
if covered_features != all_features_set:
not_covered = all_features_set - covered_features
if not_covered:
raise ValueError(f"The provided query feature sets (feature_sets) must cover all features in the data. Features not covered: {not_covered}")
else:
not_exists = covered_features - all_features_set
raise ValueError(f"The provided query feature sets (feature_sets) cover features that are not present in the data: {not_exists}.")

self.int_feature_sets = [
tuple(self.feature_by_index.index(feature) for feature in feature_set)
for feature_set in self.feature_sets
Expand Down Expand Up @@ -202,9 +212,8 @@ def get_canonical_queries(self, show_progressbar=False) -> 'FullMarginalQuerySet
original_clique_queries[original_clique].append(new_query)

canonical_queries = {key: QueryList(queries) for key, queries in original_clique_queries.items()}
new_fmqs = FullMarginalQuerySet([], self.value_counts_by_feature)
new_fmqs.queries = canonical_queries
new_fmqs.feature_sets = list(canonical_queries.keys())
new_fmqs = FullMarginalQuerySet(canonical_queries.keys(), self.value_counts_by_feature)
new_fmqs.queries = canonical_queries # TODO: this seems quite hacky here; being able to create FullMarginalQuerySet with pre-made Query objects would be a better way
return new_fmqs


Expand Down
4 changes: 2 additions & 2 deletions twinify/napsu_mq/napsu_mq.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class NapsuMQModel(InferenceModel):

def __init__(
self, queries: Optional[Iterable[FrozenSet[str]]] = None,
forced_queries_in_automatic_selection: Optional[Iterable[FrozenSet[str]]] = tuple(),
forced_queries_in_automatic_selection: Optional[Iterable[FrozenSet[str]]] = None,
inference_config: NapsuMQInferenceConfig = NapsuMQInferenceConfig(),
# required_marginals: Iterable[FrozenSet[str]] = tuple()
) -> None:
Expand All @@ -157,7 +157,7 @@ def __init__(

super().__init__()
if forced_queries_in_automatic_selection is None:
raise ValueError("forced_queries_in_automatic_selection may not be None")
forced_queries_in_automatic_selection = tuple()
self._forced_queries_in_automatic_selection = forced_queries_in_automatic_selection
self._queries = queries
self._inference_config = inference_config
Expand Down

0 comments on commit c522a90

Please sign in to comment.