MetaOps.KERNEL (#5543)

This commit is contained in:
George Hotz
2024-07-17 19:41:23 -07:00
committed by GitHub
parent d3b098299d
commit fa7e734b49
19 changed files with 39 additions and 39 deletions

View File

@@ -53,7 +53,7 @@ ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_s
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
sink = LazyOp(MetaOps.SINK, (st_0,))
sink = LazyOp(MetaOps.KERNEL, (st_0,))
# convert the computation to a "linearized" format (print the format)
from tinygrad.engine.realize import get_kernel, CompiledRunner

View File

@@ -69,7 +69,7 @@ if __name__ == "__main__":
print(f"optimizing for {Device.DEFAULT}")
sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
sched = [x for x in sched if x.ast.op is MetaOps.SINK]
sched = [x for x in sched if x.ast.op is MetaOps.KERNEL]
# focus on one kernel
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]

View File

@@ -49,7 +49,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
# confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
assert all(si.ast.op is MetaOps.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
assert all(si.ast.op is MetaOps.KERNEL or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
return schedule, schedule_independent, inputs
def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
@@ -105,7 +105,7 @@ if __name__ == "__main__":
#exit(0)
schedule, schedule_independent, inputs = get_schedule(onnx_data)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is MetaOps.SINK)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is MetaOps.KERNEL)
print(f"{len(schedule_input)} inputs")
run_schedule(schedule_independent)

View File

@@ -10,7 +10,7 @@ inf, nan = float('inf'), float('nan')
# kernel unpacker
from tinygrad.codegen.kernel import Kernel
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.SINK, val) if isinstance(val:=eval(ast_str), tuple) else val
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.KERNEL, val) if isinstance(val:=eval(ast_str), tuple) else val
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
def kern_str_to_lin(kern_str:str, opts=None):
(ast, applied_opts,) = eval(kern_str)

View File

@@ -20,7 +20,7 @@ if __name__ == "__main__":
with Timing("***** model schedule in "):
sched = out.schedule()
asts = dedup([x.ast for x in sched if x.ast.op is MetaOps.SINK])
asts = dedup([x.ast for x in sched if x.ast.op is MetaOps.KERNEL])
uops = []
with Profiling(PROFILE):
with Timing("***** model uops in "):

View File

@@ -9,7 +9,7 @@ from test.helpers import is_dtype_supported
def _check_ast_count(desired_count:int, t:Tensor):
# NOTE: this has side effect because everything can be scheduled only once
schedule = create_schedule(t.lazydata.lbs)
asts = [s for s in schedule if s.ast.op is MetaOps.SINK]
asts = [s for s in schedule if s.ast.op is MetaOps.KERNEL]
assert len(asts) == desired_count
class TestUnaryOpsConstFolding(unittest.TestCase):

View File

@@ -13,7 +13,7 @@ class TestConvShapetracker(unittest.TestCase):
# first run to init the weights, they are saved in seen
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
# run it again to get the kernels
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is MetaOps.SINK]
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is MetaOps.KERNEL]
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
for st in [x.arg.st for x in sched[0].ast.lazyops if x.op is BufferOps.LOAD]:
assert len(st.views) == 1

View File

@@ -785,7 +785,7 @@ class TestLinearizer(unittest.TestCase):
def test_div_collapse(self):
def helper(t, msg, max_ops=0):
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.KERNEL]
assert len(sched) == 1
lin = Kernel(sched[0].ast)
@@ -806,7 +806,7 @@ class TestLinearizer(unittest.TestCase):
def test_sum_collapse(self):
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.KERNEL]
assert len(sched) == 1
lin = Kernel(sched[0].ast)
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
@@ -1159,7 +1159,7 @@ class TestHandCodedOpts(unittest.TestCase):
assert k.upcasted == 1
def helper_linearizer_ast(ast:Union[Tuple[LazyOp, ...], LazyOp], inputs:List[Tensor], *args, **kwargs):
if not isinstance(ast, LazyOp): ast = LazyOp(MetaOps.SINK, ast)
if not isinstance(ast, LazyOp): ast = LazyOp(MetaOps.KERNEL, ast)
inbufs = [x.lazydata.buffer for x in inputs]
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)

View File

@@ -253,7 +253,7 @@ class TestLinearizerFailures(unittest.TestCase):
def test_failure_33(self):
# UOps.UNMUL left after linearize
ast = LazyOp(op=MetaOps.SINK, src=(
ast = LazyOp(op=MetaOps.KERNEL, src=(
LazyOp(op=BufferOps.STORE, src=(
LazyOp(op=ReduceOps.SUM, src=(
LazyOp(op=BinaryOps.MUL, src=(
@@ -282,7 +282,7 @@ class TestLinearizerFailures(unittest.TestCase):
# from fuzzing on metal
def test_failure_34(self, unroll=False):
ast = LazyOp(op=MetaOps.SINK, src=(
ast = LazyOp(op=MetaOps.KERNEL, src=(
LazyOp(op=BufferOps.STORE, src=(
LazyOp(op=BinaryOps.MAX, src=(
LazyOp(op=ReduceOps.SUM, src=(
@@ -298,7 +298,7 @@ class TestLinearizerFailures(unittest.TestCase):
# from world fuzz_linearizer: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_N=100 FUZZ_NTH=84 python3 ./test/external/fuzz_linearizer.py
def test_failure_36(self):
# UOps.UNMUL left after linearize
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))))),), arg=dtypes.uint),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.uint, st=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.uchar),), arg=MemBuffer(idx=0, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))))),), arg=dtypes.uint),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.uint, st=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.uchar),), arg=MemBuffer(idx=0, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
@@ -307,14 +307,14 @@ class TestLinearizerFailures(unittest.TestCase):
def test_failure_37(self):
# beautiful mnist kernel number 28: 6 possible TC axis_choices (3 for axis_buf1 and 2 reduce) and all fail
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=28 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
for axis in [0,1,2,3,4,5]:
opts = [Opt(op=OptOps.TC, axis=axis, amt=2)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
def test_failure_38(self):
# beautiful mnist kernel number 87: 6 possible TC axis_choices (2 for axis_buf1 and 3 reduce) and first/second reduce axis fail for both axis_buf1 choices
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=87 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
ast = LazyOp(op=MetaOps.SINK, src=(
ast = LazyOp(op=MetaOps.KERNEL, src=(
LazyOp(op=BufferOps.STORE, src=(
LazyOp(op=ReduceOps.SUM, src=(
LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(
@@ -327,14 +327,14 @@ class TestLinearizerFailures(unittest.TestCase):
def test_failure_39(self):
# beautiful mnist kernel number 127: 6 possible TC axis_choices (3 for axis_buf1 and 2 reduce) and all fail
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=127 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
for axis in [0,1,2,3,4,5]:
opts = [Opt(op=OptOps.TC, axis=axis, amt=2)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
def test_failure_40(self):
# beautiful mnist kernel number 3:
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 DEBUG=2 FUZZ_NTH=3 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
for amt in [16,32]:
opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=amt), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["METAL", "GPU"])

View File

@@ -440,7 +440,7 @@ class TestNN(unittest.TestCase):
[12, 19, 8, 1]])
result = layer(a)
schedule = create_schedule([result.lazydata])
self.assertEqual(3, len([item for item in schedule if item.ast.op is MetaOps.SINK]), "first run realizes arange, weight, and embedding")
self.assertEqual(3, len([item for item in schedule if item.ast.op is MetaOps.KERNEL]), "first run realizes arange, weight, and embedding")
run_schedule(schedule)
b = Tensor([[1, 2, 3],
@@ -448,7 +448,7 @@ class TestNN(unittest.TestCase):
[7, 8, 9]])
result = layer(b)
schedule = create_schedule([result.lazydata])
self.assertEqual(1, len([item for item in schedule if item.ast.op is MetaOps.SINK]), "second run realizes embedding only")
self.assertEqual(1, len([item for item in schedule if item.ast.op is MetaOps.KERNEL]), "second run realizes embedding only")
run_schedule(schedule)
def test_load_state_dict(self):

View File

@@ -28,7 +28,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for i,out in enumerate(s.outputs):
seen.add(out)
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.SINK]
if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.KERNEL]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
@@ -37,7 +37,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
# test the (sink) ops linearize
for s in sched:
if s.ast.op is not MetaOps.SINK: continue
if s.ast.op is not MetaOps.KERNEL: continue
l = Kernel(s.ast)
l.hand_coded_optimizations()
l.linearize()

View File

@@ -15,7 +15,7 @@ from tinygrad.shape.view import View
class TestTimeLinearizer(unittest.TestCase):
def test_reasonable_time(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.SINK][0]
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.KERNEL][0]
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
memops = {x.arg.idx:x.arg.st.real_size() for x in si.ast.lazyops if x.op is BufferOps.LOAD}
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
@@ -23,7 +23,7 @@ class TestTimeLinearizer(unittest.TestCase):
assert tm > 0 and tm != float('inf')
def test_bufs_from_lin(self):
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.SINK][0]
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.KERNEL][0]
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
assert len(rawbufs) == len(lin.membufs)
assert all(r is not None for r in rawbufs)

View File

@@ -11,7 +11,7 @@ from tinygrad.shape.view import View
class InvalidLazyOpException(Exception): pass
def lower(*ast:LazyOp):
sink_ast = LazyOp(MetaOps.SINK, ast)
sink_ast = LazyOp(MetaOps.KERNEL, ast)
if DEBUG >= 3:
for op in ast: print_tree(op)
try: verify_lazyop(sink_ast)

View File

@@ -23,7 +23,7 @@ class TestWinograd(unittest.TestCase):
sched = create_schedule([out.lazydata])
for i,s in enumerate(sched):
if s.ast.op is not MetaOps.SINK: continue
if s.ast.op is not MetaOps.KERNEL: continue
ops = s.ast.lazyops
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Kernel(s.ast)

View File

@@ -58,9 +58,9 @@ class Kernel:
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
if len(ast) > 1 or ast[0].op is BufferOps.STORE:
assert all(x.op is BufferOps.STORE for x in ast)
self.ast = LazyOp(MetaOps.SINK, ast)
self.ast = LazyOp(MetaOps.KERNEL, ast)
else:
assert len(ast) == 1 and ast[0].op is MetaOps.SINK
assert len(ast) == 1 and ast[0].op is MetaOps.KERNEL
self.ast = ast[0]
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
@@ -718,7 +718,7 @@ class Kernel:
local_store = LazyOp(BufferOps.STORE, (start,), local_buffer)
local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
elif op.op is MetaOps.SINK:
elif op.op is MetaOps.KERNEL:
arg = KernelInfo(self.local_dims, self.upcasted)
else:
arg = op.arg

View File

@@ -139,7 +139,7 @@ class IndependentLowerer:
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
in_uops = tuple(self.to_uop(y) for y in x.src)
if x.op is MetaOps.SINK: return UOp(UOps.SINK, src=in_uops)
if x.op is MetaOps.KERNEL: return UOp(UOps.SINK, src=in_uops)
if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
if x.op in ReduceOps:

View File

@@ -189,7 +189,7 @@ class ExecItem:
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is MetaOps.COPY or getenv("USE_COPY_KERNEL")
if si.ast.op is MetaOps.SINK:
if si.ast.op is MetaOps.KERNEL:
runner = get_runner(si.outputs[0].device, si.ast)
return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals], si.metadata)
out = si.outputs[0]

View File

@@ -27,11 +27,11 @@ class ScheduleItem:
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
return self.bufs[:len(self.ast.src)] if self.ast.op is MetaOps.SINK else self.bufs[0:1]
return self.bufs[:len(self.ast.src)] if self.ast.op is MetaOps.KERNEL else self.bufs[0:1]
@property
def inputs(self) -> Tuple[Buffer, ...]:
"""Read only buffers in the schedule."""
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.SINK else self.bufs[1:]
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
@@ -105,7 +105,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
return LazyOp(MetaOps.SINK, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
@@ -122,7 +122,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
output_view, vv = output_view.simplify().unbind()
if vv: var_vals.update(vv)
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
return LazyOp(MetaOps.SINK, tuple(ast)), inputs, var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
return LazyOp(MetaOps.KERNEL, tuple(ast)), inputs, var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
# *** DAG creation: decide which LazyBuffers should realize ***
@@ -318,7 +318,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
var_vals = merge_dicts([var_vals, ps[3]])
for out in ps[0]: del out.srcs # can only schedule once
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4]))
if logops and si.ast.op is MetaOps.SINK and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
for x in graph[ps[0][0]]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append(prescheduled[x])
@@ -382,5 +382,5 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
assigned = _internal_memory_planner([si.bufs for si in schedule],
noopt_buffers={b for si in schedule if si.ast.op is not MetaOps.SINK for b in si.bufs})
noopt_buffers={b for si in schedule if si.ast.op is not MetaOps.KERNEL for b in si.bufs})
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]

View File

@@ -27,7 +27,7 @@ class ReduceOps(Enum):
SUM = auto(); MAX = auto(); WMMA = auto() # noqa: E702
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
class MetaOps(Enum):
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); SINK = auto() # noqa: E702
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); KERNEL = auto() # noqa: E702
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps]
# do not preserve f(0) = 0
@@ -171,7 +171,7 @@ def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return
# the living definition of LazyOps
def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
assert ast.op is MetaOps.SINK, "must be SINK"
assert ast.op is MetaOps.KERNEL, "must be SINK"
sts: Dict[LazyOp, ShapeTracker] = {}
def dfs(op:LazyOp, st:ShapeTracker):
if op in sts: return