diff --git a/llama/generation.py b/llama/generation.py index 5f8faf9f3..2de45476a 100755 --- a/llama/generation.py +++ b/llama/generation.py @@ -19,6 +19,13 @@ from llama.model import ModelArgs, Transformer from llama.tokenizer import Tokenizer +if torch.backends.mps.is_available(): + device = torch.device('mps') +elif torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') + Role = Literal["system", "user", "assistant"] @@ -82,14 +89,18 @@ def build( """ if not torch.distributed.is_initialized(): - torch.distributed.init_process_group("nccl") + if device == torch.device('cuda'): + torch.distributed.init_process_group("nccl") + else: + torch.distributed.init_process_group("gloo") if not model_parallel_is_initialized(): if model_parallel_size is None: model_parallel_size = int(os.environ.get("WORLD_SIZE", 1)) initialize_model_parallel(model_parallel_size) local_rank = int(os.environ.get("LOCAL_RANK", 0)) - torch.cuda.set_device(local_rank) + if device == torch.device('cuda'): + torch.cuda.set_device(local_rank) # seed must be the same in all processes torch.manual_seed(seed) @@ -115,9 +126,13 @@ def build( ) tokenizer = Tokenizer(model_path=tokenizer_path) model_args.vocab_size = tokenizer.n_words - torch.set_default_tensor_type(torch.cuda.HalfTensor) + if device == torch.device('cuda'): + torch.set_default_tensor_type(torch.cuda.HalfTensor) + else: + torch.set_default_tensor_type(torch.HalfTensor) model = Transformer(model_args) model.load_state_dict(checkpoint, strict=False) + model.to(device) print(f"Loaded in {time.time() - start_time:.2f} seconds") return Llama(model, tokenizer) @@ -165,14 +180,14 @@ def generate( total_len = min(params.max_seq_len, max_gen_len + max_prompt_len) pad_id = self.tokenizer.pad_id - tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda") + tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device) for k, t in enumerate(prompt_tokens): - tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) if logprobs: - token_logprobs = torch.zeros_like(tokens, dtype=torch.float) + token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=device) prev_pos = 0 - eos_reached = torch.tensor([False] * bsz, device="cuda") + eos_reached = torch.tensor([False] * bsz, device=device) input_text_mask = tokens != pad_id if min_prompt_len == total_len: logits = self.model.forward(tokens, prev_pos) diff --git a/llama/model.py b/llama/model.py index c78570f68..24e8c1e3e 100755 --- a/llama/model.py +++ b/llama/model.py @@ -15,6 +15,12 @@ ) from torch import nn +if torch.backends.mps.is_available(): + device = torch.device('mps') +elif torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') @dataclass class ModelArgs: @@ -153,12 +159,17 @@ def apply_rotary_emb( """ + if not torch.cuda.is_available(): + xq = xq.to('cpu') + xk = xk.to('cpu') xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + if not torch.cuda.is_available(): + freqs_cis = freqs_cis.to('cpu') xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -240,7 +251,7 @@ def __init__(self, args: ModelArgs): self.n_local_kv_heads, self.head_dim, ) - ).cuda() + ).to(device) self.cache_v = torch.zeros( ( args.max_batch_size, @@ -248,7 +259,7 @@ def __init__(self, args: ModelArgs): self.n_local_kv_heads, self.head_dim, ) - ).cuda() + ).to(device) def forward( self, @@ -474,7 +485,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int): mask = None if seqlen > 1: mask = torch.full( - (seqlen, seqlen), float("-inf"), device=tokens.device + (seqlen, seqlen), float("-inf"), device=torch.device('cpu') ) mask = torch.triu(mask, diagonal=1) @@ -484,12 +495,12 @@ def forward(self, tokens: torch.Tensor, start_pos: int): # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for # j > cache_len + i, since row i corresponds to token cache_len + i. mask = torch.hstack([ - torch.zeros((seqlen, start_pos), device=tokens.device), + torch.zeros((seqlen, start_pos), device=torch.device('cpu')), mask ]).type_as(h) for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) + h = layer(h, start_pos, freqs_cis, (mask.to(device) if mask is not None else mask)) h = self.norm(h) output = self.output(h).float() return output diff --git a/llama/tokenizer.py b/llama/tokenizer.py index 3eda89a06..68adaad9f 100755 --- a/llama/tokenizer.py +++ b/llama/tokenizer.py @@ -65,4 +65,4 @@ def decode(self, t: List[int]) -> str: Returns: str: The decoded string. """ - return self.sp_model.decode(t) + return self.sp_model.decode(list(filter(lambda tk: tk != -1, t)))