Skip to content

Commit

Permalink
Quantize: CLI command to quantize input model
Browse files Browse the repository at this point in the history
Usage:
 olive quantize --m <model-name> --device <cpu|gpu> --algorithms <awq,gptq> --data_config_path <file-name> -o <output-folder>

Few other code improvements:
* Moved global function is cli/base.py to be static members of
  cli/base/BaseOliveCLICommand to avoid multiple imports in each cli command
  implementation. Moreover these functions are only useable in the context of
  cli command implementation anyways.
* Created new new functions (add_data_config_options, add_hf_dataset_options,
  and add_accelerator_options) to cli/base/BaseOliveCLICommand to avoid code
  duplication and standardization across different cli command implementations.
  • Loading branch information
shaahji committed Sep 17, 2024
1 parent 9fa2604 commit 7cc7299
Show file tree
Hide file tree
Showing 11 changed files with 791 additions and 490 deletions.
760 changes: 494 additions & 266 deletions olive/cli/base.py

Large diffs are not rendered by default.

28 changes: 9 additions & 19 deletions olive/cli/capture_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,7 @@
from pathlib import Path
from typing import ClassVar, Dict

from olive.cli.base import (
BaseOliveCLICommand,
add_logging_options,
add_model_options,
add_remote_options,
get_input_model_config,
get_output_model_number,
is_remote_run,
update_model_config,
update_remote_option,
)
from olive.cli.base import BaseOliveCLICommand
from olive.common.utils import IntEnumBase, hardlink_copy_dir, set_nested_dict_value, set_tempdir


Expand All @@ -43,10 +33,10 @@ def register_subcommand(parser: ArgumentParser):
help=("Capture ONNX graph using PyTorch Exporter or Model Builder from the Huggingface model."),
)

add_logging_options(sub_parser)
CaptureOnnxGraphCommand._add_logging_options(sub_parser)

# model options
add_model_options(sub_parser)
CaptureOnnxGraphCommand._add_input_model_options(sub_parser)

sub_parser.add_argument(
"--device",
Expand Down Expand Up @@ -154,7 +144,7 @@ def register_subcommand(parser: ArgumentParser):
)

# remote options
add_remote_options(sub_parser)
CaptureOnnxGraphCommand._add_remote_options(sub_parser)

sub_parser.set_defaults(func=CaptureOnnxGraphCommand)

Expand All @@ -168,12 +158,12 @@ def run(self):

output = olive_run(run_config)

if is_remote_run(self.args):
if self._is_remote_run():
# TODO(jambayk): point user to datastore with outputs or download outputs
# both are not implemented yet
return

if get_output_model_number(output) > 0:
if CaptureOnnxGraphCommand._get_output_model_number(output) > 0:
output_path = Path(self.args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
pass_name = "m" if self.args.use_model_builder else "c"
Expand All @@ -184,14 +174,14 @@ def run(self):
else:
shutil.move(str(source_path.with_suffix(".onnx")), output_path)

update_model_config(source_path.with_suffix(".json"), output_path)
CaptureOnnxGraphCommand._update_model_config(source_path.with_suffix(".json"), output_path)
print(f"ONNX Model is saved to {output_path.resolve()}")
else:
print("Failed to run capture-onnx-graph. Please set the log_level to 1 for more detailed logs.")

def get_run_config(self, tempdir: str) -> Dict:
config = deepcopy(TEMPLATE)
config["input_model"] = get_input_model_config(self.args)
self._update_input_model_options(config)

to_replace = [
("output_dir", tempdir),
Expand Down Expand Up @@ -234,7 +224,7 @@ def get_run_config(self, tempdir: str) -> Dict:
if value is None:
continue
set_nested_dict_value(config, keys, value)
update_remote_option(config, self.args, "capture-onnx-graph", tempdir)
self._update_remote_options(config, "capture-onnx-graph", tempdir)

return config

Expand Down
88 changes: 18 additions & 70 deletions olive/cli/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,8 @@
from pathlib import Path
from typing import ClassVar, Dict

from olive.cli.base import (
BaseOliveCLICommand,
add_logging_options,
add_model_options,
add_remote_options,
get_input_model_config,
get_output_model_number,
is_remote_run,
update_model_config,
update_remote_option,
)
from olive.common.utils import hardlink_copy_dir, set_nested_dict_value, set_tempdir, unescaped_str
from olive.cli.base import BaseOliveCLICommand
from olive.common.utils import hardlink_copy_dir, set_nested_dict_value, set_tempdir


class FineTuneCommand(BaseOliveCLICommand):
Expand All @@ -38,7 +28,7 @@ def register_subcommand(parser: ArgumentParser):
),
)

add_logging_options(sub_parser)
FineTuneCommand._add_logging_options(sub_parser)

# TODO(jambayk): option to list/install required dependencies?
sub_parser.add_argument(
Expand All @@ -50,7 +40,7 @@ def register_subcommand(parser: ArgumentParser):
)

# Model options
add_model_options(sub_parser)
FineTuneCommand._add_input_model_options(sub_parser)

sub_parser.add_argument(
"--torch_dtype",
Expand All @@ -64,42 +54,8 @@ def register_subcommand(parser: ArgumentParser):
)

# Dataset options
dataset_group = sub_parser.add_argument_group("dataset options")
dataset_group.add_argument(
"-d",
"--data_name",
type=str,
required=True,
help="The dataset name.",
)
# TODO(jambayk): currently only supports single file or list of files, support mapping
dataset_group.add_argument(
"--data_files", type=str, help="The dataset files. If multiple files, separate by comma."
)
dataset_group.add_argument("--train_split", type=str, default="train", help="The split to use for training.")
dataset_group.add_argument(
"--eval_split",
default="",
help="The dataset split to evaluate on.",
)
text_group = dataset_group.add_mutually_exclusive_group(required=True)
text_group.add_argument(
"--text_field",
type=str,
help="The text field to use for fine-tuning.",
)
text_group.add_argument(
"--text_template",
# using special string type to allow for escaped characters like \n
type=unescaped_str,
help=r"Template to generate text field from. E.g. '### Question: {prompt} \n### Answer: {response}'",
)
dataset_group.add_argument(
"--max_seq_len",
type=int,
default=1024,
help="Maximum sequence length for the data.",
)
FineTuneCommand._add_dataset_options(sub_parser)

# LoRA options
lora_group = sub_parser.add_argument_group("lora options")
lora_group.add_argument(
Expand Down Expand Up @@ -135,7 +91,7 @@ def register_subcommand(parser: ArgumentParser):
sub_parser.add_argument("--clean", action="store_true", help="Run in a clean cache directory")

# remote options
add_remote_options(sub_parser)
FineTuneCommand._add_remote_options(sub_parser)

sub_parser.set_defaults(func=FineTuneCommand)

Expand All @@ -149,19 +105,19 @@ def run(self):

output = olive_run(run_config)

if is_remote_run(self.args):
if self._is_remote_run():
# TODO(jambayk): point user to datastore with outputs or download outputs
# both are not implemented yet
return

if get_output_model_number(output) > 0:
if FineTuneCommand._get_output_model_number(output) > 0:
# need to improve the output structure of olive run
output_path = Path(self.args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
source_path = Path(tempdir) / "-".join(run_config["passes"].keys()) / "gpu-cuda_model"
hardlink_copy_dir(source_path, output_path)

update_model_config(source_path.with_suffix(".json"), output_path)
FineTuneCommand._update_model_config(source_path.with_suffix(".json"), output_path)
print(f"Model and adapters saved to {output_path.resolve()}")
else:
print("Failed to run finetune. Please set the log_level to 1 for more detailed logs.")
Expand All @@ -182,19 +138,9 @@ def parse_training_args(self) -> Dict:
return {k: v for k, v in vars(training_args).items() if k in arg_keys}

def get_run_config(self, tempdir: str) -> Dict:
load_key = ("data_configs", 0, "load_dataset_config")
preprocess_key = ("data_configs", 0, "pre_process_data_config")

finetune_key = ("passes", "f")
to_replace = [
((*load_key, "data_name"), self.args.data_name),
((*load_key, "split"), self.args.train_split),
(
(*load_key, "data_files"),
self.args.data_files.split(",") if self.args.data_files else None,
),
((*preprocess_key, "text_cols"), self.args.text_field),
((*preprocess_key, "text_template"), self.args.text_template),
((*preprocess_key, "max_seq_len"), self.args.max_seq_len),
((*finetune_key, "type"), self.args.method),
((*finetune_key, "torch_dtype"), self.args.torch_dtype),
((*finetune_key, "training_args"), self.parse_training_args()),
Expand All @@ -210,12 +156,12 @@ def get_run_config(self, tempdir: str) -> Dict:
to_replace.append(((*finetune_key, "target_modules"), self.args.target_modules.split(",")))

config = deepcopy(TEMPLATE)
config["input_model"] = get_input_model_config(self.args)
self._update_input_model_options(config)
self._update_dataset_options(config)

for keys, value in to_replace:
if value is None:
continue
set_nested_dict_value(config, keys, value)
if value is not None:
set_nested_dict_value(config, keys, value)

if self.args.eval_split:
eval_data_config = deepcopy(config["data_configs"][0])
Expand All @@ -227,7 +173,7 @@ def get_run_config(self, tempdir: str) -> Dict:
if not self.args.use_ort_genai:
del config["passes"]["m"]

update_remote_option(config, self.args, "finetune", tempdir)
self._update_remote_options(config, "finetune", tempdir)
config["log_severity_level"] = self.args.log_level

return config
Expand All @@ -248,6 +194,8 @@ def get_run_config(self, tempdir: str) -> Dict:
"type": "HuggingfaceContainer",
"load_dataset_config": {},
"pre_process_data_config": {},
"dataloader_config": {},
"post_process_data_config": {},
}
],
"passes": {
Expand Down
2 changes: 2 additions & 0 deletions olive/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from olive.cli.finetune import FineTuneCommand
from olive.cli.manage_aml_compute import ManageAMLComputeCommand
from olive.cli.perf_tuning import PerfTuningCommand
from olive.cli.quantize import QuantizeCommand
from olive.cli.run import WorkflowRunCommand


Expand All @@ -33,6 +34,7 @@ def get_cli_parser(called_as_console_script: bool = True) -> ArgumentParser:
ConfigureQualcommSDKCommand.register_subcommand(commands_parser)
ManageAMLComputeCommand.register_subcommand(commands_parser)
PerfTuningCommand.register_subcommand(commands_parser)
QuantizeCommand.register_subcommand(commands_parser)
CloudCacheCommand.register_subcommand(commands_parser)

return parser
Expand Down
18 changes: 7 additions & 11 deletions olive/cli/manage_aml_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def run(self):
)

if self.args.create:
print("Creating compute %s...", self.args.compute_name)
print(f"Creating compute {self.args.compute_name}...")
if self.args.vm_size is None:
raise ValueError("vm_size must be provided if operation is create")
if self.args.location is None:
Expand All @@ -84,19 +84,15 @@ def run(self):
)
ml_client.begin_create_or_update(cluster_basic).result()
print(
"Successfully created compute: %s at %s with vm_size:%s and "
"min_nodes=%d and max_nodes=%d and idle_time_before_scale_down=%d",
self.args.compute_name,
self.args.location,
self.args.vm_size,
self.args.min_nodes,
self.args.max_nodes,
self.args.idle_time_before_scale_down,
f"Successfully created compute: {self.args.compute_name} at {self.args.location} "
f"with vm_size:{self.args.vm_size} and min_nodes={self.args.min_nodes} and "
f"max_nodes={self.args.max_nodes} and "
f"idle_time_before_scale_down={self.args.idle_time_before_scale_down}"
)
elif self.args.delete:
print("Deleting compute %s...", self.args.compute_name)
print(f"Deleting compute {self.args.compute_name}...")
ml_client.compute.begin_delete(self.args.compute_name).wait()
print("Successfully deleted compute: %s", self.args.compute_name)
print(f"Successfully deleted compute: {self.args.compute_name}")

@classmethod
def get_ml_client(cls, aml_config_path: str, subscription_id: str, resource_group: str, workspace_name: str):
Expand Down
Loading

0 comments on commit 7cc7299

Please sign in to comment.