diff --git a/mergekit/common.py b/mergekit/common.py index 86f5c296..908914d8 100644 --- a/mergekit/common.py +++ b/mergekit/common.py @@ -71,8 +71,10 @@ def merged( return ModelReference(path=out_path) - def config(self) -> PretrainedConfig: - return AutoConfig.from_pretrained(self.path) + def config(self, trust_remote_code: bool = False) -> PretrainedConfig: + return AutoConfig.from_pretrained( + self.path, trust_remote_code=trust_remote_code + ) def tensor_index(self, cache_dir: Optional[str] = None) -> ShardedTensorIndex: assert self.lora_path is None diff --git a/mergekit/merge.py b/mergekit/merge.py index 6415d275..3aaddcb6 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -60,7 +60,8 @@ def run_merge(merge_config: MergeConfiguration, out_path: str, options: MergeOpt method = merge_methods.get(merge_config.merge_method) model_arch_info = [ - get_architecture_info(m.config()) for m in merge_config.referenced_models() + get_architecture_info(m.config(trust_remote_code=options.trust_remote_code)) + for m in merge_config.referenced_models() ] if not options.allow_crimes: if not all(a == model_arch_info[0] for a in model_arch_info[1:]): @@ -102,7 +103,9 @@ def run_merge(merge_config: MergeConfiguration, out_path: str, options: MergeOpt clone_tensors=options.clone_tensors, ) - cfg_out = method.model_out_config(merge_config) + cfg_out = method.model_out_config( + merge_config, trust_remote_code=options.trust_remote_code + ) if tokenizer: try: cfg_out.vocab_size = len(tokenizer.get_vocab()) diff --git a/mergekit/merge_methods/base.py b/mergekit/merge_methods/base.py index aabbacae..549ca3d1 100644 --- a/mergekit/merge_methods/base.py +++ b/mergekit/merge_methods/base.py @@ -45,12 +45,18 @@ def input_layer_dependencies( """List any tensors necessary when input includes a specific layer""" return [] - def model_out_config(self, config: MergeConfiguration) -> PretrainedConfig: + def model_out_config( + self, config: MergeConfiguration, trust_remote_code: bool = False + ) -> PretrainedConfig: """Return a configuration for the resulting model.""" if config.base_model: - res = ModelReference.parse(config.base_model).config() + res = ModelReference.parse(config.base_model).config( + trust_remote_code=trust_remote_code + ) else: - res = config.referenced_models()[0].config() + res = config.referenced_models()[0].config( + trust_remote_code=trust_remote_code + ) if config.dtype: res.torch_dtype = config.dtype diff --git a/mergekit/plan.py b/mergekit/plan.py index a05a0a55..c337b58a 100644 --- a/mergekit/plan.py +++ b/mergekit/plan.py @@ -35,6 +35,7 @@ def plan( merge_config: MergeConfiguration, arch_info: ArchitectureInfo, embed_permutations: Optional[Dict[ModelReference, torch.Tensor]] = None, + trust_remote_code: bool = False, ) -> Tuple[List[TensorReference], Dict[TensorReference, Operation]]: layer_idx = 0 @@ -62,7 +63,7 @@ def plan( if base_model and mref == base_model: base_included = True - model_cfg = mref.config() + model_cfg = mref.config(trust_remote_code=trust_remote_code) num_layers = arch_info.num_layers(model_cfg) slices_in.append( InputSliceDefinition( @@ -74,7 +75,7 @@ def plan( if base_model and not base_included: logging.info("Base model specified but not in input models - adding") - base_cfg = base_model.config() + base_cfg = base_model.config(trust_remote_code=trust_remote_code) num_layers = arch_info.num_layers(base_cfg) slices_in.append( InputSliceDefinition(