Skip to content

Commit

Permalink
Merge branch 'main' into model_loading
Browse files Browse the repository at this point in the history
  • Loading branch information
tigranfah committed Sep 20, 2024
2 parents ef0af4b + 379be76 commit 6049de6
Showing 1 changed file with 9 additions and 11 deletions.
20 changes: 9 additions & 11 deletions torchtitan/models/opt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def __init__(self, model_args: ModelArgs):
self.norm = build_norm(
model_args.norm_type, dim=model_args.dim, eps=model_args.norm_eps
)
self.output = lambda x: F.linear(x, self.tok_embeddings.weight)

self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
self.init_weights()

def init_weights(self):
Expand All @@ -322,24 +322,22 @@ def init_weights(self):
``init_weights``. We only call it in the constructor of this
``Transformer`` root module to avoid reinitializing tensors.
"""
if self.tok_embeddings is not None:
nn.init.normal_(self.tok_embeddings.weight)
nn.init.normal_(self.pos_encoder.weight)
for layer in self.layers.values():
if layer is not None:
layer.init_weights()
if self.norm is not None:
self.norm.reset_parameters()
final_out_std = self.model_args.dim**-0.5
cutoff_factor = 3
if self.output is not None:
if self.tok_embeddings is not None:
nn.init.trunc_normal_(
self.output.weight,
self.tok_embeddings.weight,
mean=0.0,
std=final_out_std,
a=-cutoff_factor * final_out_std,
b=cutoff_factor * final_out_std,
)
nn.init.normal_(self.pos_encoder.weight)
for layer in self.layers.values():
if layer is not None:
layer.init_weights()
if self.norm is not None:
self.norm.reset_parameters()

def forward(self, tokens: torch.Tensor):
"""
Expand Down

0 comments on commit 6049de6

Please sign in to comment.