diff --git a/olive/cli/base.py b/olive/cli/base.py index 69b83f859..63bcf6e9c 100644 --- a/olive/cli/base.py +++ b/olive/cli/base.py @@ -11,13 +11,13 @@ from abc import ABC, abstractmethod from argparse import ArgumentParser, Namespace from pathlib import Path -from typing import ClassVar, Dict, Optional, Union +from typing import ClassVar, Dict, Optional import yaml from olive.cli.constants import CONDA_CONFIG from olive.common.user_module_loader import UserModuleLoader -from olive.common.utils import hash_dict +from olive.common.utils import hash_dict, set_nested_dict_value, unescaped_str class BaseOliveCLICommand(ABC): @@ -39,280 +39,508 @@ def register_subcommand(parser: ArgumentParser): def run(self): raise NotImplementedError - -def _get_hf_input_model(args, model_path): - print("Loading HuggingFace model from:", model_path) - input_model = { - "type": "HfModel", - "model_path": model_path, - "load_kwargs": { - "trust_remote_code": args.trust_remote_code, - "attn_implementation": "eager", - }, - } - if args.task: - input_model["task"] = args.task - return input_model - - -def _get_onnx_input_model(model_path): - print("Loading ONNX model from:", model_path) - return { - "type": "OnnxModel", - "model_path": model_path, - } - - -def _get_pt_input_model(args, model_path): - if not args.model_script: - raise ValueError("model_script is not provided. Either model_name_or_path or model_script is required.") - - user_module_loader = UserModuleLoader(args.model_script, args.script_dir) - - if not model_path and not user_module_loader.has_function("_model_loader"): - raise ValueError("Either _model_loader or model_name_or_path is required for PyTorch model.") - - input_model_config = { - "type": "PyTorchModel", - "model_script": args.model_script, - } - - if args.script_dir: - input_model_config["script_dir"] = args.script_dir - - if model_path: - print("Loading PyTorch model from:", model_path) - input_model_config["model_path"] = model_path - - if user_module_loader.has_function("_model_loader"): - print("Loading PyTorch model from function: _model_loader.") - input_model_config["model_loader"] = "_model_loader" - - model_funcs = [ - ("io_config", "_io_config"), - ("dummy_inputs_func", "_dummy_inputs"), - ("model_file_format", "_model_file_format"), - ] - input_model_config.update( - {config_key: func_name for config_key, func_name in model_funcs if user_module_loader.has_function(func_name)} - ) - - if "io_config" not in input_model_config and "dummy_inputs_func" not in input_model_config: - raise ValueError("_io_config or _dummy_inputs is required in the script for PyTorch model.") - return input_model_config - - -def get_input_model_config(args) -> Union[str, Dict[str, str]]: - """Parse the model_name_or_path and return the input model config. - - Check model_name_or_path formats in order: - 1. Local PyTorch model with model loader but no model path - 2. azureml:: (only for PyTorch model) - 3. Load PyTorch model with model_script - 4. azureml://registries//models//versions/ (only for HF model) - 5. https://huggingface.co/ (only for HF model) - 6. HF model name string - 7. local file path - a. local onnx model file path (either a user-provided model or a model produced by the Olive CLI) - b. local HF model file path (either a user-provided model or a model produced by the Olive CLI) - """ - model_name_or_path = args.model_name_or_path - - # Check if local PyTorch model with model loader - if model_name_or_path is None: - print("model_name_or_path is not provided. Using model_script to load the model.") - return _get_pt_input_model(args, None) - - # Check AzureML model - pattern = r"^azureml:(?P[^:]+):(?P[^:]+)$" - match = re.match(pattern, model_name_or_path) - if match: - return _get_pt_input_model( - args, - { - "type": "azureml_model", - "name": match.group("model_name"), - "version": match.group("version"), - }, + @staticmethod + def _add_input_model_options(sub_parser): + model_group = sub_parser.add_argument_group("Model options") + model_group.add_argument( + "-m", + "--model_name_or_path", + type=str, + help=( + "The model checkpoint for weights initialization. If using an AzureML Registry model, provide the model" + " path as 'registry_name:model_name:version'." + ), + ) + model_group.add_argument( + "--trust_remote_code", action="store_true", help="Trust remote code when loading a model." ) + model_group.add_argument("-t", "--task", type=str, help="Task for which the model is used.") + model_group.add_argument( + "--model_script", + type=str, + help="The script file containing the model definition. Required for PyTorch model.", + ) + model_group.add_argument( + "--script_dir", + type=str, + default=None, + help="The directory containing the model script file.", + ) + return model_group + + def _get_hf_input_model(self, model_path): + print("Loading HuggingFace model from:", model_path) + input_model = { + "type": "HfModel", + "model_path": model_path, + "load_kwargs": { + "trust_remote_code": self.args.trust_remote_code, + "attn_implementation": "eager", + }, + } + if self.args.task: + input_model["task"] = self.args.task + return input_model + + def _get_onnx_input_model(self, model_path): + print("Loading ONNX model from:", model_path) + return { + "type": "OnnxModel", + "model_path": model_path, + } - if args.model_script: - return _get_pt_input_model(args, model_name_or_path) + def _get_pt_input_model(self, model_path): + if not self.args.model_script: + raise ValueError("model_script is not provided. Either model_name_or_path or model_script is required.") - # Check AzureML Registry model - pattern = ( - r"^azureml://registries/(?P[^/]+)/models/(?P[^/]+)/versions/(?P[^/]+)$" - ) - match = re.match(pattern, model_name_or_path) - if match: - return _get_hf_input_model( - args, + user_module_loader = UserModuleLoader(self.args.model_script, self.args.script_dir) + + if not model_path and not user_module_loader.has_function("_model_loader"): + raise ValueError("Either _model_loader or model_name_or_path is required for PyTorch model.") + + input_model_config = { + "type": "PyTorchModel", + "model_script": self.args.model_script, + } + + if self.args.script_dir: + input_model_config["script_dir"] = self.args.script_dir + + if model_path: + print("Loading PyTorch model from:", model_path) + input_model_config["model_path"] = model_path + + if user_module_loader.has_function("_model_loader"): + print("Loading PyTorch model from function: _model_loader.") + input_model_config["model_loader"] = "_model_loader" + + model_funcs = [ + ("io_config", "_io_config"), + ("dummy_inputs_func", "_dummy_inputs"), + ("model_file_format", "_model_file_format"), + ] + input_model_config.update( { - "type": "azureml_registry_model", - "registry_name": match.group("registry_name"), - "name": match.group("model_name"), - "version": match.group("version"), - }, + config_key: func_name + for config_key, func_name in model_funcs + if user_module_loader.has_function(func_name) + } ) - # Check HuggingFace url - pattern = r"https://huggingface\.co/([^/]+/[^/]+)(?:/.*)?" - match = re.search(pattern, model_name_or_path) - if match: - return _get_hf_input_model(args, match.group(1)) - - model_path = Path(model_name_or_path) - - # Check HF model name string - if not model_path.exists(): - try: - from huggingface_hub import repo_exists - except ImportError as e: - raise ImportError("Please install huggingface_hub to use the CLI for Huggingface model.") from e - - if not repo_exists(model_name_or_path): - raise ValueError(f"{model_name_or_path} is not a valid Huggingface model name.") - return _get_hf_input_model(args, model_name_or_path) - - # Check if local model is from Olive CLI - if model_path.is_dir(): - for file in model_path.iterdir(): - if file.is_file() and file.name == "model_config.json": - with open(file) as f: - return json.load(f) - - # Check local onnx file (user-provided model) - if model_path.is_file() and model_path.suffix == ".onnx": - return _get_onnx_input_model(model_name_or_path) - - # Check local HF model file (user-provided model) - return _get_hf_input_model(args, model_name_or_path) - - -def add_logging_options(sub_parser): - log_group = sub_parser.add_argument_group("logging options") - log_group.add_argument( - "--log_level", - type=int, - default=3, - help="Logging level. Default is 3. level 0: DEBUG, 1: INFO, 2: WARNING, 3: ERROR, 4: CRITICAL", - ) - - -def add_remote_options(sub_parser): - remote_group = sub_parser.add_argument_group("remote options") - remote_group.add_argument( - "--resource_group", - type=str, - required=False, - help="Resource group for the AzureML workspace.", - ) - remote_group.add_argument( - "--workspace_name", - type=str, - required=False, - help="Workspace name for the AzureML workspace.", - ) - remote_group.add_argument( - "--keyvault_name", - type=str, - required=False, - help=( - "The azureml keyvault name with huggingface token to use for remote run. Refer to" - " https://microsoft.github.io/Olive/features/huggingface_model_optimization.html#huggingface-login for" - " more details." - ), - ) - remote_group.add_argument( - "--aml_compute", - type=str, - required=False, - help="The compute name to run the workflow on.", - ) - - -def add_model_options(sub_parser): - model_group = sub_parser.add_argument_group("Model options") - model_group.add_argument( - "-m", - "--model_name_or_path", - type=str, - help=( - "The model checkpoint for weights initialization. If using an AzureML Registry model, provide the model" - " path as 'registry_name:model_name:version'." - ), - ) - model_group.add_argument("--trust_remote_code", action="store_true", help="Trust remote code when loading a model.") - model_group.add_argument("-t", "--task", type=str, help="Task for which the model is used.") - model_group.add_argument( - "--model_script", - type=str, - help="The script file containing the model definition. Required for PyTorch model.", - ) - model_group.add_argument( - "--script_dir", - type=str, - default=None, - help="The directory containing the model script file.", - ) - - -def is_remote_run(args): - return all([args.resource_group, args.workspace_name, args.aml_compute]) - - -def update_remote_option(config, args, cli_action, tempdir): - if args.resource_group or args.workspace_name or args.aml_compute: - if not is_remote_run(args): - raise ValueError("resource_group, workspace_name and aml_compute are required for remote workflow run.") - - config["workflow_id"] = f"{cli_action}-{hash_dict(config)}" - - try: - subscription_id = json.loads(subprocess.check_output("az account show", shell=True).decode("utf-8"))["id"] - print("Using Azure subscription ID: %s", subscription_id) - - except subprocess.CalledProcessError: - print( - "Error: Unable to retrieve account information. " - "Make sure you are logged in to Azure CLI with command `az login`." + if "io_config" not in input_model_config and "dummy_inputs_func" not in input_model_config: + raise ValueError("_io_config or _dummy_inputs is required in the script for PyTorch model.") + return input_model_config + + def _update_input_model_options(self, config): + """Parse the model_name_or_path and return the input model config. + + Check model_name_or_path formats in order: + 1. Local PyTorch model with model loader but no model path + 2. azureml:: (only for PyTorch model) + 3. Load PyTorch model with model_script + 4. azureml://registries//models//versions/ (only for HF model) + 5. https://huggingface.co/ (only for HF model) + 6. HF model name string + 7. local file path + a. local onnx model file path (either a user-provided model or a model produced by the Olive CLI) + b. local HF model file path (either a user-provided model or a model produced by the Olive CLI) + + """ + model_name_or_path = self.args.model_name_or_path + + # Check if local PyTorch model with model loader + if model_name_or_path is None: + print("model_name_or_path is not provided. Using model_script to load the model.") + config["input_model"] = self._get_pt_input_model(None) + return + + # Check AzureML model + pattern = r"^azureml:(?P[^:]+):(?P[^:]+)$" + match = re.match(pattern, model_name_or_path) + if match: + config["input_model"] = self._get_pt_input_model( + { + "type": "azureml_model", + "name": match.group("model_name"), + "version": match.group("version"), + }, ) + return - config["azureml_client"] = { - "subscription_id": subscription_id, - "resource_group": args.resource_group, - "workspace_name": args.workspace_name, - "keyvault_name": args.keyvault_name, - "default_auth_params": {"exclude_managed_identity_credential": True}, - } + if self.args.model_script: + config["input_model"] = self._get_pt_input_model(model_name_or_path) + return - conda_file_path = Path(tempdir) / "conda_gpu.yaml" - with open(conda_file_path, "w") as f: - yaml.dump(CONDA_CONFIG, f) - - config["systems"]["aml_system"] = { - "type": "AzureML", - "accelerators": [{"device": "GPU", "execution_providers": ["CUDAExecutionProvider"]}], - "aml_compute": args.aml_compute, - "aml_docker_config": { - "base_image": "mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.8-cudnn8-ubuntu22.04", - "conda_file_path": str(conda_file_path), - }, - "hf_token": bool(args.keyvault_name), - } - config["workflow_host"] = "aml_system" + # Check AzureML Registry model + pattern = ( + r"^azureml://registries/(?P[^/]+)/models/(?P[^/]+)/versions/(?P[^/]+)$" + ) + match = re.match(pattern, model_name_or_path) + if match: + config["input_model"] = self._get_hf_input_model( + { + "type": "azureml_registry_model", + "registry_name": match.group("registry_name"), + "name": match.group("model_name"), + "version": match.group("version"), + }, + ) + return + + # Check HuggingFace url + pattern = r"https://huggingface\.co/([^/]+/[^/]+)(?:/.*)?" + match = re.search(pattern, model_name_or_path) + if match: + config["input_model"] = self._get_hf_input_model(match.group(1)) + return + + model_path = Path(model_name_or_path) + + # Check HF model name string + if not model_path.exists(): + try: + from huggingface_hub import repo_exists + except ImportError as e: + raise ImportError("Please install huggingface_hub to use the CLI for Huggingface model.") from e + + if not repo_exists(model_name_or_path): + raise ValueError(f"{model_name_or_path} is not a valid Huggingface model name.") + config["input_model"] = self._get_hf_input_model(model_name_or_path) + return + + # Check if local model is from Olive CLI + if model_path.is_dir(): + for file in model_path.iterdir(): + if file.is_file() and file.name == "model_config.json": + with open(file) as f: + config["input_model"] = json.load(f) + return + + # Check local onnx file (user-provided model) + if model_path.is_file() and model_path.suffix == ".onnx": + config["input_model"] = self._get_onnx_input_model(model_name_or_path) + return + + # Check local HF model file (user-provided model) + config["input_model"] = self._get_hf_input_model(model_name_or_path) + + @staticmethod + def _add_logging_options(sub_parser): + log_group = sub_parser.add_argument_group("logging options") + log_group.add_argument( + "--log_level", + type=int, + default=3, + help="Logging level. Default is 3. level 0: DEBUG, 1: INFO, 2: WARNING, 3: ERROR, 4: CRITICAL", + ) + return log_group + + @staticmethod + def _add_remote_options(sub_parser): + remote_group = sub_parser.add_argument_group("remote options") + remote_group.add_argument( + "--resource_group", + type=str, + required=False, + help="Resource group for the AzureML workspace.", + ) + remote_group.add_argument( + "--workspace_name", + type=str, + required=False, + help="Workspace name for the AzureML workspace.", + ) + remote_group.add_argument( + "--keyvault_name", + type=str, + required=False, + help=( + "The azureml keyvault name with huggingface token to use for remote run. Refer to" + " https://microsoft.github.io/Olive/features/huggingface_model_optimization.html#huggingface-login for" + " more details." + ), + ) + remote_group.add_argument( + "--aml_compute", + type=str, + required=False, + help="The compute name to run the workflow on.", + ) + + return remote_group + + @staticmethod + def _update_model_config(model_config_path: Path, output_path: Path): + with open(model_config_path) as f: + model_config = json.load(f) + model_path = model_config["config"]["model_path"] + model_config["config"]["model_path"] = str(output_path.resolve() / Path(model_path).name) + model_config_path = output_path / "model_config.json" + with open(model_config_path, "w") as f: + json.dump(model_config, f, indent=4) + + @staticmethod + def _add_dataconfig_options(sub_parser): + dataconfig_group = sub_parser.add_argument_group( + "data config options, which mutually exclusive with huggingface dataset options" + ) + dataconfig_group.add_argument( + "--data_config_path", + type=str, + help="Path to the data config file. It allows to customize the data config(json/yaml) for the model.", + ) + return dataconfig_group -# TODO(team): Remove this function once the output structure is refactored -def get_output_model_number(outputs: Dict) -> int: - return sum(len(f.nodes) for f in outputs.values()) + @staticmethod + def _add_dataset_options(sub_parser): + dataset_group = sub_parser.add_argument_group("dataset options") + dataset_group.add_argument( + "-d", + "--data_name", + type=str, + required=True, + help="The dataset name.", + ) + dataset_group.add_argument("--train_subset", type=str, help="The subset to use for training.") + dataset_group.add_argument("--eval_subset", type=str, help="The subset to use for evaluation.") + # TODO(jambayk): currently only supports single file or list of files, support mapping + dataset_group.add_argument( + "--data_files", type=str, help="The dataset files. If multiple files, separate by comma." + ) + dataset_group.add_argument("--train_split", type=str, default="train", help="The split to use for training.") + dataset_group.add_argument( + "--eval_split", + default="", + help="The dataset split to evaluate on.", + ) + text_group = dataset_group.add_mutually_exclusive_group(required=False) + text_group.add_argument( + "--text_field", + type=str, + help="The text field to use for fine-tuning.", + ) + text_group.add_argument( + "--text_template", + # using special string type to allow for escaped characters like \n + type=unescaped_str, + help=r"Template to generate text field from. E.g. '### Question: {prompt} \n### Answer: {response}'", + ) + dataset_group.add_argument( + "--max_seq_len", + type=int, + default=1024, + help="Maximum sequence length for the data.", + ) + dataset_group.add_argument( + "--add_special_tokens", + type=bool, + default=False, + help="Whether to add special tokens during preprocessing.", + ) + dataset_group.add_argument( + "--max_samples", + type=int, + default=256, + help="Maximum samples to select from the dataset.", + ) + dataset_group.add_argument( + "--batch_size", + type=int, + default=1, + help="Batch size.", + ) + return dataset_group, text_group + + def _update_dataset_options(self, config): + load_key = ("data_configs", 0, "load_dataset_config") + preprocess_key = ("data_configs", 0, "pre_process_data_config") + dataloader_key = ("data_configs", 0, "dataloader_config") + to_replace = [ + ((*load_key, "data_name"), self.args.data_name), + ((*load_key, "split"), self.args.train_split), + ((*load_key, "subset"), self.args.train_subset), + ( + (*load_key, "data_files"), + self.args.data_files.split(",") if self.args.data_files else None, + ), + ((*preprocess_key, "text_cols"), self.args.text_field), + ((*preprocess_key, "text_template"), self.args.text_template), + ((*preprocess_key, "max_seq_len"), self.args.max_seq_len), + ((*preprocess_key, "add_special_tokens"), self.args.add_special_tokens), + ((*preprocess_key, "max_samples"), self.args.max_samples), + ((*dataloader_key, "batch_size"), self.args.batch_size), + ] + for keys, value in to_replace: + if value is not None: + set_nested_dict_value(config, keys, value) -def update_model_config(model_config_path: Path, output_path: Path): - with open(model_config_path) as f: - model_config = json.load(f) - model_path = model_config["config"]["model_path"] - model_config["config"]["model_path"] = str(output_path.resolve() / Path(model_path).name) - model_config_path = output_path / "model_config.json" - with open(model_config_path, "w") as f: - json.dump(model_config, f, indent=4) + @staticmethod + def _add_hf_dataset_options(sub_parser): + hf_dataset_group = sub_parser.add_argument_group( + "huggingface dataset options, if dataset options are not provided, " + "user should provide the following options to modify the default data config. " + "Please refer to olive.data.container.TransformersTokenDummyDataContainer for more details." + ) + hf_dataset_group.add_argument( + "--hf_model_name", + help="Huggingface model name used to load model configs from huggingface.", + ) + hf_dataset_group.add_argument( + "--batch_size", + type=int, + help="Batch size of the input data.", + ) + hf_dataset_group.add_argument( + "--seq_len", + type=int, + help="Sequence length to use for the input data.", + ) + hf_dataset_group.add_argument( + "--past_seq_len", + type=int, + help="Past sequence length to use for the input data.", + ) + hf_dataset_group.add_argument( + "--max_seq_len", + type=int, + help="Max sequence length to use for the input data.", + ) + hf_dataset_group.add_argument( + "--shared_kv", + action="store_true", + help="Whether to enable share kv cache in the input data.", + ) + hf_dataset_group.add_argument( + "--generative", + action="store_true", + help="Whether to enable generative mode in the input data.", + ) + hf_dataset_group.add_argument( + "--ort_past_key_name", + type=str, + help="Past key name for the input data.", + ) + hf_dataset_group.add_argument( + "--ort_past_value_name", + type=str, + help="Past value name for the input data.", + ) + # TODO(all): Argument conflicting with use of the same name in model options + # hf_dataset_group.add_argument( + # "--trust_remote_code", + # action="store_true", + # help="Whether to trust remote code in the input data.", + # ) + hf_dataset_group.add_argument( + "--max_samples", + type=int, + help="Max samples to use for the input data.", + ) + hf_dataset_group.add_argument( + "--fields_no_batch", + nargs="*", + help="List of fields that should not be batched.", + ) + + return hf_dataset_group + + @staticmethod + def _add_accelerator_options(sub_parser): + accelerator_group = sub_parser.add_argument_group("accelerator group") + + accelerator_group.add_argument( + "--device", + type=str, + default="cpu", + choices=["gpu", "cpu", "npu"], + help="Device to use for the model.", + ) + + accelerator_group.add_argument( + "--providers_list", + type=str, + nargs="*", + choices=[ + "CUDAExecutionProvider", + "DmlExecutionProvider", + "JsExecutionProvider", + "MIGraphXExecutionProvider", + "OpenVINOExecutionProvider", + "OpenVINOExecutionProvider", + "QNNExecutionProviderROCMExecutionProvider", + "TensorrtExecutionProvider", + ], + help=( + "List of execution providers to use for ONNX model. They are case sensitive. " + "If not provided, all available providers will be used." + ), + ) + + return accelerator_group + + def _update_accelerator_options(self, config): + to_replace = [ + (("systems", "local_system", "accelerators", 0, "device"), self.args.device), + ] + + if self.args.providers_list: + to_replace.append( + (("systems", "local_system", "accelerators", 0, "execution_providers"), self.args.providers_list) + ) + + for k, v in to_replace: + if v is not None: + set_nested_dict_value(config, k, v) + + def _is_remote_run(self): + return all([self.args.resource_group, self.args.workspace_name, self.args.aml_compute]) + + def _update_remote_options(self, config, cli_action, tempdir): + if self.args.resource_group or self.args.workspace_name or self.args.aml_compute: + if not self._is_remote_run(): + raise ValueError("resource_group, workspace_name and aml_compute are required for remote workflow run.") + + config["workflow_id"] = f"{cli_action}-{hash_dict(config)}" + + try: + subscription_id = json.loads(subprocess.check_output("az account show", shell=True).decode("utf-8"))[ + "id" + ] + print(f"Using Azure subscription ID: {subscription_id}") + + except subprocess.CalledProcessError: + print( + "Error: Unable to retrieve account information. " + "Make sure you are logged in to Azure CLI with command `az login`." + ) + + config["azureml_client"] = { + "subscription_id": subscription_id, + "resource_group": self.args.resource_group, + "workspace_name": self.args.workspace_name, + "keyvault_name": self.args.keyvault_name, + "default_auth_params": {"exclude_managed_identity_credential": True}, + } + + conda_file_path = Path(tempdir) / "conda_gpu.yaml" + with open(conda_file_path, "w") as f: + yaml.dump(CONDA_CONFIG, f) + + config["systems"]["aml_system"] = { + "type": "AzureML", + "accelerators": [{"device": "GPU", "execution_providers": ["CUDAExecutionProvider"]}], + "aml_compute": self.args.aml_compute, + "aml_docker_config": { + "base_image": "mcr.microsoft.com/azureml/openmpi4.1.0-cuda11.8-cudnn8-ubuntu22.04", + "conda_file_path": str(conda_file_path), + }, + "hf_token": bool(self.args.keyvault_name), + } + config["workflow_host"] = "aml_system" + + # TODO(team): Remove this function once the output structure is refactored + @staticmethod + def _get_output_model_number(outputs: Dict) -> int: + return sum(len(f.nodes) for f in outputs.values()) diff --git a/olive/cli/capture_onnx.py b/olive/cli/capture_onnx.py index ad204e1fe..118a42636 100644 --- a/olive/cli/capture_onnx.py +++ b/olive/cli/capture_onnx.py @@ -12,17 +12,7 @@ from pathlib import Path from typing import ClassVar, Dict -from olive.cli.base import ( - BaseOliveCLICommand, - add_logging_options, - add_model_options, - add_remote_options, - get_input_model_config, - get_output_model_number, - is_remote_run, - update_model_config, - update_remote_option, -) +from olive.cli.base import BaseOliveCLICommand from olive.common.utils import IntEnumBase, hardlink_copy_dir, set_nested_dict_value, set_tempdir @@ -43,10 +33,10 @@ def register_subcommand(parser: ArgumentParser): help=("Capture ONNX graph using PyTorch Exporter or Model Builder from the Huggingface model."), ) - add_logging_options(sub_parser) + CaptureOnnxGraphCommand._add_logging_options(sub_parser) # model options - add_model_options(sub_parser) + CaptureOnnxGraphCommand._add_input_model_options(sub_parser) sub_parser.add_argument( "--device", @@ -154,7 +144,7 @@ def register_subcommand(parser: ArgumentParser): ) # remote options - add_remote_options(sub_parser) + CaptureOnnxGraphCommand._add_remote_options(sub_parser) sub_parser.set_defaults(func=CaptureOnnxGraphCommand) @@ -168,12 +158,12 @@ def run(self): output = olive_run(run_config) - if is_remote_run(self.args): + if self._is_remote_run(): # TODO(jambayk): point user to datastore with outputs or download outputs # both are not implemented yet return - if get_output_model_number(output) > 0: + if CaptureOnnxGraphCommand._get_output_model_number(output) > 0: output_path = Path(self.args.output_path) output_path.mkdir(parents=True, exist_ok=True) pass_name = "m" if self.args.use_model_builder else "c" @@ -184,14 +174,14 @@ def run(self): else: shutil.move(str(source_path.with_suffix(".onnx")), output_path) - update_model_config(source_path.with_suffix(".json"), output_path) + CaptureOnnxGraphCommand._update_model_config(source_path.with_suffix(".json"), output_path) print(f"ONNX Model is saved to {output_path.resolve()}") else: print("Failed to run capture-onnx-graph. Please set the log_level to 1 for more detailed logs.") def get_run_config(self, tempdir: str) -> Dict: config = deepcopy(TEMPLATE) - config["input_model"] = get_input_model_config(self.args) + self._update_input_model_options(config) to_replace = [ ("output_dir", tempdir), @@ -234,7 +224,7 @@ def get_run_config(self, tempdir: str) -> Dict: if value is None: continue set_nested_dict_value(config, keys, value) - update_remote_option(config, self.args, "capture-onnx-graph", tempdir) + self._update_remote_options(config, "capture-onnx-graph", tempdir) return config diff --git a/olive/cli/finetune.py b/olive/cli/finetune.py index 6637064d7..1422181a7 100644 --- a/olive/cli/finetune.py +++ b/olive/cli/finetune.py @@ -11,18 +11,8 @@ from pathlib import Path from typing import ClassVar, Dict -from olive.cli.base import ( - BaseOliveCLICommand, - add_logging_options, - add_model_options, - add_remote_options, - get_input_model_config, - get_output_model_number, - is_remote_run, - update_model_config, - update_remote_option, -) -from olive.common.utils import hardlink_copy_dir, set_nested_dict_value, set_tempdir, unescaped_str +from olive.cli.base import BaseOliveCLICommand +from olive.common.utils import hardlink_copy_dir, set_nested_dict_value, set_tempdir class FineTuneCommand(BaseOliveCLICommand): @@ -38,7 +28,7 @@ def register_subcommand(parser: ArgumentParser): ), ) - add_logging_options(sub_parser) + FineTuneCommand._add_logging_options(sub_parser) # TODO(jambayk): option to list/install required dependencies? sub_parser.add_argument( @@ -50,7 +40,7 @@ def register_subcommand(parser: ArgumentParser): ) # Model options - add_model_options(sub_parser) + FineTuneCommand._add_input_model_options(sub_parser) sub_parser.add_argument( "--torch_dtype", @@ -64,42 +54,8 @@ def register_subcommand(parser: ArgumentParser): ) # Dataset options - dataset_group = sub_parser.add_argument_group("dataset options") - dataset_group.add_argument( - "-d", - "--data_name", - type=str, - required=True, - help="The dataset name.", - ) - # TODO(jambayk): currently only supports single file or list of files, support mapping - dataset_group.add_argument( - "--data_files", type=str, help="The dataset files. If multiple files, separate by comma." - ) - dataset_group.add_argument("--train_split", type=str, default="train", help="The split to use for training.") - dataset_group.add_argument( - "--eval_split", - default="", - help="The dataset split to evaluate on.", - ) - text_group = dataset_group.add_mutually_exclusive_group(required=True) - text_group.add_argument( - "--text_field", - type=str, - help="The text field to use for fine-tuning.", - ) - text_group.add_argument( - "--text_template", - # using special string type to allow for escaped characters like \n - type=unescaped_str, - help=r"Template to generate text field from. E.g. '### Question: {prompt} \n### Answer: {response}'", - ) - dataset_group.add_argument( - "--max_seq_len", - type=int, - default=1024, - help="Maximum sequence length for the data.", - ) + FineTuneCommand._add_dataset_options(sub_parser) + # LoRA options lora_group = sub_parser.add_argument_group("lora options") lora_group.add_argument( @@ -135,7 +91,7 @@ def register_subcommand(parser: ArgumentParser): sub_parser.add_argument("--clean", action="store_true", help="Run in a clean cache directory") # remote options - add_remote_options(sub_parser) + FineTuneCommand._add_remote_options(sub_parser) sub_parser.set_defaults(func=FineTuneCommand) @@ -149,19 +105,19 @@ def run(self): output = olive_run(run_config) - if is_remote_run(self.args): + if self._is_remote_run(): # TODO(jambayk): point user to datastore with outputs or download outputs # both are not implemented yet return - if get_output_model_number(output) > 0: + if FineTuneCommand._get_output_model_number(output) > 0: # need to improve the output structure of olive run output_path = Path(self.args.output_path) output_path.mkdir(parents=True, exist_ok=True) source_path = Path(tempdir) / "-".join(run_config["passes"].keys()) / "gpu-cuda_model" hardlink_copy_dir(source_path, output_path) - update_model_config(source_path.with_suffix(".json"), output_path) + FineTuneCommand._update_model_config(source_path.with_suffix(".json"), output_path) print(f"Model and adapters saved to {output_path.resolve()}") else: print("Failed to run finetune. Please set the log_level to 1 for more detailed logs.") @@ -182,19 +138,9 @@ def parse_training_args(self) -> Dict: return {k: v for k, v in vars(training_args).items() if k in arg_keys} def get_run_config(self, tempdir: str) -> Dict: - load_key = ("data_configs", 0, "load_dataset_config") - preprocess_key = ("data_configs", 0, "pre_process_data_config") + finetune_key = ("passes", "f") to_replace = [ - ((*load_key, "data_name"), self.args.data_name), - ((*load_key, "split"), self.args.train_split), - ( - (*load_key, "data_files"), - self.args.data_files.split(",") if self.args.data_files else None, - ), - ((*preprocess_key, "text_cols"), self.args.text_field), - ((*preprocess_key, "text_template"), self.args.text_template), - ((*preprocess_key, "max_seq_len"), self.args.max_seq_len), ((*finetune_key, "type"), self.args.method), ((*finetune_key, "torch_dtype"), self.args.torch_dtype), ((*finetune_key, "training_args"), self.parse_training_args()), @@ -210,12 +156,12 @@ def get_run_config(self, tempdir: str) -> Dict: to_replace.append(((*finetune_key, "target_modules"), self.args.target_modules.split(","))) config = deepcopy(TEMPLATE) - config["input_model"] = get_input_model_config(self.args) + self._update_input_model_options(config) + self._update_dataset_options(config) for keys, value in to_replace: - if value is None: - continue - set_nested_dict_value(config, keys, value) + if value is not None: + set_nested_dict_value(config, keys, value) if self.args.eval_split: eval_data_config = deepcopy(config["data_configs"][0]) @@ -227,7 +173,7 @@ def get_run_config(self, tempdir: str) -> Dict: if not self.args.use_ort_genai: del config["passes"]["m"] - update_remote_option(config, self.args, "finetune", tempdir) + self._update_remote_options(config, "finetune", tempdir) config["log_severity_level"] = self.args.log_level return config @@ -248,6 +194,8 @@ def get_run_config(self, tempdir: str) -> Dict: "type": "HuggingfaceContainer", "load_dataset_config": {}, "pre_process_data_config": {}, + "dataloader_config": {}, + "post_process_data_config": {}, } ], "passes": { diff --git a/olive/cli/launcher.py b/olive/cli/launcher.py index e758f9dfc..f34c14ad3 100644 --- a/olive/cli/launcher.py +++ b/olive/cli/launcher.py @@ -13,6 +13,7 @@ from olive.cli.finetune import FineTuneCommand from olive.cli.manage_aml_compute import ManageAMLComputeCommand from olive.cli.perf_tuning import PerfTuningCommand +from olive.cli.quantize import QuantizeCommand from olive.cli.run import WorkflowRunCommand @@ -33,6 +34,7 @@ def get_cli_parser(called_as_console_script: bool = True) -> ArgumentParser: ConfigureQualcommSDKCommand.register_subcommand(commands_parser) ManageAMLComputeCommand.register_subcommand(commands_parser) PerfTuningCommand.register_subcommand(commands_parser) + QuantizeCommand.register_subcommand(commands_parser) CloudCacheCommand.register_subcommand(commands_parser) return parser diff --git a/olive/cli/manage_aml_compute.py b/olive/cli/manage_aml_compute.py index f054b07ce..832bc8947 100644 --- a/olive/cli/manage_aml_compute.py +++ b/olive/cli/manage_aml_compute.py @@ -68,7 +68,7 @@ def run(self): ) if self.args.create: - print("Creating compute %s...", self.args.compute_name) + print(f"Creating compute {self.args.compute_name}...") if self.args.vm_size is None: raise ValueError("vm_size must be provided if operation is create") if self.args.location is None: @@ -84,19 +84,15 @@ def run(self): ) ml_client.begin_create_or_update(cluster_basic).result() print( - "Successfully created compute: %s at %s with vm_size:%s and " - "min_nodes=%d and max_nodes=%d and idle_time_before_scale_down=%d", - self.args.compute_name, - self.args.location, - self.args.vm_size, - self.args.min_nodes, - self.args.max_nodes, - self.args.idle_time_before_scale_down, + f"Successfully created compute: {self.args.compute_name} at {self.args.location} " + f"with vm_size:{self.args.vm_size} and min_nodes={self.args.min_nodes} and " + f"max_nodes={self.args.max_nodes} and " + f"idle_time_before_scale_down={self.args.idle_time_before_scale_down}" ) elif self.args.delete: - print("Deleting compute %s...", self.args.compute_name) + print(f"Deleting compute {self.args.compute_name}...") ml_client.compute.begin_delete(self.args.compute_name).wait() - print("Successfully deleted compute: %s", self.args.compute_name) + print(f"Successfully deleted compute: {self.args.compute_name}") @classmethod def get_ml_client(cls, aml_config_path: str, subscription_id: str, resource_group: str, workspace_name: str): diff --git a/olive/cli/perf_tuning.py b/olive/cli/perf_tuning.py index 51ce006fd..fb9f8bfd1 100644 --- a/olive/cli/perf_tuning.py +++ b/olive/cli/perf_tuning.py @@ -15,16 +15,7 @@ import yaml from olive.auto_optimizer.template_mapping import PERF_TUNING_TEMPLATE -from olive.cli.base import ( - BaseOliveCLICommand, - add_logging_options, - add_model_options, - add_remote_options, - get_input_model_config, - get_output_model_number, - is_remote_run, - update_remote_option, -) +from olive.cli.base import BaseOliveCLICommand from olive.common.utils import set_nested_dict_value, set_tempdir from olive.data.config import DataConfig from olive.workflows import run as olive_run @@ -43,96 +34,26 @@ def register_subcommand(parser: ArgumentParser): "--hf_model_name hf_model_name --device device_type to get the tuned session parameters." ), ) - add_logging_options(sub_parser) + PerfTuningCommand._add_logging_options(sub_parser) # model options - add_model_options(sub_parser) + PerfTuningCommand._add_input_model_options(sub_parser) # dataset options - dataset_group = sub_parser.add_argument_group( - "dataset options, which mutually exclusive with huggingface dataset options" - ) - dataset_group.add_argument( - "--data_config_path", - type=str, - help="Path to the data config file. It allows to customize the data config(json/yaml) for the model.", - ) - - hf_dataset_group = sub_parser.add_argument_group( - "huggingface dataset options, if dataset options are not provided, " - "user should provide the following options to modify the default data config. " - "Please refer to olive.data.container.TransformersTokenDummyDataContainer for more details." - ) + PerfTuningCommand._add_dataconfig_options(sub_parser) + hf_dataset_group = PerfTuningCommand._add_hf_dataset_options(sub_parser) hf_dataset_group.add_argument( "--predict_with_kv_cache", action="store_true", help="Whether to use key-value cache for perf_tuning", ) - hf_dataset_group.add_argument( - "--hf_model_name", - help="Huggingface model name used to load model configs from huggingface.", - ) - hf_dataset_group.add_argument( - "--batch_size", - type=int, - help="Batch size of the input data.", - ) - hf_dataset_group.add_argument( - "--seq_len", - type=int, - help="Sequence length to use for the input data.", - ) - hf_dataset_group.add_argument( - "--past_seq_len", - type=int, - help="Past sequence length to use for the input data.", - ) - hf_dataset_group.add_argument( - "--max_seq_len", - type=int, - help="Max sequence length to use for the input data.", - ) - hf_dataset_group.add_argument( - "--shared_kv", - action="store_true", - help="Whether to enable share kv cache in the input data.", - ) - hf_dataset_group.add_argument( - "--generative", - action="store_true", - help="Whether to enable generative mode in the input data.", - ) - hf_dataset_group.add_argument( - "--ort_past_key_name", - type=str, - help="Past key name for the input data.", - ) - hf_dataset_group.add_argument( - "--ort_past_value_name", - type=str, - help="Past value name for the input data.", - ) - hf_dataset_group.add_argument( - "--max_samples", - type=int, - help="Max samples to use for the input data.", - ) - hf_dataset_group.add_argument( - "--fields_no_batch", - nargs="*", - help="List of fields that should not be batched.", - ) # pass options pass_group = sub_parser.add_argument_group("pass options") - pass_group.add_argument( - "--device", - type=str, - default="cpu", - choices=["gpu", "cpu"], - help="Device to use for the model.", - ) + # accelerator options + PerfTuningCommand._add_accelerator_options(sub_parser) + pass_group.add_argument( "--cpu_cores", type=int, @@ -149,15 +70,6 @@ def register_subcommand(parser: ArgumentParser): action="store_true", help="Whether enable CUDA Graph for CUDA execution provider.", ) - pass_group.add_argument( - "--providers_list", - type=str, - nargs="*", - help=( - "List of execution providers to use for ONNX model. They are case sensitive. " - "If not provided, all available providers will be used." - ), - ) pass_group.add_argument( "--execution_mode_list", type=int, nargs="*", help="Parallelism list between operators." ) @@ -203,7 +115,7 @@ def register_subcommand(parser: ArgumentParser): ) # remote options - add_remote_options(sub_parser) + PerfTuningCommand._add_remote_options(sub_parser) sub_parser.set_defaults(func=PerfTuningCommand) @@ -277,7 +189,7 @@ def refine_args(self): def get_run_config(self, tempdir) -> Dict: template_config = PerfTuningCommand.perf_tuning_template() - template_config["input_model"] = get_input_model_config(self.args) + self._update_input_model_options(template_config) print(f"input_model: {template_config['input_model']}") perf_tuning_key = ("passes", "perf_tuning") @@ -296,12 +208,12 @@ def get_run_config(self, tempdir) -> Dict: to_replace.append((system_ep_key, self.args.providers_list)) config = deepcopy(template_config) - for k, v in to_replace: - if v is None: - continue - set_nested_dict_value(config, k, v) + self._update_accelerator_options(config) + self._update_remote_options(config, "perf-tuning", tempdir) - update_remote_option(config, self.args, "perf-tuning", tempdir) + for k, v in to_replace: + if v is not None: + set_nested_dict_value(config, k, v) config["log_severity_level"] = self.args.log_level return config @@ -314,12 +226,12 @@ def run(self): run_config["output_dir"] = tempdir output = olive_run(run_config) - if is_remote_run(self.args): + if self._is_remote_run(): # TODO(jambayk): point user to datastore with outputs or download outputs # both are not implemented yet return - if get_output_model_number(output) > 0: + if PerfTuningCommand._get_output_model_number(output) > 0: # need to improve the output structure of olive run output_path = Path(self.args.output_path) output_path.mkdir(parents=True, exist_ok=True) diff --git a/olive/cli/quantize.py b/olive/cli/quantize.py new file mode 100644 index 000000000..fc0641979 --- /dev/null +++ b/olive/cli/quantize.py @@ -0,0 +1,149 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +# ruff: noqa: T201 +# ruff: noqa: RUF012 + +import tempfile +from argparse import ArgumentParser +from copy import deepcopy +from pathlib import Path +from typing import Any, Dict + +from olive.cli.base import BaseOliveCLICommand +from olive.common.utils import hardlink_copy_dir, set_nested_dict_value, set_tempdir + + +class QuantizeCommand(BaseOliveCLICommand): + _CONFIG_TEMPLATE = { + "input_model": {"type": "HfModel", "load_kwargs": {"attn_implementation": "eager"}}, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [{"device": "gpu", "execution_providers": ["CUDAExecutionProvider"]}], + } + }, + "data_configs": [ + { + "name": "default_data_config", + "type": "HuggingfaceContainer", + "load_dataset_config": {}, + "pre_process_data_config": {}, + "dataloader_config": {}, + "post_process_data_config": {}, + } + ], + "passes": { + "awq": {"type": "AutoAWQQuantizer"}, + "gptq": { + # Ref: https://github.com/AutoGPTQ/AutoGPTQ/pull/651/files + "type": "GptqQuantizer", + "data_config": "default_data_config", + }, + "quarot": { + "type": "QuaRot", + "w_rtn": True, + "rotate": True, + "w_bits": 4, + "a_bits": 4, + "k_bits": 4, + "v_bits": 4, + "calibration_data_config": None, + }, + }, + "pass_flows": [], + "cache_dir": "cache", + "output_dir": "models", + "host": "local_system", + "target": "local_system", + } + + @staticmethod + def register_subcommand(parser: ArgumentParser): + sub_parser = parser.add_parser( + "quantize", + help="Quantize the input model", + ) + + # model options + QuantizeCommand._add_input_model_options(sub_parser) + + # Logging options + QuantizeCommand._add_logging_options(sub_parser) + + sub_parser.add_argument( + "-o", + "--output_path", + type=str, + required=True, + help="Path to save quantized model weights.", + ) + sub_parser.add_argument( + "--tempdir", default=None, type=str, help="Root directory for tempfile directories and files" + ) + sub_parser.add_argument( + "--algorithms", + type=str, + nargs="*", + required=True, + choices=sorted(QuantizeCommand._CONFIG_TEMPLATE["passes"].keys()), + help="List of quantization algorithms to run.", + ) + + # dataset options + QuantizeCommand._add_dataset_options(sub_parser) + + # accelerator options + QuantizeCommand._add_accelerator_options(sub_parser) + + # remote options + QuantizeCommand._add_remote_options(sub_parser) + + sub_parser.set_defaults(func=QuantizeCommand) + + def _get_run_config(self, tempdir: str) -> Dict[str, Any]: + to_replace = [ + (("pass_flows"), [[name] for name in self.args.algorithms]), + (("output_dir"), tempdir), + ] + + config = deepcopy(QuantizeCommand._CONFIG_TEMPLATE) + self._update_input_model_options(config) + self._update_dataset_options(config) + self._update_accelerator_options(config) + config["log_severity_level"] = self.args.log_level + + for k, v in to_replace: + if v is not None: + set_nested_dict_value(config, k, v) + + return config + + def run(self): + from olive.workflows import run as olive_run + + set_tempdir(self.args.tempdir) + + with tempfile.TemporaryDirectory() as tempdir: + run_config = self._get_run_config(tempdir) + output = olive_run(run_config) + + if self._is_remote_run(): + # TODO(jambayk): point user to datastore with outputs or download outputs + # both are not implemented yet + return + + if QuantizeCommand._get_output_model_number(output) > 0: + # need to improve the output structure of olive run + output_path = Path(self.args.output_path) + output_path.mkdir(parents=True, exist_ok=True) + device_name = "gpu-cuda_model" if self.args.device == "gpu" else "cpu-cpu_model" + for algorithm_name in self.args.algorithms: + hardlink_copy_dir( + Path(tempdir) / algorithm_name / device_name / "model", output_path / algorithm_name + ) + print(f"Quantized models saved to {output_path.resolve()}") + else: + print("Failed to run quantize. Please set the log_level to 1 for more detailed logs.") diff --git a/olive/common/hf/mappings.py b/olive/common/hf/mappings.py index 723928a37..e13bf2d86 100644 --- a/olive/common/hf/mappings.py +++ b/olive/common/hf/mappings.py @@ -70,3 +70,16 @@ "llama": "gpt2", "roberta": "bert", } + +MODEL_OUTSIDE_LAYER_MODULES = { + "phi3": ["model.embed_tokens", "embed_dropout", "model.norm"], +} + +MODEL_INSIDE_LAYER_MODULES = { + "phi3": [ + ["self_attn.qkv_proj"], + ["self_attn.o_proj"], + ["mlp.gate_up_proj"], + ["mlp.down_proj"], + ] +} diff --git a/olive/passes/pytorch/gptq.py b/olive/passes/pytorch/gptq.py index b39bae832..02e3936c3 100644 --- a/olive/passes/pytorch/gptq.py +++ b/olive/passes/pytorch/gptq.py @@ -11,6 +11,7 @@ import torch from olive.common.config_utils import validate_config +from olive.common.hf.mappings import MODEL_INSIDE_LAYER_MODULES, MODEL_OUTSIDE_LAYER_MODULES from olive.data.config import DataConfig from olive.hardware.accelerator import AcceleratorSpec, Device from olive.model import HfModelHandler, PyTorchModelHandler @@ -166,20 +167,26 @@ def _run_for_config( def get_onnx_quant_linear(*args, **kwargs): return QuantLinear - if hasattr(pytorch_model, "config") and pytorch_model.config.model_type in GPTQ_CAUSAL_LM_MODEL_MAP: - model_type = pytorch_model.config.model_type - model_class = GPTQ_CAUSAL_LM_MODEL_MAP[model_type] - quantized_model = model_class(pytorch_model, False, quantize_config) - else: - quantized_model = BaseGPTQForCausalLM(pytorch_model, False, quantize_config) - if not (config["layers_block_name"] and config["outside_layer_modules"] and config["inside_layer_modules"]): - raise ValueError( - "Can't get layers_block_name to quantize automatically, " - "please set layers_block_name, outside_layer_modules and inside_layer_modules in config." - ) - quantized_model.layers_block_name = config["layers_block_name"] + model_type = pytorch_model.config.model_type if hasattr(pytorch_model, "config") else "" + model_class = GPTQ_CAUSAL_LM_MODEL_MAP.get(model_type, BaseGPTQForCausalLM) + quantized_model = model_class(pytorch_model, False, quantize_config) + + if config["inside_layer_modules"]: quantized_model.outside_layer_modules = config["outside_layer_modules"] + elif model_type in MODEL_OUTSIDE_LAYER_MODULES: + quantized_model.outside_layer_modules = MODEL_OUTSIDE_LAYER_MODULES[model_type] + else: + raise ValueError("Can't get outside_layer_modules to quantize automatically, please provide it in config.") + + if config["inside_layer_modules"]: quantized_model.inside_layer_modules = config["inside_layer_modules"] + elif model_type in MODEL_INSIDE_LAYER_MODULES: + quantized_model.inside_layer_modules = MODEL_INSIDE_LAYER_MODULES[model_type] + else: + raise ValueError("Can't get inside_layer_modules to quantize automatically, please provide it in config.") + + if config["layers_block_name"]: + quantized_model.layers_block_name = config["layers_block_name"] import auto_gptq diff --git a/test/unit_test/cli/test_base.py b/test/unit_test/cli/test_base.py index 04a609938..ac4504512 100644 --- a/test/unit_test/cli/test_base.py +++ b/test/unit_test/cli/test_base.py @@ -8,7 +8,16 @@ import pytest -from olive.cli.base import get_input_model_config +from olive.cli.base import BaseOliveCLICommand + + +class MockCommand(BaseOliveCLICommand): + @staticmethod + def register_subcommand(parser): + pass + + def run(self): + pass @pytest.mark.parametrize( @@ -218,10 +227,11 @@ def has_function_side_effect(arg): mock_instance.has_function.side_effect = has_function_side_effect # execute - config = get_input_model_config(args) + MockCommand(None, args)._update_input_model_options(config) # assert - assert config == expected_config + assert "input_model" in config + assert config["input_model"] == expected_config @patch("olive.cli.base.UserModuleLoader") @@ -238,7 +248,7 @@ def test_insert_input_model_pt_model_missing_loader(MockUserModuleLoader): # execute and assert with pytest.raises(ValueError, match="Either _model_loader or model_name_or_path is required for PyTorch model."): - get_input_model_config(args) + MockCommand(None, args)._update_input_model_options({}) def test_insert_input_model_invalid_hf_model_name(): @@ -253,7 +263,7 @@ def test_insert_input_model_invalid_hf_model_name(): # execute and assert with pytest.raises(ValueError, match="invalid-name is not a valid Huggingface model name."): - get_input_model_config(args) + MockCommand(None, args)._update_input_model_options({}) def test_insert_input_model_cli_output_model(): @@ -269,7 +279,9 @@ def test_insert_input_model_cli_output_model(): expected_config = {"type": "PyTorchModel", "model_path": "model_path"} # execute - config = get_input_model_config(args) + config = {} + MockCommand(None, args)._update_input_model_options(config) # assert - assert config == expected_config + assert "input_model" in config + assert config["input_model"] == expected_config diff --git a/test/unit_test/cli/test_cli.py b/test/unit_test/cli/test_cli.py index 4a31b678b..e5b734812 100644 --- a/test/unit_test/cli/test_cli.py +++ b/test/unit_test/cli/test_cli.py @@ -250,5 +250,49 @@ def test_cloud_cache_command(mock_container_client, test_set): mock_container_client().delete_blob.assert_called_once() +@pytest.mark.parametrize("algorithm_names", [{"awq"}, {"awq", "gptq"}]) +@patch("olive.workflows.run") +@patch("olive.cli.finetune.tempfile.TemporaryDirectory") +@patch("huggingface_hub.repo_exists") +def test_quantize_command(mock_repo_exists, mock_tempdir, mock_run, algorithm_names, tmp_path): + # some directories + tmpdir = tmp_path / "tmpdir" + tmpdir.mkdir() + + output_dir = tmp_path / "output_dir" + + # setup + mock_repo_exists.return_value = True + mock_tempdir.return_value = tmpdir.resolve() + mock_run.return_value = {"output_dir": Footprint(nodes={"dummy_output": "dummy_output"})} + + for algo_name in algorithm_names: + workflow_output_dir = tmpdir / algo_name / "cpu-cpu_model" / "model" + workflow_output_dir.mkdir(parents=True) + dummy_model = workflow_output_dir / "dummy_model" + with dummy_model.open("w") as f: + f.write("dummy_model") + + # setup + command_args = [ + "quantize", + "-m", + "dummy_model", + "-d", + "dummy_dataset", + "--algorithms", + *algorithm_names, + "-o", + str(output_dir), + ] + + # execute + cli_main(command_args) + + config = mock_run.call_args[0][0] + assert config["input_model"]["model_path"] == "dummy_model" + assert {el.name for el in output_dir.iterdir()} == algorithm_names + + # TODO(anyone): Add tests for ManageAMLComputeCommand # Test for ExportAdaptersCommand is added as part of test/unit_test/passes/onnx/test_export_adapters.py