mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 05:24:00 -05:00
309 lines
9.7 KiB
Plaintext
309 lines
9.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {
|
|
"collapsed": true,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/mlevental/miniconda3/envs/torch-mlir/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# standard imports\n",
|
|
"import torch\n",
|
|
"from amdshark.iree_utils import get_iree_compiled_module"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"outputs": [],
|
|
"source": [
|
|
"# torch dynamo related imports\n",
|
|
"try:\n",
|
|
" import torchdynamo\n",
|
|
" from torchdynamo.optimizations.backends import create_backend\n",
|
|
" from torchdynamo.optimizations.subgraph import SubGraph\n",
|
|
"except ModuleNotFoundError:\n",
|
|
" print(\n",
|
|
" \"Please install TorchDynamo using pip install git+https://github.com/pytorch/torchdynamo\"\n",
|
|
" )\n",
|
|
" exit()\n",
|
|
"\n",
|
|
"# torch-mlir imports for compiling\n",
|
|
"from torch_mlir import compile, OutputType"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"[TorchDynamo](https://github.com/pytorch/torchdynamo) is a compiler for PyTorch programs that uses the [frame evaluation API](https://www.python.org/dev/peps/pep-0523/) in CPython to dynamically modify Python bytecode right before it is executed. It creates this FX Graph through bytecode analysis and is designed to mix Python execution with compiled backends."
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"outputs": [],
|
|
"source": [
|
|
"def toy_example(*args):\n",
|
|
" a, b = args\n",
|
|
"\n",
|
|
" x = a / (torch.abs(a) + 1)\n",
|
|
" if b.sum() < 0:\n",
|
|
" b = b * -1\n",
|
|
" return x * b"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"outputs": [],
|
|
"source": [
|
|
"# compiler that lowers fx_graph to through MLIR\n",
|
|
"def __torch_mlir(fx_graph, *args, **kwargs):\n",
|
|
" assert isinstance(\n",
|
|
" fx_graph, torch.fx.GraphModule\n",
|
|
" ), \"Model must be an FX GraphModule.\"\n",
|
|
"\n",
|
|
" def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule):\n",
|
|
" \"\"\"Replace tuple with tuple element in functions that return one-element tuples.\"\"\"\n",
|
|
"\n",
|
|
" for node in fx_g.graph.nodes:\n",
|
|
" if node.op == \"output\":\n",
|
|
" assert (\n",
|
|
" len(node.args) == 1\n",
|
|
" ), \"Output node must have a single argument\"\n",
|
|
" node_arg = node.args[0]\n",
|
|
" if isinstance(node_arg, tuple) and len(node_arg) == 1:\n",
|
|
" node.args = (node_arg[0],)\n",
|
|
" fx_g.graph.lint()\n",
|
|
" fx_g.recompile()\n",
|
|
" return fx_g\n",
|
|
"\n",
|
|
" fx_graph = _unwrap_single_tuple_return(fx_graph)\n",
|
|
" ts_graph = torch.jit.script(fx_graph)\n",
|
|
"\n",
|
|
" # torchdynamo does munges the args differently depending on whether you use\n",
|
|
" # the @torchdynamo.optimize decorator or the context manager\n",
|
|
" if isinstance(args, tuple):\n",
|
|
" args = list(args)\n",
|
|
" assert isinstance(args, list)\n",
|
|
" if len(args) == 1 and isinstance(args[0], list):\n",
|
|
" args = args[0]\n",
|
|
"\n",
|
|
" linalg_module = compile(\n",
|
|
" ts_graph, args, output_type=OutputType.LINALG_ON_TENSORS\n",
|
|
" )\n",
|
|
" callable, _ = get_iree_compiled_module(\n",
|
|
" linalg_module, \"cuda\", func_name=\"forward\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" def forward(*inputs):\n",
|
|
" return callable(*inputs)\n",
|
|
"\n",
|
|
" return forward"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"Simplest way to use TorchDynamo with the `torchdynamo.optimize` context manager:"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Found 1 device(s).\n",
|
|
"Device: 0\n",
|
|
" Name: NVIDIA GeForce RTX 3080\n",
|
|
" Compute Capability: 8.6\n",
|
|
"[-0.40066046 -0.4210303 0.03225489 -0.44849953 0.10370405 -0.04422468\n",
|
|
" 0.33262825 -0.20109026 0.02102537 -0.24882983]\n",
|
|
"[-0.07824923 -0.17004533 0.06439921 -0.06163602 0.26633525 -1.1560082\n",
|
|
" -0.06660341 0.24227881 0.1462235 -0.32055548]\n",
|
|
"[-0.01464001 0.442209 -0.0607936 -0.5477967 -0.25226554 -0.08588809\n",
|
|
" -0.30497575 0.00061084 -0.50069696 0.2317973 ]\n",
|
|
"[ 0.25726247 0.39388427 -0.24093066 0.12316308 -0.01981307 0.5661146\n",
|
|
" 0.26199922 0.8123446 -0.01576749 0.30846444]\n",
|
|
"[ 0.7878203 -0.45975062 -0.29956317 -0.07032048 -0.55817443 -0.62506855\n",
|
|
" -1.6837492 -0.38442805 0.28220773 -1.5325156 ]\n",
|
|
"[ 0.07975311 0.67754704 -0.30927914 0.00347631 -0.07326564 0.01893554\n",
|
|
" -0.7518105 -0.03078967 -0.07623022 0.38865626]\n",
|
|
"[-0.7751679 -0.5841397 -0.6622711 0.18574935 -0.6049372 0.02844244\n",
|
|
" -0.20471913 0.3337415 -0.3619432 -0.35087156]\n",
|
|
"[-0.08569919 -0.10775139 -0.02338934 0.21933547 -0.46712473 0.00062137\n",
|
|
" -0.58207744 0.06457533 0.18276742 0.03866556]\n",
|
|
"[-0.2311981 -0.43036282 0.20561649 -0.10363232 -0.13248594 0.02885137\n",
|
|
" -0.31241602 -0.36907142 0.08861586 0.2331427 ]\n",
|
|
"[-0.07273526 -0.31246194 -0.24218291 -0.24145737 0.0364486 0.14382267\n",
|
|
" -0.00531162 0.15447603 -0.5220248 -0.09016377]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"with torchdynamo.optimize(__torch_mlir):\n",
|
|
" for _ in range(10):\n",
|
|
" print(toy_example(torch.randn(10), torch.randn(10)))"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"It can also be used through a decorator:"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"outputs": [],
|
|
"source": [
|
|
"@create_backend\n",
|
|
"def torch_mlir(subgraph, *args, **kwargs):\n",
|
|
" assert isinstance(subgraph, SubGraph), \"Model must be a dynamo SubGraph.\"\n",
|
|
" return __torch_mlir(subgraph.model, *list(subgraph.example_inputs))\n",
|
|
"\n",
|
|
"\n",
|
|
"@torchdynamo.optimize(\"torch_mlir\")\n",
|
|
"def toy_example2(*args):\n",
|
|
" a, b = args\n",
|
|
"\n",
|
|
" x = a / (torch.abs(a) + 1)\n",
|
|
" if b.sum() < 0:\n",
|
|
" b = b * -1\n",
|
|
" return x * b"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Found 1 device(s).\n",
|
|
"Device: 0\n",
|
|
" Name: NVIDIA GeForce RTX 3080\n",
|
|
" Compute Capability: 8.6\n",
|
|
"[-0.35494277 0.03409214 -0.02271946 0.7335942 0.03122527 -0.41881397\n",
|
|
" -0.6609761 -0.6418614 0.29336175 -0.01973678]\n",
|
|
"[-2.7246824e-01 -3.5543957e-01 6.0087401e-01 -7.4570496e-03\n",
|
|
" -4.2481605e-02 -5.0296803e-04 7.2928613e-01 -1.4673788e-03\n",
|
|
" -2.7621329e-01 -6.0995776e-02]\n",
|
|
"[-0.03165906 0.3889693 0.24052973 0.27279532 -0.02773128 -0.12602475\n",
|
|
" -1.0124422 0.5720256 -0.35437614 -0.20992722]\n",
|
|
"[-0.41831446 0.5525326 -0.29749998 -0.17044766 0.11804754 -0.05210691\n",
|
|
" -0.46145165 -0.8776549 0.10090438 0.17463352]\n",
|
|
"[ 0.02194221 0.20959911 0.26973712 0.12551276 -0.0020404 0.1490246\n",
|
|
" -0.04456685 1.1100804 0.8105744 0.6676846 ]\n",
|
|
"[ 0.06528181 -0.13591261 0.5370964 -0.4398162 -0.03372452 0.9691372\n",
|
|
" -0.01120087 0.2947028 0.4804801 -0.3324341 ]\n",
|
|
"[ 0.33549032 -0.23001772 -0.08681437 0.16490957 -0.11223086 0.09168988\n",
|
|
" 0.02403045 0.17344482 0.46406478 -0.00129451]\n",
|
|
"[-0.27475086 0.42384806 1.9090122 -0.41147137 -0.6888369 0.08435658\n",
|
|
" -0.26628923 -0.17436793 -0.8058869 -0.02582378]\n",
|
|
"[-0.10109414 0.08681287 -0.10055986 0.6858881 0.29267687 -0.02797117\n",
|
|
" -0.01425194 0.4882803 0.3551982 -0.858935 ]\n",
|
|
"[-0.22086617 0.524994 0.17721705 -0.03813264 -0.54570735 -0.4421502\n",
|
|
" 0.11938014 -0.01122053 0.39294165 -0.61770755]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"for _ in range(10):\n",
|
|
" print(toy_example2(torch.randn(10), torch.randn(10)))"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |