diff --git a/extra/reduce_speed.py b/extra/reduce_speed.py index ca9ea34674..36bd0d3d5c 100644 --- a/extra/reduce_speed.py +++ b/extra/reduce_speed.py @@ -113,7 +113,7 @@ if __name__ == "__main__": GlobalCounters.reset() out = a.sum() sis = out.schedule() - for i,ei in enumerate(lower_schedule(sis)): + for i,(_,ei) in enumerate(lower_schedule(sis)): if i == 0: # change the source code prg_spec = ei.prg.p diff --git a/test/test_image_dtype.py b/test/test_image_dtype.py index f7e2154c22..256bc4b0ba 100644 --- a/test/test_image_dtype.py +++ b/test/test_image_dtype.py @@ -123,7 +123,7 @@ class TestImageDType(unittest.TestCase): loss = x.image_dot(w1).image_dot(w2).float().max() loss.backward() sched = unwrap(w1.grad).schedule() - for s,ei in zip(sched, lower_schedule(sched[:])): + for s,(_,ei) in zip(sched, lower_schedule(sched[:])): ei.run() if s.bufs[0].dtype == dtypes.float: lst = s.bufs[0].as_buffer().cast("f").tolist() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index f9601eb131..af4fce506f 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -71,7 +71,7 @@ class TestLinearizer(unittest.TestCase): a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize() np_a, np_b = a.numpy(), b.numpy() c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),)))) - lowered = list(lower_schedule(c.schedule())) + lowered = [x[1] for x in lower_schedule(c.schedule())] for ei in lowered: ei.run() rawbufs = lowered[-1].bufs assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized} diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 0289a55c3d..c6298b63ef 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -81,7 +81,7 @@ class TestMultiTensor(unittest.TestCase): out = (X + X) sched = out.schedule() names = [] - for si, ei in zip(sched[:], lower_schedule(sched)): + for si, ei in lower_schedule(sched): if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name) ei.run() self.assertEqual(len(set(names)), 3), "function was relinearized" diff --git a/test/test_randomness.py b/test/test_randomness.py index ae7bfa1dd2..59d1048e8b 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -99,7 +99,7 @@ class TestRandomness(unittest.TestCase): @unittest.skipIf(getenv("PTX"), "fails with PTX") def test_threefry_doesnt_use_long(self): - for ei in lower_schedule(Tensor.rand(20).schedule()): + for (_,ei) in lower_schedule(Tensor.rand(20).schedule()): if isinstance(ei.prg, CompiledRunner): for u in ei.prg.p.uops: self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}") diff --git a/test/test_schedule.py b/test/test_schedule.py index d4116362db..95b1bbfcd4 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -31,7 +31,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz assert isinstance(t, UOp), f"can't schedule {t}" sched, _, __ = create_schedule_with_vars(t.sink()) # test lowering all the ScheduleItems to ExecItems - lowered = list(lower_schedule(sched.copy())) + lowered = [x[1] for x in lower_schedule(sched.copy())] if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)] if len(sched) != allowed: print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}") @@ -1614,7 +1614,7 @@ class TestIndexing(unittest.TestCase): with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)): lst = [xt] if isinstance(xt, Tensor) else xt s = Tensor.schedule(*lst) - lowered = list(lower_schedule(s.copy())) + lowered = [x[1] for x in lower_schedule(s.copy())] kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)] if FUSE_ARANGE: self.assertEqual(len(kernels), cnt) for ei in lowered: ei.run(do_update_stats=True) diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 22c946a080..bce072ad4c 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -2,7 +2,7 @@ from typing import Optional, cast, Generator import time, pprint from dataclasses import dataclass, replace from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA -from tinygrad.helpers import DEVECTORIZE, time_to_str +from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer from tinygrad.device import Device, Buffer from tinygrad.renderer import Renderer, ProgramSpec, Estimates @@ -150,10 +150,10 @@ si_lowerer = PatternMatcher([ ]) def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata) -def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]: +def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]: while len(schedule): si = schedule.pop(0) - try: yield lower_schedule_item(si) + try: yield (si, lower_schedule_item(si)) except Exception as e: if DEBUG >= 2: print(f"error lowering {si.ast.op}") @@ -166,6 +166,21 @@ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, Non capturing: list = [] # put classes with an add method in here def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, int]]=None, do_update_stats=True): - for ei in lower_schedule(schedule): + for si, ei in lower_schedule(schedule): if len(capturing) and CAPTURING: capturing[0].add(ei) - ei.run(var_vals, do_update_stats=do_update_stats) + if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK: + # copy in allocated buffers from the GPU + nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs) + for cpu_b, gpu_b in zip(nb, si.bufs): + if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer()) + + # run on GPU + ei.run(var_vals, do_update_stats=do_update_stats) + + # validate the output buffers match (NOTE: this is assuming the output is buffer 0) + lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats) + import numpy as np + np.testing.assert_allclose(nb[0].numpy(), si.bufs[0].numpy(), rtol=1e-3, atol=1e-3) + else: + ei.run(var_vals, do_update_stats=do_update_stats) + diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 3c87eba7d6..924454886e 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -113,7 +113,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), Conte PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1) CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1) DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0) -QUANTIZE = ContextVar("QUANTIZE", 0) +QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0) @dataclass(frozen=True) class Metadata: