diff --git a/ChangeLog.txt b/ChangeLog.txt index 6eb3b9e..30fa799 100644 --- a/ChangeLog.txt +++ b/ChangeLog.txt @@ -1,4 +1,6 @@ - master: + - changed FullMarginalQuerySet to raise ValueErrors if features in provided query feature sets do not match features present in data + - changed NapsuMQModel to no longer raise an error if forced_queries_in_automatic_selection is None instead of an empty iterable - changed InferenceModel.fit: added show_progress and return_diagnostics arguments, removed model specific kwargs (breaking) NapsuMQModel.fit and DPVIModel.fit changed accordingly - DPVIResult no longer contains the final ELBO from model fitting (breaking) diff --git a/tests/napsu_mq/marginal_query_test.py b/tests/napsu_mq/marginal_query_test.py index bafabed..edd2250 100644 --- a/tests/napsu_mq/marginal_query_test.py +++ b/tests/napsu_mq/marginal_query_test.py @@ -151,21 +151,13 @@ def test_canonical_queries_large_nonbinary_domain(self): naive_bayes_cross_queries = FullMarginalQuerySet([(0, 1), (0, 2), (0, 3), (1, 4), (2, 5)], domain.value_counts_by_col) naive_bayes_3_way_cross_queries = FullMarginalQuerySet([(0, 1), (0, 2), (0, 3), (1, 4, 5), (2, 5)], domain.value_counts_by_col) - naive_bayes_3_way_cross_queries_missing = FullMarginalQuerySet([(0, 1), (0, 2), (0, 3), (1, 4, 3)], - domain.value_counts_by_col) one_way_marginals = FullMarginalQuerySet([(0,), (1,), (2,), (3,), (4,), (5,)], domain.value_counts_by_col) - one_way_marginals_missing = FullMarginalQuerySet([(0,), (1,), (2,), (3,), (4,)], domain.value_counts_by_col) canon_queries = naive_bayes_3_way_cross_queries.get_canonical_queries().flatten() n_canon_queries = len(canon_queries.queries) rank = self.query_matrix_rank(domain, canon_queries) self.assertEqual(rank, n_canon_queries + 1) - canon_queries = naive_bayes_3_way_cross_queries_missing.get_canonical_queries().flatten() - n_canon_queries = len(canon_queries.queries) - rank = self.query_matrix_rank(domain, canon_queries) - self.assertEqual(rank, n_canon_queries + 1) - canon_queries = naive_bayes_cross_queries.get_canonical_queries().flatten() n_canon_queries = len(canon_queries.queries) rank = self.query_matrix_rank(domain, canon_queries) @@ -177,11 +169,38 @@ def test_canonical_queries_large_nonbinary_domain(self): self.assertEqual(n_canon_queries, 1 + 1 + 2 + 3 + 4 + 5) self.assertEqual(rank, n_canon_queries + 1) - canon_queries = one_way_marginals_missing.get_canonical_queries().flatten() - n_canon_queries = len(canon_queries.queries) - rank = self.query_matrix_rank(domain, canon_queries) - self.assertEqual(n_canon_queries, 1 + 1 + 2 + 3 + 4) - self.assertEqual(rank, n_canon_queries + 1) + +class FullMarginalQuerySetTests(unittest.TestCase): + + def setUp(self) -> None: + self.value_counts_by_feature = { + 'A': 3, + 'B': 3, + 'C': 2, + } + + def test_single_query(self) -> None: + feature_sets = [('A', 'B', 'C')] + query_set = FullMarginalQuerySet( + feature_sets=feature_sets, + value_counts_by_feature=self.value_counts_by_feature + ) + self.assertEqual(set(feature_sets), query_set.queries.keys()) + self.assertEqual(self.value_counts_by_feature, query_set.value_counts_by_feature) + + def test_incomplete_queries(self) -> None: + with self.assertRaisesRegex(ValueError, r"not covered.*C.*"): + FullMarginalQuerySet( + feature_sets=[('A', 'B'), ('A',), ('D',)], + value_counts_by_feature=self.value_counts_by_feature + ) + + def test_nonexistent_features(self) -> None: + with self.assertRaisesRegex(ValueError, r"not present.*D.*"): + FullMarginalQuerySet( + feature_sets=[('A', 'B'), ('B', 'C'), ('D',)], + value_counts_by_feature=self.value_counts_by_feature + ) if __name__ == "__main__": diff --git a/tests/napsu_mq/mst_test.py b/tests/napsu_mq/mst_test.py index ad5fc05..6dd8046 100644 --- a/tests/napsu_mq/mst_test.py +++ b/tests/napsu_mq/mst_test.py @@ -15,11 +15,12 @@ import unittest - # TODO: write tests + class MSTTest(unittest.TestCase): - def setUp(self): + + def setUp(self) -> None: pass def test_MST_selection(self): diff --git a/tests/napsu_mq/napsu_mq_test.py b/tests/napsu_mq/napsu_mq_test.py index f5f2829..f41d2dc 100644 --- a/tests/napsu_mq/napsu_mq_test.py +++ b/tests/napsu_mq/napsu_mq_test.py @@ -275,6 +275,36 @@ def test_NAPSUMQ_model_fit_rejects_pure_integer_data(self) -> None: with self.assertRaises(ValueError): model.fit(data=data, rng=rng, epsilon=1, delta=(n ** (-2))) + def test_NAPSUMQ_incomplete_queries(self) -> None: + n = 4 + + data = pd.DataFrame({ + 'A': np.random.randint(500, 1000, size=n), + 'B': np.random.randint(500, 1000, size=n), + 'C': np.random.randint(500, 1000, size=n) + }, dtype='category') + + rng = d3p.random.PRNGKey(42) + + model = NapsuMQModel(queries=[('A', 'B'), ('A',), ('D',)]) + with self.assertRaisesRegex(ValueError, r"not covered.*C.*"): + model.fit(data=data, rng=rng, epsilon=1, delta=(n ** (-2))) + + def test_NAPSUMQ_nonexistent_features(self) -> None: + n = 4 + + data = pd.DataFrame({ + 'A': np.random.randint(500, 1000, size=n), + 'B': np.random.randint(500, 1000, size=n), + 'C': np.random.randint(500, 1000, size=n) + }, dtype='category') + + rng = d3p.random.PRNGKey(42) + + model = NapsuMQModel(queries=[('A', 'B'), ('B', 'C'), ('D',)]) + with self.assertRaisesRegex(ValueError, r"not present.*D.*"): + model.fit(data=data, rng=rng, epsilon=1, delta=(n ** (-2))) + class TestNapsuMQResult(unittest.TestCase): diff --git a/twinify/napsu_mq/marginal_query.py b/twinify/napsu_mq/marginal_query.py index 970b243..7e1fbdb 100644 --- a/twinify/napsu_mq/marginal_query.py +++ b/twinify/napsu_mq/marginal_query.py @@ -69,6 +69,16 @@ def __init__(self, feature_sets: Iterable[Tuple], value_counts_by_feature: Dict[ self.feature_by_index = list(self.value_counts_by_feature.keys()) + all_features_set = set(self.value_counts_by_feature.keys()) + covered_features = set.union(*(set(fs) for fs in feature_sets)) if len(feature_sets) > 0 else set() + if covered_features != all_features_set: + not_covered = all_features_set - covered_features + if not_covered: + raise ValueError(f"The provided query feature sets (feature_sets) must cover all features in the data. Features not covered: {not_covered}") + else: + not_exists = covered_features - all_features_set + raise ValueError(f"The provided query feature sets (feature_sets) cover features that are not present in the data: {not_exists}.") + self.int_feature_sets = [ tuple(self.feature_by_index.index(feature) for feature in feature_set) for feature_set in self.feature_sets @@ -202,9 +212,8 @@ def get_canonical_queries(self, show_progressbar=False) -> 'FullMarginalQuerySet original_clique_queries[original_clique].append(new_query) canonical_queries = {key: QueryList(queries) for key, queries in original_clique_queries.items()} - new_fmqs = FullMarginalQuerySet([], self.value_counts_by_feature) - new_fmqs.queries = canonical_queries - new_fmqs.feature_sets = list(canonical_queries.keys()) + new_fmqs = FullMarginalQuerySet(canonical_queries.keys(), self.value_counts_by_feature) + new_fmqs.queries = canonical_queries # TODO: this seems quite hacky here; being able to create FullMarginalQuerySet with pre-made Query objects would be a better way return new_fmqs diff --git a/twinify/napsu_mq/napsu_mq.py b/twinify/napsu_mq/napsu_mq.py index d9d05ff..2e5a668 100644 --- a/twinify/napsu_mq/napsu_mq.py +++ b/twinify/napsu_mq/napsu_mq.py @@ -144,7 +144,7 @@ class NapsuMQModel(InferenceModel): def __init__( self, queries: Optional[Iterable[FrozenSet[str]]] = None, - forced_queries_in_automatic_selection: Optional[Iterable[FrozenSet[str]]] = tuple(), + forced_queries_in_automatic_selection: Optional[Iterable[FrozenSet[str]]] = None, inference_config: NapsuMQInferenceConfig = NapsuMQInferenceConfig(), # required_marginals: Iterable[FrozenSet[str]] = tuple() ) -> None: @@ -157,7 +157,7 @@ def __init__( super().__init__() if forced_queries_in_automatic_selection is None: - raise ValueError("forced_queries_in_automatic_selection may not be None") + forced_queries_in_automatic_selection = tuple() self._forced_queries_in_automatic_selection = forced_queries_in_automatic_selection self._queries = queries self._inference_config = inference_config