kernel count tests for pad [pr] (#10369)

* kernel count tests for pads

* handcoded rand one kernel

* comment

* prerealize device rng counter

* test_rand_handcoded generates /0

* remove track_rewrites
This commit is contained in:
qazal
2025-05-17 17:20:46 +03:00
committed by GitHub
parent 90c4bb10c0
commit e054b53a75

View File

@@ -15,7 +15,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite,
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map, Kernel, create_ast, merge_views
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map, Kernel, create_ast, merge_views, create_kernels
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
@@ -23,7 +23,7 @@ def verify_ast(sink:UOp): return type_verify(list(sink.toposort()), shape_spec)
class KernelCountException(Exception): pass
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
if to_prerealize:
for pre in to_prerealize: pre.schedule()
with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize)
if isinstance(t, Tensor): sched = t.schedule()
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
else:
@@ -32,15 +32,14 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
becomes_map = get_becomes_map(sink)
sched, _, __ = create_schedule_with_vars(sink.substitute(becomes_map))
# test lowering all the ScheduleItems to ExecItems
lowered = [x[1] for x in lower_schedule(sched.copy())]
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
if len(sched) != allowed:
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
if kernel_cnt != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
if DEBUG >= 3:
for i,s in enumerate(sched):
print("kernel", i+1)
print(s.ast)
raise KernelCountException(f"{len(sched)=} != {allowed}")
raise KernelCountException(f"{kernel_cnt} != {allowed}")
return sched
def _realize_weights(m):
@@ -87,6 +86,44 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(b, 1))
np.testing.assert_allclose(b.numpy(), np.broadcast_to(a.numpy().astype(np.float16), (2, 4, 4))+2)
def test_push_pads_elementwise(self):
x = Tensor.full((4,4), 2.).contiguous().realize()
y = Tensor.full((4,4), 4.).contiguous().realize()
z = (x.reciprocal()*y).pad((None, (0,1),)).sum()
run_schedule(check_schedule(z, 2))
self.assertEqual(z.item(), 32)
# TODO: same issue in precompute_freqs_cis
def test_push_pads_contiguous(self):
x = Tensor.full((4,1), 2.).contiguous()
y = Tensor.full((4,4), 4.).contiguous()
z = (x.reciprocal().expand(4,4)*y).pad((None, (0,1),)).sum()
run_schedule(check_schedule(z, 3, [x,y]))
self.assertEqual(z.item(), 32)
def test_rand(self):
x = Tensor.rand(32)
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
def test_rand_recompute_arange(self):
x = Tensor.rand(32)
with Context(DONT_GROUP_REDUCES=1):
check_schedule(x, 2, [Tensor._device_rng_counters[x.device]])
@unittest.skip("TODO: do not divide by zero given x.idiv(VALID)")
def test_rand_handcoded(self):
Tensor.manual_seed(0)
x = Tensor.rand(32)
# pre-realize shared seed
Tensor._device_rng_counters[x.device].realize()
# run custom kernelized kernel
sched_sink = graph_rewrite(x.lazydata, create_kernels, ctx={u:None for u in x.lazydata.toposort() if u.op is Ops.COPY}, bottom_up=True)
y = Tensor(graph_rewrite(sched_sink, create_ast, bottom_up=True))
run_schedule(check_schedule(y, 1))
# compare against reference
run_schedule(check_schedule(x, 3))
np.testing.assert_allclose(y.numpy(), x.numpy())
def test_empty_is_not_realized(self):
a = Tensor.empty(10)
child = a+2