Schedule item (#2012)

* ScheduleItem

* put var_vals in the schedule

* fix tests, wow that proliferated quickly

* not ready to be in the schedule
This commit is contained in:
George Hotz
2023-10-07 08:59:25 -07:00
committed by GitHub
parent f1f64bc88d
commit 121f7aa8c5
9 changed files with 97 additions and 91 deletions

View File

@@ -267,7 +267,7 @@ result = Tensor(2).realize() + Tensor(3).realize()
# use the real Linearizer to linearize 2+3
from tinygrad.codegen.linearizer import Linearizer
sched = result.lazydata.schedule()
linearizer = Linearizer(sched[-1][0])
linearizer = Linearizer(sched[-1].ast)
linearizer.linearize()
# print the uops

View File

@@ -19,12 +19,12 @@ if __name__ == "__main__":
x = Tensor.empty(64, 3, 224, 224)
out = mdl(x)
sched = out.lazydata.schedule(seen)
sched = [x for x in sched if x[0].op not in LoadOps]
sched = [x for x in sched if x.ast.op not in LoadOps]
# work with the schedule
total_tm = 0
for i,(op,out,inp) in enumerate(sched):
if DEBUG >= 2: print_tree(op)
for i,si in enumerate(sched):
if DEBUG >= 2: print_tree(si.ast)
# enable only one kernel to focus on it
#if i != 1: continue
@@ -37,7 +37,7 @@ if __name__ == "__main__":
for big_chomp in [1,2]: #[1,2,4,8,16]:
for lil_chomp in [2,4,7,8,14]:
for upcasted in [0,1,2]:
lin = Linearizer(op, LinearizerOptions(device="METAL"))
lin = Linearizer(si.ast, LinearizerOptions(device="METAL"))
lin.reshape_and_permute(lambda x: (4096//big_chomp,big_chomp,56//lil_chomp,lil_chomp,56//lil_chomp,lil_chomp)+x[-2:], [0,2,4,1,3,5,6,7])
lin.upcasted += upcasted
lin.local_dims += 3
@@ -45,13 +45,13 @@ if __name__ == "__main__":
else:
# try with and without tensor cores
for tc in [0,1]:
lin = Linearizer(op, LinearizerOptions(device="METAL"))
lin = Linearizer(si.ast, LinearizerOptions(device="METAL"))
lin.hand_coded_optimizations(use_tensor_cores=tc)
lins.append(lin)
# create output/input buffers
rout = RawMetalBuffer(out.st.size(), out.dtype)
rin = [RawMetalBuffer(x.st.size(), x.dtype) for x in inp]
rout = RawMetalBuffer(si.out.st.size(), si.out.dtype)
rin = [RawMetalBuffer(x.st.size(), x.dtype) for x in si.inputs]
# benchmark the programs
choices = []

View File

@@ -32,7 +32,7 @@ class TestLinearizer(unittest.TestCase):
# these are of size 3 to avoid float4 coalesce
r = a[:-1] + a[1:]
k = Linearizer(r.lazydata.schedule()[-1][0])
k = Linearizer(r.lazydata.schedule()[-1].ast)
k.upcast()
k.linearize()
num_loads = len([uop for uop in k.uops if uop.uop == UOps.LOAD])
@@ -48,7 +48,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = a.expand([2]) + b.expand([2])
k = Linearizer(r.lazydata.schedule()[-1][0])
k = Linearizer(r.lazydata.schedule()[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
@@ -61,7 +61,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
r = Tensor.stack([a, b])
k = Linearizer(r.lazydata.schedule()[-1][0])
k = Linearizer(r.lazydata.schedule()[-1].ast)
k.upcast()
k.linearize()
num_ops = len([uop for uop in k.uops if uop.uop == UOps.ALU])
@@ -106,8 +106,8 @@ def helper_realized_ast(r:Tensor):
s = r.lazydata.schedule()
run_schedule(s[:-1]) # run all kernels except the last one
# now all input LazyBuffers buffers in s[-1] should be realized
output_buffer = Device[s[-1][1].device].buffer(prod((s if isinstance(s, int) else s.max for s in s[-1][1].shape)), s[-1][1].dtype, **s[-1][1]._device_extra_args()) # allocate an output buffer
return s[-1][0], [output_buffer] + [l.realized for l in s[-1][2]]
output_buffer = Device[s[-1].out.device].buffer(prod((s if isinstance(s, int) else s.max for s in s[-1].out.shape)), s[-1].out.dtype, **s[-1].out._device_extra_args()) # allocate an output buffer
return s[-1].ast, [output_buffer] + [l.realized for l in s[-1].inputs]
def helper_linearizer_opt(r:Tensor, opts=[]):
wanna_output = None
@@ -232,7 +232,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.hand_coded_optimizations()
k.linearize()
@@ -244,7 +244,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.shift_to(0, 4) # float4 dimension
k.shift_to(0, 2, insert_before=k.shape_len-1)
k.upcast()
@@ -260,7 +260,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.hand_coded_optimizations() # implicit trigger float4 dim
k.linearize()
@@ -272,7 +272,7 @@ class TestFloat4(unittest.TestCase):
c = a + b
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
k.upcast()
k.shift_to(len(k.full_unupcasted_shape)-1, 2, insert_before=k.shape_len-1)
@@ -290,7 +290,7 @@ class TestFloat4(unittest.TestCase):
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.upcast()
k.linearize()
@@ -305,7 +305,7 @@ class TestFloat4(unittest.TestCase):
# don't.
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.upcast()
k.upcast()
k.linearize()
@@ -321,7 +321,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.shift_to(0, 4, top=True) # top axes are float4 axes
k.upcast()
k.linearize()
@@ -337,7 +337,7 @@ class TestFloat4(unittest.TestCase):
# since the top axis is not contiguous.
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()
@@ -352,7 +352,7 @@ class TestFloat4(unittest.TestCase):
# should float4 b but not a
s = c.lazydata.schedule()[0]
k = Linearizer(s[0])
k = Linearizer(s.ast)
k.shift_to(0, 4) # float4 axis
k.upcast()
k.linearize()

View File

@@ -16,21 +16,21 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
if to_prerealize:
for pre in to_prerealize:
for s in pre.lazydata.schedule(seen.copy()):
log_schedule_item(*s)
seen.add(s[1])
log_schedule_item(s)
seen.add(s.out)
sched = t.lazydata.schedule(seen)
for s in sched: log_schedule_item(*s)
if filter_loadops: sched = [s for s in sched if s[0].op not in LoadOps]
for s in sched: log_schedule_item(s)
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
for i, s in enumerate(sched):
print("op", i)
print_tree(s[0])
print_tree(s.ast)
assert len(sched) == allowed
# test the (non loadops) ops linearize
for s in sched:
if s[0].op in LoadOps: continue
l = Linearizer(s[0])
if s.ast.op in LoadOps: continue
l = Linearizer(s.ast)
l.hand_coded_optimizations()
l.linearize()

View File

@@ -22,10 +22,10 @@ class TestWinograd(unittest.TestCase):
sched = out.lazydata.schedule()
for i,s in enumerate(sched):
if s[0].op in LoadOps: continue
ops = s[0].get_lazyops()
if s.ast.op in LoadOps: continue
ops = s.ast.get_lazyops()
with Timing(f"linearize {i} with {len(ops):4d} ops: "):
l = Linearizer(s[0])
l = Linearizer(s.ast)
l.hand_coded_optimizations()
l.linearize()

View File

@@ -4,12 +4,10 @@ try:
except ImportError:
nx = None # graph won't work
from collections import defaultdict
from typing import Dict, List, TYPE_CHECKING, Tuple
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
from typing import Dict, List
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
# **** debugging and graphing ****
G = nx.DiGraph() if nx is not None else None
@@ -47,31 +45,31 @@ def str_dtype(dtyp):
ret = str(dtyp)[7:]
return "" if ret == 'float' else f"\n{ret}"
def log_schedule_item(iop: LazyOp, ret: 'LazyBuffer', inp: Tuple['LazyBuffer', ...]):
def log_schedule_item(si: ScheduleItem):
show_graph = bool(GRAPH)
if not DEBUG and not show_graph: return
if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(inp[0].base))
if iop.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
if si.ast.op == LoadOps.CONTIGUOUS: setattr(si.out, 'node_id', nm(si.inputs[0].base))
if si.ast.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
op: List[Op] = [x.op for x in iop.get_lazyops()]
op: List[Op] = [x.op for x in si.ast.get_lazyops()]
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1
if show_graph:
assert ret.base == ret, "all outputs based"
assert si.out.base == si.out, "all outputs based"
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
for x in inp:
for x in si.inputs:
assert x.base == x, "all inputs based"
#assert nm(x) in G.nodes, "all inputs seen"
G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060')
G.add_edge(nm(x), nm(si.out), label=get_sop(op), color='#00000060')
if 'label' not in G.nodes[nm(x)]:
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
if nm(ret) not in G.nodes: G.add_node(nm(ret))
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(si.out.dtype)
if nm(si.out) not in G.nodes: G.add_node(nm(si.out))
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)+(f"\n{iop.op}" if iop.op in LoadOps else "")
G.nodes[nm(ret)]['fillcolor'] = top_colors[optype]
G.nodes[nm(ret)]['color'] = 'black'
G.nodes[nm(ret)]['style'] = 'filled'
G.nodes[nm(si.out)]['label'] = (str(set(x.shape for x in si.inputs))+"\n"+str(si.out.shape) if optype == ReduceOps else str(si.out.shape))+str_dtype(si.out.dtype)+(f"\n{si.ast.op}" if si.ast.op in LoadOps else "")
G.nodes[nm(si.out)]['fillcolor'] = top_colors[optype]
G.nodes[nm(si.out)]['color'] = 'black'
G.nodes[nm(si.out)]['style'] = 'filled'
def _tree(lazydata, prefix=""):
if type(lazydata).__name__ == "LazyBuffer": return [f"━━ realized {lazydata.dtype.name} {lazydata.shape}"] if (lazydata.realized) else _tree(lazydata.op, "LB ")

View File

@@ -5,7 +5,7 @@ from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, partition, dedup, merge_dicts
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.ops import ScheduleItem, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
@@ -79,7 +79,7 @@ def get_single_root(root:LazyBuffer) -> LazyBuffer: return get_single_root(cast(
def get_movementroot(root:LazyBuffer, allow_contiguous=False) -> LazyBuffer: return get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and (root.optype == MovementOps or (root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous)) else root
def get_movementroot_contiguous(x:LazyBuffer) -> LazyBuffer: return get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else (get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x)
def var_vals_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.var_vals() for x in ast.get_lazyops() if x.op in BufferOps]))
def var_vals_from_ast(ast:LazyOp) -> List[Variable]: return dedup(functools.reduce(operator.add, [x.arg.st.var_vals() for x in ast.get_lazyops() if x.op in BufferOps], []))
lazycache: WeakValueDictionary = WeakValueDictionary()
def create_lazybuffer(device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], base:Optional[LazyBuffer]=None):
@@ -159,7 +159,7 @@ class LazyBuffer:
# *** scheduling ***
def schedule(self, seen=None) -> List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]:
def schedule(self, seen=None) -> List[ScheduleItem]:
if seen is None: seen = set()
if self in seen or self.realized or self.is_unrealized_const(): return []
seen.add(self)
@@ -180,7 +180,7 @@ class LazyBuffer:
# run the ast and log the op
op, base_bufs = _replace_bufferops(op)
return ret + [(op, self, tuple(base_bufs))]
return ret + [ScheduleItem(op, self, tuple(base_bufs))]
# *** creation/special ops ***

View File

@@ -4,6 +4,8 @@ import numpy as np
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer
from dataclasses import dataclass
# these are the llops your accelerator must implement, along with toCpu
@@ -38,6 +40,13 @@ class ConstBuffer:
dtype: DType
st: ShapeTracker
@dataclass(frozen=True)
class ScheduleItem:
ast: LazyOp
out: LazyBuffer
inputs: Tuple[LazyBuffer, ...]
# TODO: move var_vals here
class LazyOp:
__slots__ = "op", "src", "arg", "buffers", "__weakref__"
op: Op
@@ -154,9 +163,6 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
# **************** for Compiled Buffers ****************
from tinygrad.runtime.lib import RawBuffer
from tinygrad.shape.symbolic import Variable, sym_infer
class BasicBatchExecutor:
def __init__(self, jit_cache:List[Tuple[Any, Any, Any]]): pass
def exec(self, jit_cache: List[Tuple[Any, Any, Any]], updatable_entries):

View File

@@ -1,6 +1,7 @@
from typing import List, Tuple, cast, Dict, Callable
from typing import List, cast, Dict, Callable
import numpy as np
from tinygrad.ops import LazyOp, LoadOps, Device, UnaryOps, BufferOps, MemBuffer, get_lazyop_info
import dataclasses
from tinygrad.ops import ScheduleItem, LazyOp, LoadOps, Device, UnaryOps, BufferOps, MemBuffer, get_lazyop_info
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.lazy import LazyBuffer
from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE, ImageDType, dtypes
@@ -8,59 +9,60 @@ from tinygrad.helpers import DEBUG, prod, all_int, getenv, IMAGE, ImageDType, dt
from tinygrad.runtime.lib import RawBufferMapped, RawBufferTransfer
from tinygrad.runtime.ops_disk import RawDiskBuffer
def fix_schedule_for_images(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]):
def fix_schedule_for_images(schedule:List[ScheduleItem]):
# this is the fundamental fix, find unwritable or unreadable images and convert them to normal float32 (TODO: should it be float16?)
for op,out,buffers in schedule:
if isinstance(out.dtype, ImageDType) and (prod(out.shape) != prod(out.dtype.shape) or not any(out.shape[x]%4 == 0 for x in out.st.unit_stride_axes())):
out.dtype = dtypes.float32
bops = [x for x in op.get_lazyops() if x.op == BufferOps.MEM]
for b in bops:
if isinstance(buffers[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
buffers[b.arg.idx-1].dtype = dtypes.float32
for si in schedule:
if isinstance(si.out.dtype, ImageDType) and (prod(si.out.shape) != prod(si.out.dtype.shape) or not any(si.out.shape[x]%4 == 0 for x in si.out.st.unit_stride_axes())):
si.out.dtype = dtypes.float32
for b in si.ast.get_lazyops():
if b.op != BufferOps.MEM: continue
if isinstance(si.inputs[b.arg.idx-1].dtype, ImageDType) and (b.arg.st.real_offset() % 4 != 0 or not any(b.arg.st.shape[x]%4 == 0 for x in b.arg.st.unit_stride_axes())):
si.inputs[b.arg.idx-1].dtype = dtypes.float32
# now fix up the schedule to reflect the new dtypes
fixed_schedule = []
for op,out,buffers in schedule:
fixed_schedule:List[ScheduleItem] = []
for si in schedule:
ast = si.ast
# fix input dtypes to match what they actually are
bops = [x for x in op.get_lazyops() if x.op == BufferOps.MEM]
replacements = {}
for x in bops:
if x.arg.dtype != buffers[x.arg.idx-1].dtype:
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(x.arg.idx, buffers[x.arg.idx-1].dtype, x.arg.st))
if replacements: op = op.map_buffers(replacements)
for b in si.ast.get_lazyops():
if b.op != BufferOps.MEM: continue
if b.arg.dtype != si.inputs[b.arg.idx-1].dtype:
replacements[b] = LazyOp(BufferOps.MEM, (), MemBuffer(b.arg.idx, si.inputs[b.arg.idx-1].dtype, b.arg.st))
if replacements: ast = ast.map_buffers(replacements)
# fix the ops to create the output dtype
if op.op not in LoadOps:
info = get_lazyop_info(op)
if info.dtype != out.dtype:
op = LazyOp(UnaryOps.CAST, (op,), (out.dtype, False))
if ast.op not in LoadOps:
info = get_lazyop_info(ast)
if info.dtype != si.out.dtype:
ast = LazyOp(UnaryOps.CAST, (ast,), (si.out.dtype, False))
# put this in the fixed schedule
fixed_schedule.append((op, out, buffers))
fixed_schedule.append(dataclasses.replace(si, ast=ast))
return fixed_schedule
# *** this is where things happen ***
def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]):
def run_schedule(schedule:List[ScheduleItem]):
# HACK: images can be not usable due to shape
if IMAGE >= 2: schedule = fix_schedule_for_images(schedule)
# NOTE: if you for loop the schedule it's slow because nothing frees
while len(schedule):
op,out,buffers = schedule.pop(0)
log_schedule_item(op, out, buffers)
assert all(x.realized for x in buffers), "can't run schedule, some buffers aren't realized"
if DEBUG >= 3: print_tree(op)
if op.op in LoadOps:
si = schedule.pop(0)
log_schedule_item(si)
assert all(x.realized for x in si.inputs), "can't run schedule, some inputs aren't realized"
if DEBUG >= 3: print_tree(si.ast)
if si.ast.op in LoadOps:
# confirm the LoadOps are contiguous and in order
for i,s in enumerate(op.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out, *buffers)
for i,s in enumerate(si.ast.src): assert isinstance(s, LazyOp) and s.op == BufferOps.MEM and s.arg.idx == i+1 and s.arg.st.contiguous, f"bad LoadOps src {i}: {s}"
LOAD_OPS_DISPATCHER[cast(LoadOps, si.ast.op)](si.out, *si.inputs)
else:
out.realized = Device[out.device].exec_ast(op, output=out, inputs=buffers, var_vals=out.var_vals, **out._device_extra_args())
del out.op
for v in out.views: del v.op
assert out.realized and isinstance(out.realized, Device[out.device].buffer), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
si.out.realized = Device[si.out.device].exec_ast(si.ast, output=si.out, inputs=si.inputs, var_vals=si.out.var_vals, **si.out._device_extra_args())
del si.out.op
for v in si.out.views: del v.op
assert si.out.realized and isinstance(si.out.realized, Device[si.out.device].buffer), f"device mismatch on realized got {type(si.out.realized)} expected {si.out.device}"
assert si.out.realized.dtype == si.out.dtype, "realized dtype is incorrect"
# *** zero op LoadOps ***