early gate the graph (#3070)

This commit is contained in:
George Hotz
2024-01-09 20:17:13 -08:00
committed by GitHub
parent ff0d6e4551
commit 2495ca95c7
4 changed files with 36 additions and 38 deletions

View File

@@ -7,7 +7,7 @@ from typing import List, Optional
from tinygrad.tensor import Tensor
from tinygrad.ops import LoadOps
from tinygrad.device import Device, Compiled
from tinygrad.helpers import DEBUG
from tinygrad.helpers import DEBUG, GRAPH
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad import nn, dtypes
@@ -17,10 +17,11 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
if to_prerealize:
for pre in to_prerealize:
for s in pre.lazydata.schedule(seen.copy()):
realized_lazybuffer(s.out, 0)
if GRAPH: realized_lazybuffer(s.out, 0)
seen.add(s.out)
sched = t.lazydata.schedule(seen)
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
if GRAPH:
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if len(sched) != allowed or DEBUG >= 3:

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
from typing import List, Any, DefaultDict
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp, GlobalCounters
from tinygrad.device import Device
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, getenv
from tinygrad.helpers import GRAPHPATH, DEBUG, getenv
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.shape.symbolic import NumNode
@@ -47,37 +47,35 @@ def str_dtype(dtyp):
return "" if ret == 'float' else f"\n{ret}"
def realized_lazybuffer(lb, num):
if GRAPH:
init_graph()
G.nodes[nm(lb)]['style'] = '"filled,bold"'
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{nm(lb.realized)}"'
init_graph()
G.nodes[nm(lb)]['style'] = '"filled,bold"'
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{"FAKE" if lb.realized is None else nm(lb.realized)}"'
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
def log_lazybuffer(lb, scheduled=False):
if GRAPH:
init_graph()
if lb.base != lb:
offset = lb.st.expr_node(NumNode(0))[0]
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
lb = lb.base
if lb.realized is None:
for x in lb.srcs:
if nm(x) not in G.nodes: log_lazybuffer(x)
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
label = '"' + \
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
str_dtype(lb.dtype)+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
else:
if nm(lb) not in G.nodes:
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
init_graph()
if lb.base != lb:
offset = lb.st.expr_node(NumNode(0))[0]
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
lb = lb.base
if lb.realized is None:
for x in lb.srcs:
if nm(x) not in G.nodes: log_lazybuffer(x)
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
label = '"' + \
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
str_dtype(lb.dtype)+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
else:
if nm(lb) not in G.nodes:
# realized but unseen?
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
def _tree(lazydata, cycles, cnt, prefix=""):
cnt[0] += 1

View File

@@ -4,7 +4,7 @@ import numpy as np
from collections import defaultdict
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict, cast
from tinygrad.dtype import dtypes, DType, ImageDType
from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same
from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same, GRAPH
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
from tinygrad.shape.symbolic import sint, Variable
from tinygrad.shape.shapetracker import ShapeTracker
@@ -214,9 +214,9 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]]):
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
if buf in allbufs or buf.base.realized: return
log_lazybuffer(buf)
if GRAPH: log_lazybuffer(buf, scheduled)
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
@@ -249,14 +249,13 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
if seen is None: seen = set()
for out in outs: log_lazybuffer(out, scheduled=True)
# start by just realizing the buffers passed in
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Set[LazyBuffer] = set()
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
# check if we have to realize pads
for p in simple_pads:

View File

@@ -2,7 +2,7 @@ from typing import List, Dict, Optional, cast
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats, InterpretedASTRunner
from tinygrad.graph import print_tree, realized_lazybuffer
from tinygrad.helpers import colored, getenv
from tinygrad.helpers import colored, getenv, GRAPH
from tinygrad.shape.symbolic import Variable
# *** schedule running ***
@@ -47,4 +47,4 @@ def run_schedule(schedule:List[ScheduleItem]):
assert all(x.realized is not None for x in si.inputs), f"can't run, some inputs aren't realized {[x for x in si.inputs if x.realized is None]}"
if prg: prg.exec([si.out.realized] + [cast(Buffer, x.realized) for x in si.inputs], si.var_vals)
else: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
realized_lazybuffer(si.out, GlobalCounters.kernel_count)
if GRAPH: realized_lazybuffer(si.out, GlobalCounters.kernel_count)