mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
reorder DEFINE_GLOBAL in fuzz_uops (#4651)
* globals base * test: opt out of DEFINE_GLOBAL * do it like ExecItem
This commit is contained in:
11
test/external/fuzz_uops.py
vendored
11
test/external/fuzz_uops.py
vendored
@@ -22,26 +22,27 @@ class UOpsFuzzerRunner(CompiledRunner):
|
||||
def __call__(self, rawbufs:List[Buffer], var_vals:Dict[Variable, int], wait=False):
|
||||
assert self.p.uops is not None and len(self.p.uops.fuzz_paths) >= 1
|
||||
init_rawbufs, init_name = {x:x.as_buffer() for x in rawbufs}, self.p.function_name
|
||||
init_globals = {i[0]:buf for i, buf in zip(self.p.globals, rawbufs)}
|
||||
if DEBUG >= 1: print(colored(f"fuzzing {len(self.p.uops.fuzz_paths)} UOps permutations for {init_name}", "yellow"))
|
||||
|
||||
super().__call__(rawbufs, var_vals, wait)
|
||||
ground_truth = [np.frombuffer(x.as_buffer(), x.dtype.np) for x in rawbufs]
|
||||
ground_truth = {x:np.frombuffer(x.as_buffer(), x.dtype.np) for x in rawbufs}
|
||||
|
||||
for i, path in enumerate(self.p.uops.fuzz_paths):
|
||||
# setup prg
|
||||
uops = UOpGraph()
|
||||
uops._uops = list(path)
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops))
|
||||
if DEBUG >= 6: uops.print()
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
|
||||
self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
|
||||
self.clprg = Device[self.p.dname].runtime(name, self.lib)
|
||||
for x in rawbufs: x.copyin(init_rawbufs[x])
|
||||
if DEBUG >= 4: print(self.p.src)
|
||||
if DEBUG >= 7: uops.print()
|
||||
for x in (rawbufs:=[init_globals[i[0]] for i in self.p.globals]): x.copyin(init_rawbufs[x])
|
||||
# verify
|
||||
super().__call__(rawbufs, var_vals, wait)
|
||||
for i, x in enumerate(rawbufs):
|
||||
try:
|
||||
np.testing.assert_allclose(np.frombuffer(x.as_buffer(), x.dtype.np), ground_truth[i])
|
||||
np.testing.assert_allclose(np.frombuffer(x.as_buffer(), x.dtype.np), ground_truth[x])
|
||||
if DEBUG >= 2: print(colored(name, "green"))
|
||||
except AssertionError as e:
|
||||
print(colored(name, "red"))
|
||||
|
||||
Reference in New Issue
Block a user