VALIDATE_WITH_CPU [pr] (#9488)

* VALIDATE_WITH_CPU [pr]

* fix test
This commit is contained in:
George Hotz
2025-03-18 15:15:04 +08:00
committed by GitHub
parent 935cd01f56
commit 117b7a16ef
8 changed files with 28 additions and 13 deletions

View File

@@ -113,7 +113,7 @@ if __name__ == "__main__":
GlobalCounters.reset()
out = a.sum()
sis = out.schedule()
for i,ei in enumerate(lower_schedule(sis)):
for i,(_,ei) in enumerate(lower_schedule(sis)):
if i == 0:
# change the source code
prg_spec = ei.prg.p

View File

@@ -123,7 +123,7 @@ class TestImageDType(unittest.TestCase):
loss = x.image_dot(w1).image_dot(w2).float().max()
loss.backward()
sched = unwrap(w1.grad).schedule()
for s,ei in zip(sched, lower_schedule(sched[:])):
for s,(_,ei) in zip(sched, lower_schedule(sched[:])):
ei.run()
if s.bufs[0].dtype == dtypes.float:
lst = s.bufs[0].as_buffer().cast("f").tolist()

View File

@@ -71,7 +71,7 @@ class TestLinearizer(unittest.TestCase):
a, b = Tensor.randn(4).realize(), Tensor.randn(4).realize()
np_a, np_b = a.numpy(), b.numpy()
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
lowered = list(lower_schedule(c.schedule()))
lowered = [x[1] for x in lower_schedule(c.schedule())]
for ei in lowered: ei.run()
rawbufs = lowered[-1].bufs
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}

View File

@@ -81,7 +81,7 @@ class TestMultiTensor(unittest.TestCase):
out = (X + X)
sched = out.schedule()
names = []
for si, ei in zip(sched[:], lower_schedule(sched)):
for si, ei in lower_schedule(sched):
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name)
ei.run()
self.assertEqual(len(set(names)), 3), "function was relinearized"

View File

@@ -99,7 +99,7 @@ class TestRandomness(unittest.TestCase):
@unittest.skipIf(getenv("PTX"), "fails with PTX")
def test_threefry_doesnt_use_long(self):
for ei in lower_schedule(Tensor.rand(20).schedule()):
for (_,ei) in lower_schedule(Tensor.rand(20).schedule()):
if isinstance(ei.prg, CompiledRunner):
for u in ei.prg.p.uops:
self.assertNotIn(u.dtype, {dtypes.long, dtypes.ulong}, msg=f"long found in {ei.prg.p.name}")

View File

@@ -31,7 +31,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealiz
assert isinstance(t, UOp), f"can't schedule {t}"
sched, _, __ = create_schedule_with_vars(t.sink())
# test lowering all the ScheduleItems to ExecItems
lowered = list(lower_schedule(sched.copy()))
lowered = [x[1] for x in lower_schedule(sched.copy())]
if filter_sink: sched = [s for s,ei in zip(sched, lowered) if isinstance(ei.prg, CompiledRunner)]
if len(sched) != allowed:
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
@@ -1614,7 +1614,7 @@ class TestIndexing(unittest.TestCase):
with Context(FUSE_ARANGE=getenv("FUSE_ARANGE", 1)):
lst = [xt] if isinstance(xt, Tensor) else xt
s = Tensor.schedule(*lst)
lowered = list(lower_schedule(s.copy()))
lowered = [x[1] for x in lower_schedule(s.copy())]
kernels = [ei for ei in list(lowered) if isinstance(ei.prg, CompiledRunner)]
if FUSE_ARANGE: self.assertEqual(len(kernels), cnt)
for ei in lowered: ei.run(do_update_stats=True)

View File

@@ -2,7 +2,7 @@ from typing import Optional, cast, Generator
import time, pprint
from dataclasses import dataclass, replace
from tinygrad.helpers import all_same, colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, TRACEMETA
from tinygrad.helpers import DEVECTORIZE, time_to_str
from tinygrad.helpers import DEVECTORIZE, time_to_str, VALIDATE_WITH_CPU
from tinygrad.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
from tinygrad.device import Device, Buffer
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
@@ -150,10 +150,10 @@ si_lowerer = PatternMatcher([
])
def lower_schedule_item(si:ScheduleItem) -> ExecItem: return ExecItem(*cast(tuple[Runner,list], si_lowerer.rewrite(si.ast, si.bufs)), si.metadata)
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, None]:
def lower_schedule(schedule:list[ScheduleItem]) -> Generator[tuple[ScheduleItem, ExecItem], None, None]:
while len(schedule):
si = schedule.pop(0)
try: yield lower_schedule_item(si)
try: yield (si, lower_schedule_item(si))
except Exception as e:
if DEBUG >= 2:
print(f"error lowering {si.ast.op}")
@@ -166,6 +166,21 @@ def lower_schedule(schedule:list[ScheduleItem]) -> Generator[ExecItem, None, Non
capturing: list = [] # put classes with an add method in here
def run_schedule(schedule:list[ScheduleItem], var_vals:Optional[dict[Variable, int]]=None, do_update_stats=True):
for ei in lower_schedule(schedule):
for si, ei in lower_schedule(schedule):
if len(capturing) and CAPTURING: capturing[0].add(ei)
ei.run(var_vals, do_update_stats=do_update_stats)
if VALIDATE_WITH_CPU and si.ast.op is Ops.SINK:
# copy in allocated buffers from the GPU
nb: tuple[Buffer, ...] = tuple(Buffer("CPU", b.size, b.dtype) for b in si.bufs)
for cpu_b, gpu_b in zip(nb, si.bufs):
if gpu_b.is_allocated(): cpu_b.ensure_allocated().copyin(gpu_b.as_buffer())
# run on GPU
ei.run(var_vals, do_update_stats=do_update_stats)
# validate the output buffers match (NOTE: this is assuming the output is buffer 0)
lower_schedule_item(ScheduleItem(si.ast, nb, si.metadata)).run(var_vals, do_update_stats=do_update_stats)
import numpy as np
np.testing.assert_allclose(nb[0].numpy(), si.bufs[0].numpy(), rtol=1e-3, atol=1e-3)
else:
ei.run(var_vals, do_update_stats=do_update_stats)

View File

@@ -113,7 +113,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), Conte
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
QUANTIZE = ContextVar("QUANTIZE", 0)
QUANTIZE, VALIDATE_WITH_CPU = ContextVar("QUANTIZE", 0), ContextVar("VALIDATE_WITH_CPU", 0)
@dataclass(frozen=True)
class Metadata: