mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
kernel count tests for pad [pr] (#10369)
* kernel count tests for pads * handcoded rand one kernel * comment * prerealize device rng counter * test_rand_handcoded generates /0 * remove track_rewrites
This commit is contained in:
@@ -15,7 +15,7 @@ from tinygrad.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite,
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map, Kernel, create_ast, merge_views
|
||||
from tinygrad.engine.grouper import view_left, view_right, sym, get_becomes_map, Kernel, create_ast, merge_views, create_kernels
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
|
||||
@@ -23,7 +23,7 @@ def verify_ast(sink:UOp): return type_verify(list(sink.toposort()), shape_spec)
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize: pre.schedule()
|
||||
with Context(DEBUG=0, TRACK_MATCH_STATS=0): Tensor.realize(*to_prerealize)
|
||||
if isinstance(t, Tensor): sched = t.schedule()
|
||||
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
|
||||
else:
|
||||
@@ -32,15 +32,14 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
|
||||
becomes_map = get_becomes_map(sink)
|
||||
sched, _, __ = create_schedule_with_vars(sink.substitute(becomes_map))
|
||||
# test lowering all the ScheduleItems to ExecItems
|
||||
lowered = [x[1] for x in lower_schedule(sched.copy())]
|
||||
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
|
||||
if len(sched) != allowed:
|
||||
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
|
||||
if kernel_cnt != allowed:
|
||||
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
if DEBUG >= 3:
|
||||
for i,s in enumerate(sched):
|
||||
print("kernel", i+1)
|
||||
print(s.ast)
|
||||
raise KernelCountException(f"{len(sched)=} != {allowed}")
|
||||
raise KernelCountException(f"{kernel_cnt} != {allowed}")
|
||||
return sched
|
||||
|
||||
def _realize_weights(m):
|
||||
@@ -87,6 +86,44 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(b, 1))
|
||||
np.testing.assert_allclose(b.numpy(), np.broadcast_to(a.numpy().astype(np.float16), (2, 4, 4))+2)
|
||||
|
||||
def test_push_pads_elementwise(self):
|
||||
x = Tensor.full((4,4), 2.).contiguous().realize()
|
||||
y = Tensor.full((4,4), 4.).contiguous().realize()
|
||||
z = (x.reciprocal()*y).pad((None, (0,1),)).sum()
|
||||
run_schedule(check_schedule(z, 2))
|
||||
self.assertEqual(z.item(), 32)
|
||||
|
||||
# TODO: same issue in precompute_freqs_cis
|
||||
def test_push_pads_contiguous(self):
|
||||
x = Tensor.full((4,1), 2.).contiguous()
|
||||
y = Tensor.full((4,4), 4.).contiguous()
|
||||
z = (x.reciprocal().expand(4,4)*y).pad((None, (0,1),)).sum()
|
||||
run_schedule(check_schedule(z, 3, [x,y]))
|
||||
self.assertEqual(z.item(), 32)
|
||||
|
||||
def test_rand(self):
|
||||
x = Tensor.rand(32)
|
||||
check_schedule(x, 3, [Tensor._device_rng_counters[x.device]])
|
||||
|
||||
def test_rand_recompute_arange(self):
|
||||
x = Tensor.rand(32)
|
||||
with Context(DONT_GROUP_REDUCES=1):
|
||||
check_schedule(x, 2, [Tensor._device_rng_counters[x.device]])
|
||||
|
||||
@unittest.skip("TODO: do not divide by zero given x.idiv(VALID)")
|
||||
def test_rand_handcoded(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.rand(32)
|
||||
# pre-realize shared seed
|
||||
Tensor._device_rng_counters[x.device].realize()
|
||||
# run custom kernelized kernel
|
||||
sched_sink = graph_rewrite(x.lazydata, create_kernels, ctx={u:None for u in x.lazydata.toposort() if u.op is Ops.COPY}, bottom_up=True)
|
||||
y = Tensor(graph_rewrite(sched_sink, create_ast, bottom_up=True))
|
||||
run_schedule(check_schedule(y, 1))
|
||||
# compare against reference
|
||||
run_schedule(check_schedule(x, 3))
|
||||
np.testing.assert_allclose(y.numpy(), x.numpy())
|
||||
|
||||
def test_empty_is_not_realized(self):
|
||||
a = Tensor.empty(10)
|
||||
child = a+2
|
||||
|
||||
Reference in New Issue
Block a user