rangeify cost function infrastructure (#12091)

* one call to hc opt

* does that pass?

* add cost function to rangeify

* test

* more test

* gate thread

* bufferize has shape

* ish

* match old behavior

* no ci there
This commit is contained in:
George Hotz
2025-09-11 07:19:53 +08:00
committed by GitHub
parent 78610b681e
commit d4eba5800d
3 changed files with 35 additions and 14 deletions

View File

@@ -1,5 +1,5 @@
import unittest
from tinygrad import Tensor
from tinygrad import Tensor, nn
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
from tinygrad.uop.ops import UOp
@@ -93,6 +93,16 @@ class TestRangeify(unittest.TestCase):
w2 = Tensor.empty(12, 8, 3, 3)
x.conv2d(w1).conv2d(w2).realize()
def test_conv_maxpool_contig(self): self.test_conv_maxpool(True)
def test_conv_maxpool(self, contig=False):
GlobalCounters.reset()
x = Tensor.empty(32, 16, 64, 64)
l1 = nn.Conv2d(16, 16, 3)
for p in nn.state.get_parameters(l1): p.replace(Tensor.empty(p.shape))
x = l1(x)
if contig: x = x.contiguous()
x.max_pool2d().realize()
def test_double_conv2d_half_contig(self):
x = Tensor.empty(1, 4, 32, 32)
w1 = Tensor.empty(8, 4, 3, 3)

View File

@@ -192,7 +192,7 @@ class ClangRenderer(CStyleLanguage):
float4_style = ('{', '}')
gep_arr_threshold = 0
has_local = False
has_threads = True
has_threads = bool(getenv("THREADED", 1))
global_max = (CPU_COUNT.value, 0, 0)
infinity = "__builtin_inff()"
nan = '__builtin_nanf("")'

View File

@@ -16,15 +16,15 @@ double_reshape = PatternMatcher([
])
earliest_rewrites = double_reshape+PatternMatcher([
# non shape changing RESHAPE is NOOP
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
# UOp with size 0 is zero
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
# reduce of size 0 is the identity element
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
# non shape changing RESHAPE is NOOP
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
# RESHAPE after COPY
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
# TODO: this should be BUFFER_VIEW
@@ -307,12 +307,23 @@ def cleanup_dead_axes(b:UOp):
# if a buffer is being stored just for permutes or something, remove it
# we want to reexpress the indexes of idx2 in terms of the implied b1
def remove_bufferize(b2:UOp, idx2:UOp):
# HACK
if len(b2.src) != len(idx2.src): return None
assert len(b2.src) == len(idx2.src)
assert all(x.op is Ops.RANGE for x in b2.src[1:])
return b2.src[0].substitute(dict(zip(b2.src[1:], idx2.src[1:])))
def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
# see if we can't do it, should this ever hit?
assert len(buf.src) == len(idx.src), "index on wrong bufferize"
assert all(x.op is Ops.RANGE for x in buf.src[1:])
# here is where we compute the cost
# for now just no REDUCE, COPY, or ASSIGN
# TODO: exclude fusion of user contiguous
#ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX})
#if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.ASSIGN} for x in ran): return None
# simple, matching old behavior
if src.op is not Ops.INDEX: return None
# this is the ranges replaced
return src.substitute(dict(zip(buf.src[1:], idx.src[1:])))
pm_cleanups = double_reshape+pm_mops+PatternMatcher([
#(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
@@ -320,8 +331,8 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more
#(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"),
# lambda idx,b2: idx.src[0] if idx.src[1:] == b2.src[1:] else None),
# remove reindexing
(UPat(Ops.INDEX).f(Ops.BUFFERIZE, allow_any_len=True, name="b2").f(Ops.INDEX, allow_any_len=True, name="idx2"), remove_bufferize),
# remove reindexing with cost function
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
# no buffers for const
#(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: c.reshape((1,)*len(b.shape)).expand(b.shape)),
])