You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
in torch or tm_tensor dialect, there's ScaledDotProductAttentionOp but
import torch
from torch import nn
import torch_mlir.fx as fx
class Basic(nn.Module):
def __init__(self):
super().__init__()
def forward(self, Q, K, V):
return torch.ops.aten.scaled_dot_product_attention(
Q, K, V,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=1.0,
enable_gqa=False)
Q = torch.randn([15, 64])
K = torch.randn([15, 64])
V = torch.randn([15, 64])
module = fx.export_and_import(
Basic(),
Q, K, V,
# The key is to specify `output_type=fx.OutputType.TORCH`
output_type=fx.OutputType.TORCH
)
print(module)
when I'm trying to lower, it yields decomposed several ops.
The text was updated successfully, but these errors were encountered:
ita9naiwa
changed the title
torch.ops.aten.scaled_dot_product_attention is decomposed while compilation?torch.ops.aten.scaled_dot_product_attention is decomposed while compilation?
Jan 20, 2025
giving decomposition_table=[] args fixes the issue.
module=fx.export_and_import(
Basic(),
Q, K, V,
# The key is to specify `output_type=fx.OutputType.TORCH`decomposition_table=[],
output_type=fx.OutputType.TORCH
)
Hi,
in torch or tm_tensor dialect, there's
ScaledDotProductAttentionOp
butwhen I'm trying to lower, it yields decomposed several ops.
The text was updated successfully, but these errors were encountered: