Skip to content

Commit

Permalink
quick fix for release (#734)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinfeng <[email protected]>
  • Loading branch information
lijinf2 committed Sep 6, 2024
1 parent fc17cd0 commit bd106d3
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 2 additions & 0 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,8 @@ class LogisticRegression(
And it will automatically map pyspark parameters
to cuML parameters.
In the case of applying LogisticRegression on sparse vectors, Spark 3.4 or above is required.
Parameters
----------
featuresCol: str or List[str]
Expand Down
12 changes: 8 additions & 4 deletions python/src/spark_rapids_ml/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -851,7 +851,11 @@ def setAlgorithm(self: P, value: str) -> P:
"""
Sets the value of `algorithm`.
"""
assert value == "ivfflat", "Only IVFFLAT algorithm is currently supported"
assert value in {
"ivfflat",
"ivfpq",
"cagra",
}, "Only ivfflat, ivfpq, and cagra are currently supported"
self._set_params(algorithm=value)
return self

Expand Down Expand Up @@ -919,7 +923,7 @@ class ApproximateNearestNeighbors(
the default number of approximate nearest neighbors to retrieve for each query.
algorithm: str (default = 'ivfflat')
the algorithm parameter to be passed into cuML. It currently must be 'ivfflat' or 'ivfpq'. Other algorithms are expected to be supported later.
the algorithm parameter to be passed into cuML. It currently must be 'ivfflat', 'ivfpq' or 'cagra'. Other algorithms are expected to be supported later.
algoParams: Optional[Dict[str, Any]] (default = None)
if set, algoParam is used to configure the algorithm, on each data partition (or maxRecordsPerBatch if Arrow is enabled) of the item_df.
Expand Down Expand Up @@ -1455,7 +1459,7 @@ def _transform_internal(

start_time = time.time()

if nn_object is not "cagra":
if nn_object != "cagra":
nn_object.fit(item)
else:
from cuvs.neighbors import cagra
Expand All @@ -1473,7 +1477,7 @@ def _transform_internal(

start_time = time.time()

if nn_object is not "cagra":
if nn_object != "cagra":
distances, indices = nn_object.kneighbors(bcast_qfeatures.value)
else:
gpu_qfeatures = cp.array(
Expand Down

0 comments on commit bd106d3

Please sign in to comment.