mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
rangeify: late zero folding (#12464)
* rangeify: late zero folding * early * not kernels * none * multi * linter * mstack is sink comment * more comment
This commit is contained in:
@@ -7,7 +7,7 @@ from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
|
||||
import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings
|
||||
from test.helpers import REAL_DEV, not_support_multi_device, expect_rangeify_fails
|
||||
from test.helpers import REAL_DEV, not_support_multi_device
|
||||
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
@@ -201,7 +201,6 @@ class TestMultiTensor(unittest.TestCase):
|
||||
fn = f(n)
|
||||
np.testing.assert_allclose(fX.numpy(), fn, rtol=1e-6, atol=1e-6)
|
||||
|
||||
@expect_rangeify_fails # TODO: fix
|
||||
def test_allreduce_shard_ring_sum(self):
|
||||
for axis in (0, 1, None):
|
||||
for use_ring in (0, 2):
|
||||
|
||||
@@ -814,6 +814,14 @@ class TestSchedule(unittest.TestCase):
|
||||
check_schedule(a, 0)
|
||||
self.assertEqual(a.tolist(), [])
|
||||
|
||||
def test_zero_size_children(self):
|
||||
r = Tensor.ones(1,2).contiguous().realize().sum(axis=(1,), keepdim=True)
|
||||
ax = r.reshape(1)*2
|
||||
ay = r.reshape(1).shrink(((1,1),))*2
|
||||
out = ax+ay.pad(((1, 0),))
|
||||
run_schedule(check_schedule(out, 1))
|
||||
self.assertEqual(out.item(), 4.)
|
||||
|
||||
def test_reduce_permute_nofuse(self):
|
||||
x = Tensor.empty(32, 32, 32)
|
||||
y = Tensor.empty(32, 32)
|
||||
|
||||
@@ -57,6 +57,9 @@ earliest_rewrites = PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
|
||||
# handle size 0
|
||||
(UPat(GroupOp.All-{Ops.SINK}, name="x"), lambda x: x.const_like(0).rtag(x.tag) if x.st is not None and x.size == 0 else None),
|
||||
|
||||
# remove contiguous on movement ops before a copy on disk
|
||||
(UPat(GroupOp.Movement-{Ops.SHRINK, Ops.RESHAPE}, name="x").f(Ops.CONTIGUOUS).f(Ops.COPY, allow_any_len=True, name="copy"),
|
||||
lambda x,copy: copy.replace(src=(x,)+copy.src[1:]) if isinstance(x.device, str) and x.device.startswith("DISK") else None),
|
||||
@@ -133,7 +136,8 @@ def extract_children(ctx:ChildrenContext, x:UOp):
|
||||
children_map = x.get_children_map()
|
||||
ctx.children = {}
|
||||
for k,v in children_map.items():
|
||||
non_sink_children = [u for u in v if u.op is not Ops.SINK]
|
||||
# NOTE: we treat mstack children like sink here
|
||||
non_sink_children = [u for u in v if u.op not in {Ops.SINK, Ops.MSTACK}]
|
||||
if len(non_sink_children) <= 1: continue
|
||||
# NOTE: this gate shouldn't be here
|
||||
if k.op_in_parents(Ops.REDUCE_AXIS) and k.op_in_parents(Ops.BUFFER, Ops.CONTIGUOUS):
|
||||
@@ -363,9 +367,6 @@ pm_rangeify = pm_mops+PatternMatcher([
|
||||
# handle arg on any op with weight. old endrange stuff
|
||||
(UPat(Ops.INDEX, src=(UPat(GroupOp.Elementwise.union({Ops.REDUCE_AXIS})),), allow_any_len=True, name="idx"), might_end_axis),
|
||||
|
||||
# handle size 0
|
||||
(UPat(Ops.INDEX, name="x"), lambda x: x.replace(src=(x.const_like(0),)+x.src[1:]) if x.st is not None and x.size == 0 else None),
|
||||
|
||||
# handle assign
|
||||
(UPat(Ops.INDEX, src=(UPat(Ops.ASSIGN, name="assign"),), allow_any_len=True, name="x"),
|
||||
lambda x,assign: assign.replace(src=tuple([s.index(*x.src[1:]) for s in assign.src])+(assign.src[0],)) \
|
||||
@@ -572,7 +573,7 @@ pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
|
||||
|
||||
# move RESHAPEs through MSELECT/MSTACK
|
||||
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
||||
lambda m: m.replace(src=tuple([x.src[0] for x in m.src]), tag=None).reshape(m.src[0].arg).rtag(m.tag)),
|
||||
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.src[0].arg).rtag(m.tag)),
|
||||
])
|
||||
|
||||
# *****************
|
||||
|
||||
Reference in New Issue
Block a user