mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
early gate the graph (#3070)
This commit is contained in:
@@ -7,7 +7,7 @@ from typing import List, Optional
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import LoadOps
|
||||
from tinygrad.device import Device, Compiled
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.helpers import DEBUG, GRAPH
|
||||
from tinygrad.codegen.linearizer import Linearizer
|
||||
from tinygrad.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad import nn, dtypes
|
||||
@@ -17,10 +17,11 @@ def check_schedule(t:Tensor, allowed:int, to_prerealize:Optional[List[Tensor]]=N
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize:
|
||||
for s in pre.lazydata.schedule(seen.copy()):
|
||||
realized_lazybuffer(s.out, 0)
|
||||
if GRAPH: realized_lazybuffer(s.out, 0)
|
||||
seen.add(s.out)
|
||||
sched = t.lazydata.schedule(seen)
|
||||
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
|
||||
if GRAPH:
|
||||
for i,s in enumerate(sched): realized_lazybuffer(s.out, i+1)
|
||||
if filter_loadops: sched = [s for s in sched if s.ast.op not in LoadOps]
|
||||
if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if len(sched) != allowed or DEBUG >= 3:
|
||||
|
||||
@@ -3,7 +3,7 @@ from collections import defaultdict
|
||||
from typing import List, Any, DefaultDict
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, MovementOps, LoadOps, BufferOps, TernaryOps, Op, LazyOp, GlobalCounters
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.helpers import GRAPH, GRAPHPATH, DEBUG, getenv
|
||||
from tinygrad.helpers import GRAPHPATH, DEBUG, getenv
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.shape.symbolic import NumNode
|
||||
|
||||
@@ -47,37 +47,35 @@ def str_dtype(dtyp):
|
||||
return "" if ret == 'float' else f"\n{ret}"
|
||||
|
||||
def realized_lazybuffer(lb, num):
|
||||
if GRAPH:
|
||||
init_graph()
|
||||
G.nodes[nm(lb)]['style'] = '"filled,bold"'
|
||||
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
|
||||
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{nm(lb.realized)}"'
|
||||
init_graph()
|
||||
G.nodes[nm(lb)]['style'] = '"filled,bold"'
|
||||
G.nodes[nm(lb)]['fillcolor'] = G.nodes[nm(lb)]['fillcolor'][:-2]
|
||||
G.nodes[nm(lb)]['label'] = '"' + G.nodes[nm(lb)]["label"].replace('"', '') + f'\nK:{num} b:{"FAKE" if lb.realized is None else nm(lb.realized)}"'
|
||||
|
||||
top_colors = {LoadOps: '#FFFFa0', UnaryOps: "#c0c0c0", ReduceOps: "#FFA0A0", BinaryOps: "#c0c0c0",
|
||||
MovementOps: "#80ff80", TernaryOps: "#c0c0c0", BufferOps: '#a0a0ff'}
|
||||
def log_lazybuffer(lb, scheduled=False):
|
||||
if GRAPH:
|
||||
init_graph()
|
||||
if lb.base != lb:
|
||||
offset = lb.st.expr_node(NumNode(0))[0]
|
||||
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
|
||||
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
|
||||
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
|
||||
lb = lb.base
|
||||
if lb.realized is None:
|
||||
for x in lb.srcs:
|
||||
if nm(x) not in G.nodes: log_lazybuffer(x)
|
||||
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
||||
label = '"' + \
|
||||
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
||||
str_dtype(lb.dtype)+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
|
||||
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
|
||||
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
|
||||
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
|
||||
else:
|
||||
if nm(lb) not in G.nodes:
|
||||
# realized but unseen?
|
||||
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
||||
init_graph()
|
||||
if lb.base != lb:
|
||||
offset = lb.st.expr_node(NumNode(0))[0]
|
||||
label = f"{lb.st.shape}\n{lb.st.real_strides()}" + (f"\n{offset}" if offset != 0 else "")
|
||||
G.add_node(nm(lb), style='"filled,dashed"', fillcolor="#80ff8080", color="black", label=label)
|
||||
G.add_edge(nm(lb.base), nm(lb), color='#00000060')
|
||||
lb = lb.base
|
||||
if lb.realized is None:
|
||||
for x in lb.srcs:
|
||||
if nm(x) not in G.nodes: log_lazybuffer(x)
|
||||
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
|
||||
label = '"' + \
|
||||
(str(set(x.shape for x in lb.srcs))+"\n"+str(lb.shape) if lb.op in ReduceOps else str(lb.shape)) + \
|
||||
str_dtype(lb.dtype)+f"\n{lb.op}"+(f"\n{lb.arg}" if lb.op in {LoadOps.CONST, UnaryOps.CAST} else "") + \
|
||||
(f"\n{lb.device}" if lb.device != Device.DEFAULT else "") + '"'
|
||||
G.add_node(nm(lb), style='"filled,dashed"', fillcolor=[v for k,v in top_colors.items() if lb.op in k][0] + "80", color="black", label=label)
|
||||
if scheduled: G.nodes[nm(lb)]['shape'] = 'box'
|
||||
else:
|
||||
if nm(lb) not in G.nodes:
|
||||
# realized but unseen?
|
||||
G.add_node(nm(lb), label=f'"{str(lb.base.realized)[5:-1].replace(" ", chr(10))}\nb:{nm(lb.realized)}"', style='filled', fillcolor="#f0c08080")
|
||||
|
||||
def _tree(lazydata, cycles, cnt, prefix=""):
|
||||
cnt[0] += 1
|
||||
|
||||
@@ -4,7 +4,7 @@ import numpy as np
|
||||
from collections import defaultdict
|
||||
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict, cast
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType
|
||||
from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same
|
||||
from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same, GRAPH
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -214,9 +214,9 @@ def _recursive_schedule(out:LazyBuffer, seen:Set[LazyBuffer], realizes:Set[LazyB
|
||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
|
||||
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
|
||||
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]]):
|
||||
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
|
||||
if buf in allbufs or buf.base.realized: return
|
||||
log_lazybuffer(buf)
|
||||
if GRAPH: log_lazybuffer(buf, scheduled)
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
||||
@@ -249,14 +249,13 @@ def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
|
||||
|
||||
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
|
||||
if seen is None: seen = set()
|
||||
for out in outs: log_lazybuffer(out, scheduled=True)
|
||||
|
||||
# start by just realizing the buffers passed in
|
||||
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
|
||||
allbufs: Dict[LazyBuffer, None] = {}
|
||||
simple_pads: Set[LazyBuffer] = set()
|
||||
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
|
||||
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children)
|
||||
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
|
||||
|
||||
# check if we have to realize pads
|
||||
for p in simple_pads:
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Dict, Optional, cast
|
||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, GlobalCounters
|
||||
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner, update_stats, InterpretedASTRunner
|
||||
from tinygrad.graph import print_tree, realized_lazybuffer
|
||||
from tinygrad.helpers import colored, getenv
|
||||
from tinygrad.helpers import colored, getenv, GRAPH
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
# *** schedule running ***
|
||||
@@ -47,4 +47,4 @@ def run_schedule(schedule:List[ScheduleItem]):
|
||||
assert all(x.realized is not None for x in si.inputs), f"can't run, some inputs aren't realized {[x for x in si.inputs if x.realized is None]}"
|
||||
if prg: prg.exec([si.out.realized] + [cast(Buffer, x.realized) for x in si.inputs], si.var_vals)
|
||||
else: update_stats(colored(f"empty {si.out.st.size:10d} {si.out.dtype}", "yellow"), 0, 0, {}, None, 1, device=si.out.device)
|
||||
realized_lazybuffer(si.out, GlobalCounters.kernel_count)
|
||||
if GRAPH: realized_lazybuffer(si.out, GlobalCounters.kernel_count)
|
||||
|
||||
Reference in New Issue
Block a user