diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 4f170e3f84..a45d9e3a12 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -598,7 +598,7 @@ def _fp8_gemm(): tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype, torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), + te.module.base.get_workspace().repeat(3), bias=None, use_bias=False, gelu=False, @@ -639,7 +639,7 @@ def _fp8_gemm2(gemm1_out): tex.FP8FwdTensors.GEMM2_INPUT, fp8_dtype, torch.uint8 if opts.fp8_output else torch.bfloat16, - te.module.base.get_workspace(), + te.module.base.get_workspace().repeat(3), bias=None, use_bias=False, gelu=False, @@ -662,7 +662,7 @@ def _gemm(): kernel_t, gemm_inp, torch.bfloat16, - te.module.base.get_workspace(), + te.module.base.get_workspace().repeat(3), bias=None, use_bias=False, gelu=False,