diff --git a/test/test_schedule.py b/test/test_schedule.py index ccbac47541..805755d9d0 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -3,14 +3,24 @@ # NOTE: this has overlap with external_test_opt.py import unittest +from typing import List, Optional from tinygrad.tensor import Tensor from tinygrad.ops import LoadOps from tinygrad.helpers import DEBUG, dtypes from tinygrad.codegen.linearizer import Linearizer +from tinygrad.graph import log_schedule_item from tinygrad import nn -def check_schedule(t:Tensor, allowed:int): - sched = [s for s in t.lazydata.schedule() if s[0].op not in LoadOps] +def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None): + seen = set() + if to_prerealize: + for pre in to_prerealize: + for s in pre.lazydata.schedule(seen.copy()): + log_schedule_item(*s) + seen.add(s[1]) + sched = t.lazydata.schedule(seen) + for s in sched: log_schedule_item(*s) + sched = [s for s in sched if s[0].op not in LoadOps] if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") if len(sched) != allowed or DEBUG >= 3: from extra.utils import print_tree @@ -111,8 +121,7 @@ class TestSchedule(unittest.TestCase): b = Tensor.empty(10) c = a+b d = a+b - c.realize() - check_schedule(d, 0) + check_schedule(d, 0, [c]) @unittest.skip("failing in old lazy") def test_cache_binaryop_reshaped(self): @@ -120,25 +129,14 @@ class TestSchedule(unittest.TestCase): b = Tensor.empty(10) c = a+b d = a.reshape(10,1)+b.reshape(10,1) - c.realize() - check_schedule(d, 0) + check_schedule(d, 0, [c]) def test_cache_binaryop_transpose(self): a = Tensor.empty(10,10) b = Tensor.empty(10,10) c = (a.T*b.T).T #.contiguous() d = a*b - c.realize() - check_schedule(d, 0) - - @unittest.skip("failing in old lazy") - def test_cache_binaryop_transpose_realized(self): - a = Tensor.randn(10,10).realize() - b = Tensor.randn(10,10).realize() - c = (a.T*b.T).T - d = a*b - c.realize() - check_schedule(d, 0) + check_schedule(d, 0, [c]) def test_cache_two_reduceops(self): a = Tensor.empty(10) @@ -162,23 +160,19 @@ class TestSchedule(unittest.TestCase): def test_fold_conv_relu(self): c1 = nn.Conv2d(3,16,3) - c1.weight.realize() - c1.bias.realize() # run img = Tensor.ones(2,3,64,64) out = c1(img).relu() - check_schedule(out, 1) + check_schedule(out, 1, [c1.weight, c1.bias]) def test_fold_conv_elu(self): c1 = nn.Conv2d(3,16,3) - c1.weight.realize() - c1.bias.realize() # run img = Tensor.ones(2,3,64,64) out = c1(img).elu() - check_schedule(out, 1) + check_schedule(out, 1, [c1.weight, c1.bias]) def test_two_sum(self): img = Tensor.empty(64,64) @@ -206,8 +200,7 @@ class TestSchedule(unittest.TestCase): b = Tensor.empty(16) c = a+b d = (a+b).reshape(16,1) - c.realize() - check_schedule(d, 0) + check_schedule(d, 0, [c]) def test_multi_permute_should_collapse(self): a = Tensor.empty(4,4,4,4) @@ -224,19 +217,6 @@ class TestSchedule(unittest.TestCase): out = c.sum() + d.sum() check_schedule(out, 1) - """ - def test_reshape_doesnt_matter(self): - a = Tensor.empty(10) - b = a.reshape(10,1) - self.assertIs(a.lazydata.backing, b.lazydata.backing) - - def test_permute_doesnt_matter(self): - a = Tensor.empty(10, 10) - b = a.permute(1,0) - c = a.reshape(10, 1, 10).permute(2,1,0) - self.assertIs(b.lazydata.backing, c.lazydata.backing) - """ - # NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first @unittest.skip("not real world") def test_children_dont_push(self): @@ -255,8 +235,7 @@ class TestSchedule(unittest.TestCase): e = keep_me.sum() # give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse) d = keep_me+c check_schedule(d, 2) - d.realize() - check_schedule(keep_me, 0) + check_schedule(keep_me, 0, [d]) @unittest.skip("failing in old lazy") def test_permute_breaks_fusion(self): @@ -294,7 +273,6 @@ class TestSchedule(unittest.TestCase): # NOOP, 3 convs, contiguous check_schedule(x, 5) - @unittest.skip("failing now with contig") def test_image_conv_fusion_minimal(self): b1 = Tensor.empty(16) b2 = Tensor.empty(16) diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py deleted file mode 100644 index fcb41dcfbd..0000000000 --- a/test/unit/test_graph.py +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env python -import unittest -import networkx as nx # type: ignore -from tinygrad.tensor import Tensor -from tinygrad.graph import G, log_op, prune_graph -from tinygrad.ops import BinaryOps, LazyOp, MovementOps, ReduceOps - -def buf(*shp): return Tensor.ones(*shp, device="CPU").lazydata - -class TestGraph(unittest.TestCase): - def setUp(self): - G.clear() - - def helper_compare_graph(self, RG: nx.DiGraph): - assert nx.is_isomorphic(G, RG, node_match=lambda x,y: x["label"] == y["label"], edge_match=lambda x,y: x["label"] == y["label"] if "label" in y else True) - - def test_add_graph(self): - a = buf(4,4) - b = buf(4,4) - ast = LazyOp(BinaryOps.ADD, (a,b)) - ret = buf(4,4) - - RG = nx.DiGraph() - RG.add_node(0, label="(4, 4)") - RG.add_node(1, label="(4, 4)") - RG.add_node(2, label="(4, 4)") - RG.add_edge(0, 2, label="ADD") - RG.add_edge(1, 2, label="ADD") - - log_op(ret, ast, show_graph=True) - self.helper_compare_graph(RG) - - def test_add_sum_graph(self): - a = buf(4,4) - b = buf(1,1) - op0 = LazyOp(MovementOps.RESHAPE, (b,), (4, 4)) - op1 = LazyOp(BinaryOps.ADD, (a,op0)) - ast = LazyOp(ReduceOps.SUM, (op1,), (1,1)) - ret = buf(1,1) - - RG = nx.DiGraph() - RG.add_node(0, label="(4, 4)") - RG.add_node(1, label="(1, 1)") - RG.add_node(2, label="{(4, 4), (1, 1)}\n(1, 1)") - RG.add_edge(0, 2, label="RES.ADD.SUM") - RG.add_edge(1, 2, label="RES.ADD.SUM") - - log_op(ret, ast, show_graph=True) - self.helper_compare_graph(RG) - - def test_add_graph_prune(self): - a = buf(1,1) - ast = LazyOp(MovementOps.RESHAPE, (a,), (4, 4)) - ret = buf(4,4) - log_op(ret, ast, show_graph=True) - - b = buf(4,4) - ast = LazyOp(BinaryOps.ADD, (ret,b)) - ret = buf(4,4) - log_op(ret, ast, show_graph=True) - prune_graph() - - RG = nx.DiGraph() - RG.add_node(0, label="(1, 1)") - RG.add_node(1, label="(4, 4)") - RG.add_node(2, label="(4, 4)") - RG.add_edge(0, 2) # edge connecting pruned nodes - RG.add_edge(1, 2, label="ADD") - - self.helper_compare_graph(RG) - -if __name__ == "__main__": - unittest.main() diff --git a/tinygrad/graph.py b/tinygrad/graph.py index dd44f024b9..17e2a06238 100644 --- a/tinygrad/graph.py +++ b/tinygrad/graph.py @@ -1,13 +1,12 @@ -import os, atexit, itertools +import os, atexit try: import networkx as nx # type: ignore except ImportError: nx = None # graph won't work from collections import defaultdict -from typing import Dict, List, Optional, TYPE_CHECKING -from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, Op, OpType, LazyOp -from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters -from tinygrad.runtime.lib import RawConst +from typing import Dict, List, TYPE_CHECKING, Tuple, cast +from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp +from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer @@ -24,7 +23,6 @@ if DEBUG >= 2: if GRAPH: def save_graph_exit(): for k,v in cnts.items(): print(k, v) - if PRUNEGRAPH: prune_graph() print("saving", G) nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot') # -Gnslimit=100 can make it finish, but you won't like results @@ -40,6 +38,7 @@ def nm(x): return x.node_id def get_sop(op: List[Op]): + op = [x for x in op if x not in BufferOps] if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1]) if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1]) return str(len(op)) @@ -48,36 +47,28 @@ def str_dtype(dtyp): ret = str(dtyp)[7:] return "" if ret == 'float' else f"\n{ret}" -def log_op(ret: 'LazyBuffer', ast: LazyOp, show_graph: Optional[bool] = None, phantom=False): - if show_graph is None: show_graph = bool(GRAPH) +def log_schedule_item(iop: LazyOp, ret: 'LazyBuffer', inp: Tuple['LazyBuffer', ...]): + show_graph = bool(GRAPH) if not DEBUG and not show_graph: return - op: List[Op] = [x.op for x in ast.get_lazyops()] - inp: List['LazyBuffer'] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1] - oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps] + if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(cast('LazyBuffer', iop.src[0]).base)) + if iop.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return + + op: List[Op] = [x.op for x in iop.get_lazyops()] + oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps] optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0]) cnts[optype] += 1 - if DEBUG >= 6: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>") if show_graph: - top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#ff8080"} - dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore - + assert ret.base == ret, "all outputs based" + top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'} for x in inp: - G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060' if phantom else 'black') + assert x.base == x, "all inputs based" + #assert nm(x) in G.nodes, "all inputs seen" + G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060') if 'label' not in G.nodes[nm(x)]: G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype) if nm(ret) not in G.nodes: G.add_node(nm(ret)) - G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype) - G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('60' if phantom else ('80' if dashed else str()))) if optype in top_colors else "#ffffff" - G.nodes[nm(ret)]['color'] = 'white' if phantom else 'black' - G.nodes[nm(ret)]['style'] = ('filled, dashed' if dashed else 'filled') - G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps] - -# prune movementops and loadops -def prune_graph(): - dead_nodes = [] - for n in G.nodes: - if 'prunable' in G.nodes[n] and G.nodes[n]['prunable']: - G.add_edges_from([(x, y) for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n))]) - dead_nodes.append(n) - G.remove_nodes_from(dead_nodes) + G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)+(f"\n{iop.op}" if iop.op in LoadOps else "") + G.nodes[nm(ret)]['fillcolor'] = top_colors[optype] + G.nodes[nm(ret)]['color'] = 'black' + G.nodes[nm(ret)]['style'] = 'filled' diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 8165f9a2dd..5593cae41d 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -4,8 +4,8 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp from weakref import ref, WeakSet, WeakValueDictionary import numpy as np -from tinygrad.graph import log_op -from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts +from tinygrad.graph import log_schedule_item +from tinygrad.helpers import DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps from tinygrad.shape.shapetracker import ShapeTracker, get_contraction from tinygrad.shape.symbolic import Variable, sint @@ -27,7 +27,8 @@ REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2 PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3 -# **** realize functions **** +# **** ast fixing functions **** + def _ast_reduceops(op:LazyOp) -> LazyOp: # TODO: this can also corealize a binary op after the reduce, not just before src = op.src[0] @@ -71,7 +72,7 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]: if x.base in base_bufs: replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st)) elif x.realized and isinstance(x.realized, RawConst): - replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, st)) + replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.dtype, st)) elif not x.realized and x.base.op.op == LoadOps.CONST: replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st)) else: @@ -102,8 +103,8 @@ UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, class LazyBuffer: __deletable__ = ('op',) - def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None): - self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker + def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None): + self.st: ShapeTracker = st self._var_vals: Dict[Variable, int] = var_vals self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype self._realized: Optional[RawBuffer] = src @@ -111,18 +112,16 @@ class LazyBuffer: # TODO: does children have to be a ref count instead of a set? can a Buffer be a double child? self.children: WeakSet = WeakSet() self.views: WeakSet = WeakSet() - # NOTE: op should be read only after construction of LazyBuffer - self.op: LazyOp = op + # NOTE: op should be read only after construction of LazyBuffer. it is now with schedule + if op is not None: + self.op: LazyOp = op + for x in op.buffers: x.children.add(self) assert optype != MovementOps or (base is not None and base.optype != MovementOps), "MovementOps must be based" self._base = base if base: base.views.add(self) - for x in op.buffers: x.children.add(self) + else: assert st.contiguous, "unbased LazyBuffers must be contiguous" if not LAZY: self.realize() - # log phantom ops to the graph - if GRAPH >= 3: - log_op(self, self.op, phantom=True) - @property def var_vals_key(self): return tuple(sorted(self.var_vals.keys())) @@ -148,7 +147,7 @@ class LazyBuffer: assert self._base is None, "no setting var_vals of based LazyBuffers" self._var_vals = val - def __repr__(self): return f"" + def __repr__(self): return f"" @property def key(self): if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key) @@ -161,19 +160,19 @@ class LazyBuffer: def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self) def get_lazyops(self) -> List[LazyOp]: return [] + # *** scheduling *** + def schedule(self, seen=None) -> List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]: if seen is None: seen = set() if self in seen or self.realized: return [] seen.add(self) + if self.optype is MovementOps: return self.base.schedule(seen) op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src) if op.op in LoadOps: return [(self.op, self, ())] - if self.optype is MovementOps: return self.base.schedule(seen) if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape) - elif self.optype is ReduceOps: - op = _ast_reduceops(op) - if op.op in BinaryOps: op = _ast_binaryops(op, self.shape) + elif self.optype is ReduceOps: op = _ast_reduceops(op) # HACK: image shape can be wrong, hot cast it back to a normal float if isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())): @@ -185,12 +184,13 @@ class LazyBuffer: if self.op.op == LoadOps.CONTIGUOUS: src = cast(LazyBuffer, self.op.src[0]) if src.st.contiguous and src.st.size() == src.base.st.size() and (src.realized or not src.base.op.op == LoadOps.CONST) and (not src.realized or not isinstance(src.realized, RawConst)): - #for c in src.children: print(c) return src.schedule(seen) + [(self.op, self, ())] # realize the past and exec the AST ret = [] for x in op.buffers: ret += x.schedule(seen) + + # TODO: this belongs in the schedule in some way self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key)) # run the ast and log the op @@ -198,29 +198,14 @@ class LazyBuffer: return ret + [(op, self, tuple(base_bufs))] def realize(self:LazyBuffer) -> LazyBuffer: - if not self.realized: - # NOTE: if you for loop the schedule it's slow because nothing frees - schedule = self.schedule() - #if DEBUG >= 2: print(f"scheduled {len(schedule)}") - while len(schedule): - op,out,buffers = schedule.pop(0) - if DEBUG >= 3: - from extra.utils import print_tree # type: ignore - print_tree(op) - if op.op in LoadOps: - LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out) - # TODO: why can't we delete these ops? - else: - out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **self._device_extra_args()) - del out.op - for v in out.views: del v.op - assert out.realized and isinstance(out.realized, (RawConst, Device[out.device].buffer)), f"device mismatch on realized got {type(out.realized)} expected {out.device}" - assert out.realized.dtype == out.dtype, "realized dtype is incorrect" + if not self.realized: run_schedule(self.schedule()) return self + # *** creation/special ops *** + @staticmethod - def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer: - return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {}) + def loadop(op, shape, dtype, device, arg=None, src=None, val_vals=None) -> LazyBuffer: + return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, val_vals if val_vals else {}) # create a constant with the shape and dtype of self def const(self, val:Union[float, int]) -> LazyBuffer: @@ -229,11 +214,11 @@ class LazyBuffer: def contiguous(self:LazyBuffer) -> LazyBuffer: if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one - return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals) + return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals) @staticmethod def fromCPU(x: np.ndarray) -> LazyBuffer: - return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x)) + return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x)) def prepare_transfer(self): self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self @@ -289,7 +274,7 @@ class LazyBuffer: # *** movement ops *** - def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer: + def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer: if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children: return self.op.replace_with_movement_ops([(op, arg)]) if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous: @@ -315,18 +300,17 @@ class LazyBuffer: assert isinstance(self.op.src[0], LazyBuffer) self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? return self.op.src[0].reshape(arg) - return self.shuffle_and_prune_movement_ops(self.st.reshape(arg), MovementOps.RESHAPE, arg) + return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg) def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer: if all(b == 0 and e == 0 for b,e in arg): return self if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)])) - return self.shuffle_and_prune_movement_ops(self.st.pad(arg), MovementOps.PAD, arg) + return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg) def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer: if self.shape == arg: return self - if not self.realized and self.op.op == MovementOps.EXPAND: - return self.op.src[0].expand(arg) - return self.shuffle_and_prune_movement_ops(self.st.expand(arg), MovementOps.EXPAND, arg) + if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg) + return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg) def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: if arg == tuple(range(len(self.shape))): return self @@ -349,17 +333,17 @@ class LazyBuffer: if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape): self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why?? return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape) - return self.shuffle_and_prune_movement_ops(self.st.permute(arg), MovementOps.PERMUTE, arg) + return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg) def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)])) - return self.shuffle_and_prune_movement_ops(self.st.shrink(arg), MovementOps.SHRINK, arg) + return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg) def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer: if all(a == 1 for a in arg): return self if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg))) - return self.shuffle_and_prune_movement_ops(self.st.stride(arg), MovementOps.STRIDE, arg) + return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg) def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer: y = self @@ -393,9 +377,28 @@ MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = { MovementOps.STRIDE: LazyBuffer.stride, } -# *** loadop realization (unrelated to lazy) *** +# *** realization (unrelated to lazy) *** + +def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]): + # NOTE: if you for loop the schedule it's slow because nothing frees + while len(schedule): + op,out,buffers = schedule.pop(0) + log_schedule_item(op, out, buffers) + if DEBUG >= 3: + from extra.utils import print_tree # type: ignore + print_tree(op) + if op.op in LoadOps: + LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out) + # TODO: why can't we delete these ops? + else: + out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args()) + del out.op + for v in out.views: del v.op + assert out.realized and isinstance(out.realized, (RawConst, Device[out.device].buffer)), f"device mismatch on realized got {type(out.realized)} expected {out.device}" + assert out.realized.dtype == out.dtype, "realized dtype is incorrect" def _realize_contiguous(buffer: LazyBuffer) -> None: + # this is just a copy now, if it's not a copy schedule will handle it src = cast(LazyBuffer, buffer.op.src[0]) buffer.realized = src.realized assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}" diff --git a/tinygrad/ops.py b/tinygrad/ops.py index bfb71a9343..3789037ebc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -11,8 +11,8 @@ from dataclasses import dataclass # NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702 class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702 -class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702 +class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702 class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702 # Ops below this line are not allowed in ASTs class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702 diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 1e6fb31036..d577b68acc 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools -from typing import Tuple, List, Optional, NamedTuple +from dataclasses import dataclass +from typing import Tuple, List, Optional from tinygrad.helpers import prod, all_int from tinygrad.shape.symbolic import Node, NumNode, is_sym_int, sint @@ -14,7 +15,8 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]: for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides return filter_strides(shape, tuple(strides)) -class View(NamedTuple): +@dataclass(frozen=True) +class View: shape:Tuple[sint, ...] strides:Tuple[sint, ...] offset:sint