diff --git a/test/backend/test_schedule.py b/test/backend/test_schedule.py index 6a754080d1..c1f419ed40 100644 --- a/test/backend/test_schedule.py +++ b/test/backend/test_schedule.py @@ -1322,7 +1322,7 @@ class TestCopyFolding(unittest.TestCase): a = Tensor.ones(4, 4).contiguous().realize() # use copy_to_device to bypass Tensor.to() shortcircuit and force a real same-device COPY in the graph a.assign(Tensor(a.uop.copy_to_device(a.device), a.device)) - run_schedule(check_schedule(a, 0, filter_sink=False)) + run_schedule(check_schedule(a, 2, filter_sink=False)) self.assertListEqual(a.tolist(), [[1.]*4]*4) def test_clone(self): diff --git a/test/null/test_schedule.py b/test/null/test_schedule.py index c27c303e13..d53c64ca68 100644 --- a/test/null/test_schedule.py +++ b/test/null/test_schedule.py @@ -412,20 +412,20 @@ class TestSchedule(unittest.TestCase): out = bn(c1(img)).relu() check_schedule(out, 4, [c1.weight, c1.bias]) - def test_fold_conv_batchnorm_optim(self): - # this is too high - for optim, cnt in [(nn.optim.Adam, 17), (nn.optim.SGD, 7)]: - with self.subTest(optim=optim.__name__): - with Tensor.train(): - img = Tensor.ones(1,3,4,4) - c1 = nn.Conv2d(3,32,3) - bn = nn.BatchNorm2d(32, track_running_stats=False) - _realize_weights([c1, bn]) - opt = optim(nn.state.get_parameters([c1, bn])) - img_bn = bn(c1(img)).elu().sum() - opt.zero_grad() - img_bn.backward() - check_schedule(opt.schedule_step(), cnt) + def test_fold_conv_batchnorm_optim(self, adam=False): + # 2 is too low? + optim, cnt = (nn.optim.Adam, 16) if adam else (nn.optim.SGD, 2) + with Tensor.train(): + img = Tensor.ones(1,3,4,4) + c1 = nn.Conv2d(3,32,3) + bn = nn.BatchNorm2d(32, track_running_stats=False) + _realize_weights([c1, bn]) + opt = optim(nn.state.get_parameters([c1, bn])) + img_bn = bn(c1(img)).elu().sum() + opt.zero_grad() + img_bn.backward() + check_schedule(opt.schedule_step(), cnt) + def test_fold_conv_batchnorm_optim_adam(self): self.test_fold_conv_batchnorm_optim(True) def test_fold_batchnorm_backward(self): with Tensor.train(): diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b7eda1ccee..d3fe1ab683 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -4,7 +4,7 @@ from collections import deque from tinygrad.uop.ops import UOp, Ops, buffers, UOpMetaClass, track_rewrites, PatternMatcher, UPat, graph_rewrite, gate_kernel_sink from tinygrad.uop.spec import type_verify, tensor_spec from tinygrad.device import Buffer, MultiBuffer -from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, flatten, pluralize, SCACHE +from tinygrad.helpers import DEBUG, cpu_profile, TracingKey, SPEC, pluralize, SCACHE from tinygrad.engine.realize import ExecItem from tinygrad.engine.allocations import allocate_global_buffers @@ -127,11 +127,7 @@ def complete_create_schedule_with_vars(big_sink:UOp) -> tuple[dict[UOp, UOp], li if not SCACHE or (sc_ret:=schedule_cache.get(sched_cache_key, None)) is None: # verify Tensors match the spec (on big_sink, we only need to do this if cache misses) if SPEC: type_verify(big_sink, tensor_spec) - - if any(isinstance(x._device, tuple) for x in big_sink_cache.toposort()): - big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm") - big_sink_cache = UOp.sink(*flatten([x.src if x.op is Ops.MULTI else [x] for x in big_sink_cache.src])) - + big_sink_cache = graph_rewrite(big_sink_cache, multi_pm, name="multi_pm") big_sink = get_rangeify(big_sink_cache) pre_schedule, buf_uops_sink = create_schedule(big_sink) if SCACHE: schedule_cache[sched_cache_key] = (pre_schedule, buf_uops_sink) diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 7d2070df99..2456ce7874 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -2,9 +2,9 @@ from dataclasses import dataclass, field, replace import itertools from tinygrad.dtype import dtypes, PtrDType, ImageDType, AddrSpace from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _substitute, KernelInfo, pm_gate_kernel_sink -from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate, _remove_all_tags +from tinygrad.uop.ops import graph_rewrite, identity_element, sint, AxisType, BottomUpGate from tinygrad.uop.symbolic import symbolic -from tinygrad.helpers import argsort, prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS +from tinygrad.helpers import prod, all_same, getenv, dedup, all_int, DEBUG, SPLIT_REDUCEOP, DEBUG_RANGEIFY, VIZ, MAX_KERNEL_BUFFERS from tinygrad.helpers import PCONTIG, partition, get_single_element from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify from tinygrad.codegen.opt import Opt @@ -169,7 +169,7 @@ earliest_rewrites = mop_cleanup+PatternMatcher([ # ***************** # 3.5 cleanups -ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC} +ALWAYS_RUN_OPS = {Ops.CONTIGUOUS, Ops.COPY, Ops.ASSIGN, Ops.ENCDEC, Ops.NOOP} # you don't know in the first pass if axes are going to die, this happens if there's an EXPAND to the left def cleanup_dead_axes(b:UOp): @@ -527,42 +527,9 @@ split_kernels = PatternMatcher([ (UPat((Ops.STORE, Ops.END), name="x"), split_store), ]) -def tag_uop(ctx:tuple[list[UOp], set[UOp]], x:UOp): - if x.tag is not None or x in ctx[1]: return None - if x.tag is None and x.op is Ops.CALL: - # don't tag anything in a CALL - for u in x.src[0].toposort(): ctx[1].add(u) - if x.dtype.scalar() == dtypes.index: return None - ctx[0].append(x) - return x.replace(tag=(len(ctx[0])-1,)) -add_tags = pm_gate_kernel_sink+PatternMatcher([ - # don't tag BUFFERs, they are global - (UPat(GroupOp.All-{Ops.PARAM, Ops.CONST, Ops.DEVICE, Ops.UNIQUE, Ops.LUNIQUE, Ops.DEFINE_VAR, Ops.BIND, Ops.END, - Ops.MSTACK, Ops.MSELECT, Ops.RANGE}.union(GroupOp.Movement), name="x"), tag_uop), - (UPat({Ops.MSTACK, Ops.MSELECT}, name="x"), lambda ctx,x: None if all(s.op is Ops.PARAM for s in x.src) else tag_uop(ctx, x)), -]) - -# support for using a contiguous permuted view instead of the parent view if one exists - -def found_contiguous(ctx:dict[UOp, UOp], contig:UOp, src:UOp): - x = src - while x is not src.base: - if x.op is Ops.PERMUTE: contig = contig.permute(argsort(x.marg)) - elif x.op is Ops.RESHAPE: contig = contig.reshape(x.src[0].shape) - else: return None - x = x.src[0] - ctx[src.base] = contig -replace_contiguous = PatternMatcher([ - (UPat(Ops.CONTIGUOUS, src=(UPat(GroupOp.Movement, name="src"),), name="contig"), found_contiguous), - (UPat(GroupOp.ALU, name="alu"), lambda ctx,alu: alu.replace(src=new_src) if (new_src:=tuple(ctx.get(s, s) for s in alu.src)) != alu.src else None), -]) - def get_rangeify(sink:UOp) -> UOp: if VIZ: graph_rewrite(sink, PatternMatcher([]), name="View Input Graph") - uop_list: list[UOp] = [] - tsink = graph_rewrite(sink, add_tags, ctx=(uop_list, set()), bottom_up=True, name="number the uops") - - tsink = graph_rewrite(tsink, pm_syntactic_sugar+pm_mops+earliest_rewrites+replace_contiguous, ctx={}, bottom_up=True, name="earliest rewrites") + tsink = graph_rewrite(sink, pm_syntactic_sugar+pm_mops+earliest_rewrites, bottom_up=True, name="earliest rewrites") # convert movement ops to ranges tsink, rctx = run_rangeify(tsink, bool(DEBUG_RANGEIFY)) @@ -570,12 +537,6 @@ def get_rangeify(sink:UOp) -> UOp: tsink = graph_rewrite(tsink, symbolic+pm_reduce_simplify+pm_const_buffer_folding+pm_remove_bufferize, name="symbolic+reduce_collapse+debuf") tsink = graph_rewrite(tsink, pm_limit_bufs, ctx=rctx, name="limit buffers") - # rebuild the sink with all the BUFFERIZEs with tags, this is what's ending up in the tensor graph - # MSTACK stacks multiple BUFFERIZEs in one tagged tensor - # if it's not tagged by here, it's out - tsink = UOp.sink(*[x for x in tsink.backward_slice if x.base.op in {Ops.BUFFERIZE, Ops.MSTACK, Ops.CONST, Ops.PARAM, Ops.AFTER} and \ - x.tag is not None and len(x.tag)]) - if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Tagged Rangeify") # bufferize -> store @@ -597,8 +558,5 @@ def get_rangeify(sink:UOp) -> UOp: raise RuntimeError(f"cycle detected in graph, kernel for {u.buf_uop} must either depend on AFTER or BUFFER") assign_rep[a] = kernel_assign[s] = a.replace(src=a.src+(u,)) if assign_rep: tsink = graph_rewrite(tsink, _substitute, ctx=assign_rep, bottom_up=True, name="fix_assign") - - # TODO: we can probably get this earlier - tsink = graph_rewrite(tsink, _remove_all_tags, name="remove all tags") if VIZ: graph_rewrite(tsink, PatternMatcher([]), name="View Kernel Graph") return tsink \ No newline at end of file