diff --git a/spark_matcher/matching_base/matching_base.py b/spark_matcher/matching_base/matching_base.py index b36b977..5ea1fa1 100644 --- a/spark_matcher/matching_base/matching_base.py +++ b/spark_matcher/matching_base/matching_base.py @@ -28,17 +28,18 @@ def __init__(self, spark_session: SparkSession, table_checkpointer: Optional[Tab blocking_recall: float = 1.0, n_perfect_train_matches=1, n_train_samples: int = 100_000, ratio_hashed_samples: float = 0.5, scorer: Optional[Scorer] = None, verbose: int = 0): self.spark_session = spark_session - self.table_checkpointer = table_checkpointer - if not self.table_checkpointer: - if checkpoint_dir: - self.table_checkpointer = ParquetCheckPointer(self.spark_session, checkpoint_dir, - "checkpoint_deduplicator") - else: - warnings.warn( - 'Either `table_checkpointer` or `checkpoint_dir` should be provided. This instance can only be used' - ' when loading a previously saved instance.') - - if col_names: + if not self.table_checkpointer and checkpoint_dir: + self.table_checkpointer = ParquetCheckPointer(self.spark_session, checkpoint_dir, "checkpoint_deduplicator") + elif table_checkpointer: + self.table_checkpointer = table_checkpointer + else: + warnings.warn( + 'Either `table_checkpointer` or `checkpoint_dir` should be provided. This instance can only be used ' + 'when loading a previously saved instance.') + + if col_names and field_info: + raise ValueError("Either `col_names` or `field_info` should be provided.") + elif col_names: self.col_names = col_names self.field_info = {col_name: [token_set_ratio, token_sort_ratio] for col_name in self.col_names}