diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 655d1a551b80..0769a00f1dba 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -17,6 +17,7 @@ "Conv1dNoPaddingModule_basic", "Conv1dNoPaddingTransposeModule_basic", "Conv1dNoPaddingGroupModule_basic", + "RepeatInterleaveStaticModule_basic", # tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0 "IndexPutImpl2DNoneIndexBroadcastStaticModule_basic" } @@ -287,6 +288,9 @@ # failed to legalize operation 'torch.aten.clamp' that was explicitly marked illegal "ElementwiseClampIntModule_basic", + + # failed to legalize operation 'torch.constant.int' + "RepeatInterleaveStaticModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -1216,6 +1220,7 @@ "SplitTensorListUnpackModule_basic", "ChunkListUnpack_Module_basic", "ChunkListUnpackUneven_Module_basic", + "RepeatInterleaveStaticModule_basic", } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 9e0a29fc0bd7..70e3135b1cd0 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5071,11 +5071,9 @@ class ConvertAtenFillScalarOp : public OpConversionPattern { return rewriter.notifyMatchFailure( op, "Supplied value must be a Scalar constant"); - if (outType != constOp.getType()) - rewriter.replaceOpWithNewOp(op, outType, constOp); - else - rewriter.replaceOp(op, constOp); - + auto newOp = + rewriter.createOrFold(op.getLoc(), outType, constOp); + rewriter.replaceOp(op, newOp); return success(); } }; @@ -5407,6 +5405,63 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenRepeatInterleaveTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + + auto outputTy = getTypeConverter() + ->convertType(op.getType()) + .dyn_cast(); + if (!outputTy) + return rewriter.notifyMatchFailure( + op, "Only ranked tensor type outputs permitted"); + + auto shape = outputTy.getShape(); + if (shape.size() != 1) + return rewriter.notifyMatchFailure(op, "Only rank 1 tensors are permitted"); + + int64_t outputSize; + if (!matchPattern(op.getOutputSize(), m_TorchConstantInt(&outputSize))) { + return rewriter.notifyMatchFailure( + op, "Currently only scalar constants are supported for " + "output_size in TOSA operation"); + } + + auto repeats = dyn_cast(adaptor.getRepeats().getDefiningOp()); + if (!repeats) + return rewriter.notifyMatchFailure( + op, "Currently only constants are supported for " + "repeats in TOSA operation"); + + auto attr = repeats.getValue(); + if (!attr.isSplat()) + return rewriter.notifyMatchFailure(op, "Only single values are supported."); + + auto elementTy = outputTy.getElementType(); + if (!elementTy.isa()) + return rewriter.notifyMatchFailure(op, + "Only integer values are supported."); + + int64_t numberOfRepeats = attr.getSplatValue().getSExtValue(); + + // Create an array of repeated values + auto createConstArrayOfRepeatedValues = [&](int64_t numOfRepeats) { + SmallVector values; + for (int64_t val = 0; val < outputSize / numberOfRepeats; ++val) { + SmallVector newValues(numberOfRepeats, val); + values.insert(values.end(), newValues.begin(), newValues.end()); + } + return values; + }; + + auto newOp = tosa::getConstTensor( + rewriter, op, createConstArrayOfRepeatedValues(numberOfRepeats), shape, + elementTy); + rewriter.replaceOp(op, *newOp); + return success(); +} + template class ConvertAtenOpToTosaCustomOp : public OpConversionPattern { public: @@ -5703,6 +5758,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { INSERT_ATENOP_PATTERN(AtenCatOp); INSERT_ATENOP_PATTERN(AtenSqrtOp); INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp); + INSERT_ATENOP_PATTERN(AtenRepeatInterleaveTensorOp); #undef INSERT_ATENOP_PATTERN #define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \ diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 56af8ef1e9ef..56a8c6746352 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -1482,6 +1482,27 @@ def RepeatInterleaveModule_basic(module, tu: TestUtils): # ============================================================================== +class RepeatInterleaveStaticModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + x = torch.ones((10), dtype=torch.int).fill_(3) + z = torch.ops.aten.repeat_interleave(x, output_size=30) + return z + + +@register_test_case(module_factory=lambda: RepeatInterleaveStaticModule()) +def RepeatInterleaveStaticModule_basic(module, tu: TestUtils): + module.forward() + +# ============================================================================== + class ExpandModule(torch.nn.Module):