diff --git a/tensorrt_llm/builder.py b/tensorrt_llm/builder.py index 335f5c139..bfaff1b95 100644 --- a/tensorrt_llm/builder.py +++ b/tensorrt_llm/builder.py @@ -584,7 +584,7 @@ def from_dict(cls, config, plugin_config=None): config.get('auto_parallel_config', {})) max_encoder_input_len = config.pop('max_encoder_input_len', 1024) weight_streaming = config.pop('weight_streaming', False) - + use_fused_mlp = config.pop('use_fused_mlp', False) use_strip_plan = config.pop('use_strip_plan', False) if plugin_config is None: @@ -623,6 +623,7 @@ def from_dict(cls, config, plugin_config=None): max_encoder_input_len=max_encoder_input_len, weight_sparsity=weight_sparsity, weight_streaming=weight_streaming, + use_fused_mlp=use_fused_mlp, plugin_config=plugin_config, dry_run=dry_run, visualize_network=visualize_network)