Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override call_cuml_fit_func to use Dataframe, model saving+loading as numpy #352

Merged
merged 4 commits into from
Aug 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
273 changes: 244 additions & 29 deletions python/src/spark_rapids_ml/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,19 @@
# limitations under the License.
#

import json
import os
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Generator,
Iterable,
Iterator,
List,
Optional,
Sequence,
Tuple,
Type,
Union,
Expand All @@ -30,7 +35,9 @@
import numpy as np
import pandas as pd
import pyspark
from pandas import DataFrame as PandasDataFrame
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasOutputCol
from pyspark.ml.util import DefaultParamsReader, DefaultParamsWriter, MLReader, MLWriter
from pyspark.sql import Column, DataFrame
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.types import (
Expand All @@ -48,16 +55,26 @@
CumlT,
FitInputType,
_ConstructFunc,
_CumlCommon,
_CumlEstimator,
_CumlEstimatorSupervised,
_CumlModel,
_CumlModelReader,
_CumlModelWriter,
_EvaluateFunc,
_TransformFunc,
alias,
param_alias,
transform_evaluate,
)
from .params import HasFeaturesCols, P, _CumlClass, _CumlParams
from .utils import _ArrayOrder, _concat_and_free, _get_spark_session
from .utils import (
_ArrayOrder,
_concat_and_free,
_get_spark_session,
_is_local,
get_logger,
)

if TYPE_CHECKING:
import cudf
Expand Down Expand Up @@ -346,26 +363,40 @@ def _fit(self, dataset: DataFrame) -> "UMAPModel":
# Force to single partition, single worker
self._num_workers = 1
if data_subset.rdd.getNumPartitions() != 1:
data_subset = data_subset.repartition(1)
data_subset = data_subset.coalesce(1)

maxRecordsPerBatch_str = _get_spark_session().conf.get(
"spark.sql.execution.arrow.maxRecordsPerBatch", "10000"
)
assert maxRecordsPerBatch_str is not None
self.maxRecordsPerBatch = int(maxRecordsPerBatch_str)

pipelined_rdd = self._call_cuml_fit_func(
df_output = self._call_cuml_fit_func_dataframe(
dataset=data_subset,
partially_collect=False,
paramMaps=None,
)
rows = pipelined_rdd.collect()
# Collect and concatenate row-by-row fit results
from itertools import chain

embeddings = list(chain.from_iterable([row["embedding_"] for row in rows]))
raw_data = list(chain.from_iterable([row["raw_data_"] for row in rows]))
del rows
pdf_output: PandasDataFrame = df_output.toPandas()

# Collect and concatenate row-by-row fit results
embeddings = np.array(
list(
pd.concat(
[pd.Series(x) for x in pdf_output["embedding_"]], ignore_index=True
)
),
dtype=np.float32,
)
raw_data = np.array(
list(
pd.concat(
[pd.Series(x) for x in pdf_output["raw_data_"]], ignore_index=True
)
),
dtype=np.float32,
)
del pdf_output

spark = _get_spark_session()
broadcast_embeddings = spark.sparkContext.broadcast(embeddings)
Expand All @@ -392,7 +423,7 @@ def _get_cuml_fit_func( # type: ignore
self, dataset: DataFrame
) -> Callable[[FitInputType, Dict[str, Any]], Dict[str, Any],]:
"""
This class overrides the parent function with a different return signature.
This class replaces the parent function with a different return signature. See fit_generator_func below.
"""
pass

Expand Down Expand Up @@ -454,6 +485,111 @@ def _cuml_fit(

return _cuml_fit

def _call_cuml_fit_func_dataframe(
self,
dataset: DataFrame,
partially_collect: bool = True,
paramMaps: Optional[Sequence["ParamMap"]] = None,
) -> DataFrame:
"""
Fits a model to the input dataset. This overrides the parent function to omit barrier stages and return a dataframe rather than an RDD.

Parameters
----------
dataset : :py:class:`pyspark.sql.DataFrame`
input dataset

Returns
-------
output : :py:class:`pyspark.sql.DataFrame`
fitted model attributes
"""

cls = self.__class__

select_cols, multi_col_names, _, _ = self._pre_process_data(dataset)

dataset = dataset.select(*select_cols)

is_local = _is_local(_get_spark_session().sparkContext)

cuda_managed_mem_enabled = (
_get_spark_session().conf.get("spark.rapids.ml.uvm.enabled", "false")
== "true"
)
if cuda_managed_mem_enabled:
get_logger(cls).info("CUDA managed memory enabled.")

# parameters passed to subclass
params: Dict[str, Any] = {
param_alias.cuml_init: self.cuml_params,
}

params[param_alias.fit_multiple_params] = []

cuml_fit_func = self._get_cuml_fit_generator_func(dataset, None) # type: ignore

array_order = self._fit_array_order()

cuml_verbose = self.cuml_params.get("verbose", False)

def _train_udf(pdf_iter: Iterable[pd.DataFrame]) -> Iterable[pd.DataFrame]:
from pyspark import TaskContext

logger = get_logger(cls)
logger.info("Initializing cuml context")

import cupy as cp

if cuda_managed_mem_enabled:
import rmm
from rmm.allocators.cupy import rmm_cupy_allocator

rmm.reinitialize(managed_memory=True)
cp.cuda.set_allocator(rmm_cupy_allocator)

_CumlCommon.initialize_cuml_logging(cuml_verbose)

context = TaskContext.get()

# set gpu device
_CumlCommon.set_gpu_device(context, is_local)

# handle the input
# inputs = [(X, Optional(y)), (X, Optional(y))]
logger.info("Loading data into python worker memory")
inputs = []
sizes = []
for pdf in pdf_iter:
sizes.append(pdf.shape[0])
if multi_col_names:
features = np.array(pdf[multi_col_names], order=array_order)
else:
features = np.array(list(pdf[alias.data]), order=array_order)
# experiments indicate it is faster to convert to numpy array and then to cupy array than directly
# invoking cupy array on the list
if cuda_managed_mem_enabled:
features = cp.array(features)

label = pdf[alias.label] if alias.label in pdf.columns else None
row_number = (
pdf[alias.row_number] if alias.row_number in pdf.columns else None
)
inputs.append((features, label, row_number))

# call the cuml fit function
# *note*: cuml_fit_func may delete components of inputs to free
# memory. do not rely on inputs after this call.
result = cuml_fit_func(inputs, params)
logger.info("Cuml fit complete")

for row in result:
yield row

output_df = dataset.mapInPandas(_train_udf, schema=self._out_schema())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is mostly duplicated code, wondering if it can be refactored into the existing API.

Also, is the fit_multiple_params API (from @wbo4958) explicitly unsupported then? If so, maybe we should document this, especially if it's removed for specific reasons.

Copy link
Collaborator Author

@rishic3 rishic3 Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Took out fit_multiple_params since fit_multiple isn't supported in UMAP - basically just trimmed out everything that wasn't relevant to UMAP specifically since this func only lives in UMAP atm.

As for refactoring this into the existing API, don't think there's a clean way without overriding or creating a new call_fit_func in core due to the RDD return signature and the barrier stuff. If we're interested in using dataframes for future algos that don't require NCCL during fit, we could have a second call_cuml_fit_func within core like the one in this PR for those use cases which future algos (and this algo) could inherit from. Not sure if this is preferred,


return output_df

def _use_fit_generator(self) -> bool:
return True

Expand Down Expand Up @@ -499,8 +635,8 @@ def _pre_process_data(
class UMAPModel(_CumlModel, UMAPClass, _UMAPCumlParams):
def __init__(
self,
embedding_: Union[pyspark.broadcast.Broadcast, List[List[float]]],
raw_data_: Union[pyspark.broadcast.Broadcast, List[List[float]]],
embedding_: Union[pyspark.broadcast.Broadcast, np.ndarray],
raw_data_: Union[pyspark.broadcast.Broadcast, np.ndarray],
n_cols: int,
dtype: str,
) -> None:
Expand All @@ -514,16 +650,16 @@ def __init__(
self.raw_data_ = raw_data_

@property
def embedding(self) -> List[List[float]]:
if isinstance(self.embedding_, list):
return self.embedding_
return self.embedding_.value
def embedding(self) -> np.ndarray:
if isinstance(self.embedding_, np.ndarray):
return self.embedding_.tolist()
return self.embedding_.value.tolist()

@property
def raw_data(self) -> List[List[float]]:
if isinstance(self.raw_data_, list):
return self.raw_data_
return self.raw_data_.value
def raw_data(self) -> np.ndarray:
if isinstance(self.raw_data_, np.ndarray):
rishic3 marked this conversation as resolved.
Show resolved Hide resolved
return self.raw_data_.tolist()
return self.raw_data_.value.tolist()

def _get_cuml_transform_func(
self, dataset: DataFrame, category: str = transform_evaluate.transform
Expand All @@ -538,19 +674,31 @@ def _construct_umap() -> CumlT:

from .utils import cudf_to_cuml_array

embedding_np = np.array(self.embedding, dtype=np.float32)
raw_data_np = np.array(self.raw_data, dtype=np.float32)
embedding = (
self.embedding_
if isinstance(self.embedding_, np.ndarray)
else self.embedding_.value
)
raw_data = (
self.raw_data_
if isinstance(self.raw_data_, np.ndarray)
else self.raw_data_.value
)

if embedding.dtype != np.float32:
embedding = embedding.astype(np.float32)
raw_data = raw_data.astype(np.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably should log a warning that we're auto-converting the type (but only if it's not too chatty).

Copy link
Collaborator Author

@rishic3 rishic3 Aug 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment I'm not supporting user-control over the "convert_dtype" param from cuml (determines whether the internal computations are float64); currently just defaulting to float32. I figured we could keep it like that for now for perf reasons and add float64 support in a future pr (and like Erik mentioned, maybe include a default conversion to float32 much earlier, long before we get to the cuml side, if desired).


if is_sparse(raw_data_np):
raw_data_cuml = SparseCumlArray(raw_data_np, convert_format=False)
if is_sparse(raw_data):
raw_data_cuml = SparseCumlArray(raw_data, convert_format=False)
else:
raw_data_cuml = cudf_to_cuml_array(
raw_data_np,
raw_data,
order="C",
)

internal_model = CumlUMAP(**cuml_alg_params)
internal_model.embedding_ = cp.array(embedding_np).data
internal_model.embedding_ = cp.array(embedding).data
internal_model._raw_data = raw_data_cuml

return internal_model
Expand Down Expand Up @@ -596,9 +744,76 @@ def _out_schema(self, input_schema: StructType) -> Union[StructType, str]:
)

def get_model_attributes(self) -> Optional[Dict[str, Any]]:
"""Override parent method to bring broadcast variables to driver before JSON serialization."""
if not isinstance(self.embedding_, list):
"""
Override parent method to bring broadcast variables to driver before JSON serialization.
"""
if not isinstance(self.embedding_, np.ndarray):
self._model_attributes["embedding_"] = self.embedding_.value
if not isinstance(self.raw_data_, list):
if not isinstance(self.raw_data_, np.ndarray):
rishic3 marked this conversation as resolved.
Show resolved Hide resolved
self._model_attributes["raw_data_"] = self.raw_data_.value
return self._model_attributes

def write(self) -> MLWriter:
return _CumlModelWriterNumpy(self)

@classmethod
def read(cls) -> MLReader:
return _CumlModelReaderNumpy(cls)


class _CumlModelWriterNumpy(_CumlModelWriter):
"""
Override parent writer to save numpy objects of _CumlModel to the file
"""

def saveImpl(self, path: str) -> None:
DefaultParamsWriter.saveMetadata(
self.instance,
path,
self.sc,
extraMetadata={
"_cuml_params": self.instance._cuml_params,
"_num_workers": self.instance._num_workers,
},
)
data_path = os.path.join(path, "data")
model_attributes = self.instance.get_model_attributes()

if not os.path.exists(data_path):
os.makedirs(data_path)
assert model_attributes is not None
for key, value in model_attributes.items():
if isinstance(value, np.ndarray):
array_path = os.path.join(data_path, f"{key}.npy")
np.save(array_path, value)
model_attributes[key] = array_path

metadata_file_path = os.path.join(data_path, "metadata.json")
model_attributes_str = json.dumps(model_attributes)
self.sc.parallelize([model_attributes_str], 1).saveAsTextFile(
metadata_file_path
)


class _CumlModelReaderNumpy(_CumlModelReader):
"""
Override parent reader to instantiate numpy objects of _CumlModel from file
"""

def load(self, path: str) -> "_CumlEstimator":
metadata = DefaultParamsReader.loadMetadata(path, self.sc)
data_path = os.path.join(path, "data")
metadata_file_path = os.path.join(data_path, "metadata.json")

model_attr_str = self.sc.textFile(metadata_file_path).collect()[0]
model_attr_dict = json.loads(model_attr_str)

for key, value in model_attr_dict.items():
if isinstance(value, str) and value.endswith(".npy"):
model_attr_dict[key] = np.load(value)

instance = self.model_cls(**model_attr_dict)
DefaultParamsReader.getAndSetParams(instance, metadata)
instance._cuml_params = metadata["_cuml_params"]
instance._num_workers = metadata["_num_workers"]
return instance
4 changes: 2 additions & 2 deletions python/tests/test_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def test_spark_umap(
@pytest.mark.parametrize("supervised", [True])
@pytest.mark.parametrize("dataset", ["digits"])
@pytest.mark.parametrize("n_neighbors", [10])
@pytest.mark.parametrize("dtype", [cuml_supported_data_types[0]])
@pytest.mark.parametrize("dtype", cuml_supported_data_types)
@pytest.mark.parametrize("feature_type", [pyspark_supported_feature_types[0]])
def test_spark_umap_fast(
n_parts: int,
Expand Down Expand Up @@ -303,7 +303,7 @@ def assert_umap_model(model: UMAPModel) -> None:
assert embedding.shape == (100, 2)
assert raw_data.shape == (100, 20)
assert np.array_equal(raw_data, X.get())
assert model.dtype == "float"
assert model.dtype == "float32"
assert model.n_cols == X.shape[1]

umap_model = umap.fit(df)
Expand Down
Loading