mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Add the shark backend for torch.compile API. (#1596)
This commit is contained in:
154
shark/dynamo_backend/utils.py
Normal file
154
shark/dynamo_backend/utils.py
Normal file
@@ -0,0 +1,154 @@
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
from torch.func import functionalize
|
||||
import io
|
||||
import torch_mlir
|
||||
|
||||
|
||||
# TODO: Control decompositions.
|
||||
def default_decompositions():
|
||||
return get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
torch.ops.aten.native_layer_norm,
|
||||
torch.ops.aten.masked_fill.Tensor,
|
||||
torch.ops.aten.masked_fill.Scalar,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]:
|
||||
removed_indexes = []
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, (list, tuple)):
|
||||
node_arg = list(node_arg)
|
||||
node_args_len = len(node_arg)
|
||||
for i in range(node_args_len):
|
||||
curr_index = node_args_len - (i + 1)
|
||||
if node_arg[curr_index] is None:
|
||||
removed_indexes.append(curr_index)
|
||||
node_arg.pop(curr_index)
|
||||
node.args = (tuple(node_arg),)
|
||||
break
|
||||
|
||||
if len(removed_indexes) > 0:
|
||||
fx_g.graph.lint()
|
||||
fx_g.graph.eliminate_dead_code()
|
||||
fx_g.recompile()
|
||||
removed_indexes.sort()
|
||||
return removed_indexes
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
class SharkBackend:
|
||||
def __init__(
|
||||
self, fx_g: torch.fx.GraphModule, inputs: tuple, options: dict
|
||||
):
|
||||
self.fx_g = fx_g
|
||||
self.inputs = inputs
|
||||
self.shark_module = None
|
||||
self.device: str = options.get("device", "cpu")
|
||||
self.was_unwrapped: bool = False
|
||||
self.none_indices: list = []
|
||||
self._modify_fx_g()
|
||||
self.compile()
|
||||
|
||||
def _modify_fx_g(self):
|
||||
self.none_indices = _remove_nones(self.fx_g)
|
||||
self.was_unwrapped = _unwrap_single_tuple_return(self.fx_g)
|
||||
|
||||
def compile(self):
|
||||
gm = make_fx(
|
||||
functionalize(self.fx_g),
|
||||
decomposition_table=default_decompositions(),
|
||||
)(*self.inputs)
|
||||
gm.graph.set_codegen(torch.fx.graph.CodeGen())
|
||||
gm.recompile()
|
||||
strip_overloads(gm)
|
||||
ts_g = torch.jit.script(gm)
|
||||
mlir_module = torch_mlir.compile(
|
||||
ts_g, self.inputs, output_type="linalg-on-tensors"
|
||||
)
|
||||
bytecode_stream = io.BytesIO()
|
||||
mlir_module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=self.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
shark_module.compile(extra_args=[])
|
||||
self.shark_module = shark_module
|
||||
|
||||
def __call__(self, *inputs):
|
||||
np_inputs = [x.detach().cpu().numpy() for x in inputs]
|
||||
np_outs = self.shark_module("forward", np_inputs)
|
||||
if self.was_unwrapped:
|
||||
np_outs = [
|
||||
np_outs,
|
||||
]
|
||||
|
||||
if not isinstance(np_outs, list):
|
||||
res = torch.from_numpy(np_outs)
|
||||
return res
|
||||
|
||||
result = [torch.from_numpy(x) for x in np_outs]
|
||||
for r_in in self.none_indices:
|
||||
result.insert(r_in, None)
|
||||
result = tuple(result)
|
||||
return result
|
||||
@@ -1,11 +0,0 @@
|
||||
1. Install torchdynamo
|
||||
- `git clone https://github.com/pytorch/torchdynamo.git`
|
||||
- `cd torchdynamo`
|
||||
- `python -m pip install -r requirements.txt`
|
||||
- `python setup.py develop`
|
||||
|
||||
2. Install functorch
|
||||
- `python -m pip install -v "git+https://github.com/pytorch/pytorch.git@$(python -c "import torch.version; print(torch.version.git_version)")#subdirectory=functorch"`
|
||||
|
||||
3. Run examples.
|
||||
- `python shark/examples/shark_dynamo/basic_examples.py`
|
||||
@@ -1,163 +0,0 @@
|
||||
import functools
|
||||
import time
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._functorch.compile_utils import strip_overloads
|
||||
from shark.shark_inference import SharkInference
|
||||
from torch._decomp import get_decompositions
|
||||
|
||||
import torch_mlir
|
||||
|
||||
|
||||
# TODO: Control decompositions.
|
||||
def default_decompositions():
|
||||
return get_decompositions(
|
||||
[
|
||||
torch.ops.aten.embedding_dense_backward,
|
||||
torch.ops.aten.native_layer_norm_backward,
|
||||
torch.ops.aten.slice_backward,
|
||||
torch.ops.aten.select_backward,
|
||||
torch.ops.aten.norm.ScalarOpt_dim,
|
||||
torch.ops.aten.native_group_norm,
|
||||
torch.ops.aten.upsample_bilinear2d.vec,
|
||||
torch.ops.aten.split.Tensor,
|
||||
torch.ops.aten.split_with_sizes,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def timeit(*, append_time_to: Optional[List] = None):
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
start_time = time.time_ns()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time_ns()
|
||||
|
||||
if append_time_to is not None:
|
||||
append_time_to.append(end_time - start_time)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool:
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
return len(node_arg) == 0
|
||||
return False
|
||||
|
||||
|
||||
def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool:
|
||||
"""
|
||||
Replace tuple with tuple element in functions that return one-element tuples.
|
||||
Returns true if an unwrapping took place, and false otherwise.
|
||||
"""
|
||||
unwrapped_tuple = False
|
||||
for node in fx_g.graph.nodes:
|
||||
if node.op == "output":
|
||||
assert (
|
||||
len(node.args) == 1
|
||||
), "Output node must have a single argument"
|
||||
node_arg = node.args[0]
|
||||
if isinstance(node_arg, tuple):
|
||||
if len(node_arg) == 1:
|
||||
node.args = (node_arg[0],)
|
||||
unwrapped_tuple = True
|
||||
break
|
||||
|
||||
if unwrapped_tuple:
|
||||
fx_g.graph.lint()
|
||||
fx_g.recompile()
|
||||
return unwrapped_tuple
|
||||
|
||||
|
||||
def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
|
||||
def compiler(
|
||||
fx_graph: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
):
|
||||
"""Compile GraphModule using torch-mlir + SHARK."""
|
||||
if verbose:
|
||||
print("Compiling graph...")
|
||||
|
||||
if _returns_nothing(fx_graph):
|
||||
return fx_graph
|
||||
|
||||
was_unwrapped = _unwrap_single_tuple_return(fx_graph)
|
||||
fx_graph = make_fx(
|
||||
fx_graph, decomposition_table=default_decompositions()
|
||||
)(*example_inputs)
|
||||
strip_overloads(fx_graph)
|
||||
|
||||
if verbose:
|
||||
print("torch.fx graph:")
|
||||
print(fx_graph.graph)
|
||||
|
||||
ts_compiler = torch.jit.trace if use_tracing else torch.jit.script
|
||||
ts_graph = ts_compiler(fx_graph, example_inputs)
|
||||
|
||||
if verbose:
|
||||
torch_mlir_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
example_inputs,
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
)
|
||||
print("\n\ntorch-mlir backend contract graph:")
|
||||
print(torch_mlir_module)
|
||||
|
||||
linalg_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
example_inputs,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
)
|
||||
import io
|
||||
|
||||
bytecode_stream = io.BytesIO()
|
||||
linalg_module.operation.write_bytecode(bytecode_stream)
|
||||
mlir_module = bytecode_stream.getvalue()
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module, mlir_dialect="linalg", device=device
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
def forward(*inputs):
|
||||
result = shark_module("forward", inputs)
|
||||
result = tuple() if result is None else result
|
||||
return (result,) if was_unwrapped else result
|
||||
|
||||
return forward
|
||||
|
||||
return compiler
|
||||
|
||||
|
||||
def check_results(compiled_results, eager_results):
|
||||
for compiled_result, eager_result in zip(compiled_results, eager_results):
|
||||
if not torch.allclose(
|
||||
compiled_result.to("cpu"), eager_result.to("cpu"), atol=1e-5
|
||||
):
|
||||
print("Compiled result does not match eager result")
|
||||
return
|
||||
print("Compiled result matches eager result!")
|
||||
|
||||
|
||||
def print_time_stats(times):
|
||||
times_tensor = torch.tensor(times)
|
||||
|
||||
def quantile_ms(q):
|
||||
return torch.quantile(times_tensor.to(float), q).item() / 1e6
|
||||
|
||||
print(f"Median: {quantile_ms(0.5)} ms")
|
||||
print(f"10%ile: {quantile_ms(0.1)} ms")
|
||||
print(f"90%ile: {quantile_ms(0.9)} ms")
|
||||
print(f"Total: {torch.sum(times_tensor) / 1e6} ms")
|
||||
print()
|
||||
Reference in New Issue
Block a user