From cf64f8bb40463d6a77275e052ff1baba670bde42 Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Tue, 10 Sep 2024 08:00:19 +0800 Subject: [PATCH] start with the UOps.VALID spec [run_process_replay] (#6435) * document UOps.VALID [run_process_replay] * now the assert --- tinygrad/ops.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index eae787b776..19e2ff8787 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -99,7 +99,7 @@ class UOps(HashEnum): CONTRACT = auto() SHAPETRACKER = auto() """ - Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.CONST`. + Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.VALID`. - **`dtype`**: `None` - **`src`**: `Tuple[]` @@ -169,14 +169,22 @@ class UOps(HashEnum): - **`dtype`**: The scalar DType of the value. - - **`src`**: - The scheduler creates a CONST with a single SHAPETRACKER UOp src: `Tuple[UOp]`. - - The Lowerer replaces the SHAPETRACKER with an empty src. - It uses the ShapeTracker valid to create a `WHERE` UOp mask with sources: `(The actual CONST UOp, CONST 0, 0.0 or False)` + - **`src`**: `Tuple[]` - **`arg`**: The value. """ + VALID = auto() + """ + This is the first argument in a masked CONST. + + - **`dtype`**: `dtypes.bool` + - **`src`**: + `Tuple[UOp]` + - UOps.SHAPETRACKER + - **`arg`**: `None` + + A masked CONST is defined as `valid.where(value, 0)`. + """ SPECIAL = auto() NOOP = auto() GEP = auto() @@ -503,6 +511,7 @@ def type_verify(uops): if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}" if uop is UOps.IF: assert dtype is None and len(src) == 2 and src[0].dtype == dtypes.bool + if uop is UOps.VALID: assert dtype == dtypes.bool and len(src) == 1 and src[0].op is UOps.SHAPETRACKER and arg is None if uop is UOps.STORE: assert dtype is None, f"{uop} dtype must be None, got {dtype}" if len(src) == 4: assert src[3].dtype == dtypes.bool or src[3].op is UOps.IF, f"bad gate {src[3]}"