Skip to content

Commit

Permalink
Add decomposition for index_select along with a test (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
flemairen6 authored Jun 12, 2023
1 parent 3f24fe6 commit 917d5a8
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 0 deletions.
5 changes: 5 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectStaticModule_basic",
"IndexTensorStaticModule_basic",
"IndexTensorMultiIndexStaticModule_basic",
"LayerNormLastDimModule_basic",
Expand Down Expand Up @@ -1205,8 +1206,12 @@
"SliceWholeTensorModule_basic",
"TensorFloatModule_basic",
"TensorIntModule_basic",
"IndexSelectNegativeDimModule_basic",
"IndexSelectSingleIdxModule_basic",
"IndexSelectTwoIdxModule_basic",
"IndexSelectWholeDimensionModule_basic",
"IndexSelectWholeTensorModule_basic",
"IndexSelectStaticModule_basic",
"LinalgVectorNormModule_basic",
"LinalgVectorNormKeepDimModule_basic",
"NormScalarOptDimKeepDimModule_basic",
Expand Down
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op,
TypedValue<RankedTensorType> reshapeTo(Location loc, PatternRewriter &rewriter,
Value val, ArrayRef<int64_t> newShape);

TypedValue<RankedTensorType> transposeBy(Location loc,
PatternRewriter &rewriter, Value val,
ArrayRef<int32_t> permutation);

} // namespace tosa
} // namespace mlir

Expand Down
126 changes: 126 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "llvm/ADT/SmallVector.h"
#include <numeric>

using namespace mlir;
using namespace mlir::torch;
Expand Down Expand Up @@ -3928,6 +3930,129 @@ LogicalResult SimplifyAtenOp<AtenConvolutionOp>::matchAndRewrite(
return success();
}

// The goal of this pattern is to handle the case where the indices for all
// dimensions except one are None.
class ConvertAtenIndexTensorOpNone
: public OpConversionPattern<AtenIndexTensorOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(AtenIndexTensorOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// To do so, we rewrite index.Tensor like that :
// - To match tosa format of NxKxC, with K the dimension to extract from:
// - Transpose the dim to extract into position 'K'
// - flatten the other dimensions
// - Reshape to insert a 1x dimension as the N - The format should be
// 1xKxC with C the flattened dimensions
// - Insert a tosa.gather
// - Bring back to the original format:
// - Reshape
// - Transpose
auto loc = op->getLoc();
auto outTy = dyn_cast<RankedTensorType>(
getTypeConverter()->convertType(op.getType()));
if (!outTy || !outTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op.getLoc(),
"unimplemented: Only static shapes are currently supported");

SmallVector<Value> torchIndices;
if (!getListConstructElements(op.getIndices(), torchIndices))
return rewriter.notifyMatchFailure(
op.getLoc(),
"unimplemented: the tensor list is not from list construct");

auto indicesList =
getTypeConvertedValues(rewriter, loc, typeConverter, torchIndices);

// Check that all indices are none but one.
int64_t indexDim = -1;
for (size_t i = 0; i < indicesList.size(); ++i) {
if (!indicesList[i])
continue;
if (indexDim != -1) {
return rewriter.notifyMatchFailure(
op.getLoc(), "unimplemented: only one dimension must be set in "
"indices for this pattern to work");
}
indexDim = i;
}
if (indexDim == -1) {
return rewriter.notifyMatchFailure(op.getLoc(),
"unimplemented: all indices are none");
}

auto indices =
dyn_cast<TypedValue<RankedTensorType>>(indicesList[indexDim]);
if (!indices) {
return rewriter.notifyMatchFailure(
op.getLoc(), "unimplemented: index must be ranked tensor");
}

auto input = adaptor.getSelf();
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy || !inputTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op.getLoc(), "unimplemented: input must have static shapes");
auto inputElemTy = inputTy.getElementType();

// Transpose indexDim into dimension 0
SmallVector<int32_t> transposePerm;
for (int64_t i = 0; i < inputTy.getRank(); ++i)
transposePerm.push_back(i);
transposePerm[0] = indexDim;
transposePerm[indexDim] = 0;

auto transposedInput = tosa::transposeBy(loc, rewriter, input, transposePerm);

// Flatten matrix [k, ...] -> [1, k, c]
auto transposedShape = transposedInput.getType().getShape();
int64_t k = transposedShape[0];
int64_t c = std::accumulate(transposedShape.begin() + 1, transposedShape.end(), 1,
[&](int64_t a, int64_t b) {
return a * b;
});

SmallVector<int64_t> reshapedFormat = {1, k, c};
// Reshapes the input to 1xKx(flattened_dims)
auto reshapedInput =
tosa::reshapeTo(loc, rewriter, transposedInput, reshapedFormat);

auto w = indices.getType().getDimSize(0);
auto reshapedIndices = tosa::reshapeTo(loc, rewriter, indices, {1, w});

// And cast indices to i32
TensorType promotedType =
reshapedIndices.getType().cloneWith(reshapedIndices.getType().getShape(), rewriter.getI32Type());
auto castedIndices = rewriter.create<tosa::CastOp>(op->getLoc(), promotedType, reshapedIndices);

SmallVector<int64_t> gatherShape = {1, w, c};
auto gatherOp = rewriter.create<tosa::GatherOp>(
op->getLoc(), RankedTensorType::get(gatherShape, inputElemTy),
reshapedInput, castedIndices);

// Unflatten [1, w, c] -> [w, ...]
SmallVector<int64_t> unflattenedShape{transposedShape};
unflattenedShape[0] = w;
auto unflattened =
tosa::reshapeTo(loc, rewriter, gatherOp, unflattenedShape);

SmallVector<int32_t> inversePermutation(transposePerm.size(), 0);
for (size_t i = 0; i < transposePerm.size(); ++i)
inversePermutation[transposePerm[i]] = i;


// Transpose 'w' back in the original position of 'k'
auto unTranspose =
tosa::transposeBy(loc, rewriter, unflattened, inversePermutation);

rewriter.replaceOp(op, unTranspose);
return success();
}
};

template <>
LogicalResult ConvertAtenOp<AtenIndexTensorOp>::matchAndRewrite(
AtenIndexTensorOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -5307,6 +5432,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {

patterns.add<SimplifyAten_IndexPutImplOp>(context);
patterns.add<SimplifyAten_IndexPutImplOpNone>(context);
patterns.add<ConvertAtenIndexTensorOpNone>(typeConverter, context);

#define INSERT_SIMPLIFY_OP_PATTERN(AtenOp) \
target.addIllegalOp<AtenOp>(); \
Expand Down
21 changes: 21 additions & 0 deletions lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,27 @@ TypedValue<RankedTensorType> reshapeTo(Location loc, PatternRewriter &rewriter,
loc, newTy, val, rewriter.getDenseI64ArrayAttr(newShape));
}

TypedValue<RankedTensorType> transposeBy(Location loc, PatternRewriter &rewriter,
Value val,
ArrayRef<int32_t> permutation) {
auto tensorTy = dyn_cast<TensorType>(val.getType());
assert(tensorTy);

auto permType = RankedTensorType::get({(int64_t)permutation.size()},
rewriter.getI32Type());
auto permAttr = DenseElementsAttr::get(permType, permutation);
auto permOp = rewriter.create<tosa::ConstOp>(loc, permType, permAttr);

SmallVector<int64_t> newShape{tensorTy.getShape()};
for (size_t i = 0; i < newShape.size(); i++)
newShape[i] = tensorTy.getShape()[permutation[i]];

auto newTy = RankedTensorType::get(newShape, tensorTy.getElementType());

auto v = rewriter.createOrFold<tosa::TransposeOp>(loc, newTy, val, permOp);
return cast<TypedValue<RankedTensorType>>(v);
}

// Template instantiation
template std::optional<Value> getConstTensor<bool>(PatternRewriter &,
Operation *,
Expand Down
1 change: 1 addition & 0 deletions python/torch_mlir/dynamo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _get_decomposition_table():
aten.im2col,
aten.index_select,
aten.linalg_vector_norm,
aten.index_select,
]
# TODO: enable test once 2.1.0 is stable
if torch_baseversion() >= (2, 1):
Expand Down
20 changes: 20 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@

# ==============================================================================

class IndexSelectStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()
self.tensor = torch.ones(2, 3)

@export
@annotate_args([
None,
([3, 3], torch.float32, True),
([1], torch.int, True),
])
def forward(self, x, y):
return torch.ops.aten.index_select(x, 0, y)


@register_test_case(module_factory=lambda: IndexSelectStaticModule())
def IndexSelectStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 3), torch.tensor([1], dtype=torch.int))


class IndexSelectSingleIdxModule(torch.nn.Module):
def __init__(self):
Expand Down

0 comments on commit 917d5a8

Please sign in to comment.