diff --git a/megatron/model/norms.py b/megatron/model/norms.py index ba175d3eb..777b0e684 100644 --- a/megatron/model/norms.py +++ b/megatron/model/norms.py @@ -33,6 +33,12 @@ def get_norm(neox_args): norm = MixedFusedLayerNorm else: norm = LayerNorm + elif neox_args.norm == "non_parametric_layernorm": + eps = neox_args.layernorm_epsilon + if neox_args.layernorm_fusion: + raise ValueError(f"neox_args.layernorm_fusion not supported for non_parametric_layernorm") + else: + norm = LayerNorm(elementwise_affine=False, bias=False) elif neox_args.norm == "scalenorm": eps = neox_args.scalenorm_epsilon norm = ScaleNorm diff --git a/megatron/neox_arguments/neox_args.py b/megatron/neox_arguments/neox_args.py index 9c8d3635f..82d400032 100644 --- a/megatron/neox_arguments/neox_args.py +++ b/megatron/neox_arguments/neox_args.py @@ -168,10 +168,10 @@ class NeoXArgsModel(NeoXArgsTemplate): """ norm: Literal[ - "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm" + "layernorm", "rmsnorm", "non_parametric_layernorm", "scalenorm", "te_rmsnorm", "te_layernorm" ] = "layernorm" """ - Normalization layer to use. Choose from "layernorm", "rmsnorm", "scalenorm", "te_rmsnorm", "te_layernorm". + Normalization layer to use. Choose from "layernorm", "rmsnorm", "non_parametric_layernorm", "scalenorm", "te_rmsnorm", "te_layernorm". """ layernorm_fusion: bool = False