Skip to content

Commit

Permalink
Support lazy tensor allocation (#193)
Browse files Browse the repository at this point in the history
Summary:

Support lazy tensor allocation

The current algorithm to allocate tensors in et_replay is to find out the tensors that can not be generated when all ops are replayed, then pre-allocate them before replay starts and keeps them between each iterations.

However, this algorithm leads to OOM when replaying Llama4 70B model. This PR introduced TensorAllcationMode

class TensorAllcationMode(Enum):
    """
    Enum to represent the tensor allocation mode
    """

    # Allocate input tensors that can not be generated when replaying the trace
    # at the beginning and reuse them for all iterations.
    PRE_ALLOCATE = 1

    # Allocate tensors on the fly and free them after they are out of scope
    LAZY_ALLOCATE = 2

For LAZY_ALLOCATE mode, tensors are kept in tensor_storage_map and tensor_registry, and have replay_tensor_id_to_last_node_id_map and tensor_storage_id_to_last_node_id_map to track the last node id to access to the tensor and tensor storage. If the replay passes the last node, tensor or tensor_storage will be deleted appropriately.

The DIFF also introduced another option --device-memory-threshold. With LAZY_ALLOCATE, this option will free all tensors when the ratio between the allocated device memory and the total device memory is greater than device-memory-threshold. It can keep replay running with the overhead of freeing and allocating memory. Llama4 7B does not need this option when ET is captured with unique storage id (https://www.internalfb.com/diff/D66849516)

This fixed OOM issue in Llama4 70B.

Reviewed By: sanrise

Differential Revision: D66487952
  • Loading branch information
shengfukevin authored and facebook-github-bot committed Jan 2, 2025
1 parent 827ac1f commit f9c45d6
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 40 deletions.
15 changes: 9 additions & 6 deletions et_replay/et_replay_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,33 @@
TORCH_DTYPES_RNG = {
"bool": (torch.bool, torch.ones),
"int8": (torch.int8, torch.ones),
"half": (torch.half, torch.ones),
"half": (torch.half, torch.randn),
"int": (torch.int, torch.ones),
"long": (torch.int64, torch.ones),
"long int": (torch.int64, torch.ones),
"float": (torch.float32, torch.randn),
"double": (torch.float64, torch.randn),
"signed char": (torch.int8, torch.ones),
"unsigned char": (torch.uint8, torch.ones),
"c10::Half": (torch.half, torch.ones),
"c10::BFloat16": (torch.bfloat16, torch.ones),
"c10::Half": (torch.half, torch.randn),
"c10::BFloat16": (torch.bfloat16, torch.randn),
"c10::complex<float>": (torch.complex32, torch.randn),
}

TORCH_DTYPES_RNG_str = {
"bool": ("torch.bool", "torch.ones"),
"int8": ("torch.int8", "torch.ones"),
"half": ("torch.half", "torch.ones"),
"half": ("torch.half", "torch.randn"),
"int": ("torch.int", "torch.ones"),
"long": ("torch.int64", "torch.ones"),
"long int": ("torch.int64", "torch.ones"),
"float": ("torch.float32", "torch.randn"),
"double": ("torch.float64", "torch.randn"),
"signed char": ("torch.int8", "torch.ones"),
"unsigned char": ("torch.uint8", "torch.ones"),
"c10::Half": ("torch.half", "torch.ones"),
"c10::BFloat16": ("torch.bfloat16", "torch.ones"),
"c10::Half": ("torch.half", "torch.randn"),
"c10::BFloat16": ("torch.bfloat16", "torch.randn"),
"c10::complex<float>": ("torch.complex32", "torch.randn"),
}

TORCH_DTYPES_BYTES = {
Expand All @@ -56,6 +58,7 @@
"unsigned char": 1,
"c10::Half": 2,
"c10::BFloat16": 2,
"c10::complex<float>": 8,
}


Expand Down
2 changes: 1 addition & 1 deletion et_replay/tools/comm_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ def replaySingle(

if self.backendFuncs.get_global_rank() == 0:
logger.info(
f"{logLable}[{cnt+1} / {self.max_msg_cnt}] Replayed {recordName} in block [{curBlockStack}]... {global_latency:.2f} us"
f"{logLable}[{cnt+1} / {self.max_msg_cnt}] Replayed {recordName} with id={curComm.id} in block [{curBlockStack}]... {global_latency:.2f} us"
)

def benchTime(self, commsParams: commsParamsHolderBase) -> None:
Expand Down
Loading

0 comments on commit f9c45d6

Please sign in to comment.