diff --git a/mergekit/io/lazy_tensor_loader.py b/mergekit/io/lazy_tensor_loader.py index 94931a08..d450707f 100644 --- a/mergekit/io/lazy_tensor_loader.py +++ b/mergekit/io/lazy_tensor_loader.py @@ -98,7 +98,7 @@ def from_disk(cls, base_path: str) -> "ShardedTensorIndex": tensor_paths = {key: shard_name for key in st.keys()} else: # this is ugly but not much else can be done - shard = torch.load(model_path) + shard = torch.load(model_path, map_location="meta") if "state_dict" in shard: shard = shard["state_dict"] diff --git a/mergekit/scripts/mixtral_moe.py b/mergekit/scripts/mixtral_moe.py index 08559c5c..0e2355ca 100644 --- a/mergekit/scripts/mixtral_moe.py +++ b/mergekit/scripts/mixtral_moe.py @@ -100,6 +100,7 @@ def get_gate_params( load_in_4bit: bool = False, load_in_8bit: bool = False, lazy_unpickle: bool = False, + device: str = "auto", ): gate_vecs = [] _do_it = None @@ -121,7 +122,7 @@ def get_gate_params( model = AutoModelForCausalLM.from_pretrained( model_ref.path, torch_dtype=torch.bfloat16, - device_map="auto", + device_map=device, low_cpu_mem_usage=True, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit, @@ -150,6 +151,7 @@ def build( load_in_4bit: bool = False, load_in_8bit: bool = False, lazy_unpickle: bool = False, + device: str = "auto", ): base_model = ModelReference.parse(config.base_model) base_cfg = base_model.config() @@ -223,6 +225,7 @@ def build( load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit, lazy_unpickle=lazy_unpickle, + device=device, ) # gate_vecs: (num_layers, num_experts, hidden_size) @@ -251,6 +254,9 @@ def main( lazy_unpickle: Annotated[ bool, typer.Option(help="Use experimental lazy unpickler") ] = False, + device: Annotated[ + str, typer.Option(help="Device to use to compute embeddings") + ] = "auto", ): with open(config_path, "r", encoding="utf-8") as file: data = yaml.load(file, yaml.SafeLoader) @@ -262,6 +268,7 @@ def main( load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit, lazy_unpickle=lazy_unpickle, + device=device, )