move create schedule and delete old API (#3377)

* move create schedule and delete old API

* fix test multitensor
This commit is contained in:
George Hotz
2024-02-12 18:10:45 +01:00
committed by GitHub
parent 41efaa848c
commit 2e60012bcf
21 changed files with 252 additions and 242 deletions

View File

@@ -250,7 +250,8 @@ result = Tensor(2.0).realize() + Tensor(3.0).realize()
# use the real Linearizer to linearize 2+3
from tinygrad.codegen.linearizer import Linearizer
sched = result.lazydata.schedule()
from tinygrad.realize import create_schedule
sched = create_schedule([result.lazydata])
linearizer = Linearizer(sched[-1].ast, ClangCompiler.linearizer_opts)
linearizer.linearize()

View File

@@ -73,7 +73,7 @@ assert out.as_buffer().cast('I')[0] == 5
print("******** third, the LazyBuffer ***********")
from tinygrad.lazy import LazyBuffer, LoadOps
from tinygrad.realize import run_schedule
from tinygrad.realize import run_schedule, create_schedule
# allocate some values + load in values
# TODO: remove numpy here
@@ -87,7 +87,7 @@ b.realized = Buffer("CPU", 1, dtypes.int32, np.array([3], np.int32).flatten())
out = a.e(BinaryOps.ADD, b)
# schedule the computation as a list of kernels
sched = out.schedule()
sched = create_schedule([out])
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
# DEBUGGING: print the compute ast as a tree

View File

@@ -20,10 +20,8 @@ More generically, the whole network is a DAG. Ignore the forward/backward stuff,
This is a rewrite of a lot of tinygrad. I don't think continuing to support Interpreted backends is worth it, have to deal with disk in a smart way.
We keep the frontend: tensor.py + mlops.py + lazy.py
We keep the backend (renderer/runtime): cstyle.py + device.py + ops_*.py
We keep the shapetracker/symbolic: shapetracker.py + view.py + symbolic.py
We keep the features and nn stuff.
But codegen is all rewritten.
We keep the frontend (Tensor -> LazyBuffer): tensor.py + mlops.py + lazy.py
We keep the shapetracker/symbolic (part of the frontend): shapetracker.py + view.py + symbolic.py
Codegen is all rewritten.
We keep the backend (uops renderer/runtime): cstyle.py + device.py + ops_*.py

View File

@@ -8,6 +8,7 @@ from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
from tinygrad.helpers import ansilen, DEBUG, getenv
from tinygrad.shape.symbolic import sym_infer
from tinygrad.dtype import dtypes
from tinygrad.realize import create_schedule
if __name__ == "__main__":
if getenv("HALF"):
@@ -21,12 +22,12 @@ if __name__ == "__main__":
print(f"optimizing for {Device.DEFAULT}")
# first model run to init the weights, they are saved in seen
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
create_schedule([mdl(Tensor.empty(64, 3, 224, 224)).lazydata], seen)
# run model again to get only what changes, these are the kernels of the model
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = out.lazydata.schedule(seen)
sched = create_schedule([out.lazydata], seen)
sched = [x for x in sched if x.ast.op not in LoadOps]
# focus on one kernel

View File

@@ -3,13 +3,14 @@ from tinygrad.ops import LoadOps
from tinygrad.codegen.linearizer import Linearizer
from test.external.fuzz_linearizer import run_linearizer
from tinygrad.codegen.kernel import Opt, OptOps
from tinygrad.realize import create_schedule
N = 17**3
a = Tensor.rand(N, N)
b = Tensor.rand(N, N)
c = a @ b
sched = [si for si in c.lazydata.schedule() if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([c.lazydata]) if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
@@ -24,7 +25,7 @@ run_linearizer(lin)
###
a = Tensor.rand(61, 61).sum(axis=0)
sched = [si for si in a.lazydata.schedule() if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([a.lazydata]) if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)

View File

@@ -30,13 +30,14 @@ except ImportError:
import os
from tinygrad.tensor import Tensor
from tinygrad.realize import create_schedule
# define the compute
A = Tensor.rand(M, K, device="clang")
B = Tensor.rand(K, N, device="clang")
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
sched = C.lazydata.schedule()
sched = create_schedule([C.lazydata])
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.kernel import LinearizerOptions
lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False))

View File

@@ -16,7 +16,7 @@ from extra.onnx import get_run_onnx
from tinygrad import Tensor, Device, GlobalCounters, dtypes
from tinygrad.dtype import ImageDType
from tinygrad.helpers import partition, Context, fetch, getenv, GRAPH, DEBUG
from tinygrad.realize import run_schedule, lower_schedule_item
from tinygrad.realize import run_schedule, lower_schedule_item, create_schedule
from tinygrad.ops import LoadOps, ScheduleItem
Device.DEFAULT = "GPU"
@@ -32,7 +32,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
# run the model
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
schedule = ret.lazydata.schedule()
schedule = create_schedule([ret.lazydata])
# filter schedule that don't depend on the inputs
input_lb = [x.lazydata.base for x in inputs.values()]

View File

@@ -4,6 +4,7 @@ from tinygrad.tensor import Tensor
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.renderer.cstyle import OpenCLRenderer
from tinygrad.features.graph import graph_uops
from tinygrad.realize import create_schedule
from tinygrad.nn import Conv2d
class TestUopsGraph(unittest.TestCase):
@@ -11,7 +12,7 @@ class TestUopsGraph(unittest.TestCase):
N = 1024
a = Tensor.rand(N,N)
b = Tensor.rand(N,N)
si = (a@b).lazydata.schedule()[-1]
si = create_schedule([(a@b).lazydata])[-1]
lin = Linearizer(si.ast)
lin.hand_coded_optimizations()
print(lin.colored_shape())
@@ -22,7 +23,7 @@ class TestUopsGraph(unittest.TestCase):
def test_reduce(self):
a = Tensor.rand(1024*1024)
si = a.sum().lazydata.schedule()[-1]
si = create_schedule([a.sum().lazydata])[-1]
lin = Linearizer(si.ast)
lin.hand_coded_optimizations()
uops = lin.linearize().uops
@@ -32,7 +33,7 @@ class TestUopsGraph(unittest.TestCase):
def test_conv(self):
x = Tensor.rand(1,3,16,16)
c = Conv2d(3, 16, (3,3))
si = c(x).elu().lazydata.schedule()[-1]
si = create_schedule([c(x).elu().lazydata])[-1]
lin = Linearizer(si.ast)
lin.hand_coded_optimizations()
uops = lin.linearize().uops

View File

@@ -3,6 +3,7 @@ import unittest
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.nn import Conv2d
from tinygrad.realize import create_schedule
class TestConvShapetracker(unittest.TestCase):
def test_conv_3x3_one_view(self):
@@ -10,9 +11,9 @@ class TestConvShapetracker(unittest.TestCase):
seen = set()
# first run to init the weights, they are saved in seen
conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen)
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
# run it again to get the kernels
sched = [si for si in conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen) if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op not in LoadOps]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
print(sched[0])
for arg in [sched[0].out, *sched[0].inputs]:

View File

@@ -6,6 +6,7 @@ import numpy as np
from hypothesis import given, strategies as strat, settings
from tinygrad.dtype import DType
from tinygrad.helpers import CI, getenv
from tinygrad.realize import create_schedule
from tinygrad.ops import UnaryOps, get_lazyop_info
from test.test_dtype import is_dtype_supported
@@ -64,7 +65,7 @@ def universal_test(a, b, dtype, op):
def universal_test_unary(a, dtype, op):
if not isinstance(op, tuple): op = (op, op)
out: Tensor = op[0](Tensor([a], dtype=dtype))
ast = out.lazydata.schedule()[-1].ast
ast = create_schedule([out.lazydata])[-1].ast
tensor_value = out.numpy()
numpy_value = op[1](np.array([a]).astype(dtype.np))
if dtype in dtypes_float:

View File

@@ -3,8 +3,7 @@ import time
import numpy as np
from tinygrad import Tensor, dtypes
from tinygrad.device import InterpretedASTRunner
from tinygrad.lazy import create_schedule
from tinygrad.realize import run_schedule, lower_schedule_item
from tinygrad.realize import run_schedule, create_schedule, lower_schedule_item
class TestFusionOp(unittest.TestCase):
def test_contiguous_add(self):

View File

@@ -1,5 +1,6 @@
import unittest
from tinygrad.tensor import Tensor
from tinygrad.realize import create_schedule
# stuff needed to unpack a kernel
# ruff: noqa: F401
@@ -16,7 +17,7 @@ inf, nan = float('inf'), float('nan')
class TestLazyOp(unittest.TestCase):
def test_lazyop_str(self):
t = Tensor.rand(10) + Tensor.rand(10)
s = t.lazydata.schedule()
s = create_schedule([t.lazydata])
ast = s[-1].ast
ast_remade = eval(str(ast))
self.assertEqual(ast, ast_remade)

View File

@@ -10,7 +10,7 @@ from tinygrad.shape.view import View
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode
from tinygrad.tensor import Tensor
from tinygrad.features.jit import CacheCollector
from tinygrad.realize import run_schedule
from tinygrad.realize import create_schedule, run_schedule
from tinygrad.helpers import prod, Context
from tinygrad.dtype import DType, dtypes
@@ -33,7 +33,7 @@ class TestLinearizer(unittest.TestCase):
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
k = Linearizer(r.lazydata.schedule()[-1].ast)
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
@@ -46,7 +46,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
k = Linearizer(r.lazydata.schedule()[-1].ast)
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
@@ -56,7 +56,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack([a, b])
k = Linearizer(r.lazydata.schedule()[-1].ast)
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
@@ -67,7 +67,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor(2), Tensor(3)
r = a * b
k = Linearizer(r.lazydata.schedule()[-1].ast)
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
assert num_ops <= 0, "more load or alu uops than needed"
@@ -76,14 +76,14 @@ class TestLinearizer(unittest.TestCase):
for tensor_dtype, acc_dtype in (
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
k = Linearizer(a.lazydata.schedule()[-1].ast)
k = Linearizer(create_schedule([a.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == acc_dtype
def test_arg_acc_dtype(self):
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
k = Linearizer(c.lazydata.schedule()[-1].ast)
k = Linearizer(create_schedule([c.lazydata])[-1].ast)
k.linearize()
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
assert local[0].dtype == expected_dtype
@@ -121,7 +121,7 @@ class TestLinearizer(unittest.TestCase):
def test_limit_dims_to_max_5d_global(self):
t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9)
@@ -129,7 +129,7 @@ class TestLinearizer(unittest.TestCase):
def test_sum_collapse(self):
t = Tensor.ones(256,256).sum()
sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps]
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
assert len(sched) == 1
lin = Linearizer(sched[0].ast)
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
@@ -154,7 +154,7 @@ class TestLinearizer(unittest.TestCase):
arg=TernaryOps.WHERE).uop == UOps.ALU
def helper_realized_ast(r:Tensor):
s = r.lazydata.schedule()
s = create_schedule([r.lazydata])
run_schedule(s[:-1]) # run all kernels except the last one
# now all input LazyBuffers buffers in s[-1] should be realized
# allocate an output buffer
@@ -176,7 +176,7 @@ class TestFloat4(unittest.TestCase):
b = Tensor.rand(2, 8).realize()
c = a + b
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.hand_coded_optimizations()
k.linearize()
@@ -188,7 +188,7 @@ class TestFloat4(unittest.TestCase):
b = Tensor.rand(2, 8).realize()
c = a + b
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.shift_to(0, 4) # float4 dimension
k.shift_to(0, 2, insert_before=k.shape_len-1)
@@ -204,7 +204,7 @@ class TestFloat4(unittest.TestCase):
b = Tensor.rand(9).realize().shrink(((1, 9),))
c = a + b
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.hand_coded_optimizations() # implicit trigger float4 dim
k.linearize()
@@ -216,7 +216,7 @@ class TestFloat4(unittest.TestCase):
b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
c = a + b
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
k.upcast()
@@ -234,7 +234,7 @@ class TestFloat4(unittest.TestCase):
# only the first and last conv dot products are aligned in a, and b is never aligned, so no
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.upcast()
k.linearize()
@@ -249,7 +249,7 @@ class TestFloat4(unittest.TestCase):
# dimension, then we could do float4 for only that one set of loads, but we currently
# don't.
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.upcast()
k.upcast()
@@ -265,7 +265,7 @@ class TestFloat4(unittest.TestCase):
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
# since the top axis is not contiguous.
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.shift_to(0, 4, top=True) # top axes are float4 axes
k.upcast()
@@ -281,7 +281,7 @@ class TestFloat4(unittest.TestCase):
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
# since the top axis is not contiguous.
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
@@ -296,7 +296,7 @@ class TestFloat4(unittest.TestCase):
# should float4 b but not a
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
@@ -310,7 +310,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)])
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
s = layer_2.lazydata.schedule()[-1]
s = create_schedule([layer_2.lazydata])[-1]
k = Linearizer(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
@@ -323,7 +323,7 @@ class TestHandCodedOpts(unittest.TestCase):
def test_masked_upcast_wino(self):
monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
s = monster.lazydata.schedule()[-1]
s = create_schedule([monster.lazydata])[-1]
k = Linearizer(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
@@ -335,7 +335,7 @@ class TestHandCodedOpts(unittest.TestCase):
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
out = Tensor.conv2d(x,w, padding=1)
upcasts = []
wino_schedule = out.lazydata.schedule()
wino_schedule = create_schedule([out.lazydata])
# collect upcasts of tile transform kernels
for i, si in enumerate(wino_schedule):
k = Linearizer(si.ast)
@@ -349,7 +349,7 @@ class TestHandCodedOpts(unittest.TestCase):
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
out.mean().backward()
backward_schedule = x.grad.lazydata.schedule() + w.grad.lazydata.schedule()
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
for si in backward_schedule:
k = Linearizer(si.ast)
k.hand_coded_optimizations()
@@ -364,7 +364,7 @@ class TestHandCodedOpts(unittest.TestCase):
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4))
layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4))
s = layer_3.lazydata.schedule()[-1]
s = create_schedule([layer_3.lazydata])[-1]
k = Linearizer(s.ast)
k.hand_coded_optimizations()
assert len(k.bufs) == 5 # make sure all ops are done in one kernel
@@ -379,7 +379,7 @@ class TestHandCodedOpts(unittest.TestCase):
b = Tensor.rand(N, N).realize()
c = a @ b
s = c.lazydata.schedule()[0]
s = create_schedule([c.lazydata])[0]
k = Linearizer(s.ast)
k.hand_coded_optimizations()

View File

@@ -5,6 +5,7 @@ from tinygrad.device import BufferCopy
from tinygrad.ops import LoadOps, ReduceOps
from tinygrad.helpers import CI
from tinygrad.nn.state import get_parameters
from tinygrad.realize import create_schedule
import numpy as np
from hypothesis import given, strategies as strat, settings
@@ -296,7 +297,7 @@ class TestMultiTensor(unittest.TestCase):
for p in get_parameters(bn): p.shard_(devices).realize()
out = bn(t)
scheds = [sched for sched in out.lazydata.schedule() if sched.out.device in devices and sched.ast.op is not LoadOps.COPY]
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.out.device in devices and sched.ast.op is not LoadOps.COPY]
assert set(sched.out.device for sched in scheds) == set(devices), "should have ast on each shard device"
asts = [sched.ast for sched in scheds]
assert len(asts) == 8, len(asts)
@@ -527,9 +528,9 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
p.shard_(devices)
synced_out = synced_bn(x)
synced_si = [si for si in synced_out.lazydata.schedule()]
synced_si = [si for si in create_schedule(synced_out.lazydata.lbs)]
unsynced_out = unsynced_bn(x)
unsynced_si = [si for si in unsynced_out.lazydata.schedule()]
unsynced_si = [si for si in create_schedule(unsynced_out.lazydata.lbs)]
# TODO: test synced / unsynced batchnorm cross device kernel and copies
assert synced_si

View File

@@ -10,16 +10,17 @@ from tinygrad.device import Device, Compiled
from tinygrad.helpers import DEBUG, GRAPH
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.realize import create_schedule
from tinygrad import nn, dtypes
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
seen = set()
if to_prerealize:
for pre in to_prerealize:
for s in pre.lazydata.schedule(seen.copy()):
for s in create_schedule([pre.lazydata], seen.copy()):
if GRAPH: realized_lazybuffer(s.out, 0)
seen.add(s.out)
sched = t.lazydata.schedule(seen)
sched = create_schedule([t.lazydata], seen)
if GRAPH:
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]

View File

@@ -1,6 +1,7 @@
import unittest
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.realize import create_schedule
from tinygrad.features.search import time_linearizer
from tinygrad.device import Compiled, Device, Buffer
from tinygrad.ops import LoadOps
@@ -11,7 +12,7 @@ class TestTimeLinearizer(unittest.TestCase):
if not isinstance(Device[Device.DEFAULT], Compiled): raise unittest.SkipTest("only test for compiled backends")
def test_reasonable_time(self):
si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0]
si = [si for si in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if si.ast.op not in LoadOps][0]
rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
assert tm > 0 and tm != float('inf')

View File

@@ -3,6 +3,7 @@ from tinygrad import Tensor, GlobalCounters
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG
from tinygrad.ops import LoadOps
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.realize import create_schedule
class TestWinograd(unittest.TestCase):
def setUp(self):
@@ -19,7 +20,7 @@ class TestWinograd(unittest.TestCase):
out = Tensor.conv2d(x, w)
with Timing("scheduling: "):
sched = out.lazydata.schedule()
sched = create_schedule([out.lazydata])
for i,s in enumerate(sched):
if s.ast.op in LoadOps: continue

View File

@@ -4,7 +4,7 @@ import functools, itertools, operator
from tinygrad.helpers import all_same, dedup, round_up, prod, DEBUG
from tinygrad.dtype import DType, Scalar
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import sint
def all_reduce(op:ReduceOps, lbs):
@@ -57,7 +57,6 @@ class MultiLazyBuffer:
def is_unrealized_contiguous_const(self): return False
# passthroughs
def schedule(self, seen=None): return create_schedule(self.real_lbs, seen)
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)

View File

@@ -1,19 +1,14 @@
from __future__ import annotations
import sys, math
from collections import defaultdict
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict, cast
from tinygrad.dtype import dtypes, DType, ImageDType, Scalar
from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same, GRAPH
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
from tinygrad.shape.symbolic import sint, Variable
import math
from typing import Union, Optional, Any, Tuple, List, Dict, cast
from tinygrad.dtype import dtypes, DType, Scalar
from tinygrad.helpers import prod, getenv, all_int, all_same
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op
from tinygrad.shape.symbolic import sint
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.device import Buffer
from tinygrad.features.graph import log_lazybuffer
from weakref import ref, ReferenceType
# lazy can recurse a lot
sys.setrecursionlimit(10000)
lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {}
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
@@ -76,8 +71,6 @@ class LazyBuffer:
def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
def schedule(self, seen=None): return create_schedule([self], seen)
def _copy(self, device:str) -> LazyBuffer:
sync_size = 1 if self.device.startswith("HIP") else 0
sync = LazyBuffer.loadop(LoadOps.SYNC, (sync_size,), dtypes.uint32, self.device, src=self, enable_cache=True)
@@ -124,6 +117,7 @@ class LazyBuffer:
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
# TODO: this logic should move to the scheduler
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
assert len(self.shape)==len(new_shape) and all(ns in (1,s) for s,ns in zip(self.shape,new_shape)), f"not a contraction {self.shape=} {new_shape=}"
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
@@ -149,166 +143,3 @@ class LazyBuffer:
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
# *** schedule creation ***
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
st = buf.st + st
buf = buf.base
# all buffers here are base now
assert buf.op is not None
# consts are always fused and generated
if buf.op is LoadOps.CONST:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, unbound_st))
# if we aren't fusing it, it's a load and we add it to the inputs
if buf.realized or (buf in realizes and not first):
if buf not in inputs: inputs.append(buf)
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
# if a CONTIGUOUS made it all the way here, just skip it
if buf.op is LoadOps.CONTIGUOUS:
assert first
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
assert st.contiguous, "ReduceOps late fusion must be contiguous"
st = ShapeTracker.from_shape(buf.srcs[0].shape)
# otherwise we fuse it like normal
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
return ret
# recursively walk back in the graph to create the schedule
def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer],
reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]:
if out in seen or out.realized or out.op == LoadOps.CONST: return []
assert out.base == out
seen.add(out)
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
else:
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
if buf in allbufs or buf.base.realized: return
if GRAPH: log_lazybuffer(buf, scheduled)
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
if buf.base != buf:
# realize all places where the buffer is expanded
if prod(buf.base.st.shape) < prod(buf.st.shape):
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
else:
realizes.add(buf.base)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
if buf.forced_realize: realizes.add(buf)
allbufs[buf] = None
if buf.op in LoadOps: realizes.add(buf.base)
if buf.op == LoadOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes.add(buf.srcs[0].base)
for x in buf.srcs:
children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
if buf in realizes or buf.realized: return True
# NOTE: this broke to_image_idx and coder with JIT
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
if seen is None: seen = set()
# start by just realizing the buffers passed in
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Set[LazyBuffer] = set()
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
# check if we have to realize pads
for p in simple_pads:
if not _is_padding_okay(p, realizes):
realizes.add(p)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
for r in allbufs.keys():
if r != r.base or r.op not in ReduceOps or r in realizes: continue
# follow the reduce down
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
forced_realize = False
can_chase = True
while not forced_realize and len(child_set):
next_child_set = {}
for tr,st in child_set.items():
if tr in realizes:
realized_children[tr] = st
# can only have one output buffer
# can only reduce contiguous
# max one reduceop per kernel
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
continue
for tr_next in children[tr].keys():
if not tr_next.realized:
# max one reduceop per kernel
if tr_next.op in ReduceOps:
forced_realize = True
break
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1:
forced_realize = True
break
next_child_set[tr_next] = st + st_childs[0].st
child_set = next_child_set
if forced_realize:
tr = r
if can_chase:
# can chase this down to contiguous children
st = tr.st
while len(children[tr]) == 1:
tr_next = next(iter(children[tr].keys()))
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].st.size: break
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
reduce_for_op[tr] = r
realizes.add(tr)
else:
assert len(realized_children) == 1
reduce_for_op[next(iter(realized_children.keys()))] = r
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs)

View File

@@ -1,9 +1,14 @@
from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
import sys
from collections import defaultdict
from typing import List, Dict, Optional, cast, Set, DefaultDict
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, InterpretedASTRunner, Compiled, BufferOptions
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG
from tinygrad.features.graph import print_tree, realized_lazybuffer, log_lazybuffer
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, flatten, prod, dedup, all_int
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
# *** schedule running ***
@@ -68,3 +73,169 @@ def run_schedule(schedule:List[ScheduleItem]):
if prg: prg.exec(cast(List[Buffer], real_buffers), si.var_vals)
elif si.out.size > 0: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
if GRAPH: realized_lazybuffer(si.out, GlobalCounters.kernel_count)
# *** schedule creation ***
# creation can recurse a lot
sys.setrecursionlimit(10000)
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
st = buf.st + st
buf = buf.base
# all buffers here are base now
assert buf.op is not None
# consts are always fused and generated
if buf.op is LoadOps.CONST:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, unbound_st))
# if we aren't fusing it, it's a load and we add it to the inputs
if buf.realized or (buf in realizes and not first):
if buf not in inputs: inputs.append(buf)
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
# if a CONTIGUOUS made it all the way here, just skip it
if buf.op is LoadOps.CONTIGUOUS:
assert first
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
assert st.contiguous, "ReduceOps late fusion must be contiguous"
st = ShapeTracker.from_shape(buf.srcs[0].shape)
# otherwise we fuse it like normal
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
return ret
# recursively walk back in the graph to create the schedule
def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer],
reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]:
if out in seen or out.realized or out.op == LoadOps.CONST: return []
assert out.base == out
seen.add(out)
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
else:
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
if buf in allbufs or buf.base.realized: return
if GRAPH: log_lazybuffer(buf, scheduled)
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
if buf.base != buf:
# realize all places where the buffer is expanded
if prod(buf.base.st.shape) < prod(buf.st.shape):
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
else:
realizes.add(buf.base)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
if buf.forced_realize: realizes.add(buf)
allbufs[buf] = None
if buf.op in LoadOps: realizes.add(buf.base)
if buf.op == LoadOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes.add(buf.srcs[0].base)
for x in buf.srcs:
children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
if buf in realizes or buf.realized: return True
# NOTE: this broke to_image_idx and coder with JIT
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
if seen is None: seen = set()
# start by just realizing the buffers passed in
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Set[LazyBuffer] = set()
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
# check if we have to realize pads
for p in simple_pads:
if not _is_padding_okay(p, realizes):
realizes.add(p)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
for r in allbufs.keys():
if r != r.base or r.op not in ReduceOps or r in realizes: continue
# follow the reduce down
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
forced_realize = False
can_chase = True
while not forced_realize and len(child_set):
next_child_set = {}
for tr,st in child_set.items():
if tr in realizes:
realized_children[tr] = st
# can only have one output buffer
# can only reduce contiguous
# max one reduceop per kernel
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
continue
for tr_next in children[tr].keys():
if not tr_next.realized:
# max one reduceop per kernel
if tr_next.op in ReduceOps:
forced_realize = True
break
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1:
forced_realize = True
break
next_child_set[tr_next] = st + st_childs[0].st
child_set = next_child_set
if forced_realize:
tr = r
if can_chase:
# can chase this down to contiguous children
st = tr.st
while len(children[tr]) == 1:
tr_next = next(iter(children[tr].keys()))
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].st.size: break
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
reduce_for_op[tr] = r
realizes.add(tr)
else:
assert len(realized_children) == 1
reduce_for_op[next(iter(realized_children.keys()))] = r
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs)

View File

@@ -8,12 +8,12 @@ import numpy as np
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
from tinygrad.lazy import LazyBuffer, create_schedule
from tinygrad.lazy import LazyBuffer
from tinygrad.features.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Buffer
from tinygrad.shape.symbolic import sint
from tinygrad.realize import run_schedule
from tinygrad.realize import run_schedule, create_schedule
# **** start with two base classes, Tensor and Function ****
@@ -127,7 +127,7 @@ class Tensor:
run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
def realize(self) -> Tensor:
run_schedule(self.lazydata.schedule())
Tensor.corealize([self])
return self
def assign(self, x) -> Tensor: