Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support store_param_remainders feature from Apex in TE Fused Adam #1408

Open
wants to merge 6 commits into
base: main
Choose a base branch
from

Conversation

sanandaraj5597
Copy link
Contributor

Description

When the master parameter is in FP32 and the model parameters are in BF16, we can store the trailing 16 remainder bits and reconstruct the master FP32 param from (BF16 model param + the remainder).

This helps us half the master parameter memory usage.

@@ -243,13 +256,14 @@ def _apply_scale(self, state_name, unscaled_state, scaled_state, scale):
unscaled_state.mul_(rscale)
scaled_state.copy_(unscaled_state)

def get_unscaled_state(self, param, state_name):
def get_unscaled_state(self, param, state_name, store_param_remainders=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default value of store_param_remainders is False here, but it's True by default in the constructor. I think it's misleading, why not just set it to True here as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to store param remainders for state_name other than master_params, that's why it's defaulted to false.

@MaciejBalaNV
Copy link

I'm getting NaNs when using this feature. You can reproduce it by running test_fused_optimizer tests, after setting store_param_remainders=True in _initialize_state method (otherwise it fails earlier) and by commenting out torch.testing.assert_close(ref_params, master_params) check (this is expected to fail, since we now keep master_params as int16).

Still, with all these changes, the tests fail at torch.testing.assert_close(ref_params, model_params_to_fp32, rtol=1e-2, atol=1e-2, equal_nan=True) with an error message that weights are NaN.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants