Incorporate latest changes in the shark_dynamo backend.

This commit is contained in:
Prashant Kumar
2023-02-08 10:56:10 +00:00
parent 3a9cfe113a
commit 3595b4aaff
2 changed files with 9 additions and 4 deletions

View File

@@ -1,6 +1,6 @@
import torchdynamo
import torch
import torch_mlir
import torch._dynamo as torchdynamo
from shark.sharkdynamo.utils import make_shark_compiler

View File

@@ -3,7 +3,7 @@ import time
from typing import List, Optional
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from functorch._src.compile_utils import strip_overloads
from torch._functorch.compile_utils import strip_overloads
from shark.shark_inference import SharkInference
from torch._decomp import get_decompositions
@@ -119,14 +119,19 @@ def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
example_inputs,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
)
import io
bytecode_stream = io.BytesIO()
linalg_module.operation.write_bytecode(bytecode_stream)
mlir_module = bytecode_stream.getvalue()
shark_module = SharkInference(
linalg_module, "forward", mlir_dialect="linalg", device=device
mlir_module, mlir_dialect="linalg", device=device
)
shark_module.compile()
def forward(*inputs):
result = shark_module.forward(inputs)
result = shark_module("forward", inputs)
result = tuple() if result is None else result
return (result,) if was_unwrapped else result