diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index e87ad2fb57..418c304ed6 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -50,6 +50,7 @@ class TestPTXFailures(unittest.TestCase): ret = _test_uop_result([], uops, local_size=[4, 1, 1])[0] np.testing.assert_equal(ret, [0, 1, 1, 1]) + @unittest.skip("not still valid?") def test_gated_store_with_if(self): a = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), (), 0) gate_alu = (lidx0:=UOp(UOps.SPECIAL, dtypes.int, (), ('lidx0', 4))).ne(0) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 877e726cd3..1e8de3aeb5 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -87,255 +87,41 @@ def identity_element(op:BinaryOps, dt:DType): return dtypes.as_const({BinaryOps. class UOps(FastEnum): # uops that aren't rendered SINK = auto() - """ - Holds `UOps.STORE`. SINK defines the AST for a Kernel. - - - **`dtype`**: `dtypes.void` - - **`src`**: `Tuple[UOp, ...]`, Only global STOREs are allowed. - - **`arg`**: `Optional[KernelInfo]` - - NOTE: `ScheduleItem` ASTs do not have the `KernelInfo` arg, `Kernel` inserts this to the SINK later. - """ EXT = auto() - """ - Holds a single MetaOp. EXT UOps do not need a Kernel. - - - **`dtype`**: Output DType - - **`src`**: `Tuple[]` - - **`arg`**: (`MetaOps.CUSTOM | MetaOps.COPY | MetaOps.EMPTY | MetaOps.VIEW`, LazyBuffer arg) - """ EXPAND = auto() CONTRACT = auto() SHAPETRACKER = auto() - """ - Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.VALID`. - - - **`dtype`**: `dtypes.void` - - **`src`**: `Tuple[]` - - **`arg`**: `ShapeTracker` - """ SWIZZLE = auto() - """ - Swizzle inserts a movement op between a UOp and its children. Because movement ops (reshape, expand, shrink, permute, pad) are not allowed in an AST, - the scheduler rewrites SWIZZLE by pushing its ShapeTracker through reduceops or elementwise ops to the edges of the graph. - - This movement op can push up to the LOADs and/or down to the STOREs. - - Example: - ```python - a = Tensor.empty(32, 32) - first_reduce = a.sum() - output = (a + first_reduce).sum() - ``` - `first_reduce` must broadcast to `(32, 32)` before ADD. We UOp this as: - - ``` - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), - UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - x3, - UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)) - ``` - - The scheduler rewrites this by pushing the expand in SWIZZLE through the reduce, to the LOAD: - - ```diff - UOp(UOps.ALU, dtypes.int, arg=BinaryOps.ADD, src=( - - UOp(UOps.SWIZZLE, dtypes.int, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=( - - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( - - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - - x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), - - UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), - + UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2, 3)), src=( - + UOp(UOps.LOAD, dtypes.int, arg=None, src=( - + x2:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), - + UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - - x3, - - UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)) - + x2, - + UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)) - - ``` - - NOTE: Pushing a SWIZZLE through a reduce changes the axis. - - NOTE: Pushing a SWIZZLE changes the output shape of that UOp. We have to reshape every other adjacent node. eg. reshape of the second LOAD to `(32, 32, 1, 1)` above. - - - **`dtype`**: Output DType - - **`src`**: `Tuple[UOp]`, a single UOp to swizzle. - - **`arg`**: ShapeTracker - """ # noqa E501 DEFINE_GLOBAL = auto() DEFINE_VAR = auto() DEFINE_LOCAL = auto() DEFINE_ACC = auto() VCONST = auto() CONST = auto() - """ - Defines a single scalar constant value. - - - **`dtype`**: The scalar DType of the value. - - - **`src`**: `Tuple[]` - - - **`arg`**: The value. - """ VALID = auto() - """ - This is the first argument in a masked CONST. - - - **`dtype`**: `dtypes.bool` - - **`src`**: - `Tuple[UOp]` - - UOps.SHAPETRACKER - - **`arg`**: `None` - - A masked CONST is defined as `valid.where(value, 0)`. - """ SPECIAL = auto() NOOP = auto() GEP = auto() + # math ops CAST = auto() - """ - - **`dtype`**: The casted scalar DType - - **`src`**: `Tuple[UOp]` - - **`arg`**: `None` - """ BITCAST = auto() - """ - - **`dtype`**: The bitcasted scalar DType - - **`src`**: `Tuple[UOp]` - - **`arg`**: `None` - """ VECTORIZE = auto() - """ - - **`dtype`**: The upcasted vector DType - - **`src`**: `Tuple[UOp, ...]` - - **`arg`**: `None` - - NOTE: Length of sources must match `dtype.count` - """ ALU = auto() - """ - - **`dtype`**: Output DType - - **`src`**: `Tuple[UOp] | Tuple[UOp, UOp] | Tuple[UOp, UOp, UOp]` - - **`arg`**: `UnaryOps | BinaryOps | TernaryOps` - """ REDUCE = auto() REDUCE_AXIS = auto() - """ - - **`dtype`**: Output DType - - **`src`**: Input to reduce `Tuple[UOp]` - - **`arg`**: `(BinaryOps.ADD | BinaryOps.MUL | BinaryOps.MAX, Tuple[int, ...])` - """ WMMA = auto() + # memory/assignment ops LOAD = auto() - """ - - **`dtype`**: Output DType - - **`src`**: - - The scheduler and Kernel create LOADs with a SHAPETRACKER uop in src. - - - Normal LOAD: `Tuple[UOp, UOp]` - - Buffer UOp `UOps.DEFINE_GLOBAL`. - - SHAPETRACKER UOp. - - - Local LOAD: `Tuple[UOp, UOp, UOp]` - - Buffer UOp `UOps.DEFINE_LOCAL`. - - SHAPETRACKER UOp. - - Local UOps.STORE to the same local buffer. We will barrier this later. - - The Lowerer replaces the SHAPETRACKER with an indexing uop and gates the LOAD if needed. - - - Normal LOAD: `Tuple[UOp, UOp]` - - Buffer UOp `UOps.DEFINE_GLOBAL`. - - Indexing UOp, can only return `dtypes.int32`. - - Gated LOAD: `Tuple[UOp, UOp, UOp, UOp]` - - Buffer UOp `UOps.DEFINE_GLOBAL`. - - Indexing UOp, can only return `dtypes.int32`. - - Gate UOp, can only return `dtypes.bool`. - - Value if gate is `False`, can only be a `UOps.CONST` with arg 0, 0.0 or `False`. - - Barriered LOAD: `Tuple[UOp, UOp, UOp, UOp]` - - Buffer UOp `UOps.DEFINE_LOCAL`. - - Indexing UOp, can only return `dtypes.int32`. - - Gate UOp, can only return `dtypes.bool`. - - Barrier UOp `UOps.BARRIER`. - - **`arg`**: `None` - """ STORE = auto() - """ - - **`dtype`**: `dtypes.void` - - **`src`**: - - Similar to LOAD, the scheduler and Kernel create STOREs with a SHAPETRACKER uop in src: - - - Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`. - - SHAPETRACKER UOp. - - Value to store. - - The Lowerer replaces the SHAPETRACKER with an indexing uop and gates the STORE if needed. - - - Normal STORE: `Tuple[UOp, UOp, UOp]` - - Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`. - - Indexing UOp, can only return `dtypes.int32`. - - Value to store. - - Gated STORE: `Tuple[UOp, UOp, UOp, UOp]` - - Buffer UOp `UOps.DEFINE_GLOBAL` or `UOps.DEFINE_LOCAL`. - - Indexing UOp, can only return `dtypes.int32`. - - Value to store. - - Gate UOp, can only return `dtypes.bool`. We rewrite this to an IF block in the end. - - **`arg`**: `None` - """ ASSIGN = auto() + # control flow ops BARRIER = auto() - """ - Inserts a warp sync between local stores and local loads. - - - **`dtype`**: `dtypes.void` - - **`src`**: `Tuple[UOp, ...]`, Only local STOREs are allowed. - - **`arg`**: `None` - """ IF = auto() - """ - Gates a single STORE to global memory. The IF block could also contain additional UOps the STORE depends on. - - - **`dtype`**: `dtypes.void` - - **`src`**: - `Tuple[UOp, UOp]` - - Gate UOp, can only return `dtypes.bool` - - The second UOp starts the gate block; All of its children are gated until the final STORE. - - **`arg`**: `None` - - For example, a local reduce must only run on one thread. - - The STORE's IF gate: - ``` - UOp(UOps.IF, src=( - UOp(UOps.ALU, dtypes.bool, (...), BinaryOps.CMPNE), - UOp(UOps.BARRIER, dtypes.void, (...)))) - ``` - The kernel: - ``` - barrier(CLK_LOCAL_MEM_FENCE); - if (lidx0!=1) { - int acc1 = 0; - for (int ridx1 = 0; ridx1 < 16; ridx1++) { - int val1 = temp1[ridx1]; - acc1 = (acc1+val1); - } - data0[0] = acc1; - } - ``` - """ RANGE = auto() + # ops that are not graph nodes ENDRANGE = auto() ENDIF = auto() @@ -537,52 +323,6 @@ def uop_alu_resolve(u:UOp) -> sint: if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src))) raise RuntimeError(f"ALU resolve fail @ {u.op}") -# ***** uop type spec ***** - -def type_verify(uops): - for u in uops: - uop, arg, src, dtype = u.op, u.arg, u.src, u.dtype - if uop is UOps.DEFINE_LOCAL: assert isinstance(dtype, PtrDType), f"invalid dtype for local buffer {dtype}" - if uop is UOps.DEFINE_GLOBAL: assert isinstance(dtype, (PtrDType, ImageDType)), f"invalid dtype for global buffer {dtype}" - if isinstance(dtype, ImageDType): assert uop is UOps.DEFINE_GLOBAL, f"{uop} can't be image" - if uop is UOps.SHAPETRACKER: assert len(src) == 0, f"SHAPETRACKER must only define a ShapeTracker arg {uop}" - if uop is UOps.REDUCE_AXIS: assert isinstance(arg, tuple) and len(arg) == 2 and arg[0] in BinaryOps, f"invalid arg for REDUCE_AXIS {arg}" - if uop in {UOps.CONST, UOps.DEFINE_ACC}: - if uop is UOps.CONST: - assert dtype is not None and dtype == dtype.scalar(), f"consts must be scalar, got {dtype}" - # TODO: intermediate CONST of Variable is DEFINE_VAR - assert (isinstance(arg, Variable) and u.src) or (type(arg) is type(dtypes.as_const(arg, dtype))), f"type of {arg=} does not match {dtype}" - if uop is UOps.DEFINE_ACC: assert dtype != dtypes.void and src[0].dtype == dtype, f"dtype mismatch {src[0].dtype=} != {dtype=}" - if uop in {UOps.CAST, UOps.BITCAST, UOps.VECTORIZE}: assert arg is None and dtype != dtypes.void # type is the output type, not an arg - if uop is UOps.CAST: assert dtype.count == 1 and len(src) == 1 - if uop is UOps.VECTORIZE: - assert dtype.count > 1 and len(src) == dtype.count, f"dtype vectorization mismatch {dtype.count=} != {len(src)=}" - assert all(dtype == x.dtype.vec(len(src)) for x in src), f"{dtype=} must be {src[0].dtype.vec(len(src))}" - if uop is UOps.LOAD and len(src) > 3 and src[3].op is UOps.ALU: assert src[3].dtype == dtypes.bool and src[2].dtype == dtype - if uop is UOps.GEP: assert dtype == src[0].dtype.scalar(), f"GEP of {src[0].dtype=} should be {src[0].dtype.scalar()} != {dtype}" - if uop is UOps.IF: assert dtype == dtypes.void and len(src) == 2 and src[0].dtype == dtypes.bool - if uop is UOps.VALID: assert dtype == dtypes.bool and len(src) == 1 and src[0].op is UOps.SHAPETRACKER and arg is None - if uop is UOps.STORE: - assert dtype == dtypes.void, f"{uop} dtype must be void, got {dtype}" - if len(src) == 4: assert src[3].dtype == dtypes.bool or src[3].op is UOps.IF, f"bad gate {src[3]}" - if uop is UOps.ALU: - if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" - elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: - bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool - assert dtype == bd, f"{arg} output dtype mismatch {dtype=} != {bd=}" - assert src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" - elif arg is BinaryOps.IDIV: - assert dtypes.is_int(src[0].dtype) and dtypes.is_int(src[1].dtype), f"input dtype is not int {src[0].dtype=}, {src[1].dtype=}" - assert dtypes.is_int(dtype), f"output dtype is not int {dtype=}" - elif arg in {BinaryOps.SHL, BinaryOps.SHR}: - # the distance to shift isn't typechecked - assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" - elif arg in BinaryOps: assert dtype == src[0].dtype == src[1].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=} != {src[1].dtype=}" - elif arg == TernaryOps.WHERE: - bd = dtypes.bool.vec(dtype.count) if dtype.count != 1 else dtypes.bool - assert src[0].dtype == bd, f"{arg} selector dtype mismatch {src[0].dtype=} != {bd}" - assert dtype == src[1].dtype == src[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {src[1].dtype=} != {src[2].dtype=}" - # ***** uop helpers ***** def print_uops(uops:List[UOp]): @@ -806,3 +546,81 @@ class RewriteContext: def graph_rewrite(sink:UOp, pm:PatternMatcher) -> UOp: if TRACK_MATCH_STATS >= 2: contexts.append(TrackedRewriteContext(f"{(l:=get_location())[0].split('/')[-1]}:{l[1]}", sink, [])) return RewriteContext(pm).rewrite(sink) + +# ***** uop type spec ***** + +# this is the matcher for the final rendered UOps +# matcher functions returns True or False (or None to not match) +spec = PatternMatcher([(x, functools.partial(lambda fxn,**kw: UOp.const(dtypes.bool, r) if (r:=fxn(**kw)) is not None else None, y)) for (x,y) in [ + (UPat(UOps.DEFINE_GLOBAL, name="x"), lambda x: isinstance(x.dtype, (PtrDType, ImageDType))), + (UPat(UOps.DEFINE_LOCAL, name="x"), lambda x: isinstance(x.dtype, PtrDType)), + (UPat(UOps.DEFINE_ACC, src=(UPat(UOps.CONST, name="c"),), name="x", allow_any_len=True), + lambda x,c: all(y.op is UOps.RANGE for y in x.src[1:]) and c.dtype == x.dtype), + (UPat(UOps.DEFINE_VAR, src=(), name="x"), lambda x: isinstance(x.arg[1], UOp) and isinstance(x.arg[2], UOp)), + + (UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype), + (UPat(UOps.SPECIAL, src=()), lambda: True), + + # TODO: confirm the args of both of these are shapetrackers + (UPat(UOps.SHAPETRACKER, src=()), lambda: True), + (UPat(UOps.SWIZZLE, src=(UPat(),)), lambda: True), + + (UPat(UOps.CONST, name="x"), + lambda x: x.dtype == x.dtype.scalar() and (isinstance(x.arg, Variable) and x.src) or (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), + + # early LOAD has a + (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.SHAPETRACKER))), lambda: True), + (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(UOps.SHAPETRACKER), UPat(UOps.STORE))), lambda: True), + + # LOAD takes a + (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat())), lambda: True), + (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat((UOps.IF, UOps.BARRIER)))), lambda: True), + (UPat(UOps.LOAD, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(name="alt"), UPat(dtype=dtypes.bool)), name="ld"), + lambda ld,alt: ld.dtype == alt.dtype), + + # STORE takes a + (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat())), lambda: True), + (UPat(UOps.STORE, dtype=dtypes.void, src=(UPat((UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL)), UPat(), UPat(), UPat(dtype=dtypes.bool))), lambda: True), + + # most ALUs have all matching dtypes, except CMPLT, CMPNE, and WHERE + (UPat(UOps.ALU, name="w", src=(UPat(dtype=dtypes.bool), UPat(name="x"), UPat(name="y")), arg=TernaryOps.WHERE), + lambda w,x,y: w.dtype == x.dtype == y.dtype), + (UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPLT), lambda x,y: x.dtype == y.dtype), + (UPat(UOps.ALU, dtype=dtypes.bool, src=(UPat(name="x"), UPat(name="y")), arg=BinaryOps.CMPNE), lambda x,y: x.dtype == y.dtype), + # and SHL/SHR, the shift distance is an int + (UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHL), lambda alu,x: alu.dtype == x.dtype), + (UPat(UOps.ALU, src=(UPat(name="x"), UPat()), name="alu", arg=BinaryOps.SHR), lambda alu,x: alu.dtype == x.dtype), + (UPat(UOps.ALU, arg=BinaryOps.IDIV, name="x"), lambda x: None if dtypes.is_int(x.dtype) else False), + (UPat(UOps.ALU, name="x"), lambda x: all(x.dtype == y.dtype for y in x.src)), + + (UPat(UOps.ASSIGN, src=(UPat(UOps.DEFINE_ACC), UPat())), lambda: True), + (UPat(UOps.ENDRANGE, dtype=dtypes.void, src=(UPat(UOps.RANGE),)), lambda: True), + + # early WMMA has 2 args, + (UPat(UOps.WMMA, src=(UPat(), UPat())), lambda: True), + # late WMMA has 3 args, + (UPat(UOps.WMMA, src=(UPat(), UPat(), UPat())), lambda: True), + + # if has a + (UPat(UOps.IF, dtype=dtypes.void, src=(UPat(), UPat(UOps.BARRIER))), lambda: True), + (UPat(UOps.ENDIF, dtype=dtypes.void, src=(UPat(UOps.IF),)), lambda: True), + + (UPat(UOps.REDUCE_AXIS, name="x"), lambda x: isinstance(x.arg, tuple) and len(x.arg) == 2 and x.arg[0] in BinaryOps), + (UPat(UOps.GEP, src=(UPat(name="src"),), name="gep"), lambda gep,src: gep.dtype == src.dtype.scalar()), + (UPat(UOps.VECTORIZE, name="x"), lambda x: len(x.src)>1 and len(x.src) == x.dtype.count and all(x.dtype == y.dtype.vec(len(x.src)) for y in x.src)), + (UPat((UOps.BITCAST, UOps.CAST), src=(UPat(),), name="x"), lambda x: x.arg is None and x.dtype.count == 1), + (UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(UOps.DEFINE_LOCAL),), allow_any_len=True)), lambda: True), + + # NOTE: for testing, we let sinks be anything + #(UPat(UOps.SINK, src=UPat(UOps.STORE)), lambda: True), + (UPat(UOps.SINK), lambda: True), + + # PTX LOAD/STORE + (UPat((UOps.LOAD, UOps.STORE), src=(UPat(dtype=dtypes.int64),), allow_any_len=True), lambda: True), + (UPat(UOps.BARRIER, dtypes.void, src=UPat(UOps.STORE, src=(UPat(dtype=dtypes.int64),), allow_any_len=True)), lambda: True), +]]) + +def type_verify(uops:List[UOp]): + for u in uops: + chk = spec.rewrite(u) + assert chk is not None and chk.arg is True, f"UOp verification failed on {u.op} {u.dtype} {len(u.src)} {[x.op for x in u.src]} {u.arg}"