mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-02-19 11:56:43 -05:00
155 lines
5.0 KiB
Python
155 lines
5.0 KiB
Python
import functools
|
|
from typing import List, Optional
|
|
import torch
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch._functorch.compile_utils import strip_overloads
|
|
from shark.shark_inference import SharkInference
|
|
from torch._decomp import get_decompositions
|
|
from torch.func import functionalize
|
|
import io
|
|
import torch_mlir
|
|
|
|
|
|
# TODO: Control decompositions.
|
|
def default_decompositions():
|
|
return get_decompositions(
|
|
[
|
|
torch.ops.aten.embedding_dense_backward,
|
|
torch.ops.aten.native_layer_norm_backward,
|
|
torch.ops.aten.slice_backward,
|
|
torch.ops.aten.select_backward,
|
|
torch.ops.aten.norm.ScalarOpt_dim,
|
|
torch.ops.aten.native_group_norm,
|
|
torch.ops.aten.upsample_bilinear2d.vec,
|
|
torch.ops.aten.split.Tensor,
|
|
torch.ops.aten.split_with_sizes,
|
|
torch.ops.aten.native_layer_norm,
|
|
torch.ops.aten.masked_fill.Tensor,
|
|
torch.ops.aten.masked_fill.Scalar,
|
|
]
|
|
)
|
|
|
|
|
|
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
|
removed_indexes = []
|
|
for node in fx_g.graph.nodes:
|
|
if node.op == "output":
|
|
assert (
|
|
len(node.args) == 1
|
|
), "Output node must have a single argument"
|
|
node_arg = node.args[0]
|
|
if isinstance(node_arg, (list, tuple)):
|
|
node_arg = list(node_arg)
|
|
node_args_len = len(node_arg)
|
|
for i in range(node_args_len):
|
|
curr_index = node_args_len - (i + 1)
|
|
if node_arg[curr_index] is None:
|
|
removed_indexes.append(curr_index)
|
|
node_arg.pop(curr_index)
|
|
node.args = (tuple(node_arg),)
|
|
break
|
|
|
|
if len(removed_indexes) > 0:
|
|
fx_g.graph.lint()
|
|
fx_g.graph.eliminate_dead_code()
|
|
fx_g.recompile()
|
|
removed_indexes.sort()
|
|
return removed_indexes
|
|
|
|
|
|
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
|
for node in fx_g.graph.nodes:
|
|
if node.op == "output":
|
|
assert (
|
|
len(node.args) == 1
|
|
), "Output node must have a single argument"
|
|
node_arg = node.args[0]
|
|
if isinstance(node_arg, tuple):
|
|
return len(node_arg) == 0
|
|
return False
|
|
|
|
|
|
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
|
"""
|
|
Replace tuple with tuple element in functions that return one-element tuples.
|
|
Returns true if an unwrapping took place, and false otherwise.
|
|
"""
|
|
unwrapped_tuple = False
|
|
for node in fx_g.graph.nodes:
|
|
if node.op == "output":
|
|
assert (
|
|
len(node.args) == 1
|
|
), "Output node must have a single argument"
|
|
node_arg = node.args[0]
|
|
if isinstance(node_arg, tuple):
|
|
if len(node_arg) == 1:
|
|
node.args = (node_arg[0],)
|
|
unwrapped_tuple = True
|
|
break
|
|
|
|
if unwrapped_tuple:
|
|
fx_g.graph.lint()
|
|
fx_g.recompile()
|
|
return unwrapped_tuple
|
|
|
|
|
|
class SharkBackend:
|
|
def __init__(
|
|
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
|
|
):
|
|
self.fx_g = fx_g
|
|
self.inputs = inputs
|
|
self.shark_module = None
|
|
self.device: str = options.get("device", "cpu")
|
|
self.was_unwrapped: bool = False
|
|
self.none_indices: list = []
|
|
self._modify_fx_g()
|
|
self.compile()
|
|
|
|
def _modify_fx_g(self):
|
|
self.none_indices = _remove_nones(self.fx_g)
|
|
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
|
|
|
|
def compile(self):
|
|
gm = make_fx(
|
|
functionalize(self.fx_g),
|
|
decomposition_table=default_decompositions(),
|
|
)(*self.inputs)
|
|
gm.graph.set_codegen(torch.fx.graph.CodeGen())
|
|
gm.recompile()
|
|
strip_overloads(gm)
|
|
ts_g = torch.jit.script(gm)
|
|
mlir_module = torch_mlir.compile(
|
|
ts_g, self.inputs, output_type="linalg-on-tensors"
|
|
)
|
|
bytecode_stream = io.BytesIO()
|
|
mlir_module.operation.write_bytecode(bytecode_stream)
|
|
bytecode = bytecode_stream.getvalue()
|
|
from shark.shark_inference import SharkInference
|
|
|
|
shark_module = SharkInference(
|
|
mlir_module=bytecode,
|
|
device=self.device,
|
|
mlir_dialect="tm_tensor",
|
|
)
|
|
shark_module.compile(extra_args=[])
|
|
self.shark_module = shark_module
|
|
|
|
def __call__(self, *inputs):
|
|
np_inputs = [x.contiguous().detach().cpu().numpy() for x in inputs]
|
|
np_outs = self.shark_module("forward", np_inputs)
|
|
if self.was_unwrapped:
|
|
np_outs = [
|
|
np_outs,
|
|
]
|
|
|
|
if not isinstance(np_outs, list):
|
|
res = torch.from_numpy(np_outs)
|
|
return res
|
|
|
|
result = [torch.from_numpy(x) for x in np_outs]
|
|
for r_in in self.none_indices:
|
|
result.insert(r_in, None)
|
|
result = tuple(result)
|
|
return result
|