Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support gpu version crossvalidator for RandomForestRegressor #303

Merged
merged 2 commits into from
Jun 26, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions python/src/spark_rapids_ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union, cast
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union

from pyspark.ml.evaluation import Evaluator, MulticlassClassificationEvaluator

Expand Down Expand Up @@ -49,7 +49,6 @@
CumlT,
TransformInputType,
_ConstructFunc,
_CumlModel,
_EvaluateFunc,
_TransformFunc,
alias,
Expand Down Expand Up @@ -324,26 +323,6 @@ def __init__(
self._model_json = model_json
self._rf_spark_model: Optional[SparkRandomForestClassificationModel] = None

@staticmethod
def _combine(models: List[_CumlModel]) -> "RandomForestClassificationModel":
assert len(models) > 0 and all(
isinstance(model, RandomForestClassificationModel) for model in models
)

casted_models = cast(List[RandomForestClassificationModel], models)
first_model = casted_models[0]

treelite_models = [model._treelite_model for model in casted_models]
model_jsons = [model._model_json for model in casted_models]
attrs = first_model.get_model_attributes()
assert attrs is not None
attrs["treelite_model"] = treelite_models
attrs["model_json"] = model_jsons
rf_model = RandomForestClassificationModel(**attrs)
first_model._copyValues(rf_model)
first_model._copy_cuml_params(rf_model)
return rf_model

def cpu(self) -> SparkRandomForestClassificationModel:
"""Return the PySpark ML RandomForestClassificationModel"""

Expand Down
4 changes: 2 additions & 2 deletions python/src/spark_rapids_ml/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,8 +973,8 @@ def _transformEvaluate(
"""
raise NotImplementedError()

@staticmethod
def _combine(models: List["_CumlModel"]) -> "_CumlModel":
@classmethod
def _combine(cls: Type["_CumlModel"], models: List["_CumlModel"]) -> "_CumlModel":
"""Combine a list of same type models into a model"""
raise NotImplementedError()

Expand Down
45 changes: 30 additions & 15 deletions python/src/spark_rapids_ml/metrics/RegressionMetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@
# limitations under the License.
#
import math
from typing import List
from collections import namedtuple
from typing import List, Optional, cast

from pyspark import Row
from pyspark.ml.evaluation import RegressionEvaluator

from spark_rapids_ml.core import pred

RegMetrics = namedtuple("RegMetrics", ("m2n", "m2", "l1", "mean", "total_count"))
reg_metrics = RegMetrics("m2n", "m2", "l1", "mean", "total_count")


# This class is aligning with Spark SummarizerBuffer scala version
class _SummarizerBuffer:
Expand Down Expand Up @@ -124,21 +131,8 @@ def _compute_variance(self) -> List[float]:
self._weight_square_sum / self._total_weight_sum
)
if denominator > 0.0:
delta_mean = self._curr_mean
real_variance = [
max(
(
self._curr_m2n[i]
# + delta_mean[i]
# * delta_mean[i]
# * self._curr_weight_sum[i]
# * (self._total_weight_sum - self._curr_weight_sum[i])
# / self._total_weight_sum
)
/ denominator,
0.0,
)
for i in range(self._num_cols)
max(self._curr_m2n[i] / denominator, 0.0) for i in range(self._num_cols)
]
else:
real_variance = [0] * self._num_cols
Expand Down Expand Up @@ -172,6 +166,27 @@ def create(
) -> "RegressionMetrics":
return RegressionMetrics(_SummarizerBuffer(mean, m2n, m2, l1, total_cnt))

@classmethod
def from_rows(cls, num_models: int, rows: List[Row]) -> List["RegressionMetrics"]:
"""The rows must contain pred.model_index, and mean/m2n/m2/l1/total_count"""
metrics: List[Optional["RegressionMetrics"]] = [None] * num_models

for row in rows:
index = row[pred.model_index]
metric = RegressionMetrics.create(
mean=row[reg_metrics.mean],
m2n=row[reg_metrics.m2n],
m2=row[reg_metrics.m2],
l1=row[reg_metrics.l1],
total_cnt=row[reg_metrics.total_count],
)
old_metric = metrics[index]
metrics[index] = (
old_metric.merge(metric) if old_metric is not None else metric
)

return cast(List["RegressionMetrics"], metrics)

def merge(self, other: "RegressionMetrics") -> "RegressionMetrics":
"""Merge other to self and return a new RegressionMetrics"""
summary = self._summary.merge(other._summary)
Expand Down
91 changes: 91 additions & 0 deletions python/src/spark_rapids_ml/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import pandas as pd
from pyspark import Row, keyword_only
from pyspark.ml.common import _py2java
from pyspark.ml.evaluation import Evaluator, RegressionEvaluator
from pyspark.ml.linalg import Vector, Vectors, _convert_to_vector
from pyspark.ml.regression import LinearRegressionModel as SparkLinearRegressionModel
from pyspark.ml.regression import LinearRegressionSummary
Expand Down Expand Up @@ -56,9 +57,12 @@
_CumlModelWithPredictionCol,
_EvaluateFunc,
_TransformFunc,
alias,
param_alias,
pred,
transform_evaluate,
)
from .metrics.RegressionMetrics import RegressionMetrics, reg_metrics
from .params import HasFeaturesCols, P, _CumlClass, _CumlParams
from .tree import (
_RandomForestClass,
Expand All @@ -68,6 +72,9 @@
)
from .utils import PartitionDescriptor, _get_spark_session, cudf_to_cuml_array, java_uid

if TYPE_CHECKING:
from pyspark.ml._typing import ParamMap

T = TypeVar("T")


Expand Down Expand Up @@ -784,6 +791,9 @@ def _is_classification(self) -> bool:
def _create_pyspark_model(self, result: Row) -> "RandomForestRegressionModel":
return RandomForestRegressionModel.from_row(result)

def _supportsTransformEvaluate(self, evaluator: Evaluator) -> bool:
return True if isinstance(evaluator, RegressionEvaluator) else False


class RandomForestRegressionModel(
_RandomForestRegressorClass,
Expand Down Expand Up @@ -833,3 +843,84 @@ def cpu(self) -> SparkRandomForestRegressionModel:

def _is_classification(self) -> bool:
return False

def _get_cuml_transform_func(
self, dataset: DataFrame, category: str = transform_evaluate.transform
) -> Tuple[_ConstructFunc, _TransformFunc, Optional[_EvaluateFunc],]:
_construct_rf, _predict, _ = super()._get_cuml_transform_func(dataset, category)

def _evaluate(
input: TransformInputType,
transformed: TransformInputType,
) -> pd.DataFrame:
# calculate the count of (label, prediction)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

count -> metrics

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Done

comb = pd.DataFrame(
{
"label": input[alias.label],
"prediction": transformed,
}
)
comb.insert(1, "label-prediction", comb["label"] - comb["prediction"])
total_cnt = comb.shape[0]
return pd.DataFrame(
data={
reg_metrics.mean: [comb.mean().to_list()],
reg_metrics.m2n: [(comb.var(ddof=0) * total_cnt).to_list()],
reg_metrics.m2: [comb.pow(2).sum().to_list()],
reg_metrics.l1: [comb.abs().sum().to_list()],
reg_metrics.total_count: total_cnt,
}
)

return _construct_rf, _predict, _evaluate

def _transformEvaluate(
self,
dataset: DataFrame,
evaluator: Evaluator,
params: Optional["ParamMap"] = None,
) -> List[float]:
"""
Transforms and evaluates the input dataset with optional parameters in a single pass.

Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame`
a dataset that contains labels/observations and predictions
evaluator: :py:class:`pyspark.ml.evaluation.Evaluator`
an evaluator user intends to use
params : dict, optional
an optional param map that overrides embedded params

Returns
-------
float
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

list of float?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Done

metric
"""

if not isinstance(evaluator, RegressionEvaluator):
raise NotImplementedError(f"{evaluator} is unsupported yet.")

if self.getLabelCol() not in dataset.schema.names:
raise RuntimeError("Label column is not existing.")

dataset = dataset.withColumnRenamed(self.getLabelCol(), alias.label)

schema = StructType(
[
StructField(pred.model_index, IntegerType()),
StructField(reg_metrics.mean, ArrayType(FloatType())),
StructField(reg_metrics.m2n, ArrayType(FloatType())),
StructField(reg_metrics.m2, ArrayType(FloatType())),
StructField(reg_metrics.l1, ArrayType(FloatType())),
StructField(reg_metrics.total_count, IntegerType()),
]
)

rows = super()._transform_evaluate_internal(dataset, schema).collect()
num_models = (
len(self._treelite_model) if isinstance(self._treelite_model, list) else 1
)

metrics = RegressionMetrics.from_rows(num_models, rows)
return [metric.evaluate(evaluator) for metric in metrics]
29 changes: 18 additions & 11 deletions python/src/spark_rapids_ml/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,7 @@
import math
import pickle
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Union,
cast,
)
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -598,3 +588,20 @@ def _transform(self, dataset: DataFrame) -> DataFrame:
return df.withColumn(
self.getPredictionCol(), df[self.getPredictionCol()].cast("double")
)

@classmethod
def _combine(
cls: Type["_RandomForestModel"], models: List["_RandomForestModel"] # type: ignore
) -> "_RandomForestModel":
assert len(models) > 0 and all(isinstance(model, cls) for model in models)
first_model = models[0]
treelite_models = [model._treelite_model for model in models]
model_jsons = [model._model_json for model in models]
attrs = first_model.get_model_attributes()
assert attrs is not None
attrs["treelite_model"] = treelite_models
attrs["model_json"] = model_jsons
rf_model = cls(**attrs)
first_model._copyValues(rf_model)
first_model._copy_cuml_params(rf_model)
return rf_model
51 changes: 38 additions & 13 deletions python/tests/test_random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
RandomForestClassificationModel as SparkRFClassificationModel,
)
from pyspark.ml.classification import RandomForestClassifier as SparkRFClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, RegressionEvaluator
from pyspark.ml.linalg import Vectors
from pyspark.ml.param import Param
from pyspark.ml.regression import RandomForestRegressionModel as SparkRFRegressionModel
Expand Down Expand Up @@ -61,6 +61,13 @@
RandomForest = TypeVar(
"RandomForest", Type[RandomForestClassifier], Type[RandomForestRegressor]
)

RandomForestEvaluator = TypeVar(
"RandomForestEvaluator",
Type[MulticlassClassificationEvaluator],
Type[RegressionEvaluator],
)

RandomForestModel = TypeVar(
"RandomForestModel",
Type[RandomForestClassificationModel],
Expand Down Expand Up @@ -797,37 +804,55 @@ def get_num_trees(
)


@pytest.mark.parametrize(
"estimator_evaluator",
[
(RandomForestClassifier, MulticlassClassificationEvaluator),
(RandomForestRegressor, RegressionEvaluator),
],
)
@pytest.mark.parametrize("feature_type", [feature_types.vector])
@pytest.mark.parametrize("data_type", [np.float32])
@pytest.mark.parametrize("data_shape", [(100, 8)], ids=idfn)
def test_crossvalidator_random_forest_classifier(
def test_crossvalidator_random_forest(
estimator_evaluator: Tuple[RandomForest, RandomForestEvaluator],
tmp_path: str,
feature_type: str,
data_type: np.dtype,
data_shape: Tuple[int, int],
) -> None:
RF, Evaluator = estimator_evaluator

# Train a toy model
X, _, y, _ = make_classification_dataset(
datatype=data_type,
nrows=data_shape[0],
ncols=data_shape[1],
n_classes=4,
n_informative=data_shape[1],
n_redundant=0,
n_repeated=0,
)

if RF == RandomForestClassifier:
X, _, y, _ = make_classification_dataset(
datatype=data_type,
nrows=data_shape[0],
ncols=data_shape[1],
n_classes=4,
n_informative=data_shape[1],
n_redundant=0,
n_repeated=0,
)
else:
X, _, y, _ = make_regression_dataset(
datatype=data_type,
nrows=data_shape[0],
ncols=data_shape[1],
)

with CleanSparkSession() as spark:
df, features_col, label_col = create_pyspark_dataframe(
spark, feature_type, data_type, X, y
)
assert label_col is not None

rfc = RandomForestClassifier()
rfc = RF()
rfc.setFeaturesCol(features_col)
rfc.setLabelCol(label_col)

evaluator = MulticlassClassificationEvaluator()
evaluator = Evaluator()
evaluator.setLabelCol(label_col)

grid = (
Expand Down