Compare commits

...

2 Commits

Author SHA1 Message Date
Chi_Liu
dedb995af3 Add decompose of aten._scaled_dot_product_flash_attention_for_cpu (#2064)
New decompose from: https://github.com/pytorch/pytorch/pull/117390
Requied from chatglm model: https://github.com/llvm/torch-mlir/issues/2730
2024-01-15 20:03:17 -08:00
AmosLewis
c199ac78eb Add decompose of aten._scaled_dot_product_flash_attention.default
The new decompose was just implemented from pytorch thes day.
Here is pytorch pr: https://github.com/pytorch/pytorch/pull/117390
This decompose is required from lowering chatglm model in torch-mlir.
Here is the issue:https://github.com/llvm/torch-mlir/issues/2730
2024-01-16 03:03:14 +00:00

View File

@@ -686,6 +686,7 @@ def import_with_fx(
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
torch.ops.aten._scaled_dot_product_flash_attention_for_cpu,
]
if precision in ["int4", "int8"] and not is_gptq:
from brevitas_examples.llm.llm_quant.export import (