Files
SHARK-Studio/amdshark/examples/amdshark_eager/dynamo_demo.ipynb
pdhirajkumarprasad fe03539901 Migration to AMDShark (#2182)
Signed-off-by: pdhirajkumarprasad <dhirajp@amd.com>
2025-11-20 12:52:07 +05:30

309 lines
9.7 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"collapsed": true,
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"# standard imports\n",
"import torch\n",
"from amdshark.iree_utils import get_iree_compiled_module"
]
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"# torch dynamo related imports\n",
"try:\n",
" import torchdynamo\n",
" from torchdynamo.optimizations.backends import create_backend\n",
" from torchdynamo.optimizations.subgraph import SubGraph\n",
"except ModuleNotFoundError:\n",
" print(\n",
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
" )\n",
" exit()\n",
"\n",
"# torch-mlir imports for compiling\n",
"from torch_mlir import compile, OutputType"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends."
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"def toy_example(*args):\n",
" a, b = args\n",
"\n",
" x = a / (torch.abs(a) + 1)\n",
" if b.sum() < 0:\n",
" b = b * -1\n",
" return x * b"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"# compiler that lowers fx_graph to through MLIR\n",
"def __torch_mlir(fx_graph, *args, **kwargs):\n",
" assert isinstance(\n",
" fx_graph, torch.fx.GraphModule\n",
" ), \"Model must be an FX GraphModule.\"\n",
"\n",
" def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):\n",
" \"\"\"Replace tuple with tuple element in functions that return one-element tuples.\"\"\"\n",
"\n",
" for node in fx_g.graph.nodes:\n",
" if node.op == \"output\":\n",
" assert (\n",
" len(node.args) == 1\n",
" ), \"Output node must have a single argument\"\n",
" node_arg = node.args[0]\n",
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
" node.args = (node_arg[0],)\n",
" fx_g.graph.lint()\n",
" fx_g.recompile()\n",
" return fx_g\n",
"\n",
" fx_graph = _unwrap_single_tuple_return(fx_graph)\n",
" ts_graph = torch.jit.script(fx_graph)\n",
"\n",
" # torchdynamo does munges the args differently depending on whether you use\n",
" # the @torchdynamo.optimize decorator or the context manager\n",
" if isinstance(args, tuple):\n",
" args = list(args)\n",
" assert isinstance(args, list)\n",
" if len(args) == 1 and isinstance(args[0], list):\n",
" args = args[0]\n",
"\n",
" linalg_module = compile(\n",
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
" )\n",
" callable, _ = get_iree_compiled_module(\n",
" linalg_module, \"cuda\", func_name=\"forward\"\n",
" )\n",
"\n",
" def forward(*inputs):\n",
" return callable(*inputs)\n",
"\n",
" return forward"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 1 device(s).\n",
"Device: 0\n",
" Name: NVIDIA GeForce RTX 3080\n",
" Compute Capability: 8.6\n",
"[-0.40066046 -0.4210303 0.03225489 -0.44849953 0.10370405 -0.04422468\n",
" 0.33262825 -0.20109026 0.02102537 -0.24882983]\n",
"[-0.07824923 -0.17004533 0.06439921 -0.06163602 0.26633525 -1.1560082\n",
" -0.06660341 0.24227881 0.1462235 -0.32055548]\n",
"[-0.01464001 0.442209 -0.0607936 -0.5477967 -0.25226554 -0.08588809\n",
" -0.30497575 0.00061084 -0.50069696 0.2317973 ]\n",
"[ 0.25726247 0.39388427 -0.24093066 0.12316308 -0.01981307 0.5661146\n",
" 0.26199922 0.8123446 -0.01576749 0.30846444]\n",
"[ 0.7878203 -0.45975062 -0.29956317 -0.07032048 -0.55817443 -0.62506855\n",
" -1.6837492 -0.38442805 0.28220773 -1.5325156 ]\n",
"[ 0.07975311 0.67754704 -0.30927914 0.00347631 -0.07326564 0.01893554\n",
" -0.7518105 -0.03078967 -0.07623022 0.38865626]\n",
"[-0.7751679 -0.5841397 -0.6622711 0.18574935 -0.6049372 0.02844244\n",
" -0.20471913 0.3337415 -0.3619432 -0.35087156]\n",
"[-0.08569919 -0.10775139 -0.02338934 0.21933547 -0.46712473 0.00062137\n",
" -0.58207744 0.06457533 0.18276742 0.03866556]\n",
"[-0.2311981 -0.43036282 0.20561649 -0.10363232 -0.13248594 0.02885137\n",
" -0.31241602 -0.36907142 0.08861586 0.2331427 ]\n",
"[-0.07273526 -0.31246194 -0.24218291 -0.24145737 0.0364486 0.14382267\n",
" -0.00531162 0.15447603 -0.5220248 -0.09016377]\n"
]
}
],
"source": [
"with torchdynamo.optimize(__torch_mlir):\n",
" for _ in range(10):\n",
" print(toy_example(torch.randn(10), torch.randn(10)))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"It can also be used through a decorator:"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [],
"source": [
"@create_backend\n",
"def torch_mlir(subgraph, *args, **kwargs):\n",
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
"\n",
"\n",
"@torchdynamo.optimize(\"torch_mlir\")\n",
"def toy_example2(*args):\n",
" a, b = args\n",
"\n",
" x = a / (torch.abs(a) + 1)\n",
" if b.sum() < 0:\n",
" b = b * -1\n",
" return x * b"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found 1 device(s).\n",
"Device: 0\n",
" Name: NVIDIA GeForce RTX 3080\n",
" Compute Capability: 8.6\n",
"[-0.35494277 0.03409214 -0.02271946 0.7335942 0.03122527 -0.41881397\n",
" -0.6609761 -0.6418614 0.29336175 -0.01973678]\n",
"[-2.7246824e-01 -3.5543957e-01 6.0087401e-01 -7.4570496e-03\n",
" -4.2481605e-02 -5.0296803e-04 7.2928613e-01 -1.4673788e-03\n",
" -2.7621329e-01 -6.0995776e-02]\n",
"[-0.03165906 0.3889693 0.24052973 0.27279532 -0.02773128 -0.12602475\n",
" -1.0124422 0.5720256 -0.35437614 -0.20992722]\n",
"[-0.41831446 0.5525326 -0.29749998 -0.17044766 0.11804754 -0.05210691\n",
" -0.46145165 -0.8776549 0.10090438 0.17463352]\n",
"[ 0.02194221 0.20959911 0.26973712 0.12551276 -0.0020404 0.1490246\n",
" -0.04456685 1.1100804 0.8105744 0.6676846 ]\n",
"[ 0.06528181 -0.13591261 0.5370964 -0.4398162 -0.03372452 0.9691372\n",
" -0.01120087 0.2947028 0.4804801 -0.3324341 ]\n",
"[ 0.33549032 -0.23001772 -0.08681437 0.16490957 -0.11223086 0.09168988\n",
" 0.02403045 0.17344482 0.46406478 -0.00129451]\n",
"[-0.27475086 0.42384806 1.9090122 -0.41147137 -0.6888369 0.08435658\n",
" -0.26628923 -0.17436793 -0.8058869 -0.02582378]\n",
"[-0.10109414 0.08681287 -0.10055986 0.6858881 0.29267687 -0.02797117\n",
" -0.01425194 0.4882803 0.3551982 -0.858935 ]\n",
"[-0.22086617 0.524994 0.17721705 -0.03813264 -0.54570735 -0.4421502\n",
" 0.11938014 -0.01122053 0.39294165 -0.61770755]\n"
]
}
],
"source": [
"for _ in range(10):\n",
" print(toy_example2(torch.randn(10), torch.randn(10)))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}