diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 8e4a43c3f24586..5c977055e95dc8 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -15,6 +15,7 @@ #define MLIR_TRANSFORMS_PASSES_H #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/LocationSnapshot.h" #include "mlir/Transforms/ViewOpGraph.h" diff --git a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp index 70fd9bc0a1e68f..b61218bb7f1af6 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp @@ -44,7 +44,8 @@ struct NarrowingPattern : OpRewritePattern { NarrowingPattern(MLIRContext *ctx, const ArithIntNarrowingOptions &options, PatternBenefit benefit = 1) : OpRewritePattern(ctx, benefit), - supportedBitwidths(options.bitwidthsSupported) { + supportedBitwidths(options.bitwidthsSupported.begin(), + options.bitwidthsSupported.end()) { assert(!supportedBitwidths.empty() && "Invalid options"); assert(!llvm::is_contained(supportedBitwidths, 0) && "Invalid bitwidth"); llvm::sort(supportedBitwidths); @@ -757,7 +758,8 @@ struct ArithIntNarrowingPass final MLIRContext *ctx = op->getContext(); RewritePatternSet patterns(ctx); populateArithIntNarrowingPatterns( - patterns, ArithIntNarrowingOptions{bitwidthsSupported}); + patterns, ArithIntNarrowingOptions{ + llvm::to_vector_of(bitwidthsSupported)}); if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns)))) signalPassFailure(); } diff --git a/mlir/tools/mlir-tblgen/PassGen.cpp b/mlir/tools/mlir-tblgen/PassGen.cpp index 90aa67115a4007..655843f26201ac 100644 --- a/mlir/tools/mlir-tblgen/PassGen.cpp +++ b/mlir/tools/mlir-tblgen/PassGen.cpp @@ -97,7 +97,7 @@ static void emitPassOptionsStruct(const Pass &pass, raw_ostream &os) { std::string type = opt.getType().str(); if (opt.isListOption()) - type = "::llvm::ArrayRef<" + type + ">"; + type = "::llvm::SmallVector<" + type + ">"; os.indent(2) << llvm::formatv("{0} {1}", type, opt.getCppVariableName()); @@ -128,8 +128,8 @@ static void emitPassDecls(const Pass &pass, raw_ostream &os) { // Declaration of the constructor with options. if (ArrayRef options = pass.getOptions(); !options.empty()) - os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(const " - "{0}Options &options);\n", + os << llvm::formatv("std::unique_ptr<::mlir::Pass> create{0}(" + "{0}Options options);\n", passName); } @@ -236,7 +236,7 @@ namespace impl {{ const char *const friendDefaultConstructorWithOptionsDeclTemplate = R"( namespace impl {{ - std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options); + std::unique_ptr<::mlir::Pass> create{0}({0}Options options); } // namespace impl )"; @@ -247,8 +247,8 @@ const char *const friendDefaultConstructorDefTemplate = R"( )"; const char *const friendDefaultConstructorWithOptionsDefTemplate = R"( - friend std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ - return std::make_unique(options); + friend std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{ + return std::make_unique(std::move(options)); } )"; @@ -259,8 +259,8 @@ std::unique_ptr<::mlir::Pass> create{0}() {{ )"; const char *const defaultConstructorWithOptionsDefTemplate = R"( -std::unique_ptr<::mlir::Pass> create{0}(const {0}Options &options) {{ - return impl::create{0}(options); +std::unique_ptr<::mlir::Pass> create{0}({0}Options options) {{ + return impl::create{0}(std::move(options)); } )"; @@ -326,10 +326,10 @@ static void emitPassDefs(const Pass &pass, raw_ostream &os) { if (ArrayRef options = pass.getOptions(); !options.empty()) { os.indent(2) << llvm::formatv( - "{0}Base(const {0}Options &options) : {0}Base() {{\n", passName); + "{0}Base({0}Options options) : {0}Base() {{\n", passName); for (const PassOption &opt : pass.getOptions()) - os.indent(4) << llvm::formatv("{0} = options.{0};\n", + os.indent(4) << llvm::formatv("{0} = std::move(options.{0});\n", opt.getCppVariableName()); os.indent(2) << "}\n"; diff --git a/mlir/unittests/TableGen/PassGenTest.cpp b/mlir/unittests/TableGen/PassGenTest.cpp index 859eba7e517869..27f2fa06542961 100644 --- a/mlir/unittests/TableGen/PassGenTest.cpp +++ b/mlir/unittests/TableGen/PassGenTest.cpp @@ -72,8 +72,7 @@ TEST(PassGenTest, PassOptions) { TestPassWithOptionsOptions options; options.testOption = 57; - llvm::SmallVector testListOption = {1, 2}; - options.testListOption = testListOption; + options.testListOption = {1, 2}; const auto unwrap = [](const std::unique_ptr &pass) { return static_cast(pass.get());