mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
remove late numbering of UOps (#14923)
* remove late numbering of UOps * stupid fix * dead code
This commit is contained in:
@@ -1322,7 +1322,7 @@ class TestCopyFolding(unittest.TestCase):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
# use copy_to_device to bypass Tensor.to() shortcircuit and force a real same-device COPY in the graph
|
||||
a.assign(Tensor(a.uop.copy_to_device(a.device), a.device))
|
||||
run_schedule(check_schedule(a, 0, filter_sink=False))
|
||||
run_schedule(check_schedule(a, 2, filter_sink=False))
|
||||
self.assertListEqual(a.tolist(), [[1.]*4]*4)
|
||||
|
||||
def test_clone(self):
|
||||
|
||||
@@ -412,20 +412,20 @@ class TestSchedule(unittest.TestCase):
|
||||
out = bn(c1(img)).relu()
|
||||
check_schedule(out, 4, [c1.weight, c1.bias])
|
||||
|
||||
def test_fold_conv_batchnorm_optim(self):
|
||||
# this is too high
|
||||
for optim, cnt in [(nn.optim.Adam, 17), (nn.optim.SGD, 7)]:
|
||||
with self.subTest(optim=optim.__name__):
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
_realize_weights([c1, bn])
|
||||
opt = optim(nn.state.get_parameters([c1, bn]))
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
check_schedule(opt.schedule_step(), cnt)
|
||||
def test_fold_conv_batchnorm_optim(self, adam=False):
|
||||
# 2 is too low?
|
||||
optim, cnt = (nn.optim.Adam, 16) if adam else (nn.optim.SGD, 2)
|
||||
with Tensor.train():
|
||||
img = Tensor.ones(1,3,4,4)
|
||||
c1 = nn.Conv2d(3,32,3)
|
||||
bn = nn.BatchNorm2d(32, track_running_stats=False)
|
||||
_realize_weights([c1, bn])
|
||||
opt = optim(nn.state.get_parameters([c1, bn]))
|
||||
img_bn = bn(c1(img)).elu().sum()
|
||||
opt.zero_grad()
|
||||
img_bn.backward()
|
||||
check_schedule(opt.schedule_step(), cnt)
|
||||
def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True)
|
||||
|
||||
def test_fold_batchnorm_backward(self):
|
||||
with Tensor.train():
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import deque
|
||||
from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink
|
||||
from tinygrad.uop.spec import type_verify, tensor_spec
|
||||
from tinygrad.device import Buffer, MultiBuffer
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE
|
||||
from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.allocations import allocate_global_buffers
|
||||
|
||||
@@ -127,11 +127,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li
|
||||
if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None:
|
||||
# verify Tensors match the spec (on big_sink, we only need to do this if cache misses)
|
||||
if SPEC: type_verify(big_sink, tensor_spec)
|
||||
|
||||
if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()):
|
||||
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
|
||||
big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src]))
|
||||
|
||||
big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm")
|
||||
big_sink = get_rangeify(big_sink_cache)
|
||||
pre_schedule, buf_uops_sink = create_schedule(big_sink)
|
||||
if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink)
|
||||
|
||||
@@ -2,9 +2,9 @@ from dataclasses import dataclass, field, replace
|
||||
import itertools
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace
|
||||
from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags
|
||||
from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.helpers import argsort, prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS
|
||||
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
||||
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
||||
from tinygrad.codegen.opt import Opt
|
||||
@@ -169,7 +169,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([
|
||||
# *****************
|
||||
# 3.5 cleanups
|
||||
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC}
|
||||
ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP}
|
||||
|
||||
# you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left
|
||||
def cleanup_dead_axes(b:UOp):
|
||||
@@ -527,42 +527,9 @@ split_kernels = PatternMatcher([
|
||||
(UPat((Ops.STORE, Ops.END), name="x"), split_store),
|
||||
])
|
||||
|
||||
def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp):
|
||||
if x.tag is not None or x in ctx[1]: return None
|
||||
if x.tag is None and x.op is Ops.CALL:
|
||||
# don't tag anything in a CALL
|
||||
for u in x.src[0].toposort(): ctx[1].add(u)
|
||||
if x.dtype.scalar() == dtypes.index: return None
|
||||
ctx[0].append(x)
|
||||
return x.replace(tag=(len(ctx[0])-1,))
|
||||
add_tags = pm_gate_kernel_sink+PatternMatcher([
|
||||
# don't tag BUFFERs, they are global
|
||||
(UPat(GroupOp.All-{Ops.PARAM, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.END,
|
||||
Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop),
|
||||
(UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.PARAM for s in x.src) else tag_uop(ctx, x)),
|
||||
])
|
||||
|
||||
# support for using a contiguous permuted view instead of the parent view if one exists
|
||||
|
||||
def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp):
|
||||
x = src
|
||||
while x is not src.base:
|
||||
if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg))
|
||||
elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape)
|
||||
else: return None
|
||||
x = x.src[0]
|
||||
ctx[src.base] = contig
|
||||
replace_contiguous = PatternMatcher([
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous),
|
||||
(UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None),
|
||||
])
|
||||
|
||||
def get_rangeify(sink:UOp) -> UOp:
|
||||
if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph")
|
||||
uop_list: list[UOp] = []
|
||||
tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops")
|
||||
|
||||
tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites+replace_contiguous, ctx={}, bottom_up=True, name="earliest rewrites")
|
||||
tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites")
|
||||
|
||||
# convert movement ops to ranges
|
||||
tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY))
|
||||
@@ -570,12 +537,6 @@ def get_rangeify(sink:UOp) -> UOp:
|
||||
tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf")
|
||||
tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers")
|
||||
|
||||
# rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph
|
||||
# MSTACK stacks multiple BUFFERIZEs in one tagged tensor
|
||||
# if it's not tagged by here, it's out
|
||||
tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.PARAM, Ops.AFTER} and \
|
||||
x.tag is not None and len(x.tag)])
|
||||
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify")
|
||||
|
||||
# bufferize -> store
|
||||
@@ -597,8 +558,5 @@ def get_rangeify(sink:UOp) -> UOp:
|
||||
raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER")
|
||||
assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,))
|
||||
if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign")
|
||||
|
||||
# TODO: we can probably get this earlier
|
||||
tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags")
|
||||
if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph")
|
||||
return tsink
|
||||
Reference in New Issue
Block a user