Skip to content

Commit

Permalink
[LoRA] Add transforms to inject/optimize LoRA
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Feb 28, 2024
1 parent 10b5259 commit 13755c8
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1501,15 +1501,35 @@ def build_model_from_args(args: argparse.Namespace):

transform_seq = []

if args.lora is not None:
prefill_name = "prefill" if "prefill" in mod else "prefill_with_embed"
transform_seq.extend(
[
remove_decode_func,
relax.transform.DeadCodeElimination(),
# TODO: Shouldn't assume that we know which parameters
# should be lora-tuned at this stage. Maybe start with
# everything being lora-ized, then make specialized
# versions with BindParams for common cases?
lora_optimization_pipeline(mod[prefill_name].params, args.lora),
]
)

transform_seq.append(optimize_mod_pipeline(args, model_config))

if args.lora is not None:
transform_seq.append(bundle_lora_params)

transform_seq.append(
tvm.ir.transform.ApplyPassToFunction(
relax.transform.BundleModelParams("base_params"),
"(?!transform_params).*",
)
)

if args.lora is not None:
transform_seq.append(reorder_lora_params_after_base_model_params)

transform_seq.append(
tvm.ir.transform.ApplyPassToFunction(
tvm.ir.transform.Sequential(
Expand All @@ -1523,6 +1543,13 @@ def build_model_from_args(args: argparse.Namespace):
)
)

if args.lora is not None:
# TODO(Lunderberg): Replace this with
# transform.CheckForSpecialCase once
# https://github.com/apache/tvm/pull/16457 is fully implemented
# and landed.
transform_seq.append(auto_generate_decode_func)

mod = tvm.ir.transform.Sequential(transform_seq, name="OptimizeMLCModel")(mod)

mod.show(
Expand Down

0 comments on commit 13755c8

Please sign in to comment.