launch a cmp kernel

This commit is contained in:
qazal
2024-04-16 10:13:04 +03:00
parent efe5428ae8
commit 791c608992

View File

@@ -2,10 +2,12 @@ import numpy as np
from collections import defaultdict
from typing import DefaultDict, Dict, List, Set, TypeVar
from tinygrad.buffer import Buffer
from tinygrad.device import Device
from tinygrad.dtype import dtypes
from tinygrad.engine.realize import CustomOp, lower_schedule, capturing
from tinygrad.helpers import DEBUG, colored, getenv
from tinygrad.lazy import LazyBuffer
from tinygrad.engine.schedule import _graph_schedule
from tinygrad.engine.schedule import _graph_schedule, create_schedule
from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.tensor import Tensor
@@ -51,12 +53,20 @@ def fuzz_schedule(outs: List[LazyBuffer]):
# assert all LazyBuffers realized correctly
for lb, bufs in outputs.items():
ground_truth = np.frombuffer(bufs[0].as_buffer(), bufs[0].dtype.np)
a = Tensor.empty((bufs[0].size,), dtype=bufs[0].dtype)
b = Tensor.empty((bufs[0].size,), dtype=bufs[0].dtype)
ast = assert_allclose_ast(a, b)
for buf in bufs[1:]:
try: np.testing.assert_allclose(np.frombuffer(buf.as_buffer(), buf.dtype.np), ground_truth, atol=1e-2, rtol=1e-2)
except AssertionError as e:
print(f"COMPARE FAILED FOR {lb}")
raise e
prg = Device[Device.DEFAULT].get_runner(*ast)
ret_buf = Buffer(Device.DEFAULT, 1, dtypes.bool).allocate()
prg.exec([ret_buf, bufs[0], buf])
del buf
assert np.frombuffer(ret_buf.as_buffer(), ret_buf.dtype.np), f"FAILED FOR {lb}"
def assert_allclose_ast(a:Tensor, b:Tensor, atol=1e-2, rtol=1e-2):
diff = (a - b).abs()
tol = atol + rtol * b.abs()
return create_schedule([((diff > tol).sum() == 0).lazydata])[-1].ast
T = TypeVar("T")
def find_all_toposorts(graph:DefaultDict[T, List[T]], in_degree:DefaultDict[T, int]) -> List[List[T]]: