Update decomp list and shark trainer for DLRM

This commit is contained in:
Vivek Khandelwal
2023-09-06 15:33:53 +00:00
parent ede6bf83e2
commit 9681d494eb
2 changed files with 10 additions and 3 deletions

View File

@@ -581,6 +581,8 @@ def import_with_fx(
torch.ops.aten.masked_fill.Tensor,
torch.ops.aten.masked_fill.Scalar,
torch.ops.aten._scaled_dot_product_flash_attention.default,
torch.ops.aten.index_add,
torch.ops.aten.index_add_,
]
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.export import (
@@ -681,5 +683,5 @@ def import_with_fx(
)
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()
mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type)
return mlir_module, func_name

View File

@@ -69,7 +69,7 @@ class SharkTrainer:
self.frontend = frontend
# Training function is needed in the case of torch_fn.
def compile(self, training_fn=None, extra_args=[]):
def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]):
if self.frontend in ["torch", "pytorch"]:
packed_inputs = (
dict(self.model.named_parameters()),
@@ -77,7 +77,12 @@ class SharkTrainer:
tuple(self.input),
)
mlir_module, func_name = import_with_fx(
training_fn, packed_inputs, False, [], training=True
training_fn,
packed_inputs,
False,
[],
training=True,
mlir_type=mlir_type,
)
self.shark_runner = SharkRunner(
mlir_module,