In [1]:
# standard imports
import torch
from amdshark.iree_utils import get_iree_compiled_module

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# torch dynamo related imports
try:
    import torchdynamo
    from torchdynamo.optimizations.backends import create_backend
    from torchdynamo.optimizations.subgraph import SubGraph
except ModuleNotFoundError:
    print(
        "Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo"
    )
    exit()

# torch-mlir imports for compiling
from torch_mlir import compile, OutputType

[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends.

In [3]:
def toy_example(*args):
    a, b = args

    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

In [4]:
# compiler that lowers fx_graph to through MLIR
def __torch_mlir(fx_graph, *args, **kwargs):
    assert isinstance(
        fx_graph, torch.fx.GraphModule
    ), "Model must be an FX GraphModule."

    def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):
        """Replace tuple with tuple element in functions that return one-element tuples."""

        for node in fx_g.graph.nodes:
            if node.op == "output":
                assert (
                    len(node.args) == 1
                ), "Output node must have a single argument"
                node_arg = node.args[0]
                if isinstance(node_arg, tuple) and len(node_arg) == 1:
                    node.args = (node_arg[0],)
        fx_g.graph.lint()
        fx_g.recompile()
        return fx_g

    fx_graph = _unwrap_single_tuple_return(fx_graph)
    ts_graph = torch.jit.script(fx_graph)

    # torchdynamo does munges the args differently depending on whether you use
    # the @torchdynamo.optimize decorator or the context manager
    if isinstance(args, tuple):
        args = list(args)
    assert isinstance(args, list)
    if len(args) == 1 and isinstance(args[0], list):
        args = args[0]

    linalg_module = compile(
        ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS
    )
    callable, _ = get_iree_compiled_module(
        linalg_module, "cuda", func_name="forward"
    )

    def forward(*inputs):
        return callable(*inputs)

    return forward

Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:

In [5]:
with torchdynamo.optimize(__torch_mlir):
    for _ in range(10):
        print(toy_example(torch.randn(10), torch.randn(10)))

Found 1 device(s).
Device: 0
  Name: NVIDIA GeForce RTX 3080
  Compute Capability: 8.6
[-0.40066046 -0.4210303   0.03225489 -0.44849953  0.10370405 -0.04422468
  0.33262825 -0.20109026  0.02102537 -0.24882983]
[-0.07824923 -0.17004533  0.06439921 -0.06163602  0.26633525 -1.1560082
 -0.06660341  0.24227881  0.1462235  -0.32055548]
[-0.01464001  0.442209   -0.0607936  -0.5477967  -0.25226554 -0.08588809
 -0.30497575  0.00061084 -0.50069696  0.2317973 ]
[ 0.25726247  0.39388427 -0.24093066  0.12316308 -0.01981307  0.5661146
  0.26199922  0.8123446  -0.01576749  0.30846444]
[ 0.7878203  -0.45975062 -0.29956317 -0.07032048 -0.55817443 -0.62506855
 -1.6837492  -0.38442805  0.28220773 -1.5325156 ]
[ 0.07975311  0.67754704 -0.30927914  0.00347631 -0.07326564  0.01893554
 -0.7518105  -0.03078967 -0.07623022  0.38865626]
[-0.7751679  -0.5841397  -0.6622711   0.18574935 -0.6049372   0.02844244
 -0.20471913  0.3337415  -0.3619432  -0.35087156]
[-0.08569919 -0.10775139 -0.02338934  0.21933547 -0.46

It can also be used through a decorator:

In [6]:
@create_backend
def torch_mlir(subgraph, *args, **kwargs):
    assert isinstance(subgraph, SubGraph), "Model must be a dynamo SubGraph."
    return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))


@torchdynamo.optimize("torch_mlir")
def toy_example2(*args):
    a, b = args

    x = a / (torch.abs(a) + 1)
    if b.sum() < 0:
        b = b * -1
    return x * b

In [7]:
for _ in range(10):
    print(toy_example2(torch.randn(10), torch.randn(10)))

Found 1 device(s).
Device: 0
  Name: NVIDIA GeForce RTX 3080
  Compute Capability: 8.6
[-0.35494277  0.03409214 -0.02271946  0.7335942   0.03122527 -0.41881397
 -0.6609761  -0.6418614   0.29336175 -0.01973678]
[-2.7246824e-01 -3.5543957e-01  6.0087401e-01 -7.4570496e-03
 -4.2481605e-02 -5.0296803e-04  7.2928613e-01 -1.4673788e-03
 -2.7621329e-01 -6.0995776e-02]
[-0.03165906  0.3889693   0.24052973  0.27279532 -0.02773128 -0.12602475
 -1.0124422   0.5720256  -0.35437614 -0.20992722]
[-0.41831446  0.5525326  -0.29749998 -0.17044766  0.11804754 -0.05210691
 -0.46145165 -0.8776549   0.10090438  0.17463352]
[ 0.02194221  0.20959911  0.26973712  0.12551276 -0.0020404   0.1490246
 -0.04456685  1.1100804   0.8105744   0.6676846 ]
[ 0.06528181 -0.13591261  0.5370964  -0.4398162  -0.03372452  0.9691372
 -0.01120087  0.2947028   0.4804801  -0.3324341 ]
[ 0.33549032 -0.23001772 -0.08681437  0.16490957 -0.11223086  0.09168988
  0.02403045  0.17344482  0.46406478 -0.00129451]
[-0.27475086  0.4238480