From 13755c8393264f2a3c0be84c47fe5c930abbbaf0 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 8 Feb 2024 20:31:33 +0000 Subject: [PATCH] [LoRA] Add transforms to inject/optimize LoRA --- mlc_llm/core.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 1c10aec9fe..86984995df 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -1501,8 +1501,25 @@ def build_model_from_args(args: argparse.Namespace): transform_seq = [] + if args.lora is not None: + prefill_name = "prefill" if "prefill" in mod else "prefill_with_embed" + transform_seq.extend( + [ + remove_decode_func, + relax.transform.DeadCodeElimination(), + # TODO: Shouldn't assume that we know which parameters + # should be lora-tuned at this stage. Maybe start with + # everything being lora-ized, then make specialized + # versions with BindParams for common cases? + lora_optimization_pipeline(mod[prefill_name].params, args.lora), + ] + ) + transform_seq.append(optimize_mod_pipeline(args, model_config)) + if args.lora is not None: + transform_seq.append(bundle_lora_params) + transform_seq.append( tvm.ir.transform.ApplyPassToFunction( relax.transform.BundleModelParams("base_params"), @@ -1510,6 +1527,9 @@ def build_model_from_args(args: argparse.Namespace): ) ) + if args.lora is not None: + transform_seq.append(reorder_lora_params_after_base_model_params) + transform_seq.append( tvm.ir.transform.ApplyPassToFunction( tvm.ir.transform.Sequential( @@ -1523,6 +1543,13 @@ def build_model_from_args(args: argparse.Namespace): ) ) + if args.lora is not None: + # TODO(Lunderberg): Replace this with + # transform.CheckForSpecialCase once + # https://github.com/apache/tvm/pull/16457 is fully implemented + # and landed. + transform_seq.append(auto_generate_decode_func) + mod = tvm.ir.transform.Sequential(transform_seq, name="OptimizeMLCModel")(mod) mod.show(