Skip to content

Commit

Permalink
Quantize: CLI command to quantize input model
Browse files Browse the repository at this point in the history
Usage:
  olive quantize                  \
    -m <model-name>               \
    --trust_remote_code           \
    --device <cpu|gpu|npu>        \
    --algorithms <awq,gptq>       \
    --data_name <data-name>       \
    --train_subset <subset-name>  \
    --batch_size <batch-size>     \
    --tempdir <temp-dir>          \
    -o <output-dir>
  • Loading branch information
shaahji committed Sep 18, 2024
1 parent d2bc1c4 commit f38fa1c
Show file tree
Hide file tree
Showing 8 changed files with 463 additions and 102 deletions.
218 changes: 214 additions & 4 deletions olive/cli/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from olive.cli.constants import CONDA_CONFIG
from olive.common.user_module_loader import UserModuleLoader
from olive.common.utils import hash_dict
from olive.common.utils import hash_dict, set_nested_dict_value, unescaped_str


class BaseOliveCLICommand(ABC):
Expand Down Expand Up @@ -189,6 +189,10 @@ def get_input_model_config(args) -> Union[str, Dict[str, str]]:
return _get_hf_input_model(args, model_name_or_path)


def update_input_model_options(args, config):
config["input_model"] = get_input_model_config(args)


def add_logging_options(sub_parser):
log_group = sub_parser.add_argument_group("logging options")
log_group.add_argument(
Expand All @@ -197,6 +201,7 @@ def add_logging_options(sub_parser):
default=3,
help="Logging level. Default is 3. level 0: DEBUG, 1: INFO, 2: WARNING, 3: ERROR, 4: CRITICAL",
)
return log_group


def add_remote_options(sub_parser):
Expand Down Expand Up @@ -230,8 +235,10 @@ def add_remote_options(sub_parser):
help="The compute name to run the workflow on.",
)

return remote_group

def add_model_options(sub_parser):

def add_input_model_options(sub_parser):
model_group = sub_parser.add_argument_group("Model options")
model_group.add_argument(
"-m",
Expand All @@ -255,13 +262,14 @@ def add_model_options(sub_parser):
default=None,
help="The directory containing the model script file.",
)
return model_group


def is_remote_run(args):
return all([args.resource_group, args.workspace_name, args.aml_compute])


def update_remote_option(config, args, cli_action, tempdir):
def update_remote_options(config, args, cli_action, tempdir):
if args.resource_group or args.workspace_name or args.aml_compute:
if not is_remote_run(args):
raise ValueError("resource_group, workspace_name and aml_compute are required for remote workflow run.")
Expand All @@ -270,7 +278,7 @@ def update_remote_option(config, args, cli_action, tempdir):

try:
subscription_id = json.loads(subprocess.check_output("az account show", shell=True).decode("utf-8"))["id"]
print("Using Azure subscription ID: %s", subscription_id)
print(f"Using Azure subscription ID: {subscription_id}")

except subprocess.CalledProcessError:
print(
Expand Down Expand Up @@ -316,3 +324,205 @@ def update_model_config(model_config_path: Path, output_path: Path):
model_config_path = output_path / "model_config.json"
with open(model_config_path, "w") as f:
json.dump(model_config, f, indent=4)


def add_dataset_options(sub_parser):
dataset_group = sub_parser.add_argument_group("dataset options")
dataset_group.add_argument(
"-d",
"--data_name",
type=str,
required=True,
help="The dataset name.",
)
dataset_group.add_argument("--train_subset", type=str, help="The subset to use for training.")
dataset_group.add_argument("--eval_subset", type=str, help="The subset to use for evaluation.")
# TODO(jambayk): currently only supports single file or list of files, support mapping
dataset_group.add_argument(
"--data_files", type=str, help="The dataset files. If multiple files, separate by comma."
)
dataset_group.add_argument("--train_split", type=str, default="train", help="The split to use for training.")
dataset_group.add_argument(
"--eval_split",
default="",
help="The dataset split to evaluate on.",
)
text_group = dataset_group.add_mutually_exclusive_group(required=False)
text_group.add_argument(
"--text_field",
type=str,
help="The text field to use for fine-tuning.",
)
text_group.add_argument(
"--text_template",
# using special string type to allow for escaped characters like \n
type=unescaped_str,
help=r"Template to generate text field from. E.g. '### Question: {prompt} \n### Answer: {response}'",
)
dataset_group.add_argument(
"--max_seq_len",
type=int,
default=1024,
help="Maximum sequence length for the data.",
)
dataset_group.add_argument(
"--add_special_tokens",
type=bool,
default=False,
help="Whether to add special tokens during preprocessing.",
)
dataset_group.add_argument(
"--max_samples",
type=int,
default=256,
help="Maximum samples to select from the dataset.",
)
dataset_group.add_argument(
"--batch_size",
type=int,
default=1,
help="Batch size.",
)

return dataset_group, text_group


def update_dataset_options(args, config):
load_key = ("data_configs", 0, "load_dataset_config")
preprocess_key = ("data_configs", 0, "pre_process_data_config")
dataloader_key = ("data_configs", 0, "dataloader_config")
to_replace = [
((*load_key, "data_name"), args.data_name),
((*load_key, "split"), args.train_split),
((*load_key, "subset"), args.train_subset),
(
(*load_key, "data_files"),
args.data_files.split(",") if args.data_files else None,
),
((*preprocess_key, "text_cols"), args.text_field),
((*preprocess_key, "text_template"), args.text_template),
((*preprocess_key, "max_seq_len"), args.max_seq_len),
((*preprocess_key, "add_special_tokens"), args.add_special_tokens),
((*preprocess_key, "max_samples"), args.max_samples),
((*dataloader_key, "batch_size"), args.batch_size),
]
for keys, value in to_replace:
if value is not None:
set_nested_dict_value(config, keys, value)


def add_hf_dataset_options(sub_parser):
hf_dataset_group = sub_parser.add_argument_group(
"huggingface dataset options, if dataset options are not provided, "
"user should provide the following options to modify the default data config. "
"Please refer to olive.data.container.TransformersTokenDummyDataContainer for more details."
)
hf_dataset_group.add_argument(
"--hf_model_name",
help="Huggingface model name used to load model configs from huggingface.",
)
hf_dataset_group.add_argument(
"--batch_size",
type=int,
help="Batch size of the input data.",
)
hf_dataset_group.add_argument(
"--seq_len",
type=int,
help="Sequence length to use for the input data.",
)
hf_dataset_group.add_argument(
"--past_seq_len",
type=int,
help="Past sequence length to use for the input data.",
)
hf_dataset_group.add_argument(
"--max_seq_len",
type=int,
help="Max sequence length to use for the input data.",
)
hf_dataset_group.add_argument(
"--shared_kv",
action="store_true",
help="Whether to enable share kv cache in the input data.",
)
hf_dataset_group.add_argument(
"--generative",
action="store_true",
help="Whether to enable generative mode in the input data.",
)
hf_dataset_group.add_argument(
"--ort_past_key_name",
type=str,
help="Past key name for the input data.",
)
hf_dataset_group.add_argument(
"--ort_past_value_name",
type=str,
help="Past value name for the input data.",
)
# TODO(all): Argument conflicting with use of the same name in model options
# hf_dataset_group.add_argument(
# "--trust_remote_code",
# action="store_true",
# help="Whether to trust remote code in the input data.",
# )
hf_dataset_group.add_argument(
"--max_samples",
type=int,
help="Max samples to use for the input data.",
)
hf_dataset_group.add_argument(
"--fields_no_batch",
nargs="*",
help="List of fields that should not be batched.",
)

return hf_dataset_group


def add_accelerator_options(sub_parser):
accelerator_group = sub_parser.add_argument_group("accelerator group")

accelerator_group.add_argument(
"--device",
type=str,
default="cpu",
choices=["gpu", "cpu", "npu"],
help="Device to use for the model.",
)

accelerator_group.add_argument(
"--providers_list",
type=str,
nargs="*",
choices=[
"CUDAExecutionProvider",
"DmlExecutionProvider",
"JsExecutionProvider",
"MIGraphXExecutionProvider",
"OpenVINOExecutionProvider",
"OpenVINOExecutionProvider",
"QNNExecutionProviderROCMExecutionProvider",
"TensorrtExecutionProvider",
],
help=(
"List of execution providers to use for ONNX model. They are case sensitive. "
"If not provided, all available providers will be used."
),
)

return accelerator_group


def update_accelerator_options(args, config):
to_replace = [
(("systems", "local_system", "accelerators", 0, "device"), args.device),
]

if args.providers_list:
to_replace.append((("systems", "local_system", "accelerators", 0, "execution_providers"), args.providers_list))

for k, v in to_replace:
if v is not None:
set_nested_dict_value(config, k, v)
8 changes: 4 additions & 4 deletions olive/cli/capture_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

from olive.cli.base import (
BaseOliveCLICommand,
add_input_model_options,
add_logging_options,
add_model_options,
add_remote_options,
get_input_model_config,
get_output_model_number,
is_remote_run,
update_model_config,
update_remote_option,
update_remote_options,
)
from olive.common.utils import IntEnumBase, hardlink_copy_dir, set_nested_dict_value, set_tempdir

Expand All @@ -46,7 +46,7 @@ def register_subcommand(parser: ArgumentParser):
add_logging_options(sub_parser)

# model options
add_model_options(sub_parser)
add_input_model_options(sub_parser)

sub_parser.add_argument(
"--device",
Expand Down Expand Up @@ -234,7 +234,7 @@ def get_run_config(self, tempdir: str) -> Dict:
if value is None:
continue
set_nested_dict_value(config, keys, value)
update_remote_option(config, self.args, "capture-onnx-graph", tempdir)
update_remote_options(config, self.args, "capture-onnx-graph", tempdir)

return config

Expand Down
Loading

0 comments on commit f38fa1c

Please sign in to comment.