mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
closer
This commit is contained in:
@@ -10,9 +10,8 @@ from hypothesis import assume, given, settings, strategies as strat
|
||||
from tinygrad import nn, dtypes, Device, Tensor, Variable
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat, Kernel
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.schedule.rangeify import Kernel
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import time
|
||||
from typing import cast
|
||||
from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, Kernel
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE, Metadata
|
||||
@@ -22,14 +22,14 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
# build kernel dependency graph: edges from producer kernel to consumer kernels
|
||||
children: dict[UOp, list[UOp]] = {}
|
||||
in_degree: dict[UOp, int] = {}
|
||||
for u in sched_sink.toposort():
|
||||
for u in sched_sink.toposort(gate_kernel_sink):
|
||||
if u.op is Ops.RANGE: in_degree.setdefault(u, 0)
|
||||
if u.op is not Ops.AFTER: continue
|
||||
if (k:=u.src[1]).op is Ops.RANGE: continue # RANGEs are scheduled directly, not through dependency graph
|
||||
assert k.op in {Ops.KERNEL, Ops.END}, f"AFTER src[1] should be KERNEL or END, not {k.op}"
|
||||
in_degree.setdefault(k, 0)
|
||||
if k.op is Ops.END: assert k.src[0].op is Ops.KERNEL, f"END src[0] should be KERNEL, not {k.src[0].op}"
|
||||
for s in k.src[0].src if k.op is Ops.END else k.src[1:]:
|
||||
for s in k.src[0].src[1:] if k.op is Ops.END else k.src[1:]:
|
||||
match (s := _unwrap_src(s)).op:
|
||||
case Ops.AFTER:
|
||||
children.setdefault(s.src[1], []).append(k)
|
||||
@@ -60,7 +60,7 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
|
||||
#ast = (kernel:=cast(Kernel, k.arg)).ast
|
||||
ast = k.src[0]
|
||||
buf_uops = tuple(_unwrap_src(s).buf_uop for s in k.src[1:] if s.op is not Ops.BIND)
|
||||
bound_ranges = tuple(s for s in k.src if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
bound_ranges = tuple(s for s in k.src[1:] if s.op is Ops.BIND and len(s.src) > 1 and s.src[1].op is Ops.RANGE)
|
||||
sched_item[k] = (ast, buf_uops, k.arg.metadata, bound_ranges)
|
||||
schedule.append(k)
|
||||
if rk.op is Ops.END: schedule.append(rk)
|
||||
|
||||
@@ -3,9 +3,9 @@ import functools, itertools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.dtype import dtypes, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, graph_rewrite, sint, AxisType, profile_matches
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort
|
||||
from tinygrad.uop.ops import consumer_map_from_toposort, KernelInfo, BottomUpGate, gate_kernel_sink
|
||||
from tinygrad.uop.symbolic import symbolic, pm_simplify_valid, pm_drop_and_clauses
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored
|
||||
from tinygrad.helpers import argsort, all_same, cpu_profile, PCONTIG, colored, panic
|
||||
|
||||
ALWAYS_CONTIGUOUS: set[Ops] = {Ops.CONTIGUOUS, Ops.ASSIGN, Ops.COPY, Ops.BUFFER, Ops.BUFFER_VIEW,
|
||||
Ops.CONST, Ops.BIND, Ops.DEVICE, Ops.MSELECT, Ops.MSTACK, Ops.PARAM,
|
||||
@@ -19,9 +19,12 @@ def realize_srcs(ctx:dict[UOp, None], rb:UOp) -> None:
|
||||
|
||||
def realize_assign(ctx:dict[UOp, None], a:UOp) -> None:
|
||||
if a.src[1].op not in ALWAYS_CONTIGUOUS: ctx[a.src[1]] = None
|
||||
ctx[a] = None
|
||||
# if it's a kernel, we don't realize it
|
||||
if a.src[1].op is not Ops.KERNEL: ctx[a] = None
|
||||
|
||||
pm_generate_realize_map = PatternMatcher([
|
||||
# if it's a Kernel, stop
|
||||
(UPat(Ops.SINK, name="sink"), lambda sink: panic(BottomUpGate()) if isinstance(sink.arg, KernelInfo) else None),
|
||||
# always realize SINK src
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.update((x.base, None) for x in s.src if x.base.op not in ALWAYS_CONTIGUOUS)),
|
||||
# always realize COPY/BUFFER_VIEW/CONTIGUOUS/STORE/ENCDEC
|
||||
@@ -161,11 +164,11 @@ def run_rangeify(tsink:UOp, debug:bool=False) -> tuple[UOp, IndexingContext]:
|
||||
rctx = IndexingContext()
|
||||
|
||||
# get ops to realize
|
||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, name="get realize")
|
||||
graph_rewrite(tsink, pm_generate_realize_map, ctx=rctx.realize_map, bottom_up=True, name="get realize")
|
||||
|
||||
# get the consumer map
|
||||
with cpu_profile("consumer map in rangeify", "TINY"):
|
||||
consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort())
|
||||
consumer_map = consumer_map_from_toposort(tsink_toposort:=tsink.toposort(gate_kernel_sink))
|
||||
|
||||
# explicit rangeify
|
||||
ending_ranges: dict[UOp, list[UOp]] = {}
|
||||
|
||||
@@ -2,7 +2,7 @@ from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, Kernel, _remove_all_tags, range_str
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags, range_str
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element, panic
|
||||
@@ -66,9 +66,12 @@ mop_cleanup = PatternMatcher([
|
||||
|
||||
def resolve_custom_kernel(ck:UOp) -> UOp:
|
||||
placeholders = [UOp.placeholder_like(s, slot=i) for i,s in enumerate(ck.src)]
|
||||
return UOp(Ops.KERNEL, src=ck.src, arg=Kernel(ck.arg.fxn(*placeholders)))
|
||||
return ck.arg.fxn(*placeholders).call(*ck.src)
|
||||
|
||||
def resolve_call(c:UOp) -> UOp:
|
||||
def resolve_call(c:UOp) -> UOp|None:
|
||||
# don't resolve real kernel calls, sink or program
|
||||
if c.src[0].op is Ops.SINK and isinstance(c.src[0].arg, KernelInfo): return None
|
||||
if c.src[0].op is Ops.PROGRAM: return None
|
||||
params = sorted([x for x in c.src[0].toposort() if x.op == Ops.PARAM], key=lambda x: x.arg)
|
||||
args = c.src[1:]
|
||||
# TODO: this check belongs in spec, not here
|
||||
|
||||
@@ -67,7 +67,8 @@ def consumer_map_from_toposort(lst:Iterable[UOp]):
|
||||
ret: dict[UOp, dict[UOp, None]] = {}
|
||||
for u in lst:
|
||||
ret[u] = {}
|
||||
for s in u.src: ret[s][u] = None
|
||||
for s in u.src:
|
||||
if s in ret: ret[s][u] = None
|
||||
return ret
|
||||
|
||||
def pretty_print(x:UOp, cache=None, d=0)->str:
|
||||
@@ -236,7 +237,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
return self.src[0]._shape
|
||||
|
||||
# ops with custom handling
|
||||
case Ops.KERNEL: return self.arg.ast._shape
|
||||
#case Ops.KERNEL: return self.arg.ast._shape
|
||||
|
||||
# TODO: disallow shape changing bitcast
|
||||
case Ops.BITCAST:
|
||||
@@ -836,6 +837,8 @@ class KernelInfo:
|
||||
@property
|
||||
def function_name(self): return to_function_name(self.name)
|
||||
|
||||
def gate_kernel_sink(x:UOp) -> bool: return not (x.op is Ops.SINK and isinstance(x.arg, KernelInfo))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CustomKernel:
|
||||
fxn: Callable
|
||||
@@ -1448,10 +1451,10 @@ def pyrender(ast:UOp) -> str:
|
||||
if op_depth > 100: to_render.add(u)
|
||||
depth[u] = 0 if u in to_render else op_depth
|
||||
# do the rendering
|
||||
if u.op is Ops.KERNEL:
|
||||
if u.arg.ast not in kernels:
|
||||
kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")
|
||||
r[u.arg.ast] = kernels[u.arg.ast][0]
|
||||
#if u.op is Ops.KERNEL:
|
||||
# if u.arg.ast not in kernels:
|
||||
# kernels[u.arg.ast] = (f"k{len(kernels)}", f"def k{len(kernels)}():\n " + pyrender(u.arg.ast).replace('\n', '\n ') + "\n return ast\n\n")
|
||||
# r[u.arg.ast] = kernels[u.arg.ast][0]
|
||||
ren = cast(str, pm_pyrender.rewrite(u, ctx=r))
|
||||
assert isinstance(ren, str)
|
||||
if u.tag is not None: ren += f".rtag({repr(u.tag)})"
|
||||
|
||||
@@ -45,7 +45,7 @@ from tinygrad.dtype import dtypes
|
||||
uops_colors = {Ops.LOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.REDUCE: "#FF5B5B",
|
||||
Ops.PARAM:"#cb9037", **{x:"#f2cb91" for x in {Ops.DEFINE_LOCAL, Ops.DEFINE_REG}}, Ops.REDUCE_AXIS: "#FF6B6B",
|
||||
Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#909090", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.KERNEL: "#3e7f55", Ops.CUSTOM_KERNEL: "#3ebf55",
|
||||
Ops.INDEX: "#cef263", Ops.WMMA: "#efefc0", Ops.MULTI: "#f6ccff", Ops.CUSTOM_KERNEL: "#3ebf55",
|
||||
**{x:"#D8F9E4" for x in GroupOp.Movement}, **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.THREEFRY:"#ffff80",
|
||||
Ops.BUFFER_VIEW: "#E5EAFF", Ops.BUFFER: "#B0BDFF", Ops.COPY: "#a040a0", Ops.ENCDEC: "#bf71b6",
|
||||
Ops.CALL: "#00B7C8", Ops.PARAM: "#14686F",
|
||||
|
||||
Reference in New Issue
Block a user