Skip to content

Commit

Permalink
build: manually update PyTorch version (llvm#3727)
Browse files Browse the repository at this point in the history
Set PyTorch and TorchVision version to nightly release 2024-10-15.

Tracker issue for the failing tests added to xfail_set in this PR.
Issue: llvm#3796
This commit disables the failing sparse tensor tests since they are not 
maintained on day-to-day basis and blocks the roll PyTorch update for now.

Signed-Off By: Vivek Khandelwal <[email protected]>
  • Loading branch information
vivekkhandelwal1 authored and mgehre-amd committed Jan 16, 2025
1 parent bcbcdb4 commit 03816f9
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 191 deletions.
24 changes: 24 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6317,6 +6317,30 @@ def Torch_AtenDotOp : Torch_Op<"aten.dot", [
let hasCanonicalizer = 1;
}

def Torch_AtenOuterOp : Torch_Op<"aten.outer", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::outer : (Tensor, Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$vec2
);
let results = (outs
AnyTorchOptionalTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenOuterOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 2, 1);
}
void AtenOuterOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 2, 1);
}
}];
}

def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
66 changes: 24 additions & 42 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7607,6 +7607,13 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" } : (!torch.int, !torch.bool) -> ()\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.outer\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %int0 = torch.constant.int 0\n"
" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %1 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list<int>, !torch.int -> !torch.int\n"
" %2 = torch.prim.ListConstruct %0, %1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" return %2 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.dot\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>) -> !torch.list<int> {\n"
" %0 = torch.prim.ListConstruct : () -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -13441,6 +13448,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.outer\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.mm\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>) -> !torch.int {\n"
" %false = torch.constant.bool false\n"
" %int5 = torch.constant.int 5\n"
Expand Down Expand Up @@ -13851,63 +13866,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list<optional<int>>\n"
" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %3 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list<int>\n"
" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %4 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
" %str = torch.constant.str \"AssertionError: \"\n"
" %int11 = torch.constant.int 11\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %3 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %4 = torch.aten.ne.int %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %4 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %5 = torch.aten.ne.int %2#1, %int11 : !torch.int, !torch.int -> !torch.bool\n"
" torch.prim.If %5 -> () {\n"
" torch.prim.If.yield\n"
" } else {\n"
" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n"
" torch.prim.If.yield\n"
" }\n"
" %6 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %7 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %8 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%6, %7) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %8 : !torch.int\n"
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.addcdiv\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.number) -> !torch.int {\n"
" %int6 = torch.constant.int 6\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %2:2 = torch.prim.TupleUnpack %arg2 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %3 = torch.prim.ListConstruct %0#0, %1#0, %2#0 : (!torch.int, !torch.int, !torch.int) -> !torch.list<optional<int>>\n"
" %4 = torch.prim.ListConstruct %0#1, %1#1, %2#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>\n"
" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list<optional<int>>, !torch.list<int>) -> !torch.int\n"
" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n"
" %7 = torch.prim.If %6 -> (!torch.int) {\n"
" torch.prim.If.yield %int6 : !torch.int\n"
" } else {\n"
" torch.prim.If.yield %5 : !torch.int\n"
" }\n"
" return %7 : !torch.int\n"
" return %5 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.add.Scalar\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.number, %arg2: !torch.number) -> !torch.int {\n"
" %none = torch.constant.none\n"
Expand Down
101 changes: 74 additions & 27 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,6 @@
"ElementwiseDequantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseQuantizePerTensorUIntModule_basic",
"ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"EqIntModule_basic",
"FloatImplicitModule_basic",
Expand Down Expand Up @@ -527,9 +523,6 @@
"RepeatInterleaveStaticModule_basic",
"RsubInt0d_NumToTensor_Module_basic",
"ScalarImplicitFloatModule_basic",
"SignAndLogarithmOfDeterminantModule_F32",
"SignAndLogarithmOfDeterminantBatchedModule_F32",
"SignAndLogarithmOfDeterminantDynamicModule_F32",
"SortIntListReverse_basic",
"SortIntList_basic",
"SplitDimDynamicModule_basic",
Expand Down Expand Up @@ -562,6 +555,34 @@
"SplitTensorNegativeDimModule_basic",
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"AdaptiveMaxPool1dDynamicNoBatch_basic",
"AdaptiveMaxPool1dDynamic_basic",
"AdaptiveMaxPool1dStatic_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"IndexPutImpl1DFloatAccumulateModule_basic",
"IndexPutImpl1DFloatNonAccumulateModule_basic",
"IndexPutImpl1DIntAccumulateModule_basic",
"IndexPutImpl1DIntNonAccumulateModule_basic",
"IndexPutImpl2DFloatNonAccumulateModule_basic",
"IndexPutImpl2DImplicitModule_basic",
"IndexPutImpl2DIndexModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"InterpolateDynamicModule_sizes_nearest",
"IouOfModule_basic",
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
"OneHotModule_basic",
}

if torch_version_for_comparison() < version.parse("2.6.0.dev"):
Expand All @@ -586,6 +607,7 @@
# Runtime op verification: out-of-bounds access
"_SoftmaxModule_basic",
"UpSampleNearest2dDynamicFactor_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
}

FX_IMPORTER_STABLEHLO_XFAIL_SET = {
Expand Down Expand Up @@ -614,10 +636,6 @@
"ElementwiseRemainderTensorModule_Int_Float_NegativeDivisor_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDividend_basic",
"ElementwiseRemainderTensorModule_Int_NegativeDivisor_basic",
"ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"MaxPool1dCeilModeTrueModule_basic",
"MaxPool1dStaticCeilModeTrueModule_basic",
"MaxUnpool3dModulePad0_basic",
Expand Down Expand Up @@ -651,7 +669,6 @@
"AdaptiveAvgPool3dDynamic_basic",
"AdaptiveMaxPool1dDynamicNoBatch_basic",
"AdaptiveMaxPool1dDynamic_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"AdaptiveMaxPool1dStatic_basic",
"AdaptiveMaxPool2dDynamicNoBatch_basic",
"AdaptiveMaxPool2dDynamicWithIndices_basic",
Expand Down Expand Up @@ -818,12 +835,7 @@
"MaxPool2dWithIndicesBackwardStatic3DModule_basic",
"MaxPool2dWithIndicesBackwardStatic4DModule_basic",
"MaxPool3dCeilModeTrueModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic",
"MaxPool3dModule_basic",
"MaxPool3dStaticCeilModeTrueModule_basic",
"MaxPool3dStaticModule_basic",
"MaxPool3dWithIndicesAllNegativeValuesModule_basic",
"MaxPool3dWithIndicesAllOnesModule_basic",
"MaxPool3dWithIndicesCeilModeTrueModule_basic",
Expand Down Expand Up @@ -980,6 +992,51 @@
"Unfold_Module_Rank_Zero_basic",
"Unfold_Module_Rank_Zero_Size_Zero_basic",
"Unfold_Module_Dynamic_basic",
"AdaptiveAvgPool1dGeneralDynamic_basic",
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
"AdaptiveAvgPool1dStaticLargerOutput_basic",
"AdaptiveAvgPool2dDynamicNoBatch_basic",
"AdaptiveAvgPool2dDynamic_basic",
"AddIntModule_basic",
"AtenIntTensorByteDtypeModule_basic",
"AtenIntTensorCharDtypeModule_basic",
"AtenItemIntOpModule_basic",
"CrossEntropyLossModule_basic",
"CrossEntropyLossNoReductionModule_basic",
"EinsumStaticContractRhsModule_basic",
"EinsumStaticFourDimensionModule_basic",
"EinsumStaticModule_basic",
"EinsumStaticWithEllipsisSlicingAndBroadcastModule_basic",
"EinsumStaticWithEllipsisSlicingModule_basic",
"ElementwiseExpm1IntModule_basic",
"ElementwiseExpm1Module_basic",
"InterpolateDynamicModule_sizes_nearest",
"IouOfModule_basic",
"IscloseStaticModuleTrue_basic",
"IscloseStaticModule_basic",
"MeshgridIndexingIJ_basic",
"MeshgridIndexingXY_basic",
"Meshgrid_basic",
"MulIntModule_basic",
"OneHotModule_basic",
"ReduceFrobeniusNormComplexModule_basic",
"ScalarImplicitIntModule_basic",
"ScaledDotProductAttentionBoolMaskModule_basic",
"ScaledDotProductAttentionDifferentCausalModule_basic",
"ScaledDotProductAttentionDifferentDynamicCausalModule_basic",
"ScaledDotProductAttentionDifferentModule_basic",
"ScaledDotProductAttentionMaskModule_basic",
"ScaledDotProductAttentionSameCausalModule_basic",
"ScaledDotProductAttentionSameDynamicModule_basic",
"ScaledDotProductAttentionSameModule_basic",
"SubIntModule_basic",
"TensorToIntZeroRank_basic",
"UpSampleNearest2dDynamicFactor_basic",
"UpSampleNearest2dDynamicSize_basic",
"UpSampleNearest2dStaticFactor_basic",
"UpSampleNearest2dStaticSize_basic",
"UpSampleNearest2d_basic",
}

FX_IMPORTER_STABLEHLO_CRASHING_SET = {
Expand Down Expand Up @@ -3535,7 +3592,6 @@
"SplitWithSizesListUnpackModule_basic",
"SplitWithSizes_Module_basic",
"ElementwiseCreateComplexModule_basic",
"AdaptiveMaxPool1dDimOneStatic_basic",
"AtenPolarDoubleModule_basic",
"AtenPolarFloatModule_basic",
"HstackBasicComplexModule_basic",
Expand All @@ -3553,10 +3609,6 @@
"Conv_Transpose3dStaticModule_basic",
"ElementwiseFloatTensorGtIntTensorModule_basic",
"ElementwiseIntTensorLtFloatTensorModule_basic",
"ElementwiseRreluEvalModule_basic",
"ElementwiseRreluEvalStaticModule_basic",
"ElementwiseRreluTrainModule_basic",
"ElementwiseRreluTrainStaticModule_basic",
"IndexPutWithNoneAndBroadcastModule_basic",
"MaskedScatterStaticBasic_basic",
"MaxUnpool3dModulePad0_basic",
Expand Down Expand Up @@ -3866,12 +3918,7 @@
"MaxPool2dWithIndicesNonDefaultStrideModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",
"MaxPool3dCeilModeTrueModule_basic",
"MaxPool3dEmptyStrideStaticModule_basic",
"MaxPool3dLargeDatadModule_basic",
"MaxPool3dModuleRandomSimple_basic",
"MaxPool3dModule_basic",
"MaxPool3dStaticCeilModeTrueModule_basic",
"MaxPool3dStaticModule_basic",
"MaxPool3dWithIndicesAllNegativeValuesModule_basic",
"MaxPool3dWithIndicesAllOnesModule_basic",
"MaxPool3dWithIndicesCeilModeTrueModule_basic",
Expand Down
Loading

0 comments on commit 03816f9

Please sign in to comment.