From 9681d494eb8dd68685117f3595bb1f032ee0eb4e Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 6 Sep 2023 15:33:53 +0000 Subject: [PATCH] Update decomp list and shark trainer for DLRM --- shark/shark_importer.py | 4 +++- shark/shark_trainer.py | 9 +++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/shark/shark_importer.py b/shark/shark_importer.py index 4109fb0a..6adcc6cd 100644 --- a/shark/shark_importer.py +++ b/shark/shark_importer.py @@ -581,6 +581,8 @@ def import_with_fx( torch.ops.aten.masked_fill.Tensor, torch.ops.aten.masked_fill.Scalar, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten.index_add, + torch.ops.aten.index_add_, ] if precision in ["int4", "int8"]: from brevitas_examples.llm.llm_quant.export import ( @@ -681,5 +683,5 @@ def import_with_fx( ) return mlir_module, func_name - mlir_module, func_name = mlir_importer.import_mlir() + mlir_module, func_name = mlir_importer.import_mlir(mlir_type=mlir_type) return mlir_module, func_name diff --git a/shark/shark_trainer.py b/shark/shark_trainer.py index abf242b7..36916f24 100644 --- a/shark/shark_trainer.py +++ b/shark/shark_trainer.py @@ -69,7 +69,7 @@ class SharkTrainer: self.frontend = frontend # Training function is needed in the case of torch_fn. - def compile(self, training_fn=None, extra_args=[]): + def compile(self, training_fn=None, mlir_type="linalg", extra_args=[]): if self.frontend in ["torch", "pytorch"]: packed_inputs = ( dict(self.model.named_parameters()), @@ -77,7 +77,12 @@ class SharkTrainer: tuple(self.input), ) mlir_module, func_name = import_with_fx( - training_fn, packed_inputs, False, [], training=True + training_fn, + packed_inputs, + False, + [], + training=True, + mlir_type=mlir_type, ) self.shark_runner = SharkRunner( mlir_module,