From 9e07824542b50f5c2f276a897efacb7c32e3b050 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 27 Nov 2023 11:34:37 -0800 Subject: [PATCH] move device to device.py (#2466) * move device to device.py * pylint test --disable R,C,W,E --enable E0611 * fix tests --- .github/workflows/test.yml | 6 +- docs/abstractions.py | 2 +- examples/f16_w_uint32.py | 2 +- examples/gpt2.py | 2 +- examples/hlb_cifar10.py | 2 +- examples/llama.py | 2 +- examples/stable_diffusion.py | 2 +- examples/webgpu/stable_diffusion/compile.py | 2 +- extra/dist/__init__.py | 2 +- extra/gemm/metal_conv.py | 2 +- extra/gemm/metal_matvec.py | 2 +- extra/utils.py | 2 +- openpilot/compile2.py | 4 +- test/external/external_llama_eval.py | 2 +- test/external/external_model_benchmark.py | 2 +- .../external_test_allocator_on_models.py | 2 +- test/external/external_test_jit_on_models.py | 2 +- test/external/external_test_onnx_backend.py | 2 +- test/external/external_test_speed_llama.py | 4 +- test/external/fuzz_linearizer.py | 3 +- test/external/test_example.py | 2 +- test/models/test_bert.py | 2 +- test/models/test_real_world.py | 2 +- test/models/test_waifu2x.py | 2 +- test/models/test_whisper.py | 2 +- test/test_allocators.py | 2 +- test/test_assign.py | 2 +- test/test_copy_speed.py | 2 +- test/test_custom_function.py | 2 +- test/test_dtype.py | 2 +- test/test_jit.py | 2 +- test/test_kernel_cache.py | 2 +- test/test_lazybuffer.py | 2 +- test/test_linearizer.py | 5 +- test/test_linearizer_failures.py | 2 +- test/test_ops.py | 2 +- test/test_sample.py | 2 +- test/test_schedule.py | 3 +- test/test_search.py | 3 +- test/test_specific_conv.py | 2 +- test/test_speed_v_torch.py | 2 +- test/test_to_numpy.py | 2 +- test/test_uops.py | 3 +- tinygrad/__init__.py | 2 +- tinygrad/codegen/kernel.py | 3 +- tinygrad/device.py | 200 +++++++++++++++++ tinygrad/features/search.py | 3 +- tinygrad/jit.py | 3 +- tinygrad/nn/state.py | 2 +- tinygrad/ops.py | 202 +----------------- tinygrad/realize.py | 3 +- tinygrad/runtime/ops_clang.py | 2 +- tinygrad/runtime/ops_cpu.py | 3 +- tinygrad/runtime/ops_cuda.py | 2 +- tinygrad/runtime/ops_disk.py | 3 +- tinygrad/runtime/ops_gpu.py | 2 +- tinygrad/runtime/ops_hip.py | 2 +- tinygrad/runtime/ops_llvm.py | 2 +- tinygrad/runtime/ops_metal.py | 2 +- tinygrad/runtime/ops_torch.py | 3 +- tinygrad/runtime/ops_webgpu.py | 2 +- tinygrad/tensor.py | 3 +- 62 files changed, 281 insertions(+), 264 deletions(-) create mode 100644 tinygrad/device.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bd68558bd4..a406efda70 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -245,7 +245,7 @@ jobs: # dtype test has issues on test_half_to_int8 - name: Check Device.DEFAULT (METAL) and print some source run: | - METAL=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'METAL', Device.DEFAULT" + METAL=1 python -c "from tinygrad import Device; assert Device.DEFAULT == 'METAL', Device.DEFAULT" METAL=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add - name: Run metal ops test run: DEBUG=2 METAL=1 python -m pytest -n=auto test/test_ops.py @@ -263,7 +263,7 @@ jobs: run: METAL=1 TC=2 python -m pytest -n=auto test/test_ops.py - name: Check Device.DEFAULT (WEBGPU) and print some source run: | - WEBGPU=1 python -c "from tinygrad.ops import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT" + WEBGPU=1 python -c "from tinygrad import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT" WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add #- name: Run webgpu pytest # run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto @@ -349,7 +349,7 @@ jobs: run: pip install -e '.[testing${{matrix.backend=='llvm'&&',llvm'||matrix.backend=='cuda'&&',cuda'||matrix.backend=='ptx'&&',cuda'||matrix.backend=='triton'&&',triton'||''}}]' --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ - name: Check Device.DEFAULT and print some source run: | - python -c "from tinygrad.ops import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT" + python -c "from tinygrad import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT" DEBUG=5 PYTHONPATH=${{ github.workspace }} FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add - name: Run pytest (not cuda) if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton' diff --git a/docs/abstractions.py b/docs/abstractions.py index 32f629c157..537ef3f061 100644 --- a/docs/abstractions.py +++ b/docs/abstractions.py @@ -22,7 +22,7 @@ from abc import ABC # let's trace an addition down through the layers of abstraction. # we will be using the clang backend -from tinygrad.ops import Device +from tinygrad import Device Device.DEFAULT = "CLANG" # first, 2+3 as a Tensor, the highest level diff --git a/examples/f16_w_uint32.py b/examples/f16_w_uint32.py index 0a01bae72e..bf281661e9 100644 --- a/examples/f16_w_uint32.py +++ b/examples/f16_w_uint32.py @@ -1,7 +1,7 @@ import numpy as np from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes -from tinygrad.ops import Device +from tinygrad import Device # TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul def bit_extract(x, s, e) -> Tensor: diff --git a/examples/gpt2.py b/examples/gpt2.py index 1b875a7ec0..5a818a9c2b 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -2,7 +2,7 @@ import argparse from tqdm import trange import numpy as np -from tinygrad.ops import Device +from tinygrad import Device from typing import Optional from tinygrad.tensor import Tensor from tinygrad.nn import Embedding, Linear, LayerNorm diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index a298d12d5c..4c0c7764f5 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -16,7 +16,7 @@ from extra.datasets import fetch_cifar, cifar_mean, cifar_std from tinygrad import nn from tinygrad.nn.state import get_state_dict from tinygrad.nn import optim -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.tensor import Tensor from tinygrad.helpers import GlobalCounters from tinygrad.shape.symbolic import Node diff --git a/examples/llama.py b/examples/llama.py index 18ac8129d1..84a209f468 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -8,7 +8,7 @@ import sys, argparse, json import numpy as np np.set_printoptions(linewidth=200) from tinygrad.helpers import Timing, Profiling, getenv, DEBUG, dtypes -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.tensor import Tensor from tinygrad.nn.state import safe_load, torch_load, load_state_dict, get_parameters from tinygrad.helpers import GlobalCounters diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 316e0b9326..5e2dd2bbb9 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -8,7 +8,7 @@ from collections import namedtuple from tqdm import tqdm from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv, fetch from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict diff --git a/examples/webgpu/stable_diffusion/compile.py b/examples/webgpu/stable_diffusion/compile.py index f7a8600a9d..7984e63132 100644 --- a/examples/webgpu/stable_diffusion/compile.py +++ b/examples/webgpu/stable_diffusion/compile.py @@ -3,7 +3,7 @@ from extra.export_model import compile_net, jit_model from examples.stable_diffusion import StableDiffusion from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.helpers import fetch from typing import NamedTuple, Any, List from pathlib import Path diff --git a/extra/dist/__init__.py b/extra/dist/__init__.py index 6a75354987..e25e58c1e0 100644 --- a/extra/dist/__init__.py +++ b/extra/dist/__init__.py @@ -40,7 +40,7 @@ def _process_wrap(rank:int, device:str, oob:_OOB, fn:Callable, args=()): OOB = oob # do specific runtime initialization for distributed - from tinygrad.ops import Device + from tinygrad import Device device, device_num = Device.canonicalize(device), 0 if ":" not in device else int(device.split(":")[-1]) if "GPU" in device: from tinygrad.runtime.ops_gpu import CL diff --git a/extra/gemm/metal_conv.py b/extra/gemm/metal_conv.py index c37c93b9ea..9b3df4628b 100644 --- a/extra/gemm/metal_conv.py +++ b/extra/gemm/metal_conv.py @@ -33,7 +33,7 @@ except RuntimeError: from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit -from tinygrad.ops import Device +from tinygrad import Device b = Tensor(nb) c = Tensor(nc) # TODO: slowness without the JIT I suspect comes from a lack of a caching allocator diff --git a/extra/gemm/metal_matvec.py b/extra/gemm/metal_matvec.py index 6a99375462..60df010d49 100644 --- a/extra/gemm/metal_matvec.py +++ b/extra/gemm/metal_matvec.py @@ -6,7 +6,7 @@ import time, torch, torch.mps from tinygrad.helpers import GlobalCounters from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.helpers import colored, getenv, CI import os diff --git a/extra/utils.py b/extra/utils.py index e4b1f4d6a1..d7c39ff255 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -10,7 +10,7 @@ from tinygrad.helpers import prod, getenv, DEBUG, dtypes, get_child from tinygrad.helpers import GlobalCounters from tinygrad.tensor import Tensor from tinygrad.lazy import LazyBuffer -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.shape.view import strides_for_shape OSX = platform.system() == "Darwin" WINDOWS = platform.system() == "Windows" diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 872006ad34..441cb46711 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -14,10 +14,10 @@ from typing import Tuple, List from extra.utils import fetch from extra.onnx import get_run_onnx from tinygrad.graph import print_tree, log_schedule_item -from tinygrad.tensor import Tensor +from tinygrad import Tensor, Device from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv, ImageDType, GRAPH from tinygrad.realize import run_schedule -from tinygrad.ops import LoadOps, Device, ScheduleItem +from tinygrad.ops import LoadOps, ScheduleItem from tinygrad.features.image import fix_schedule_for_images Device.DEFAULT = "GPU" diff --git a/test/external/external_llama_eval.py b/test/external/external_llama_eval.py index dfbbcbac6c..ce1cbd8004 100644 --- a/test/external/external_llama_eval.py +++ b/test/external/external_llama_eval.py @@ -4,7 +4,7 @@ import torch, json, argparse from examples.llama import LLaMa from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device class LLaMaAdaptor(BaseLM): def __init__( diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index b784f11b65..857c148a1d 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -10,7 +10,7 @@ from extra.utils import download_file from extra.onnx import get_run_onnx from tinygrad.helpers import OSX, DEBUG from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device MODELS = { "resnet50": "https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx", diff --git a/test/external/external_test_allocator_on_models.py b/test/external/external_test_allocator_on_models.py index a1c86aa483..a0211ab111 100644 --- a/test/external/external_test_allocator_on_models.py +++ b/test/external/external_test_allocator_on_models.py @@ -6,7 +6,7 @@ from tinygrad.nn.state import get_state_dict from tinygrad.helpers import GlobalCounters from tinygrad.runtime.lib import RawBuffer, LRUAllocator from tinygrad.helpers import dtypes, prod -from tinygrad.ops import Device +from tinygrad import Device from test.helpers import derandomize_model from examples.llama import Transformer diff --git a/test/external/external_test_jit_on_models.py b/test/external/external_test_jit_on_models.py index 95a5d36ca1..ba47cc23a7 100644 --- a/test/external/external_test_jit_on_models.py +++ b/test/external/external_test_jit_on_models.py @@ -4,7 +4,7 @@ import numpy as np from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit from tinygrad.helpers import dtypes, CI -from tinygrad.ops import Device +from tinygrad import Device from test.helpers import derandomize_model from examples.llama import Transformer diff --git a/test/external/external_test_onnx_backend.py b/test/external/external_test_onnx_backend.py index 659cc88757..4df86c88c8 100644 --- a/test/external/external_test_onnx_backend.py +++ b/test/external/external_test_onnx_backend.py @@ -5,7 +5,7 @@ import onnx.backend.test import numpy as np from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, CI -from tinygrad.ops import Device +from tinygrad import Device # pip3 install tabulate pytest_plugins = 'onnx.backend.test.report', diff --git a/test/external/external_test_speed_llama.py b/test/external/external_test_speed_llama.py index 4d5777d76d..2f54ad89d6 100644 --- a/test/external/external_test_speed_llama.py +++ b/test/external/external_test_speed_llama.py @@ -3,9 +3,9 @@ import unittest, time import numpy as np from examples.llama import Transformer, MODEL_PARAMS from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.nn.state import get_state_dict -from tinygrad.ops import Compiled +from tinygrad.device import Compiled from tinygrad.helpers import dtypes, prod, Profiling from tinygrad.runtime.lib import RawBuffer diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index ba185a1242..6684664299 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -6,12 +6,11 @@ from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import get_linearizer_actions, bufs_from_lin, tuplize_uops from tinygrad.graph import print_tree from tinygrad.helpers import getenv -from tinygrad.ops import Device, Compiled, Interpreted +from tinygrad.device import Device, Compiled, Interpreted from tinygrad.lazy import vars_from_ast device = Device[Device.DEFAULT] - def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None): if rawbufs is None: rawbufs = bufs_from_lin(lin) if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)} diff --git a/test/external/test_example.py b/test/external/test_example.py index af66df182b..ed3c71b1ed 100644 --- a/test/external/test_example.py +++ b/test/external/test_example.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, CI diff --git a/test/models/test_bert.py b/test/models/test_bert.py index d382e78d6a..f09eb56970 100644 --- a/test/models/test_bert.py +++ b/test/models/test_bert.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device import torch def get_question_samp(bsz, seq_len, vocab_size, seed): diff --git a/test/models/test_real_world.py b/test/models/test_real_world.py index 24e3f5c01b..f2e9a9aa9b 100644 --- a/test/models/test_real_world.py +++ b/test/models/test_real_world.py @@ -4,7 +4,7 @@ from tinygrad.tensor import Tensor from tinygrad.nn import optim from tinygrad.nn.state import get_parameters from tinygrad.jit import TinyJit -from tinygrad.ops import Device, GlobalCounters +from tinygrad import Device, GlobalCounters from tinygrad.helpers import CI, dtypes from test.helpers import derandomize_model diff --git a/test/models/test_waifu2x.py b/test/models/test_waifu2x.py index 0b34ae0356..7af07877ad 100644 --- a/test/models/test_waifu2x.py +++ b/test/models/test_waifu2x.py @@ -3,7 +3,7 @@ import pathlib import unittest import numpy as np from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device class TestVGG7(unittest.TestCase): def test_vgg7(self): diff --git a/test/models/test_whisper.py b/test/models/test_whisper.py index 4b93e041ae..0423c22c67 100644 --- a/test/models/test_whisper.py +++ b/test/models/test_whisper.py @@ -2,7 +2,7 @@ import unittest import pathlib from examples.whisper import init_whisper, load_file_waveform, transcribe_file, transcribe_waveform from tinygrad.helpers import CI, fetch -from tinygrad.ops import Device +from tinygrad import Device # Audio generated with the command on MacOS: # say "Could you please let me out of the box?" --file-format=WAVE --data-format=LEUI8@16000 -o test diff --git a/test/test_allocators.py b/test/test_allocators.py index 849ccbb96a..1fe00ecc57 100644 --- a/test/test_allocators.py +++ b/test/test_allocators.py @@ -7,7 +7,7 @@ from weakref import ref from tinygrad.helpers import GlobalCounters from tinygrad.runtime.lib import RawBuffer, LRUAllocator from tinygrad.helpers import dtypes, prod -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.tensor import Tensor def check_gc(): diff --git a/test/test_assign.py b/test/test_assign.py index 02b9521474..1373268245 100644 --- a/test/test_assign.py +++ b/test/test_assign.py @@ -2,7 +2,7 @@ import unittest import numpy as np from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.helpers import dtypes N = 200 # has to be bigger than the cache to fail diff --git a/test/test_copy_speed.py b/test/test_copy_speed.py index 4f78c1cb04..a7e98a080b 100644 --- a/test/test_copy_speed.py +++ b/test/test_copy_speed.py @@ -1,6 +1,6 @@ import unittest from tinygrad import Tensor -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.helpers import Timing, CI import multiprocessing.shared_memory as shared_memory diff --git a/test/test_custom_function.py b/test/test_custom_function.py index 92968b84a4..f12f51f44f 100644 --- a/test/test_custom_function.py +++ b/test/test_custom_function.py @@ -9,7 +9,7 @@ from tinygrad.helpers import prod, dtypes # *** first, we implement the atan2 op at the lowest level *** # `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers from tinygrad.lazy import LazyBuffer, create_lazybuffer -from tinygrad.ops import CompiledASTRunner, Device +from tinygrad.device import CompiledASTRunner, Device from tinygrad.shape.shapetracker import ShapeTracker import pytest diff --git a/test/test_dtype.py b/test/test_dtype.py index f03a64fa7f..64668eb514 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,7 +1,7 @@ import unittest import numpy as np from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.tensor import Tensor, dtypes from typing import Any, List diff --git a/test/test_jit.py b/test/test_jit.py index 8a43dfd6d3..8366244a15 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1,7 +1,7 @@ #!/usr/bin/env python import unittest import numpy as np -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit diff --git a/test/test_kernel_cache.py b/test/test_kernel_cache.py index 82a38f5319..96a38cf8f3 100644 --- a/test/test_kernel_cache.py +++ b/test/test_kernel_cache.py @@ -3,7 +3,7 @@ import unittest import secrets import string from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.helpers import diskcache def generate_random_string(length=16): diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 45e222f93f..381c39ec80 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -2,7 +2,7 @@ import numpy as np import unittest from tinygrad.lazy import LazyBuffer -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.tensor import Tensor from tinygrad.shape.symbolic import Variable from tinygrad.jit import CacheCollector diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 3668598262..e11069d8d7 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -3,13 +3,14 @@ import unittest, os from tinygrad.codegen.kernel import Opt, OptOps, tensor_cores from tinygrad.codegen.linearizer import Linearizer, UOp, UOps -from tinygrad.ops import BufferOps, Compiled, ConstBuffer, Device, LazyOp, LoadOps, TernaryOps +from tinygrad.device import Compiled, Device +from tinygrad.ops import BufferOps, ConstBuffer, LazyOp, LoadOps, TernaryOps from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View from tinygrad.tensor import Tensor from tinygrad.jit import CacheCollector from tinygrad.realize import run_schedule -from tinygrad.helpers import dtypes, prod, getenv, CI +from tinygrad.helpers import dtypes, prod class TestLinearizer(unittest.TestCase): def test_arg_dedup(self): diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index cfb4caa966..424aaac73f 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -1,7 +1,7 @@ import unittest from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import Opt, OptOps -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.helpers import OSX, CI from test.external.fuzz_linearizer import run_linearizer diff --git a/test/test_ops.py b/test/test_ops.py index 4dc5d88ff3..a8726cdb8a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -5,7 +5,7 @@ import numpy as np import unittest from tinygrad.tensor import Tensor from tinygrad.helpers import getenv, IMAGE, DEBUG, CI, dtypes -from tinygrad.ops import Device +from tinygrad import Device if CI: import warnings diff --git a/test/test_sample.py b/test/test_sample.py index 9ed8ad8ced..66692f6a5a 100644 --- a/test/test_sample.py +++ b/test/test_sample.py @@ -1,7 +1,7 @@ import unittest import numpy as np from tinygrad.tensor import Tensor -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.shape.symbolic import Variable @unittest.skipUnless(Device.DEFAULT in ["GPU", "METAL", "CLANG", "CUDA", "LLVM"], f"{Device.DEFAULT} is not supported") diff --git a/test/test_schedule.py b/test/test_schedule.py index 9da8b578ed..d54ed41dd0 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -5,7 +5,8 @@ import unittest from typing import List, Optional from tinygrad.tensor import Tensor -from tinygrad.ops import LoadOps, Device, Compiled +from tinygrad.ops import LoadOps +from tinygrad.device import Device, Compiled from tinygrad.helpers import DEBUG, dtypes from tinygrad.codegen.linearizer import Linearizer from tinygrad.graph import log_schedule_item, print_tree diff --git a/test/test_search.py b/test/test_search.py index 400c04a575..3ec3a49cc9 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -2,7 +2,8 @@ import unittest from tinygrad.codegen.linearizer import Linearizer from tinygrad.features.search import time_linearizer -from tinygrad.ops import Compiled, Device, LoadOps +from tinygrad.device import Compiled, Device +from tinygrad.ops import LoadOps from tinygrad.tensor import Tensor class TestTimeLinearizer(unittest.TestCase): diff --git a/test/test_specific_conv.py b/test/test_specific_conv.py index ac28e6ccc2..992f7a8abf 100644 --- a/test/test_specific_conv.py +++ b/test/test_specific_conv.py @@ -1,7 +1,7 @@ import unittest from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes -from tinygrad.ops import Device +from tinygrad import Device import pytest # similar to test/external/external_test_gpu_ast.py, but universal diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 0537d82a07..c360675503 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -9,7 +9,7 @@ torch.set_num_threads(1) import time import numpy as np np.set_printoptions(linewidth=160) -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.helpers import GlobalCounters from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d diff --git a/test/test_to_numpy.py b/test/test_to_numpy.py index 9acdcb54a4..a5a5c9c133 100644 --- a/test/test_to_numpy.py +++ b/test/test_to_numpy.py @@ -1,4 +1,4 @@ -from tinygrad.ops import Device +from tinygrad import Device from tinygrad.tensor import Tensor import numpy as np import pickle diff --git a/test/test_uops.py b/test/test_uops.py index 8749e89123..49d3120974 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -3,7 +3,8 @@ import unittest, math import numpy as np from tinygrad.helpers import dtypes, getenv, DType, PtrDType from tinygrad.tensor import Device -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, CompiledASTRunner, Compiled +from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps +from tinygrad.device import CompiledASTRunner, Compiled from tinygrad.codegen.linearizer import UOps, UOp def _uops_to_prg(uops): diff --git a/tinygrad/__init__.py b/tinygrad/__init__.py index 7538966865..0913ae165a 100644 --- a/tinygrad/__init__.py +++ b/tinygrad/__init__.py @@ -4,5 +4,5 @@ from tinygrad.shape.symbolic import Variable # noqa: F401 from tinygrad.helpers import dtypes # noqa: F401 # NOTE: these should not be relied on to be stable -from tinygrad.ops import Device # noqa: F401 +from tinygrad.device import Device # noqa: F401 from tinygrad.helpers import GlobalCounters # noqa: F401 \ No newline at end of file diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index cd40b0c36a..7e4387c2d4 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -2,7 +2,8 @@ from __future__ import annotations import os, math, itertools from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union from tinygrad.lazy import vars_from_ast -from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps, Device, Compiled +from tinygrad.ops import LazyOp, FlopCounter, get_lazyop_info, UnaryOps, BinaryOps, ReduceOps, MemBuffer, ConstBuffer, BufferOps +from tinygrad.device import Device, Compiled from tinygrad.helpers import dedup, dtypes, colored, ImageDType, DType, ansilen, getenv, prod, DEBUG, round_up from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import sint diff --git a/tinygrad/device.py b/tinygrad/device.py new file mode 100644 index 0000000000..f7ea12e33c --- /dev/null +++ b/tinygrad/device.py @@ -0,0 +1,200 @@ +from __future__ import annotations +from typing import TYPE_CHECKING, Union, Type, Any, List, Optional, Dict, Callable +import importlib, inspect, functools, pathlib, time, re +from tinygrad.helpers import ansilen, DEBUG, getenv, GlobalCounters, colored, BEAM, NOOPT, all_int, to_function_name +from tinygrad.runtime.lib import RawBuffer +from tinygrad.shape.symbolic import Variable, sym_infer, sint +from tinygrad.ops import LazyOp, TernaryOps, get_lazyop_info, ReduceOps, BufferOps, BinaryOps, Op + +if TYPE_CHECKING: + from tinygrad.codegen.linearizer import Linearizer + from tinygrad.codegen.kernel import LinearizerOptions + +# **************** Device **************** + +class _Device: + def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] + def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT + @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none + def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: + x = x.split(":")[0].upper() + return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0] + @functools.cached_property + def DEFAULT(self) -> str: + device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) # type: ignore + if device_from_env: return device_from_env + for device in ["METAL", "CUDA", "GPU"]: + try: + if self[device]: return device + except Exception: pass + return "CPU" +Device = _Device() + +# **************** shared device helpers **************** + +class JITRunner: + def __init__(self): + self.op_estimate, self.mem_estimate = 0, 0 + def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: + var_vals = var_vals if var_vals is not None else {} + from tinygrad.jit import CacheCollector + et = self(rawbufs, var_vals) + CacheCollector.add(self, rawbufs, var_vals) + return et + def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + raise NotImplementedError("override this") + +def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra: Optional[Dict]=None): + if var_vals is None: var_vals = {} + op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) + if DEBUG >= 2: + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + + (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) + GlobalCounters.kernel_count += num_kernels + GlobalCounters.global_ops += op_estimate + GlobalCounters.global_mem += mem_estimate + if et is not None: GlobalCounters.time_sum_s += et + +# **************** for Interpreted Buffers **************** + +class InterpretedASTRunner(JITRunner): + def __init__(self, ast:LazyOp, fxn:Callable): + super().__init__() + self.fxn = fxn + info = get_lazyop_info(ast) + self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate + + def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float: + st = time.perf_counter() + ret: RawBuffer = self.fxn(rawbufs[1:], var_vals) + et = time.perf_counter() - st + update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit) + assert rawbufs[0].dtype == ret.dtype, f"dtype mismatch in Interpreted, {rawbufs[0].dtype=} != {ret.dtype=}" + rawbufs[0].dtype, rawbufs[0].size, rawbufs[0]._buf, rawbufs[0].offset = ret.dtype, ret.size, ret._buf, ret.offset + return et + +class Interpreted: + def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable]): + self.buffer, self.fxn_for_op = buffer, fxn_for_op + self.synchronize, self.codegen, self.graph = lambda: None, None, None + + @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none + def get_runner(self, ast:LazyOp) -> InterpretedASTRunner: return _get_interpreted_fxn(self.fxn_for_op, ast) + +def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner: + if DEBUG >= 3: + from tinygrad.graph import print_tree + print_tree(ast) + tglob: Dict[str, Any] = {"Variable": Variable} + lines: List[str] = [] + + @functools.lru_cache(None) + def gstr(x:Any, nm=None) -> str: + if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg): + str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg) + # TODO: (Variable - Variable) might create NumNode. can we remove it? + return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg) + ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}" + tglob[ret] = x + return ret + + @functools.lru_cache(None) + def _interpret_ast(ast:LazyOp) -> str: + if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: + ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) + + if ast.op in BufferOps: + tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])" + for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})" + else: + tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})" + + ret = f"a{len(lines)}" + lines.append(f" {ret} = {tmp}") + return ret + + ret = _interpret_ast(ast) + src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(fxn_for_op[BufferOps.FROM_UNDERLYING], BufferOps.FROM_UNDERLYING)}({ret})" if BufferOps.FROM_UNDERLYING in fxn_for_op else f" return {ret}"]) + if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src)) + exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used + return InterpretedASTRunner(ast, tglob['run']) + +# **************** for Compiled Buffers **************** + +class CompiledASTRunner(JITRunner): + def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None): + super().__init__() + if DEBUG >= 4: print(prg) + if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) + if local_size is not None: local_size = local_size + [1]*(3-len(local_size)) + self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \ + to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {} + self.vars: List[Variable] = [] + if ast: + info = get_lazyop_info(ast) + self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate + from tinygrad.lazy import vars_from_ast + self.vars = vars_from_ast(ast) + assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" + + def build(self, compiler, runtime): + self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) + self.clprg = runtime(self.name, self.lib) + return self + + def launch_dims(self, var_vals): + global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size + local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size + return global_size, local_size + + def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: + global_size, local_size = self.launch_dims(var_vals) + if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] + # TODO: this is copied from get_program + from tinygrad.features.search import optimize_local_size + local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs) + global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] + lra = self.runtime_args.copy() + if global_size: lra['global_size'] = global_size + if local_size and 'local_size' not in lra: lra['local_size'] = local_size + et = self.clprg(*rawbufs, *[var_vals[k] for k in self.vars], **lra, wait=wait or DEBUG>=2) + update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra) + return et + +class Compiled: + def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, graph=None): + self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.graph = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, graph + + def to_program(self, k:Linearizer) -> CompiledASTRunner: + k.linearize() + src, runtime_args = self.renderer(to_function_name(k.name), k.uops) + return CompiledASTRunner(k.ast, k.name, src, k.global_size, k.local_size, runtime_args).build(self.compiler, self.runtime) + + @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none + def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(_get_optimized_linearizer(self.linearizer_opts, ast)) + +def _get_optimized_linearizer(linearizer_opts:LinearizerOptions, ast:LazyOp) -> Linearizer: + if DEBUG >= 3: + from tinygrad.graph import print_tree + print_tree(ast) + from tinygrad.codegen.linearizer import Linearizer + k = Linearizer(ast, linearizer_opts) + if not NOOPT: + if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations() + if BEAM >= 1: + lins = [(("tc" if used_tensor_cores else "hc"), k)] + kb = Linearizer(ast, linearizer_opts) + kb.required_optimizations() + from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin + # TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions + test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization + lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))))) + if used_tensor_cores: + lins.append(("hc", Linearizer(ast, linearizer_opts))) + lins[-1][1].hand_coded_optimizations() + timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) + if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) + k = timed[0][1] + else: + k.required_optimizations() + return k \ No newline at end of file diff --git a/tinygrad/features/search.py b/tinygrad/features/search.py index b7edd367c4..0bcf6f945b 100644 --- a/tinygrad/features/search.py +++ b/tinygrad/features/search.py @@ -1,7 +1,8 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable import itertools, random, math, time from tinygrad.lazy import vars_from_ast -from tinygrad.ops import Device, Compiled, MemBuffer +from tinygrad.device import Device, Compiled +from tinygrad.ops import MemBuffer from tinygrad.helpers import prod, ImageDType, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, all_int, colored, Timing from tinygrad.codegen.linearizer import Linearizer, UOp from tinygrad.runtime.lib import RawBuffer diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 7b93806ec4..8118a022df 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -2,7 +2,8 @@ from __future__ import annotations from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic import functools, itertools, operator from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int -from tinygrad.ops import RawBuffer, Device, JITRunner, CompiledASTRunner +from tinygrad.device import Device, JITRunner, CompiledASTRunner +from tinygrad.runtime.lib import RawBuffer from tinygrad.tensor import Tensor from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import Variable, NumNode, Node diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 63b34d4e64..a1adf348af 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -4,7 +4,7 @@ from typing import Dict, Union, List, Optional, Any, Tuple from tinygrad.tensor import Tensor from tinygrad.helpers import dtypes, prod, argsort, DEBUG, Timing, GlobalCounters, CI, unwrap from tinygrad.shape.view import strides_for_shape -from tinygrad.ops import Device +from tinygrad import Device safe_dtypes = {"F16": dtypes.float16, "F32": dtypes.float32, "U8": dtypes.uint8, "I8": dtypes.int8, "I32": dtypes.int32, "I64": dtypes.int64} inverse_safe_dtypes = {v:k for k,v in safe_dtypes.items()} diff --git a/tinygrad/ops.py b/tinygrad/ops.py index c7829e5f47..3c631449a9 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1,10 +1,9 @@ from __future__ import annotations -import importlib, inspect, functools, pathlib, time, re +from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Dict, Callable, Mapping +import functools from enum import Enum, auto -from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, Mapping -from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, BEAM, NOOPT, dedup, all_int, to_function_name -from tinygrad.runtime.lib import RawBuffer -from tinygrad.shape.symbolic import Variable, sym_infer, sint +from tinygrad.helpers import prod, DType, dedup +from tinygrad.shape.symbolic import Variable from dataclasses import dataclass # these are the llops your accelerator must implement, along with toCpu @@ -26,8 +25,6 @@ OpType = Union[Type[UnaryOps], Type[BinaryOps], Type[ReduceOps], Type[MovementOp if TYPE_CHECKING: from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.lazy import LazyBuffer - from tinygrad.codegen.linearizer import Linearizer - from tinygrad.codegen.kernel import LinearizerOptions @dataclass(frozen=True) class MemBuffer: @@ -83,26 +80,6 @@ class LazyOp: def shrink(self, _): raise NotImplementedError def stride(self, _): raise NotImplementedError -# **************** Device **************** - -class _Device: - def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")] - def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT - @functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none - def __getitem__(self, x:str) -> Union[Interpreted, Compiled]: - x = x.split(":")[0].upper() - return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0] - @functools.cached_property - def DEFAULT(self) -> str: - device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) # type: ignore - if device_from_env: return device_from_env - for device in ["METAL", "CUDA", "GPU"]: - try: - if self[device]: return device - except Exception: pass - return "CPU" -Device = _Device() - # **************** independent FlopCounter **************** @dataclass @@ -130,174 +107,3 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: @functools.lru_cache(None) # NOTE: this cache needs to be recreated for new ASTs def run_ast(ast): return InterpretedFlopCounter[ast.op](*([run_ast(x) for x in ast.src]+([ast.arg] if ast.arg is not None else []))) return run_ast(ast) - -# **************** GlobalCounters stats **************** - -def update_stats(name:str, op_estimate:sint, mem_estimate:sint, var_vals: Optional[Dict[Variable, int]], et: Optional[float], buf_count, jit=False, num_kernels=1, lra=None): - if var_vals is None: var_vals = {} - op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(mem_estimate, var_vals) - if DEBUG >= 2: - print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + - (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)")) - GlobalCounters.kernel_count += num_kernels - GlobalCounters.global_ops += op_estimate - GlobalCounters.global_mem += mem_estimate - if et is not None: GlobalCounters.time_sum_s += et - -# **************** shared Runner that can go in the JIT **************** - -class JITRunner: - def __init__(self): - self.op_estimate, self.mem_estimate = 0, 0 - def exec(self, rawbufs:List[RawBuffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: - var_vals = var_vals if var_vals is not None else {} - from tinygrad.jit import CacheCollector - et = self(rawbufs, var_vals) - CacheCollector.add(self, rawbufs, var_vals) - return et - def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - raise NotImplementedError("override this") - -# **************** for Interpreted Buffers **************** - -class InterpretedASTRunner(JITRunner): - def __init__(self, ast:LazyOp, fxn:Callable): - super().__init__() - self.fxn = fxn - info = get_lazyop_info(ast) - self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - - def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> float: - st = time.perf_counter() - ret: RawBuffer = self.fxn(rawbufs[1:], var_vals) - et = time.perf_counter() - st - update_stats(f"", self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit) - assert rawbufs[0].dtype == ret.dtype, f"dtype mismatch in Interpreted, {rawbufs[0].dtype=} != {ret.dtype=}" - rawbufs[0].dtype, rawbufs[0].size, rawbufs[0]._buf, rawbufs[0].offset = ret.dtype, ret.size, ret._buf, ret.offset - return et - -class Interpreted: - def __init__(self, buffer: Type[RawBuffer], fxn_for_op:Dict[Op, Callable]): - self.buffer, self.fxn_for_op = buffer, fxn_for_op - self.synchronize, self.codegen, self.graph = lambda: None, None, None - - @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def get_runner(self, ast:LazyOp) -> InterpretedASTRunner: return _get_interpreted_fxn(self.fxn_for_op, ast) - -def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner: - if DEBUG >= 3: - from tinygrad.graph import print_tree - print_tree(ast) - tglob: Dict[str, Any] = {"Variable": Variable} - lines: List[str] = [] - - @functools.lru_cache(None) - def gstr(x:Any, nm=None) -> str: - if ('Variable' in (str_arg := repr(x)) or 'NumNode' in str_arg): - str_arg = re.sub(r'Variable\(.*?\)', lambda m: f'var_vals[{str(m.group(0))}]', str_arg) - # TODO: (Variable - Variable) might create NumNode. can we remove it? - return re.sub(r'NumNode\((.*?)\)', r'\1', str_arg) - ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}" - tglob[ret] = x - return ret - - @functools.lru_cache(None) - def _interpret_ast(ast:LazyOp) -> str: - if TernaryOps.MULACC in fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL: - ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg) - - if ast.op in BufferOps: - tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})" if ast.op == BufferOps.CONST else f"{gstr(fxn_for_op[ast.op], ast.op)}(inputs[{ast.arg.idx-1}])" - for mop,arg in ast.arg.st.to_movement_ops(): tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})" - else: - tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})" - - ret = f"a{len(lines)}" - lines.append(f" {ret} = {tmp}") - return ret - - ret = _interpret_ast(ast) - src = '\n'.join(['def run(inputs, var_vals):'] + lines + [f" return {gstr(fxn_for_op[BufferOps.FROM_UNDERLYING], BufferOps.FROM_UNDERLYING)}({ret})" if BufferOps.FROM_UNDERLYING in fxn_for_op else f" return {ret}"]) - if DEBUG >= 4: print(functools.reduce(lambda x,y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x), tglob.items(), src)) - exec(compile(src, "", "exec"), tglob) # pylint: disable=exec-used - return InterpretedASTRunner(ast, tglob['run']) - -# **************** for Compiled Buffers **************** - -class CompiledASTRunner(JITRunner): - def __init__(self, ast:Optional[LazyOp], name:str, prg:str, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, runtime_args:Optional[dict]=None): - super().__init__() - if DEBUG >= 4: print(prg) - if global_size is not None: global_size = global_size + [1]*(3-len(global_size)) - if local_size is not None: local_size = local_size + [1]*(3-len(local_size)) - self.name, self.display_name, self.prg, self.global_size, self.local_size, self.runtime_args = \ - to_function_name(name), name, prg, global_size, local_size, runtime_args if runtime_args is not None else {} - self.vars: List[Variable] = [] - if ast: - info = get_lazyop_info(ast) - self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate - from tinygrad.lazy import vars_from_ast - self.vars = vars_from_ast(ast) - assert all(v._val is None for v in self.vars), f"ASTRunner contains bound Variable {self.vars}" - - def build(self, compiler, runtime): - self.lib = compiler.__wrapped__(self.prg) if getenv("DISABLE_COMPILER_CACHE") else compiler(self.prg) - self.clprg = runtime(self.name, self.lib) - return self - - def launch_dims(self, var_vals): - global_size = [sym_infer(sz, var_vals) for sz in self.global_size] if self.global_size is not None else self.global_size - local_size = [sym_infer(sz, var_vals) for sz in self.local_size] if self.local_size is not None else self.local_size - return global_size, local_size - - def __call__(self, rawbufs:List[RawBuffer], var_vals:Dict[Variable, int], wait=False, jit=False) -> Optional[float]: - global_size, local_size = self.launch_dims(var_vals) - if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type] - # TODO: this is copied from get_program - from tinygrad.features.search import optimize_local_size - local_size = self.local_size = optimize_local_size(self.clprg, global_size, rawbufs) - global_size = self.global_size = [g//l if g%l == 0 else g/l for g,l in zip(global_size, local_size)] - lra = self.runtime_args.copy() - if global_size: lra['global_size'] = global_size - if local_size and 'local_size' not in lra: lra['local_size'] = local_size - et = self.clprg(*rawbufs, *[var_vals[k] for k in self.vars], **lra, wait=wait or DEBUG>=2) - update_stats(self.display_name, self.op_estimate, self.mem_estimate, var_vals, et, len(rawbufs), jit, lra=lra) - return et - -class Compiled: - def __init__(self, buffer: Type[RawBuffer], linearizer_opts:LinearizerOptions, renderer, compiler, runtime, synchronize=lambda: None, graph=None): - self.buffer, self.linearizer_opts, self.renderer, self.compiler, self.runtime, self.synchronize, self.graph = buffer, linearizer_opts, renderer, compiler, runtime, synchronize, graph - - def to_program(self, k:Linearizer) -> CompiledASTRunner: - k.linearize() - src, runtime_args = self.renderer(to_function_name(k.name), k.uops) - return CompiledASTRunner(k.ast, k.name, src, k.global_size, k.local_size, runtime_args).build(self.compiler, self.runtime) - - @functools.lru_cache(None) # pylint: disable=method-cache-max-size-none - def get_runner(self, ast:LazyOp) -> CompiledASTRunner: return self.to_program(_get_optimized_linearizer(self.linearizer_opts, ast)) - -def _get_optimized_linearizer(linearizer_opts:LinearizerOptions, ast:LazyOp) -> Linearizer: - if DEBUG >= 3: - from tinygrad.graph import print_tree - print_tree(ast) - from tinygrad.codegen.linearizer import Linearizer - k = Linearizer(ast, linearizer_opts) - if not NOOPT: - if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations() - if BEAM >= 1: - lins = [(("tc" if used_tensor_cores else "hc"), k)] - kb = Linearizer(ast, linearizer_opts) - kb.required_optimizations() - from tinygrad.features.search import beam_search, time_linearizer, bufs_from_lin - # TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions - test_rawbuffers = bufs_from_lin(kb) # allocate scratch buffers for optimization - lins.append((f"beam{BEAM.value}", beam_search(kb, test_rawbuffers, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1))))) - if used_tensor_cores: - lins.append(("hc", Linearizer(ast, linearizer_opts))) - lins[-1][1].hand_coded_optimizations() - timed = sorted([(nm, tk, time_linearizer(tk, test_rawbuffers, allow_test_size=False, clear_l2=True)) for nm, tk in lins], key=lambda x: x[2]) - if DEBUG >= 1: print(" < ".join(f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us" for nm, lin, tm in timed)) - k = timed[0][1] - else: - k.required_optimizations() - return k diff --git a/tinygrad/realize.py b/tinygrad/realize.py index a6805b4d00..ec236c86a8 100644 --- a/tinygrad/realize.py +++ b/tinygrad/realize.py @@ -1,6 +1,7 @@ from typing import List, cast, Dict, Callable import numpy as np -from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, BufferOps +from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, BufferOps +from tinygrad.device import Device from tinygrad.graph import log_schedule_item, print_tree from tinygrad.lazy import LazyBuffer from tinygrad.helpers import DEBUG, prod, all_int, IMAGE diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index 59116a9333..470a784de3 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,6 +1,6 @@ import time, ctypes, subprocess, platform, functools, pathlib, tempfile from typing import Any -from tinygrad.ops import Compiled +from tinygrad.device import Compiled from tinygrad.helpers import diskcache from tinygrad.runtime.lib import RawMallocBuffer from tinygrad.codegen.kernel import LinearizerOptions diff --git a/tinygrad/runtime/ops_cpu.py b/tinygrad/runtime/ops_cpu.py index 98cc2032b4..a5230440c5 100644 --- a/tinygrad/runtime/ops_cpu.py +++ b/tinygrad/runtime/ops_cpu.py @@ -1,7 +1,8 @@ import numpy as np from typing import Callable, Dict, Tuple, Optional from tinygrad.helpers import dtypes, DType -from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op, Interpreted +from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, ReduceOps, TernaryOps, Op +from tinygrad.device import Interpreted from tinygrad.runtime.lib import RawBuffer class RawNumpyBuffer(RawBuffer): diff --git a/tinygrad/runtime/ops_cuda.py b/tinygrad/runtime/ops_cuda.py index b5be0a758c..2359d1f015 100644 --- a/tinygrad/runtime/ops_cuda.py +++ b/tinygrad/runtime/ops_cuda.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple import numpy as np from pycuda.compiler import compile as cuda_compile from tinygrad.helpers import DEBUG, getenv, colored, diskcache -from tinygrad.ops import Compiled +from tinygrad.device import Compiled from tinygrad.runtime.lib import RawBufferCopyInOut, RawMallocBuffer, LRUAllocator from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cuda import CUDARenderer diff --git a/tinygrad/runtime/ops_disk.py b/tinygrad/runtime/ops_disk.py index 15181a43f1..ef9c4b2951 100644 --- a/tinygrad/runtime/ops_disk.py +++ b/tinygrad/runtime/ops_disk.py @@ -5,7 +5,8 @@ from typing import Optional from typing import Callable, Dict, Tuple from tinygrad.helpers import prod, all_int, DType, OSX from tinygrad.runtime.lib import RawBufferMapped -from tinygrad.ops import Interpreted, Op, MovementOps, UnaryOps, BufferOps +from tinygrad.device import Interpreted +from tinygrad.ops import Op, MovementOps, UnaryOps, BufferOps from tinygrad.shape.view import strides_for_shape MAP_LOCKED, MAP_POPULATE = 0x2000, 0x008000 diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 26f790da4d..8c647a0468 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -6,7 +6,7 @@ import numpy as np import pyopencl as cl from typing import Optional, List, Tuple from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache -from tinygrad.ops import Compiled +from tinygrad.device import Compiled from tinygrad.renderer.opencl import OpenCLRenderer from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer from tinygrad.codegen.kernel import LinearizerOptions diff --git a/tinygrad/runtime/ops_hip.py b/tinygrad/runtime/ops_hip.py index 7f0547b1fe..653668dc73 100644 --- a/tinygrad/runtime/ops_hip.py +++ b/tinygrad/runtime/ops_hip.py @@ -3,7 +3,7 @@ import ctypes import extra.hip_wrapper as hip from typing import Tuple, List, Any, Dict, cast, Optional, Callable from tinygrad.helpers import DEBUG, getenv, diskcache -from tinygrad.ops import Compiled, CompiledASTRunner, update_stats +from tinygrad.device import Compiled, CompiledASTRunner, update_stats from tinygrad.renderer.hip import HIPRenderer from tinygrad.runtime.lib import RawBufferCopyInOut, LRUAllocator, RawBufferTransfer, RawBuffer from tinygrad.codegen.kernel import LinearizerOptions diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 708e6b827d..2ac9e4e448 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -1,6 +1,6 @@ import time, ctypes from typing import ClassVar -from tinygrad.ops import Compiled +from tinygrad.device import Compiled from tinygrad.helpers import getenv, DEBUG, diskcache from ctypes import CFUNCTYPE from tinygrad.codegen.kernel import LinearizerOptions diff --git a/tinygrad/runtime/ops_metal.py b/tinygrad/runtime/ops_metal.py index 85fbc9e0d7..45ff3bbd6b 100644 --- a/tinygrad/runtime/ops_metal.py +++ b/tinygrad/runtime/ops_metal.py @@ -3,7 +3,7 @@ import Metal, libdispatch from typing import List, Any, Tuple, Dict, cast, Optional from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.helpers import prod, getenv, DEBUG, DType, dtypes, diskcache, dedup -from tinygrad.ops import Compiled, CompiledASTRunner, update_stats +from tinygrad.device import Compiled, CompiledASTRunner, update_stats from tinygrad.renderer.metal import MetalRenderer from tinygrad.runtime.lib import RawBufferMapped, RawBuffer, LRUAllocator from tinygrad.shape.symbolic import Variable diff --git a/tinygrad/runtime/ops_torch.py b/tinygrad/runtime/ops_torch.py index de7e44295f..14f38f139c 100644 --- a/tinygrad/runtime/ops_torch.py +++ b/tinygrad/runtime/ops_torch.py @@ -1,7 +1,8 @@ import torch import numpy as np from typing import Dict, Callable, Optional -from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op, Interpreted +from tinygrad.ops import BufferOps, UnaryOps, BinaryOps, MovementOps, TernaryOps, ReduceOps, Op +from tinygrad.device import Interpreted from tinygrad.helpers import getenv, dtypes, prod, DType from tinygrad.runtime.ops_cpu import einsum_mulacc, shape_to_axis from tinygrad.runtime.lib import RawBuffer diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py index 319b2018ee..12887c65ed 100644 --- a/tinygrad/runtime/ops_webgpu.py +++ b/tinygrad/runtime/ops_webgpu.py @@ -3,7 +3,7 @@ import functools from wgpu.utils.device import get_default_device from tinygrad.runtime.lib import RawBufferCopyIn, LRUAllocator from tinygrad.helpers import dtypes, DType -from tinygrad.ops import Compiled +from tinygrad.device import Compiled from tinygrad.codegen.kernel import LinearizerOptions from tinygrad.renderer.cstyle import uops_to_cstyle from tinygrad.renderer.wgsl import WGSLLanguage diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 40ad59765b..74862df718 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,7 +9,8 @@ import numpy as np from tinygrad.helpers import ImageDType, argfix, make_pair, getenv, IMAGE, DEBUG, flatten, DType, dtypes, prod, all_int, round_up from tinygrad.lazy import LazyBuffer -from tinygrad.ops import Device, LoadOps +from tinygrad.ops import LoadOps +from tinygrad.device import Device from tinygrad.shape.symbolic import sint from tinygrad.realize import run_schedule