Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FusionDefinition.execute returns output shardings. #3732

Open
wants to merge 8 commits into
base: wjy/execute
Choose a base branch
from

Conversation

wujingyue
Copy link
Collaborator

@wujingyue wujingyue commented Jan 19, 2025

This is done by adding a nvfuser.DistributedTensor that inherits from torch.Tensor and wraps a mesh and a mesh-axis-to-tensor-axis mapping.

Copy link

github-actions bot commented Jan 19, 2025

PR Reviewer Guide 🔍

(Review updated until commit d7dc6cd)

Here are some key observations to aid the review process:

⏱️ Estimated effort to review: 4 🔵🔵🔵🔵⚪
🧪 PR contains tests
⚡ Recommended focus areas for review

Logic Change

The FusionDefinition::execute function now returns a vector of DistributedTensor instead of at::Tensor. This change may affect the logic of the function and its usage.

std::vector<DistributedTensor> FusionDefinition::execute(
    const at::ArrayRef<c10::IValue>& inputs,
    std::optional<int8_t> selected_device,
    bool override_user_schedule,
    bool capture_debug_output,
    bool profile,
    std::vector<std::string> _enable_options,
    std::vector<std::string> _disable_options) const {
  debug_output_ = std::nullopt;
  std::stringstream debug_ss;
  DebugStreamGuard dsg(capture_debug_output ? debug_ss : std::cout);

  NVF_CHECK(id().has_value(), "Valid fusion schedule is not available!");

  auto scheds = fusionCache()->queryFusionSchedules(id().value());

  if (profile) {
    ProfilerOptionsGuard::getCurOptions().set(ProfilerOption::Enable);
  }

  EnableOptionsGuard enable_opt_guard;
  for (const auto& _enable_option : _enable_options) {
    std::optional<EnableOption> opt = stringToEnableOption(_enable_option);
    NVF_CHECK(opt.has_value(), "Unrecognized enable_option: ", _enable_option);
    EnableOptionsGuard::getCurOptions().set(opt.value());
  }

  DisableOptionsGuard disable_opt_guard;
  for (const auto& _disable_option : _disable_options) {
    std::optional<DisableOption> opt = stringToDisableOption(_disable_option);
    NVF_CHECK(
        opt.has_value(), "Unrecognized disable_option: ", _disable_option);
    DisableOptionsGuard::getCurOptions().set(opt.value());
  }

  auto find_user_schedule = [&]() -> const UserSchedule* {
    if (override_user_schedule) {
      return nullptr;
    }

    auto user_sched_id = fusionCache()->queryUserScheduleId(scheds, inputs);
    if (!user_sched_id.has_value()) {
      return nullptr;
    }

    auto device = getCommonDeviceCUDA(inputs, selected_device);
    NVF_CHECK(
        inputs.empty() || device > -1,
        "Inputs are not all on the same device or don't match selection!");
    const UserSchedule& user_sched =
        fusionCache()->queryUserSchedule(scheds, user_sched_id.value(), device);
    return &user_sched;
  };
  const auto* user_sched = find_user_schedule();

  std::vector<at::Tensor> out_tensors;
  if (user_sched == nullptr) {
    out_tensors = scheds->auto_gen_schedules->runFusionWithInputs(
        inputs, std::nullopt, selected_device);
  } else {
    if (isProfilerEnabledWithCupti()) {
      FusionProfiler::start();
      FusionProfiler::createSegments(1);
    }
    scheds->last_user_def_scheduled_ir = user_sched->scheduled_fusion.get();
    scheds->last_user_def_executor = user_sched->executor.get();

    if (user_sched->heuristic_params == nullptr) {
      // Manual schedule
      if (!user_sched->executor->isCompiled()) {
        user_sched->executor->compile(
            user_sched->scheduled_fusion.get(), inputs);
      }
      out_tensors = user_sched->executor->run(inputs);
    } else {
      // Automatic scheduler was used for UserSchedule.
      // Pass launch and compile params to compileFusion and runFusion.
      if (!user_sched->executor->isCompiled()) {
        user_sched->executor->compile(
            user_sched->scheduled_fusion.get(),
            KernelArgumentHolder::createKernelArgumentHolder(
                inputs, getCommonDeviceCUDA(inputs)),
            user_sched->heuristic_params->lparams,
            user_sched->heuristic_params->cparams,
            user_sched->heuristic_params->scheduler_type);
      }
      out_tensors = user_sched->executor->run(
          inputs,
          user_sched->heuristic_params->lparams,
          user_sched->heuristic_params->cparams);
    }

    if (isProfilerEnabledWithCupti()) {
      FusionProfiler::segment(0).scheduler("user");
      FusionProfiler::stop();
      if (isProfilerPrintingEnabled()) {
        debug() << FusionProfiler::profile();
      }
    }
  }

  if (profile) {
    ProfilerOptionsGuard::getCurOptions().unset(ProfilerOption::Enable);
  }

  if (capture_debug_output) {
    debug_output_ = debug_ss.str();
  }

  // Convert `at::Tensor`s to `DistributedTensor`s.
  std::vector<DistributedTensor> out_dtensors;
  out_dtensors.reserve(out_tensors.size());
  if (user_sched == nullptr) {
    FusionKernelRuntime* runtime =
        scheds->auto_gen_schedules->getMostRecentKernelRuntime();
    Fusion* fusion = runtime->fusionSegments()->completeFusion();

    int64_t tensor_index = 0;
    for (Val* out_val : fusion->outputs()) {
      auto* out_tv = out_val->as<TensorView>();
      if (fusion->getOutputAlias(out_tv).hide_output) {
        continue;
      }

      const at::Tensor& out_tensor = out_tensors.at(tensor_index);
      tensor_index++;
      const DeviceMesh& mesh = out_tv->getDeviceMesh();
      DistributedTensor out_dtensor(out_tensor, mesh);

      if (mesh.size() > 0) {
        for (const ParallelType parallel_type : kParallelTypeDIDs) {
          if (const auto axis = getShardedLogicalAxis(out_tv, parallel_type);
              axis != -1) {
            out_dtensor.setAxisIsShardedOn(axis, parallel_type);
          }
        }
      }

      out_dtensors.push_back(std::move(out_dtensor));
    }
    NVF_ERROR(out_dtensors.size() == out_tensors.size());
  } else {
    for (const auto& out_tensor : out_tensors) {
      out_dtensors.emplace_back(out_tensor);
    }
  }
  return out_dtensors;
}
New Functionality

A new class DistributedTensor has been introduced, which represents a distributed tensor and provides methods to set and get the axis sharded on a parallel type.

// clang-format off
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 */
// clang-format on

#include <exceptions.h>
#include <python_frontend/distributed_tensor.h>
#include <utils.h>

namespace nvfuser::python_frontend {

void DistributedTensor::setAxisIsShardedOn(
    const int64_t axis,
    const ParallelType parallel_type) {
  const auto i = axis_sharded_on_.find(parallel_type);
  NVF_CHECK(
      i == axis_sharded_on_.end(),
      "Parallel type ",
      parallel_type,
      " was already used to shard axis ",
      i->second);
  axis_sharded_on_[parallel_type] = axis;
}

int64_t DistributedTensor::axisShardedOn(
    const ParallelType parallel_type) const {
  return getOrDefault(axis_sharded_on_, parallel_type, -1L);
}

} // namespace nvfuser::python_frontend
Type Change

The FusionDefinition.execute method now returns a list of DistributedTensor instead of torch.Tensor.

    capture_debug_output=False,
    print_repro=False,
    profile=False,
    save_repro_inputs=False,
    _enable_options: list[str] = [],
    _disable_options: list[str] = [],
) -> list[torch.Tensor]:
    """
    Executes an nvFuser set of kernels for a given Fusion

    The FusionDefinition will be executed on a single CUDA device.
    Typically, which device to run on is determined by the devices where
    the input tensors reside. However, if the Fusion is defined such that
    none of the inputs are tensors, we are not able to infer a device from
    the inputs. For example, the following FusionDefinition will be unable
    to unambiguously infer the device of its output:

        with FusionDefinition() as fd:
            tv1 = fd.ops.full([5])
            fd.add_output(tv1)

    In that case, we default to selecting the first CUDA
    device, i.e. `torch.device("cuda:0")`. This method enables selecting an
    alternative preferred device.

    Args:
        inputs (List[Union[Tensor, Scalar]]): A list of inputs to fusion.

    Kwargs:
        device (Optional[Union[int, str, torch.device]]): This is a hint to run
            the Fusion on the given CUDA device. This is not typically
            necessary, as the device is usually inferred from the locations
            of input tensors. However, for some fusion definitions, no
            tensors will be input (for example when all tensors are
            generated with `full` or `uniform` ops). In these cases, we
            must either tell NVFuser where to run the resulting kernel, or
            let it default to 0. Note that passing this option providing
            and input tensors that lie on another device is an error.
        override_user_schedule (bool): For a user defined schedule,
            override with auto-generated schedule (default: False)
        capture_debug_output (bool): Whether to capture any printed
            debugging information as a string. If True, the string can be
            retrieved after execution using :meth:`get_debug_output`. If False,
            then that method will return None when called.
        print_repro (bool): Prints a reproduction script to stdout.
        profile (bool): Captures a CUPTI based profile of a fusion.
        save_repro_inputs (bool): Saves the inputs for last_repro_script() to
            provide a provide a reproduction script.
        _enable_options/_disable_options (list): NVFUSER_ENABLE/DISABLE options to use.
            This is an alternative to environment variables.
            Note: Currently, we do not cache/store these options in the FusionCache which makes it
                plausible to reuse kernels when executing the same fusion definition with different sets of options.
                Reset the FusionCache manually to avoid inadvertent kernel reuse when between different sets of options.

    Returns:
        List[Tensor]
    """
    self.profiled = profile

    if device is not None:
        if not isinstance(device, torch.device):
            device = torch.device(device)
        assert (
            device.type == "cuda"
        ), "If device argument is passed it must be a CUDA device"
        device = device.index

    # if definition is not defined by a context manager, try a child class
    if self.id() is None:
        self._setup_definition()
        self.definition()
        self._finalize_definition()

    defined_multidevice_schedule = hasattr(
        self, "multidevice_schedule"
    ) and isinstance(self.multidevice_schedule, Callable)
    defined_schedule = hasattr(self, "schedule") and isinstance(
        self.schedule, Callable
    )
    assert not (
        defined_multidevice_schedule and defined_schedule
    ), "I haven't tested what if both are defined. We don't plan to support this use case although it may just work."

    if defined_multidevice_schedule:
        # Unlike `schedule`, `multidevice_schedule` is designed for inter-device
        # scheduling, The scheduling is done before concretization and therefore
        # before pre-segmentation. `schedule` however assumes the FusionDefinition
        # has been concretized and pre-segmented, and therefore requires
        # `_setup_schedule` and `_finalize_schedule` to be called before and after.
        #
        # Note: there's a plan to embed multidevice schedules into FusionDefinition
        # as annotating nodes. This may eventually replace `multidevice_schedule`.
        self._setup_multidevice_schedule()
        self.multidevice_schedule()
        self._finalize_multidevice_schedule()

    # If schedule is defined by child class and schedule is not defined for
    # inputs, make a schedule.
    if defined_schedule:
        # Schedule fusion if it does not exist yet or profiling fusion
        if profile or not self._exist_schedule(inputs):
            self._setup_schedule(inputs, overwrite_existing_schedule=profile)
            self.schedule()
            self._finalize_schedule(inputs)

    if save_repro_inputs:
        from torch._subclasses.fake_tensor import FakeTensorMode

        fake_mode = FakeTensorMode()
        self.fake_inputs = [fake_mode.from_tensor(inp) for inp in inputs]

    if hasattr(self, "segments") and len(self.segments) > 0:
        return self._execute_segments(inputs, device=device, profile=profile)

    try:
        if print_repro:
            print(self.repro_script_for(inputs))
        if len(_enable_options) or len(_disable_options):
            warnings.warn(
                "Reset the FusionCache manually to avoid reusing kernels when re-executing the fusion definition with different options."
            )

        out_dtensors: Iterable[_C.DistributedTensor] = self._execute(
            inputs,
            device=device,
            override_user_schedule=override_user_schedule,
            capture_debug_output=capture_debug_output,
            profile=profile,
            _enable_options=_enable_options,
            _disable_options=_disable_options,
        )

        out_tensors: list[torch.Tensor] = []
        for out_dtensor in out_dtensors:
            if out_dtensor.mesh.size == 0:
                out_tensors.append(out_dtensor.local)
            else:
                out_tensors.append(DistributedTensor(out_dtensor))
        return out_tensors

@@ -71,11 +75,24 @@ void bindDeviceMesh(py::module& nvfuser) {
py::arg("device_id"));
}

void bindDistributedTensor(py::module& nvfuser) {
py::class_<DistributedTensor> distributed_tensor(
nvfuser, "_DistributedTensor");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs are omitted because the class name starts with an underscore, indicating it's hidden from the user. nvfuser.DistributedTensor, defined in __init__.py, has docs though.

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@wujingyue
Copy link
Collaborator Author

!test

@@ -344,7 +347,7 @@ void FusionDefinition::print(std::ostream& os) const {
os << std::endl;
}

std::vector<at::Tensor> FusionDefinition::execute(
std::vector<DistributedTensor> FusionDefinition::execute(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatives:

  1. Return std::vector<std::variant<at::Tensor, DistributedTensor>>. Because DistributedTensor can also represent a non-distributed tensor, I chose the current API for simplicity -- C++ is more verbose than Python when dealing with dynamic types.
  2. Return std::variant<std::vector<at::Tensor>, std::vector<DistributedTensor>>. Same reason.
  3. Store output shardings (i.e. the mesh and the mesh-to-tenseor-axis mapping) to a field of FusionDefinition and retrieve it using another method. This would be similar to getDebugOutput. I didn't choose that because it introduced a new state in the class that could get out of sync.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant