Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FusionDefinition.execute returns output shardings. #3732

Open
wants to merge 8 commits into
base: wjy/execute
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ endif()

if(BUILD_PYTHON)
list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/python_frontend/distributed_tensor.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_state.cpp
Expand Down
33 changes: 33 additions & 0 deletions csrc/python_frontend/distributed_tensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <exceptions.h>
#include <python_frontend/distributed_tensor.h>
#include <utils.h>

namespace nvfuser::python_frontend {

void DistributedTensor::setAxisIsShardedOn(
const int64_t axis,
const ParallelType parallel_type) {
const auto i = axis_sharded_on_.find(parallel_type);
NVF_CHECK(
i == axis_sharded_on_.end(),
"Parallel type ",
parallel_type,
" was already used to shard axis ",
i->second);
axis_sharded_on_[parallel_type] = axis;
}

int64_t DistributedTensor::axisShardedOn(
const ParallelType parallel_type) const {
return getOrDefault(axis_sharded_on_, parallel_type, -1L);
}

} // namespace nvfuser::python_frontend
50 changes: 50 additions & 0 deletions csrc/python_frontend/distributed_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#pragma once

#include <ATen/core/TensorBody.h>

#include <multidevice/device_mesh.h>
#include <type.h>

namespace nvfuser::python_frontend {

// A class that represents a distributed tensor. It wraps a local tensor, a
// mesh, and a mapping from mesh axes to tensor axes. If the mesh is empty,
// it degenerates into a non-distributed tensor.
class DistributedTensor {
public:
explicit DistributedTensor(
at::Tensor local_tensor,
DeviceMesh mesh = DeviceMesh())
: local_(std::move(local_tensor)), mesh_(std::move(mesh)) {}
DistributedTensor(const DistributedTensor&) = delete;
DistributedTensor& operator=(const DistributedTensor&) = delete;
DistributedTensor(DistributedTensor&&) = default;
DistributedTensor& operator=(DistributedTensor&&) = default;

const DeviceMesh& mesh() const {
return mesh_;
}

at::Tensor local() const {
return local_;
}

void setAxisIsShardedOn(int64_t axis, ParallelType parallel_type);

int64_t axisShardedOn(ParallelType parallel_type) const;

private:
at::Tensor local_;
DeviceMesh mesh_;
std::unordered_map<ParallelType, int64_t> axis_sharded_on_;
};

} // namespace nvfuser::python_frontend
52 changes: 46 additions & 6 deletions csrc/python_frontend/fusion_definition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
#include <debug.h>
#include <fusion_profiler.h>
#include <instrumentation.h>
#include <multidevice/utils.h>
#include <options.h>
#include <preseg_passes/pre_segmenter.h>
#include <python_frontend/distributed_tensor.h>
#include <python_frontend/fusion_cache.h>
#include <python_frontend/fusion_definition.h>
#include <python_frontend/translation.h>
#include <runtime/executor_kernel_arg.h>
#include <runtime/fusion_kernel_runtime.h>
#include <scheduler/compile_time_info.h>
#include <scheduler/scheduler_types.h>
#include <utils.h>
Expand Down Expand Up @@ -344,7 +347,7 @@ void FusionDefinition::print(std::ostream& os) const {
os << std::endl;
}

std::vector<at::Tensor> FusionDefinition::execute(
std::vector<DistributedTensor> FusionDefinition::execute(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatives:

  1. Return std::vector<std::variant<at::Tensor, DistributedTensor>>. Because DistributedTensor can also represent a non-distributed tensor, I chose the current API for simplicity -- C++ is more verbose than Python when dealing with dynamic types.
  2. Return std::variant<std::vector<at::Tensor>, std::vector<DistributedTensor>>. Same reason.
  3. Store output shardings (i.e. the mesh and the mesh-to-tenseor-axis mapping) to a field of FusionDefinition and retrieve it using another method. This would be similar to getDebugOutput. I didn't choose that because it introduced a new state in the class that could get out of sync.

const at::ArrayRef<c10::IValue>& inputs,
std::optional<int8_t> selected_device,
bool override_user_schedule,
Expand Down Expand Up @@ -399,9 +402,9 @@ std::vector<at::Tensor> FusionDefinition::execute(
};
const auto* user_sched = find_user_schedule();

std::vector<at::Tensor> outputs;
std::vector<at::Tensor> out_tensors;
if (user_sched == nullptr) {
outputs = scheds->auto_gen_schedules->runFusionWithInputs(
out_tensors = scheds->auto_gen_schedules->runFusionWithInputs(
inputs, std::nullopt, selected_device);
} else {
if (isProfilerEnabledWithCupti()) {
Expand All @@ -417,7 +420,7 @@ std::vector<at::Tensor> FusionDefinition::execute(
user_sched->executor->compile(
user_sched->scheduled_fusion.get(), inputs);
}
outputs = user_sched->executor->run(inputs);
out_tensors = user_sched->executor->run(inputs);
} else {
// Automatic scheduler was used for UserSchedule.
// Pass launch and compile params to compileFusion and runFusion.
Expand All @@ -430,7 +433,7 @@ std::vector<at::Tensor> FusionDefinition::execute(
user_sched->heuristic_params->cparams,
user_sched->heuristic_params->scheduler_type);
}
outputs = user_sched->executor->run(
out_tensors = user_sched->executor->run(
inputs,
user_sched->heuristic_params->lparams,
user_sched->heuristic_params->cparams);
Expand All @@ -453,7 +456,44 @@ std::vector<at::Tensor> FusionDefinition::execute(
debug_output_ = debug_ss.str();
}

return outputs;
// Convert `at::Tensor`s to `DistributedTensor`s.
std::vector<DistributedTensor> out_dtensors;
out_dtensors.reserve(out_tensors.size());
if (user_sched == nullptr) {
FusionKernelRuntime* runtime =
scheds->auto_gen_schedules->getMostRecentKernelRuntime();
Fusion* fusion = runtime->fusionSegments()->completeFusion();

int64_t tensor_index = 0;
for (Val* out_val : fusion->outputs()) {
auto* out_tv = out_val->as<TensorView>();
if (fusion->getOutputAlias(out_tv).hide_output) {
continue;
}

const at::Tensor& out_tensor = out_tensors.at(tensor_index);
tensor_index++;
const DeviceMesh& mesh = out_tv->getDeviceMesh();
DistributedTensor out_dtensor(out_tensor, mesh);

if (mesh.size() > 0) {
for (const ParallelType parallel_type : kParallelTypeDIDs) {
if (const auto axis = getShardedLogicalAxis(out_tv, parallel_type);
axis != -1) {
out_dtensor.setAxisIsShardedOn(axis, parallel_type);
}
}
}

out_dtensors.push_back(std::move(out_dtensor));
}
NVF_ERROR(out_dtensors.size() == out_tensors.size());
} else {
for (const auto& out_tensor : out_tensors) {
out_dtensors.emplace_back(out_tensor);
}
}
return out_dtensors;
}

std::string FusionDefinition::fusionIr() {
Expand Down
5 changes: 3 additions & 2 deletions csrc/python_frontend/fusion_definition.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
// clang-format on
#pragma once

#include <exceptions.h>
#include <functional>
#include <iostream>
#include <unordered_map>

#include <exceptions.h>
#include <python_frontend/distributed_tensor.h>
#include <python_frontend/fusion_state.h>
#include <python_frontend/segmentation.h>
#include <visibility.h>
Expand Down Expand Up @@ -193,7 +194,7 @@ class NVF_API FusionDefinition : public FusionState {
//! Prints a python function representing the definition
NVF_API void print(std::ostream& os) const;
//! Executes a fusion if a valid definition or cache lookup occurred prior
NVF_API std::vector<at::Tensor> execute(
NVF_API std::vector<DistributedTensor> execute(
const at::ArrayRef<c10::IValue>& inputs,
std::optional<int8_t> device,
bool override_user_schedule,
Expand Down
25 changes: 21 additions & 4 deletions csrc/python_frontend/multidevice_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,18 @@ void bindCommunicator(py::module& nvfuser) {
}

void bindDeviceMesh(py::module& nvfuser) {
py::class_<DeviceMesh> device_mesh_class(nvfuser, "DeviceMesh");
device_mesh_class.def(py::init<std::vector<int64_t>>());
device_mesh_class.def("__repr__", [](const DeviceMesh& self) {
py::class_<DeviceMesh> device_mesh(nvfuser, "DeviceMesh");
device_mesh.def(py::init<std::vector<int64_t>>());
device_mesh.def("__repr__", [](const DeviceMesh& self) {
std::stringstream ss;
ss << self;
return ss.str();
});
device_mesh_class.def(
device_mesh.def_property_readonly(
"size",
[](const DeviceMesh& self) -> int64_t { return self.size(); },
"Returns the size of the mesh.");
device_mesh.def(
"shard_tensor",
[](const DeviceMesh& self,
at::Tensor tensor,
Expand All @@ -71,11 +75,24 @@ void bindDeviceMesh(py::module& nvfuser) {
py::arg("device_id"));
}

void bindDistributedTensor(py::module& nvfuser) {
py::class_<DistributedTensor> distributed_tensor(
nvfuser, "_DistributedTensor");
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docs are omitted because the class name starts with an underscore, indicating it's hidden from the user. nvfuser.DistributedTensor, defined in __init__.py, has docs though.

distributed_tensor.def_property_readonly("local", &DistributedTensor::local);
distributed_tensor.def_property_readonly(
"mesh", &DistributedTensor::mesh, py::return_value_policy::reference);
distributed_tensor.def(
"axis_sharded_on",
&DistributedTensor::axisShardedOn,
py::arg("parallel_type"));
}

} // namespace

void bindMultidevice(py::module& nvfuser) {
bindCommunicator(nvfuser);
bindDeviceMesh(nvfuser);
bindDistributedTensor(nvfuser);
}

} // namespace nvfuser::python_frontend
7 changes: 5 additions & 2 deletions csrc/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -543,9 +543,12 @@ NVF_API char* getNvFuserEnv(const char* env_name);

// Returns the mapped value or the default.
template <typename K, typename V>
V getOrDefault(const std::unordered_map<K, V>& map, const K& key) {
const V& getOrDefault(
const std::unordered_map<K, V>& map,
const K& key,
const V& default_value = V()) {
const auto i = map.find(key);
return i == map.end() ? V() : i->second;
return i == map.end() ? default_value : i->second;
}

size_t deviceAvailableSharedMemoryBytes();
Expand Down
65 changes: 59 additions & 6 deletions nvfuser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import os
import re
import sys
from typing import Callable, Optional, Union, List # noqa: F401
from typing import Callable, Optional, Union, List, Iterable # noqa: F401
import warnings

import torch
from torch.utils._pytree import tree_map

# This is needed when libnvfuser.so is patched and doesn't have the pytorch library location available.
pytorch_lib_dir = os.path.join(os.path.dirname(torch.__file__), "lib")
Expand Down Expand Up @@ -50,6 +51,54 @@ def disable_automatic_serialization():
atexit.unregister(_C.serialize)


class DistributedTensor(torch.Tensor):
"""Wraps a _C._DistributedTensor as a torch.Tensor.

This way, the user can use the underlying local tensor without extra conversion.
For example, `torch.testing.assert_close(dtensor, expected_local_tensor)`.
"""

_dtensor: _C._DistributedTensor

@staticmethod
def __new__(cls, dtensor: _C._DistributedTensor):
t = dtensor.local
return torch.Tensor._make_wrapper_subclass(
cls,
t.shape,
strides=t.stride(),
storage_offset=t.storage_offset(),
device=t.device,
layout=t.layout,
requires_grad=t.requires_grad,
dtype=t.dtype,
)

def __init__(self, dtensor: _C._DistributedTensor):
self._dtensor = dtensor

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
def unwrap(t):
if isinstance(t, DistributedTensor):
return t._dtensor.local
return t

return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))

@property
def mesh(self) -> DeviceMesh:
"""Returns the device mesh."""
return self._dtensor.mesh

def axis_sharded_on(self, parallel_type: ParallelType) -> int:
"""Returns the axis sharded on the given parallel type.

If the distributed tensor is replicated on that parallel type, returns -1.
"""
return self._dtensor.axis_sharded_on(parallel_type)


class FusionDefinition(_C._FusionDefinition):
def __init__(self, id=None, max_length=1024):
super(FusionDefinition, self).__init__(id, max_length)
Expand Down Expand Up @@ -198,7 +247,7 @@ def execute(
save_repro_inputs=False,
_enable_options: list[str] = [],
_disable_options: list[str] = [],
):
) -> list[torch.Tensor]:
"""
Executes an nvFuser set of kernels for a given Fusion

Expand Down Expand Up @@ -306,8 +355,6 @@ def execute(
if hasattr(self, "segments") and len(self.segments) > 0:
return self._execute_segments(inputs, device=device, profile=profile)

results = None

try:
if print_repro:
print(self.repro_script_for(inputs))
Expand All @@ -316,7 +363,7 @@ def execute(
"Reset the FusionCache manually to avoid reusing kernels when re-executing the fusion definition with different options."
)

results = self._execute(
out_dtensors: Iterable[_C.DistributedTensor] = self._execute(
inputs,
device=device,
override_user_schedule=override_user_schedule,
Expand All @@ -325,7 +372,13 @@ def execute(
_enable_options=_enable_options,
_disable_options=_disable_options,
)
return results
out_tensors = []
for out_dtensor in out_dtensors:
if out_dtensor.mesh.size == 0:
out_tensors.append(out_dtensor.local)
else:
out_tensors.append(DistributedTensor(out_dtensor))
return out_tensors
except Exception as err:
logger.exception(self._repro_error_str("executing", inputs))
raise
Expand Down
Loading
Loading