Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN] Support Cast fusion specific for int64 data type #23256

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 108 additions & 10 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,66 @@ bool GetShape(const NodeArg& node_arg, std::vector<int64_t>& shape, const loggin
return true;
}

bool FindCastFusableNodeIndex(const Node& node, const bool is_prec_node, NodeIndex& index) {
if (is_prec_node) {
// Preceding node should have only one input edge.
if (node.GetInputEdgesCount() == 1) {
const auto& prec_node = node.InputEdgesBegin()->GetNode();
// Preceding node should have only one output edge.
if (prec_node.GetOutputEdgesCount() == 1) {
index = prec_node.Index();
return true;
}
}
} else {
// Successive node should have only one output edge.
if (node.GetOutputEdgesCount() == 1) {
const auto& next_node = node.OutputEdgesBegin()->GetNode();
index = next_node.Index();
return true;
}
}

return false;
}

bool IsCastFusable(const Node& node, const bool is_prec_node, const logging::Logger& logger) {
if (node.OpType() != "Cast") {
return false;
}

int32_t input_type;
if (!GetType(*node.InputDefs()[0], input_type, logger))
return false;

NodeAttrHelper cast_helper(node);
const auto to_type = cast_helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_FLOAT);

if (is_prec_node) {
// If it is preceding node, only allows casting from int64 to int32.
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_INT64 &&
to_type != ONNX_NAMESPACE::TensorProto_DataType_INT32) {
return false;
}
} else {
// If it is successive node, only allows casting from int32 to int64.
if (input_type != ONNX_NAMESPACE::TensorProto_DataType_INT32 &&
to_type != ONNX_NAMESPACE::TensorProto_DataType_INT64) {
return false;
}
}

NodeIndex index;
return FindCastFusableNodeIndex(node, is_prec_node, index);
}

bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type,
const emscripten::val& wnn_limits, const logging::Logger& logger) {
const emscripten::val& wnn_limits, bool& is_fusable, const logging::Logger& logger) {
const auto& op_builders = GetOpBuilders();
if (Contains(op_builders, node.OpType())) {
const auto* op_builder = op_builders.at(node.OpType());
return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger);
return op_builder->IsOpSupported(
graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, is_fusable, logger);
} else {
return false;
}
Expand Down Expand Up @@ -99,30 +153,74 @@ bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_n
return true;
}

std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger) {
std::vector<std::vector<NodeIndex>> GetSupportedNodes(
const GraphViewer& graph_viewer, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
InlinedHashMap<NodeIndex, NodeIndex>& fused_node_map, const logging::Logger& logger) {
std::vector<std::vector<size_t>> supported_node_groups;
std::vector<size_t> supported_node_group;
const auto& node_indices = graph_viewer.GetNodesInTopologicalOrder();

// Some WebNN backends do not support the int64 data type, but this limitation can be addressed by converting the
// model's int64 inputs, outputs, and initializers to int32. However, certain ONNX nodes, such as ArgMax, ArgMin,
// and ScatterND, require the int64 data type for specific inputs or outputs.
// To handle such case, we can add Cast nodes before or after these nodes in the model and fuse them during WebNN EP
// optimization. The fusion strategy is as follows:
// 1. Verify if the Cast node can be fused with either the preceding node or the successive node.
// 2. Check if the node requiring the int64 data type can be supported solely by addressing the int64 data type
// limitation. Ensure that the node is unsupported only due to the int64 restriction.
// 3. Use an is_fusable flag to record paired nodes as <Cast node index, fusable node index> that can be
// fused together.
// 4. Mark the fusable nodes as supported after identifying them.
// 5. During WebNN graph compilation, skip the Cast node and fuse it in its paired fusable node.

InlinedHashMap<NodeIndex, WebnnNodeInfo> node_info_map;
std::vector<NodeIndex> fusable_cast_nodes;
for (size_t i = 0; i < node_indices.size(); i++) {
auto node_idx = node_indices[i];
const auto* node(graph_viewer.GetNode(node_idx));
bool supported = false;
bool is_fusable = false;
// Firstly check if platform supports the WebNN op.
if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) {
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger);
supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, is_fusable, logger);
node_info_map[node_idx] = {supported, is_fusable};
}

if (node->OpType() == "Cast" && is_fusable) {
fusable_cast_nodes.push_back(node_idx);
}
}

// Try to find the fusable nodes for the Cast nodes.
// Note: graph partition will make sure fusable nodes are in the same partition.
for (size_t i = 0; i < fusable_cast_nodes.size(); i++) {
NodeIndex fusable_node_idx;
NodeIndex fusable_cast_node_idx = fusable_cast_nodes[i];
const auto& cast_node = *graph_viewer.GetNode(fusable_cast_node_idx);
// Cast can only be fused by either preceding node or successive node.
bool fusable_found = FindCastFusableNodeIndex(cast_node, true, fusable_node_idx) ||
FindCastFusableNodeIndex(cast_node, false, fusable_node_idx);

if (fusable_found && node_info_map[fusable_node_idx].fusable) {
// Set the fusable nodes to supported.
node_info_map[fusable_cast_node_idx].supported = true;
node_info_map[fusable_node_idx].supported = true;
fused_node_map[fusable_cast_node_idx] = fusable_node_idx;
}
}

for (size_t i = 0; i < node_indices.size(); i++) {
auto node_idx = node_indices[i];
const auto* node(graph_viewer.GetNode(node_idx));
WebnnNodeInfo node_info = node_info_map[node_idx];

LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType()
<< "] index: [" << node_idx
<< "] name: [" << node->Name()
<< "] supported: [" << supported
<< "] supported: [" << node_info.supported
<< "]";
if (supported) {
if (node_info.supported) {
supported_node_group.push_back(node_idx);
} else {
if (!supported_node_group.empty()) {
Expand Down
17 changes: 12 additions & 5 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ enum class WebnnDeviceType {
NPU,
};

struct WebnnNodeInfo {
bool supported; // whether the node is supported by WebNN.
bool fusable; // whether the node is fusable with its preceding or successive node.
};

WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type);

// Collects all the initializer tensors in the subGraph and its ancestor graphs.
Expand Down Expand Up @@ -185,15 +190,17 @@ inline bool TensorExists(const ConstPointerContainer<std::vector<NodeArg*>>& def
return tensor_index < defs.size() && defs[tensor_index]->Exists();
}

bool FindCastFusableNodeIndex(const Node& node, const bool is_prec_node, NodeIndex& index);
bool IsCastFusable(const Node& node, const bool is_prec_node, const logging::Logger& logger);
bool IsTensorShapeSupported(const NodeArg& node_arg, const std::string& parent_name,
const logging::Logger& logger, bool allow_empty_input = false);

// Get a list of groups of supported nodes, each group represents a subgraph supported by WebNN EP.
std::vector<std::vector<NodeIndex>> GetSupportedNodes(const GraphViewer& graph_viewer,
const emscripten::val& wnn_builder,
const WebnnDeviceType device_type,
const emscripten::val& wnn_limits,
const logging::Logger& logger);
std::vector<std::vector<NodeIndex>> GetSupportedNodes(
const GraphViewer& graph_viewer, const emscripten::val& wnn_builder,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
InlinedHashMap<NodeIndex, NodeIndex>& fused_node_map,
const logging::Logger& logger);
static const InlinedHashMap<std::string, std::string> op_map = {
{"Abs", "abs"},
{"Add", "add"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool& is_fusable, const logging::Logger& logger) const override;
};

// Add operator related.
Expand All @@ -43,8 +45,16 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

emscripten::val options = emscripten::val::object();
options.set("keepDimensions", keep_dims == 1);
// TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API.
options.set("outputDataType", "int64");
std::string output_type = "int64";
// If ArgMax/ArgMin is fused with Cast (to int32), the output type should be int32.
const auto& fused_nodes = model_builder.GetFusedNodes();
for (const auto& pair : fused_nodes) {
if (pair.second == node.Index()) {
output_type = "int32";
}
}

options.set("outputDataType", output_type);
options.set("label", node.Name());
emscripten::val output = emscripten::val::object();

Expand Down Expand Up @@ -75,6 +85,24 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
return true;
}

bool ArgMaxMinOpBuilder::HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool& is_fusable, const logging::Logger& logger) const {
const auto& output_defs = node.OutputDefs();
const auto& op_type = node.OpType();
int32_t output_type = 0;

if (!GetType(*output_defs[0], output_type, logger))
return false;

// Check if the next node is Cast and fusable.
if (node.GetOutputEdgesCount() == 1) {
const auto& next_node = node.OutputEdgesBegin()->GetNode();
is_fusable = IsCastFusable(next_node, true, logger);
}

return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "reduced", logger);
}

void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
if (op_registrations.op_builder_map.count(op_type) > 0)
return;
Expand Down
30 changes: 16 additions & 14 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ namespace webnn {

Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const {
ORT_RETURN_IF_NOT(
IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(),
model_builder.GetOpSupportLimits(), logger),
"Unsupported operator ", node.OpType());
ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger));
return Status::OK();
}
Expand All @@ -28,20 +24,25 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&

bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
if (!HasSupportedInputs(initializers, node, wnn_limits, logger))
bool& is_fusable, const logging::Logger& logger) const {
if (!HasSupportedOpSet(node, logger))
return false;

if (!HasSupportedOutputs(node, wnn_limits, logger))
if (!IsOpSupportedImpl(initializers, node, device_type, logger))
return false;

if (!HasSupportedOpSet(node, logger))
// Don't change the order of the following two calls, in order to check the is_fusable flag.
if (!HasSupportedInputs(initializers, node, wnn_limits, is_fusable, logger))
return false;

if (!HasSupportedOutputs(node, wnn_limits, is_fusable, logger))
return false;

return IsOpSupportedImpl(initializers, node, device_type, logger);
return true;
}

bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits, bool& is_fusable,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
Expand All @@ -50,11 +51,11 @@ bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers,
}
}

return HasSupportedInputsImpl(initializers, node, wnn_limits, logger);
return HasSupportedInputsImpl(initializers, node, wnn_limits, is_fusable, logger);
}

bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits,
const emscripten::val& wnn_limits, bool& /* is_fusable */,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
const auto& input = *node.InputDefs()[0];
Expand All @@ -67,19 +68,20 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializ
}

bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool& is_fusable, const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* output : node.OutputDefs()) {
if (!IsTensorShapeSupported(*output, node_name, logger)) {
return false;
}
}

return HasSupportedOutputsImpl(node, wnn_limits, logger);
return HasSupportedOutputsImpl(node, wnn_limits, is_fusable, logger);
}

bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
const emscripten::val& wnn_limits,
bool& /* is_fusable */,
const logging::Logger& logger) const {
// We only check the type of output 0 by default, specific op builder can override this.
const auto& output = *node.OutputDefs()[0];
Expand Down
14 changes: 9 additions & 5 deletions onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,19 @@ class BaseOpBuilder : public IOpBuilder {
public:
bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool& is_fusable, const logging::Logger& logger) const override;

protected:
virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */,
const WebnnDeviceType /* device_type */, const logging::Logger& /* logger */) const {
return true;
}

virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits, bool& is_fusable,
const logging::Logger& logger) const;
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
bool& is_fusable, const logging::Logger& logger) const;

// ONNX Runtime only *guarantees* support for models stamped
// with opset version 7 or above for opset domain 'ai.onnx'.
Expand All @@ -56,8 +57,11 @@ class BaseOpBuilder : public IOpBuilder {

private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits, bool& is_fusable,
const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, bool& is_fusable,
const logging::Logger& logger) const;

const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input.
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class BinaryOpBuilder : public BaseOpBuilder {
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
const emscripten::val& wnn_limits, bool& /* is_fusable */,
const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -87,7 +88,8 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
}

bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const emscripten::val& wnn_limits, bool& /* is_fusable */,
const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
Expand Down
Loading
Loading