Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.ops.aten.scaled_dot_product_attention is decomposed while compilation? #3973

Closed
ita9naiwa opened this issue Jan 20, 2025 · 2 comments
Closed

Comments

@ita9naiwa
Copy link

Hi,

in torch or tm_tensor dialect, there's ScaledDotProductAttentionOp but

import torch
from torch import nn

import torch_mlir.fx as fx


class Basic(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, Q, K, V):
        return torch.ops.aten.scaled_dot_product_attention(
            Q, K, V,
            attn_mask=None,
            dropout_p=0.0,
            is_causal=False,
            scale=1.0,
            enable_gqa=False)


Q = torch.randn([15, 64])
K = torch.randn([15, 64])
V = torch.randn([15, 64])

module = fx.export_and_import(
    Basic(),
    Q, K, V,
    # The key is to specify `output_type=fx.OutputType.TORCH`
    output_type=fx.OutputType.TORCH
)
print(module)

when I'm trying to lower, it yields decomposed several ops.

module {
  func.func @main(%arg0: !torch.vtensor<[15,64],f32>, %arg1: !torch.vtensor<[15,64],f32>, %arg2: !torch.vtensor<[15,64],f32>) -> !torch.vtensor<[15,64],f32> {
    %false = torch.constant.bool false
    %int6 = torch.constant.int 6
    %0 = torch.vtensor.literal(dense<0.000000e+00> : tensor<f64>) : !torch.vtensor<[],f64>
    %float-Inf = torch.constant.float 0xFFF0000000000000
    %none = torch.constant.none
    %true = torch.constant.bool true
    %int-1 = torch.constant.int -1
    %int-2 = torch.constant.int -2
    %float1.000000e00 = torch.constant.float 1.000000e+00
    %1 = torch.aten.mul.Scalar %arg0, %float1.000000e00 : !torch.vtensor<[15,64],f32>, !torch.float -> !torch.vtensor<[15,64],f32>
    %2 = torch.aten.transpose.int %arg1, %int-2, %int-1 : !torch.vtensor<[15,64],f32>, !torch.int, !torch.int -> !torch.vtensor<[64,15],f32>
    %3 = torch.aten.mul.Scalar %2, %float1.000000e00 : !torch.vtensor<[64,15],f32>, !torch.float -> !torch.vtensor<[64,15],f32>
    %4 = torch.aten.mm %1, %3 : !torch.vtensor<[15,64],f32>, !torch.vtensor<[64,15],f32> -> !torch.vtensor<[15,15],f32>
    %values, %indices = torch.aten.max.dim %4, %int-1, %true : !torch.vtensor<[15,15],f32>, !torch.int, !torch.bool -> !torch.vtensor<[15,1],f32>, !torch.vtensor<[15,1],si64>
    %5 = torch.aten.sub.Tensor %4, %values, %float1.000000e00 : !torch.vtensor<[15,15],f32>, !torch.vtensor<[15,1],f32>, !torch.float -> !torch.vtensor<[15,15],f32>
    %6 = torch.aten.exp %5 : !torch.vtensor<[15,15],f32> -> !torch.vtensor<[15,15],f32>
    %7 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list<int>
    %8 = torch.aten.sum.dim_IntList %6, %7, %true, %none : !torch.vtensor<[15,15],f32>, !torch.list<int>, !torch.bool, !torch.none -> !torch.vtensor<[15,1],f32>
    %9 = torch.aten.div.Tensor %6, %8 : !torch.vtensor<[15,15],f32>, !torch.vtensor<[15,1],f32> -> !torch.vtensor<[15,15],f32>
    %10 = torch.aten.eq.Scalar %4, %float-Inf : !torch.vtensor<[15,15],f32>, !torch.float -> !torch.vtensor<[15,15],i1>
    %11 = torch.aten.all.dim %10, %int-1, %true : !torch.vtensor<[15,15],i1>, !torch.int, !torch.bool -> !torch.vtensor<[15,1],i1>
    %12 = torch.aten.to.dtype %0, %int6, %false, %false, %none : !torch.vtensor<[],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f32>
    %13 = torch.aten.where.self %11, %12, %9 : !torch.vtensor<[15,1],i1>, !torch.vtensor<[],f32>, !torch.vtensor<[15,15],f32> -> !torch.vtensor<[15,15],f32>
    %14 = torch.aten.mm %13, %arg2 : !torch.vtensor<[15,15],f32>, !torch.vtensor<[15,64],f32> -> !torch.vtensor<[15,64],f32>
    return %14 : !torch.vtensor<[15,64],f32>
  }
}
@ita9naiwa ita9naiwa changed the title torch.ops.aten.scaled_dot_product_attention is decomposed while compilation? torch.ops.aten.scaled_dot_product_attention is decomposed while compilation? Jan 20, 2025
@ita9naiwa
Copy link
Author

seems related to the issue #3953 and pr #3956

@ita9naiwa
Copy link
Author

giving decomposition_table=[] args fixes the issue.

module = fx.export_and_import(
    Basic(),
    Q, K, V,
    # The key is to specify `output_type=fx.OutputType.TORCH`
    decomposition_table=[],
    output_type=fx.OutputType.TORCH
)
module {
  func.func @main(%arg0: !torch.vtensor<[15,64],f32>, %arg1: !torch.vtensor<[15,64],f32>, %arg2: !torch.vtensor<[15,64],f32>) -> !torch.vtensor<[15,64],f32> {
    %none = torch.constant.none
    %float0.000000e00 = torch.constant.float 0.000000e+00
    %false = torch.constant.bool false
    %float1.000000e00 = torch.constant.float 1.000000e+00
    %0 = torch.aten.scaled_dot_product_attention %arg0, %arg1, %arg2, %none, %float0.000000e00, %false, %float1.000000e00, %false : !torch.vtensor<[15,64],f32>, !torch.vtensor<[15,64],f32>, !torch.vtensor<[15,64],f32>, !torch.none, !torch.float, !torch.bool, !torch.float, !torch.bool -> !torch.vtensor<[15,64],f32>
    return %0 : !torch.vtensor<[15,64],f32>
  }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant