From ab62f35373c3944b68e564214fd04fff39dd92fc Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 11 Oct 2024 11:15:17 -0500 Subject: [PATCH 1/2] Add more patterns to scalarize-shapes pass (#3781) -Adds patterns for propagating shapes through AtenWhereSelf and AtenEqTensor -Adds fold pattern for a rank0 squeezeDim of a full op -Adds support for getting a list from a splat ValueTensorLiteralOp for materializing scalar comparisons in where.self and eq.tensor With a bit of hammering, these changes should unblock several IREE inference failures. --- .../Torch/Transforms/ScalarizeShapes.cpp | 211 ++++++++++++++++++ test/Dialect/Torch/scalarize-shapes.mlir | 76 +++++++ 2 files changed, 287 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 168518e3d5c0..dd2f835ed8a3 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -63,6 +63,29 @@ LogicalResult getListOperands(Value value, SmallVector &vals) { return success(); } +LogicalResult constructListFromLiteral(PatternRewriter &rewriter, + ValueTensorLiteralOp literalOp, + SmallVector &vals) { + // only supports splat ValueTensorLiterals for now. TODO: add support for + // small non-splat valuetensorliterals. + auto ty = dyn_cast(literalOp.getType()); + if (!ty || !ty.hasSizes()) + return failure(); + auto attr = dyn_cast_or_null(literalOp.getValue()); + if (!attr) + return failure(); + auto attrInt = dyn_cast(attr.getSplatValue()); + if (!attrInt) + return failure(); + IntegerType intty = cast(attrInt.getType()); + if (!intty.isSignedInteger()) + return failure(); + Value materializedVal = rewriter.create( + literalOp.getLoc(), attrInt.getSInt()); + vals.resize(vals.size() + ty.getSizes()[0], materializedVal); + return success(); +} + LogicalResult getListFromTensor(Value value, SmallVector &vals) { constexpr int64_t kMaxFold = 16; if (auto tensor = value.getDefiningOp()) @@ -351,6 +374,172 @@ class PropagateAtenSliceTensorPattern }; } // namespace +namespace { +class PropagateAtenWhereSelfPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenWhereSelfOp op, + PatternRewriter &rewriter) const override { + Value condition = op.getCondition(); + Value self = op.getSelf(); + Value other = op.getOther(); + auto conditionTy = dyn_cast(condition.getType()); + if (!conditionTy || !conditionTy.hasSizes() || + conditionTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad condition type"); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad self type"); + auto otherTy = dyn_cast(other.getType()); + if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad other type"); + int64_t conditionSize = selfTy.getSizes()[0]; + int64_t selfSize = selfTy.getSizes()[0]; + int64_t otherSize = otherTy.getSizes()[0]; + + if (selfSize != otherSize || selfSize != conditionSize) + return rewriter.notifyMatchFailure( + op, + "unimplemented: support for propogating with implicit broadcasting."); + + constexpr int64_t kMaxFold = 16; + if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold) + return rewriter.notifyMatchFailure(op, + "arguments are dynamic or too big"); + + SmallVector conditionList, selfList, otherList; + if (failed(getListFromTensor(condition, conditionList)) || + (int64_t)conditionList.size() != conditionSize) + return failure(); + + // If one of these tensors is a value tensor literal op, we will need to + // create constant ints in the IR to form a list. Before calling + // constructListFromLiteral, we must be certain that the conversion can no + // longer fail, otherwise we will cause an infinite loop of creating a + // constant and removing it. + LogicalResult selfFromList = getListFromTensor(self, selfList); + LogicalResult otherFromList = getListFromTensor(other, otherList); + + if (failed(selfFromList) && failed(otherFromList)) + return rewriter.notifyMatchFailure( + op, "At least one operand must succeed at constructing a list"); + + auto selfLiteral = self.getDefiningOp(); + auto otherLiteral = other.getDefiningOp(); + if (succeeded(selfFromList) && otherLiteral && + failed(constructListFromLiteral(rewriter, otherLiteral, otherList))) + return failure(); + if (succeeded(otherFromList) && selfLiteral && + failed(constructListFromLiteral(rewriter, selfLiteral, selfList))) + return failure(); + if ((int64_t)selfList.size() != selfSize || + (int64_t)otherList.size() != otherSize) + // this should only occur if we did not generate IR with + // constructListFromLiteral + return failure(); + + Location loc = op.getLoc(); + SmallVector whereVals; + auto rank0IntTy = rewriter.getType( + ArrayRef({}), selfTy.getDtype()); + auto rank0BoolTy = rewriter.getType( + ArrayRef({}), conditionTy.getDtype()); + for (uint64_t i = 0; i < selfList.size(); i++) { + Value rank0Cond = rewriter.create( + loc, rank0BoolTy, conditionList[i]); + Value rank0Self = rewriter.create( + loc, rank0IntTy, selfList[i]); + Value rank0Other = rewriter.create( + loc, rank0IntTy, otherList[i]); + Value rank0Where = rewriter.create( + loc, rank0IntTy, rank0Cond, rank0Self, rank0Other); + whereVals.push_back(rewriter.create( + loc, rewriter.getType(), rank0Where)); + } + Value list = rewriter.create( + op.getLoc(), Torch::ListType::get(whereVals[0].getType()), whereVals); + Value cstNone = rewriter.create(op.getLoc()); + Value cstFalse = rewriter.create( + op.getLoc(), rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + op, op.getType(), list, cstNone, cstNone, cstFalse); + return success(); + } +}; +} // namespace + +namespace { +class PropagateAtenEqTensorPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenEqTensorOp op, + PatternRewriter &rewriter) const override { + Value self = op.getSelf(); + Value other = op.getOther(); + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad self type"); + auto otherTy = dyn_cast(other.getType()); + if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1) + return rewriter.notifyMatchFailure(op, "bad other type"); + int64_t selfSize = selfTy.getSizes()[0]; + int64_t otherSize = otherTy.getSizes()[0]; + + if (selfSize != otherSize) + return rewriter.notifyMatchFailure( + op, + "unimplemented: support for propogating with implicit broadcasting."); + + constexpr int64_t kMaxFold = 16; + if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold || + otherSize == Torch::kUnknownSize || otherSize > kMaxFold) + return rewriter.notifyMatchFailure(op, + "self or other is dynamic or too big"); + + SmallVector selfList, otherList; + // If one of these tensors is a value tensor literal op, we will need to + // create constant ints in the IR to form a list. Before calling + // constructListFromLiteral, we must be certain that the conversion can no + // longer fail, otherwise we will cause an infinite loop of creating a + // constant and removing it. + LogicalResult selfFromList = getListFromTensor(self, selfList); + LogicalResult otherFromList = getListFromTensor(other, otherList); + + if (failed(selfFromList) && failed(otherFromList)) + return rewriter.notifyMatchFailure( + op, "At least one operand must succeed at constructing a list"); + + auto selfLiteral = self.getDefiningOp(); + auto otherLiteral = other.getDefiningOp(); + if (succeeded(selfFromList) && otherLiteral && + failed(constructListFromLiteral(rewriter, otherLiteral, otherList))) + return failure(); + if (succeeded(otherFromList) && selfLiteral && + failed(constructListFromLiteral(rewriter, selfLiteral, selfList))) + return failure(); + if ((int64_t)selfList.size() != selfSize || + (int64_t)otherList.size() != otherSize) + // this should only occur if we did not generate IR with + // constructListFromLiteral + return failure(); + + SmallVector eqVals; + for (uint64_t i = 0; i < selfList.size(); i++) { + eqVals.push_back( + rewriter.create(op.getLoc(), selfList[i], otherList[i])); + } + Value list = rewriter.create( + op.getLoc(), Torch::ListType::get(eqVals[0].getType()), eqVals); + Value cstNone = rewriter.create(op.getLoc()); + Value cstFalse = rewriter.create( + op.getLoc(), rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + op, op.getType(), list, cstNone, cstNone, cstFalse); + return success(); + } +}; +} // namespace + namespace { class PropagateAtenItemPattern : public OpRewritePattern { public: @@ -454,6 +643,26 @@ class FoldAtenSqueezePattern : public OpRewritePattern { }; } // namespace +namespace { +class FoldAtenSqueezeDimPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSqueezeDimOp op, + PatternRewriter &rewriter) const override { + auto resultTy = cast(op.getType()); + if (!resultTy.hasSizes() || resultTy.getSizes().size() != 0) + return rewriter.notifyMatchFailure(op, "Unknown result shape"); + + if (auto atenFull = op.getSelf().getDefiningOp()) { + rewriter.replaceOpWithNewOp( + op, resultTy, atenFull.getFillValue()); + return success(); + } + return failure(); + } +}; +} // namespace + namespace { class FoldAtenWhereSelf : public OpRewritePattern { public: @@ -694,6 +903,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern, FoldAtenSqueezePattern, FoldAtenUnsqueezePattern, FoldAtenWhereSelf, CanonicalizeAtenViewPattern, + PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern, + FoldAtenSqueezeDimPattern, RemoveUnusedPattern, RemoveUnusedPattern, RemoveUnusedPattern, diff --git a/test/Dialect/Torch/scalarize-shapes.mlir b/test/Dialect/Torch/scalarize-shapes.mlir index 17f786a8215b..c86844996d9c 100644 --- a/test/Dialect/Torch/scalarize-shapes.mlir +++ b/test/Dialect/Torch/scalarize-shapes.mlir @@ -160,3 +160,79 @@ func.func @unsqueeze_squeeze_combo(%arg0: !torch.vtensor<[?,?,16,64],f32>) -> !t %14 = torch.aten.item %13 : !torch.vtensor<[1],si64> -> !torch.int return %14 : !torch.int } + + +// ----- + +// CHECK-LABEL: @eq_tensor_and_where_self +func.func @eq_tensor_and_where_self(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],si64> { + // CHECK-DAG: %[[false:.*]] = torch.constant.bool false + // CHECK-DAG: %[[none:.*]] = torch.constant.none + // CHECK-DAG: %[[I1:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[I0:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[DIM1:.*]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + // CHECK-DAG: %[[DIM0:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[I1]], %[[DIM1]], %[[DIM1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[4],si64> + %none = torch.constant.none + %0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int0 = torch.constant.int 0 + %2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %4 = torch.prim.ListConstruct %3, %int1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.aten.tensor %4, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> + %6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + %7 = torch.aten.where.self %6, %1, %5 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %7 : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @eq_tensor_from_tensor_and_literal +func.func @eq_tensor_from_tensor_and_literal(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[4],i1> { + // CHECK-DAG: %[[none:.*]] = torch.constant.none + // CHECK-DAG: %[[false:.*]] = torch.constant.bool false + // CHECK-DAG: %[[true:.*]] = torch.constant.bool true + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[false]], %[[true]], %[[false]], %[[false]] : (!torch.bool, !torch.bool, !torch.bool, !torch.bool) -> !torch.list + // CHECK: %[[TENSOR:.*]] = torch.aten.tensor %[[LIST]], %[[none]], %[[none]], %[[false]] : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[TENSOR]] : !torch.vtensor<[4],i1> + %none = torch.constant.none + %0 = torch.vtensor.literal(dense<-1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %1 = torch.vtensor.literal(dense<1> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %false = torch.constant.bool false + %int1 = torch.constant.int 1 + %int-1 = torch.constant.int -1 + %int0 = torch.constant.int 0 + %2 = torch.aten.size.int %arg0, %int1 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %3 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %4 = torch.prim.ListConstruct %3, %int-1, %2, %2 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %5 = torch.aten.tensor %4, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[4],si64> + %6 = torch.aten.eq.Tensor %5, %0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %6 : !torch.vtensor<[4],i1> +} + + + +// ----- + +// CHECK-LABEL: @squeeze_dim_full_fold +func.func @squeeze_dim_full_fold(%arg0: !torch.vtensor<[?,?],si64>) -> !torch.int { + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[SZE:.*]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + // CHECK: return %[[SZE]] : !torch.int + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %none = torch.constant.none + %false = torch.constant.bool false + %51 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?],si64>, !torch.int -> !torch.int + %55 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %56 = torch.aten.full %55, %51, %none, %none, %none, %false : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + %57 = torch.aten.squeeze.dim %56, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> + %58 = torch.aten.item %57 : !torch.vtensor<[],si64> -> !torch.int + return %58 : !torch.int +} From b176939808046703bd59f5219a9923d62758fa18 Mon Sep 17 00:00:00 2001 From: yyp0 Date: Sat, 12 Oct 2024 17:51:15 +0800 Subject: [PATCH 2/2] [Torch] support 1d aten tensor shape and dtype infer (#3776) --- .../Transforms/SimplifyShapeCalculations.cpp | 57 +++++++++++++++++++ .../torch_mlir_e2e_test/test_suite/basic.py | 24 ++++++++ 2 files changed, 81 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index 6d2008a28407..f63fb4eb96d4 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -46,6 +46,62 @@ class DecomposeAtenSizeOp : public OpRewritePattern { }; } // namespace +namespace { +class InferTensorOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTensorOp op, + PatternRewriter &rewriter) const override { + auto context = op.getContext(); + auto loc = op.getLoc(); + auto result = op.getResult(); + auto resultType = cast(result.getType()); + if (resultType.hasSizes() && resultType.hasDtype()) { + return rewriter.notifyMatchFailure( + op, "The result of aten.tensor is already a BaseTensorType."); + } + + auto inputList = op.getOperand(0); + auto listConstruct = inputList.getDefiningOp(); + if (!listConstruct) { + return rewriter.notifyMatchFailure( + op, "The operand 0 of aten.tensor is not PrimListConstructOp."); + } + + // Currently only support the 1d input list. + SmallVector sizes; + sizes.push_back(listConstruct->getOperands().size()); + FailureOr torchType; + auto eleType = listConstruct->getOperands()[0].getType(); + if (isa(eleType)) { + torchType = getTypeForScalarType(op->getContext(), + torch_upstream::ScalarType::Long); + } else if (isa(eleType)) { + torchType = getTypeForScalarType(op->getContext(), + torch_upstream::ScalarType::Float); + } else { + return rewriter.notifyMatchFailure( + op, "Currently only support Int and Float Type."); + } + auto newResultType = ValueTensorType::get(context, sizes, *torchType); + + Value originalTypedValue; + for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) { + if (!originalTypedValue) { + rewriter.setInsertionPointAfter(op); + originalTypedValue = + rewriter.create(loc, resultType, result); + } + use.set(originalTypedValue); + } + + result.setType(newResultType); + + return success(); + } +}; +} // namespace + static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op, int resultNum, PatternRewriter &rewriter) { @@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass populateFoldPrimUncheckedCastOpPattern(patterns, context); patterns.insert(context); patterns.insert(context); + patterns.insert(context); PrimIfOp::getCanonicalizationPatterns(patterns, context); Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index ef20079b6f75..bef16f3efcd7 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -5621,6 +5621,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils): # ============================================================================== +class TensorAlloc1dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 4, 6], torch.int, True), + ] + ) + def forward(self, x): + res = torch.tensor([x.shape[0]]) + return res + + +@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule()) +def TensorAlloc1dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 6)) + + +# ============================================================================== + + class ScalarTensorFloat32Module(torch.nn.Module): def __init__(self): super().__init__()