Skip to content

Commit

Permalink
Add device option to mergekit-moe for use on cpu & Apple silicon (tha…
Browse files Browse the repository at this point in the history
…nks @ddh0)
  • Loading branch information
cg123 committed Dec 22, 2023
1 parent 66523c3 commit 503e740
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion mergekit/io/lazy_tensor_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def from_disk(cls, base_path: str) -> "ShardedTensorIndex":
tensor_paths = {key: shard_name for key in st.keys()}
else:
# this is ugly but not much else can be done
shard = torch.load(model_path)
shard = torch.load(model_path, map_location="meta")
if "state_dict" in shard:
shard = shard["state_dict"]

Expand Down
9 changes: 8 additions & 1 deletion mergekit/scripts/mixtral_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def get_gate_params(
load_in_4bit: bool = False,
load_in_8bit: bool = False,
lazy_unpickle: bool = False,
device: str = "auto",
):
gate_vecs = []
_do_it = None
Expand All @@ -121,7 +122,7 @@ def get_gate_params(
model = AutoModelForCausalLM.from_pretrained(
model_ref.path,
torch_dtype=torch.bfloat16,
device_map="auto",
device_map=device,
low_cpu_mem_usage=True,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
Expand Down Expand Up @@ -150,6 +151,7 @@ def build(
load_in_4bit: bool = False,
load_in_8bit: bool = False,
lazy_unpickle: bool = False,
device: str = "auto",
):
base_model = ModelReference.parse(config.base_model)
base_cfg = base_model.config()
Expand Down Expand Up @@ -223,6 +225,7 @@ def build(
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
lazy_unpickle=lazy_unpickle,
device=device,
)
# gate_vecs: (num_layers, num_experts, hidden_size)

Expand Down Expand Up @@ -251,6 +254,9 @@ def main(
lazy_unpickle: Annotated[
bool, typer.Option(help="Use experimental lazy unpickler")
] = False,
device: Annotated[
str, typer.Option(help="Device to use to compute embeddings")
] = "auto",
):
with open(config_path, "r", encoding="utf-8") as file:
data = yaml.load(file, yaml.SafeLoader)
Expand All @@ -262,6 +268,7 @@ def main(
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
lazy_unpickle=lazy_unpickle,
device=device,
)


Expand Down

0 comments on commit 503e740

Please sign in to comment.