-
Notifications
You must be signed in to change notification settings - Fork 348
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support store_param_remainders
feature from Apex in TE Fused Adam
#1408
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
Signed-off-by: Selvaraj Anandaraj <[email protected]>
…ransformerEngine into param_remainder
for more information, see https://pre-commit.ci
@@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale): | |||
unscaled_state.mul_(rscale) | |||
scaled_state.copy_(unscaled_state) | |||
|
|||
def get_unscaled_state(self, param, state_name): | |||
def get_unscaled_state(self, param, state_name, store_param_remainders=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default value of store_param_remainders
is False
here, but it's True
by default in the constructor. I think it's misleading, why not just set it to True
here as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't want to store param remainders for state_name other than master_params, that's why it's defaulted to false.
I'm getting NaNs when using this feature. You can reproduce it by running Still, with all these changes, the tests fail at |
Description
When the master parameter is in FP32 and the model parameters are in BF16, we can store the trailing 16 remainder bits and reconstruct the master FP32 param from (BF16 model param + the remainder).
This helps us half the master parameter memory usage.