mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
move create schedule and delete old API (#3377)
* move create schedule and delete old API * fix test multitensor
This commit is contained in:
@@ -250,7 +250,8 @@ result = Tensor(2.0).realize() + Tensor(3.0).realize()
|
||||
|
||||
# use the real Linearizer to linearize 2+3
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
sched = result.lazydata.schedule()
|
||||
from tinygrad.realize import create_schedule
|
||||
sched = create_schedule([result.lazydata])
|
||||
linearizer = Linearizer(sched[-1].ast, ClangCompiler.linearizer_opts)
|
||||
linearizer.linearize()
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ assert out.as_buffer().cast('I')[0] == 5
|
||||
print("******** third, the LazyBuffer ***********")
|
||||
|
||||
from tinygrad.lazy import LazyBuffer, LoadOps
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.realize import run_schedule, create_schedule
|
||||
|
||||
# allocate some values + load in values
|
||||
# TODO: remove numpy here
|
||||
@@ -87,7 +87,7 @@ b.realized = Buffer("CPU", 1, dtypes.int32, np.array([3], np.int32).flatten())
|
||||
out = a.e(BinaryOps.ADD, b)
|
||||
|
||||
# schedule the computation as a list of kernels
|
||||
sched = out.schedule()
|
||||
sched = create_schedule([out])
|
||||
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||
|
||||
# DEBUGGING: print the compute ast as a tree
|
||||
|
||||
@@ -20,10 +20,8 @@ More generically, the whole network is a DAG. Ignore the forward/backward stuff,
|
||||
|
||||
This is a rewrite of a lot of tinygrad. I don't think continuing to support Interpreted backends is worth it, have to deal with disk in a smart way.
|
||||
|
||||
We keep the frontend: tensor.py + mlops.py + lazy.py
|
||||
We keep the backend (renderer/runtime): cstyle.py + device.py + ops_*.py
|
||||
We keep the shapetracker/symbolic: shapetracker.py + view.py + symbolic.py
|
||||
We keep the features and nn stuff.
|
||||
But codegen is all rewritten.
|
||||
|
||||
|
||||
We keep the frontend (Tensor -> LazyBuffer): tensor.py + mlops.py + lazy.py
|
||||
We keep the shapetracker/symbolic (part of the frontend): shapetracker.py + view.py + symbolic.py
|
||||
Codegen is all rewritten.
|
||||
We keep the backend (uops renderer/runtime): cstyle.py + device.py + ops_*.py
|
||||
|
||||
@@ -8,6 +8,7 @@ from tinygrad.features.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.helpers import ansilen, DEBUG, getenv
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.realize import create_schedule
|
||||
|
||||
if __name__ == "__main__":
|
||||
if getenv("HALF"):
|
||||
@@ -21,12 +22,12 @@ if __name__ == "__main__":
|
||||
print(f"optimizing for {Device.DEFAULT}")
|
||||
|
||||
# first model run to init the weights, they are saved in seen
|
||||
mdl(Tensor.empty(64, 3, 224, 224)).lazydata.schedule(seen)
|
||||
create_schedule([mdl(Tensor.empty(64, 3, 224, 224)).lazydata], seen)
|
||||
|
||||
# run model again to get only what changes, these are the kernels of the model
|
||||
x = Tensor.empty(64, 3, 224, 224)
|
||||
out = mdl(x)
|
||||
sched = out.lazydata.schedule(seen)
|
||||
sched = create_schedule([out.lazydata], seen)
|
||||
sched = [x for x in sched if x.ast.op not in LoadOps]
|
||||
|
||||
# focus on one kernel
|
||||
|
||||
@@ -3,13 +3,14 @@ from tinygrad.ops import LoadOps
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from test.external.fuzz_linearizer import run_linearizer
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.realize import create_schedule
|
||||
|
||||
N = 17**3
|
||||
|
||||
a = Tensor.rand(N, N)
|
||||
b = Tensor.rand(N, N)
|
||||
c = a @ b
|
||||
sched = [si for si in c.lazydata.schedule() if si.ast.op not in LoadOps]
|
||||
sched = [si for si in create_schedule([c.lazydata]) if si.ast.op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
|
||||
@@ -24,7 +25,7 @@ run_linearizer(lin)
|
||||
###
|
||||
|
||||
a = Tensor.rand(61, 61).sum(axis=0)
|
||||
sched = [si for si in a.lazydata.schedule() if si.ast.op not in LoadOps]
|
||||
sched = [si for si in create_schedule([a.lazydata]) if si.ast.op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
|
||||
|
||||
@@ -30,13 +30,14 @@ except ImportError:
|
||||
|
||||
import os
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.realize import create_schedule
|
||||
|
||||
# define the compute
|
||||
A = Tensor.rand(M, K, device="clang")
|
||||
B = Tensor.rand(K, N, device="clang")
|
||||
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
||||
|
||||
sched = C.lazydata.schedule()
|
||||
sched = create_schedule([C.lazydata])
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.codegen.kernel import LinearizerOptions
|
||||
lin = Linearizer(sched[-1].ast, LinearizerOptions(has_local=False, supports_float4=False))
|
||||
|
||||
@@ -16,7 +16,7 @@ from extra.onnx import get_run_onnx
|
||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.helpers import partition, Context, fetch, getenv, GRAPH, DEBUG
|
||||
from tinygrad.realize import run_schedule, lower_schedule_item
|
||||
from tinygrad.realize import run_schedule, lower_schedule_item, create_schedule
|
||||
from tinygrad.ops import LoadOps, ScheduleItem
|
||||
Device.DEFAULT = "GPU"
|
||||
|
||||
@@ -32,7 +32,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
# run the model
|
||||
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
|
||||
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
|
||||
schedule = ret.lazydata.schedule()
|
||||
schedule = create_schedule([ret.lazydata])
|
||||
|
||||
# filter schedule that don't depend on the inputs
|
||||
input_lb = [x.lazydata.base for x in inputs.values()]
|
||||
|
||||
7
test/external/external_test_uops_graphing.py
vendored
7
test/external/external_test_uops_graphing.py
vendored
@@ -4,6 +4,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.renderer.cstyle import OpenCLRenderer
|
||||
from tinygrad.features.graph import graph_uops
|
||||
from tinygrad.realize import create_schedule
|
||||
from tinygrad.nn import Conv2d
|
||||
|
||||
class TestUopsGraph(unittest.TestCase):
|
||||
@@ -11,7 +12,7 @@ class TestUopsGraph(unittest.TestCase):
|
||||
N = 1024
|
||||
a = Tensor.rand(N,N)
|
||||
b = Tensor.rand(N,N)
|
||||
si = (a@b).lazydata.schedule()[-1]
|
||||
si = create_schedule([(a@b).lazydata])[-1]
|
||||
lin = Linearizer(si.ast)
|
||||
lin.hand_coded_optimizations()
|
||||
print(lin.colored_shape())
|
||||
@@ -22,7 +23,7 @@ class TestUopsGraph(unittest.TestCase):
|
||||
|
||||
def test_reduce(self):
|
||||
a = Tensor.rand(1024*1024)
|
||||
si = a.sum().lazydata.schedule()[-1]
|
||||
si = create_schedule([a.sum().lazydata])[-1]
|
||||
lin = Linearizer(si.ast)
|
||||
lin.hand_coded_optimizations()
|
||||
uops = lin.linearize().uops
|
||||
@@ -32,7 +33,7 @@ class TestUopsGraph(unittest.TestCase):
|
||||
def test_conv(self):
|
||||
x = Tensor.rand(1,3,16,16)
|
||||
c = Conv2d(3, 16, (3,3))
|
||||
si = c(x).elu().lazydata.schedule()[-1]
|
||||
si = create_schedule([c(x).elu().lazydata])[-1]
|
||||
lin = Linearizer(si.ast)
|
||||
lin.hand_coded_optimizations()
|
||||
uops = lin.linearize().uops
|
||||
|
||||
@@ -3,6 +3,7 @@ import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.realize import create_schedule
|
||||
|
||||
class TestConvShapetracker(unittest.TestCase):
|
||||
def test_conv_3x3_one_view(self):
|
||||
@@ -10,9 +11,9 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
seen = set()
|
||||
|
||||
# first run to init the weights, they are saved in seen
|
||||
conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen)
|
||||
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in conv(Tensor.empty(1, 16, 10, 10)).lazydata.schedule(seen) if si.ast.op not in LoadOps]
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op not in LoadOps]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
print(sched[0])
|
||||
for arg in [sched[0].out, *sched[0].inputs]:
|
||||
|
||||
@@ -6,6 +6,7 @@ import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.realize import create_schedule
|
||||
from tinygrad.ops import UnaryOps, get_lazyop_info
|
||||
from test.test_dtype import is_dtype_supported
|
||||
|
||||
@@ -64,7 +65,7 @@ def universal_test(a, b, dtype, op):
|
||||
def universal_test_unary(a, dtype, op):
|
||||
if not isinstance(op, tuple): op = (op, op)
|
||||
out: Tensor = op[0](Tensor([a], dtype=dtype))
|
||||
ast = out.lazydata.schedule()[-1].ast
|
||||
ast = create_schedule([out.lazydata])[-1].ast
|
||||
tensor_value = out.numpy()
|
||||
numpy_value = op[1](np.array([a]).astype(dtype.np))
|
||||
if dtype in dtypes_float:
|
||||
|
||||
@@ -3,8 +3,7 @@ import time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.device import InterpretedASTRunner
|
||||
from tinygrad.lazy import create_schedule
|
||||
from tinygrad.realize import run_schedule, lower_schedule_item
|
||||
from tinygrad.realize import run_schedule, create_schedule, lower_schedule_item
|
||||
|
||||
class TestFusionOp(unittest.TestCase):
|
||||
def test_contiguous_add(self):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.realize import create_schedule
|
||||
|
||||
# stuff needed to unpack a kernel
|
||||
# ruff: noqa: F401
|
||||
@@ -16,7 +17,7 @@ inf, nan = float('inf'), float('nan')
|
||||
class TestLazyOp(unittest.TestCase):
|
||||
def test_lazyop_str(self):
|
||||
t = Tensor.rand(10) + Tensor.rand(10)
|
||||
s = t.lazydata.schedule()
|
||||
s = create_schedule([t.lazydata])
|
||||
ast = s[-1].ast
|
||||
ast_remade = eval(str(ast))
|
||||
self.assertEqual(ast, ast_remade)
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, Node, create_rednode
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.features.jit import CacheCollector
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.realize import create_schedule, run_schedule
|
||||
from tinygrad.helpers import prod, Context
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
|
||||
@@ -33,7 +33,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# these are of size 3 to avoid float4 coalesce
|
||||
r = a[:-1] + a[1:]
|
||||
|
||||
k = Linearizer(r.lazydata.schedule()[-1].ast)
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
|
||||
@@ -46,7 +46,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
|
||||
k = Linearizer(r.lazydata.schedule()[-1].ast)
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
@@ -56,7 +56,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack([a, b])
|
||||
|
||||
k = Linearizer(r.lazydata.schedule()[-1].ast)
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
|
||||
@@ -67,7 +67,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor(2), Tensor(3)
|
||||
r = a * b
|
||||
|
||||
k = Linearizer(r.lazydata.schedule()[-1].ast)
|
||||
k = Linearizer(create_schedule([r.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.uop in [UOps.LOAD, UOps.ALU]])
|
||||
assert num_ops <= 0, "more load or alu uops than needed"
|
||||
@@ -76,14 +76,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
for tensor_dtype, acc_dtype in (
|
||||
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
k = Linearizer(a.lazydata.schedule()[-1].ast)
|
||||
k = Linearizer(create_schedule([a.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
k = Linearizer(c.lazydata.schedule()[-1].ast)
|
||||
k = Linearizer(create_schedule([c.lazydata])[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.uop == UOps.DEFINE_ACC]
|
||||
assert local[0].dtype == expected_dtype
|
||||
@@ -121,7 +121,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_limit_dims_to_max_5d_global(self):
|
||||
t = Tensor.rand(3, 4, 5, 6, 7).pad(((1, 1), (1, 1), (1, 1), (1, 1), (1, 1))) + 1
|
||||
sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps]
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
assert lin.full_shape[:lin.global_dims] == (5, 6, 7, 8, 9)
|
||||
@@ -129,7 +129,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor.ones(256,256).sum()
|
||||
sched = [si for si in t.lazydata.schedule() if si.ast.op not in LoadOps]
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op not in LoadOps]
|
||||
assert len(sched) == 1
|
||||
lin = Linearizer(sched[0].ast)
|
||||
assert not any(u.uop == UOps.LOOP for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
@@ -154,7 +154,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
arg=TernaryOps.WHERE).uop == UOps.ALU
|
||||
|
||||
def helper_realized_ast(r:Tensor):
|
||||
s = r.lazydata.schedule()
|
||||
s = create_schedule([r.lazydata])
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
# now all input LazyBuffers buffers in s[-1] should be realized
|
||||
# allocate an output buffer
|
||||
@@ -176,7 +176,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(2, 8).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
@@ -188,7 +188,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(2, 8).realize()
|
||||
c = a + b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(0, 4) # float4 dimension
|
||||
k.shift_to(0, 2, insert_before=k.shape_len-1)
|
||||
@@ -204,7 +204,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(9).realize().shrink(((1, 9),))
|
||||
c = a + b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations() # implicit trigger float4 dim
|
||||
k.linearize()
|
||||
@@ -216,7 +216,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
c = a + b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
|
||||
k.upcast()
|
||||
@@ -234,7 +234,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# only the first and last conv dot products are aligned in a, and b is never aligned, so no
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -249,7 +249,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# dimension, then we could do float4 for only that one set of loads, but we currently
|
||||
# don't.
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
@@ -265,7 +265,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(0, 4, top=True) # top axes are float4 axes
|
||||
k.upcast()
|
||||
@@ -281,7 +281,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
@@ -296,7 +296,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
# should float4 b but not a
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
@@ -310,7 +310,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)])
|
||||
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
|
||||
|
||||
s = layer_2.lazydata.schedule()[-1]
|
||||
s = create_schedule([layer_2.lazydata])[-1]
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
|
||||
@@ -323,7 +323,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
def test_masked_upcast_wino(self):
|
||||
monster = Tensor.stack([Tensor.stack([Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
|
||||
|
||||
s = monster.lazydata.schedule()[-1]
|
||||
s = create_schedule([monster.lazydata])[-1]
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
|
||||
@@ -335,7 +335,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
x,w = Tensor.rand(1,4,8,8, requires_grad=True).realize(), Tensor.rand(4,4,3,3, requires_grad=True).realize()
|
||||
out = Tensor.conv2d(x,w, padding=1)
|
||||
upcasts = []
|
||||
wino_schedule = out.lazydata.schedule()
|
||||
wino_schedule = create_schedule([out.lazydata])
|
||||
# collect upcasts of tile transform kernels
|
||||
for i, si in enumerate(wino_schedule):
|
||||
k = Linearizer(si.ast)
|
||||
@@ -349,7 +349,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
|
||||
|
||||
out.mean().backward()
|
||||
backward_schedule = x.grad.lazydata.schedule() + w.grad.lazydata.schedule()
|
||||
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
|
||||
for si in backward_schedule:
|
||||
k = Linearizer(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
@@ -364,7 +364,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 7, 4))
|
||||
layer_3 = Tensor.cat(layer_2.unsqueeze(0), Tensor.rand(6, 7, 7, 4))
|
||||
|
||||
s = layer_3.lazydata.schedule()[-1]
|
||||
s = create_schedule([layer_3.lazydata])[-1]
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 5 # make sure all ops are done in one kernel
|
||||
@@ -379,7 +379,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
b = Tensor.rand(N, N).realize()
|
||||
c = a @ b
|
||||
|
||||
s = c.lazydata.schedule()[0]
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
k = Linearizer(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from tinygrad.device import BufferCopy
|
||||
from tinygrad.ops import LoadOps, ReduceOps
|
||||
from tinygrad.helpers import CI
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.realize import create_schedule
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
|
||||
@@ -296,7 +297,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for p in get_parameters(bn): p.shard_(devices).realize()
|
||||
|
||||
out = bn(t)
|
||||
scheds = [sched for sched in out.lazydata.schedule() if sched.out.device in devices and sched.ast.op is not LoadOps.COPY]
|
||||
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.out.device in devices and sched.ast.op is not LoadOps.COPY]
|
||||
assert set(sched.out.device for sched in scheds) == set(devices), "should have ast on each shard device"
|
||||
asts = [sched.ast for sched in scheds]
|
||||
assert len(asts) == 8, len(asts)
|
||||
@@ -527,9 +528,9 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
p.shard_(devices)
|
||||
|
||||
synced_out = synced_bn(x)
|
||||
synced_si = [si for si in synced_out.lazydata.schedule()]
|
||||
synced_si = [si for si in create_schedule(synced_out.lazydata.lbs)]
|
||||
unsynced_out = unsynced_bn(x)
|
||||
unsynced_si = [si for si in unsynced_out.lazydata.schedule()]
|
||||
unsynced_si = [si for si in create_schedule(unsynced_out.lazydata.lbs)]
|
||||
|
||||
# TODO: test synced / unsynced batchnorm cross device kernel and copies
|
||||
assert synced_si
|
||||
|
||||
@@ -10,16 +10,17 @@ from tinygrad.device import Device, Compiled
|
||||
from tinygrad.helpers import DEBUG, GRAPH
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad.realize import create_schedule
|
||||
from tinygrad import nn, dtypes
|
||||
|
||||
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_loadops=True):
|
||||
seen = set()
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
for s in pre.lazydata.schedule(seen.copy()):
|
||||
for s in create_schedule([pre.lazydata], seen.copy()):
|
||||
if GRAPH: realized_lazybuffer(s.out, 0)
|
||||
seen.add(s.out)
|
||||
sched = t.lazydata.schedule(seen)
|
||||
sched = create_schedule([t.lazydata], seen)
|
||||
if GRAPH:
|
||||
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
|
||||
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.realize import create_schedule
|
||||
from tinygrad.features.search import time_linearizer
|
||||
from tinygrad.device import Compiled, Device, Buffer
|
||||
from tinygrad.ops import LoadOps
|
||||
@@ -11,7 +12,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
if not isinstance(Device[Device.DEFAULT], Compiled): raise unittest.SkipTest("only test for compiled backends")
|
||||
|
||||
def test_reasonable_time(self):
|
||||
si = [si for si in Tensor([1,2,3,4]).add(1).lazydata.schedule() if si.ast.op not in LoadOps][0]
|
||||
si = [si for si in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if si.ast.op not in LoadOps][0]
|
||||
rawbufs = [Buffer(Device.DEFAULT, si.out.st.real_size(), si.out.dtype)] + [Buffer(Device.DEFAULT, x.st.real_size(), x.dtype) for x in si.inputs]
|
||||
tm = time_linearizer(Linearizer(si.ast), rawbufs, allow_test_size=False, cnt=10)
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
@@ -3,6 +3,7 @@ from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.realize import create_schedule
|
||||
|
||||
class TestWinograd(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -19,7 +20,7 @@ class TestWinograd(unittest.TestCase):
|
||||
out = Tensor.conv2d(x, w)
|
||||
|
||||
with Timing("scheduling: "):
|
||||
sched = out.lazydata.schedule()
|
||||
sched = create_schedule([out.lazydata])
|
||||
|
||||
for i,s in enumerate(sched):
|
||||
if s.ast.op in LoadOps: continue
|
||||
|
||||
@@ -4,7 +4,7 @@ import functools, itertools, operator
|
||||
from tinygrad.helpers import all_same, dedup, round_up, prod, DEBUG
|
||||
from tinygrad.dtype import DType, Scalar
|
||||
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer, create_schedule
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import sint
|
||||
|
||||
def all_reduce(op:ReduceOps, lbs):
|
||||
@@ -57,7 +57,6 @@ class MultiLazyBuffer:
|
||||
def is_unrealized_contiguous_const(self): return False
|
||||
|
||||
# passthroughs
|
||||
def schedule(self, seen=None): return create_schedule(self.real_lbs, seen)
|
||||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||
def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis, self.real)
|
||||
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
||||
|
||||
183
tinygrad/lazy.py
183
tinygrad/lazy.py
@@ -1,19 +1,14 @@
|
||||
from __future__ import annotations
|
||||
import sys, math
|
||||
from collections import defaultdict
|
||||
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict, cast
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType, Scalar
|
||||
from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same, GRAPH
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
import math
|
||||
from typing import Union, Optional, Any, Tuple, List, Dict, cast
|
||||
from tinygrad.dtype import dtypes, DType, Scalar
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.features.graph import log_lazybuffer
|
||||
from weakref import ref, ReferenceType
|
||||
|
||||
# lazy can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
lazycache: Dict[Any, ReferenceType[LazyBuffer]] = {}
|
||||
def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Op]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(),
|
||||
base:Optional[LazyBuffer]=None, enable_cache=bool(getenv("LAZYCACHE", 1))):
|
||||
@@ -76,8 +71,6 @@ class LazyBuffer:
|
||||
def is_unrealized_const(self): return not self.base.realized and self.base.op is LoadOps.CONST
|
||||
def is_unrealized_contiguous_const(self): return self.base == self and not self.base.realized and self.op is LoadOps.CONST
|
||||
|
||||
def schedule(self, seen=None): return create_schedule([self], seen)
|
||||
|
||||
def _copy(self, device:str) -> LazyBuffer:
|
||||
sync_size = 1 if self.device.startswith("HIP") else 0
|
||||
sync = LazyBuffer.loadop(LoadOps.SYNC, (sync_size,), dtypes.uint32, self.device, src=self, enable_cache=True)
|
||||
@@ -124,6 +117,7 @@ class LazyBuffer:
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), self.dtype, op, unbound_new_shape, (self,))
|
||||
|
||||
def r(self, op:ReduceOps, new_shape:Tuple[sint, ...]) -> LazyBuffer:
|
||||
# TODO: this logic should move to the scheduler
|
||||
if self.size == 0 and 0 not in new_shape: return self.const({ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[op], new_shape)
|
||||
assert len(self.shape)==len(new_shape) and all(ns in (1,s) for s,ns in zip(self.shape,new_shape)), f"not a contraction {self.shape=} {new_shape=}"
|
||||
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
|
||||
@@ -149,166 +143,3 @@ class LazyBuffer:
|
||||
def permute(self, arg:Tuple[int, ...]): return self._view(self.st.permute(arg))
|
||||
def shrink(self, arg:Tuple[Tuple[sint, sint], ...]): return self._view(self.st.shrink(arg))
|
||||
def stride(self, arg:Tuple[int, ...]): return self._view(self.st.stride(arg))
|
||||
|
||||
# *** schedule creation ***
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
if buf != buf.base:
|
||||
st = buf.st + st
|
||||
buf = buf.base
|
||||
# all buffers here are base now
|
||||
assert buf.op is not None
|
||||
|
||||
# consts are always fused and generated
|
||||
if buf.op is LoadOps.CONST:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, unbound_st))
|
||||
|
||||
# if we aren't fusing it, it's a load and we add it to the inputs
|
||||
if buf.realized or (buf in realizes and not first):
|
||||
if buf not in inputs: inputs.append(buf)
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
|
||||
|
||||
# if a CONTIGUOUS made it all the way here, just skip it
|
||||
if buf.op is LoadOps.CONTIGUOUS:
|
||||
assert first
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
||||
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
||||
|
||||
# otherwise we fuse it like normal
|
||||
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
# recursively walk back in the graph to create the schedule
|
||||
def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer],
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]:
|
||||
if out in seen or out.realized or out.op == LoadOps.CONST: return []
|
||||
assert out.base == out
|
||||
seen.add(out)
|
||||
|
||||
inputs: List[LazyBuffer] = []
|
||||
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
else:
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
|
||||
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
|
||||
if buf in allbufs or buf.base.realized: return
|
||||
if GRAPH: log_lazybuffer(buf, scheduled)
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
||||
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
|
||||
if buf.base != buf:
|
||||
# realize all places where the buffer is expanded
|
||||
if prod(buf.base.st.shape) < prod(buf.st.shape):
|
||||
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
|
||||
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
||||
simple_pads.add(buf.base)
|
||||
else:
|
||||
realizes.add(buf.base)
|
||||
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
||||
if buf.forced_realize: realizes.add(buf)
|
||||
allbufs[buf] = None
|
||||
if buf.op in LoadOps: realizes.add(buf.base)
|
||||
if buf.op == LoadOps.COPY:
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
||||
realizes.add(buf.srcs[0].base)
|
||||
for x in buf.srcs:
|
||||
children[x.base][buf] = None
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
|
||||
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
||||
if buf in realizes or buf.realized: return True
|
||||
# NOTE: this broke to_image_idx and coder with JIT
|
||||
if buf.op in UNSAFE_PAD_OPS: return False
|
||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
|
||||
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
if seen is None: seen = set()
|
||||
|
||||
# start by just realizing the buffers passed in
|
||||
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
simple_pads: Set[LazyBuffer] = set()
|
||||
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
||||
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
|
||||
|
||||
# check if we have to realize pads
|
||||
for p in simple_pads:
|
||||
if not _is_padding_okay(p, realizes):
|
||||
realizes.add(p)
|
||||
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
for r in allbufs.keys():
|
||||
if r != r.base or r.op not in ReduceOps or r in realizes: continue
|
||||
|
||||
# follow the reduce down
|
||||
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
|
||||
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
|
||||
forced_realize = False
|
||||
can_chase = True
|
||||
while not forced_realize and len(child_set):
|
||||
next_child_set = {}
|
||||
for tr,st in child_set.items():
|
||||
if tr in realizes:
|
||||
realized_children[tr] = st
|
||||
# can only have one output buffer
|
||||
# can only reduce contiguous
|
||||
# max one reduceop per kernel
|
||||
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
||||
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
||||
forced_realize = True
|
||||
break
|
||||
continue
|
||||
for tr_next in children[tr].keys():
|
||||
if not tr_next.realized:
|
||||
# max one reduceop per kernel
|
||||
if tr_next.op in ReduceOps:
|
||||
forced_realize = True
|
||||
break
|
||||
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
||||
if len(st_childs) > 1:
|
||||
forced_realize = True
|
||||
break
|
||||
next_child_set[tr_next] = st + st_childs[0].st
|
||||
child_set = next_child_set
|
||||
if forced_realize:
|
||||
tr = r
|
||||
if can_chase:
|
||||
# can chase this down to contiguous children
|
||||
st = tr.st
|
||||
while len(children[tr]) == 1:
|
||||
tr_next = next(iter(children[tr].keys()))
|
||||
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
||||
if len(st_childs) > 1: break
|
||||
if st.size != st_childs[0].st.size: break
|
||||
st = st + st_childs[0].st
|
||||
if not st.contiguous or tr_next.op in ReduceOps: break
|
||||
tr = tr_next
|
||||
reduce_for_op[tr] = r
|
||||
realizes.add(tr)
|
||||
else:
|
||||
assert len(realized_children) == 1
|
||||
reduce_for_op[next(iter(realized_children.keys()))] = r
|
||||
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs)
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
from typing import List, Dict, Optional, cast
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import List, Dict, Optional, cast, Set, DefaultDict
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, BufferXfer, BufferRead, JITRunner, update_stats, InterpretedASTRunner, Compiled, BufferOptions
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG
|
||||
from tinygrad.features.graph import print_tree, realized_lazybuffer, log_lazybuffer
|
||||
from tinygrad.helpers import colored, getenv, GRAPH, cpu_time_execution, DEBUG, flatten, prod, dedup, all_int
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# *** schedule running ***
|
||||
|
||||
@@ -68,3 +73,169 @@ def run_schedule(schedule:List[ScheduleItem]):
|
||||
if prg: prg.exec(cast(List[Buffer], real_buffers), si.var_vals)
|
||||
elif si.out.size > 0: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
|
||||
if GRAPH: realized_lazybuffer(si.out, GlobalCounters.kernel_count)
|
||||
|
||||
# *** schedule creation ***
|
||||
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
# recursively create a lazyop
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Set[LazyBuffer], cache, first=True) -> LazyOp:
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
if buf != buf.base:
|
||||
st = buf.st + st
|
||||
buf = buf.base
|
||||
# all buffers here are base now
|
||||
assert buf.op is not None
|
||||
|
||||
# consts are always fused and generated
|
||||
if buf.op is LoadOps.CONST:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
return LazyOp(BufferOps.CONST, (), ConstBuffer(float(buf.arg), buf.dtype, unbound_st))
|
||||
|
||||
# if we aren't fusing it, it's a load and we add it to the inputs
|
||||
if buf.realized or (buf in realizes and not first):
|
||||
if buf not in inputs: inputs.append(buf)
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(inputs.index(buf)+1, buf.dtype, unbound_st))
|
||||
|
||||
# if a CONTIGUOUS made it all the way here, just skip it
|
||||
if buf.op is LoadOps.CONTIGUOUS:
|
||||
assert first
|
||||
return _recursive_lazyop(buf.srcs[0], inputs, var_vals, st, realizes, cache, False)
|
||||
|
||||
# if it's a reduce, we have to change the shapetracker
|
||||
if buf.op in ReduceOps:
|
||||
assert st.contiguous, "ReduceOps late fusion must be contiguous"
|
||||
st = ShapeTracker.from_shape(buf.srcs[0].shape)
|
||||
|
||||
# otherwise we fuse it like normal
|
||||
cache[(buf, st)] = ret = LazyOp(buf.op, tuple(_recursive_lazyop(x, inputs, var_vals, st, realizes, cache, False) for x in buf.srcs), buf.arg)
|
||||
return ret
|
||||
|
||||
# recursively walk back in the graph to create the schedule
|
||||
def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyBuffer],
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> List[ScheduleItem]:
|
||||
if out in seen or out.realized or out.op == LoadOps.CONST: return []
|
||||
assert out.base == out
|
||||
seen.add(out)
|
||||
|
||||
inputs: List[LazyBuffer] = []
|
||||
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
|
||||
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.WAIT, LoadOps.COPY, LoadOps.EMPTY}:
|
||||
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
|
||||
else:
|
||||
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
|
||||
op = _recursive_lazyop(out, inputs, var_vals, output_st, realizes, cache={})
|
||||
op = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0]))
|
||||
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in inputs) + [ScheduleItem(op, out, tuple(inputs), var_vals)]
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
|
||||
if buf in allbufs or buf.base.realized: return
|
||||
if GRAPH: log_lazybuffer(buf, scheduled)
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
||||
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
|
||||
if buf.base != buf:
|
||||
# realize all places where the buffer is expanded
|
||||
if prod(buf.base.st.shape) < prod(buf.st.shape):
|
||||
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
|
||||
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
|
||||
simple_pads.add(buf.base)
|
||||
else:
|
||||
realizes.add(buf.base)
|
||||
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
|
||||
if buf.forced_realize: realizes.add(buf)
|
||||
allbufs[buf] = None
|
||||
if buf.op in LoadOps: realizes.add(buf.base)
|
||||
if buf.op == LoadOps.COPY:
|
||||
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
|
||||
realizes.add(buf.srcs[0].base)
|
||||
for x in buf.srcs:
|
||||
children[x.base][buf] = None
|
||||
_recurse_lb(x, realizes, allbufs, simple_pads, children)
|
||||
|
||||
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
|
||||
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
||||
if buf in realizes or buf.realized: return True
|
||||
# NOTE: this broke to_image_idx and coder with JIT
|
||||
if buf.op in UNSAFE_PAD_OPS: return False
|
||||
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
|
||||
|
||||
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
if seen is None: seen = set()
|
||||
|
||||
# start by just realizing the buffers passed in
|
||||
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
simple_pads: Set[LazyBuffer] = set()
|
||||
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
||||
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
|
||||
|
||||
# check if we have to realize pads
|
||||
for p in simple_pads:
|
||||
if not _is_padding_okay(p, realizes):
|
||||
realizes.add(p)
|
||||
|
||||
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
|
||||
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
for r in allbufs.keys():
|
||||
if r != r.base or r.op not in ReduceOps or r in realizes: continue
|
||||
|
||||
# follow the reduce down
|
||||
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
|
||||
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
|
||||
forced_realize = False
|
||||
can_chase = True
|
||||
while not forced_realize and len(child_set):
|
||||
next_child_set = {}
|
||||
for tr,st in child_set.items():
|
||||
if tr in realizes:
|
||||
realized_children[tr] = st
|
||||
# can only have one output buffer
|
||||
# can only reduce contiguous
|
||||
# max one reduceop per kernel
|
||||
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
|
||||
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
|
||||
forced_realize = True
|
||||
break
|
||||
continue
|
||||
for tr_next in children[tr].keys():
|
||||
if not tr_next.realized:
|
||||
# max one reduceop per kernel
|
||||
if tr_next.op in ReduceOps:
|
||||
forced_realize = True
|
||||
break
|
||||
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
||||
if len(st_childs) > 1:
|
||||
forced_realize = True
|
||||
break
|
||||
next_child_set[tr_next] = st + st_childs[0].st
|
||||
child_set = next_child_set
|
||||
if forced_realize:
|
||||
tr = r
|
||||
if can_chase:
|
||||
# can chase this down to contiguous children
|
||||
st = tr.st
|
||||
while len(children[tr]) == 1:
|
||||
tr_next = next(iter(children[tr].keys()))
|
||||
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
|
||||
if len(st_childs) > 1: break
|
||||
if st.size != st_childs[0].st.size: break
|
||||
st = st + st_childs[0].st
|
||||
if not st.contiguous or tr_next.op in ReduceOps: break
|
||||
tr = tr_next
|
||||
reduce_for_op[tr] = r
|
||||
realizes.add(tr)
|
||||
else:
|
||||
assert len(realized_children) == 1
|
||||
reduce_for_op[next(iter(realized_children.keys()))] = r
|
||||
|
||||
return flatten(_recursive_schedule(x.base, seen, realizes, reduce_for_op) for x in outs)
|
||||
@@ -8,12 +8,12 @@ import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
|
||||
from tinygrad.lazy import LazyBuffer, create_schedule
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.realize import run_schedule
|
||||
from tinygrad.realize import run_schedule, create_schedule
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
@@ -127,7 +127,7 @@ class Tensor:
|
||||
run_schedule(create_schedule(flatten([x.lazydata.lbs if isinstance(x.lazydata, MultiLazyBuffer) else [x.lazydata] for x in lst])))
|
||||
|
||||
def realize(self) -> Tensor:
|
||||
run_schedule(self.lazydata.schedule())
|
||||
Tensor.corealize([self])
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user