mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Update torch_mlir.compile API.
torch_mlir.compile API is updated and verified by compiling all the torch models via generate_sharktank script.
This commit is contained in:
@@ -12,26 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import io
|
||||
import pickle
|
||||
|
||||
from torch_mlir.dialects.torch.importer.jit_ir import (
|
||||
ClassAnnotator,
|
||||
ModuleBuilder,
|
||||
)
|
||||
from torch_mlir_e2e_test.torchscript.serialization import (
|
||||
extract_serializable_annotations,
|
||||
apply_serializable_annotations,
|
||||
SerializableTest,
|
||||
)
|
||||
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
from torch_mlir.passmanager import PassManager
|
||||
from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
||||
from torch_mlir.ir import StringAttr
|
||||
import torch_mlir
|
||||
from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend
|
||||
|
||||
|
||||
def get_module_name_for_asm_dump(module):
|
||||
@@ -45,22 +28,6 @@ def get_module_name_for_asm_dump(module):
|
||||
).value
|
||||
|
||||
|
||||
def get_input_annotations(inputs: tuple, dynamic: bool) -> list:
|
||||
"""TODO: Include necessary documentation"""
|
||||
|
||||
annotations_list = [None]
|
||||
for i in inputs:
|
||||
temp_list = []
|
||||
if dynamic:
|
||||
temp_list.append([-1 for i in range(len(i.shape))])
|
||||
else:
|
||||
temp_list.append(list(i.shape))
|
||||
temp_list.append(i.dtype)
|
||||
temp_list.append(True)
|
||||
annotations_list.append(tuple(temp_list))
|
||||
return annotations_list
|
||||
|
||||
|
||||
def run_on_refbackend(torch_module, inputs):
|
||||
backend = refbackend.RefBackendLinalgOnTensorsBackend()
|
||||
compiled = backend.compile(torch_module)
|
||||
@@ -69,42 +36,16 @@ def run_on_refbackend(torch_module, inputs):
|
||||
return jit_module.forward(np_inputs[0])
|
||||
|
||||
|
||||
def shark_jit_trace(
|
||||
module, input: tuple, dynamic: bool, tracing_required: bool
|
||||
):
|
||||
"""TODO: Include necessary documentation."""
|
||||
|
||||
if not tracing_required:
|
||||
return torch.jit.script(module)
|
||||
|
||||
traced_module = torch.jit.trace_module(module, {"forward": input})
|
||||
actual_script = traced_module._actual_script_module
|
||||
export(actual_script.forward)
|
||||
annotate_args_decorator = annotate_args(
|
||||
get_input_annotations(input, dynamic)
|
||||
)
|
||||
annotate_args_decorator(actual_script.forward)
|
||||
module = torch.jit.script(actual_script)
|
||||
|
||||
# TODO: remove saved annotations.pickle
|
||||
torchscript_module_bytes = module.save_to_buffer(
|
||||
{
|
||||
"annotations.pkl": pickle.dumps(
|
||||
extract_serializable_annotations(module)
|
||||
)
|
||||
}
|
||||
)
|
||||
serializable_test = SerializableTest(
|
||||
unique_name="", program=torchscript_module_bytes, trace=None
|
||||
)
|
||||
_extra_files = {"annotations.pkl": ""}
|
||||
module = torch.jit.load(
|
||||
io.BytesIO(serializable_test.program), _extra_files=_extra_files
|
||||
)
|
||||
# Load the pickled annotations.
|
||||
annotations = pickle.loads(_extra_files["annotations.pkl"])
|
||||
apply_serializable_annotations(module, annotations)
|
||||
return module
|
||||
# Creates dynamic dims for all dims.
|
||||
# TODO: Pass user specified dynamic dims.
|
||||
def create_dynamic_placeholders(inputs):
|
||||
placeholders = []
|
||||
for inp in inputs:
|
||||
placeholder = torch_mlir.TensorPlaceholder.like(
|
||||
inp, dynamic_axes=[i for i in range(len(inp.shape))]
|
||||
)
|
||||
placeholders.append(placeholder)
|
||||
return tuple(placeholders)
|
||||
|
||||
|
||||
def get_torch_mlir_module(
|
||||
@@ -114,39 +55,18 @@ def get_torch_mlir_module(
|
||||
jit_trace: bool,
|
||||
from_torchscript: bool = False,
|
||||
):
|
||||
"""TODO: Include necessary documentation."""
|
||||
"""Get the MLIR's linalg-on-tensors module from torchscipt module."""
|
||||
ignore_traced_shapes = False
|
||||
if dynamic:
|
||||
input = create_dynamic_placeholders(input)
|
||||
if jit_trace:
|
||||
ignore_traced_shapes = True
|
||||
|
||||
# Static modules compiles well with the torch_mlir.compile API.
|
||||
# We will always jit_trace = True with the API since we always
|
||||
# want to propagate static shapes.
|
||||
if not dynamic:
|
||||
module = torch_mlir.compile(
|
||||
module,
|
||||
input,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=jit_trace,
|
||||
)
|
||||
return module
|
||||
|
||||
# Tracing is not required from the aot_module.
|
||||
if not from_torchscript:
|
||||
module = shark_jit_trace(module, input, dynamic, jit_trace)
|
||||
|
||||
mb = ModuleBuilder()
|
||||
class_annotator = ClassAnnotator()
|
||||
class_annotator.exportNone(module._c._type())
|
||||
class_annotator.exportPath(module._c._type(), ["forward"])
|
||||
class_annotator.annotateArgs(
|
||||
module._c._type(),
|
||||
["forward"],
|
||||
get_input_annotations(input, dynamic),
|
||||
module = torch_mlir.compile(
|
||||
module,
|
||||
input,
|
||||
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=jit_trace,
|
||||
ignore_traced_shapes=ignore_traced_shapes,
|
||||
)
|
||||
mb.import_module(module._c, class_annotator)
|
||||
|
||||
with mb.module.context:
|
||||
pm = PassManager.parse(
|
||||
"torchscript-module-to-torch-backend-pipeline,torch-backend-to-linalg-on-tensors-backend-pipeline"
|
||||
)
|
||||
pm.run(mb.module)
|
||||
|
||||
return mb.module
|
||||
return module
|
||||
|
||||
Reference in New Issue
Block a user