mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
rangeify cost function infrastructure (#12091)
* one call to hc opt * does that pass? * add cost function to rangeify * test * more test * gate thread * bufferize has shape * ish * match old behavior * no ci there
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad import Tensor, nn
|
||||
from tinygrad.helpers import RANGEIFY, Context, GlobalCounters
|
||||
from tinygrad.uop.ops import UOp
|
||||
|
||||
@@ -93,6 +93,16 @@ class TestRangeify(unittest.TestCase):
|
||||
w2 = Tensor.empty(12, 8, 3, 3)
|
||||
x.conv2d(w1).conv2d(w2).realize()
|
||||
|
||||
def test_conv_maxpool_contig(self): self.test_conv_maxpool(True)
|
||||
def test_conv_maxpool(self, contig=False):
|
||||
GlobalCounters.reset()
|
||||
x = Tensor.empty(32, 16, 64, 64)
|
||||
l1 = nn.Conv2d(16, 16, 3)
|
||||
for p in nn.state.get_parameters(l1): p.replace(Tensor.empty(p.shape))
|
||||
x = l1(x)
|
||||
if contig: x = x.contiguous()
|
||||
x.max_pool2d().realize()
|
||||
|
||||
def test_double_conv2d_half_contig(self):
|
||||
x = Tensor.empty(1, 4, 32, 32)
|
||||
w1 = Tensor.empty(8, 4, 3, 3)
|
||||
|
||||
@@ -192,7 +192,7 @@ class ClangRenderer(CStyleLanguage):
|
||||
float4_style = ('{', '}')
|
||||
gep_arr_threshold = 0
|
||||
has_local = False
|
||||
has_threads = True
|
||||
has_threads = bool(getenv("THREADED", 1))
|
||||
global_max = (CPU_COUNT.value, 0, 0)
|
||||
infinity = "__builtin_inff()"
|
||||
nan = '__builtin_nanf("")'
|
||||
|
||||
@@ -16,15 +16,15 @@ double_reshape = PatternMatcher([
|
||||
])
|
||||
|
||||
earliest_rewrites = double_reshape+PatternMatcher([
|
||||
# non shape changing RESHAPE is NOOP
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
|
||||
# UOp with size 0 is zero
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="root"), lambda root: root.const_like(0) if root.base.st is not None and root.size == 0 else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# non shape changing RESHAPE is NOOP
|
||||
(UPat(Ops.RESHAPE, name="x"), lambda x: x.src[0] if x.src[0].shape == x.arg else None),
|
||||
# DETACH and CONTIGUOUS_BACKWARD are NOOPs here, so is FUSE
|
||||
(UPat((Ops.DETACH, Ops.CONTIGUOUS_BACKWARD, Ops.FUSE), name="x"), lambda x: x.src[0]),
|
||||
# RESHAPE after COPY
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.RESHAPE, name="r"),UPat(name="d")), name="c"), lambda c,r,d: c.replace(src=(r.src[0],d)).reshape(r.arg)),
|
||||
# TODO: this should be BUFFER_VIEW
|
||||
@@ -307,12 +307,23 @@ def cleanup_dead_axes(b:UOp):
|
||||
|
||||
# if a buffer is being stored just for permutes or something, remove it
|
||||
# we want to reexpress the indexes of idx2 in terms of the implied b1
|
||||
def remove_bufferize(b2:UOp, idx2:UOp):
|
||||
# HACK
|
||||
if len(b2.src) != len(idx2.src): return None
|
||||
assert len(b2.src) == len(idx2.src)
|
||||
assert all(x.op is Ops.RANGE for x in b2.src[1:])
|
||||
return b2.src[0].substitute(dict(zip(b2.src[1:], idx2.src[1:])))
|
||||
def remove_bufferize(src:UOp, buf:UOp, idx:UOp):
|
||||
# see if we can't do it, should this ever hit?
|
||||
assert len(buf.src) == len(idx.src), "index on wrong bufferize"
|
||||
assert all(x.op is Ops.RANGE for x in buf.src[1:])
|
||||
|
||||
# here is where we compute the cost
|
||||
# for now just no REDUCE, COPY, or ASSIGN
|
||||
# TODO: exclude fusion of user contiguous
|
||||
#ran = src.toposort(gate=lambda x: x.op not in {Ops.INDEX})
|
||||
#if any(x.op in {Ops.REDUCE, Ops.COPY, Ops.ASSIGN} for x in ran): return None
|
||||
|
||||
# simple, matching old behavior
|
||||
if src.op is not Ops.INDEX: return None
|
||||
|
||||
# this is the ranges replaced
|
||||
return src.substitute(dict(zip(buf.src[1:], idx.src[1:])))
|
||||
|
||||
|
||||
pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
||||
#(UPat(Ops.BUFFERIZE, name="b"), cleanup_dead_axes),
|
||||
@@ -320,8 +331,8 @@ pm_cleanups = double_reshape+pm_mops+PatternMatcher([
|
||||
# NOTE: this is mostly the same case as below, but if there's no INDEX this gets more
|
||||
#(UPat(Ops.INDEX, name="idx").f(Ops.BUFFERIZE, allow_any_len=True, name="b2"),
|
||||
# lambda idx,b2: idx.src[0] if idx.src[1:] == b2.src[1:] else None),
|
||||
# remove reindexing
|
||||
(UPat(Ops.INDEX).f(Ops.BUFFERIZE, allow_any_len=True, name="b2").f(Ops.INDEX, allow_any_len=True, name="idx2"), remove_bufferize),
|
||||
# remove reindexing with cost function
|
||||
(UPat.var("src").f(Ops.BUFFERIZE, allow_any_len=True, name="buf").f(Ops.INDEX, allow_any_len=True, name="idx"), remove_bufferize),
|
||||
# no buffers for const
|
||||
#(UPat(Ops.CONST, name='c').f(Ops.BUFFERIZE, allow_any_len=True, name="b"), lambda c,b: c.reshape((1,)*len(b.shape)).expand(b.shape)),
|
||||
])
|
||||
|
||||
Reference in New Issue
Block a user