Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Mac M1/M2 #947

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 22 additions & 7 deletions llama/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from llama.model import ModelArgs, Transformer
from llama.tokenizer import Tokenizer

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

Role = Literal["system", "user", "assistant"]


Expand Down Expand Up @@ -82,14 +89,18 @@ def build(

"""
if not torch.distributed.is_initialized():
torch.distributed.init_process_group("nccl")
if device == torch.device('cuda'):
torch.distributed.init_process_group("nccl")
else:
torch.distributed.init_process_group("gloo")
if not model_parallel_is_initialized():
if model_parallel_size is None:
model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
initialize_model_parallel(model_parallel_size)

local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
if device == torch.device('cuda'):
torch.cuda.set_device(local_rank)

# seed must be the same in all processes
torch.manual_seed(seed)
Expand All @@ -115,9 +126,13 @@ def build(
)
tokenizer = Tokenizer(model_path=tokenizer_path)
model_args.vocab_size = tokenizer.n_words
torch.set_default_tensor_type(torch.cuda.HalfTensor)
if device == torch.device('cuda'):
torch.set_default_tensor_type(torch.cuda.HalfTensor)
else:
torch.set_default_tensor_type(torch.HalfTensor)
model = Transformer(model_args)
model.load_state_dict(checkpoint, strict=False)
model.to(device)
print(f"Loaded in {time.time() - start_time:.2f} seconds")

return Llama(model, tokenizer)
Expand Down Expand Up @@ -165,14 +180,14 @@ def generate(
total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)

pad_id = self.tokenizer.pad_id
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device=device)
for k, t in enumerate(prompt_tokens):
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device)
if logprobs:
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
token_logprobs = torch.zeros_like(tokens, dtype=torch.float, device=device)

prev_pos = 0
eos_reached = torch.tensor([False] * bsz, device="cuda")
eos_reached = torch.tensor([False] * bsz, device=device)
input_text_mask = tokens != pad_id
if min_prompt_len == total_len:
logits = self.model.forward(tokens, prev_pos)
Expand Down
23 changes: 17 additions & 6 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
)
from torch import nn

if torch.backends.mps.is_available():
device = torch.device('mps')
elif torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')

@dataclass
class ModelArgs:
Expand Down Expand Up @@ -153,12 +159,17 @@ def apply_rotary_emb(


"""
if not torch.cuda.is_available():
xq = xq.to('cpu')
xk = xk.to('cpu')
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
if not torch.cuda.is_available():
freqs_cis = freqs_cis.to('cpu')
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
return xq_out.type_as(xq).to(device), xk_out.type_as(xk).to(device)


def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
Expand Down Expand Up @@ -240,15 +251,15 @@ def __init__(self, args: ModelArgs):
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)
self.cache_v = torch.zeros(
(
args.max_batch_size,
args.max_seq_len,
self.n_local_kv_heads,
self.head_dim,
)
).cuda()
).to(device)

def forward(
self,
Expand Down Expand Up @@ -474,7 +485,7 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
mask = None
if seqlen > 1:
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
(seqlen, seqlen), float("-inf"), device=torch.device('cpu')
)

mask = torch.triu(mask, diagonal=1)
Expand All @@ -484,12 +495,12 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([
torch.zeros((seqlen, start_pos), device=tokens.device),
torch.zeros((seqlen, start_pos), device=torch.device('cpu')),
mask
]).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = layer(h, start_pos, freqs_cis, (mask.to(device) if mask is not None else mask))
h = self.norm(h)
output = self.output(h).float()
return output
2 changes: 1 addition & 1 deletion llama/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,4 @@ def decode(self, t: List[int]) -> str:
Returns:
str: The decoded string.
"""
return self.sp_model.decode(t)
return self.sp_model.decode(list(filter(lambda tk: tk != -1, t)))