move graph.py and jit.py into features (#3376)

* move graph.py into features

* move jit into features

* fix quickstart
This commit is contained in:
George Hotz
2024-02-12 17:34:34 +01:00
committed by GitHub
parent 0f6cde243d
commit 41efaa848c
41 changed files with 77 additions and 47 deletions

View File

@@ -277,7 +277,7 @@ result = Tensor(2.0) + Tensor(3.0)
# we have a global cache used by the JIT
# from there, we can see the generated clang code
from tinygrad.jit import CacheCollector
from tinygrad.features.jit import CacheCollector
CacheCollector.start() # enables the cache
result.realize() # create the program and runs it
cache_saved = CacheCollector.finish() # disable the cache

View File

@@ -91,7 +91,7 @@ sched = out.schedule()
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
# DEBUGGING: print the compute ast as a tree
from tinygrad.graph import print_tree
from tinygrad.features.graph import print_tree
print_tree(sched[-1].ast)
# NOTE: sched[-1].ast is the same as st_0 above

29
docs/linearizer_v2.md Normal file
View File

@@ -0,0 +1,29 @@
At base, the Linearizer a function that takes an AST + opts -> uops
It should be rewritten like this. The AST can't be a LazyOp, because it should be able to have multiple outputs
We need a generic class to represent DAGs.
This refactor is probably a prereq for the new linearizer, and can be used on existing uops also.
Can this class also represent the large graph? The op graph is a subset of the large graph.
Currently the Linearizer is merging many concerns:
1. LocalBuffers are added. These should be added to the upper DAG, for both grouping and tensor cores. Some opts are used here. NOTE: currently reduce splitting is done in lazy.py and it shouldn't be
2. The ShapeTrackers at the edges are collected and modified according to the other opts.
3. The Ops are toposorted.
4. The Ops are lowered to UOps. This requires expansion and loop assignment, potentially to global dimensions
5. The indexes into the Tensor are computed from the shapetrackers
More generically, the whole network is a DAG. Ignore the forward/backward stuff, I'm fine with starting at the LazyBuffer level.
1. Is it possible to put an entire network in a single kernel? I think the answer has to be yes, but you may end up doing an absolutely crazy amount of recomputation. This should still be doable to check correctness.
2. You can use intermediate buffers, be they local or global, to do less compute.
This is a rewrite of a lot of tinygrad. I don't think continuing to support Interpreted backends is worth it, have to deal with disk in a smart way.
We keep the frontend: tensor.py + mlops.py + lazy.py
We keep the backend (renderer/runtime): cstyle.py + device.py + ops_*.py
We keep the shapetracker/symbolic: shapetracker.py + view.py + symbolic.py
We keep the features and nn stuff.
But codegen is all rewritten.

View File

@@ -247,7 +247,7 @@ To use the JIT we just need to add a function decorator to the forward pass of o
Or in this case we will create a wrapper function and decorate the wrapper function to speed up the evaluation of our neural network.
```python
from tinygrad.jit import TinyJit
from tinygrad import TinyJit
@TinyJit
def jit(x):

View File

@@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters
from tinygrad.nn import optim
from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import getenv
from tinygrad.jit import CacheCollector
from tinygrad.features.jit import CacheCollector
def tensors_allocated():
return sum(isinstance(x, Tensor) for x in gc.get_objects())

View File

@@ -9,7 +9,7 @@ import numpy as np
from PIL import Image
from tinygrad.tensor import Tensor
from tinygrad.helpers import getenv, fetch, Timing
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from extra.models.efficientnet import EfficientNet
np.set_printoptions(suppress=True)

View File

@@ -3,7 +3,7 @@ start = time.perf_counter()
from pathlib import Path
import numpy as np
from tinygrad import Tensor, Device, dtypes, GlobalCounters
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad.nn.state import get_parameters, load_state_dict, safe_load
from tinygrad.helpers import getenv, Timing
from examples.mlperf import helpers
@@ -103,7 +103,7 @@ def eval_retinanet():
coco_eval = COCOeval(coco, iouType="bbox")
coco_evalimgs, evaluated_imgs, ncats, narea = [], [], len(coco_eval.params.catIds), len(coco_eval.params.areaRng)
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
mdlrun = TinyJit(lambda x: mdl(input_fixup(x)).realize())
n, bs = 0, 8

View File

@@ -13,7 +13,7 @@ from tinygrad import Device, GlobalCounters, dtypes, Tensor
from tinygrad.helpers import Timing, Context, getenv, fetch, colored
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
class AttnBlock:
def __init__(self, in_channels):

View File

@@ -9,7 +9,7 @@ from tinygrad import nn, dtypes
from tinygrad.helpers import fetch
from tinygrad.nn.state import torch_load
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from unidecode import unidecode
LRELU_SLOPE = 0.1

View File

@@ -6,7 +6,7 @@ import base64
import multiprocessing
import numpy as np
from typing import Optional, Union, Literal, List
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad.nn.state import torch_load, load_state_dict
from tinygrad.helpers import getenv, DEBUG, CI, fetch
import tinygrad.nn as nn

View File

@@ -91,7 +91,7 @@ def pad_input(volume, roi_shape, strides, padding_mode="constant", padding_val=-
return F.pad(torch.from_numpy(volume), paddings, mode=padding_mode, value=padding_val).numpy(), paddings
def sliding_window_inference(model, inputs, labels, roi_shape=(128, 128, 128), overlap=0.5):
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
mdl_run = TinyJit(lambda x: model(x).realize())
image_shape, dim = list(inputs.shape[2:]), len(inputs.shape[2:])
strides = [int(roi_shape[i] * (1 - overlap)) for i in range(dim)]

2
extra/dist/world.py vendored
View File

@@ -9,7 +9,7 @@ try:
from tinygrad.runtime.ops_hip import RawHIPBuffer, check
except: RawHIPBuffer = None
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.jit import CacheCollector
from tinygrad.features.jit import CacheCollector
from tinygrad.tensor import Tensor, Function
import numpy as np

View File

@@ -1,7 +1,7 @@
from typing import Tuple, Dict, List
from tinygrad.dtype import DType
from tinygrad.tensor import Device, Tensor
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad.nn.state import get_state_dict
from tinygrad.dtype import dtypes
import json
@@ -133,7 +133,7 @@ def export_model_webgl(functions, statements, bufs, bufs_to_save, weight_names,
gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, textures[0].tex, 0);
gl.useProgram(program);
gl.uniform1i(gl.getUniformLocation(program, "width"), textures[0].width);
gl.uniform1i(gl.getUniformLocation(program, "width"), textures[0].width);
const vao = setupVertexData(gl, program, [-1, 1, 0, 1, -1, -1, 0, 0, 1, 1, 1, 1, 1, -1, 1, 0]);
gl.bindVertexArray(vao);
@@ -158,13 +158,13 @@ def export_model_webgl(functions, statements, bufs, bufs_to_save, weight_names,
function limitTextureDims(size, threshold) {{
if (size <= threshold) {{ return [size, 1] }};
for (let i = 2; i < threshold + 1; i++) {{
if ((size % i == 0) && (Math.floor(size / i) <= threshold)) {{
return [Math.floor(size / i), i];
}}
}}
return [size, 1];
}}
@@ -197,11 +197,11 @@ def export_model_webgl(functions, statements, bufs, bufs_to_save, weight_names,
const internalFormat = gl.RGBA;
const texSize = limitTextureDims(size, gl.getParameter(gl.MAX_TEXTURE_SIZE));
let weights;
if (tensorBuffer != null) {{
if (!isHalf)
weights = new Float32Array(tensorBuffer.buffer, tensorBuffer.byteOffset, tensorBuffer.byteLength / Float32Array.BYTES_PER_ELEMENT);
else
else
weights = new Uint16Array(tensorBuffer.buffer, tensorBuffer.byteOffset, tensorBuffer.byteLength / Uint16Array.BYTES_PER_ELEMENT);
}} else {{
if (!isHalf)

View File

@@ -32,7 +32,7 @@ except RuntimeError:
print("no torch metal conv")
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad import Device
b = Tensor(nb)
c = Tensor(nc)

View File

@@ -116,7 +116,7 @@ tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator

View File

@@ -4,7 +4,7 @@ import numpy as np
import time, torch, torch.mps
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad import Device, GlobalCounters, dtypes
from tinygrad.helpers import colored, getenv, CI, flat_mv
@@ -108,7 +108,7 @@ metalalloc.copyout(flat_mv(metal_a.data), a)
np.testing.assert_allclose(metal_a, torch_a, atol=5e-3)
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator

View File

@@ -1,5 +1,5 @@
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad.nn import Linear, Embedding
from tinygrad.helpers import fetch
import numpy as np

View File

@@ -2,7 +2,7 @@ import numpy as np
from tqdm import trange
from tinygrad.tensor import Tensor
from tinygrad.helpers import CI
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
def train(model, X_train, Y_train, optim, steps, BS=128, lossfn=lambda out,y: out.sparse_categorical_crossentropy(y),

View File

@@ -66,7 +66,7 @@ def benchmark_model(m, devices, validate_outs=False):
tinygrad_model = get_run_onnx(onnx_model)
benchmark(m, f"tinygrad_{device.lower()}_jitless", lambda: {k:v.numpy() for k,v in tinygrad_model(inputs).items()})
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
tinygrad_jitted_model = TinyJit(lambda **kwargs: {k:v.realize() for k,v in tinygrad_model(kwargs).items()})
for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821

View File

@@ -1,5 +1,5 @@
from extra import dist
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
if __name__ == "__main__":
dist.preinit()

View File

@@ -1,5 +1,5 @@
from extra import dist
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
if __name__ == "__main__":
dist.preinit()

View File

@@ -2,7 +2,7 @@
import unittest
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad.helpers import CI
from test.helpers import derandomize_model

View File

@@ -17,7 +17,7 @@ from tinygrad.helpers import getenv
from tinygrad.nn import optim
#from tinygrad.lazy import PUSH_PERMUTES
PUSH_PERMUTES = False
from tinygrad.jit import CacheCollector
from tinygrad.features.jit import CacheCollector
class CLCache:
def __init__(self, allowed=None, strict=False, preclear=True, var_vals=None):

View File

@@ -3,7 +3,7 @@ import unittest
from tinygrad.tensor import Tensor
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.graph import graph_uops
from tinygrad.features.graph import graph_uops
from tinygrad.nn import Conv2d
class TestUopsGraph(unittest.TestCase):

View File

@@ -6,7 +6,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_lin
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.tensor import Tensor
from tinygrad.graph import print_tree
from tinygrad.features.graph import print_tree
from tinygrad.helpers import getenv, from_mv, Context
from tinygrad.device import Device, Compiled, Interpreted
from tinygrad.codegen.linearizer import UOp

View File

@@ -3,7 +3,7 @@ import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.nn import optim
from tinygrad.nn.state import get_parameters
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad import Device, GlobalCounters, dtypes
from tinygrad.helpers import CI
from tinygrad.shape.symbolic import Variable

View File

@@ -87,7 +87,7 @@ class TestCustomFunction(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT in ["CPU"], "atan2_cpu not jittable")
def test_atan2_jit(self):
# custom ops even work in the JIT!
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
@TinyJit
def jitted_atan2(a:Tensor, b:Tensor) -> Tensor:

View File

@@ -4,7 +4,7 @@ import numpy as np
from test.helpers import assert_jit_cache_len
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad.device import Device
from tinygrad.helpers import CI

View File

@@ -9,7 +9,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode
from tinygrad.tensor import Tensor
from tinygrad.jit import CacheCollector
from tinygrad.features.jit import CacheCollector
from tinygrad.realize import run_schedule
from tinygrad.helpers import prod, Context
from tinygrad.dtype import DType, dtypes

View File

@@ -9,7 +9,7 @@ from tinygrad.ops import LoadOps
from tinygrad.device import Device, Compiled
from tinygrad.helpers import DEBUG, GRAPH
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad import nn, dtypes
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):

View File

@@ -13,7 +13,7 @@ from tinygrad import Device, GlobalCounters
from tinygrad.tensor import Tensor
from tinygrad.nn import Conv2d
from tinygrad.helpers import colored, getenv, CI
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
IN_CHANS = [int(x) for x in getenv("IN_CHANS", "4,16,64").split(",")]

View File

@@ -1,7 +1,7 @@
import unittest
from test.helpers import assert_jit_cache_len
from tinygrad.jit import TinyJit
from tinygrad.features.jit import TinyJit
from tinygrad.helpers import getenv
from tinygrad.shape.symbolic import Variable
from tinygrad.tensor import Tensor

View File

@@ -1,5 +1,5 @@
from tinygrad.tensor import Tensor # noqa: F401
from tinygrad.jit import TinyJit # noqa: F401
from tinygrad.features.jit import TinyJit # noqa: F401
from tinygrad.shape.symbolic import Variable # noqa: F401
from tinygrad.dtype import dtypes # noqa: F401
from tinygrad.ops import GlobalCounters # noqa: F401

View File

@@ -383,7 +383,7 @@ class Linearizer(Kernel):
for u in self.uops:
print(f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}") # noqa: E501
if getenv("GRAPHUOPS"):
from tinygrad.graph import graph_uops
from tinygrad.features.graph import graph_uops
graph_uops(self.uops)
# restore backups

View File

@@ -43,7 +43,7 @@ class JITRunner:
def __init__(self): self.op_estimate, self.mem_estimate = 0, 0
def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]:
var_vals = var_vals if var_vals is not None else {}
from tinygrad.jit import CacheCollector
from tinygrad.features.jit import CacheCollector
et = self(rawbufs, var_vals)
CacheCollector.add(self, rawbufs, var_vals)
return et
@@ -194,7 +194,7 @@ class Interpreted:
def _get_interpreted_fxn(fxn_for_op:Dict[Op, Callable], ast:LazyOp) -> InterpretedASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
from tinygrad.features.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
@@ -322,7 +322,7 @@ class Compiled:
def get_linearizer(self, ast:LazyOp) -> Linearizer:
if DEBUG >= 3:
from tinygrad.graph import print_tree
from tinygrad.features.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, self.compiler.linearizer_opts)

View File

@@ -8,7 +8,7 @@ from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Bu
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from tinygrad.graph import log_lazybuffer
from tinygrad.features.graph import log_lazybuffer
from weakref import ref, ReferenceType
# lazy can recurse a lot

View File

@@ -1,7 +1,7 @@
from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, InterpretedASTRunner, Compiled, BufferOptions
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG
from tinygrad.shape.symbolic import Variable

View File

@@ -5,7 +5,8 @@ from tinygrad.helpers import init_c_var, encode_args_cuda_style, all_same, Graph
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
from tinygrad.runtime.ops_cuda import check, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, \
get_jc_idxs_with_updatable_launch_dims, get_jc_idxs_with_updatable_var_vals
class CUDAGraph:
def __init__(self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):

View File

@@ -3,7 +3,7 @@ import Metal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, unwrap2, GraphException
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
from tinygrad.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.features.jit import JitItem, get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_metal import MetalDevice