Files
SHARK-Studio/amdshark/torch_mlir_lockstep_tensor.py
pdhirajkumarprasad fe03539901 Migration to AMDShark (#2182)
Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
2025-11-20 12:52:07 +05:30

221 lines
9.2 KiB
Python

# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.
import contextlib
import re
import traceback
import warnings
from typing import Any
import numpy as np
import torch
from torch.utils._pytree import tree_map
from torch_mlir.eager_mode.ir_building import build_mlir_module
from torch_mlir.eager_mode.torch_mlir_dispatch import (
UnsupportedByTorchMlirEagerMode,
normalize_args_kwargs,
check_get_aliased_arg,
)
from torch_mlir.eager_mode import EAGER_MODE_DEBUG
from torch_mlir.eager_mode.torch_mlir_tensor import (
TorchMLIRTensor,
check_requires_grad,
make_wrapper_subclass_from_torch_tensor,
make_bare_wrapper_subclass,
UNSUPPORTED_OPS,
no_dispatch,
)
from torch_mlir.eager_mode import torch_mlir_tensor
from amdshark.iree_eager_backend import EagerModeIREELinalgOnTensorsBackend
backend = EagerModeIREELinalgOnTensorsBackend("cpu")
torch_mlir_tensor.backend = backend
rtol = 1e-04
atol = 1e-05
class TorchMLIRLockstepTensor(TorchMLIRTensor):
"""This class overrides the dispatching for TorchMLIRTensor to allow for an op-by-op numerical comparison between PyTorch and the Torch-MLIR -> IREE backend compilation pipeline. This only supports the IREE backend and focuses on op-by-op level verification.
TODO: Extend this to do a cumulative trace with summary statistics at the end. Possibly requires a wrapper environment to store full trace info.
"""
def __new__(cls, elem, **kwargs):
if kwargs.get("constructing_from_device_tensor", False):
tensor_meta_data = backend.get_torch_metadata(elem, kwargs)
r = make_bare_wrapper_subclass(
cls=cls,
size=tensor_meta_data.size,
strides=tensor_meta_data.strides,
storage_offset=tensor_meta_data.storage_offset,
dtype=tensor_meta_data.dtype,
layout=tensor_meta_data.layout,
device=tensor_meta_data.device,
requires_grad=tensor_meta_data.requires_grad,
)
r.elem = elem
elif isinstance(elem, torch.nn.Parameter):
r = make_wrapper_subclass_from_torch_tensor(
cls, elem.data, **kwargs
)
# This is a hack to handle non-contiguous data through IREE-backend
nt = elem.detach().data.numpy()
if not nt.flags["C_CONTIGUOUS"]:
nt = np.ascontiguousarray(nt, dtype=nt.dtype)
r.elem = backend.transfer_from_torch_to_device(
torch.from_numpy(nt)
)
elif isinstance(elem, torch.Tensor):
r = make_wrapper_subclass_from_torch_tensor(cls, elem, **kwargs)
# Ditto TODO: Find a better way to handle this
nt = elem.numpy()
if not nt.flags["C_CONTIGUOUS"]:
nt = np.ascontiguousarray(nt, dtype=nt.dtype)
r.elem = backend.transfer_from_torch_to_device(
torch.from_numpy(nt)
)
# This branch handles the case when a python scalar is passed to some op
# or is returned from some aten op, such as _local_scalar_dense.
elif isinstance(elem, (int, float, bool)):
return elem
else:
raise ValueError(f"Unknown element type: {type(elem)}")
return r
def __repr__(self):
if self.grad_fn:
return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__}, grad_fn={self.grad_fn})"
else:
return f"TorchMLIRLockstepTensor({self.elem}, backend={backend.__class__.__name__})"
"""This does essentially the same dispatch as TorchMLIRTensor but operates as if debug mode is enabled. The numeric verification happens after the Torch-MLIR result is obtained by comparing against the
"""
@classmethod
def __torch_dispatch__(cls, func, _types, args=(), kwargs=None):
requires_grad = check_requires_grad(*args, **kwargs)
try:
with no_dispatch():
if hasattr(func, "op_name"):
op_name = func.op_name
elif hasattr(func, "__name__"):
# Handle builtin_function_or_method.
op_name = func.__name__
else:
raise RuntimeError(f"op {func} has no name")
if UNSUPPORTED_OPS.match(op_name):
raise UnsupportedByTorchMlirEagerMode(op_name)
if not hasattr(func, "_schema"):
raise RuntimeError(f"op {func} has no schema.")
normalized_kwargs = normalize_args_kwargs(func, args, kwargs)
if "layout" in normalized_kwargs and normalized_kwargs[
"layout"
] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{normalized_kwargs['layout']} layout not supported."
)
if "memory_format" in normalized_kwargs and normalized_kwargs[
"memory_format"
] not in {0, None}:
raise UnsupportedByTorchMlirEagerMode(
f"{normalized_kwargs['memory_format']} memory format not supported."
)
eager_module = build_mlir_module(func, normalized_kwargs)
device_tensor_args = [
kwarg.elem
for _, kwarg in normalized_kwargs.items()
if isinstance(kwarg, cls)
]
assert len(eager_module.body.operations[0].arguments) == len(
device_tensor_args
), "Number of parameters and number of arguments differs."
op_mlir_backend_callable = backend.compile(eager_module)
out = op_mlir_backend_callable(*device_tensor_args)
out = tree_map(
lambda x: cls(
x,
requires_grad=requires_grad,
constructing_from_device_tensor=True,
),
out,
)
# Numeric verification; Value for comparison comes from PyTorch eager
with no_dispatch():
unwrapped_args = tree_map(cls.unwrap, args)
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
if "_reshape_alias" in op_name:
native_out = torch.ops.aten.view(
unwrapped_args[0], unwrapped_args[1]
)
else:
native_out = func(*unwrapped_args, **unwrapped_kwargs)
native_out = tree_map(
lambda x: cls(x, requires_grad=requires_grad), native_out
).elem
tmp_out = out.elem
try:
np.testing.assert_allclose(
native_out.to_host(),
tmp_out.to_host(),
rtol=rtol,
atol=atol,
)
except Exception as e:
shaped_args = [
arg.shape if torch.is_tensor(arg) else arg
for arg in unwrapped_args
]
shaped_kwargs = [
kwarg.shape if torch.is_tensor(kwarg) else kwarg
for kwarg in unwrapped_kwargs
]
warnings.warn(
f"Lockstep accuracy verification failed with error: *{str(e)}*; "
f"Dispatched function name: *{str(func)}*; "
f"Dispatched function args: *{str(shaped_args)}*; "
f"Dispatched function kwargs: *{str(shaped_kwargs)}*; "
)
except Exception as e:
warnings.warn(traceback.format_exc())
if isinstance(e, UnsupportedByTorchMlirEagerMode):
warnings.warn(
f"Couldn't use TorchMLIR eager because current incompatibility: *{str(e)}*; running through PyTorch eager."
)
else:
warnings.warn(
f"Couldn't use TorchMLIR eager because of error: *{str(e)}*; "
f"Running through PyTorch eager"
)
with no_dispatch():
unwrapped_args = tree_map(cls.unwrap, args)
unwrapped_kwargs = tree_map(cls.unwrap, kwargs)
if "_reshape_alias" in op_name:
out = torch.ops.aten.view(
unwrapped_args[0], unwrapped_args[1]
)
else:
out = func(*unwrapped_args, **unwrapped_kwargs)
out = tree_map(lambda x: cls(x, requires_grad=requires_grad), out)
maybe_aliased_arg_name = check_get_aliased_arg(func)
if maybe_aliased_arg_name is not None:
warnings.warn(
f"Found aliased arg, but didn't copy tensor contents. This could lead to incorrect results for E2E model execution but doesn't affect the validity of the lockstep op verification."
)
# TODO: Find a way to handle argument aliasing for IREE backend
# backend.copy_into(normalized_kwargs[maybe_aliased_arg_name].elem, out.elem)
return out