remove late numbering of UOps (#14923)

* remove late numbering of UOps

* stupid fix

* dead code
This commit is contained in:
George Hotz
2026-02-21 09:18:48 +08:00
committed by GitHub
parent c9b706125d
commit df7774661a
4 changed files with 21 additions and 67 deletions

View File

@@ -1322,7 +1322,7 @@ class TestCopyFolding(unittest.TestCase):
a = Tensor.ones(4, 4).contiguous().realize()
# use copy_to_device to bypass Tensor.to() shortcircuit and force a real same-device COPY in the graph
a.assign(Tensor(a.uop.copy_to_device(a.device), a.device))
run_schedule(check_schedule(a, 0, filter_sink=False))
run_schedule(check_schedule(a, 2, filter_sink=False))
self.assertListEqual(a.tolist(), [[1.]*4]*4)
def test_clone(self):

View File

@@ -412,20 +412,20 @@ class TestSchedule(unittest.TestCase):
out = bn(c1(img)).relu()
check_schedule(out, 4, [c1.weight, c1.bias])
def test_fold_conv_batchnorm_optim(self):
# this is too high
for optim, cnt in [(nn.optim.Adam, 17), (nn.optim.SGD, 7)]:
with self.subTest(optim=optim.__name__):
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
_realize_weights([c1, bn])
opt = optim(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
check_schedule(opt.schedule_step(), cnt)
def test_fold_conv_batchnorm_optim(self, adam=False):
# 2 is too low?
optim, cnt = (nn.optim.Adam, 16) if adam else (nn.optim.SGD, 2)
with Tensor.train():
img = Tensor.ones(1,3,4,4)
c1 = nn.Conv2d(3,32,3)
bn = nn.BatchNorm2d(32, track_running_stats=False)
_realize_weights([c1, bn])
opt = optim(nn.state.get_parameters([c1, bn]))
img_bn = bn(c1(img)).elu().sum()
opt.zero_grad()
img_bn.backward()
check_schedule(opt.schedule_step(), cnt)
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
def test_fold_batchnorm_backward(self):
with Tensor.train():

View File

@@ -4,7 +4,7 @@ from collections import deque
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink
from tinygrad.uop.spec import type_verify, tensor_spec
from tinygrad.device import Buffer, MultiBuffer
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE
from tinygrad.engine.realize import ExecItem
from tinygrad.engine.allocations import allocate_global_buffers
@@ -127,11 +127,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
if SPEC: type_verify(big_sink, tensor_spec)
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
big_sink = get_rangeify(big_sink_cache)
pre_schedule, buf_uops_sink = create_schedule(big_sink)
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)

View File

@@ -2,9 +2,9 @@ from dataclasses import dataclass, field, replace
import itertools
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate
from tinygrad.uop.symbolic import symbolic
from tinygrad.helpers import argsort, prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
from tinygrad.helpers import PCONTIG, partition, get_single_element
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
from tinygrad.codegen.opt import Opt
@@ -169,7 +169,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
# *****************
# 3.5 cleanups
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC}
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
def cleanup_dead_axes(b:UOp):
@@ -527,42 +527,9 @@ split_kernels = PatternMatcher([
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
])
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
if x.tag is not None or x in ctx[1]: return None
if x.tag is None and x.op is Ops.CALL:
# don't tag anything in a CALL
for u in x.src[0].toposort(): ctx[1].add(u)
if x.dtype.scalar() == dtypes.index: return None
ctx[0].append(x)
return x.replace(tag=(len(ctx[0])-1,))
add_tags = pm_gate_kernel_sink+PatternMatcher([
# don't tag BUFFERs, they are global
(UPat(GroupOp.All-{Ops.PARAM, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.END,
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.PARAM for s in x.src) else tag_uop(ctx, x)),
])
# support for using a contiguous permuted view instead of the parent view if one exists
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
x = src
while x is not src.base:
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg))
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
else: return None
x = x.src[0]
ctx[src.base] = contig
replace_contiguous = PatternMatcher([
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
])
def get_rangeify(sink:UOp) -> UOp:
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
uop_list: list[UOp] = []
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites+replace_contiguous, ctx={}, bottom_up=True, name="earliest rewrites")
tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
# convert movement ops to ranges
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))
@@ -570,12 +537,6 @@ def get_rangeify(sink:UOp) -> UOp:
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf")
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
# MSTACK stacks multiple BUFFERIZEs in one tagged tensor
# if it's not tagged by here, it's out
tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.PARAM, Ops.AFTER} and \
x.tag is not None and len(x.tag)])
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
# bufferize -> store
@@ -597,8 +558,5 @@ def get_rangeify(sink:UOp) -> UOp:
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER")
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
# TODO: we can probably get this earlier
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
return tsink