document UOps.IF [run_process_replay] (#6374)

* document UOps.IF [run_process_replay]

* this will be a block of STOREs after merge_gates

* now i can enable the assert

* more docs

* raw code block

* cname

* cleanup

Revert "cname"

This reverts commit d823f87561.
This commit is contained in:
qazal
2024-09-06 05:23:16 +08:00
committed by GitHub
parent c26744de9f
commit a63f53c28b

View File

@@ -270,7 +270,7 @@ class UOps(Enum):
- Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`.
- Indexing UOp, can only return `dtypes.int32`.
- Value to store.
- Gate UOp, can only return `dtypes.bool`.
- Gate UOp, can only return `dtypes.bool`. We rewrite this to an IF block in the end.
- **`arg`**: `None`
"""
PHI = auto()
@@ -284,6 +284,37 @@ class UOps(Enum):
- **`arg`**: `None`
"""
IF = auto()
"""
Gates a single STORE to global memory. The IF block could also contain additional UOps the STORE depends on.
- **`dtype`**: `None`
- **`src`**:
`Tuple[UOp, UOp]`
- Gate UOp, can only return `dtypes.bool`
- The second UOp starts the gate block; All of its children are gated until the final STORE.
- **`arg`**: `None`
For example, a local reduce must only run on one thread.
The STORE's IF gate:
```
UOp(UOps.IF, src=(
UOp(UOps.ALU, dtypes.bool, (...), BinaryOps.CMPNE),
UOp(UOps.BARRIER, None, (...))))
```
The kernel:
```
barrier(CLK_LOCAL_MEM_FENCE);
if (lidx0!=1) {
int acc1 = 0;
for (int ridx1 = 0; ridx1 < 16; ridx1++) {
int val1 = temp1[ridx1];
acc1 = (acc1+val1);
}
data0[0] = acc1;
}
```
"""
RANGE = auto()
# ops that are not graph nodes
ENDRANGE = auto()
@@ -634,6 +665,7 @@ def type_verify(uops):
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
if uop is UOps.IF: assert dtype is None and len(src) == 2 and src[0].dtype == dtypes.bool
if uop is UOps.STORE:
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: