mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 14:58:46 -05:00
start with the UOps.VALID spec [run_process_replay] (#6435)
* document UOps.VALID [run_process_replay] * now the assert
This commit is contained in:
@@ -99,7 +99,7 @@ class UOps(HashEnum):
|
||||
CONTRACT = auto()
|
||||
SHAPETRACKER = auto()
|
||||
"""
|
||||
Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.CONST`.
|
||||
Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.VALID`.
|
||||
|
||||
- **`dtype`**: `None`
|
||||
- **`src`**: `Tuple[]`
|
||||
@@ -169,14 +169,22 @@ class UOps(HashEnum):
|
||||
|
||||
- **`dtype`**: The scalar DType of the value.
|
||||
|
||||
- **`src`**:
|
||||
The scheduler creates a CONST with a single SHAPETRACKER UOp src: `Tuple[UOp]`.
|
||||
|
||||
The Lowerer replaces the SHAPETRACKER with an empty src.
|
||||
It uses the ShapeTracker valid to create a `WHERE` UOp mask with sources: `(The actual CONST UOp, CONST 0, 0.0 or False)`
|
||||
- **`src`**: `Tuple[]`
|
||||
|
||||
- **`arg`**: The value.
|
||||
"""
|
||||
VALID = auto()
|
||||
"""
|
||||
This is the first argument in a masked CONST.
|
||||
|
||||
- **`dtype`**: `dtypes.bool`
|
||||
- **`src`**:
|
||||
`Tuple[UOp]`
|
||||
- UOps.SHAPETRACKER
|
||||
- **`arg`**: `None`
|
||||
|
||||
A masked CONST is defined as `valid.where(value, 0)`.
|
||||
"""
|
||||
SPECIAL = auto()
|
||||
NOOP = auto()
|
||||
GEP = auto()
|
||||
@@ -503,6 +511,7 @@ def type_verify(uops):
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
|
||||
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
|
||||
if uop is UOps.IF: assert dtype is None and len(src) == 2 and src[0].dtype == dtypes.bool
|
||||
if uop is UOps.VALID: assert dtype == dtypes.bool and len(src) == 1 and src[0].op is UOps.SHAPETRACKER and arg is None
|
||||
if uop is UOps.STORE:
|
||||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool or src[3].op is UOps.IF, f"bad gate {src[3]}"
|
||||
|
||||
Reference in New Issue
Block a user