mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
correctly insert UOps.END* in fuzz result (#4653)
This commit is contained in:
8
test/external/fuzz_uops.py
vendored
8
test/external/fuzz_uops.py
vendored
@@ -1,6 +1,6 @@
|
||||
import numpy as np
|
||||
from dataclasses import replace
|
||||
from typing import DefaultDict, Dict, List
|
||||
from typing import DefaultDict, Dict, List, Set
|
||||
from test.external.fuzz_schedule import find_all_toposorts
|
||||
from tinygrad.codegen.uops import UOp, UOpGraph, UOps
|
||||
from tinygrad.device import Buffer, Device
|
||||
@@ -8,7 +8,7 @@ from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int]):
|
||||
def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int], loops_children:Dict[UOp, Set[UOp]]):
|
||||
paths: List[List[UOp]] = []
|
||||
# TODO: express DEFINE_ACC and loop children conditions in the graph, builtin.
|
||||
for p in find_all_toposorts(graph, in_degree):
|
||||
@@ -16,6 +16,8 @@ def fuzz_uops(graph:DefaultDict[UOp, List[UOp]], in_degree:DefaultDict[UOp, int]
|
||||
paths.append(path:=list(p[:-1]))
|
||||
for u in path:
|
||||
if u.uop is UOps.IF: path.append(UOp(UOps.ENDIF, None, (u,)))
|
||||
if u.uop is UOps.RANGE:
|
||||
path.insert(max(path.index(x) for x in loops_children[u])+1, UOp(UOps.ENDRANGE, None, (u,)))
|
||||
return paths
|
||||
|
||||
class UOpsFuzzerRunner(CompiledRunner):
|
||||
@@ -34,9 +36,9 @@ class UOpsFuzzerRunner(CompiledRunner):
|
||||
uops._uops = list(path)
|
||||
if DEBUG >= 6: uops.print()
|
||||
self.p = replace(self.p, name=(name:=f"{init_name}fuzz{i}"), src=Device[self.p.dname].renderer.render(name, uops), uops=uops)
|
||||
if DEBUG >= 4: print(self.p.src)
|
||||
self.lib = Device[self.p.dname].compiler.compile_cached(self.p.src)
|
||||
self.clprg = Device[self.p.dname].runtime(name, self.lib)
|
||||
if DEBUG >= 4: print(self.p.src)
|
||||
for x in (rawbufs:=[init_globals[i[0]] for i in self.p.globals]): x.copyin(init_rawbufs[x])
|
||||
# verify
|
||||
super().__call__(rawbufs, var_vals, wait)
|
||||
|
||||
Reference in New Issue
Block a user