{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true, "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "# standard imports\n", "import torch\n", "from amdshark.iree_utils import get_iree_compiled_module" ] }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "# torch dynamo related imports\n", "try:\n", " import torchdynamo\n", " from torchdynamo.optimizations.backends import create_backend\n", " from torchdynamo.optimizations.subgraph import SubGraph\n", "except ModuleNotFoundError:\n", " print(\n", " \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n", " )\n", " exit()\n", "\n", "# torch-mlir imports for compiling\n", "from torch_mlir import compile, OutputType" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends." ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [], "source": [ "def toy_example(*args):\n", " a, b = args\n", "\n", " x = a / (torch.abs(a) + 1)\n", " if b.sum() < 0:\n", " b = b * -1\n", " return x * b" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "# compiler that lowers fx_graph to through MLIR\n", "def __torch_mlir(fx_graph, *args, **kwargs):\n", " assert isinstance(\n", " fx_graph, torch.fx.GraphModule\n", " ), \"Model must be an FX GraphModule.\"\n", "\n", " def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):\n", " \"\"\"Replace tuple with tuple element in functions that return one-element tuples.\"\"\"\n", "\n", " for node in fx_g.graph.nodes:\n", " if node.op == \"output\":\n", " assert (\n", " len(node.args) == 1\n", " ), \"Output node must have a single argument\"\n", " node_arg = node.args[0]\n", " if isinstance(node_arg, tuple) and len(node_arg) == 1:\n", " node.args = (node_arg[0],)\n", " fx_g.graph.lint()\n", " fx_g.recompile()\n", " return fx_g\n", "\n", " fx_graph = _unwrap_single_tuple_return(fx_graph)\n", " ts_graph = torch.jit.script(fx_graph)\n", "\n", " # torchdynamo does munges the args differently depending on whether you use\n", " # the @torchdynamo.optimize decorator or the context manager\n", " if isinstance(args, tuple):\n", " args = list(args)\n", " assert isinstance(args, list)\n", " if len(args) == 1 and isinstance(args[0], list):\n", " args = args[0]\n", "\n", " linalg_module = compile(\n", " ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n", " )\n", " callable, _ = get_iree_compiled_module(\n", " linalg_module, \"cuda\", func_name=\"forward\"\n", " )\n", "\n", " def forward(*inputs):\n", " return callable(*inputs)\n", "\n", " return forward" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1 device(s).\n", "Device: 0\n", " Name: NVIDIA GeForce RTX 3080\n", " Compute Capability: 8.6\n", "[-0.40066046 -0.4210303 0.03225489 -0.44849953 0.10370405 -0.04422468\n", " 0.33262825 -0.20109026 0.02102537 -0.24882983]\n", "[-0.07824923 -0.17004533 0.06439921 -0.06163602 0.26633525 -1.1560082\n", " -0.06660341 0.24227881 0.1462235 -0.32055548]\n", "[-0.01464001 0.442209 -0.0607936 -0.5477967 -0.25226554 -0.08588809\n", " -0.30497575 0.00061084 -0.50069696 0.2317973 ]\n", "[ 0.25726247 0.39388427 -0.24093066 0.12316308 -0.01981307 0.5661146\n", " 0.26199922 0.8123446 -0.01576749 0.30846444]\n", "[ 0.7878203 -0.45975062 -0.29956317 -0.07032048 -0.55817443 -0.62506855\n", " -1.6837492 -0.38442805 0.28220773 -1.5325156 ]\n", "[ 0.07975311 0.67754704 -0.30927914 0.00347631 -0.07326564 0.01893554\n", " -0.7518105 -0.03078967 -0.07623022 0.38865626]\n", "[-0.7751679 -0.5841397 -0.6622711 0.18574935 -0.6049372 0.02844244\n", " -0.20471913 0.3337415 -0.3619432 -0.35087156]\n", "[-0.08569919 -0.10775139 -0.02338934 0.21933547 -0.46712473 0.00062137\n", " -0.58207744 0.06457533 0.18276742 0.03866556]\n", "[-0.2311981 -0.43036282 0.20561649 -0.10363232 -0.13248594 0.02885137\n", " -0.31241602 -0.36907142 0.08861586 0.2331427 ]\n", "[-0.07273526 -0.31246194 -0.24218291 -0.24145737 0.0364486 0.14382267\n", " -0.00531162 0.15447603 -0.5220248 -0.09016377]\n" ] } ], "source": [ "with torchdynamo.optimize(__torch_mlir):\n", " for _ in range(10):\n", " print(toy_example(torch.randn(10), torch.randn(10)))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "It can also be used through a decorator:" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [], "source": [ "@create_backend\n", "def torch_mlir(subgraph, *args, **kwargs):\n", " assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n", " return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n", "\n", "\n", "@torchdynamo.optimize(\"torch_mlir\")\n", "def toy_example2(*args):\n", " a, b = args\n", "\n", " x = a / (torch.abs(a) + 1)\n", " if b.sum() < 0:\n", " b = b * -1\n", " return x * b" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 7, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 1 device(s).\n", "Device: 0\n", " Name: NVIDIA GeForce RTX 3080\n", " Compute Capability: 8.6\n", "[-0.35494277 0.03409214 -0.02271946 0.7335942 0.03122527 -0.41881397\n", " -0.6609761 -0.6418614 0.29336175 -0.01973678]\n", "[-2.7246824e-01 -3.5543957e-01 6.0087401e-01 -7.4570496e-03\n", " -4.2481605e-02 -5.0296803e-04 7.2928613e-01 -1.4673788e-03\n", " -2.7621329e-01 -6.0995776e-02]\n", "[-0.03165906 0.3889693 0.24052973 0.27279532 -0.02773128 -0.12602475\n", " -1.0124422 0.5720256 -0.35437614 -0.20992722]\n", "[-0.41831446 0.5525326 -0.29749998 -0.17044766 0.11804754 -0.05210691\n", " -0.46145165 -0.8776549 0.10090438 0.17463352]\n", "[ 0.02194221 0.20959911 0.26973712 0.12551276 -0.0020404 0.1490246\n", " -0.04456685 1.1100804 0.8105744 0.6676846 ]\n", "[ 0.06528181 -0.13591261 0.5370964 -0.4398162 -0.03372452 0.9691372\n", " -0.01120087 0.2947028 0.4804801 -0.3324341 ]\n", "[ 0.33549032 -0.23001772 -0.08681437 0.16490957 -0.11223086 0.09168988\n", " 0.02403045 0.17344482 0.46406478 -0.00129451]\n", "[-0.27475086 0.42384806 1.9090122 -0.41147137 -0.6888369 0.08435658\n", " -0.26628923 -0.17436793 -0.8058869 -0.02582378]\n", "[-0.10109414 0.08681287 -0.10055986 0.6858881 0.29267687 -0.02797117\n", " -0.01425194 0.4882803 0.3551982 -0.858935 ]\n", "[-0.22086617 0.524994 0.17721705 -0.03813264 -0.54570735 -0.4421502\n", " 0.11938014 -0.01122053 0.39294165 -0.61770755]\n" ] } ], "source": [ "for _ in range(10):\n", " print(toy_example2(torch.randn(10), torch.randn(10)))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }