Skip to content

Commit

Permalink
fix ensureIdCol to avoid using isSet(idCol)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinfeng <[email protected]>
  • Loading branch information
lijinf2 committed May 8, 2024
1 parent 3bde6d7 commit 6cef74a
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/src/spark_rapids_ml/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _ensureIdCol(self, df: DataFrame) -> DataFrame:
id_col_name = self.getIdCol()
df_withid = (
df
if self.isSet("idCol")
if id_col_name in df.columns
else df.select(monotonically_increasing_id().alias(id_col_name), "*")
)
return df_withid
Expand Down
11 changes: 11 additions & 0 deletions python/tests/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,17 @@ def assert_knn_metadata_equal(knn_metadata: List[List[str]]) -> None:
assert knnjoin_queries[i]["features"] == query[i][0]
assert knnjoin_queries[i]["metadata"] == query[i][1]

# Test fit(dataset, ParamMap) that copies existing estimator
# After copy, self.isSet("idCol") becomes true. But the added id column does not exist in the dataframe
paramMap = gpu_knn.extractParamMap()
gpu_model_v2 = gpu_knn.fit(item_df_withid, paramMap)

assert gpu_knn.isSet("idCol") is False
assert gpu_model_v2.isSet("idCol") is True

(_, _, knn_df_v2) = gpu_model_v2.kneighbors(query_df)
assert knn_df_v2.collect() == knn_df.collect()

return gpu_knn, gpu_model


Expand Down

0 comments on commit 6cef74a

Please sign in to comment.