From 39548134eaa1b3ff3c0e371dd42a0ac2ad625b7a Mon Sep 17 00:00:00 2001 From: Mike Guo Date: Wed, 3 Apr 2024 14:21:33 +0800 Subject: [PATCH 1/6] introduce host_device for pass (#1047) ## Describe your changes Introduce host device concept for pass so that each pass could choose override it or choose it. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- docs/architecture.md | 2 +- olive/engine/engine.py | 10 ++++++- olive/passes/olive_pass.py | 29 ++++++++++++++----- olive/passes/onnx/conversion.py | 8 ++--- olive/passes/pytorch/gptq.py | 11 +++++-- olive/passes/pytorch/sparsegpt.py | 8 ++--- olive/passes/pytorch/tensor_parallel.py | 2 +- .../passes/test_pass_serialization.py | 8 +++-- 8 files changed, 56 insertions(+), 22 deletions(-) diff --git a/docs/architecture.md b/docs/architecture.md index 31cf86478..7b65200ef 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -35,7 +35,7 @@ Passes are the building blocks of an Olive workflow. Olive uses multiple Passes The base class for Pass: ```python class Pass(ABC): - def __init__(self, accelerator_spec: AcceleratorSpec, config: Union[Dict[str, Any], BaseModel], disable_search: Optional[bool] = False): + def __init__(self, accelerator_spec: AcceleratorSpec, config: Union[Dict[str, Any], BaseModel], disable_search: Optional[bool] = False, host_device: Optional[str] = None): ... @classmethod diff --git a/olive/engine/engine.py b/olive/engine/engine.py index 6df118d47..f1080e012 100644 --- a/olive/engine/engine.py +++ b/olive/engine/engine.py @@ -372,14 +372,22 @@ def run_accelerator( logger.debug("run_accelerator done") return output_footprint + def get_host_device(self): + if self.host_config.config.accelerators: + # for host device, we will always use the first accelerator device + return self.host_config.config.accelerators[0].device + else: + return None + def setup_passes(self, accelerator_spec: "AcceleratorSpec"): + host_device = self.get_host_device() # clean the passes self.passes.clear() for name, config in self.pass_config.items(): pass_cls: Type["Pass"] = config["type"] pass_cfg = config["config"] pass_cfg = pass_cls.generate_search_space(accelerator_spec, pass_cfg, config["disable_search"]) - p = pass_cls(accelerator_spec, pass_cfg, config["disable_search"]) + p = pass_cls(accelerator_spec, pass_cfg, config["disable_search"], host_device) self.register_pass( p, name=name, diff --git a/olive/passes/olive_pass.py b/olive/passes/olive_pass.py index 0a91d4206..25c3ed77c 100644 --- a/olive/passes/olive_pass.py +++ b/olive/passes/olive_pass.py @@ -61,14 +61,22 @@ def __init_subclass__(cls, **kwargs) -> None: cls.registry[cls.__name__.lower()] = cls def __init__( - self, accelerator_spec: AcceleratorSpec, config: Dict[str, Any], disable_search: Optional[bool] = False + self, + accelerator_spec: AcceleratorSpec, + config: Dict[str, Any], + disable_search: Optional[bool] = False, + host_device=None, ): """Initialize the pass. - :param config_class: the PassConfig class with the default value or default search values. - :type config_class: Type[PassConfigBase] + :param accelerator_spec: the accelerator spec for the pass. + :type accelerator_spec: AcceleratorSpec :param config: the configuration representing search space. :type config: Dict[str, Any] + :param disable_search: whether to disable search. + :type disable_search: Optional[bool] + :param host_device: the host device for the pass. + :type host_device: Optional[str] """ assert accelerator_spec is not None, "Please specify the accelerator spec for the pass." assert config is not None, "Please specify the configuration for the pass." @@ -76,6 +84,7 @@ def __init__( config_class, default_config = self.get_config_class(accelerator_spec, disable_search) self.accelerator_spec = accelerator_spec + self.host_device = host_device self._config_class = config_class self.config = config @@ -221,6 +230,7 @@ def to_json(self, check_object: bool = False) -> Dict[str, Any]: "type": self.__class__.__name__, "disable_search": True, "accelerator": self.accelerator_spec.to_json(), + "host_device": self.host_device, "config": self.serialize_config(self.config, check_object), } @@ -394,6 +404,7 @@ class FullPassConfig(ConfigBase): type: str disable_search: bool = False accelerator: Dict[str, str] = None + host_device: Optional[str] = None config: Dict[str, Any] = None @validator("type") @@ -408,17 +419,21 @@ def create_pass(self): pass_cls = Pass.registry[self.type.lower()] accelerator_spec = AcceleratorSpec(**self.accelerator) # pylint: disable=not-a-mapping - return pass_cls(accelerator_spec, self.config, self.disable_search) + return pass_cls(accelerator_spec, self.config, self.disable_search, self.host_device) -# TODO(myguo): deprecate or remove this method by explicitly specify the accelerator_spec in the arguments +# TODO(myguo): deprecate or remove this function by explicitly specify the accelerator_spec in the arguments # instead of using the default argument. def create_pass_from_dict( - pass_cls: Type[Pass], config: Dict[str, Any] = None, disable_search=False, accelerator_spec: AcceleratorSpec = None + pass_cls: Type[Pass], + config: Dict[str, Any] = None, + disable_search=False, + accelerator_spec: AcceleratorSpec = None, + host_device=None, ) -> Pass: """Create a pass from a dictionary.""" if accelerator_spec is None: accelerator_spec = DEFAULT_CPU_ACCELERATOR config = pass_cls.generate_search_space(accelerator_spec, config, disable_search) - return pass_cls(accelerator_spec, config, disable_search) + return pass_cls(accelerator_spec, config, disable_search, host_device) diff --git a/olive/passes/onnx/conversion.py b/olive/passes/onnx/conversion.py index e4d587a9f..458a86a51 100644 --- a/olive/passes/onnx/conversion.py +++ b/olive/passes/onnx/conversion.py @@ -72,7 +72,7 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon "use_dynamo_exporter": PassConfigParam( type_=bool, default_value=False, description="Whether to use dynamo_export API to export ONNX model." ), - "use_device": PassConfigParam( + "device": PassConfigParam( type_=str, description=( "The device to use for conversion, e.g., 'cuda' or 'cpu'. If not specified, will use 'cpu' for" @@ -114,14 +114,14 @@ def _run_for_config( ) -> Union[CompositeModelHandler, DistributedOnnxModelHandler, ONNXModelHandler]: # get the device to use for conversion # default to "cpu" for PyTorchModelHandler and "cuda" for DistributedPyTorchModel - device = config["use_device"] or "cpu" + device = config["device"] or "cpu" # get the dtype to use for conversion torch_dtype = resolve_torch_dtype(config["torch_dtype"]) if config["torch_dtype"] else None if torch_dtype == torch.float16 and device == "cpu": raise ValueError("Conversion to float16 is not supported for CPU.") if isinstance(model, DistributedPyTorchModelHandler): - if not config["use_device"]: + if not config["device"]: device = "cuda" return self._convert_distributed_model_on_device( model, data_root, config, output_model_path, device, torch_dtype @@ -342,7 +342,7 @@ def _load_pytorch_model( logger.warning( "Loading model on CPU, but the model loading args specify dtype float16 which is not supported for" " conversion on CPU. The dtype is changed to float32. If float16 model is desired, please specify" - " use_device as 'cuda' or use OrtTransformerOptimization/OnnxFloatToFloat16 pass after conversion to" + " device as 'cuda' or use OrtTransformerOptimization/OnnxFloatToFloat16 pass after conversion to" " convert the model to float16." ) new_from_pretrained_args["torch_dtype"] = torch.float32 diff --git a/olive/passes/pytorch/gptq.py b/olive/passes/pytorch/gptq.py index efd78a10b..78393750c 100644 --- a/olive/passes/pytorch/gptq.py +++ b/olive/passes/pytorch/gptq.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import logging from typing import Any, Callable, Dict, List, Union import torch @@ -14,6 +15,8 @@ from olive.passes import Pass from olive.passes.pass_config import PassConfigParam +logger = logging.getLogger(__name__) + class GptqQuantizer(Pass): """GPTQ quantization using Hugging Face Optimum and export model with onnxruntime optimized kernel.""" @@ -125,8 +128,13 @@ def _run_for_config( from olive.passes.pytorch.gptq_utils import QuantLinearORT - if self.accelerator_spec.accelerator_type != Device.GPU: + if not torch.cuda.is_available(): raise ValueError("Please use GPU to run gptq quantization.") + elif self.host_device != Device.GPU: + logger.warning( + "GPTQ quantization requires GPU but the host device is %s, will ignore the host device", + self.host_device, + ) dataset = None if config["dataloader_func"]: @@ -195,7 +203,6 @@ def get_onnx_quant_linear(*args, **kwargs): auto_gptq.modeling._utils.dynamically_import_QuantLinear = original # pylint: disable=protected-access quantized_model = quantized_model.model - assert self.accelerator_spec.accelerator_type == Device.GPU output_model_path = normalize_path_suffix(output_model_path, "model.pt") torch.save(quantized_model, output_model_path) diff --git a/olive/passes/pytorch/sparsegpt.py b/olive/passes/pytorch/sparsegpt.py index d8a313eb1..bcca2ea6f 100644 --- a/olive/passes/pytorch/sparsegpt.py +++ b/olive/passes/pytorch/sparsegpt.py @@ -70,9 +70,9 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon description="Only prune layers whose name contains the given string(s).", ), # this is not the same as accelerator_spec.device which is the target device for inference - # compute_device is the device we want to run the algorithm on, does not affect the final model - # so accelerator_spec.device can be cpu but compute_device can be cuda for faster pass execution - "compute_device": PassConfigParam( + # device is the device we want to run the algorithm on, does not affect the final model + # so accelerator_spec.device can be cpu but device can be cuda for faster pass execution + "device": PassConfigParam( type_=str, default_value="auto", description=( @@ -108,7 +108,7 @@ def _run_for_config( n, m = sparsity if mode == "structured" else [0, 0] # get device to use for computations - device = config["compute_device"] + device = config["device"] if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" logger.debug( diff --git a/olive/passes/pytorch/tensor_parallel.py b/olive/passes/pytorch/tensor_parallel.py index fd638ab21..588d08c75 100644 --- a/olive/passes/pytorch/tensor_parallel.py +++ b/olive/passes/pytorch/tensor_parallel.py @@ -168,7 +168,7 @@ def _run_for_config( with multiprocessing.Pool(processes=max_parallel_jobs) as pool: results = pool.map(PyTorchTensorParallel._generate_one, params) - if self.accelerator_spec.accelerator_type == Device.GPU and torch.cuda.is_available(): + if self.host_device == Device.GPU and torch.cuda.is_available(): torch.cuda.empty_cache() if world_size != sum(results): diff --git a/test/unit_test/passes/test_pass_serialization.py b/test/unit_test/passes/test_pass_serialization.py index e6ef61cac..21214734e 100644 --- a/test/unit_test/passes/test_pass_serialization.py +++ b/test/unit_test/passes/test_pass_serialization.py @@ -2,15 +2,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- +import pytest + from olive.hardware import DEFAULT_CPU_ACCELERATOR from olive.passes.olive_pass import FullPassConfig from olive.passes.onnx.conversion import OnnxConversion -def test_pass_serialization(): +@pytest.mark.parametrize("host_device", [None, "cpu", "gpu"]) +def test_pass_serialization(host_device): onnx_conversion_config = {} config = OnnxConversion.generate_search_space(DEFAULT_CPU_ACCELERATOR, onnx_conversion_config) - onnx_conversion = OnnxConversion(DEFAULT_CPU_ACCELERATOR, config) + onnx_conversion = OnnxConversion(DEFAULT_CPU_ACCELERATOR, config, host_device=host_device) json = onnx_conversion.to_json(True) cfg = FullPassConfig.from_json(json) @@ -18,3 +21,4 @@ def test_pass_serialization(): assert isinstance(p, OnnxConversion) assert p.accelerator_spec == DEFAULT_CPU_ACCELERATOR assert p.config == config + assert p.host_device == host_device From 9a504594704db4eda4631daa012fd0c49cac8b6c Mon Sep 17 00:00:00 2001 From: trajep Date: Wed, 3 Apr 2024 17:53:20 +0800 Subject: [PATCH 2/6] =?UTF-8?q?=F0=9F=8D=B8=20Comments=20about=20where=20a?= =?UTF-8?q?nd=20how=20accelerators=20specs=20are=20used=20(#1050)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Describe your changes 1. Comments about where and how accelerators specs are used. 2. Typos fix. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- olive/hardware/accelerator.py | 2 +- olive/workflows/run/run.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/olive/hardware/accelerator.py b/olive/hardware/accelerator.py index 17eea6626..bfbc28144 100644 --- a/olive/hardware/accelerator.py +++ b/olive/hardware/accelerator.py @@ -173,7 +173,7 @@ def normalize_accelerators(system_config: "SystemConfig", skip_supported_eps_che * the accelerators is not specified, infer the device/ep based on the installed ORT in case of local/python system. * only device is specified, infer the execution providers based on the installed ORT in case of local/python system. - * only EP is specified, infer the device based on the installed ORT in case of local/python sytemm. + * only EP is specified, infer the device based on the installed ORT in case of local/python system. * For AzureML and Docker system, the accelerators and execution providers must be specified. """ from olive.systems.common import SystemType diff --git a/olive/workflows/run/run.py b/olive/workflows/run/run.py index 8458a62c3..c3d6fc8ca 100644 --- a/olive/workflows/run/run.py +++ b/olive/workflows/run/run.py @@ -184,6 +184,8 @@ def run_engine(config: RunConfig, data_root: str = None): and config.auto_optimizer_config is not None and not config.auto_optimizer_config.disable_auto_optimizer ): + # For auto optimizer, Olive generates passes and pass_flows for each accelerator + # that means, the passes and pass_flows might be different for each accelerator for acc_spec in accelerator_specs: _passes, pass_flows = AutoOptimizer( input_model, @@ -195,10 +197,20 @@ def run_engine(config: RunConfig, data_root: str = None): pass_list.append(({k: RunPassConfig.parse_obj(v) for k, v in _passes.items()}, pass_flows)) acc_list.append([acc_spec]) else: + # For non-auto-optimizer case, Olive uses the same passes and pass_flows for all accelerators + # if user needs different passes and pass_flows for each accelerator, they need to write multiple + # config files. pass_list.append((config.passes, config.pass_flows)) acc_list.append(accelerator_specs) run_rls = {} + # Note that, in Olive, there are two positions where the accelerator_specs are looped over: + # 1. olive workflow run level: this is where the accelerator_specs are created and passed to + # the engine. In this level, accelerator specs can be used to generate passes and pass_flows. + # 2. engine level: this is where the accelerator_specs are looped over to run the passes. + # TODO(anyone): refactor the code to remove the engine level loop if possible. + # For time being, we are keeping both loops, but in future, we might want to refactor the code + # to remove engine level loop and pass the accelerator_specs to the engine directly. for accelerator_spec, (passes, pass_flows) in zip(acc_list, pass_list): engine.reset_passes() if passes: From dd9139376d114f144b7d665dfe4c9d231fe1f86d Mon Sep 17 00:00:00 2001 From: Xiaoyu <85524621+xiaoyu-work@users.noreply.github.com> Date: Thu, 4 Apr 2024 15:40:14 -0700 Subject: [PATCH 3/6] Add mlflow packaging option for phi2 example (#1053) ## Describe your changes Add mlflow packaging option for phi2 example ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- examples/phi2/README.md | 2 ++ examples/phi2/phi2.py | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/examples/phi2/README.md b/examples/phi2/README.md index 04ad44ee5..bc56d89cf 100644 --- a/examples/phi2/README.md +++ b/examples/phi2/README.md @@ -76,6 +76,8 @@ For better generation experience, here is the way to run inference with the opti python phi2.py --model_type cpu_fp32 --inference --optimum_optimization --prompt "Write a extremely long story starting with once upon a time" ``` +## Export output models in MLFlow format +If you want to output the optimized models to a zip file in MLFlow format, add `--export_mlflow_format` argument. The MLFlow model will be packaged in a zip file named `mlflow_model` in the output folder. ## Limitations 1. The latest ONNXRuntime implements specific fusion patterns for better performance but only works for ONNX model from TorchDynamo-based ONNX Exporter. And the TorchDynamo-based ONNX Exporter is only available on Linux. diff --git a/examples/phi2/phi2.py b/examples/phi2/phi2.py index 7cbec4251..c961e5e59 100644 --- a/examples/phi2/phi2.py +++ b/examples/phi2/phi2.py @@ -78,7 +78,12 @@ def get_args(raw_args): parser.add_argument( "--optimum_optimization", action="store_true", - help="Run inference with optimized model", + help="Use optimum optimization", + ) + parser.add_argument( + "--export_mlflow_format", + action="store_true", + help="Export the model in mlflow format.", ) parser.add_argument( "--prompt", @@ -173,6 +178,15 @@ def main(raw_args=None): del template_json["passes"][pass_name] continue + if args.export_mlflow_format: + template_json["engine"]["packaging_config"] = [ + { + "type": "Zipfile", + "name": "mlflow_model", + "config": {"export_in_mlflow_format": True}, + } + ] + with open("phi2_optimize.json", "w") as f: json.dump(template_json, f, indent=4) From 5f65d62588686b2539efef29c09f61fee345d84a Mon Sep 17 00:00:00 2001 From: shaahji <96227573+shaahji@users.noreply.github.com> Date: Thu, 4 Apr 2024 22:23:27 -0700 Subject: [PATCH 4/6] [Performance]: Delayed Python pass module load (#1028) ## Describe your changes ### [Performance]: Delayed Python module load Introducing olive_config.json to encapsulates all the different global environment properties (like pass module configuration) to setup olive. Also, merged extra_dependencies.json into olive_config.json. User can provide a alternative configuration file with the command line argument '--package-config'. Instead of loading all modules (specifically passes) at launch, delay the load until after the run config is parsed and load only the ones that are explicitly required. Moved a few hard-coded dependency management into the olive config. ### Release Note: Added support for extending available Olive passes and configuring each via data file (olive_config.json). The configuration file is provided to Olive run with command line param `--package-config`. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [x] Update documents if necessary. - [x] Lint and apply fixes to your code by running `lintrunner -a` - [x] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- docs/source/exts/auto_config_doc/__init__.py | 16 +- olive/olive_config.json | 234 ++++++++++++++++++ olive/package_config.py | 36 +++ olive/passes/__init__.py | 8 +- olive/passes/olive_pass.py | 7 - olive/passes/onnx/__init__.py | 50 ---- olive/passes/onnx/vitis_ai/__init__.py | 3 - olive/passes/openvino/__init__.py | 7 - olive/passes/pass_config.py | 14 +- olive/passes/pytorch/__init__.py | 17 -- olive/passes/qnn/__init__.py | 6 - olive/passes/snpe/__init__.py | 5 - olive/systems/azureml/aml_pass_runner.py | 5 + olive/systems/docker/runner.py | 6 + .../systems/python_environment/pass_runner.py | 9 +- olive/workflows/run/__main__.py | 14 +- olive/workflows/run/config.py | 52 ++-- olive/workflows/run/run.py | 141 +++++------ setup.py | 6 +- .../aml_model_test/test_aml_model.py | 2 +- test/multiple_ep/utils.py | 2 +- test/unit_test/engine/test_engine.py | 4 +- .../passes/common/test_user_script.py | 2 +- .../passes/onnx/test_bnb_quantization.py | 2 +- .../onnx/test_dynamic_to_fixed_shape.py | 2 +- .../passes/onnx/test_model_optimizer.py | 2 +- .../unit_test/passes/onnx/test_perf_tuning.py | 5 +- .../passes/onnx/test_qnn_preprocess.py | 2 +- .../passes/onnx/test_quantization.py | 2 +- .../onnx/test_transformer_optimization.py | 2 +- test/unit_test/passes/pytorch/test_gptq.py | 2 +- test/unit_test/passes/pytorch/test_lora.py | 2 +- .../test_quantization_aware_training.py | 2 +- .../passes/pytorch/test_sparsegpt.py | 2 +- .../pytorch/test_torch_trt_conversion.py | 2 +- test/unit_test/test_module_config.py | 15 ++ test/unit_test/utils.py | 5 +- 37 files changed, 463 insertions(+), 230 deletions(-) create mode 100644 olive/olive_config.json create mode 100644 olive/package_config.py create mode 100644 test/unit_test/test_module_config.py diff --git a/docs/source/exts/auto_config_doc/__init__.py b/docs/source/exts/auto_config_doc/__init__.py index 67c3409f7..8d68e527a 100644 --- a/docs/source/exts/auto_config_doc/__init__.py +++ b/docs/source/exts/auto_config_doc/__init__.py @@ -15,16 +15,19 @@ from olive.common.auto_config import AutoConfigClass from olive.hardware import DEFAULT_CPU_ACCELERATOR +from olive.package_config import OlivePackageConfig from olive.passes import Pass # pylint: skip-file -def import_class(class_name: str): - module_name = ".".join(class_name.split(".")[:-1]) - class_name = class_name.split(".")[-1] - module = import_module(module_name) - return getattr(module, class_name) +def import_class(class_name: str, package_config: OlivePackageConfig): + module_path, module_name = class_name.rsplit(".", 1) + if module_name in package_config.passes: + return package_config.import_pass_module(module_name) + + module = import_module(module_path) + return getattr(module, module_name) class AutoConfigDirective(Directive): @@ -68,7 +71,8 @@ def make_doc(self, auto_config_class: Union[AutoConfigClass, Pass]): def run(self): (class_name,) = self.arguments - auto_config_class = import_class(class_name) + package_config = OlivePackageConfig.load_default_config() + auto_config_class = import_class(class_name, package_config) assert issubclass(auto_config_class, AutoConfigClass) or issubclass( auto_config_class, Pass ), f"{class_name} is not a subclass of AutoConfigClass or Pass" diff --git a/olive/olive_config.json b/olive/olive_config.json new file mode 100644 index 000000000..f29393376 --- /dev/null +++ b/olive/olive_config.json @@ -0,0 +1,234 @@ +{ + "passes": { + "AppendPrePostProcessingOps": { + "module_path": "olive.passes.onnx.append_pre_post_processing_ops.AppendPrePostProcessingOps" + }, + "DynamicToFixedShape": { + "module_path": "olive.passes.onnx.dynamic_to_fixed_shape.DynamicToFixedShape" + }, + "GenAIModelExporter": { + "module_path": "olive.passes.onnx.genai_model_exporter.GenAIModelExporter" + }, + "IncDynamicQuantization": { + "module_path": "olive.passes.onnx.inc_quantization.IncDynamicQuantization", + "extra_dependencies": [ + "inc" + ] + }, + "IncQuantization": { + "module_path": "olive.passes.onnx.inc_quantization.IncQuantization", + "extra_dependencies": [ + "inc" + ] + }, + "IncStaticQuantization": { + "module_path": "olive.passes.onnx.inc_quantization.IncStaticQuantization", + "extra_dependencies": [ + "inc" + ] + }, + "InsertBeamSearch": { + "module_path": "olive.passes.onnx.insert_beam_search.InsertBeamSearch" + }, + "MoEExpertsDistributor": { + "module_path": "olive.passes.onnx.moe_experts_distributor.MoEExpertsDistributor" + }, + "OnnxBnb4Quantization": { + "module_path": "olive.passes.onnx.bnb_quantization.OnnxBnb4Quantization" + }, + "OnnxConversion": { + "module_path": "olive.passes.onnx.conversion.OnnxConversion" + }, + "OnnxDynamicQuantization": { + "module_path": "olive.passes.onnx.quantization.OnnxDynamicQuantization" + }, + "OnnxFloatToFloat16": { + "module_path": "olive.passes.onnx.float16_conversion.OnnxFloatToFloat16", + "module_dependencies": [ + "onnxconverter-common" + ] + }, + "OnnxMatMul4Quantizer": { + "module_path": "olive.passes.onnx.quantization.OnnxMatMul4Quantizer" + }, + "OnnxModelOptimizer": { + "module_path": "olive.passes.onnx.model_optimizer.OnnxModelOptimizer" + }, + "OnnxOpVersionConversion": { + "module_path": "olive.passes.onnx.conversion.OnnxOpVersionConversion" + }, + "OnnxQuantization": { + "module_path": "olive.passes.onnx.quantization.OnnxQuantization" + }, + "OnnxStaticQuantization": { + "module_path": "olive.passes.onnx.quantization.OnnxStaticQuantization" + }, + "OptimumConversion": { + "module_path": "olive.passes.onnx.optimum_conversion.OptimumConversion", + "extra_dependencies": [ + "optimum" + ] + }, + "OptimumMerging": { + "module_path": "olive.passes.onnx.optimum_merging.OptimumMerging", + "extra_dependencies": [ + "optimum" + ] + }, + "OrtMixedPrecision": { + "module_path": "olive.passes.onnx.mixed_precision.OrtMixedPrecision" + }, + "OrtPerfTuning": { + "module_path": "olive.passes.onnx.perf_tuning.OrtPerfTuning", + "module_dependencies": [ + "psutil" + ] + }, + "OrtTransformersOptimization": { + "module_path": "olive.passes.onnx.transformer_optimization.OrtTransformersOptimization" + }, + "QNNPreprocess": { + "module_path": "olive.passes.onnx.qnn_preprocess.QNNPreprocess" + }, + "VitisAIQuantization": { + "module_path": "olive.passes.onnx.vitis_ai_quantization.VitisAIQuantization" + }, + "VitisQDQQuantizer": { + "module_path": "olive.passes.onnx.vitis_ai.quantizer.VitisQDQQuantizer" + }, + "VitisQOpQuantizer": { + "module_path": "olive.passes.onnx.vitis_ai.quantizer.VitisQOpQuantizer" + }, + "quantize_static": { + "module_path": "olive.passes.onnx.vitis_ai.quantize.quantize_static" + }, + "PowerOfTwoMethod": { + "module_path": "olive.passes.onnx.vitis_ai.quant_utils.PowerOfTwoMethod" + }, + "OpenVINOConversion": { + "module_path": "olive.passes.openvino.conversion.OpenVINOConversion", + "extra_dependencies": [ + "openvino" + ] + }, + "OpenVINOQuantization": { + "module_path": "olive.passes.openvino.quantization.OpenVINOQuantization", + "extra_dependencies": [ + "openvino" + ] + }, + "GptqQuantizer": { + "module_path": "olive.passes.pytorch.gptq.GptqQuantizer", + "module_dependencies": [ + "auto-gptq", + "optimum" + ] + }, + "LoftQ": { + "module_path": "olive.passes.pytorch.lora.LoftQ" + }, + "LoRA": { + "module_path": "olive.passes.pytorch.lora.LoRA", + "extra_dependencies": [ + "lora" + ] + }, + "PyTorchTensorParallel": { + "module_path": "olive.passes.pytorch.tensor_parallel.PyTorchTensorParallel" + }, + "QLoRA": { + "module_path": "olive.passes.pytorch.lora.QLoRA", + "extra_dependencies": [ + "bnb", + "lora" + ] + }, + "QuantizationAwareTraining": { + "module_path": "olive.passes.pytorch.quantization_aware_training.QuantizationAwareTraining", + "module_dependencies": [ + "pytorch-lightning" + ] + }, + "SparseGPT": { + "module_path": "olive.passes.pytorch.sparsegpt.SparseGPT" + }, + "TorchTRTConversion": { + "module_path": "olive.passes.pytorch.torch_trt_conversion.TorchTRTConversion", + "extra_dependencies": [ + "torch-tensorrt" + ] + }, + "QNNConversion": { + "module_path": "olive.passes.qnn.conversion.QNNConversion" + }, + "QNNModelLibGenerator": { + "module_path": "olive.passes.qnn.model_lib_generator.QNNModelLibGenerator" + }, + "QNNContextBinaryGenerator": { + "module_path": "olive.passes.qnn.context_binary_generator.QNNContextBinaryGenerator" + }, + "SNPEConversion": { + "module_path": "olive.passes.snpe.conversion.SNPEConversion" + }, + "SNPEQuantization": { + "module_path": "olive.passes.snpe.quantization.SNPEQuantization" + }, + "SNPEtoONNXConversion": { + "module_path": "olive.passes.snpe.snpe_to_onnx.SNPEtoONNXConversion" + } + }, + "extra_dependencies": { + "azureml": [ + "azure-ai-ml>=1.11.1", + "azure-keyvault-secrets", + "azure-identity", + "azureml-fsspec" + ], + "docker": [ + "docker" + ], + "cpu": [ + "onnxruntime" + ], + "gpu": [ + "onnxruntime-gpu" + ], + "directml": [ + "onnxruntime-directml" + ], + "openvino": [ + "openvino==2023.2.0", + "nncf==2.7.0" + ], + "tf": [ + "tensorflow==1.15.0" + ], + "inc": [ + "neural-compressor" + ], + "optimum": [ + "optimum" + ], + "torch-tensorrt": [ + "torch-tensorrt" + ], + "lora": [ + "accelerate", + "peft", + "scipy" + ], + "bnb": [ + "bitsandbytes" + ], + "ort-training": [ + "onnxruntime-training", + "torch-ort" + ], + "ort": [ + "onnxruntime", + "onnxruntime-directml", + "onnxruntime-gpu", + "onnxruntime-openvino" + ] + } +} diff --git a/olive/package_config.py b/olive/package_config.py new file mode 100644 index 000000000..472277f5e --- /dev/null +++ b/olive/package_config.py @@ -0,0 +1,36 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import importlib +from pathlib import Path +from typing import Dict, List + +from olive.common.config_utils import ConfigBase +from olive.passes import PassModuleConfig + + +class OlivePackageConfig(ConfigBase): + passes: Dict[str, PassModuleConfig] + extra_dependencies: Dict[str, List[str]] + + @staticmethod + def get_default_config_path() -> str: + return str(Path(__file__).parent / "olive_config.json") + + @staticmethod + def load_default_config() -> "OlivePackageConfig": + return OlivePackageConfig.parse_file(OlivePackageConfig.get_default_config_path()) + + def import_pass_module(self, pass_type): + if "." in pass_type: + _, module_name = pass_type.rsplit(".", 1) + return self.import_pass_module(module_name) + + if pass_type in self.passes: + pass_module_config = self.passes.get(pass_type) + module_path, module_name = pass_module_config.module_path.rsplit(".", 1) + module = importlib.import_module(module_path, module_name) + return getattr(module, module_name) + + raise ValueError(f"Package configuration for pass of type '{pass_type}' not found") diff --git a/olive/passes/__init__.py b/olive/passes/__init__.py index dc084cf06..06e5bfc31 100644 --- a/olive/passes/__init__.py +++ b/olive/passes/__init__.py @@ -3,17 +3,13 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from olive.passes.olive_pass import FullPassConfig, Pass -from olive.passes.onnx import * # noqa: F403 -from olive.passes.openvino import * # noqa: F403 -from olive.passes.pass_config import PassParamDefault -from olive.passes.pytorch import * # noqa: F403 -from olive.passes.qnn import * # noqa: F403 -from olive.passes.snpe import * # noqa: F403 +from olive.passes.pass_config import PassModuleConfig, PassParamDefault REGISTRY = Pass.registry __all__ = [ "Pass", "PassParamDefault", + "PassModuleConfig", "FullPassConfig", ] diff --git a/olive/passes/olive_pass.py b/olive/passes/olive_pass.py index 25c3ed77c..7e3ac41e2 100644 --- a/olive/passes/olive_pass.py +++ b/olive/passes/olive_pass.py @@ -9,7 +9,6 @@ from typing import Any, Callable, ClassVar, Dict, Optional, Tuple, Type, Union, get_args from olive.common.config_utils import ConfigBase, ParamCategory, validate_config -from olive.common.pydantic_v1 import validator from olive.common.user_module_loader import UserModuleLoader from olive.data.config import DataConfig from olive.hardware import DEFAULT_CPU_ACCELERATOR, AcceleratorSpec @@ -407,12 +406,6 @@ class FullPassConfig(ConfigBase): host_device: Optional[str] = None config: Dict[str, Any] = None - @validator("type") - def validate_type(cls, v): - if v.lower() not in Pass.registry: - raise ValueError(f"Unknown pass type {v}") - return v - def create_pass(self): if not isinstance(self.accelerator, dict): raise ValueError(f"accelerator must be a dict, got {self.accelerator}") diff --git a/olive/passes/onnx/__init__.py b/olive/passes/onnx/__init__.py index 4bef604db..862c45ce3 100644 --- a/olive/passes/onnx/__init__.py +++ b/olive/passes/onnx/__init__.py @@ -2,53 +2,3 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from olive.passes.onnx.append_pre_post_processing_ops import AppendPrePostProcessingOps -from olive.passes.onnx.bnb_quantization import OnnxBnb4Quantization -from olive.passes.onnx.conversion import OnnxConversion, OnnxOpVersionConversion -from olive.passes.onnx.dynamic_to_fixed_shape import DynamicToFixedShape -from olive.passes.onnx.float16_conversion import OnnxFloatToFloat16 -from olive.passes.onnx.genai_model_exporter import GenAIModelExporter -from olive.passes.onnx.inc_quantization import IncDynamicQuantization, IncQuantization, IncStaticQuantization -from olive.passes.onnx.insert_beam_search import InsertBeamSearch -from olive.passes.onnx.mixed_precision import OrtMixedPrecision -from olive.passes.onnx.model_optimizer import OnnxModelOptimizer -from olive.passes.onnx.moe_experts_distributor import MoEExpertsDistributor -from olive.passes.onnx.optimum_conversion import OptimumConversion -from olive.passes.onnx.optimum_merging import OptimumMerging -from olive.passes.onnx.perf_tuning import OrtPerfTuning -from olive.passes.onnx.qnn_preprocess import QNNPreprocess -from olive.passes.onnx.quantization import ( - OnnxDynamicQuantization, - OnnxMatMul4Quantizer, - OnnxQuantization, - OnnxStaticQuantization, -) -from olive.passes.onnx.transformer_optimization import OrtTransformersOptimization -from olive.passes.onnx.vitis_ai_quantization import VitisAIQuantization - -__all__ = [ - "AppendPrePostProcessingOps", - "DynamicToFixedShape", - "GenAIModelExporter", - "IncDynamicQuantization", - "IncQuantization", - "IncStaticQuantization", - "InsertBeamSearch", - "MoEExpertsDistributor", - "OnnxBnb4Quantization", - "OnnxConversion", - "OnnxDynamicQuantization", - "OnnxFloatToFloat16", - "OnnxMatMul4Quantizer", - "OnnxModelOptimizer", - "OnnxOpVersionConversion", - "OnnxQuantization", - "OnnxStaticQuantization", - "OptimumConversion", - "OptimumMerging", - "OrtMixedPrecision", - "OrtPerfTuning", - "OrtTransformersOptimization", - "QNNPreprocess", - "VitisAIQuantization", -] diff --git a/olive/passes/onnx/vitis_ai/__init__.py b/olive/passes/onnx/vitis_ai/__init__.py index 819539a7b..d70cb0180 100644 --- a/olive/passes/onnx/vitis_ai/__init__.py +++ b/olive/passes/onnx/vitis_ai/__init__.py @@ -7,12 +7,9 @@ from olive.passes.onnx.vitis_ai.quant_utils import PowerOfTwoMethod from olive.passes.onnx.vitis_ai.quantize import quantize_static -from olive.passes.onnx.vitis_ai.quantizer import VitisQDQQuantizer, VitisQOpQuantizer __all__ = [ "CalibrationDataReader", - "VitisQDQQuantizer", - "VitisQOpQuantizer", "quantize_static", "PowerOfTwoMethod", "QuantFormat", diff --git a/olive/passes/openvino/__init__.py b/olive/passes/openvino/__init__.py index 5d7f9b8aa..862c45ce3 100644 --- a/olive/passes/openvino/__init__.py +++ b/olive/passes/openvino/__init__.py @@ -2,10 +2,3 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from olive.passes.openvino.conversion import OpenVINOConversion -from olive.passes.openvino.quantization import OpenVINOQuantization - -__all__ = [ - "OpenVINOConversion", - "OpenVINOQuantization", -] diff --git a/olive/passes/pass_config.py b/olive/passes/pass_config.py index 22fb338f8..944d2e4e9 100644 --- a/olive/passes/pass_config.py +++ b/olive/passes/pass_config.py @@ -4,7 +4,7 @@ # -------------------------------------------------------------------------- from enum import Enum from pathlib import Path -from typing import Callable, Dict, Optional, Type, Union +from typing import Callable, ClassVar, Dict, List, Optional, Type, Union from olive.common.config_utils import ConfigBase, ConfigParam, ParamCategory, validate_object, validate_resource_path from olive.common.pydantic_v1 import create_model, validator @@ -131,3 +131,15 @@ def create_config_class( config[param] = (type_, param_config.default_value) return create_model(f"{pass_type}Config", **config, __base__=PassConfigBase, __validators__=validators) + + +class PassModuleConfig(ConfigBase): + module_path: str + module_dependencies: ClassVar[List[str]] = [] + extra_dependencies: ClassVar[List[str]] = [] + + @validator("module_path", pre=True) + def validate_module_path(cls, v, values): + if not v: + raise ValueError("module_path cannot be empty or None") + return v diff --git a/olive/passes/pytorch/__init__.py b/olive/passes/pytorch/__init__.py index a1d1f5259..862c45ce3 100644 --- a/olive/passes/pytorch/__init__.py +++ b/olive/passes/pytorch/__init__.py @@ -2,20 +2,3 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from olive.passes.pytorch.gptq import GptqQuantizer -from olive.passes.pytorch.lora import LoftQ, LoRA, QLoRA -from olive.passes.pytorch.quantization_aware_training import QuantizationAwareTraining -from olive.passes.pytorch.sparsegpt import SparseGPT -from olive.passes.pytorch.tensor_parallel import PyTorchTensorParallel -from olive.passes.pytorch.torch_trt_conversion import TorchTRTConversion - -__all__ = [ - "GptqQuantizer", - "LoftQ", - "LoRA", - "PyTorchTensorParallel", - "QLoRA", - "QuantizationAwareTraining", - "SparseGPT", - "TorchTRTConversion", -] diff --git a/olive/passes/qnn/__init__.py b/olive/passes/qnn/__init__.py index 20c6e1def..862c45ce3 100644 --- a/olive/passes/qnn/__init__.py +++ b/olive/passes/qnn/__init__.py @@ -2,9 +2,3 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- - -from olive.passes.qnn.context_binary_generator import QNNContextBinaryGenerator -from olive.passes.qnn.conversion import QNNConversion -from olive.passes.qnn.model_lib_generator import QNNModelLibGenerator - -__all__ = ["QNNConversion", "QNNModelLibGenerator", "QNNContextBinaryGenerator"] diff --git a/olive/passes/snpe/__init__.py b/olive/passes/snpe/__init__.py index 14b17daff..862c45ce3 100644 --- a/olive/passes/snpe/__init__.py +++ b/olive/passes/snpe/__init__.py @@ -2,8 +2,3 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -from olive.passes.snpe.conversion import SNPEConversion -from olive.passes.snpe.quantization import SNPEQuantization -from olive.passes.snpe.snpe_to_onnx import SNPEtoONNXConversion - -__all__ = ["SNPEConversion", "SNPEQuantization", "SNPEtoONNXConversion"] diff --git a/olive/systems/azureml/aml_pass_runner.py b/olive/systems/azureml/aml_pass_runner.py index 7ef7103a4..f44e9d655 100644 --- a/olive/systems/azureml/aml_pass_runner.py +++ b/olive/systems/azureml/aml_pass_runner.py @@ -17,6 +17,7 @@ from olive.hardware import AcceleratorSpec from olive.logging import set_verbosity_from_env from olive.model import ModelConfig +from olive.package_config import OlivePackageConfig from olive.passes import REGISTRY as PASS_REGISTRY from olive.passes import FullPassConfig, Pass from olive.resource_path import create_resource_path @@ -108,6 +109,10 @@ def main(raw_args=None): pass_config = json.load(f) pass_type = pass_config["type"].lower() + # Import the pass package configuration from the package_config + package_config = OlivePackageConfig.load_default_config() + package_config.import_pass_module(pass_config["type"]) + if version.parse(ort_version) < version.parse("1.16.0"): # In onnxruntime, the following PRs will make the optimize_model save external data in the temporary folder # * https://github.com/microsoft/onnxruntime/pull/16531 diff --git a/olive/systems/docker/runner.py b/olive/systems/docker/runner.py index df6088cdb..75590a24b 100644 --- a/olive/systems/docker/runner.py +++ b/olive/systems/docker/runner.py @@ -12,6 +12,7 @@ from olive.common.utils import huggingface_login from olive.logging import set_verbosity_from_env from olive.model import ModelConfig +from olive.package_config import OlivePackageConfig from olive.passes.olive_pass import FullPassConfig logger = logging.getLogger("olive") @@ -29,6 +30,11 @@ def runner_entry(config, output_path, output_name): model = ModelConfig.from_json(model_json).create_model() pass_config = config_json["pass"] + + # Import the pass package configuration from the package_config + package_config = OlivePackageConfig.load_default_config() + package_config.import_pass_module(pass_config["type"]) + the_pass = FullPassConfig.from_json(pass_config).create_pass() output_model = the_pass.run(model, None, output_path) # save model json diff --git a/olive/systems/python_environment/pass_runner.py b/olive/systems/python_environment/pass_runner.py index ef3ebd612..7bbf0ba06 100644 --- a/olive/systems/python_environment/pass_runner.py +++ b/olive/systems/python_environment/pass_runner.py @@ -9,6 +9,7 @@ from olive.common.utils import set_tempdir from olive.logging import set_verbosity_from_env from olive.model import ModelConfig +from olive.package_config import OlivePackageConfig from olive.passes.olive_pass import FullPassConfig @@ -33,7 +34,13 @@ def main(raw_args=None): set_tempdir(args.tempdir) model = ModelConfig.parse_file(args.model_config).create_model() - the_pass = FullPassConfig.parse_file(args.pass_config).create_pass() + pass_config = FullPassConfig.parse_file(args.pass_config) + + # Import the pass package configuration from the package_config + package_config = OlivePackageConfig.load_default_config() + package_config.import_pass_module(pass_config.type) + + the_pass = pass_config.create_pass() # run pass output_model = the_pass.run(model, args.data_root, args.output_model_path) diff --git a/olive/workflows/run/__main__.py b/olive/workflows/run/__main__.py index 19c178760..5f2aa463e 100644 --- a/olive/workflows/run/__main__.py +++ b/olive/workflows/run/__main__.py @@ -9,9 +9,19 @@ if __name__ == "__main__": parser = argparse.ArgumentParser("Olive Workflow: Custom Run") - parser.add_argument("--config", type=str, help="Path to json config file", required=True) + parser.add_argument( + "--package-config", + type=str, + required=False, + help=( + "For advanced users. Path to optional package (json) config file with location " + "of individual pass module implementation and corresponding dependencies." + "Configuration might also include user owned/proprietary/private pass implementations." + ), + ) + parser.add_argument("--run-config", "--config", type=str, help="Path to json config file", required=True) parser.add_argument("--setup", help="Whether run environment setup", action="store_true") - parser.add_argument("--data_root", help="The data root path for optimization", required=False) + parser.add_argument("--data-root", "--data_root", help="The data root path for optimization", required=False) parser.add_argument("--tempdir", type=str, help="Root directory for tempfile directories and files", required=False) args = parser.parse_args() diff --git a/olive/workflows/run/config.py b/olive/workflows/run/config.py index 72699e3b2..88e9650a9 100644 --- a/olive/workflows/run/config.py +++ b/olive/workflows/run/config.py @@ -16,7 +16,7 @@ from olive.engine.packaging.packaging_config import PackagingConfig from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.model import ModelConfig -from olive.passes import FullPassConfig, Pass +from olive.passes import FullPassConfig from olive.passes.pass_config import PassParamDefault from olive.resource_path import AZUREML_RESOURCE_TYPES from olive.systems.system_config import SystemConfig @@ -220,32 +220,30 @@ def validate_pass_search(cls, v, values): disable_search = disable_search or False v["disable_search"] = disable_search - pass_cls = Pass.registry.get(v["type"].lower(), None) - if pass_cls: - if not v.get("config"): - return v - - searchable_configs = set() - for param_name in v["config"]: - if v["config"][param_name] == PassParamDefault.SEARCHABLE_VALUES: - searchable_configs.add(param_name) - if param_name.endswith("data_config"): - # we won't auto insert the input model data config for pass - # user must explicitly set the data config to INPUT_MODEL_DATA_CONFIG if needed - v["config"] = _resolve_data_config(v["config"], values, param_name, auto_insert=False) - - data_dir_config = v["config"].get("data_dir", None) - if isinstance(data_dir_config, dict): - if _have_aml_client(data_dir_config, values): - data_dir_config["config"]["azureml_client"] = values["azureml_client"] - v["config"]["data_dir"] = data_dir_config - - if disable_search and searchable_configs: - raise ValueError( - f"You cannot disable search for {v['type']} and" - f" set {searchable_configs} to SEARCHABLE_VALUES at the same time." - " Please remove SEARCHABLE_VALUES or enable search(needs search strategy configs)." - ) + if not v.get("config"): + return v + + searchable_configs = set() + for param_name in v["config"]: + if v["config"][param_name] == PassParamDefault.SEARCHABLE_VALUES: + searchable_configs.add(param_name) + if param_name.endswith("data_config"): + # we won't auto insert the input model data config for pass + # user must explicitly set the data config to INPUT_MODEL_DATA_CONFIG if needed + v["config"] = _resolve_data_config(v["config"], values, param_name, auto_insert=False) + + data_dir_config = v["config"].get("data_dir", None) + if isinstance(data_dir_config, dict): + if _have_aml_client(data_dir_config, values): + data_dir_config["config"]["azureml_client"] = values["azureml_client"] + v["config"]["data_dir"] = data_dir_config + + if disable_search and searchable_configs: + raise ValueError( + f"You cannot disable search for {v['type']} and" + f" set {searchable_configs} to SEARCHABLE_VALUES at the same time." + " Please remove SEARCHABLE_VALUES or enable search(needs search strategy configs)." + ) return v @validator("auto_optimizer_config", always=True) diff --git a/olive/workflows/run/run.py b/olive/workflows/run/run.py index c3d6fc8ca..af1a30075 100644 --- a/olive/workflows/run/run.py +++ b/olive/workflows/run/run.py @@ -3,27 +3,25 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import importlib.metadata -import json import logging -import os import subprocess import sys +from copy import deepcopy from pathlib import Path from typing import List, Union from olive.auto_optimizer import AutoOptimizer from olive.hardware.accelerator import create_accelerators from olive.logging import enable_filelog, set_default_logger_severity, set_ort_logger_severity, set_verbosity_info +from olive.package_config import OlivePackageConfig from olive.systems.common import SystemType from olive.workflows.run.config import RunConfig, RunPassConfig logger = logging.getLogger(__name__) -def dependency_setup(config: RunConfig): - here = os.path.abspath(os.path.dirname(__file__)) - with open(os.path.join(here, "../../extra_dependencies.json")) as f: - extras = json.load(f) +def dependency_setup(package_config: OlivePackageConfig, run_config: RunConfig): + extras = deepcopy(package_config.extra_dependencies) def get_system_extras(host_type, accelerators, execution_providers): extra_name = None @@ -45,41 +43,23 @@ def get_system_extras(host_type, accelerators, execution_providers): return extra_name def get_pass_extras(pass_type): - pass_to_extra = { - "OnnxFloatToFloat16": ["onnxconverter-common"], - "OrtPerfTuning": ["psutil"], - "QuantizationAwareTraining": ["pytorch-lightning"], - "GptqQuantizer": ["auto-gptq", "optimum"], - } - - pass_to_extra_names = { - "OpenVINOConversion": ["openvino"], - "OpenVINOQuantization": ["openvino"], - "IncQuantization": ["inc"], - "IncDynamicQuantization": ["inc"], - "IncStaticQuantization": ["inc"], - "OptimumConversion": ["optimum"], - "OptimumMerging": ["optimum"], - "TorchTRTConversion": ["torch-tensorrt"], - "LoRA": ["lora"], - "QLoRA": ["bnb", "lora"], - } + pass_module_config = package_config.passes.get(pass_type) extra_results = [] - extra_results.extend(pass_to_extra.get(pass_type, [])) - for extra_name in pass_to_extra_names.get(pass_type, []): - extra_results.extend(extras.get(extra_name)) + extra_results.extend(pass_module_config.module_dependencies) + for extra_name in pass_module_config.extra_dependencies: + extra_results.extend(package_config.extra_dependencies.get(extra_name, [])) return extra_results - ort_packages = ["onnxruntime", "onnxruntime-directml", "onnxruntime-gpu", "onnxruntime-openvino"] + ort_packages = extras.get("ort", []) local_packages = [] remote_packages = [] # add dependencies for passes - if config.passes: - for pass_config in config.passes.values(): - host = pass_config.host or config.engine.host + if run_config.passes: + for pass_config in run_config.passes.values(): + host = pass_config.host or run_config.engine.host if (host and host.type == SystemType.Local) or not host: local_packages.extend(get_pass_extras(pass_config.type)) else: @@ -95,10 +75,10 @@ def get_pass_extras(pass_type): host_type = None accelerators = [] execution_providers = [] - if config.engine.host: - host_type = config.engine.host.type - if config.engine.host.config.accelerators: - for acc in config.engine.host.config.accelerators: + if run_config.engine.host: + host_type = run_config.engine.host.type + if run_config.engine.host.config.accelerators: + for acc in run_config.engine.host.config.accelerators: accelerators.append(acc.device) if acc.execution_providers: execution_providers.extend(acc.execution_providers) @@ -134,55 +114,55 @@ def get_pass_extras(pass_type): if remote_packages: logger.info( "Please make sure the following packages are installed in %s environment: %s", - config.engine.host.type, + run_config.engine.host.type, remote_packages, ) -def run_engine(config: RunConfig, data_root: str = None): +def run_engine(package_config: OlivePackageConfig, run_config: RunConfig, data_root: str = None): import onnxruntime as ort from olive.passes import Pass # for onnxruntime # ort_py_log_severity_level: python logging levels - set_ort_logger_severity(config.engine.ort_py_log_severity_level) + set_ort_logger_severity(run_config.engine.ort_py_log_severity_level) # ort_log_severity_level: C++ logging levels - ort.set_default_logger_severity(config.engine.ort_log_severity_level) + ort.set_default_logger_severity(run_config.engine.ort_log_severity_level) # input model - input_model = config.input_model + input_model = run_config.input_model # Azure ML Client - if config.azureml_client: - config.engine.azureml_client_config = config.azureml_client + if run_config.azureml_client: + run_config.engine.azureml_client_config = run_config.azureml_client - engine = config.engine.create_engine() + engine = run_config.engine.create_engine() - # config file will be uploaded to AML job - is_azureml_system = (config.engine.host is not None and config.engine.host.type == SystemType.AzureML) or ( - config.engine.target is not None and config.engine.target.type == SystemType.AzureML + # run_config file will be uploaded to AML job + is_azureml_system = (run_config.engine.host is not None and run_config.engine.host.type == SystemType.AzureML) or ( + run_config.engine.target is not None and run_config.engine.target.type == SystemType.AzureML ) if is_azureml_system: from olive.systems.azureml.aml_system import AzureMLSystem - AzureMLSystem.olive_config = config.to_json() + AzureMLSystem.olive_config = run_config.to_json() no_evaluation = ( engine.evaluator_config is None - and config.passes - and all(pass_config.evaluator is None for pass_config in config.passes.values()) + and run_config.passes + and all(pass_config.evaluator is None for pass_config in run_config.passes.values()) ) accelerator_specs = create_accelerators(engine.target_config, skip_supported_eps_check=no_evaluation) pass_list = [] acc_list = [] if ( - not config.passes - and config.auto_optimizer_config is not None - and not config.auto_optimizer_config.disable_auto_optimizer + not run_config.passes + and run_config.auto_optimizer_config is not None + and not run_config.auto_optimizer_config.disable_auto_optimizer ): # For auto optimizer, Olive generates passes and pass_flows for each accelerator # that means, the passes and pass_flows might be different for each accelerator @@ -191,8 +171,8 @@ def run_engine(config: RunConfig, data_root: str = None): input_model, engine.evaluator_config, acc_spec, - config.auto_optimizer_config, - config.data_configs, + run_config.auto_optimizer_config, + run_config.data_configs, ).suggest() pass_list.append(({k: RunPassConfig.parse_obj(v) for k, v in _passes.items()}, pass_flows)) acc_list.append([acc_spec]) @@ -200,7 +180,7 @@ def run_engine(config: RunConfig, data_root: str = None): # For non-auto-optimizer case, Olive uses the same passes and pass_flows for all accelerators # if user needs different passes and pass_flows for each accelerator, they need to write multiple # config files. - pass_list.append((config.passes, config.pass_flows)) + pass_list.append((run_config.passes, run_config.pass_flows)) acc_list.append(accelerator_specs) run_rls = {} @@ -214,6 +194,12 @@ def run_engine(config: RunConfig, data_root: str = None): for accelerator_spec, (passes, pass_flows) in zip(acc_list, pass_list): engine.reset_passes() if passes: + # First pass registers the necessary module implementation + for pass_config in passes.values(): + logger.info("Importing pass module %s", pass_config.type) + package_config.import_pass_module(pass_config.type) + + # Second pass, initializes the pass and registers it with the engine for pass_name, pass_config in passes.items(): host = pass_config.host.create_system() if pass_config.host is not None else None engine.register( @@ -229,7 +215,7 @@ def run_engine(config: RunConfig, data_root: str = None): engine.set_pass_flows(pass_flows) if data_root is None: - data_root = config.data_root + data_root = run_config.data_root # run run_rls.update( @@ -237,34 +223,49 @@ def run_engine(config: RunConfig, data_root: str = None): input_model, accelerator_spec, data_root, - config.engine.packaging_config, - config.engine.output_dir, - config.engine.output_name, - config.engine.evaluate_input_model, + run_config.engine.packaging_config, + run_config.engine.output_dir, + run_config.engine.output_name, + run_config.engine.evaluate_input_model, ) ) return run_rls -def run(config: Union[str, Path, dict], setup: bool = False, data_root: str = None): +def run( + run_config: Union[str, Path, dict], + setup: bool = False, + data_root: str = None, + package_config: Union[str, Path, dict] = None, +): + if package_config is None: + package_config = OlivePackageConfig.get_default_config_path() + # we use parse_file and parse_obj to be safe. If implemented as expected, both should be equivalent. - if isinstance(config, (str, Path)): - config = RunConfig.parse_file(config) + logger.info("Loading Olive module configuration: %s", package_config) + if isinstance(package_config, (str, Path)): + package_config = OlivePackageConfig.parse_file(package_config) + else: + package_config = OlivePackageConfig.parse_obj(package_config) + + logger.info("Loading run configuration: %s", run_config) + if isinstance(run_config, (str, Path)): + run_config = RunConfig.parse_file(run_config) else: - config = RunConfig.parse_obj(config) + run_config = RunConfig.parse_obj(run_config) # set log level for olive - set_default_logger_severity(config.engine.log_severity_level) - if config.engine.log_to_file: - enable_filelog(config.engine.log_severity_level) + set_default_logger_severity(run_config.engine.log_severity_level) + if run_config.engine.log_to_file: + enable_filelog(run_config.engine.log_severity_level) if setup: # set the log level to INFO for setup set_verbosity_info() - dependency_setup(config) + dependency_setup(package_config, run_config) return None else: - return run_engine(config, data_root) + return run_engine(package_config, run_config, data_root) def check_local_ort_installation(package_name: str): diff --git a/setup.py b/setup.py index 2dcd2df3f..75f6d92a2 100644 --- a/setup.py +++ b/setup.py @@ -25,13 +25,13 @@ def get_version(rel_path): def get_extra_deps(rel_path): here = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(here, rel_path)) as fp: - return json.load(fp) + return json.load(fp)["extra_dependencies"] # use techniques described at https://packaging.python.org/en/latest/guides/single-sourcing-package-version/ # Don't use technique 6 since it needs extra dependencies. VERSION = get_version("olive/__init__.py") -EXTRAS = get_extra_deps("olive/extra_dependencies.json") +EXTRAS = get_extra_deps("olive/olive_config.json") with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")) as req_file: requirements = req_file.read().splitlines() @@ -81,7 +81,7 @@ def get_extra_deps(rel_path): extras_require=EXTRAS, include_package_data=False, package_data={ - "olive": ["extra_dependencies.json"], + "olive": ["olive_config.json"], "olive.auto_optimizer": ["config_template/*.yaml"], "olive.engine.packaging": ["sample_code/*/*/*"], "olive.platform_sdk.qualcomm": ["create_python_env.sh", "create_python_env.ps1", "copy_libcdsprpc.ps1"], diff --git a/test/integ_test/aml_model_test/test_aml_model.py b/test/integ_test/aml_model_test/test_aml_model.py index 31b68916f..a70490d4c 100644 --- a/test/integ_test/aml_model_test/test_aml_model.py +++ b/test/integ_test/aml_model_test/test_aml_model.py @@ -7,8 +7,8 @@ from olive.azureml.azureml_client import AzureMLClientConfig from olive.model import ModelConfig -from olive.passes import OnnxConversion from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.onnx.conversion import OnnxConversion from olive.resource_path import ResourcePath from olive.systems.azureml import AzureMLDockerConfig, AzureMLSystem diff --git a/test/multiple_ep/utils.py b/test/multiple_ep/utils.py index b5a44a58c..e5b97423b 100644 --- a/test/multiple_ep/utils.py +++ b/test/multiple_ep/utils.py @@ -12,7 +12,7 @@ from olive.evaluator.metric import LatencySubType, Metric, MetricType from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware.accelerator import create_accelerators -from olive.passes.onnx import OrtPerfTuning +from olive.passes.onnx.perf_tuning import OrtPerfTuning # pylint: disable=redefined-outer-name diff --git a/test/unit_test/engine/test_engine.py b/test/unit_test/engine/test_engine.py index 2707c96f0..f2f0e642c 100644 --- a/test/unit_test/engine/test_engine.py +++ b/test/unit_test/engine/test_engine.py @@ -22,7 +22,9 @@ from olive.evaluator.olive_evaluator import OliveEvaluatorConfig from olive.hardware import DEFAULT_CPU_ACCELERATOR from olive.hardware.accelerator import create_accelerators -from olive.passes.onnx import OnnxConversion, OnnxDynamicQuantization, OnnxStaticQuantization, OptimumConversion +from olive.passes.onnx.conversion import OnnxConversion +from olive.passes.onnx.optimum_conversion import OptimumConversion +from olive.passes.onnx.quantization import OnnxDynamicQuantization, OnnxStaticQuantization from olive.systems.common import SystemType from olive.systems.local import LocalSystem from olive.systems.system_config import LocalTargetUserConfig, SystemConfig diff --git a/test/unit_test/passes/common/test_user_script.py b/test/unit_test/passes/common/test_user_script.py index ad22c8935..73c0c7ff1 100644 --- a/test/unit_test/passes/common/test_user_script.py +++ b/test/unit_test/passes/common/test_user_script.py @@ -8,7 +8,7 @@ from olive.common.pydantic_v1 import ValidationError from olive.hardware import DEFAULT_CPU_ACCELERATOR -from olive.passes.onnx import OrtPerfTuning +from olive.passes.onnx.perf_tuning import OrtPerfTuning class TestUserScriptConfig: diff --git a/test/unit_test/passes/onnx/test_bnb_quantization.py b/test/unit_test/passes/onnx/test_bnb_quantization.py index 6a5db73ef..6eb3cc139 100644 --- a/test/unit_test/passes/onnx/test_bnb_quantization.py +++ b/test/unit_test/passes/onnx/test_bnb_quantization.py @@ -12,7 +12,7 @@ from olive.model import ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx import OnnxBnb4Quantization +from olive.passes.onnx.bnb_quantization import OnnxBnb4Quantization # pylint: disable=protected-access diff --git a/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py b/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py index a871b2d3a..d780ec4d4 100644 --- a/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py +++ b/test/unit_test/passes/onnx/test_dynamic_to_fixed_shape.py @@ -4,7 +4,7 @@ from olive.model import ONNXModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx import DynamicToFixedShape +from olive.passes.onnx.dynamic_to_fixed_shape import DynamicToFixedShape @pytest.mark.parametrize( diff --git a/test/unit_test/passes/onnx/test_model_optimizer.py b/test/unit_test/passes/onnx/test_model_optimizer.py index ab578c443..730cc2797 100644 --- a/test/unit_test/passes/onnx/test_model_optimizer.py +++ b/test/unit_test/passes/onnx/test_model_optimizer.py @@ -11,8 +11,8 @@ from olive.hardware import DEFAULT_CPU_ACCELERATOR, DEFAULT_GPU_CUDA_ACCELERATOR from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx import OnnxModelOptimizer from olive.passes.onnx.common import model_proto_to_olive_model +from olive.passes.onnx.model_optimizer import OnnxModelOptimizer if TYPE_CHECKING: from olive.model import ONNXModelHandler diff --git a/test/unit_test/passes/onnx/test_perf_tuning.py b/test/unit_test/passes/onnx/test_perf_tuning.py index 6933affe7..151d37157 100644 --- a/test/unit_test/passes/onnx/test_perf_tuning.py +++ b/test/unit_test/passes/onnx/test_perf_tuning.py @@ -15,8 +15,7 @@ from olive.evaluator.olive_evaluator import OliveEvaluator, OnnxEvaluator from olive.hardware.accelerator import DEFAULT_CPU_ACCELERATOR, DEFAULT_GPU_CUDA_ACCELERATOR, AcceleratorSpec, Device from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx import OrtPerfTuning -from olive.passes.onnx.perf_tuning import PERFTUNING_BASELINE, PerfTuningRunner, generate_test_name +from olive.passes.onnx.perf_tuning import PERFTUNING_BASELINE, OrtPerfTuning, PerfTuningRunner, generate_test_name @pytest.mark.parametrize( @@ -38,7 +37,7 @@ def test_ort_perf_tuning_pass(config, tmp_path): p.run(input_model, None, output_folder) -@patch("olive.passes.onnx.OrtPerfTuning._run_for_config") +@patch("olive.passes.onnx.perf_tuning.OrtPerfTuning._run_for_config") @pytest.mark.parametrize( "config", [ diff --git a/test/unit_test/passes/onnx/test_qnn_preprocess.py b/test/unit_test/passes/onnx/test_qnn_preprocess.py index 68adb8520..b50580ef1 100644 --- a/test/unit_test/passes/onnx/test_qnn_preprocess.py +++ b/test/unit_test/passes/onnx/test_qnn_preprocess.py @@ -7,7 +7,7 @@ from packaging import version from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx import QNNPreprocess +from olive.passes.onnx.qnn_preprocess import QNNPreprocess @pytest.mark.skipif( diff --git a/test/unit_test/passes/onnx/test_quantization.py b/test/unit_test/passes/onnx/test_quantization.py index 3f86a5359..944cee7c1 100644 --- a/test/unit_test/passes/onnx/test_quantization.py +++ b/test/unit_test/passes/onnx/test_quantization.py @@ -8,7 +8,7 @@ from olive.common.pydantic_v1 import ValidationError from olive.hardware.accelerator import AcceleratorSpec from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx import OnnxMatMul4Quantizer, OnnxQuantization, OnnxStaticQuantization +from olive.passes.onnx.quantization import OnnxMatMul4Quantizer, OnnxQuantization, OnnxStaticQuantization class DummyCalibrationDataReader(CalibrationDataReader): diff --git a/test/unit_test/passes/onnx/test_transformer_optimization.py b/test/unit_test/passes/onnx/test_transformer_optimization.py index 8fadcb289..3035e3492 100644 --- a/test/unit_test/passes/onnx/test_transformer_optimization.py +++ b/test/unit_test/passes/onnx/test_transformer_optimization.py @@ -13,8 +13,8 @@ from olive.hardware import DEFAULT_CPU_ACCELERATOR, DEFAULT_GPU_CUDA_ACCELERATOR, DEFAULT_GPU_TRT_ACCELERATOR from olive.hardware.accelerator import AcceleratorSpec, Device -from olive.passes.onnx import OrtTransformersOptimization from olive.passes.onnx.common import get_external_data_config +from olive.passes.onnx.transformer_optimization import OrtTransformersOptimization # pylint: disable=redefined-outer-name, abstract-method, protected-access diff --git a/test/unit_test/passes/pytorch/test_gptq.py b/test/unit_test/passes/pytorch/test_gptq.py index 303214153..7335fe036 100644 --- a/test/unit_test/passes/pytorch/test_gptq.py +++ b/test/unit_test/passes/pytorch/test_gptq.py @@ -10,7 +10,7 @@ from olive.hardware.accelerator import AcceleratorSpec, Device from olive.model.handler.pytorch import PyTorchModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.pytorch import GptqQuantizer +from olive.passes.pytorch.gptq import GptqQuantizer def get_dummy_dataloader_func(): diff --git a/test/unit_test/passes/pytorch/test_lora.py b/test/unit_test/passes/pytorch/test_lora.py index 9f4e8ec80..7ab20791a 100644 --- a/test/unit_test/passes/pytorch/test_lora.py +++ b/test/unit_test/passes/pytorch/test_lora.py @@ -15,7 +15,7 @@ from olive.data.template import huggingface_data_config_template from olive.model import PyTorchModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.pytorch import LoftQ, LoRA, QLoRA +from olive.passes.pytorch.lora import LoftQ, LoRA, QLoRA # pylint: disable=redefined-outer-name diff --git a/test/unit_test/passes/pytorch/test_quantization_aware_training.py b/test/unit_test/passes/pytorch/test_quantization_aware_training.py index b85c2e8ba..47a337cea 100644 --- a/test/unit_test/passes/pytorch/test_quantization_aware_training.py +++ b/test/unit_test/passes/pytorch/test_quantization_aware_training.py @@ -5,7 +5,7 @@ from test.unit_test.utils import create_dataloader, get_pytorch_model from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.pytorch import QuantizationAwareTraining +from olive.passes.pytorch.quantization_aware_training import QuantizationAwareTraining def test_quantization_aware_training_pass_default(tmp_path): diff --git a/test/unit_test/passes/pytorch/test_sparsegpt.py b/test/unit_test/passes/pytorch/test_sparsegpt.py index 1bbfe44c2..792e52db8 100644 --- a/test/unit_test/passes/pytorch/test_sparsegpt.py +++ b/test/unit_test/passes/pytorch/test_sparsegpt.py @@ -5,7 +5,7 @@ from olive.data.template import huggingface_data_config_template from olive.model import PyTorchModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.pytorch import SparseGPT +from olive.passes.pytorch.sparsegpt import SparseGPT def test_sparsegpt(tmp_path): diff --git a/test/unit_test/passes/pytorch/test_torch_trt_conversion.py b/test/unit_test/passes/pytorch/test_torch_trt_conversion.py index 374876771..d5df0a1c4 100644 --- a/test/unit_test/passes/pytorch/test_torch_trt_conversion.py +++ b/test/unit_test/passes/pytorch/test_torch_trt_conversion.py @@ -13,7 +13,7 @@ from olive.data.template import huggingface_data_config_template from olive.model import PyTorchModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.pytorch import TorchTRTConversion +from olive.passes.pytorch.torch_trt_conversion import TorchTRTConversion # pylint: disable=abstract-method diff --git a/test/unit_test/test_module_config.py b/test/unit_test/test_module_config.py new file mode 100644 index 000000000..74a6ffea1 --- /dev/null +++ b/test/unit_test/test_module_config.py @@ -0,0 +1,15 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from olive.package_config import OlivePackageConfig + + +class TestModuleConfig: + def test_passes_configuration(self): + package_config = OlivePackageConfig.load_default_config() + for pass_module_name, pass_module_config in package_config.passes.items(): + assert pass_module_config.module_path + assert pass_module_config.module_path.endswith(pass_module_name) + package_config.import_pass_module(pass_module_name) diff --git a/test/unit_test/utils.py b/test/unit_test/utils.py index b97ac631d..85dfc3d46 100644 --- a/test/unit_test/utils.py +++ b/test/unit_test/utils.py @@ -18,7 +18,6 @@ from olive.evaluator.metric_config import MetricGoal from olive.model import ModelConfig, ONNXModelHandler, PyTorchModelHandler from olive.passes.olive_pass import create_pass_from_dict -from olive.passes.onnx import OnnxConversion, OnnxDynamicQuantization ONNX_MODEL_PATH = Path(__file__).absolute().parent / "dummy_model.onnx" @@ -256,6 +255,8 @@ def get_throughput_metric(*lat_subtype, user_config=None): def get_onnxconversion_pass(ignore_pass_config=True, target_opset=13): + from olive.passes.onnx.conversion import OnnxConversion + onnx_conversion_config = {"target_opset": target_opset} p = create_pass_from_dict(OnnxConversion, onnx_conversion_config) if ignore_pass_config: @@ -266,6 +267,8 @@ def get_onnxconversion_pass(ignore_pass_config=True, target_opset=13): def get_onnx_dynamic_quantization_pass(disable_search=False): + from olive.passes.onnx.quantization import OnnxDynamicQuantization + return create_pass_from_dict(OnnxDynamicQuantization, disable_search=disable_search) From c360bfeb9c4ee68856f6d49ab07090279fa9c69c Mon Sep 17 00:00:00 2001 From: shaahji <96227573+shaahji@users.noreply.github.com> Date: Fri, 5 Apr 2024 03:50:13 -0700 Subject: [PATCH 5/6] [Docs]: Delayed Python pass module load (#1051) --- README.md | 4 +- docs/source/api/systems.rst | 2 +- docs/source/getstarted/installation.md | 5 +- docs/source/overview/quicktour.md | 3 ++ olive/extra_dependencies.json | 48 ------------------- ...odule_config.py => test_package_config.py} | 2 +- 6 files changed, 10 insertions(+), 54 deletions(-) delete mode 100644 olive/extra_dependencies.json rename test/unit_test/{test_module_config.py => test_package_config.py} (96%) diff --git a/README.md b/README.md index 0406332f5..e6c6cf0e3 100644 --- a/README.md +++ b/README.md @@ -66,8 +66,8 @@ pip install olive-ai[directml] ``` ### Optional Dependencies -Olive has optional dependencies that can be installed to enable additional features. Please refer to [extra dependencies](./olive/extra_dependencies.json) for -the list of extras and their dependencies. +Olive has optional dependencies that can be installed to enable additional features. Please refer to +[Olive package config](./olive/olive_config.json) for the list of extras and their dependencies. ## Pipeline Status diff --git a/docs/source/api/systems.rst b/docs/source/api/systems.rst index 36c0c333d..41fa1e29f 100644 --- a/docs/source/api/systems.rst +++ b/docs/source/api/systems.rst @@ -82,7 +82,7 @@ PythonEnvironmentSystem IsolatedORTSystem ^^^^^^^^^^^^^^^^^^^^^^^^ -.. autoclass:: olive.systems.ort_evironment.IsolatedORTSystem +.. autoclass:: olive.systems.isolated_ort.IsolatedORTSystem .. _olive_system_alias: diff --git a/docs/source/getstarted/installation.md b/docs/source/getstarted/installation.md index 546ef799b..bc1080350 100644 --- a/docs/source/getstarted/installation.md +++ b/docs/source/getstarted/installation.md @@ -62,5 +62,6 @@ pip install -e . ``` ## Optional Dependencies -Olive has optional dependencies that can be installed to enable additional features. Please refer to [extra dependencies](https://github.com/microsoft/Olive/blob/main/olive/extra_dependencies.json) -for the list of extras and their dependencies. +Olive has optional dependencies that can be installed to enable additional features. Please refer to +[Olive package config](https://github.com/microsoft/Olive/blob/main/olive/olive_config.json) for the list of extras +and their dependencies. diff --git a/docs/source/overview/quicktour.md b/docs/source/overview/quicktour.md index 5632889c1..82c587350 100644 --- a/docs/source/overview/quicktour.md +++ b/docs/source/overview/quicktour.md @@ -31,6 +31,9 @@ olive_run("my_model_acceleration_description.json") You can use setup mode `python -m olive.workflows.run --config my_model_acceleration_description.json --setup` to identify list of additional packages you may need to install for your workflow. +To include user implemented (or proprietary, or private) passes as part of workflow, clone olive_config.json and update it. +Provide the path to the cloned _olive_config.json_ file at launch using the '--package-config' command line option. + You can also change the default directory for temporary files and directories using `--tempdir` option. Set this to a local directory if you want to avoid using the default tempdir for reasons such as disk space and permissions. diff --git a/olive/extra_dependencies.json b/olive/extra_dependencies.json deleted file mode 100644 index cf12deae0..000000000 --- a/olive/extra_dependencies.json +++ /dev/null @@ -1,48 +0,0 @@ -{ - "azureml": [ - "azure-ai-ml>=1.11.1", - "azure-keyvault-secrets", - "azure-identity", - "azureml-fsspec" - ], - "docker": [ - "docker" - ], - "cpu": [ - "onnxruntime" - ], - "gpu": [ - "onnxruntime-gpu" - ], - "directml": [ - "onnxruntime-directml" - ], - "openvino": [ - "openvino==2023.2.0", - "nncf==2.7.0" - ], - "tf": [ - "tensorflow==1.15.0" - ], - "inc": [ - "neural-compressor" - ], - "optimum": [ - "optimum" - ], - "torch-tensorrt": [ - "torch-tensorrt" - ], - "lora": [ - "accelerate", - "peft", - "scipy" - ], - "bnb": [ - "bitsandbytes" - ], - "ort-training": [ - "onnxruntime-training", - "torch-ort" - ] -} diff --git a/test/unit_test/test_module_config.py b/test/unit_test/test_package_config.py similarity index 96% rename from test/unit_test/test_module_config.py rename to test/unit_test/test_package_config.py index 74a6ffea1..28aa875cb 100644 --- a/test/unit_test/test_module_config.py +++ b/test/unit_test/test_package_config.py @@ -6,7 +6,7 @@ from olive.package_config import OlivePackageConfig -class TestModuleConfig: +class TestPackageConfig: def test_passes_configuration(self): package_config = OlivePackageConfig.load_default_config() for pass_module_name, pass_module_config in package_config.passes.items(): From bd4fa38606e3befc48b9a804a5ee946bce727bdc Mon Sep 17 00:00:00 2001 From: trajep Date: Sun, 7 Apr 2024 15:30:43 +0800 Subject: [PATCH 6/6] =?UTF-8?q?=E2=98=82=EF=B8=8F=20Make=20`data=5Fconfig`?= =?UTF-8?q?=20optional=20for=20SNPE=20quantization=20(#1054)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Describe your changes 1. Make `data_config` optional for SNPE quantization as it can accept `dataloder_func` or `data_config` 2. Update run logs to let it output `package_config, run_config` only if they are str or path. Output the object will let the logs look a bit messy. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. - [ ] Is this PR including examples changes? If yes, please remember to update [example documentation](https://github.com/microsoft/Olive/blob/main/docs/source/examples.md) in a follow-up PR. ## (Optional) Issue link --- olive/passes/snpe/quantization.py | 1 - olive/workflows/run/run.py | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/olive/passes/snpe/quantization.py b/olive/passes/snpe/quantization.py index 6f1b167d7..d6b2ff704 100644 --- a/olive/passes/snpe/quantization.py +++ b/olive/passes/snpe/quantization.py @@ -51,7 +51,6 @@ def _default_config(cls, accelerator_spec: AcceleratorSpec) -> Dict[str, PassCon ), "data_config": PassConfigParam( type_=Union[DataConfig, Dict], - required=True, description="Data config for quantization, required if dataloader_func is None", ), "use_enhanced_quantizer": PassConfigParam( diff --git a/olive/workflows/run/run.py b/olive/workflows/run/run.py index af1a30075..896809547 100644 --- a/olive/workflows/run/run.py +++ b/olive/workflows/run/run.py @@ -242,14 +242,14 @@ def run( package_config = OlivePackageConfig.get_default_config_path() # we use parse_file and parse_obj to be safe. If implemented as expected, both should be equivalent. - logger.info("Loading Olive module configuration: %s", package_config) if isinstance(package_config, (str, Path)): + logger.info("Loading Olive module configuration from: %s", package_config) package_config = OlivePackageConfig.parse_file(package_config) else: package_config = OlivePackageConfig.parse_obj(package_config) - logger.info("Loading run configuration: %s", run_config) if isinstance(run_config, (str, Path)): + logger.info("Loading run configuration from: %s", run_config) run_config = RunConfig.parse_file(run_config) else: run_config = RunConfig.parse_obj(run_config)