Skip to content

Commit

Permalink
Merge pull request #457 from Xilinx/bump_to_b1769398
Browse files Browse the repository at this point in the history
[AutoBump] Merge with b176939 (Oct 12) (81)
  • Loading branch information
mgehre-amd authored Jan 14, 2025
2 parents c1e0eda + 6c6007d commit 28940f2
Show file tree
Hide file tree
Showing 4 changed files with 368 additions and 0 deletions.
211 changes: 211 additions & 0 deletions lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,29 @@ LogicalResult getListOperands(Value value, SmallVector<Value> &vals) {
return success();
}

LogicalResult constructListFromLiteral(PatternRewriter &rewriter,
ValueTensorLiteralOp literalOp,
SmallVector<Value> &vals) {
// only supports splat ValueTensorLiterals for now. TODO: add support for
// small non-splat valuetensorliterals.
auto ty = dyn_cast<ValueTensorType>(literalOp.getType());
if (!ty || !ty.hasSizes())
return failure();
auto attr = dyn_cast_or_null<SplatElementsAttr>(literalOp.getValue());
if (!attr)
return failure();
auto attrInt = dyn_cast<IntegerAttr>(attr.getSplatValue<Attribute>());
if (!attrInt)
return failure();
IntegerType intty = cast<IntegerType>(attrInt.getType());
if (!intty.isSignedInteger())
return failure();
Value materializedVal = rewriter.create<Torch::ConstantIntOp>(
literalOp.getLoc(), attrInt.getSInt());
vals.resize(vals.size() + ty.getSizes()[0], materializedVal);
return success();
}

LogicalResult getListFromTensor(Value value, SmallVector<Value> &vals) {
constexpr int64_t kMaxFold = 16;
if (auto tensor = value.getDefiningOp<Torch::AtenTensorOp>())
Expand Down Expand Up @@ -351,6 +374,172 @@ class PropagateAtenSliceTensorPattern
};
} // namespace

namespace {
class PropagateAtenWhereSelfPattern : public OpRewritePattern<AtenWhereSelfOp> {
public:
using OpRewritePattern<AtenWhereSelfOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenWhereSelfOp op,
PatternRewriter &rewriter) const override {
Value condition = op.getCondition();
Value self = op.getSelf();
Value other = op.getOther();
auto conditionTy = dyn_cast<Torch::ValueTensorType>(condition.getType());
if (!conditionTy || !conditionTy.hasSizes() ||
conditionTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad condition type");
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad self type");
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad other type");
int64_t conditionSize = selfTy.getSizes()[0];
int64_t selfSize = selfTy.getSizes()[0];
int64_t otherSize = otherTy.getSizes()[0];

if (selfSize != otherSize || selfSize != conditionSize)
return rewriter.notifyMatchFailure(
op,
"unimplemented: support for propogating with implicit broadcasting.");

constexpr int64_t kMaxFold = 16;
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold)
return rewriter.notifyMatchFailure(op,
"arguments are dynamic or too big");

SmallVector<Value> conditionList, selfList, otherList;
if (failed(getListFromTensor(condition, conditionList)) ||
(int64_t)conditionList.size() != conditionSize)
return failure();

// If one of these tensors is a value tensor literal op, we will need to
// create constant ints in the IR to form a list. Before calling
// constructListFromLiteral, we must be certain that the conversion can no
// longer fail, otherwise we will cause an infinite loop of creating a
// constant and removing it.
LogicalResult selfFromList = getListFromTensor(self, selfList);
LogicalResult otherFromList = getListFromTensor(other, otherList);

if (failed(selfFromList) && failed(otherFromList))
return rewriter.notifyMatchFailure(
op, "At least one operand must succeed at constructing a list");

auto selfLiteral = self.getDefiningOp<Torch::ValueTensorLiteralOp>();
auto otherLiteral = other.getDefiningOp<Torch::ValueTensorLiteralOp>();
if (succeeded(selfFromList) && otherLiteral &&
failed(constructListFromLiteral(rewriter, otherLiteral, otherList)))
return failure();
if (succeeded(otherFromList) && selfLiteral &&
failed(constructListFromLiteral(rewriter, selfLiteral, selfList)))
return failure();
if ((int64_t)selfList.size() != selfSize ||
(int64_t)otherList.size() != otherSize)
// this should only occur if we did not generate IR with
// constructListFromLiteral
return failure();

Location loc = op.getLoc();
SmallVector<Value> whereVals;
auto rank0IntTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), selfTy.getDtype());
auto rank0BoolTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>({}), conditionTy.getDtype());
for (uint64_t i = 0; i < selfList.size(); i++) {
Value rank0Cond = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, rank0BoolTy, conditionList[i]);
Value rank0Self = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, rank0IntTy, selfList[i]);
Value rank0Other = rewriter.create<Torch::PrimNumToTensorScalarOp>(
loc, rank0IntTy, otherList[i]);
Value rank0Where = rewriter.create<AtenWhereSelfOp>(
loc, rank0IntTy, rank0Cond, rank0Self, rank0Other);
whereVals.push_back(rewriter.create<AtenItemOp>(
loc, rewriter.getType<Torch::IntType>(), rank0Where));
}
Value list = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(whereVals[0].getType()), whereVals);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
op.getLoc(), rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
op, op.getType(), list, cstNone, cstNone, cstFalse);
return success();
}
};
} // namespace

namespace {
class PropagateAtenEqTensorPattern : public OpRewritePattern<AtenEqTensorOp> {
public:
using OpRewritePattern<AtenEqTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenEqTensorOp op,
PatternRewriter &rewriter) const override {
Value self = op.getSelf();
Value other = op.getOther();
auto selfTy = dyn_cast<Torch::ValueTensorType>(self.getType());
if (!selfTy || !selfTy.hasSizes() || selfTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad self type");
auto otherTy = dyn_cast<Torch::ValueTensorType>(other.getType());
if (!otherTy || !otherTy.hasSizes() || otherTy.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "bad other type");
int64_t selfSize = selfTy.getSizes()[0];
int64_t otherSize = otherTy.getSizes()[0];

if (selfSize != otherSize)
return rewriter.notifyMatchFailure(
op,
"unimplemented: support for propogating with implicit broadcasting.");

constexpr int64_t kMaxFold = 16;
if (selfSize == Torch::kUnknownSize || selfSize > kMaxFold ||
otherSize == Torch::kUnknownSize || otherSize > kMaxFold)
return rewriter.notifyMatchFailure(op,
"self or other is dynamic or too big");

SmallVector<Value> selfList, otherList;
// If one of these tensors is a value tensor literal op, we will need to
// create constant ints in the IR to form a list. Before calling
// constructListFromLiteral, we must be certain that the conversion can no
// longer fail, otherwise we will cause an infinite loop of creating a
// constant and removing it.
LogicalResult selfFromList = getListFromTensor(self, selfList);
LogicalResult otherFromList = getListFromTensor(other, otherList);

if (failed(selfFromList) && failed(otherFromList))
return rewriter.notifyMatchFailure(
op, "At least one operand must succeed at constructing a list");

auto selfLiteral = self.getDefiningOp<Torch::ValueTensorLiteralOp>();
auto otherLiteral = other.getDefiningOp<Torch::ValueTensorLiteralOp>();
if (succeeded(selfFromList) && otherLiteral &&
failed(constructListFromLiteral(rewriter, otherLiteral, otherList)))
return failure();
if (succeeded(otherFromList) && selfLiteral &&
failed(constructListFromLiteral(rewriter, selfLiteral, selfList)))
return failure();
if ((int64_t)selfList.size() != selfSize ||
(int64_t)otherList.size() != otherSize)
// this should only occur if we did not generate IR with
// constructListFromLiteral
return failure();

SmallVector<Value> eqVals;
for (uint64_t i = 0; i < selfList.size(); i++) {
eqVals.push_back(
rewriter.create<AtenEqIntOp>(op.getLoc(), selfList[i], otherList[i]));
}
Value list = rewriter.create<Torch::PrimListConstructOp>(
op.getLoc(), Torch::ListType::get(eqVals[0].getType()), eqVals);
Value cstNone = rewriter.create<Torch::ConstantNoneOp>(op.getLoc());
Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(
op.getLoc(), rewriter.getBoolAttr(false));
rewriter.replaceOpWithNewOp<Torch::AtenTensorOp>(
op, op.getType(), list, cstNone, cstNone, cstFalse);
return success();
}
};
} // namespace

namespace {
class PropagateAtenItemPattern : public OpRewritePattern<AtenItemOp> {
public:
Expand Down Expand Up @@ -454,6 +643,26 @@ class FoldAtenSqueezePattern : public OpRewritePattern<AtenSqueezeOp> {
};
} // namespace

namespace {
class FoldAtenSqueezeDimPattern : public OpRewritePattern<AtenSqueezeDimOp> {
public:
using OpRewritePattern<AtenSqueezeDimOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSqueezeDimOp op,
PatternRewriter &rewriter) const override {
auto resultTy = cast<ValueTensorType>(op.getType());
if (!resultTy.hasSizes() || resultTy.getSizes().size() != 0)
return rewriter.notifyMatchFailure(op, "Unknown result shape");

if (auto atenFull = op.getSelf().getDefiningOp<AtenFullOp>()) {
rewriter.replaceOpWithNewOp<PrimNumToTensorScalarOp>(
op, resultTy, atenFull.getFillValue());
return success();
}
return failure();
}
};
} // namespace

namespace {
class FoldAtenWhereSelf : public OpRewritePattern<AtenWhereSelfOp> {
public:
Expand Down Expand Up @@ -694,6 +903,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase<ScalarizeShapesPass> {
PropagateAtenSliceTensorPattern, FoldAtenTensorSplatPattern,
FoldAtenSqueezePattern, FoldAtenUnsqueezePattern,
FoldAtenWhereSelf, CanonicalizeAtenViewPattern,
PropagateAtenEqTensorPattern, PropagateAtenWhereSelfPattern,
FoldAtenSqueezeDimPattern,
RemoveUnusedPattern<Torch::AtenIntBoolOp>,
RemoveUnusedPattern<Torch::AtenEqIntOp>,
RemoveUnusedPattern<Torch::PrimNumToTensorScalarOp>,
Expand Down
57 changes: 57 additions & 0 deletions lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,62 @@ class DecomposeAtenSizeOp : public OpRewritePattern<AtenSizeOp> {
};
} // namespace

namespace {
class InferTensorOp : public OpRewritePattern<AtenTensorOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTensorOp op,
PatternRewriter &rewriter) const override {
auto context = op.getContext();
auto loc = op.getLoc();
auto result = op.getResult();
auto resultType = cast<BaseTensorType>(result.getType());
if (resultType.hasSizes() && resultType.hasDtype()) {
return rewriter.notifyMatchFailure(
op, "The result of aten.tensor is already a BaseTensorType.");
}

auto inputList = op.getOperand(0);
auto listConstruct = inputList.getDefiningOp<PrimListConstructOp>();
if (!listConstruct) {
return rewriter.notifyMatchFailure(
op, "The operand 0 of aten.tensor is not PrimListConstructOp.");
}

// Currently only support the 1d input list.
SmallVector<int64_t> sizes;
sizes.push_back(listConstruct->getOperands().size());
FailureOr<Type> torchType;
auto eleType = listConstruct->getOperands()[0].getType();
if (isa<Torch::IntType>(eleType)) {
torchType = getTypeForScalarType(op->getContext(),
torch_upstream::ScalarType::Long);
} else if (isa<Torch::FloatType>(eleType)) {
torchType = getTypeForScalarType(op->getContext(),
torch_upstream::ScalarType::Float);
} else {
return rewriter.notifyMatchFailure(
op, "Currently only support Int and Float Type.");
}
auto newResultType = ValueTensorType::get(context, sizes, *torchType);

Value originalTypedValue;
for (OpOperand &use : llvm::make_early_inc_range(result.getUses())) {
if (!originalTypedValue) {
rewriter.setInsertionPointAfter(op);
originalTypedValue =
rewriter.create<TensorStaticInfoCastOp>(loc, resultType, result);
}
use.set(originalTypedValue);
}

result.setType(newResultType);

return success();
}
};
} // namespace

static LogicalResult refineShapeCalculateResult(ShapeCalculateOp op,
int resultNum,
PatternRewriter &rewriter) {
Expand Down Expand Up @@ -135,6 +191,7 @@ class SimplifyShapeCalculationsPass
populateFoldPrimUncheckedCastOpPattern(patterns, context);
patterns.insert<DecomposeAtenSizeOp>(context);
patterns.insert<RefineShapeCalculateOp>(context);
patterns.insert<InferTensorOp>(context);

PrimIfOp::getCanonicalizationPatterns(patterns, context);
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
Expand Down
24 changes: 24 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5941,6 +5941,30 @@ def ConstantBoolParameterModule_basic(module, tu: TestUtils):
# ==============================================================================


class TensorAlloc1dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args(
[
None,
([2, 4, 6], torch.int, True),
]
)
def forward(self, x):
res = torch.tensor([x.shape[0]])
return res


@register_test_case(module_factory=lambda: TensorAlloc1dStaticModule())
def TensorAlloc1dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 4, 6))


# ==============================================================================


class ScalarTensorFloat32Module(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down
Loading

0 comments on commit 28940f2

Please sign in to comment.