From caca87bbe72e24886afbe4b7aec16910ab424de9 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Thu, 27 Jul 2023 16:27:23 +0200 Subject: [PATCH] Import node debugName (FX graph node name) --- .../Torch/Transforms/GlobalizeObjectGraph.cpp | 2 ++ .../Torch/Transforms/InlineGlobalSlots.cpp | 7 +++++-- python/torch_mlir/__init__.py | 4 +++- .../importer/jit_ir/csrc/import_options.h | 4 ++++ .../jit_ir/csrc/import_options_pybind.cpp | 4 +++- .../importer/jit_ir/csrc/node_importer.cpp | 18 ++++++++++++++++++ 6 files changed, 35 insertions(+), 4 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index da8be9b17e0b..a8d404ffd47e 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -524,6 +524,8 @@ static LogicalResult rewriteMonomorphizedFuncClone( auto newOp = OpBuilder(op).create( op.getLoc(), op.getType(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName()); + if (auto fxOutputName = op->getAttr("FXOutputName")) + newOp->setAttr("FXOutputName", fxOutputName); op.replaceAllUsesWith(&*newOp); } toErase.push_back(op); diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 76b57fe8c9a3..b593db3ac77e 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -375,8 +375,11 @@ class InlineGlobalSlotsPass getBackwardSliceIncludingRoot(initialValue); IRMapping mapping; OpBuilder builder(op); - for (Operation *opInSlice : slice) - builder.clone(*opInSlice, mapping); + for (Operation *opInSlice : slice) { + auto clonedOp = builder.clone(*opInSlice, mapping); + if (auto fxOutputName = op->getAttr("FXOutputName")) + clonedOp->setAttr("FXOutputName", fxOutputName); + } auto inlinedInitialValue = mapping.lookup(initialValue); inlinedInitialValue = Torch::adjustStaticInformation( builder, op.getLoc(), inlinedInitialValue, op.getType(), diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 555642ac4947..4ae7bb2c03e2 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -328,7 +328,8 @@ def compile(model: torch.nn.Module, backend_legal_ops: Optional[Sequence[str]] = None, extra_library: Iterable[Callable] = [], verbose: bool = False, - use_make_fx: bool = False): + use_make_fx: bool = False, + add_fx_outputname: bool = False): """Convert a PyTorch model to MLIR. Args: @@ -437,6 +438,7 @@ def compile(model: torch.nn.Module, mb = ModuleBuilder() import_options = ImportOptions() import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes + import_options.addFxOutputName = add_fx_outputname try: original_stderr = sys.stderr sys.stderr = StringIO() diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h index b620c784ac24..240e49e199d4 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options.h @@ -33,6 +33,10 @@ struct ImportOptions { // In that case, the appropriate shape information is provided via the type // bound annotations on the function arguments instead. bool ignoreExistingTensorShapesAndDtypes = false; + + // Whether to add the FXOutputName attribute from the debug name of the jit + // node. + bool addFxOutputName = false; }; } // namespace torch_mlir diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp index b072b0ed922c..e3bc46a123f1 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/import_options_pybind.cpp @@ -20,5 +20,7 @@ void torch_mlir::initImportOptionsBindings(py::module &m) { .def_readwrite("assumeTensorsHaveValueSemantics", &ImportOptions::assumeTensorsHaveValueSemantics) .def_readwrite("ignoreExistingTensorShapesAndDtypes", - &ImportOptions::ignoreExistingTensorShapesAndDtypes); + &ImportOptions::ignoreExistingTensorShapesAndDtypes) + .def_readwrite("addFxOutputName", + &ImportOptions::addFxOutputName); } diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp index 15cffedbe834..39a1cb63185d 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/csrc/node_importer.cpp @@ -84,6 +84,21 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, MlirLocation loc = getMlirLocationFromNode(context, node); auto kind = node->kind(); + + auto outs = node->outputs(); + auto output = outs.size() == 1 ? outs[0] : nullptr; + + auto addFxOutputNameAttr = [&](MlirOperation& operation) { + if (importOptions.addFxOutputName && output && output->hasDebugName()) { + std::string name = output->debugName(); + size_t len = name.size(); + if (len > 2 && name[len-2] == '.' && name[len-1] == '1') + name = name.substr(0, len-2); + auto strAttr = mlirStringAttrGet(context, toMlirStringRef(name)); + mlirOperationSetAttributeByName(operation, toMlirStringRef("FXOutputName"), strAttr); + } + }; + auto createAndMapTrivialNode = [&](Node *node, const std::string &opName, InputsTransformFn t) { std::vector mappedInputs = lookupMappedValues(node->inputs()); @@ -91,6 +106,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, appendToBlock, opName, loc, getMlirTypesFromValues(loc, node->outputs(), importOptions), t ? t(mappedInputs) : mappedInputs); + addFxOutputNameAttr(operation); mapResults(node, operation); }; @@ -102,6 +118,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, getMlirTypesFromValues(loc, node->outputs(), importOptions), lookupMappedValues(node->inputs()), toMlirNamedAttribute(attrName.c_str(), attr)); + addFxOutputNameAttr(operation); mapResults(node, operation); }; @@ -112,6 +129,7 @@ void NodeImporter::importNode(Node *node, MlirBlock appendToBlock, appendToBlock, loc, node->schema(), getMlirTypesFromValues(loc, node->outputs(), importOptions), lookupMappedValues(node->inputs())); + addFxOutputNameAttr(operation); mapResults(node, operation); return; }