mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
use the PatternMatcher to validate UOps type [run_process_replay] (#6583)
* use the PatternMatcher to validate UOps type [run_process_replay] * type check tests pass * DEFINE_VAR * fix precommit * fix tests * ptx * type check tests pass * ptx test * int64 * ptx barrier * delete old stuff
This commit is contained in:
@@ -50,6 +50,7 @@ class TestPTXFailures(unittest.TestCase):
|
||||
ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0]
|
||||
np.testing.assert_equal(ret, [0, 1, 1, 1])
|
||||
|
||||
@unittest.skip("not still valid?")
|
||||
def test_gated_store_with_if(self):
|
||||
a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0)
|
||||
gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0)
|
||||
|
||||
346
tinygrad/ops.py
346
tinygrad/ops.py
@@ -87,255 +87,41 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps.
|
||||
class UOps(FastEnum):
|
||||
# uops that aren't rendered
|
||||
SINK = auto()
|
||||
"""
|
||||
Holds `UOps.STORE`. SINK defines the AST for a Kernel.
|
||||
|
||||
- **`dtype`**: `dtypes.void`
|
||||
- **`src`**: `Tuple[UOp, ...]`, Only global STOREs are allowed.
|
||||
- **`arg`**: `Optional[KernelInfo]`
|
||||
|
||||
NOTE: `ScheduleItem` ASTs do not have the `KernelInfo` arg, `Kernel` inserts this to the SINK later.
|
||||
"""
|
||||
EXT = auto()
|
||||
"""
|
||||
Holds a single MetaOp. EXT UOps do not need a Kernel.
|
||||
|
||||
- **`dtype`**: Output DType
|
||||
- **`src`**: `Tuple[]`
|
||||
- **`arg`**: (`MetaOps.CUSTOM | MetaOps.COPY | MetaOps.EMPTY | MetaOps.VIEW`, LazyBuffer arg)
|
||||
"""
|
||||
EXPAND = auto()
|
||||
CONTRACT = auto()
|
||||
SHAPETRACKER = auto()
|
||||
"""
|
||||
Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.VALID`.
|
||||
|
||||
- **`dtype`**: `dtypes.void`
|
||||
- **`src`**: `Tuple[]`
|
||||
- **`arg`**: `ShapeTracker`
|
||||
"""
|
||||
SWIZZLE = auto()
|
||||
"""
|
||||
Swizzle inserts a movement op between a UOp and its children. Because movement ops (reshape, expand, shrink, permute, pad) are not allowed in an AST,
|
||||
the scheduler rewrites SWIZZLE by pushing its ShapeTracker through reduceops or elementwise ops to the edges of the graph.
|
||||
|
||||
This movement op can push up to the LOADs and/or down to the STOREs.
|
||||
|
||||
Example:
|
||||
```python
|
||||
a = Tensor.empty(32, 32)
|
||||
first_reduce = a.sum()
|
||||
output = (a + first_reduce).sum()
|
||||
```
|
||||
`first_reduce` must broadcast to `(32, 32)` before ADD. We UOp this as:
|
||||
|
||||
```
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
||||
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
x3,
|
||||
UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
||||
```
|
||||
|
||||
The scheduler rewrites this by pushing the expand in SWIZZLE through the reduce, to the LOAD:
|
||||
|
||||
```diff
|
||||
UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=(
|
||||
- UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=(
|
||||
- UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=(
|
||||
- UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
- x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
||||
- UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),
|
||||
+ UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2, 3)), src=(
|
||||
+ UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
+ x2:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()),
|
||||
+ UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)),
|
||||
UOp(UOps.LOAD, dtypes.int, arg=None, src=(
|
||||
- x3,
|
||||
- UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
||||
+ x2,
|
||||
+ UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),))
|
||||
|
||||
```
|
||||
|
||||
NOTE: Pushing a SWIZZLE through a reduce changes the axis.
|
||||
|
||||
NOTE: Pushing a SWIZZLE changes the output shape of that UOp. We have to reshape every other adjacent node. eg. reshape of the second LOAD to `(32, 32, 1, 1)` above.
|
||||
|
||||
- **`dtype`**: Output DType
|
||||
- **`src`**: `Tuple[UOp]`, a single UOp to swizzle.
|
||||
- **`arg`**: ShapeTracker
|
||||
""" # noqa E501
|
||||
DEFINE_GLOBAL = auto()
|
||||
DEFINE_VAR = auto()
|
||||
DEFINE_LOCAL = auto()
|
||||
DEFINE_ACC = auto()
|
||||
VCONST = auto()
|
||||
CONST = auto()
|
||||
"""
|
||||
Defines a single scalar constant value.
|
||||
|
||||
- **`dtype`**: The scalar DType of the value.
|
||||
|
||||
- **`src`**: `Tuple[]`
|
||||
|
||||
- **`arg`**: The value.
|
||||
"""
|
||||
VALID = auto()
|
||||
"""
|
||||
This is the first argument in a masked CONST.
|
||||
|
||||
- **`dtype`**: `dtypes.bool`
|
||||
- **`src`**:
|
||||
`Tuple[UOp]`
|
||||
- UOps.SHAPETRACKER
|
||||
- **`arg`**: `None`
|
||||
|
||||
A masked CONST is defined as `valid.where(value, 0)`.
|
||||
"""
|
||||
SPECIAL = auto()
|
||||
NOOP = auto()
|
||||
GEP = auto()
|
||||
|
||||
# math ops
|
||||
CAST = auto()
|
||||
"""
|
||||
- **`dtype`**: The casted scalar DType
|
||||
- **`src`**: `Tuple[UOp]`
|
||||
- **`arg`**: `None`
|
||||
"""
|
||||
BITCAST = auto()
|
||||
"""
|
||||
- **`dtype`**: The bitcasted scalar DType
|
||||
- **`src`**: `Tuple[UOp]`
|
||||
- **`arg`**: `None`
|
||||
"""
|
||||
VECTORIZE = auto()
|
||||
"""
|
||||
- **`dtype`**: The upcasted vector DType
|
||||
- **`src`**: `Tuple[UOp, ...]`
|
||||
- **`arg`**: `None`
|
||||
|
||||
NOTE: Length of sources must match `dtype.count`
|
||||
"""
|
||||
ALU = auto()
|
||||
"""
|
||||
- **`dtype`**: Output DType
|
||||
- **`src`**: `Tuple[UOp] | Tuple[UOp, UOp] | Tuple[UOp, UOp, UOp]`
|
||||
- **`arg`**: `UnaryOps | BinaryOps | TernaryOps`
|
||||
"""
|
||||
REDUCE = auto()
|
||||
REDUCE_AXIS = auto()
|
||||
"""
|
||||
- **`dtype`**: Output DType
|
||||
- **`src`**: Input to reduce `Tuple[UOp]`
|
||||
- **`arg`**: `(BinaryOps.ADD | BinaryOps.MUL | BinaryOps.MAX, Tuple[int, ...])`
|
||||
"""
|
||||
WMMA = auto()
|
||||
|
||||
# memory/assignment ops
|
||||
LOAD = auto()
|
||||
"""
|
||||
- **`dtype`**: Output DType
|
||||
- **`src`**:
|
||||
|
||||
The scheduler and Kernel create LOADs with a SHAPETRACKER uop in src.
|
||||
|
||||
- Normal LOAD: `Tuple[UOp, UOp]`
|
||||
- Buffer UOp `UOps.DEFINE_GLOBAL`.
|
||||
- SHAPETRACKER UOp.
|
||||
|
||||
- Local LOAD: `Tuple[UOp, UOp, UOp]`
|
||||
- Buffer UOp `UOps.DEFINE_LOCAL`.
|
||||
- SHAPETRACKER UOp.
|
||||
- Local UOps.STORE to the same local buffer. We will barrier this later.
|
||||
|
||||
The Lowerer replaces the SHAPETRACKER with an indexing uop and gates the LOAD if needed.
|
||||
|
||||
- Normal LOAD: `Tuple[UOp, UOp]`
|
||||
- Buffer UOp `UOps.DEFINE_GLOBAL`.
|
||||
- Indexing UOp, can only return `dtypes.int32`.
|
||||
- Gated LOAD: `Tuple[UOp, UOp, UOp, UOp]`
|
||||
- Buffer UOp `UOps.DEFINE_GLOBAL`.
|
||||
- Indexing UOp, can only return `dtypes.int32`.
|
||||
- Gate UOp, can only return `dtypes.bool`.
|
||||
- Value if gate is `False`, can only be a `UOps.CONST` with arg 0, 0.0 or `False`.
|
||||
- Barriered LOAD: `Tuple[UOp, UOp, UOp, UOp]`
|
||||
- Buffer UOp `UOps.DEFINE_LOCAL`.
|
||||
- Indexing UOp, can only return `dtypes.int32`.
|
||||
- Gate UOp, can only return `dtypes.bool`.
|
||||
- Barrier UOp `UOps.BARRIER`.
|
||||
- **`arg`**: `None`
|
||||
"""
|
||||
STORE = auto()
|
||||
"""
|
||||
- **`dtype`**: `dtypes.void`
|
||||
- **`src`**:
|
||||
|
||||
Similar to LOAD, the scheduler and Kernel create STOREs with a SHAPETRACKER uop in src:
|
||||
|
||||
- Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`.
|
||||
- SHAPETRACKER UOp.
|
||||
- Value to store.
|
||||
|
||||
The Lowerer replaces the SHAPETRACKER with an indexing uop and gates the STORE if needed.
|
||||
|
||||
- Normal STORE: `Tuple[UOp, UOp, UOp]`
|
||||
- Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`.
|
||||
- Indexing UOp, can only return `dtypes.int32`.
|
||||
- Value to store.
|
||||
- Gated STORE: `Tuple[UOp, UOp, UOp, UOp]`
|
||||
- Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`.
|
||||
- Indexing UOp, can only return `dtypes.int32`.
|
||||
- Value to store.
|
||||
- Gate UOp, can only return `dtypes.bool`. We rewrite this to an IF block in the end.
|
||||
- **`arg`**: `None`
|
||||
"""
|
||||
ASSIGN = auto()
|
||||
|
||||
# control flow ops
|
||||
BARRIER = auto()
|
||||
"""
|
||||
Inserts a warp sync between local stores and local loads.
|
||||
|
||||
- **`dtype`**: `dtypes.void`
|
||||
- **`src`**: `Tuple[UOp, ...]`, Only local STOREs are allowed.
|
||||
- **`arg`**: `None`
|
||||
"""
|
||||
IF = auto()
|
||||
"""
|
||||
Gates a single STORE to global memory. The IF block could also contain additional UOps the STORE depends on.
|
||||
|
||||
- **`dtype`**: `dtypes.void`
|
||||
- **`src`**:
|
||||
`Tuple[UOp, UOp]`
|
||||
- Gate UOp, can only return `dtypes.bool`
|
||||
- The second UOp starts the gate block; All of its children are gated until the final STORE.
|
||||
- **`arg`**: `None`
|
||||
|
||||
For example, a local reduce must only run on one thread.
|
||||
|
||||
The STORE's IF gate:
|
||||
```
|
||||
UOp(UOps.IF, src=(
|
||||
UOp(UOps.ALU, dtypes.bool, (...), BinaryOps.CMPNE),
|
||||
UOp(UOps.BARRIER, dtypes.void, (...))))
|
||||
```
|
||||
The kernel:
|
||||
```
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (lidx0!=1) {
|
||||
int acc1 = 0;
|
||||
for (int ridx1 = 0; ridx1 < 16; ridx1++) {
|
||||
int val1 = temp1[ridx1];
|
||||
acc1 = (acc1+val1);
|
||||
}
|
||||
data0[0] = acc1;
|
||||
}
|
||||
```
|
||||
"""
|
||||
RANGE = auto()
|
||||
|
||||
# ops that are not graph nodes
|
||||
ENDRANGE = auto()
|
||||
ENDIF = auto()
|
||||
@@ -537,52 +323,6 @@ def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src)))
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
def type_verify(uops):
|
||||
for u in uops:
|
||||
uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype
|
||||
if uop is UOps.DEFINE_LOCAL: assert isinstance(dtype, PtrDType), f"invalid dtype for local buffer {dtype}"
|
||||
if uop is UOps.DEFINE_GLOBAL: assert isinstance(dtype, (PtrDType, ImageDType)), f"invalid dtype for global buffer {dtype}"
|
||||
if isinstance(dtype, ImageDType): assert uop is UOps.DEFINE_GLOBAL, f"{uop} can't be image"
|
||||
if uop is UOps.SHAPETRACKER: assert len(src) == 0, f"SHAPETRACKER must only define a ShapeTracker arg {uop}"
|
||||
if uop is UOps.REDUCE_AXIS: assert isinstance(arg, tuple) and len(arg) == 2 and arg[0] in BinaryOps, f"invalid arg for REDUCE_AXIS {arg}"
|
||||
if uop in {UOps.CONST, UOps.DEFINE_ACC}:
|
||||
if uop is UOps.CONST:
|
||||
assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}"
|
||||
# TODO: intermediate CONST of Variable is DEFINE_VAR
|
||||
assert (isinstance(arg, Variable) and u.src) or (type(arg) is type(dtypes.as_const(arg, dtype))), f"type of {arg=} does not match {dtype}"
|
||||
if uop is UOps.DEFINE_ACC: assert dtype != dtypes.void and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}"
|
||||
if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype != dtypes.void # type is the output type, not an arg
|
||||
if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1
|
||||
if uop is UOps.VECTORIZE:
|
||||
assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}"
|
||||
assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}"
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype
|
||||
if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}"
|
||||
if uop is UOps.IF: assert dtype == dtypes.void and len(src) == 2 and src[0].dtype == dtypes.bool
|
||||
if uop is UOps.VALID: assert dtype == dtypes.bool and len(src) == 1 and src[0].op is UOps.SHAPETRACKER and arg is None
|
||||
if uop is UOps.STORE:
|
||||
assert dtype == dtypes.void, f"{uop} dtype must be void, got {dtype}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool or src[3].op is UOps.IF, f"bad gate {src[3]}"
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
|
||||
bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool
|
||||
assert dtype == bd, f"{arg} output dtype mismatch {dtype=} != {bd=}"
|
||||
assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg is BinaryOps.IDIV:
|
||||
assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}"
|
||||
assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}"
|
||||
elif arg in {BinaryOps.SHL, BinaryOps.SHR}:
|
||||
# the distance to shift isn't typechecked
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}"
|
||||
elif arg == TernaryOps.WHERE:
|
||||
bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool
|
||||
assert src[0].dtype == bd, f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}"
|
||||
assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}"
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
def print_uops(uops:List[UOp]):
|
||||
@@ -806,3 +546,81 @@ class RewriteContext:
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(l:=get_location())[0].split('/')[-1]}:{l[1]}", sink, []))
|
||||
return RewriteContext(pm).rewrite(sink)
|
||||
|
||||
# ***** uop type spec *****
|
||||
|
||||
# this is the matcher for the final rendered UOps
|
||||
# matcher functions returns True or False (or None to not match)
|
||||
spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [
|
||||
(UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType))),
|
||||
(UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType)),
|
||||
(UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True),
|
||||
lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype),
|
||||
(UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], UOp) and isinstance(x.arg[2], UOp)),
|
||||
|
||||
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(UOps.SPECIAL, src=()), lambda: True),
|
||||
|
||||
# TODO: confirm the args of both of these are shapetrackers
|
||||
(UPat(UOps.SHAPETRACKER, src=()), lambda: True),
|
||||
(UPat(UOps.SWIZZLE, src=(UPat(),)), lambda: True),
|
||||
|
||||
(UPat(UOps.CONST, name="x"),
|
||||
lambda x: x.dtype == x.dtype.scalar() and (isinstance(x.arg, Variable) and x.src) or (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))),
|
||||
|
||||
# early LOAD has a <buf, shapetracker, store?>
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.SHAPETRACKER))), lambda: True),
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.SHAPETRACKER), UPat(UOps.STORE))), lambda: True),
|
||||
|
||||
# LOAD takes a <buf, idx, alt?, gate?, barrier?>
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True),
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat((UOps.IF, UOps.BARRIER)))), lambda: True),
|
||||
(UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"),
|
||||
lambda ld,alt: ld.dtype == alt.dtype),
|
||||
|
||||
# STORE takes a <buf, idx, val, gate?>
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat())), lambda: True),
|
||||
(UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
|
||||
# most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE
|
||||
(UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE),
|
||||
lambda w,x,y: w.dtype == x.dtype == y.dtype),
|
||||
(UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype),
|
||||
(UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype),
|
||||
# and SHL/SHR, the shift distance is an int
|
||||
(UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHL), lambda alu,x: alu.dtype == x.dtype),
|
||||
(UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHR), lambda alu,x: alu.dtype == x.dtype),
|
||||
(UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False),
|
||||
(UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)),
|
||||
|
||||
(UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_ACC), UPat())), lambda: True),
|
||||
(UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True),
|
||||
|
||||
# early WMMA has 2 args, <x, w>
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat())), lambda: True),
|
||||
# late WMMA has 3 args, <x, w, acc>
|
||||
(UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True),
|
||||
|
||||
# if has a <gate, barrier>
|
||||
(UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True),
|
||||
(UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True),
|
||||
|
||||
(UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in BinaryOps),
|
||||
(UPat(UOps.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()),
|
||||
(UPat(UOps.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)),
|
||||
(UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None and x.dtype.count == 1),
|
||||
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(UOps.DEFINE_LOCAL),), allow_any_len=True)), lambda: True),
|
||||
|
||||
# NOTE: for testing, we let sinks be anything
|
||||
#(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True),
|
||||
(UPat(UOps.SINK), lambda: True),
|
||||
|
||||
# PTX LOAD/STORE
|
||||
(UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True),
|
||||
(UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True),
|
||||
]])
|
||||
|
||||
def type_verify(uops:List[UOp]):
|
||||
for u in uops:
|
||||
chk = spec.rewrite(u)
|
||||
assert chk is not None and chk.arg is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"
|
||||
|
||||
Reference in New Issue
Block a user