From 94a72d44d238eebe1a16e9a3c1adf316a638bdcb Mon Sep 17 00:00:00 2001 From: gswangg <152219575+greg-niemeyer@users.noreply.github.com> Date: Wed, 28 Aug 2024 12:26:50 -0700 Subject: [PATCH] update CI tests in extra with UOp AST (#6290) --- extra/optimization/helpers.py | 4 ++-- extra/to_movement_ops.py | 5 ++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 427cf5ed2a..3de3a6c12d 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -11,7 +11,7 @@ inf, nan = float('inf'), float('nan') # kernel unpacker from tinygrad.codegen.kernel import Kernel -def ast_str_to_ast(ast_str:str) -> LazyOp: return LazyOp(MetaOps.KERNEL, val) if isinstance(val:=eval(ast_str), tuple) else val +def ast_str_to_ast(ast_str:str) -> UOp: return eval(ast_str) def ast_str_to_lin(ast_str:str, opts=None): return Kernel(ast_str_to_ast(ast_str), opts=opts) def kern_str_to_lin(kern_str:str, opts=None): (ast, applied_opts,) = eval(kern_str) @@ -28,7 +28,7 @@ from tinygrad.helpers import dedup def load_worlds(filter_reduce=True, filter_noimage=True, filter_novariable=True): fn = Path(__file__).parent.parent / "datasets/sops.gz" ast_strs = dedup(gzip.open(fn).read().decode('utf-8').strip().split("\n")) - if filter_reduce: ast_strs = [x for x in ast_strs if "ReduceOps" in x] + if filter_reduce: ast_strs = [x for x in ast_strs if "REDUCE_AXIS" in x] if filter_noimage: ast_strs = [x for x in ast_strs if "dtypes.image" not in x] if filter_novariable: ast_strs = [x for x in ast_strs if "Variable" not in x] random.seed(1337) diff --git a/extra/to_movement_ops.py b/extra/to_movement_ops.py index f68902a56b..3545799712 100644 --- a/extra/to_movement_ops.py +++ b/extra/to_movement_ops.py @@ -3,9 +3,8 @@ from enum import Enum, auto from collections import defaultdict from typing import List, Tuple, DefaultDict from extra.optimization.helpers import load_worlds, ast_str_to_ast -from extra.ops import LazyOp from tinygrad.helpers import prod, tqdm -from tinygrad.ops import UOps +from tinygrad.ops import UOp, UOps from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.symbolic import sym_infer, Node @@ -136,7 +135,7 @@ def test_rebuild(st: ShapeTracker): last_v2 = rebuilt_st.views[-1] assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}" -def test_rebuild_bufferop_st(ast:LazyOp): +def test_rebuild_bufferop_st(ast:UOp): if ast.op is UOps.SHAPETRACKER: test_rebuild(ast.arg) for src in ast.src: test_rebuild_bufferop_st(src)