Skip to content

Commit

Permalink
TorchToTosa: Legalize aten.repeat_interleave operation. (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
ttjost authored Jun 14, 2023
1 parent 0a19469 commit 8ca7252
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 5 deletions.
5 changes: 5 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"Conv1dNoPaddingModule_basic",
"Conv1dNoPaddingTransposeModule_basic",
"Conv1dNoPaddingGroupModule_basic",
"RepeatInterleaveStaticModule_basic",
# tm_tensor.scatter' op mismatch in shape of indices and update value at dim#0
"IndexPutImpl2DNoneIndexBroadcastStaticModule_basic"
}
Expand Down Expand Up @@ -287,6 +288,9 @@

# failed to legalize operation 'torch.aten.clamp' that was explicitly marked illegal
"ElementwiseClampIntModule_basic",

# failed to legalize operation 'torch.constant.int'
"RepeatInterleaveStaticModule_basic",
}

TORCHDYNAMO_CRASHING_SET = {
Expand Down Expand Up @@ -1216,6 +1220,7 @@
"SplitTensorListUnpackModule_basic",
"ChunkListUnpack_Module_basic",
"ChunkListUnpackUneven_Module_basic",
"RepeatInterleaveStaticModule_basic",
}

MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | {
Expand Down
66 changes: 61 additions & 5 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5071,11 +5071,9 @@ class ConvertAtenFillScalarOp : public OpConversionPattern<AtenOpT> {
return rewriter.notifyMatchFailure(
op, "Supplied value must be a Scalar constant");

if (outType != constOp.getType())
rewriter.replaceOpWithNewOp<tosa::CastOp>(op, outType, constOp);
else
rewriter.replaceOp(op, constOp);

auto newOp =
rewriter.createOrFold<tosa::CastOp>(op.getLoc(), outType, constOp);
rewriter.replaceOp(op, newOp);
return success();
}
};
Expand Down Expand Up @@ -5407,6 +5405,63 @@ LogicalResult ConvertAtenOp<AtenEmptyMemoryFormatOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenRepeatInterleaveTensorOp>::matchAndRewrite(
AtenRepeatInterleaveTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

auto outputTy = getTypeConverter()
->convertType(op.getType())
.dyn_cast<RankedTensorType>();
if (!outputTy)
return rewriter.notifyMatchFailure(
op, "Only ranked tensor type outputs permitted");

auto shape = outputTy.getShape();
if (shape.size() != 1)
return rewriter.notifyMatchFailure(op, "Only rank 1 tensors are permitted");

int64_t outputSize;
if (!matchPattern(op.getOutputSize(), m_TorchConstantInt(&outputSize))) {
return rewriter.notifyMatchFailure(
op, "Currently only scalar constants are supported for "
"output_size in TOSA operation");
}

auto repeats = dyn_cast<tosa::ConstOp>(adaptor.getRepeats().getDefiningOp());
if (!repeats)
return rewriter.notifyMatchFailure(
op, "Currently only constants are supported for "
"repeats in TOSA operation");

auto attr = repeats.getValue();
if (!attr.isSplat())
return rewriter.notifyMatchFailure(op, "Only single values are supported.");

auto elementTy = outputTy.getElementType();
if (!elementTy.isa<mlir::IntegerType>())
return rewriter.notifyMatchFailure(op,
"Only integer values are supported.");

int64_t numberOfRepeats = attr.getSplatValue<llvm::APInt>().getSExtValue();

// Create an array of repeated values
auto createConstArrayOfRepeatedValues = [&](int64_t numOfRepeats) {
SmallVector<int64_t> values;
for (int64_t val = 0; val < outputSize / numberOfRepeats; ++val) {
SmallVector<int64_t> newValues(numberOfRepeats, val);
values.insert(values.end(), newValues.begin(), newValues.end());
}
return values;
};

auto newOp = tosa::getConstTensor<int64_t>(
rewriter, op, createConstArrayOfRepeatedValues(numberOfRepeats), shape,
elementTy);
rewriter.replaceOp(op, *newOp);
return success();
}

template <typename AtenOpT>
class ConvertAtenOpToTosaCustomOp : public OpConversionPattern<AtenOpT> {
public:
Expand Down Expand Up @@ -5703,6 +5758,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
INSERT_ATENOP_PATTERN(AtenRepeatInterleaveTensorOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
21 changes: 21 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,6 +1482,27 @@ def RepeatInterleaveModule_basic(module, tu: TestUtils):

# ==============================================================================

class RepeatInterleaveStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
])
def forward(self):
x = torch.ones((10), dtype=torch.int).fill_(3)
z = torch.ops.aten.repeat_interleave(x, output_size=30)
return z


@register_test_case(module_factory=lambda: RepeatInterleaveStaticModule())
def RepeatInterleaveStaticModule_basic(module, tu: TestUtils):
module.forward()

# ==============================================================================


class ExpandModule(torch.nn.Module):

Expand Down

0 comments on commit 8ca7252

Please sign in to comment.