Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert-torch-onnx-to-torch generates invalid IR for onnx.Resize where scaling is in the first two dimensions #3453

Open
mgehre-amd opened this issue Jun 12, 2024 · 4 comments
Assignees

Comments

@mgehre-amd
Copy link
Contributor

The code in that pass seems to silently assume the the first two dimensions are not scaled, but ONNX has no such restriction.

With input

func.func @test_resize_middle(%arg0: !torch.vtensor<[1,36,42,384],f32>) -> !torch.vtensor<[1,72,84,384],f32> 
  attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
 %none = torch.constant.none
 %12 = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 2.000000e+00, 1.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
 %19 = torch.operator "onnx.Resize"(%arg0, %none, %12) {
  torch.onnx.coordinate_transformation_mode = "half_pixel",
  torch.onnx.mode = "nearest",
  torch.onnx.nearest_mode = "round_prefer_floor"} : (!torch.vtensor<[1,36,42,384],f32>, !torch.none, !torch.vtensor<[4],f32>) -> !torch.vtensor<[1,72,84,384],f32>
 return %19 : !torch.vtensor<[1,72,84,384],f32>
}

we get

module {
  func.func @test_resize_middle(%arg0: !torch.vtensor<[1,36,42,384],f32>) -> !torch.vtensor<[1,72,84,384],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.vtensor.literal(dense<[1.000000e+00, 2.000000e+00, 2.000000e+00, 1.000000e+00]> : tensor<4xf32>) : !torch.vtensor<[4],f32>
    %none_0 = torch.constant.none
    %int0 = torch.constant.int 0
    %false = torch.constant.bool false
    %true = torch.constant.bool true
    %str = torch.constant.str "nearest_half_pixel,round_prefer_floor"
    %int2 = torch.constant.int 2
    %1 = torch.aten.select.int %0, %int0, %int2 : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
    %2 = torch.aten.item %1 : !torch.vtensor<[1],f32> -> !torch.float
    %int3 = torch.constant.int 3
    %3 = torch.aten.select.int %0, %int0, %int3 : !torch.vtensor<[4],f32>, !torch.int, !torch.int -> !torch.vtensor<[1],f32>
    %4 = torch.aten.item %3 : !torch.vtensor<[1],f32> -> !torch.float
    %5 = torch.prim.ListConstruct %2, %4 : (!torch.float, !torch.float) -> !torch.list<float>
    %6 = torch.aten.__interpolate.size_list_scale_list %arg0, %none_0, %5, %str, %false, %none_0, %false : !torch.vtensor<[1,36,42,384],f32>, !torch.none, !torch.list<float>, !torch.str, !torch.bool, !torch.none, !torch.bool -> !torch.vtensor<[1,72,84,384],f32>
    return %6 : !torch.vtensor<[1,72,84,384],f32>
  }
}

Here, %1 and %3 only read the last two elements out of %0. When lowering this IR to linalg (-convert-torch-to-linalg), we get
error: unexpected error: 'tensor.cast' op operand type 'tensor<1x36x?x?xf32>' and result type 'tensor<1x72x84x384xf32>' are cast incompatible because the scales used by torch.aten.__interpolate.size_list_scale_list are not matching the output shape anymore.

@arnavmehta1
Copy link

I can work on this.

@zjgarvey
Copy link
Collaborator

zjgarvey commented Jan 3, 2025

@arnavmehta1 , Since this is going stale, I'm going to unassign you from this task and hand it off to someone else. Please let me know if you have some active work on this issue.

@bjacobgordon
Copy link
Contributor

@zjgarvey Taking it!

@bjacobgordon
Copy link
Contributor

See #3945

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants