more lazy cleanup (#1938)

* small lazy cleanups

* a few more

* cleanups

* no more realizing in the scheduler test

* a few more minor things

* that was just wrong

* fix graph. the graph test was completely useless

* make graph usable

* fix op graph
This commit is contained in:
George Hotz
2023-09-29 00:53:29 -07:00
committed by GitHub
parent 2a49f7e456
commit 22b8576887
6 changed files with 99 additions and 198 deletions

View File

@@ -3,14 +3,24 @@
# NOTE: this has overlap with external_test_opt.py
import unittest
from typing import List, Optional
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.helpers import DEBUG, dtypes
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.graph import log_schedule_item
from tinygrad import nn
def check_schedule(t:Tensor, allowed:int):
sched = [s for s in t.lazydata.schedule() if s[0].op not in LoadOps]
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None):
seen = set()
if to_prerealize:
for pre in to_prerealize:
for s in pre.lazydata.schedule(seen.copy()):
log_schedule_item(*s)
seen.add(s[1])
sched = t.lazydata.schedule(seen)
for s in sched: log_schedule_item(*s)
sched = [s for s in sched if s[0].op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:
from extra.utils import print_tree
@@ -111,8 +121,7 @@ class TestSchedule(unittest.TestCase):
b = Tensor.empty(10)
c = a+b
d = a+b
c.realize()
check_schedule(d, 0)
check_schedule(d, 0, [c])
@unittest.skip("failing in old lazy")
def test_cache_binaryop_reshaped(self):
@@ -120,25 +129,14 @@ class TestSchedule(unittest.TestCase):
b = Tensor.empty(10)
c = a+b
d = a.reshape(10,1)+b.reshape(10,1)
c.realize()
check_schedule(d, 0)
check_schedule(d, 0, [c])
def test_cache_binaryop_transpose(self):
a = Tensor.empty(10,10)
b = Tensor.empty(10,10)
c = (a.T*b.T).T #.contiguous()
d = a*b
c.realize()
check_schedule(d, 0)
@unittest.skip("failing in old lazy")
def test_cache_binaryop_transpose_realized(self):
a = Tensor.randn(10,10).realize()
b = Tensor.randn(10,10).realize()
c = (a.T*b.T).T
d = a*b
c.realize()
check_schedule(d, 0)
check_schedule(d, 0, [c])
def test_cache_two_reduceops(self):
a = Tensor.empty(10)
@@ -162,23 +160,19 @@ class TestSchedule(unittest.TestCase):
def test_fold_conv_relu(self):
c1 = nn.Conv2d(3,16,3)
c1.weight.realize()
c1.bias.realize()
# run
img = Tensor.ones(2,3,64,64)
out = c1(img).relu()
check_schedule(out, 1)
check_schedule(out, 1, [c1.weight, c1.bias])
def test_fold_conv_elu(self):
c1 = nn.Conv2d(3,16,3)
c1.weight.realize()
c1.bias.realize()
# run
img = Tensor.ones(2,3,64,64)
out = c1(img).elu()
check_schedule(out, 1)
check_schedule(out, 1, [c1.weight, c1.bias])
def test_two_sum(self):
img = Tensor.empty(64,64)
@@ -206,8 +200,7 @@ class TestSchedule(unittest.TestCase):
b = Tensor.empty(16)
c = a+b
d = (a+b).reshape(16,1)
c.realize()
check_schedule(d, 0)
check_schedule(d, 0, [c])
def test_multi_permute_should_collapse(self):
a = Tensor.empty(4,4,4,4)
@@ -224,19 +217,6 @@ class TestSchedule(unittest.TestCase):
out = c.sum() + d.sum()
check_schedule(out, 1)
"""
def test_reshape_doesnt_matter(self):
a = Tensor.empty(10)
b = a.reshape(10,1)
self.assertIs(a.lazydata.backing, b.lazydata.backing)
def test_permute_doesnt_matter(self):
a = Tensor.empty(10, 10)
b = a.permute(1,0)
c = a.reshape(10, 1, 10).permute(2,1,0)
self.assertIs(b.lazydata.backing, c.lazydata.backing)
"""
# NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first
@unittest.skip("not real world")
def test_children_dont_push(self):
@@ -255,8 +235,7 @@ class TestSchedule(unittest.TestCase):
e = keep_me.sum() # give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
d = keep_me+c
check_schedule(d, 2)
d.realize()
check_schedule(keep_me, 0)
check_schedule(keep_me, 0, [d])
@unittest.skip("failing in old lazy")
def test_permute_breaks_fusion(self):
@@ -294,7 +273,6 @@ class TestSchedule(unittest.TestCase):
# NOOP, 3 convs, contiguous
check_schedule(x, 5)
@unittest.skip("failing now with contig")
def test_image_conv_fusion_minimal(self):
b1 = Tensor.empty(16)
b2 = Tensor.empty(16)

View File

@@ -1,73 +0,0 @@
#!/usr/bin/env python
import unittest
import networkx as nx # type: ignore
from tinygrad.tensor import Tensor
from tinygrad.graph import G, log_op, prune_graph
from tinygrad.ops import BinaryOps, LazyOp, MovementOps, ReduceOps
def buf(*shp): return Tensor.ones(*shp, device="CPU").lazydata
class TestGraph(unittest.TestCase):
def setUp(self):
G.clear()
def helper_compare_graph(self, RG: nx.DiGraph):
assert nx.is_isomorphic(G, RG, node_match=lambda x,y: x["label"] == y["label"], edge_match=lambda x,y: x["label"] == y["label"] if "label" in y else True)
def test_add_graph(self):
a = buf(4,4)
b = buf(4,4)
ast = LazyOp(BinaryOps.ADD, (a,b))
ret = buf(4,4)
RG = nx.DiGraph()
RG.add_node(0, label="(4, 4)")
RG.add_node(1, label="(4, 4)")
RG.add_node(2, label="(4, 4)")
RG.add_edge(0, 2, label="ADD")
RG.add_edge(1, 2, label="ADD")
log_op(ret, ast, show_graph=True)
self.helper_compare_graph(RG)
def test_add_sum_graph(self):
a = buf(4,4)
b = buf(1,1)
op0 = LazyOp(MovementOps.RESHAPE, (b,), (4, 4))
op1 = LazyOp(BinaryOps.ADD, (a,op0))
ast = LazyOp(ReduceOps.SUM, (op1,), (1,1))
ret = buf(1,1)
RG = nx.DiGraph()
RG.add_node(0, label="(4, 4)")
RG.add_node(1, label="(1, 1)")
RG.add_node(2, label="{(4, 4), (1, 1)}\n(1, 1)")
RG.add_edge(0, 2, label="RES.ADD.SUM")
RG.add_edge(1, 2, label="RES.ADD.SUM")
log_op(ret, ast, show_graph=True)
self.helper_compare_graph(RG)
def test_add_graph_prune(self):
a = buf(1,1)
ast = LazyOp(MovementOps.RESHAPE, (a,), (4, 4))
ret = buf(4,4)
log_op(ret, ast, show_graph=True)
b = buf(4,4)
ast = LazyOp(BinaryOps.ADD, (ret,b))
ret = buf(4,4)
log_op(ret, ast, show_graph=True)
prune_graph()
RG = nx.DiGraph()
RG.add_node(0, label="(1, 1)")
RG.add_node(1, label="(4, 4)")
RG.add_node(2, label="(4, 4)")
RG.add_edge(0, 2) # edge connecting pruned nodes
RG.add_edge(1, 2, label="ADD")
self.helper_compare_graph(RG)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,13 +1,12 @@
import os, atexit, itertools
import os, atexit
try:
import networkx as nx # type: ignore
except ImportError:
nx = None # graph won't work
from collections import defaultdict
from typing import Dict, List, Optional, TYPE_CHECKING
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, Op, OpType, LazyOp
from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters
from tinygrad.runtime.lib import RawConst
from typing import Dict, List, TYPE_CHECKING, Tuple, cast
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
@@ -24,7 +23,6 @@ if DEBUG >= 2:
if GRAPH:
def save_graph_exit():
for k,v in cnts.items(): print(k, v)
if PRUNEGRAPH: prune_graph()
print("saving", G)
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
# -Gnslimit=100 can make it finish, but you won't like results
@@ -40,6 +38,7 @@ def nm(x):
return x.node_id
def get_sop(op: List[Op]):
op = [x for x in op if x not in BufferOps]
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
return str(len(op))
@@ -48,36 +47,28 @@ def str_dtype(dtyp):
ret = str(dtyp)[7:]
return "" if ret == 'float' else f"\n{ret}"
def log_op(ret: 'LazyBuffer', ast: LazyOp, show_graph: Optional[bool] = None, phantom=False):
if show_graph is None: show_graph = bool(GRAPH)
def log_schedule_item(iop: LazyOp, ret: 'LazyBuffer', inp: Tuple['LazyBuffer', ...]):
show_graph = bool(GRAPH)
if not DEBUG and not show_graph: return
op: List[Op] = [x.op for x in ast.get_lazyops()]
inp: List['LazyBuffer'] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1]
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(cast('LazyBuffer', iop.src[0]).base))
if iop.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
op: List[Op] = [x.op for x in iop.get_lazyops()]
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
cnts[optype] += 1
if DEBUG >= 6: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
if show_graph:
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#ff8080"}
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
assert ret.base == ret, "all outputs based"
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
for x in inp:
G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060' if phantom else 'black')
assert x.base == x, "all inputs based"
#assert nm(x) in G.nodes, "all inputs seen"
G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060')
if 'label' not in G.nodes[nm(x)]:
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
if nm(ret) not in G.nodes: G.add_node(nm(ret))
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('60' if phantom else ('80' if dashed else str()))) if optype in top_colors else "#ffffff"
G.nodes[nm(ret)]['color'] = 'white' if phantom else 'black'
G.nodes[nm(ret)]['style'] = ('filled, dashed' if dashed else 'filled')
G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps]
# prune movementops and loadops
def prune_graph():
dead_nodes = []
for n in G.nodes:
if 'prunable' in G.nodes[n] and G.nodes[n]['prunable']:
G.add_edges_from([(x, y) for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n))])
dead_nodes.append(n)
G.remove_nodes_from(dead_nodes)
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)+(f"\n{iop.op}" if iop.op in LoadOps else "")
G.nodes[nm(ret)]['fillcolor'] = top_colors[optype]
G.nodes[nm(ret)]['color'] = 'black'
G.nodes[nm(ret)]['style'] = 'filled'

View File

@@ -4,8 +4,8 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.graph import log_op
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
from tinygrad.graph import log_schedule_item
from tinygrad.helpers import DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
@@ -27,7 +27,8 @@ REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
# **** realize functions ****
# **** ast fixing functions ****
def _ast_reduceops(op:LazyOp) -> LazyOp:
# TODO: this can also corealize a binary op after the reduce, not just before
src = op.src[0]
@@ -71,7 +72,7 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
if x.base in base_bufs:
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
elif x.realized and isinstance(x.realized, RawConst):
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, st))
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.dtype, st))
elif not x.realized and x.base.op.op == LoadOps.CONST:
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st))
else:
@@ -102,8 +103,8 @@ UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2,
class LazyBuffer:
__deletable__ = ('op',)
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
self.st: ShapeTracker = st
self._var_vals: Dict[Variable, int] = var_vals
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
self._realized: Optional[RawBuffer] = src
@@ -111,18 +112,16 @@ class LazyBuffer:
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children: WeakSet = WeakSet()
self.views: WeakSet = WeakSet()
# NOTE: op should be read only after construction of LazyBuffer
self.op: LazyOp = op
# NOTE: op should be read only after construction of LazyBuffer. it is now with schedule
if op is not None:
self.op: LazyOp = op
for x in op.buffers: x.children.add(self)
assert optype != MovementOps or (base is not None and base.optype != MovementOps), "MovementOps must be based"
self._base = base
if base: base.views.add(self)
for x in op.buffers: x.children.add(self)
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
if not LAZY: self.realize()
# log phantom ops to the graph
if GRAPH >= 3:
log_op(self, self.op, phantom=True)
@property
def var_vals_key(self): return tuple(sorted(self.var_vals.keys()))
@@ -148,7 +147,7 @@ class LazyBuffer:
assert self._base is None, "no setting var_vals of based LazyBuffers"
self._var_vals = val
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if not self._realized else self._realized} st={self.st}>"
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
@property
def key(self):
if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key)
@@ -161,19 +160,19 @@ class LazyBuffer:
def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self)
def get_lazyops(self) -> List[LazyOp]: return []
# *** scheduling ***
def schedule(self, seen=None) -> List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]:
if seen is None: seen = set()
if self in seen or self.realized: return []
seen.add(self)
if self.optype is MovementOps: return self.base.schedule(seen)
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
if op.op in LoadOps: return [(self.op, self, ())]
if self.optype is MovementOps: return self.base.schedule(seen)
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
elif self.optype is ReduceOps:
op = _ast_reduceops(op)
if op.op in BinaryOps: op = _ast_binaryops(op, self.shape)
elif self.optype is ReduceOps: op = _ast_reduceops(op)
# HACK: image shape can be wrong, hot cast it back to a normal float
if isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
@@ -185,12 +184,13 @@ class LazyBuffer:
if self.op.op == LoadOps.CONTIGUOUS:
src = cast(LazyBuffer, self.op.src[0])
if src.st.contiguous and src.st.size() == src.base.st.size() and (src.realized or not src.base.op.op == LoadOps.CONST) and (not src.realized or not isinstance(src.realized, RawConst)):
#for c in src.children: print(c)
return src.schedule(seen) + [(self.op, self, ())]
# realize the past and exec the AST
ret = []
for x in op.buffers: ret += x.schedule(seen)
# TODO: this belongs in the schedule in some way
self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
# run the ast and log the op
@@ -198,29 +198,14 @@ class LazyBuffer:
return ret + [(op, self, tuple(base_bufs))]
def realize(self:LazyBuffer) -> LazyBuffer:
if not self.realized:
# NOTE: if you for loop the schedule it's slow because nothing frees
schedule = self.schedule()
#if DEBUG >= 2: print(f"scheduled {len(schedule)}")
while len(schedule):
op,out,buffers = schedule.pop(0)
if DEBUG >= 3:
from extra.utils import print_tree # type: ignore
print_tree(op)
if op.op in LoadOps:
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out)
# TODO: why can't we delete these ops?
else:
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **self._device_extra_args())
del out.op
for v in out.views: del v.op
assert out.realized and isinstance(out.realized, (RawConst, Device[out.device].buffer)), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
if not self.realized: run_schedule(self.schedule())
return self
# *** creation/special ops ***
@staticmethod
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {})
def loadop(op, shape, dtype, device, arg=None, src=None, val_vals=None) -> LazyBuffer:
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, val_vals if val_vals else {})
# create a constant with the shape and dtype of self
def const(self, val:Union[float, int]) -> LazyBuffer:
@@ -229,11 +214,11 @@ class LazyBuffer:
def contiguous(self:LazyBuffer) -> LazyBuffer:
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals)
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals)
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
def prepare_transfer(self):
self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self
@@ -289,7 +274,7 @@ class LazyBuffer:
# *** movement ops ***
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
@@ -315,18 +300,17 @@ class LazyBuffer:
assert isinstance(self.op.src[0], LazyBuffer)
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].reshape(arg)
return self.shuffle_and_prune_movement_ops(self.st.reshape(arg), MovementOps.RESHAPE, arg)
return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg)
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
if all(b == 0 and e == 0 for b,e in arg): return self
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(self.st.pad(arg), MovementOps.PAD, arg)
return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg)
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
if self.shape == arg: return self
if not self.realized and self.op.op == MovementOps.EXPAND:
return self.op.src[0].expand(arg)
return self.shuffle_and_prune_movement_ops(self.st.expand(arg), MovementOps.EXPAND, arg)
if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg)
return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg)
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if arg == tuple(range(len(self.shape))): return self
@@ -349,17 +333,17 @@ class LazyBuffer:
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape)
return self.shuffle_and_prune_movement_ops(self.st.permute(arg), MovementOps.PERMUTE, arg)
return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg)
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
return self.shuffle_and_prune_movement_ops(self.st.shrink(arg), MovementOps.SHRINK, arg)
return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg)
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
if all(a == 1 for a in arg): return self
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
return self.shuffle_and_prune_movement_ops(self.st.stride(arg), MovementOps.STRIDE, arg)
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
y = self
@@ -393,9 +377,28 @@ MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.STRIDE: LazyBuffer.stride,
}
# *** loadop realization (unrelated to lazy) ***
# *** realization (unrelated to lazy) ***
def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]):
# NOTE: if you for loop the schedule it's slow because nothing frees
while len(schedule):
op,out,buffers = schedule.pop(0)
log_schedule_item(op, out, buffers)
if DEBUG >= 3:
from extra.utils import print_tree # type: ignore
print_tree(op)
if op.op in LoadOps:
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out)
# TODO: why can't we delete these ops?
else:
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args())
del out.op
for v in out.views: del v.op
assert out.realized and isinstance(out.realized, (RawConst, Device[out.device].buffer)), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
def _realize_contiguous(buffer: LazyBuffer) -> None:
# this is just a copy now, if it's not a copy schedule will handle it
src = cast(LazyBuffer, buffer.op.src[0])
buffer.realized = src.realized
assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}"

View File

@@ -11,8 +11,8 @@ from dataclasses import dataclass
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
# Ops below this line are not allowed in ASTs
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import functools
from typing import Tuple, List, Optional, NamedTuple
from dataclasses import dataclass
from typing import Tuple, List, Optional
from tinygrad.helpers import prod, all_int
from tinygrad.shape.symbolic import Node, NumNode, is_sym_int, sint
@@ -14,7 +15,8 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
return filter_strides(shape, tuple(strides))
class View(NamedTuple):
@dataclass(frozen=True)
class View:
shape:Tuple[sint, ...]
strides:Tuple[sint, ...]
offset:sint