diff --git a/test/test_rangeify.py b/test/test_rangeify.py index 44c56b70a1..ef4c98332f 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -1,5 +1,5 @@ import unittest -from tinygrad import Tensor +from tinygrad import Tensor, nn from tinygrad.helpers import RANGEIFY, Context, GlobalCounters from tinygrad.uop.ops import UOp @@ -93,6 +93,16 @@ class TestRangeify(unittest.TestCase): w2 = Tensor.empty(12, 8, 3, 3) x.conv2d(w1).conv2d(w2).realize() + def test_conv_maxpool_contig(self): self.test_conv_maxpool(True) + def test_conv_maxpool(self, contig=False): + GlobalCounters.reset() + x = Tensor.empty(32, 16, 64, 64) + l1 = nn.Conv2d(16, 16, 3) + for p in nn.state.get_parameters(l1): p.replace(Tensor.empty(p.shape)) + x = l1(x) + if contig: x = x.contiguous() + x.max_pool2d().realize() + def test_double_conv2d_half_contig(self): x = Tensor.empty(1, 4, 32, 32) w1 = Tensor.empty(8, 4, 3, 3) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 20e3b8ec02..9a2772c508 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -192,7 +192,7 @@ class ClangRenderer(CStyleLanguage): float4_style = ('{', '}') gep_arr_threshold = 0 has_local = False - has_threads = True + has_threads = bool(getenv("THREADED", 1)) global_max = (CPU_COUNT.value, 0, 0) infinity = "__builtin_inff()" nan = '__builtin_nanf("")' diff --git a/tinygrad/schedule/rangeify.py b/tinygrad/schedule/rangeify.py index 7681b6a0cb..f8ea8af998 100644 --- a/tinygrad/schedule/rangeify.py +++ b/tinygrad/schedule/rangeify.py @@ -16,15 +16,15 @@ double_reshape = PatternMatcher([ ]) earliest_rewrites = double_reshape+PatternMatcher([ + # non shape changing RESHAPE is NOOP + (UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None), # UOp with size 0 is zero (UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None), - # DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE - (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]), # reduce of size 0 is the identity element (UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)), lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None), - # non shape changing RESHAPE is NOOP - (UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None), + # DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE + (UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]), # RESHAPE after COPY (UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)), # TODO: this should be BUFFER_VIEW @@ -307,12 +307,23 @@ def cleanup_dead_axes(b:UOp): # if a buffer is being stored just for permutes or something, remove it # we want to reexpress the indexes of idx2 in terms of the implied b1 -def remove_bufferize(b2:UOp, idx2:UOp): - # HACK - if len(b2.src) != len(idx2.src): return None - assert len(b2.src) == len(idx2.src) - assert all(x.op is Ops.RANGE for x in b2.src[1:]) - return b2.src[0].substitute(dict(zip(b2.src[1:], idx2.src[1:]))) +def remove_bufferize(src:UOp, buf:UOp, idx:UOp): + # see if we can't do it, should this ever hit? + assert len(buf.src) == len(idx.src), "index on wrong bufferize" + assert all(x.op is Ops.RANGE for x in buf.src[1:]) + + # here is where we compute the cost + # for now just no REDUCE, COPY, or ASSIGN + # TODO: exclude fusion of user contiguous + #ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX}) + #if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.ASSIGN} for x in ran): return None + + # simple, matching old behavior + if src.op is not Ops.INDEX: return None + + # this is the ranges replaced + return src.substitute(dict(zip(buf.src[1:], idx.src[1:]))) + pm_cleanups = double_reshape+pm_mops+PatternMatcher([ #(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes), @@ -320,8 +331,8 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([ # NOTE: this is mostly the same case as below, but if there's no INDEX this gets more #(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"), # lambda idx,b2: idx.src[0] if idx.src[1:] == b2.src[1:] else None), - # remove reindexing - (UPat(Ops.INDEX).f(Ops.BUFFERIZE, allow_any_len=True, name="b2").f(Ops.INDEX, allow_any_len=True, name="idx2"), remove_bufferize), + # remove reindexing with cost function + (UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize), # no buffers for const #(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: c.reshape((1,)*len(b.shape)).expand(b.shape)), ])