Changes incorporating the recent torch_mlir compile api changes.

This commit is contained in:
Prashant Kumar
2022-11-15 09:39:39 +00:00
parent d8a9bee244
commit 7b1f04d121
3 changed files with 59 additions and 48 deletions

View File

@@ -2,10 +2,8 @@ import os
import torch
from shark.shark_inference import SharkInference
from shark.shark_importer import SharkImporter
from torch.fx.experimental.proxy_tensor import make_fx
from stable_args import args
from torch._decomp import get_decompositions
from shark.shark_importer import import_with_fx
if args.import_mlir:
import torch_mlir
@@ -53,51 +51,10 @@ def get_shark_model(tank_url, model_name, extra_args=[]):
return _compile_module(shark_module, model_name, extra_args)
# Converts the torch-module into shark_module.
# Converts the torch-module into a shark_module.
def compile_through_fx(model, inputs, model_name, extra_args=[]):
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
ts_g = torch.jit.trace(fx_g, inputs)
mlir_importer = SharkImporter(
ts_g,
inputs,
frontend="torch",
)
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
mlir_module, func_name = import_with_fx(model, inputs)
shark_module = SharkInference(
mlir_module,

View File

@@ -243,3 +243,58 @@ class SharkImporter:
self.inputs,
golden_out,
)
# Applies fx conversion to the model and imports the mlir.
def import_with_fx(model, inputs, debug=False):
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
# TODO: Control the decompositions.
fx_g = make_fx(
model,
decomposition_table=get_decompositions(
[
torch.ops.aten.embedding_dense_backward,
torch.ops.aten.native_layer_norm_backward,
torch.ops.aten.slice_backward,
torch.ops.aten.select_backward,
torch.ops.aten.norm.ScalarOpt_dim,
torch.ops.aten.native_group_norm,
torch.ops.aten.upsample_bilinear2d.vec,
torch.ops.aten.split.Tensor,
torch.ops.aten.split_with_sizes,
]
),
)(*inputs)
fx_g.graph.set_codegen(torch.fx.graph.CodeGen())
fx_g.recompile()
def strip_overloads(gm):
"""
Modifies the target of graph nodes in :attr:`gm` to strip overloads.
Args:
gm(fx.GraphModule): The input Fx graph module to be modified
"""
for node in gm.graph.nodes:
if isinstance(node.target, torch._ops.OpOverload):
node.target = node.target.overloadpacket
gm.recompile()
strip_overloads(fx_g)
mlir_importer = SharkImporter(
fx_g,
inputs,
frontend="torch",
)
if debug:
(mlir_module, func_name), _, _ = mlir_importer.import_debug()
return mlir_module, func_name
mlir_module, func_name = mlir_importer.import_mlir()
return mlir_module, func_name

View File

@@ -56,9 +56,8 @@ def get_torch_mlir_module(
input: tuple,
dynamic: bool,
jit_trace: bool,
from_torchscript: bool = False,
):
"""Get the MLIR's linalg-on-tensors module from torchscipt module."""
"""Get the MLIR's linalg-on-tensors module from the torchscipt module."""
ignore_traced_shapes = False
if dynamic:
input = create_dynamic_placeholders(input)