diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 4c6b864f365d..a231b36ef538 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -522,6 +522,7 @@ "IndexSelectWholeDimensionModule_basic", "IndexSelectWholeTensorModule_basic", "IndexSelectNegativeDimModule_basic", + "IndexSelectStaticModule_basic", "IndexTensorStaticModule_basic", "IndexTensorMultiIndexStaticModule_basic", "LayerNormLastDimModule_basic", @@ -1205,8 +1206,12 @@ "SliceWholeTensorModule_basic", "TensorFloatModule_basic", "TensorIntModule_basic", + "IndexSelectNegativeDimModule_basic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", "IndexSelectWholeDimensionModule_basic", "IndexSelectWholeTensorModule_basic", + "IndexSelectStaticModule_basic", "LinalgVectorNormModule_basic", "LinalgVectorNormKeepDimModule_basic", "NormScalarOptDimKeepDimModule_basic", diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index a49b4e35bc22..f58704b2fbe8 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -123,6 +123,10 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, Value val, ArrayRef newShape); +TypedValue transposeBy(Location loc, + PatternRewriter &rewriter, Value val, + ArrayRef permutation); + } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 6b9e8a14df41..f76e29f6e778 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -27,6 +27,8 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "llvm/ADT/SmallVector.h" +#include using namespace mlir; using namespace mlir::torch; @@ -3928,6 +3930,129 @@ LogicalResult SimplifyAtenOp::matchAndRewrite( return success(); } +// The goal of this pattern is to handle the case where the indices for all +// dimensions except one are None. +class ConvertAtenIndexTensorOpNone + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // To do so, we rewrite index.Tensor like that : + // - To match tosa format of NxKxC, with K the dimension to extract from: + // - Transpose the dim to extract into position 'K' + // - flatten the other dimensions + // - Reshape to insert a 1x dimension as the N - The format should be + // 1xKxC with C the flattened dimensions + // - Insert a tosa.gather + // - Bring back to the original format: + // - Reshape + // - Transpose + auto loc = op->getLoc(); + auto outTy = dyn_cast( + getTypeConverter()->convertType(op.getType())); + if (!outTy || !outTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op.getLoc(), + "unimplemented: Only static shapes are currently supported"); + + SmallVector torchIndices; + if (!getListConstructElements(op.getIndices(), torchIndices)) + return rewriter.notifyMatchFailure( + op.getLoc(), + "unimplemented: the tensor list is not from list construct"); + + auto indicesList = + getTypeConvertedValues(rewriter, loc, typeConverter, torchIndices); + + // Check that all indices are none but one. + int64_t indexDim = -1; + for (size_t i = 0; i < indicesList.size(); ++i) { + if (!indicesList[i]) + continue; + if (indexDim != -1) { + return rewriter.notifyMatchFailure( + op.getLoc(), "unimplemented: only one dimension must be set in " + "indices for this pattern to work"); + } + indexDim = i; + } + if (indexDim == -1) { + return rewriter.notifyMatchFailure(op.getLoc(), + "unimplemented: all indices are none"); + } + + auto indices = + dyn_cast>(indicesList[indexDim]); + if (!indices) { + return rewriter.notifyMatchFailure( + op.getLoc(), "unimplemented: index must be ranked tensor"); + } + + auto input = adaptor.getSelf(); + auto inputTy = dyn_cast(input.getType()); + if (!inputTy || !inputTy.hasStaticShape()) + return rewriter.notifyMatchFailure( + op.getLoc(), "unimplemented: input must have static shapes"); + auto inputElemTy = inputTy.getElementType(); + + // Transpose indexDim into dimension 0 + SmallVector transposePerm; + for (int64_t i = 0; i < inputTy.getRank(); ++i) + transposePerm.push_back(i); + transposePerm[0] = indexDim; + transposePerm[indexDim] = 0; + + auto transposedInput = tosa::transposeBy(loc, rewriter, input, transposePerm); + + // Flatten matrix [k, ...] -> [1, k, c] + auto transposedShape = transposedInput.getType().getShape(); + int64_t k = transposedShape[0]; + int64_t c = std::accumulate(transposedShape.begin() + 1, transposedShape.end(), 1, + [&](int64_t a, int64_t b) { + return a * b; + }); + + SmallVector reshapedFormat = {1, k, c}; + // Reshapes the input to 1xKx(flattened_dims) + auto reshapedInput = + tosa::reshapeTo(loc, rewriter, transposedInput, reshapedFormat); + + auto w = indices.getType().getDimSize(0); + auto reshapedIndices = tosa::reshapeTo(loc, rewriter, indices, {1, w}); + + // And cast indices to i32 + TensorType promotedType = + reshapedIndices.getType().cloneWith(reshapedIndices.getType().getShape(), rewriter.getI32Type()); + auto castedIndices = rewriter.create(op->getLoc(), promotedType, reshapedIndices); + + SmallVector gatherShape = {1, w, c}; + auto gatherOp = rewriter.create( + op->getLoc(), RankedTensorType::get(gatherShape, inputElemTy), + reshapedInput, castedIndices); + + // Unflatten [1, w, c] -> [w, ...] + SmallVector unflattenedShape{transposedShape}; + unflattenedShape[0] = w; + auto unflattened = + tosa::reshapeTo(loc, rewriter, gatherOp, unflattenedShape); + + SmallVector inversePermutation(transposePerm.size(), 0); + for (size_t i = 0; i < transposePerm.size(); ++i) + inversePermutation[transposePerm[i]] = i; + + + // Transpose 'w' back in the original position of 'k' + auto unTranspose = + tosa::transposeBy(loc, rewriter, unflattened, inversePermutation); + + rewriter.replaceOp(op, unTranspose); + return success(); + } +}; + template <> LogicalResult ConvertAtenOp::matchAndRewrite( AtenIndexTensorOp op, OpAdaptor adaptor, @@ -5307,6 +5432,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { patterns.add(context); patterns.add(context); + patterns.add(typeConverter, context); #define INSERT_SIMPLIFY_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index a7327783060b..5d929f7c2481 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -400,6 +400,27 @@ TypedValue reshapeTo(Location loc, PatternRewriter &rewriter, loc, newTy, val, rewriter.getDenseI64ArrayAttr(newShape)); } +TypedValue transposeBy(Location loc, PatternRewriter &rewriter, + Value val, + ArrayRef permutation) { + auto tensorTy = dyn_cast(val.getType()); + assert(tensorTy); + + auto permType = RankedTensorType::get({(int64_t)permutation.size()}, + rewriter.getI32Type()); + auto permAttr = DenseElementsAttr::get(permType, permutation); + auto permOp = rewriter.create(loc, permType, permAttr); + + SmallVector newShape{tensorTy.getShape()}; + for (size_t i = 0; i < newShape.size(); i++) + newShape[i] = tensorTy.getShape()[permutation[i]]; + + auto newTy = RankedTensorType::get(newShape, tensorTy.getElementType()); + + auto v = rewriter.createOrFold(loc, newTy, val, permOp); + return cast>(v); +} + // Template instantiation template std::optional getConstTensor(PatternRewriter &, Operation *, diff --git a/python/torch_mlir/dynamo.py b/python/torch_mlir/dynamo.py index 6933c61b49aa..653172081343 100644 --- a/python/torch_mlir/dynamo.py +++ b/python/torch_mlir/dynamo.py @@ -68,6 +68,7 @@ def _get_decomposition_table(): aten.im2col, aten.index_select, aten.linalg_vector_norm, + aten.index_select, ] # TODO: enable test once 2.1.0 is stable if torch_baseversion() >= (2, 1): diff --git a/python/torch_mlir_e2e_test/test_suite/index_select.py b/python/torch_mlir_e2e_test/test_suite/index_select.py index 0fdda62a13a0..e76c85503a4a 100644 --- a/python/torch_mlir_e2e_test/test_suite/index_select.py +++ b/python/torch_mlir_e2e_test/test_suite/index_select.py @@ -11,6 +11,26 @@ # ============================================================================== +class IndexSelectStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.tensor = torch.ones(2, 3) + + @export + @annotate_args([ + None, + ([3, 3], torch.float32, True), + ([1], torch.int, True), + ]) + def forward(self, x, y): + return torch.ops.aten.index_select(x, 0, y) + + +@register_test_case(module_factory=lambda: IndexSelectStaticModule()) +def IndexSelectStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3), torch.tensor([1], dtype=torch.int)) + class IndexSelectSingleIdxModule(torch.nn.Module): def __init__(self):