From 7689c92a80719a65baea0310c22ba197742b0689 Mon Sep 17 00:00:00 2001 From: Viacheslav Kovalevskyi Date: Sat, 22 Jul 2023 13:54:03 -0700 Subject: [PATCH 1/2] Partial support of Apple M1/M2 (via CPU mode) --- llama/generation.py | 38 +++++++++++++++++++++++--------------- llama/model.py | 6 ++++-- llama/utils.py | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 17 deletions(-) create mode 100644 llama/utils.py diff --git a/llama/generation.py b/llama/generation.py index 1f37856ef..f0a01f78c 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -7,6 +7,7 @@ import time from pathlib import Path from typing import List, Literal, Optional, Tuple, TypedDict +from .utils import default_device, model_device, distrubuted_device import torch import torch.nn.functional as F @@ -58,21 +59,26 @@ def build( max_batch_size: int, model_parallel_size: Optional[int] = None, ) -> "Llama": + if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("nccl") + torch.distributed.init_process_group("gloo") if not model_parallel_is_initialized(): if model_parallel_size is None: model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) - initialize_model_parallel(model_parallel_size) - - local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) + initialize_model_parallel( + model_parallel_size, + model_parallel_backend=distrubuted_device() + ) + if torch.cuda.is_available(): + local_rank = int(os.environ.get("LOCAL_RANK", 0)) + torch.cuda.set_device(local_rank) + if local_rank > 0: + sys.stdout = open(os.devnull, "w") + device = default_device() # seed must be the same in all processes torch.manual_seed(1) - if local_rank > 0: - sys.stdout = open(os.devnull, "w") start_time = time.time() checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) @@ -81,7 +87,7 @@ def build( checkpoints ), f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}" ckpt_path = checkpoints[get_model_parallel_rank()] - checkpoint = torch.load(ckpt_path, map_location="cpu") + checkpoint = torch.load(ckpt_path, map_location=model_device()) with open(Path(ckpt_dir) / "params.json", "r") as f: params = json.loads(f.read()) @@ -92,7 +98,9 @@ def build( ) tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words - torch.set_default_tensor_type(torch.cuda.HalfTensor) + model_args.device = device + if torch.cuda.is_available(): + torch.set_default_tensor_type(torch.cuda.HalfTensor) model = Transformer(model_args) model.load_state_dict(checkpoint, strict=False) print(f"Loaded in {time.time() - start_time:.2f} seconds") @@ -123,14 +131,14 @@ def generate( total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=default_device()) for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=default_device()) if logprobs: - token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=default_device()) prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cuda") + eos_reached = torch.tensor([False] * bsz, device=default_device()) input_text_mask = tokens != pad_id for cur_pos in range(min_prompt_len, total_len): logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) @@ -250,7 +258,7 @@ def chat_completion( f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", bos=True, eos=True, - ) + ).to(default_device()) for prompt, answer in zip( dialog[::2], dialog[1::2], @@ -265,7 +273,7 @@ def chat_completion( f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", bos=True, eos=False, - ) + ).to(default_device()) prompt_tokens.append(dialog_tokens) generation_tokens, generation_logprobs = self.generate( diff --git a/llama/model.py b/llama/model.py index 258a7dc19..61887dfc6 100755 --- a/llama/model.py +++ b/llama/model.py @@ -4,6 +4,7 @@ import math from dataclasses import dataclass from typing import Any, Optional, Tuple +from .utils import default_device import fairscale.nn.model_parallel.initialize as fs_init import torch @@ -95,6 +96,7 @@ def __init__(self, args: ModelArgs): self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads + self.device = args.device if args.device is not None else default_device() self.wq = ColumnParallelLinear( args.dim, @@ -132,7 +134,7 @@ def __init__(self, args: ModelArgs): self.n_local_kv_heads, self.head_dim, ) - ).cuda() + ).to(self.device) self.cache_v = torch.zeros( ( args.max_batch_size, @@ -140,7 +142,7 @@ def __init__(self, args: ModelArgs): self.n_local_kv_heads, self.head_dim, ) - ).cuda() + ).to(self.device) def forward( self, diff --git a/llama/utils.py b/llama/utils.py new file mode 100644 index 000000000..b621950e5 --- /dev/null +++ b/llama/utils.py @@ -0,0 +1,37 @@ +import platform +import torch + +# setting False since MPS not yet supported BFloat16 that is required for LLama2 +enable_mps = False + + +def is_it_apple_arm(): + if platform.system() != 'Darwin': + return False + if platform.machine() != 'arm64': + return False + return True + + +def distrubuted_device(): + if torch.cuda.is_available(): + return "nccl" + else: + return "gloo" + + +def default_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + elif is_it_apple_arm() and enable_mps: + return torch.device("mps") + else: + return torch.device("cpu") + + +def model_device(): + if is_it_apple_arm() and enable_mps: + return torch.device("mps") + else: + # for CUDA we also want to us CPU for model + return torch.device("cpu") \ No newline at end of file From 7f62753fefac302f91e20ab8e04d7a0c32af9699 Mon Sep 17 00:00:00 2001 From: Viacheslav Kovalevskyi Date: Mon, 24 Jul 2023 19:59:01 -0700 Subject: [PATCH 2/2] Fix for the chat prediction --- llama/generation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama/generation.py b/llama/generation.py index f0a01f78c..6c20e0c04 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -258,7 +258,7 @@ def chat_completion( f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", bos=True, eos=True, - ).to(default_device()) + ) for prompt, answer in zip( dialog[::2], dialog[1::2], @@ -273,7 +273,7 @@ def chat_completion( f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", bos=True, eos=False, - ).to(default_device()) + ) prompt_tokens.append(dialog_tokens) generation_tokens, generation_logprobs = self.generate(