start ripping out old scheduler -- no maps (#14909)

* start ripping out old scheduler -- no maps

* no more metadata
This commit is contained in:
George Hotz
2026-02-20 21:05:04 +08:00
committed by GitHub
parent 1b3b94a72a
commit 2611907afb
8 changed files with 17 additions and 271 deletions

View File

@@ -10,7 +10,7 @@ Directories are listed in order of how they are processed.
Group UOps into kernels.
::: tinygrad.schedule.rangeify.get_rangeify_map
::: tinygrad.schedule.rangeify.get_rangeify
options:
members: false
show_labels: false

View File

@@ -266,7 +266,7 @@ class TestCustomKernel(unittest.TestCase):
The custom_addmul kernel should be at index 3.
"""
from tinygrad.engine.schedule import create_schedule
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.rangeify import get_rangeify
A, B = Tensor.empty(4, 4), Tensor.empty(4, 4)
A2 = (A + 1).contiguous() # kernel 0: depends on A
@@ -277,8 +277,7 @@ class TestCustomKernel(unittest.TestCase):
result = (C + D + E).sum() # kernel 3: custom_addmul, then kernel 4: sum
big_sink = result.uop.sink()
tensor_map = get_rangeify_map(big_sink)
sched_sink = big_sink.substitute(tensor_map)
sched_sink = get_rangeify(big_sink)
schedule, _ = create_schedule(sched_sink)
# Find the custom_addmul kernel position

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
# compare kernels created by HEAD against master
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, itertools, functools, base64, codecs
import os, multiprocessing, logging, pickle, sqlite3, difflib, warnings, functools, base64, codecs
from dataclasses import replace
from typing import Callable, Any
@@ -8,7 +8,6 @@ ASSERT_DIFF = int((flag:="[pr]") in os.getenv("COMMIT_MESSAGE", flag) or flag in
if not int(os.getenv("ASSERT_PROCESS_REPLAY", "1")): ASSERT_DIFF = 0
try:
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.renderer import Renderer, ProgramSpec
from tinygrad.engine.realize import get_program
from tinygrad.uop.ops import UOp, Ops, KernelInfo
@@ -43,14 +42,6 @@ class ProcessReplayWarning(Warning): pass
# *** replay the function and convert return values to string
def replay_get_rangeify_map(ret:dict[UOp, UOp], big_sink:UOp) -> tuple[str, str, tuple[Any, ...]]:
UOp.unique_num = itertools.count(max([u.arg for u in big_sink.toposort() if u.op is Ops.UNIQUE], default=0)+1)
new_sink = big_sink.substitute(get_rangeify_map(big_sink))
def to_str(ret:UOp) -> str:
asts = [repr(u.arg.ast) for u in ret.toposort() if u.op is Ops.CALL]
return "\n".join([f"{len(asts)} kernels", *asts])
return to_str(new_sink), to_str(big_sink.substitute(ret)), (big_sink,)
def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]|None=None) -> tuple[str, str, tuple[Any, ...]]:
# the ast.arg is non None if we are inside of search.py
sink_arg = ast.arg or KernelInfo()
@@ -68,8 +59,6 @@ def replay_get_program(p:ProgramSpec, ast:UOp, renderer:Renderer, opts:list[Opt]
replayers: dict[str, Callable[..., tuple[str, str, tuple[Any, ...]]]] = {}
replayers["get_program"] = replay_get_program
# disable this for speed, does it ever find things?
#replayers["get_rangeify_map"] = replay_get_rangeify_map
# *** run replayers on captured rows and print diffs

View File

@@ -1,202 +0,0 @@
import unittest
from tinygrad import dtypes
from tinygrad.uop.ops import UOp, graph_rewrite_map, _substitute
from tinygrad.uop.symbolic import symbolic
class TestRewriteMap(unittest.TestCase):
def test_substitute(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
e = UOp.variable('e', 0, 10)
ret = (a+b)*c
sub = {a+b: e}
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
self.assertIs(sub_map[a+b], e)
self.assertIs(sub_map[(a+b)*c], e*c)
def test_substitute_depth_2(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
d = UOp.variable('d', 0, 10)
e = UOp.variable('e', 0, 10)
f = UOp.variable('f', 0, 10)
ret = (a+b)*c+d
sub = {a+b: e, (a+b)*c: f}
sub_map = graph_rewrite_map(ret, _substitute, sub, bottom_up=True)
self.assertIs(sub_map[a+b], e)
self.assertIs(sub_map[(a+b)*c], f)
def test_multistage_substitute(self):
a = UOp.variable('a', 0, 10)
b = UOp.variable('b', 0, 10)
c = UOp.variable('c', 0, 10)
d = UOp.variable('d', 0, 10)
sub1 = {a+b:c}
start = (a+b)*c
# stage 1: (a+b)*c -> c*c
sub_map1 = graph_rewrite_map(start, _substitute, sub1, bottom_up=True)
self.assertIs(sub_map1[(a+b)*c], c*c)
# stage 2: c*c -> d
sub2 = {c*c:d}
sub_map2 = graph_rewrite_map(sub_map1[start], _substitute, sub2, input_map=sub_map1, bottom_up=True)
# (a+b)*c -> c*c -> d
self.assertIs(sub_map2[(a+b)*c], d)
def test_add_zero(self):
# Build a small graph: add(0, add(const=0, const=5))
zero_node = UOp.const(dtypes.index, 0)
five_node = UOp.const(dtypes.index, 5)
inner_add = zero_node + five_node
root_add = zero_node + inner_add
# Perform top-down rewrite
node_map = graph_rewrite_map(root_add, symbolic)
# We expect that add(0, add(0, 5)) -> add(0, 5) -> 5
# Check the mapping
assert node_map[root_add] == five_node
assert node_map[inner_add] == five_node
# zero_node and five_node map to themselves
assert node_map[zero_node] == zero_node
assert node_map[five_node] == five_node
def test_double_neg(self):
"""
Test rewriting neg(neg(5)) => 5 using symbolic.
"""
# In some versions of TinyGrad, you might do: (-(-five_node))
five_node = UOp.const(dtypes.index, 5)
# If your code allows UOp(...), do that; else you might do something like:
# double_neg_five = -(-five_node)
# But let's be explicit:
neg_five = -five_node
double_neg_five = -neg_five
node_map = graph_rewrite_map(double_neg_five, symbolic)
# node_map should map double_neg_five -> five_node
self.assertEqual(node_map[double_neg_five], five_node)
# five_node maps to itself
self.assertEqual(node_map[five_node], five_node)
def test_add_zero_and_double_neg(self):
"""
Combine both rewrites: add(0, neg(neg(5))) => add(0, 5) => 5
"""
zero_node = UOp.const(dtypes.index, 0)
five_node = UOp.const(dtypes.index, 5)
neg_five = -five_node
double_neg_five = -neg_five
root_add = zero_node + double_neg_five
node_map = graph_rewrite_map(root_add, symbolic)
# node_map: root_add -> five_node, double_neg_five -> five_node
self.assertEqual(node_map[root_add], five_node)
self.assertEqual(node_map[double_neg_five], five_node)
# zero_node, five_node map to themselves
self.assertEqual(node_map[zero_node], zero_node)
self.assertEqual(node_map[five_node], five_node)
def test_multi_var_rewrites(self):
x_var = UOp.variable('x', 0, 10)
y_var = UOp.variable('y', -5, 5)
zero_node = UOp.const(dtypes.index, 0)
sum_with_zero = y_var + zero_node # (y + 0)
combined = x_var + sum_with_zero # x + (y + 0)
double_neg = -(-combined) # neg(neg(x + y))
final_expr = zero_node + double_neg # 0 + (x + y)
node_map = graph_rewrite_map(final_expr, symbolic)
# The final root should be (x_var + y_var).
expected = x_var + y_var
# Each sub-expression has its own "final" result.
# (y + 0) -> y_var
self.assertEqual(node_map[sum_with_zero], y_var)
# (x + (y+0)) -> (x + y)
self.assertEqual(node_map[combined], expected)
# neg(neg(x+y)) -> (x + y)
self.assertEqual(node_map[double_neg], expected)
# 0 + (x+y) -> (x + y)
self.assertEqual(node_map[final_expr], expected)
# x_var, y_var, zero_node remain unchanged
self.assertEqual(node_map[x_var], x_var)
self.assertEqual(node_map[y_var], y_var)
self.assertEqual(node_map[zero_node], zero_node)
def test_complex_multi_var_edges(self):
"""
Build a multi-variable expression with multiple intermediates:
x_var = UOp.variable('x', 1, 10)
y_var = UOp.variable('y', -5, 5)
z_var = UOp.variable('z', 0, 5)
zero_node = UOp.const(dtypes.int, 0)
one_node = UOp.const(dtypes.int, 1)
yz_sum = y_var + z_var
yz_sum_zero = yz_sum + zero_node -> rewrites to yz_sum
yz_neg = -yz_sum_zero -> -(y+z)
yz_dneg = -yz_neg -> y+z (double neg gone)
x_plus_yz = x_var + yz_dneg -> x + (y+z)
double_neg_x = -(-x_plus_yz) -> x + (y+z)
final_expr = double_neg_x * one_node -> x + (y+z)
We expect the final result to be (x + (y+z)).
Each original node should map to the final node that replaces it,
which might be structurally equivalent but not the same reference.
"""
x_var = UOp.variable('x', 1, 10)
y_var = UOp.variable('y', -5, 5)
z_var = UOp.variable('z', 0, 5)
zero_node = UOp.const(dtypes.index, 0)
one_node = UOp.const(dtypes.index, 1)
# Build sub-expressions
yz_sum = y_var + z_var # (y + z)
yz_sum_zero = yz_sum + zero_node # (y + z) + 0
yz_neg = -yz_sum_zero # -(y+z)
yz_dneg = -yz_neg # -(-(y+z)) -> (y+z)
x_plus_yz = x_var + yz_dneg # x + (y+z)
double_neg_x = -(-x_plus_yz) # neg(neg(x+(y+z))) -> x+(y+z)
final_expr = double_neg_x * one_node # (x+(y+z)) * 1 -> x+(y+z)
node_map = graph_rewrite_map(final_expr, symbolic)
# (y + z) is unchanged
self.assertEqual(node_map[yz_sum], yz_sum)
# (y+z) + 0 => (y+z)
self.assertEqual(node_map[yz_sum_zero], yz_sum)
# -(y+z) remains -(y+z), but might be a new UOp with updated children
# Compare structurally to -(y_var + z_var).
self.assertEqual(node_map[yz_neg], -yz_sum)
# -(-(y+z)) => (y+z)
self.assertEqual(node_map[yz_dneg], yz_sum)
# x + (y+z) => might get recreated if yz_dneg was changed, so compare to x + yz_sum
self.assertEqual(node_map[x_plus_yz], x_var + yz_sum)
# -(-(x+(y+z))) => x + (y+z)
self.assertEqual(node_map[double_neg_x], x_var + yz_sum)
# (x+(y+z)) * 1 => x+(y+z)
self.assertEqual(node_map[final_expr], x_var + yz_sum)
# Unchanged atomic nodes map to themselves
self.assertEqual(node_map[x_var], x_var)
self.assertEqual(node_map[y_var], y_var)
self.assertEqual(node_map[z_var], z_var)
self.assertEqual(node_map[zero_node], zero_node)
self.assertEqual(node_map[one_node], one_node)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,7 +1,7 @@
import time
from typing import cast
from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, graph_rewrite_map, gate_kernel_sink
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
@@ -63,8 +63,8 @@ def create_schedule(sched_sink:UOp) -> tuple[list[ExecItem], UOp]:
return pre_schedule, UOp.sink(*buf_uops_list)
from tinygrad.engine.memory import memory_planner
from tinygrad.schedule.rangeify import get_rangeify_map
from tinygrad.schedule.multi import get_multi_map
from tinygrad.schedule.rangeify import get_rangeify
from tinygrad.schedule.multi import multi_pm
def replace_input_buffer(ctx:tuple[dict[UOp, UOp], dict[str, int], list[int], list[int]], b:UOp):
if (ret:=ctx[0].get(b, None)) is None:
@@ -128,20 +128,11 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
# hack to preserve metadata
graph_rewrite_map(big_sink, pm_pre_sched_cache, ctx=({}, {}, [0], [0]), name="preserve metadata")
# tensor map is what we return
tensor_map: dict[UOp, UOp] = {}
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
tensor_map |= get_multi_map(big_sink_cache)
big_sink_cache = big_sink_cache.substitute(tensor_map, name="Apply Multi Map")
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
tensor_map |= get_rangeify_map(big_sink_cache)
big_sink = big_sink_cache.substitute(tensor_map, name="Apply Kernelize Map")
big_sink = get_rangeify(big_sink_cache)
pre_schedule, buf_uops_sink = create_schedule(big_sink)
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
else:

View File

@@ -1,6 +1,6 @@
import functools, itertools
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, VIZ, getenv
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp, graph_rewrite_map, graph_rewrite
from tinygrad.helpers import all_same, all_int, prod, DEBUG, RING, ALL2ALL, getenv
from tinygrad.uop.ops import Ops, UOp, PatternMatcher, UPat, GroupOp
from tinygrad.dtype import dtypes
# *** allreduce implementation ***
@@ -187,9 +187,3 @@ multi_pm = PatternMatcher([
(UPat(Ops.AFTER, src=(UPat(Ops.MULTI, name="multi"), UPat(Ops.CALL)), name="a"),
lambda multi,a: a.replace(src=(multi.src[0],)+a.src[1:]).multi(multi.axis)),
])+replace_allreduce
def get_multi_map(big_sink:UOp) -> dict[UOp, UOp]:
if VIZ: graph_rewrite(big_sink, PatternMatcher([]), name="View Multi AST")
ret = graph_rewrite_map(big_sink, multi_pm, name="multi_pm")
if VIZ: graph_rewrite(ret[big_sink], PatternMatcher([]), name="View Post Multi AST")
return ret

View File

@@ -4,7 +4,7 @@ from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import argsort, prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
@@ -503,7 +503,7 @@ pm_add_range_tags = PatternMatcher([
(UPat(Ops.RANGE, name="x"), lambda x: x.rtag(())),
])
def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
def split_store(x:UOp) -> UOp|None:
# if we have any open ranges here, we don't split
if x.ranges: return None
@@ -511,9 +511,6 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
lctx = LocalAddBufferContext()
ret = graph_rewrite(x, to_define_global+pm_flatten_range+rangeify_codegen+pm_remove_tags, ctx=lctx, name="kernel split", bottom_up=True)
# gather the metadata
metadatas = [ctx[y].metadata for y in lctx.parent_tags]
# SINK requires all buffers on the same device, but COPY/BUFFER_VIEW/ENCDEC are cross-device or special hardware ops
if ret.op is Ops.STORE: stored = ret.src[1]
elif ret.op is Ops.END and ret.src[0].op is Ops.STORE: stored = ret.src[0].src[1]
@@ -521,8 +518,7 @@ def split_store(ctx:list[UOp], x:UOp) -> UOp|None:
if stored.op in {Ops.COPY, Ops.BUFFER_VIEW, Ops.ENCDEC}: ret = stored
else: ret = ret.sink(arg=KernelInfo(opts_to_apply=lctx.opts))
metadata = tuple(dedup(flatten([x for x in metadatas if x is not None])))[::-1]
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys(), metadata=metadata)
kernel = ret.call(*lctx.map.values(), *lctx.vars.keys())
if ret.op is Ops.SINK and not all_same([x.device for x in kernel.src[1:] if x.op is not Ops.BIND]):
raise RuntimeError(f"all buffers must be on the same device: {tuple(b.buf_uop for b in kernel.src[1:])}")
return kernel
@@ -561,7 +557,7 @@ replace_contiguous = PatternMatcher([
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
def get_rangeify(sink:UOp) -> UOp:
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
@@ -586,7 +582,7 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
lunique_start: int = max([-1]+[x.arg for x in tsink.toposort() if x.op is Ops.LUNIQUE]) + 1
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+pm_add_buffers+pm_add_range_tags, ctx=itertools.count(lunique_start), bottom_up=True,
name="bufferize to store")
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, ctx=uop_list, bottom_up=True, name="split kernels")
tsink = graph_rewrite(tsink, pm_gate_kernel_sink+split_kernels, bottom_up=True, name="split kernels")
# WAR deps: if kernel U reads buffer S, and S is also written by another kernel, S's write must wait for U to finish
afters = [u for u in tsink.toposort() if u.op is Ops.AFTER]
@@ -603,15 +599,6 @@ def get_rangeify_map(sink:UOp) -> dict[UOp, UOp]:
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
# TODO: we can probably get this earlier
sink_tags = [s.tag for s in tsink.src]
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
becomes_map: dict[UOp, UOp] = {}
for tag, s in zip(sink_tags, tsink.src):
assert tag is not None
for a in tag:
if a is None: continue
becomes_map[uop_list[int(a)]] = s
return becomes_map
return tsink

View File

@@ -1311,18 +1311,6 @@ def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=N
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
return rewrite_ctx.unified_rewrite(sink)
@profile_matches
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, bpm=None,
input_map:dict[UOp, UOp]|None=None, ) -> dict[UOp, UOp]:
rewrite_ctx = RewriteContext(pm if not bottom_up else None, pm if bottom_up else bpm, ctx)
new_map: dict[UOp, UOp] = {}
for k in (list(sink.toposort())[::-1] if bottom_up else sink.toposort()):
new_map[k] = v = rewrite_ctx.unified_rewrite(k)
if k is not v and k.metadata is not None: all_metadata[v] = tuple(dedup(all_metadata.get(v, ())))+k.metadata
if input_map is not None:
for k,v in input_map.items(): new_map[k] = new_map.get(v,v)
return new_map
def sint_to_uop(x:sint, dtype=dtypes.index) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x.cast(dtype)
def select_dtype(u): return (dtypes.long if u.overflows(dtypes.int32) else dtypes.int).vec(u.dtype.count)