Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pipeline): add non-p2p-comm support #385

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/_base_/models/internlm2_1B.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler,
defaults to False.
4. batch_p2p_comm: bool, enable/disable batch p2p communication, defaults to False.
weight parallel (dict):
1. size: int, the size of weight parallel.
2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False.
Expand Down
104 changes: 74 additions & 30 deletions internlm/core/scheduler/comm/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,15 @@ def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) ->
return tensor_chunk_shape, chunk_tensor


def _p2p_func(_comm_op, _obj, _comm_rank):
if getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True:
op_or_handle = dist.P2POp(_comm_op, _obj, _comm_rank)
else:
op_or_handle = _comm_op(_obj, _comm_rank)

return op_or_handle


def create_recv_buffer_with_shapes(recv_shapes, dtype, scatter_gather_tensors):
if isinstance(recv_shapes, torch.Size):
recv_chunk_shape, recv_split = _get_tensor_shape(recv_shapes, scatter_gather_tensors)
Expand Down Expand Up @@ -78,12 +87,10 @@ def process_object_to_send(object_send, scatter_gather_tensors):

def filling_ops_queue(obj, comm_op, comm_rank, ops_queue):
if isinstance(obj, torch.Tensor):
op_to_add = dist.P2POp(comm_op, obj, comm_rank)
ops_queue.append(op_to_add)
ops_queue.append(_p2p_func(comm_op, obj, comm_rank))
else:
for tensor_to_comm in obj:
op_to_add = dist.P2POp(comm_op, tensor_to_comm, comm_rank)
ops_queue.append(op_to_add)
ops_queue.append(_p2p_func(comm_op, tensor_to_comm, comm_rank))


def _communicate(
Expand Down Expand Up @@ -156,23 +163,42 @@ def _communicate(
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)

ops = []
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
if gpc.get_local_rank(ParallelMode.PIPELINE) % 2 == 0:
if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
else:
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)

if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()
if getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()
else:
for req in ops:
req.wait()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down Expand Up @@ -265,29 +291,47 @@ def _communicate_async(
object_send_next = process_object_to_send(object_send_next, scatter_gather_tensors)

ops = []
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)
if gpc.get_local_rank(ParallelMode.PIPELINE) % 2 == 0:
if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)
if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if len(ops) > 0:
if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)
else:
if tensor_recv_prev is not None:
filling_ops_queue(tensor_recv_prev, dist.irecv, prev_rank, ops)

if object_send_next is not None:
filling_ops_queue(object_send_next, dist.isend, next_rank, ops)

if tensor_recv_next is not None:
filling_ops_queue(tensor_recv_next, dist.irecv, next_rank, ops)

if object_send_prev is not None:
filling_ops_queue(object_send_prev, dist.isend, prev_rank, ops)

if len(ops) > 0 and getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True:
reqs = dist.batch_isend_irecv(ops)

# return and do other things
yield

if len(ops) > 0:
for req in reqs: # pylint: disable=E0601
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()
if getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True:
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
internlm_accelerator.synchronize()
else:
for req in ops:
req.wait()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down
114 changes: 70 additions & 44 deletions internlm/core/scheduler/pipeline_scheduler_1f1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _call_engine(engine, data): # pylint: disable=W0237

def load_batch(self, engine, data_iter):
# Pipeline schedule just puts data in memory,
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False)
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=True)

# Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed,
# because internlm's current train dataset is packed, even using dummy data.
Expand Down Expand Up @@ -309,17 +309,18 @@ def _forward_step(
accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced

moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype"))
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff

# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
# so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())
else:
moe_loss = None

return output_obj, moe_loss

Expand Down Expand Up @@ -413,7 +414,11 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True)
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None
)
accum_moe_loss = torch.zeros(1, device=get_current_device())

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
accum_moe_loss = torch.zeros(1, device=get_current_device())
else:
accum_moe_loss = None

# Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape
Expand Down Expand Up @@ -456,8 +461,8 @@ def _forward_only_step(self, engine, return_loss=True, return_output_label=True)
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))

if accum_loss is not None:
accum_loss += accum_moe_loss
if accum_loss is not None:
accum_loss += accum_moe_loss

return output, label, accum_loss, accum_moe_loss

Expand Down Expand Up @@ -514,7 +519,11 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True)
else None
)
accum_moe_loss = torch.zeros(1, device=get_current_device())

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
accum_moe_loss = torch.zeros(1, device=get_current_device())
else:
accum_moe_loss = None

# Used for tensor meta information communication
forward_recv_shapes = self.tensor_shape
Expand Down Expand Up @@ -660,8 +669,8 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))

if accum_loss is not None:
accum_loss += accum_moe_loss
if accum_loss is not None:
accum_loss += accum_moe_loss

return output, label, accum_loss, accum_moe_loss

Expand Down Expand Up @@ -776,6 +785,7 @@ def __init__(
self._output_obj_grads = [[] for _ in range(num_chunks)]
self._moe_losses = [[] for _ in range(num_chunks)]

self._preload_micro_data = [None for _ in range(self.num_microbatches)]
self._input_obj_shapes = [self.tensor_shape for _ in range(num_chunks)]
self._output_obj_shapes = [None for _ in range(num_chunks)]
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(num_chunks)]
Expand All @@ -799,26 +809,37 @@ def _clear_state(self) -> None:
self._output_obj_grads = [[] for _ in range(self._num_chunks)]
self._moe_losses = [[] for _ in range(self._num_chunks)]

self._preload_micro_data = [None for _ in range(self.num_microbatches)]
self._input_obj_shapes = [self.tensor_shape for _ in range(self._num_chunks)]
self._output_obj_shapes = [None for _ in range(self._num_chunks)]
self._send_tensor_shape_flags = [self.tensor_shape is None for _ in range(self._num_chunks)]

def load_batch(self, engine, data_iter):
super().load_batch(engine, data_iter)

for mbs in range(self.num_microbatches):
micro_batch_data, micro_batch_label = self._load_micro_batch(
data=self.batch_data,
label=self.batch_label,
offset=mbs * self.bsz_stride,
bsz_stride=self.bsz_stride,
)

if self.data_process_func:
micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label)

micro_batch_data["label"] = micro_batch_label
self._preload_micro_data[mbs] = micro_batch_data

# overwrite microbatch_offset, since model chunks load the same microbatch, and should tract the offset
self.microbatch_offset = [0 for _ in range(self._num_chunks)]

def load_micro_batch(self, model_chunk_id):
micro_batch_data, micro_batch_label = self._load_micro_batch(
data=self.batch_data,
label=self.batch_label,
offset=self.microbatch_offset[model_chunk_id],
bsz_stride=self.bsz_stride,
)
if self.data_process_func:
micro_batch_data, micro_batch_label = self.data_process_func(micro_batch_data, micro_batch_label)
micro_batch_data["label"] = micro_batch_label
self.microbatch_offset[model_chunk_id] += self.bsz_stride
offset = self.microbatch_offset[model_chunk_id]
assert self._preload_micro_data[offset] is not None, "preload micro batch data is None"

micro_batch_data = self._preload_micro_data[offset]
self.microbatch_offset[model_chunk_id] += 1

result = move_to_device(micro_batch_data)
return result
Expand Down Expand Up @@ -872,18 +893,19 @@ def _forward_step(self, engine, chunk_id, input_obj=None):
self._accum_loss.add_(loss_reduced.detach())
output_obj = loss_reduced

moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype"))
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
moe_loss = sum(moe_losses) * gpc.config.loss.moe_loss_coeff

if self._accum_moe_loss is not None:
self._accum_moe_loss.add_(moe_loss.detach())
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
# so we need to do allreduce
if gpc.config.parallel.sequence_parallel or gpc.config.parallel.expert.no_tp:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches

if self._accum_moe_loss is not None:
self._accum_moe_loss.add_(moe_loss.detach())
else:
moe_loss = None

self._output_objs[chunk_id].append(output_obj)
self._moe_losses[chunk_id].append(moe_loss)
Expand Down Expand Up @@ -1394,7 +1416,9 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo

if return_loss and gpc.is_pipeline_last_stage(ignore_virtual=True):
self._accum_loss = torch.zeros(1, device=get_current_device())
self._accum_moe_loss = torch.zeros(1, device=get_current_device())

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
self._accum_moe_loss = torch.zeros(1, device=get_current_device())

if return_output_label:
self._return_tensors = []
Expand All @@ -1409,13 +1433,15 @@ def forward_backward_step(self, engine, data_iter, forward_only=False, return_lo
else:
output, label = (None, None)

accum_loss = self._accum_loss
accum_moe_loss = self._accum_moe_loss

if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1:
dist.all_reduce(self._accum_moe_loss, group=gpc.get_group(ParallelMode.PIPELINE))
accum_moe_loss = self._accum_moe_loss
accum_moe_loss = self._accum_moe_loss

accum_loss = self._accum_loss
if accum_loss is not None:
accum_loss += self._accum_moe_loss
if accum_loss is not None:
accum_loss += self._accum_moe_loss

self._clear_state()

Expand Down
10 changes: 7 additions & 3 deletions internlm/initialize/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ def args_sanity_check():
if "pipeline" not in gpc.config.parallel:
gpc.config.parallel._add_item("pipeline", dict(size=1, interleaved_overlap=False, mode="1F1B"))

if isinstance(gpc.config.parallel.pipeline, dict) and "mode" not in gpc.config.parallel.pipeline:
gpc.config.parallel.pipeline._add_item("mode", "1F1B")

if "tensor" not in gpc.config.parallel:
gpc.config.parallel._add_item("tensor", dict(size=1, mode=TensorParallelMode.mtp.name))

Expand All @@ -104,9 +101,16 @@ def args_sanity_check():

if isinstance(gpc.config.parallel.pipeline, int):
pp = gpc.config.parallel.pipeline
gpc.config.parallel._add_item("pipeline", dict(size=pp, interleaved_overlap=False))
else:
pp = gpc.config.parallel.pipeline.size

if isinstance(gpc.config.parallel.pipeline, dict) and "mode" not in gpc.config.parallel.pipeline:
gpc.config.parallel.pipeline._add_item("mode", "1F1B")

if "batch_p2p_comm" not in gpc.config.parallel.pipeline:
gpc.config.parallel.pipeline["batch_p2p_comm"] = False

if isinstance(gpc.config.parallel.pipeline, dict):
gpc.config.parallel.pipeline["mode"] = gpc.config.parallel.pipeline["mode"].upper()
assert gpc.config.parallel.pipeline["mode"] in [
Expand Down
1 change: 1 addition & 0 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList], isp_communicato
if (
zero_cfg.overlap_sync_grad
and gpc.is_using_parallel_mode(ParallelMode.PIPELINE)
and getattr(gpc.config.parallel.pipeline, "batch_p2p_comm", False) is True
and gpc.is_pipeline_first_stage() is False
):
# When pipeline parallelism is enabled, we prefer to only enable optimizer
Expand Down
Loading