mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
MetaOps.KERNEL (#5543)
This commit is contained in:
@@ -53,7 +53,7 @@ ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_s
|
||||
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
|
||||
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
sink = LazyOp(MetaOps.SINK, (st_0,))
|
||||
sink = LazyOp(MetaOps.KERNEL, (st_0,))
|
||||
|
||||
# convert the computation to a "linearized" format (print the format)
|
||||
from tinygrad.engine.realize import get_kernel, CompiledRunner
|
||||
|
||||
@@ -69,7 +69,7 @@ if __name__ == "__main__":
|
||||
print(f"optimizing for {Device.DEFAULT}")
|
||||
|
||||
sched = globals()[f"get_sched_{getenv('MODEL', 'resnet')}"]()
|
||||
sched = [x for x in sched if x.ast.op is MetaOps.SINK]
|
||||
sched = [x for x in sched if x.ast.op is MetaOps.KERNEL]
|
||||
|
||||
# focus on one kernel
|
||||
if getenv("KERNEL", -1) >= 0: sched = sched[getenv("KERNEL", -1):getenv("KERNEL", -1)+1]
|
||||
|
||||
@@ -49,7 +49,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
print(f"{len(schedule)} schedule items depend on the input, {len(schedule_independent)} don't")
|
||||
|
||||
# confirm no non-sink metaop in the (non independent) schedule except for the ones that load the input buffers
|
||||
assert all(si.ast.op is MetaOps.SINK or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
|
||||
assert all(si.ast.op is MetaOps.KERNEL or out in input_lb for si in schedule for out in si.outputs), "has non SINK ops, can't compile to Thneed"
|
||||
return schedule, schedule_independent, inputs
|
||||
|
||||
def test_vs_onnx(onnx_data, eis:Optional[List[ExecItem]], inputs:Dict[str, Tensor]):
|
||||
@@ -105,7 +105,7 @@ if __name__ == "__main__":
|
||||
#exit(0)
|
||||
|
||||
schedule, schedule_independent, inputs = get_schedule(onnx_data)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is MetaOps.SINK)
|
||||
schedule, schedule_input = partition(schedule, lambda x: x.ast.op is MetaOps.KERNEL)
|
||||
print(f"{len(schedule_input)} inputs")
|
||||
|
||||
run_schedule(schedule_independent)
|
||||
|
||||
@@ -10,7 +10,7 @@ inf, nan = float('inf'), float('nan')
|
||||
|
||||
# kernel unpacker
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.SINK, val) if isinstance(val:=eval(ast_str), tuple) else val
|
||||
def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.KERNEL, val) if isinstance(val:=eval(ast_str), tuple) else val
|
||||
def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts)
|
||||
def kern_str_to_lin(kern_str:str, opts=None):
|
||||
(ast, applied_opts,) = eval(kern_str)
|
||||
|
||||
2
test/external/external_benchmark_schedule.py
vendored
2
test/external/external_benchmark_schedule.py
vendored
@@ -20,7 +20,7 @@ if __name__ == "__main__":
|
||||
with Timing("***** model schedule in "):
|
||||
sched = out.schedule()
|
||||
|
||||
asts = dedup([x.ast for x in sched if x.ast.op is MetaOps.SINK])
|
||||
asts = dedup([x.ast for x in sched if x.ast.op is MetaOps.KERNEL])
|
||||
uops = []
|
||||
with Profiling(PROFILE):
|
||||
with Timing("***** model uops in "):
|
||||
|
||||
@@ -9,7 +9,7 @@ from test.helpers import is_dtype_supported
|
||||
def _check_ast_count(desired_count:int, t:Tensor):
|
||||
# NOTE: this has side effect because everything can be scheduled only once
|
||||
schedule = create_schedule(t.lazydata.lbs)
|
||||
asts = [s for s in schedule if s.ast.op is MetaOps.SINK]
|
||||
asts = [s for s in schedule if s.ast.op is MetaOps.KERNEL]
|
||||
assert len(asts) == desired_count
|
||||
|
||||
class TestUnaryOpsConstFolding(unittest.TestCase):
|
||||
|
||||
@@ -13,7 +13,7 @@ class TestConvShapetracker(unittest.TestCase):
|
||||
# first run to init the weights, they are saved in seen
|
||||
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen)
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is MetaOps.SINK]
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata], seen) if si.ast.op is MetaOps.KERNEL]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
for st in [x.arg.st for x in sched[0].ast.lazyops if x.op is BufferOps.LOAD]:
|
||||
assert len(st.views) == 1
|
||||
|
||||
@@ -785,7 +785,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_div_collapse(self):
|
||||
def helper(t, msg, max_ops=0):
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.KERNEL]
|
||||
assert len(sched) == 1
|
||||
|
||||
lin = Kernel(sched[0].ast)
|
||||
@@ -806,7 +806,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.SINK]
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is MetaOps.KERNEL]
|
||||
assert len(sched) == 1
|
||||
lin = Kernel(sched[0].ast)
|
||||
assert not any(u.op is UOps.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
@@ -1159,7 +1159,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
assert k.upcasted == 1
|
||||
|
||||
def helper_linearizer_ast(ast:Union[Tuple[LazyOp, ...], LazyOp], inputs:List[Tensor], *args, **kwargs):
|
||||
if not isinstance(ast, LazyOp): ast = LazyOp(MetaOps.SINK, ast)
|
||||
if not isinstance(ast, LazyOp): ast = LazyOp(MetaOps.KERNEL, ast)
|
||||
inbufs = [x.lazydata.buffer for x in inputs]
|
||||
outbufs = [Buffer(inbufs[-1].device if inbufs else Device.DEFAULT, out.arg.st.size, out.arg.dtype).allocate() for out in ast.src]
|
||||
return _helper_linearizer_opt_ast(ast, outbufs+inbufs, *args, **kwargs)
|
||||
|
||||
@@ -253,7 +253,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
|
||||
def test_failure_33(self):
|
||||
# UOps.UNMUL left after linearize
|
||||
ast = LazyOp(op=MetaOps.SINK, src=(
|
||||
ast = LazyOp(op=MetaOps.KERNEL, src=(
|
||||
LazyOp(op=BufferOps.STORE, src=(
|
||||
LazyOp(op=ReduceOps.SUM, src=(
|
||||
LazyOp(op=BinaryOps.MUL, src=(
|
||||
@@ -282,7 +282,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
|
||||
# from fuzzing on metal
|
||||
def test_failure_34(self, unroll=False):
|
||||
ast = LazyOp(op=MetaOps.SINK, src=(
|
||||
ast = LazyOp(op=MetaOps.KERNEL, src=(
|
||||
LazyOp(op=BufferOps.STORE, src=(
|
||||
LazyOp(op=BinaryOps.MAX, src=(
|
||||
LazyOp(op=ReduceOps.SUM, src=(
|
||||
@@ -298,7 +298,7 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
# from world fuzz_linearizer: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_N=100 FUZZ_NTH=84 python3 ./test/external/fuzz_linearizer.py
|
||||
def test_failure_36(self):
|
||||
# UOps.UNMUL left after linearize
|
||||
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))))),), arg=dtypes.uint),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.uint, st=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.uchar),), arg=MemBuffer(idx=0, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))))),), arg=dtypes.uint),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.uint, st=ShapeTracker(views=(View(shape=(5, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.uchar),), arg=MemBuffer(idx=0, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
opts = [Opt(op=OptOps.UPCAST, axis=0, amt=0)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
|
||||
@@ -307,14 +307,14 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
def test_failure_37(self):
|
||||
# beautiful mnist kernel number 28: 6 possible TC axis_choices (3 for axis_buf1 and 2 reduce) and all fail
|
||||
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=28 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
|
||||
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
for axis in [0,1,2,3,4,5]:
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, amt=2)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
def test_failure_38(self):
|
||||
# beautiful mnist kernel number 87: 6 possible TC axis_choices (2 for axis_buf1 and 3 reduce) and first/second reduce axis fail for both axis_buf1 choices
|
||||
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=87 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
|
||||
ast = LazyOp(op=MetaOps.SINK, src=(
|
||||
ast = LazyOp(op=MetaOps.KERNEL, src=(
|
||||
LazyOp(op=BufferOps.STORE, src=(
|
||||
LazyOp(op=ReduceOps.SUM, src=(
|
||||
LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(
|
||||
@@ -327,14 +327,14 @@ class TestLinearizerFailures(unittest.TestCase):
|
||||
def test_failure_39(self):
|
||||
# beautiful mnist kernel number 127: 6 possible TC axis_choices (3 for axis_buf1 and 2 reduce) and all fail
|
||||
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=1 FUZZ_NTH=127 DEBUG=2 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
|
||||
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.MAX, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.uchar, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(784, 0, 0, 28, 1, 0, 28, 1), offset=0, mask=None, contiguous=False),)))),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 5, 5), strides=(0, 0, 25, 0, 0, 0, 5, 1), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(6, 7)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 1, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=0.0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
for axis in [0,1,2,3,4,5]:
|
||||
opts = [Opt(op=OptOps.TC, axis=axis, amt=2)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=[])
|
||||
def test_failure_40(self):
|
||||
# beautiful mnist kernel number 3:
|
||||
# fuzz: PYTHONPATH=. METAL=1 FUZZ_ALL_ACTIONS=1 DEPTH=2 DEBUG=2 FUZZ_NTH=3 python3 ./test/external/fuzz_linearizer.py --logfile /tmp/beautiful_mnist.kernels.txt
|
||||
ast = LazyOp(op=MetaOps.SINK, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
ast = LazyOp(op=MetaOps.KERNEL, src=(LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BinaryOps.ADD, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))))),), arg=(1,)), LazyOp(op=BufferOps.CONST, src=(), arg=ConstBuffer(val=-1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)))),), arg=None)
|
||||
for amt in [16,32]:
|
||||
opts = [Opt(op=OptOps.GROUPTOP, axis=0, amt=amt), Opt(op=OptOps.UNROLL, axis=0, amt=0)]
|
||||
helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["METAL", "GPU"])
|
||||
|
||||
@@ -440,7 +440,7 @@ class TestNN(unittest.TestCase):
|
||||
[12, 19, 8, 1]])
|
||||
result = layer(a)
|
||||
schedule = create_schedule([result.lazydata])
|
||||
self.assertEqual(3, len([item for item in schedule if item.ast.op is MetaOps.SINK]), "first run realizes arange, weight, and embedding")
|
||||
self.assertEqual(3, len([item for item in schedule if item.ast.op is MetaOps.KERNEL]), "first run realizes arange, weight, and embedding")
|
||||
run_schedule(schedule)
|
||||
|
||||
b = Tensor([[1, 2, 3],
|
||||
@@ -448,7 +448,7 @@ class TestNN(unittest.TestCase):
|
||||
[7, 8, 9]])
|
||||
result = layer(b)
|
||||
schedule = create_schedule([result.lazydata])
|
||||
self.assertEqual(1, len([item for item in schedule if item.ast.op is MetaOps.SINK]), "second run realizes embedding only")
|
||||
self.assertEqual(1, len([item for item in schedule if item.ast.op is MetaOps.KERNEL]), "second run realizes embedding only")
|
||||
run_schedule(schedule)
|
||||
|
||||
def test_load_state_dict(self):
|
||||
|
||||
@@ -28,7 +28,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
for i,out in enumerate(s.outputs):
|
||||
seen.add(out)
|
||||
sched = create_schedule(flatten([r.lazydata.lbs for r in t]), seen)
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.SINK]
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is MetaOps.KERNEL]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
for i, s in enumerate(sched):
|
||||
@@ -37,7 +37,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
|
||||
if len(sched) != allowed: raise KernelCountException(f"{len(sched)=} != {allowed}")
|
||||
# test the (sink) ops linearize
|
||||
for s in sched:
|
||||
if s.ast.op is not MetaOps.SINK: continue
|
||||
if s.ast.op is not MetaOps.KERNEL: continue
|
||||
l = Kernel(s.ast)
|
||||
l.hand_coded_optimizations()
|
||||
l.linearize()
|
||||
|
||||
@@ -15,7 +15,7 @@ from tinygrad.shape.view import View
|
||||
|
||||
class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_reasonable_time(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.SINK][0]
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.KERNEL][0]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
|
||||
memops = {x.arg.idx:x.arg.st.real_size() for x in si.ast.lazyops if x.op is BufferOps.LOAD}
|
||||
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
|
||||
@@ -23,7 +23,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
def test_bufs_from_lin(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.SINK][0]
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is MetaOps.KERNEL][0]
|
||||
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
|
||||
assert len(rawbufs) == len(lin.membufs)
|
||||
assert all(r is not None for r in rawbufs)
|
||||
|
||||
@@ -11,7 +11,7 @@ from tinygrad.shape.view import View
|
||||
|
||||
class InvalidLazyOpException(Exception): pass
|
||||
def lower(*ast:LazyOp):
|
||||
sink_ast = LazyOp(MetaOps.SINK, ast)
|
||||
sink_ast = LazyOp(MetaOps.KERNEL, ast)
|
||||
if DEBUG >= 3:
|
||||
for op in ast: print_tree(op)
|
||||
try: verify_lazyop(sink_ast)
|
||||
|
||||
@@ -23,7 +23,7 @@ class TestWinograd(unittest.TestCase):
|
||||
sched = create_schedule([out.lazydata])
|
||||
|
||||
for i,s in enumerate(sched):
|
||||
if s.ast.op is not MetaOps.SINK: continue
|
||||
if s.ast.op is not MetaOps.KERNEL: continue
|
||||
ops = s.ast.lazyops
|
||||
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
|
||||
l = Kernel(s.ast)
|
||||
|
||||
@@ -58,9 +58,9 @@ class Kernel:
|
||||
def __init__(self, *ast:LazyOp, opts:Optional[Renderer]=None):
|
||||
if len(ast) > 1 or ast[0].op is BufferOps.STORE:
|
||||
assert all(x.op is BufferOps.STORE for x in ast)
|
||||
self.ast = LazyOp(MetaOps.SINK, ast)
|
||||
self.ast = LazyOp(MetaOps.KERNEL, ast)
|
||||
else:
|
||||
assert len(ast) == 1 and ast[0].op is MetaOps.SINK
|
||||
assert len(ast) == 1 and ast[0].op is MetaOps.KERNEL
|
||||
self.ast = ast[0]
|
||||
|
||||
self.opts = opts if opts is not None else Device[Device.DEFAULT].renderer
|
||||
@@ -718,7 +718,7 @@ class Kernel:
|
||||
local_store = LazyOp(BufferOps.STORE, (start,), local_buffer)
|
||||
local_load = LazyOp(BufferOps.LOAD, (local_store,), local_buffer)
|
||||
return LazyOp(op.op, (local_load,), tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)))
|
||||
elif op.op is MetaOps.SINK:
|
||||
elif op.op is MetaOps.KERNEL:
|
||||
arg = KernelInfo(self.local_dims, self.upcasted)
|
||||
else:
|
||||
arg = op.arg
|
||||
|
||||
@@ -139,7 +139,7 @@ class IndependentLowerer:
|
||||
return UOp(UOps.STORE, None, (buf, idx, self.to_uop(x.src[0])) + ((valid,) if has_valid else ()))
|
||||
|
||||
in_uops = tuple(self.to_uop(y) for y in x.src)
|
||||
if x.op is MetaOps.SINK: return UOp(UOps.SINK, src=in_uops)
|
||||
if x.op is MetaOps.KERNEL: return UOp(UOps.SINK, src=in_uops)
|
||||
if x.op is UnaryOps.CAST: return UOp(UOps.CAST, x.arg.scalar(), in_uops)
|
||||
if x.op is UnaryOps.BITCAST: return UOp(UOps.BITCAST, x.arg.scalar(), in_uops)
|
||||
if x.op in ReduceOps:
|
||||
|
||||
@@ -189,7 +189,7 @@ class ExecItem:
|
||||
|
||||
def lower_schedule_item(si:ScheduleItem) -> ExecItem:
|
||||
assert len(set(x.device for x in si.bufs)) == 1 or si.ast.op is MetaOps.COPY or getenv("USE_COPY_KERNEL")
|
||||
if si.ast.op is MetaOps.SINK:
|
||||
if si.ast.op is MetaOps.KERNEL:
|
||||
runner = get_runner(si.outputs[0].device, si.ast)
|
||||
return ExecItem(runner, [si.bufs[x[0]] for x in runner.p.globals], si.metadata)
|
||||
out = si.outputs[0]
|
||||
|
||||
@@ -27,11 +27,11 @@ class ScheduleItem:
|
||||
@property
|
||||
def outputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
return self.bufs[:len(self.ast.src)] if self.ast.op is MetaOps.SINK else self.bufs[0:1]
|
||||
return self.bufs[:len(self.ast.src)] if self.ast.op is MetaOps.KERNEL else self.bufs[0:1]
|
||||
@property
|
||||
def inputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read only buffers in the schedule."""
|
||||
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.SINK else self.bufs[1:]
|
||||
return self.bufs[len(self.ast.src):] if self.ast.op is MetaOps.KERNEL else self.bufs[1:]
|
||||
|
||||
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
|
||||
|
||||
@@ -105,7 +105,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
"""describe the computation for a LazyBuffer with LazyOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
|
||||
rd = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.uint8, st:=ShapeTracker.from_shape((out.arg,))))
|
||||
return LazyOp(MetaOps.SINK, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
|
||||
return LazyOp(MetaOps.KERNEL, (LazyOp(BufferOps.STORE, (rd,), MemBuffer(0, dtypes.uint8, st)), )), [x.base for x in out.srcs], {}, []
|
||||
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LazyOp(out.op, (), out.arg), [x.base for x in out.srcs], {}, []
|
||||
var_vals: Dict[Variable, int] = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||
@@ -122,7 +122,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
|
||||
return LazyOp(MetaOps.SINK, tuple(ast)), inputs, var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
|
||||
return LazyOp(MetaOps.KERNEL, tuple(ast)), inputs, var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
@@ -318,7 +318,7 @@ def create_schedule_with_vars(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffe
|
||||
var_vals = merge_dicts([var_vals, ps[3]])
|
||||
for out in ps[0]: del out.srcs # can only schedule once
|
||||
schedule.append(si:=ScheduleItem(ps[1], tuple(x.buffer for x in ps[0]+ps[2] if x.size != 0), ps[4]))
|
||||
if logops and si.ast.op is MetaOps.SINK and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||
if logops and si.ast.op is MetaOps.KERNEL and not any(i.device.startswith("DISK:") for i in si.inputs): logops.write(str(si.ast)+"\n")
|
||||
for x in graph[ps[0][0]]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0: queue.append(prescheduled[x])
|
||||
@@ -382,5 +382,5 @@ def _internal_memory_planner(buffers:List[Union[List[Buffer], Tuple[Buffer, ...]
|
||||
def memory_planner(schedule:List[ScheduleItem]) -> List[ScheduleItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([si.bufs for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not MetaOps.SINK for b in si.bufs})
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not MetaOps.KERNEL for b in si.bufs})
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
|
||||
|
||||
@@ -27,7 +27,7 @@ class ReduceOps(Enum):
|
||||
SUM = auto(); MAX = auto(); WMMA = auto() # noqa: E702
|
||||
class BufferOps(Enum): LOAD = auto(); CONST = auto(); STORE = auto() # noqa: E702
|
||||
class MetaOps(Enum):
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); SINK = auto() # noqa: E702
|
||||
EMPTY = auto(); CONST = auto(); COPY = auto(); CONTIGUOUS = auto(); CUSTOM = auto(); ASSIGN = auto(); VIEW = auto(); KERNEL = auto() # noqa: E702
|
||||
Op = Union[UnaryOps, BinaryOps, ReduceOps, MetaOps, TernaryOps, BufferOps]
|
||||
|
||||
# do not preserve f(0) = 0
|
||||
@@ -171,7 +171,7 @@ def reduce_st(st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return
|
||||
|
||||
# the living definition of LazyOps
|
||||
def verify_lazyop(ast:LazyOp) -> Dict[LazyOp, ShapeTracker]:
|
||||
assert ast.op is MetaOps.SINK, "must be SINK"
|
||||
assert ast.op is MetaOps.KERNEL, "must be SINK"
|
||||
sts: Dict[LazyOp, ShapeTracker] = {}
|
||||
def dfs(op:LazyOp, st:ShapeTracker):
|
||||
if op in sts: return
|
||||
|
||||
Reference in New Issue
Block a user