start with the UOps.VALID spec [run_process_replay] (#6435)

* document UOps.VALID [run_process_replay]

* now the assert
This commit is contained in:
qazal
2024-09-10 08:00:19 +08:00
committed by GitHub
parent 58a1b4f427
commit cf64f8bb40

View File

@@ -99,7 +99,7 @@ class UOps(HashEnum):
CONTRACT = auto()
SHAPETRACKER = auto()
"""
Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.CONST`.
Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.VALID`.
- **`dtype`**: `None`
- **`src`**: `Tuple[]`
@@ -169,14 +169,22 @@ class UOps(HashEnum):
- **`dtype`**: The scalar DType of the value.
- **`src`**:
The scheduler creates a CONST with a single SHAPETRACKER UOp src: `Tuple[UOp]`.
The Lowerer replaces the SHAPETRACKER with an empty src.
It uses the ShapeTracker valid to create a `WHERE` UOp mask with sources: `(The actual CONST UOp, CONST 0, 0.0 or False)`
- **`src`**: `Tuple[]`
- **`arg`**: The value.
"""
VALID = auto()
"""
This is the first argument in a masked CONST.
- **`dtype`**: `dtypes.bool`
- **`src`**:
`Tuple[UOp]`
- UOps.SHAPETRACKER
- **`arg`**: `None`
A masked CONST is defined as `valid.where(value, 0)`.
"""
SPECIAL = auto()
NOOP = auto()
GEP = auto()
@@ -503,6 +511,7 @@ def type_verify(uops):
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
if uop is UOps.IF: assert dtype is None and len(src) == 2 and src[0].dtype == dtypes.bool
if uop is UOps.VALID: assert dtype == dtypes.bool and len(src) == 1 and src[0].op is UOps.SHAPETRACKER and arg is None
if uop is UOps.STORE:
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: assert src[3].dtype == dtypes.bool or src[3].op is UOps.IF, f"bad gate {src[3]}"