mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
launch a cmp kernel
This commit is contained in:
22
test/external/fuzz_schedule.py
vendored
22
test/external/fuzz_schedule.py
vendored
@@ -2,10 +2,12 @@ import numpy as np
|
||||
from collections import defaultdict
|
||||
from typing import DefaultDict, Dict, List, Set, TypeVar
|
||||
from tinygrad.buffer import Buffer
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import CustomOp, lower_schedule, capturing
|
||||
from tinygrad.helpers import DEBUG, colored, getenv
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.engine.schedule import _graph_schedule
|
||||
from tinygrad.engine.schedule import _graph_schedule, create_schedule
|
||||
from tinygrad.ops import LoadOps, ScheduleItem
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
@@ -51,12 +53,20 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
|
||||
# assert all LazyBuffers realized correctly
|
||||
for lb, bufs in outputs.items():
|
||||
ground_truth = np.frombuffer(bufs[0].as_buffer(), bufs[0].dtype.np)
|
||||
a = Tensor.empty((bufs[0].size,), dtype=bufs[0].dtype)
|
||||
b = Tensor.empty((bufs[0].size,), dtype=bufs[0].dtype)
|
||||
ast = assert_allclose_ast(a, b)
|
||||
for buf in bufs[1:]:
|
||||
try: np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), ground_truth, atol=1e-2, rtol=1e-2)
|
||||
except AssertionError as e:
|
||||
print(f"COMPARE FAILED FOR {lb}")
|
||||
raise e
|
||||
prg = Device[Device.DEFAULT].get_runner(*ast)
|
||||
ret_buf = Buffer(Device.DEFAULT, 1, dtypes.bool).allocate()
|
||||
prg.exec([ret_buf, bufs[0], buf])
|
||||
del buf
|
||||
assert np.frombuffer(ret_buf.as_buffer(), ret_buf.dtype.np), f"FAILED FOR {lb}"
|
||||
|
||||
def assert_allclose_ast(a:Tensor, b:Tensor, atol=1e-2, rtol=1e-2):
|
||||
diff = (a - b).abs()
|
||||
tol = atol + rtol * b.abs()
|
||||
return create_schedule([((diff > tol).sum() == 0).lazydata])[-1].ast
|
||||
|
||||
T = TypeVar("T")
|
||||
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[List[T]]:
|
||||
|
||||
Reference in New Issue
Block a user