mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
more lazy cleanup (#1938)
* small lazy cleanups * a few more * cleanups * no more realizing in the scheduler test * a few more minor things * that was just wrong * fix graph. the graph test was completely useless * make graph usable * fix op graph
This commit is contained in:
@@ -3,14 +3,24 @@
|
||||
# NOTE: this has overlap with external_test_opt.py
|
||||
|
||||
import unittest
|
||||
from typing import List, Optional
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.helpers import DEBUG, dtypes
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.graph import log_schedule_item
|
||||
from tinygrad import nn
|
||||
|
||||
def check_schedule(t:Tensor, allowed:int):
|
||||
sched = [s for s in t.lazydata.schedule() if s[0].op not in LoadOps]
|
||||
def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=None):
|
||||
seen = set()
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
for s in pre.lazydata.schedule(seen.copy()):
|
||||
log_schedule_item(*s)
|
||||
seen.add(s[1])
|
||||
sched = t.lazydata.schedule(seen)
|
||||
for s in sched: log_schedule_item(*s)
|
||||
sched = [s for s in sched if s[0].op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
from extra.utils import print_tree
|
||||
@@ -111,8 +121,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = Tensor.empty(10)
|
||||
c = a+b
|
||||
d = a+b
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
check_schedule(d, 0, [c])
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_cache_binaryop_reshaped(self):
|
||||
@@ -120,25 +129,14 @@ class TestSchedule(unittest.TestCase):
|
||||
b = Tensor.empty(10)
|
||||
c = a+b
|
||||
d = a.reshape(10,1)+b.reshape(10,1)
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
check_schedule(d, 0, [c])
|
||||
|
||||
def test_cache_binaryop_transpose(self):
|
||||
a = Tensor.empty(10,10)
|
||||
b = Tensor.empty(10,10)
|
||||
c = (a.T*b.T).T #.contiguous()
|
||||
d = a*b
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_cache_binaryop_transpose_realized(self):
|
||||
a = Tensor.randn(10,10).realize()
|
||||
b = Tensor.randn(10,10).realize()
|
||||
c = (a.T*b.T).T
|
||||
d = a*b
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
check_schedule(d, 0, [c])
|
||||
|
||||
def test_cache_two_reduceops(self):
|
||||
a = Tensor.empty(10)
|
||||
@@ -162,23 +160,19 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_fold_conv_relu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
c1.weight.realize()
|
||||
c1.bias.realize()
|
||||
|
||||
# run
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
out = c1(img).relu()
|
||||
check_schedule(out, 1)
|
||||
check_schedule(out, 1, [c1.weight, c1.bias])
|
||||
|
||||
def test_fold_conv_elu(self):
|
||||
c1 = nn.Conv2d(3,16,3)
|
||||
c1.weight.realize()
|
||||
c1.bias.realize()
|
||||
|
||||
# run
|
||||
img = Tensor.ones(2,3,64,64)
|
||||
out = c1(img).elu()
|
||||
check_schedule(out, 1)
|
||||
check_schedule(out, 1, [c1.weight, c1.bias])
|
||||
|
||||
def test_two_sum(self):
|
||||
img = Tensor.empty(64,64)
|
||||
@@ -206,8 +200,7 @@ class TestSchedule(unittest.TestCase):
|
||||
b = Tensor.empty(16)
|
||||
c = a+b
|
||||
d = (a+b).reshape(16,1)
|
||||
c.realize()
|
||||
check_schedule(d, 0)
|
||||
check_schedule(d, 0, [c])
|
||||
|
||||
def test_multi_permute_should_collapse(self):
|
||||
a = Tensor.empty(4,4,4,4)
|
||||
@@ -224,19 +217,6 @@ class TestSchedule(unittest.TestCase):
|
||||
out = c.sum() + d.sum()
|
||||
check_schedule(out, 1)
|
||||
|
||||
"""
|
||||
def test_reshape_doesnt_matter(self):
|
||||
a = Tensor.empty(10)
|
||||
b = a.reshape(10,1)
|
||||
self.assertIs(a.lazydata.backing, b.lazydata.backing)
|
||||
|
||||
def test_permute_doesnt_matter(self):
|
||||
a = Tensor.empty(10, 10)
|
||||
b = a.permute(1,0)
|
||||
c = a.reshape(10, 1, 10).permute(2,1,0)
|
||||
self.assertIs(b.lazydata.backing, c.lazydata.backing)
|
||||
"""
|
||||
|
||||
# NOTE: for this to pass, LazyViews must be children of LazyBuffers so the (a+b) runs first
|
||||
@unittest.skip("not real world")
|
||||
def test_children_dont_push(self):
|
||||
@@ -255,8 +235,7 @@ class TestSchedule(unittest.TestCase):
|
||||
e = keep_me.sum() # give keep_me a child (NOTE: BinaryOps won't be a child since it will instant fuse)
|
||||
d = keep_me+c
|
||||
check_schedule(d, 2)
|
||||
d.realize()
|
||||
check_schedule(keep_me, 0)
|
||||
check_schedule(keep_me, 0, [d])
|
||||
|
||||
@unittest.skip("failing in old lazy")
|
||||
def test_permute_breaks_fusion(self):
|
||||
@@ -294,7 +273,6 @@ class TestSchedule(unittest.TestCase):
|
||||
# NOOP, 3 convs, contiguous
|
||||
check_schedule(x, 5)
|
||||
|
||||
@unittest.skip("failing now with contig")
|
||||
def test_image_conv_fusion_minimal(self):
|
||||
b1 = Tensor.empty(16)
|
||||
b2 = Tensor.empty(16)
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest
|
||||
import networkx as nx # type: ignore
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.graph import G, log_op, prune_graph
|
||||
from tinygrad.ops import BinaryOps, LazyOp, MovementOps, ReduceOps
|
||||
|
||||
def buf(*shp): return Tensor.ones(*shp, device="CPU").lazydata
|
||||
|
||||
class TestGraph(unittest.TestCase):
|
||||
def setUp(self):
|
||||
G.clear()
|
||||
|
||||
def helper_compare_graph(self, RG: nx.DiGraph):
|
||||
assert nx.is_isomorphic(G, RG, node_match=lambda x,y: x["label"] == y["label"], edge_match=lambda x,y: x["label"] == y["label"] if "label" in y else True)
|
||||
|
||||
def test_add_graph(self):
|
||||
a = buf(4,4)
|
||||
b = buf(4,4)
|
||||
ast = LazyOp(BinaryOps.ADD, (a,b))
|
||||
ret = buf(4,4)
|
||||
|
||||
RG = nx.DiGraph()
|
||||
RG.add_node(0, label="(4, 4)")
|
||||
RG.add_node(1, label="(4, 4)")
|
||||
RG.add_node(2, label="(4, 4)")
|
||||
RG.add_edge(0, 2, label="ADD")
|
||||
RG.add_edge(1, 2, label="ADD")
|
||||
|
||||
log_op(ret, ast, show_graph=True)
|
||||
self.helper_compare_graph(RG)
|
||||
|
||||
def test_add_sum_graph(self):
|
||||
a = buf(4,4)
|
||||
b = buf(1,1)
|
||||
op0 = LazyOp(MovementOps.RESHAPE, (b,), (4, 4))
|
||||
op1 = LazyOp(BinaryOps.ADD, (a,op0))
|
||||
ast = LazyOp(ReduceOps.SUM, (op1,), (1,1))
|
||||
ret = buf(1,1)
|
||||
|
||||
RG = nx.DiGraph()
|
||||
RG.add_node(0, label="(4, 4)")
|
||||
RG.add_node(1, label="(1, 1)")
|
||||
RG.add_node(2, label="{(4, 4), (1, 1)}\n(1, 1)")
|
||||
RG.add_edge(0, 2, label="RES.ADD.SUM")
|
||||
RG.add_edge(1, 2, label="RES.ADD.SUM")
|
||||
|
||||
log_op(ret, ast, show_graph=True)
|
||||
self.helper_compare_graph(RG)
|
||||
|
||||
def test_add_graph_prune(self):
|
||||
a = buf(1,1)
|
||||
ast = LazyOp(MovementOps.RESHAPE, (a,), (4, 4))
|
||||
ret = buf(4,4)
|
||||
log_op(ret, ast, show_graph=True)
|
||||
|
||||
b = buf(4,4)
|
||||
ast = LazyOp(BinaryOps.ADD, (ret,b))
|
||||
ret = buf(4,4)
|
||||
log_op(ret, ast, show_graph=True)
|
||||
prune_graph()
|
||||
|
||||
RG = nx.DiGraph()
|
||||
RG.add_node(0, label="(1, 1)")
|
||||
RG.add_node(1, label="(4, 4)")
|
||||
RG.add_node(2, label="(4, 4)")
|
||||
RG.add_edge(0, 2) # edge connecting pruned nodes
|
||||
RG.add_edge(1, 2, label="ADD")
|
||||
|
||||
self.helper_compare_graph(RG)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,13 +1,12 @@
|
||||
import os, atexit, itertools
|
||||
import os, atexit
|
||||
try:
|
||||
import networkx as nx # type: ignore
|
||||
except ImportError:
|
||||
nx = None # graph won't work
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, TernaryOps, Op, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, PRUNEGRAPH, DEBUG, GlobalCounters
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from typing import Dict, List, TYPE_CHECKING, Tuple, cast
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, OpType, LazyOp
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, GlobalCounters
|
||||
|
||||
if TYPE_CHECKING: from tinygrad.lazy import LazyBuffer
|
||||
|
||||
@@ -24,7 +23,6 @@ if DEBUG >= 2:
|
||||
if GRAPH:
|
||||
def save_graph_exit():
|
||||
for k,v in cnts.items(): print(k, v)
|
||||
if PRUNEGRAPH: prune_graph()
|
||||
print("saving", G)
|
||||
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
|
||||
# -Gnslimit=100 can make it finish, but you won't like results
|
||||
@@ -40,6 +38,7 @@ def nm(x):
|
||||
return x.node_id
|
||||
|
||||
def get_sop(op: List[Op]):
|
||||
op = [x for x in op if x not in BufferOps]
|
||||
if len(op) <= 2: return '.'.join([str(y).split(".")[1] for y in op][::-1])
|
||||
if len(op) <= 4: return '.'.join([str(y).split(".")[1][0:3] for y in op][::-1])
|
||||
return str(len(op))
|
||||
@@ -48,36 +47,28 @@ def str_dtype(dtyp):
|
||||
ret = str(dtyp)[7:]
|
||||
return "" if ret == 'float' else f"\n{ret}"
|
||||
|
||||
def log_op(ret: 'LazyBuffer', ast: LazyOp, show_graph: Optional[bool] = None, phantom=False):
|
||||
if show_graph is None: show_graph = bool(GRAPH)
|
||||
def log_schedule_item(iop: LazyOp, ret: 'LazyBuffer', inp: Tuple['LazyBuffer', ...]):
|
||||
show_graph = bool(GRAPH)
|
||||
if not DEBUG and not show_graph: return
|
||||
op: List[Op] = [x.op for x in ast.get_lazyops()]
|
||||
inp: List['LazyBuffer'] = [x for x in ast.buffers if not isinstance(x.realized, RawConst) or GRAPH > 1]
|
||||
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps]
|
||||
if iop.op == LoadOps.CONTIGUOUS: setattr(ret, 'node_id', nm(cast('LazyBuffer', iop.src[0]).base))
|
||||
if iop.op in {LoadOps.CONST, LoadOps.CONTIGUOUS}: return
|
||||
|
||||
op: List[Op] = [x.op for x in iop.get_lazyops()]
|
||||
oporder = [LoadOps, TernaryOps, ReduceOps, BinaryOps, UnaryOps, MovementOps, BufferOps]
|
||||
optype = type(sorted(op, key=lambda x: oporder.index(type(x)))[0])
|
||||
cnts[optype] += 1
|
||||
if DEBUG >= 6: print(f"{op} : {', '.join([f'{x.shape}-<{nm(x)}>' for x in inp])} -> {ret.shape}-<{nm(ret)}>")
|
||||
if show_graph:
|
||||
top_colors = {LoadOps: '#FFFF80', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#ff8080"}
|
||||
dashed = (optype == LoadOps and hasattr(ret, "_backing")) or (hasattr(ret, "st") and not ret.st.contiguous) # type: ignore
|
||||
|
||||
assert ret.base == ret, "all outputs based"
|
||||
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#8080ff", BinaryOps: "#c0c0c0", MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#FF8080'}
|
||||
for x in inp:
|
||||
G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060' if phantom else 'black')
|
||||
assert x.base == x, "all inputs based"
|
||||
#assert nm(x) in G.nodes, "all inputs seen"
|
||||
G.add_edge(nm(x), nm(ret), label=get_sop(op), color='#00000060')
|
||||
if 'label' not in G.nodes[nm(x)]:
|
||||
G.nodes[nm(x)]['label'] = str(x.shape)+str_dtype(ret.dtype)
|
||||
if nm(ret) not in G.nodes: G.add_node(nm(ret))
|
||||
|
||||
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)
|
||||
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('60' if phantom else ('80' if dashed else str()))) if optype in top_colors else "#ffffff"
|
||||
G.nodes[nm(ret)]['color'] = 'white' if phantom else 'black'
|
||||
G.nodes[nm(ret)]['style'] = ('filled, dashed' if dashed else 'filled')
|
||||
G.nodes[nm(ret)]['prunable'] = optype in [LoadOps, MovementOps]
|
||||
|
||||
# prune movementops and loadops
|
||||
def prune_graph():
|
||||
dead_nodes = []
|
||||
for n in G.nodes:
|
||||
if 'prunable' in G.nodes[n] and G.nodes[n]['prunable']:
|
||||
G.add_edges_from([(x, y) for (x,_),(_,y) in itertools.product(G.in_edges(n), G.out_edges(n))])
|
||||
dead_nodes.append(n)
|
||||
G.remove_nodes_from(dead_nodes)
|
||||
G.nodes[nm(ret)]['label'] = (str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape))+str_dtype(ret.dtype)+(f"\n{iop.op}" if iop.op in LoadOps else "")
|
||||
G.nodes[nm(ret)]['fillcolor'] = top_colors[optype]
|
||||
G.nodes[nm(ret)]['color'] = 'black'
|
||||
G.nodes[nm(ret)]['style'] = 'filled'
|
||||
|
||||
105
tinygrad/lazy.py
105
tinygrad/lazy.py
@@ -4,8 +4,8 @@ from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapp
|
||||
from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.graph import log_op
|
||||
from tinygrad.helpers import GRAPH, DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
|
||||
from tinygrad.graph import log_schedule_item
|
||||
from tinygrad.helpers import DEBUG, prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
@@ -27,7 +27,8 @@ REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE
|
||||
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT>=2, OPT>=2
|
||||
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT>=3, OPT>=3
|
||||
|
||||
# **** realize functions ****
|
||||
# **** ast fixing functions ****
|
||||
|
||||
def _ast_reduceops(op:LazyOp) -> LazyOp:
|
||||
# TODO: this can also corealize a binary op after the reduce, not just before
|
||||
src = op.src[0]
|
||||
@@ -71,7 +72,7 @@ def _replace_bufferops(op:LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
|
||||
if x.base in base_bufs:
|
||||
replacements[x] = LazyOp(BufferOps.MEM, (), MemBuffer(base_bufs.index(x.base)+1, x.dtype, st))
|
||||
elif x.realized and isinstance(x.realized, RawConst):
|
||||
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.realized.dtype, st))
|
||||
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(x.realized._buf, x.dtype, st))
|
||||
elif not x.realized and x.base.op.op == LoadOps.CONST:
|
||||
replacements[x] = LazyOp(BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st))
|
||||
else:
|
||||
@@ -102,8 +103,8 @@ UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2,
|
||||
|
||||
class LazyBuffer:
|
||||
__deletable__ = ('op',)
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:LazyOp, dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
|
||||
self.st: ShapeTracker = st # NOTE: this is not a copy! this should be a "read-only" ShapeTracker
|
||||
def __init__(self, device:str, st:ShapeTracker, optype:OpType, op:Optional[LazyOp], dtype:DType, var_vals:Dict[Variable,int], src:Optional[RawBuffer]=None, base:Optional[LazyBuffer]=None):
|
||||
self.st: ShapeTracker = st
|
||||
self._var_vals: Dict[Variable, int] = var_vals
|
||||
self.device, self.shape, self.optype, self._dtype = device, self.st.shape, optype, dtype
|
||||
self._realized: Optional[RawBuffer] = src
|
||||
@@ -111,18 +112,16 @@ class LazyBuffer:
|
||||
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
|
||||
self.children: WeakSet = WeakSet()
|
||||
self.views: WeakSet = WeakSet()
|
||||
# NOTE: op should be read only after construction of LazyBuffer
|
||||
self.op: LazyOp = op
|
||||
# NOTE: op should be read only after construction of LazyBuffer. it is now with schedule
|
||||
if op is not None:
|
||||
self.op: LazyOp = op
|
||||
for x in op.buffers: x.children.add(self)
|
||||
assert optype != MovementOps or (base is not None and base.optype != MovementOps), "MovementOps must be based"
|
||||
self._base = base
|
||||
if base: base.views.add(self)
|
||||
for x in op.buffers: x.children.add(self)
|
||||
else: assert st.contiguous, "unbased LazyBuffers must be contiguous"
|
||||
if not LAZY: self.realize()
|
||||
|
||||
# log phantom ops to the graph
|
||||
if GRAPH >= 3:
|
||||
log_op(self, self.op, phantom=True)
|
||||
|
||||
@property
|
||||
def var_vals_key(self): return tuple(sorted(self.var_vals.keys()))
|
||||
|
||||
@@ -148,7 +147,7 @@ class LazyBuffer:
|
||||
assert self._base is None, "no setting var_vals of based LazyBuffers"
|
||||
self._var_vals = val
|
||||
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if not self._realized else self._realized} st={self.st}>"
|
||||
def __repr__(self): return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
|
||||
@property
|
||||
def key(self):
|
||||
if self.realized: return (self.dtype, self.realized.key, self.st, self.var_vals_key)
|
||||
@@ -161,19 +160,19 @@ class LazyBuffer:
|
||||
def map_buffers(self, real_srcs: Mapping[LazyBuffer, Union[LazyBuffer, LazyOp]]): return real_srcs.get(self, self)
|
||||
def get_lazyops(self) -> List[LazyOp]: return []
|
||||
|
||||
# *** scheduling ***
|
||||
|
||||
def schedule(self, seen=None) -> List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]:
|
||||
if seen is None: seen = set()
|
||||
if self in seen or self.realized: return []
|
||||
seen.add(self)
|
||||
if self.optype is MovementOps: return self.base.schedule(seen)
|
||||
|
||||
op = self.op if self.op.op != LoadOps.CONTIGUOUS else LazyOp(UnaryOps.NOOP, self.op.src)
|
||||
if op.op in LoadOps: return [(self.op, self, ())]
|
||||
if self.optype is MovementOps: return self.base.schedule(seen)
|
||||
|
||||
if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape)
|
||||
elif self.optype is ReduceOps:
|
||||
op = _ast_reduceops(op)
|
||||
if op.op in BinaryOps: op = _ast_binaryops(op, self.shape)
|
||||
elif self.optype is ReduceOps: op = _ast_reduceops(op)
|
||||
|
||||
# HACK: image shape can be wrong, hot cast it back to a normal float
|
||||
if isinstance(self.dtype, ImageDType) and (prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x]%4 == 0 for x in self.st.unit_stride_axes())):
|
||||
@@ -185,12 +184,13 @@ class LazyBuffer:
|
||||
if self.op.op == LoadOps.CONTIGUOUS:
|
||||
src = cast(LazyBuffer, self.op.src[0])
|
||||
if src.st.contiguous and src.st.size() == src.base.st.size() and (src.realized or not src.base.op.op == LoadOps.CONST) and (not src.realized or not isinstance(src.realized, RawConst)):
|
||||
#for c in src.children: print(c)
|
||||
return src.schedule(seen) + [(self.op, self, ())]
|
||||
|
||||
# realize the past and exec the AST
|
||||
ret = []
|
||||
for x in op.buffers: ret += x.schedule(seen)
|
||||
|
||||
# TODO: this belongs in the schedule in some way
|
||||
self.var_vals = dict(sorted(merge_dicts([buf.var_vals for buf in op.buffers]).items(), key=lambda kv:cast(Variable,kv[0]).key))
|
||||
|
||||
# run the ast and log the op
|
||||
@@ -198,29 +198,14 @@ class LazyBuffer:
|
||||
return ret + [(op, self, tuple(base_bufs))]
|
||||
|
||||
def realize(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized:
|
||||
# NOTE: if you for loop the schedule it's slow because nothing frees
|
||||
schedule = self.schedule()
|
||||
#if DEBUG >= 2: print(f"scheduled {len(schedule)}")
|
||||
while len(schedule):
|
||||
op,out,buffers = schedule.pop(0)
|
||||
if DEBUG >= 3:
|
||||
from extra.utils import print_tree # type: ignore
|
||||
print_tree(op)
|
||||
if op.op in LoadOps:
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out)
|
||||
# TODO: why can't we delete these ops?
|
||||
else:
|
||||
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **self._device_extra_args())
|
||||
del out.op
|
||||
for v in out.views: del v.op
|
||||
assert out.realized and isinstance(out.realized, (RawConst, Device[out.device].buffer)), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
|
||||
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
|
||||
if not self.realized: run_schedule(self.schedule())
|
||||
return self
|
||||
|
||||
# *** creation/special ops ***
|
||||
|
||||
@staticmethod
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, {})
|
||||
def loadop(op, shape, dtype, device, arg=None, src=None, val_vals=None) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(tuple(shape)), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, val_vals if val_vals else {})
|
||||
|
||||
# create a constant with the shape and dtype of self
|
||||
def const(self, val:Union[float, int]) -> LazyBuffer:
|
||||
@@ -229,11 +214,11 @@ class LazyBuffer:
|
||||
|
||||
def contiguous(self:LazyBuffer) -> LazyBuffer:
|
||||
if not self.realized and self.op.op == LoadOps.CONTIGUOUS: return self # two CONTIGUOUS in a row is one
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, self.var_vals)
|
||||
return self.loadop(LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self, val_vals=self.var_vals)
|
||||
|
||||
@staticmethod
|
||||
def fromCPU(x: np.ndarray) -> LazyBuffer:
|
||||
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, LazyOp(LoadOps.EMPTY, (), None), dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
|
||||
return LazyBuffer("CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), {}, RawNumpyBuffer.fromCPU(x))
|
||||
|
||||
def prepare_transfer(self):
|
||||
self_casted = self.e(UnaryOps.CAST, arg=(dtypes.from_np(self.dtype.np), False)) if dtypes.from_np(self.dtype.np) != self.dtype else self
|
||||
@@ -289,7 +274,7 @@ class LazyBuffer:
|
||||
|
||||
# *** movement ops ***
|
||||
|
||||
def shuffle_and_prune_movement_ops(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
def _movement_op(self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]) -> LazyBuffer:
|
||||
if SHUFFLE_MOVEMENT_OPS and self.optype == BinaryOps and not self.realized and (op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (op == MovementOps.RESHAPE and self.op.op in UnaryOps)) and not self.children:
|
||||
return self.op.replace_with_movement_ops([(op, arg)])
|
||||
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
|
||||
@@ -315,18 +300,17 @@ class LazyBuffer:
|
||||
assert isinstance(self.op.src[0], LazyBuffer)
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].reshape(arg)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.reshape(arg), MovementOps.RESHAPE, arg)
|
||||
return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg)
|
||||
|
||||
def pad(self:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
|
||||
if all(b == 0 and e == 0 for b,e in arg): return self
|
||||
if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad(tuple([(b1+b2, e1+e2) for (b1,e1),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(self.st.pad(arg), MovementOps.PAD, arg)
|
||||
return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg)
|
||||
|
||||
def expand(self: LazyBuffer, arg:Tuple[sint, ...]) -> LazyBuffer:
|
||||
if self.shape == arg: return self
|
||||
if not self.realized and self.op.op == MovementOps.EXPAND:
|
||||
return self.op.src[0].expand(arg)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.expand(arg), MovementOps.EXPAND, arg)
|
||||
if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg)
|
||||
return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg)
|
||||
|
||||
def permute(self: LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if arg == tuple(range(len(self.shape))): return self
|
||||
@@ -349,17 +333,17 @@ class LazyBuffer:
|
||||
if shape_idx_groups := get_contraction(self.op.src[0].shape, self.shape):
|
||||
self.op.src[0].children.discard(self) # NOTE: this is only required in reshape and when pushing permutes, why??
|
||||
return self.op.src[0].permute(tuple(flatten(shape_idx_groups[i] for i in arg))).reshape(self.st.permute(arg).shape)
|
||||
return self.shuffle_and_prune_movement_ops(self.st.permute(arg), MovementOps.PERMUTE, arg)
|
||||
return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg)
|
||||
|
||||
def shrink(self:LazyBuffer, arg:Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
|
||||
if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self
|
||||
if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink(tuple([(b1+b2, b1+e2) for (b1,_),(b2,e2) in zip(self.op.arg, arg)]))
|
||||
return self.shuffle_and_prune_movement_ops(self.st.shrink(arg), MovementOps.SHRINK, arg)
|
||||
return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg)
|
||||
|
||||
def stride(self:LazyBuffer, arg:Tuple[int, ...]) -> LazyBuffer:
|
||||
if all(a == 1 for a in arg): return self
|
||||
if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride(tuple(map(operator.mul, arg, self.op.arg)))
|
||||
return self.shuffle_and_prune_movement_ops(self.st.stride(arg), MovementOps.STRIDE, arg)
|
||||
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
|
||||
|
||||
def replace_with_movement_ops(self: LazyBuffer, ops:List[Tuple[MovementOps, Any]]) -> LazyBuffer:
|
||||
y = self
|
||||
@@ -393,9 +377,28 @@ MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
|
||||
MovementOps.STRIDE: LazyBuffer.stride,
|
||||
}
|
||||
|
||||
# *** loadop realization (unrelated to lazy) ***
|
||||
# *** realization (unrelated to lazy) ***
|
||||
|
||||
def run_schedule(schedule:List[Tuple[LazyOp, LazyBuffer, Tuple[LazyBuffer, ...]]]):
|
||||
# NOTE: if you for loop the schedule it's slow because nothing frees
|
||||
while len(schedule):
|
||||
op,out,buffers = schedule.pop(0)
|
||||
log_schedule_item(op, out, buffers)
|
||||
if DEBUG >= 3:
|
||||
from extra.utils import print_tree # type: ignore
|
||||
print_tree(op)
|
||||
if op.op in LoadOps:
|
||||
LOAD_OPS_DISPATCHER[cast(LoadOps, op.op)](out)
|
||||
# TODO: why can't we delete these ops?
|
||||
else:
|
||||
out.realized = Device[out.device].exec_ast(op, output=out, inputs=[x.realized for x in buffers], var_vals=out.var_vals, **out._device_extra_args())
|
||||
del out.op
|
||||
for v in out.views: del v.op
|
||||
assert out.realized and isinstance(out.realized, (RawConst, Device[out.device].buffer)), f"device mismatch on realized got {type(out.realized)} expected {out.device}"
|
||||
assert out.realized.dtype == out.dtype, "realized dtype is incorrect"
|
||||
|
||||
def _realize_contiguous(buffer: LazyBuffer) -> None:
|
||||
# this is just a copy now, if it's not a copy schedule will handle it
|
||||
src = cast(LazyBuffer, buffer.op.src[0])
|
||||
buffer.realized = src.realized
|
||||
assert buffer.dtype == src.dtype, f"contiguous dtype mismatch, expecting {buffer.dtype}, got {src.dtype}"
|
||||
|
||||
@@ -11,8 +11,8 @@ from dataclasses import dataclass
|
||||
# NOTE: rdna3 only has RECIP and not DIV. DIV and POW are on the chopping block
|
||||
class UnaryOps(Enum): NOOP = auto(); EXP2 = auto(); LOG2 = auto(); CAST = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
class BinaryOps(Enum): ADD = auto(); SUB = auto(); MUL = auto(); DIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class TernaryOps(Enum): MULACC = auto(); WHERE = auto() # noqa: E702
|
||||
class ReduceOps(Enum): SUM = auto(); MAX = auto() # noqa: E702
|
||||
class BufferOps(Enum): MEM = auto(); CONST = auto() # noqa: E702
|
||||
# Ops below this line are not allowed in ASTs
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import functools
|
||||
from typing import Tuple, List, Optional, NamedTuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional
|
||||
from tinygrad.helpers import prod, all_int
|
||||
from tinygrad.shape.symbolic import Node, NumNode, is_sym_int, sint
|
||||
|
||||
@@ -14,7 +15,8 @@ def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
|
||||
return filter_strides(shape, tuple(strides))
|
||||
|
||||
class View(NamedTuple):
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
shape:Tuple[sint, ...]
|
||||
strides:Tuple[sint, ...]
|
||||
offset:sint
|
||||
|
||||
Reference in New Issue
Block a user