vars_from_ast -> LazyOp.vars (#2965)

This commit is contained in:
chenyu
2024-01-01 18:12:38 -05:00
committed by GitHub
parent 980f421442
commit 58d3d5030b
8 changed files with 17 additions and 20 deletions

View File

@@ -8,7 +8,6 @@ from tinygrad.features.search import get_linearizer_actions, bufs_from_lin
from tinygrad.graph import print_tree
from tinygrad.helpers import getenv
from tinygrad.device import Device, Compiled, Interpreted
from tinygrad.ops import vars_from_ast
from tinygrad.codegen.linearizer import UOp
def tuplize_uops(uops:List[UOp]) -> Tuple: return tuple([(x.uop, x.dtype, tuple(uops.index(x) for x in x.vin), x.arg) for x in uops])
@@ -17,7 +16,7 @@ device = Device[Device.DEFAULT]
def run_linearizer(lin: Linearizer, rawbufs=None, var_vals=None):
if rawbufs is None: rawbufs = bufs_from_lin(lin)
if var_vals is None: var_vals = {v: v.min for v in vars_from_ast(lin.ast)}
if var_vals is None: var_vals = {v: v.min for v in lin.ast.vars()}
# TODO: images needs required_optimization
try:
@@ -66,7 +65,7 @@ def fuzz_linearizer(lin: Linearizer):
print(lin.colored_shape())
# get a new output buffer
rawbufs[0] = type(rawbufs[0])(Device.DEFAULT, rawbufs[0].size, rawbufs[0].dtype)
var_vals = {v: random.randint(v.min, v.max) for v in vars_from_ast(lin.ast)}
var_vals = {v: random.randint(v.min, v.max) for v in lin.ast.vars()}
if (msg := run_linearizer(lin, rawbufs, var_vals)) != "PASS":
print(f"{lin.applied_opts=}")
return msg