diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 5448dfb408..d0919c77be 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -84,7 +84,7 @@ class UOps(HashEnum): """ Holds `UOps.STORE`. SINK defines the AST for a Kernel. - - **`dtype`**: `None` + - **`dtype`**: `dtypes.void` - **`src`**: `Tuple[UOp, ...]`, Only global STOREs are allowed. - **`arg`**: `Optional[KernelInfo]` @@ -104,7 +104,7 @@ class UOps(HashEnum): """ Defines the ShapeTracker for a buffer UOp `UOps.LOAD`, `UOps.STORE` or `UOps.VALID`. - - **`dtype`**: `None` + - **`dtype`**: `dtypes.void` - **`src`**: `Tuple[]` - **`arg`**: `ShapeTracker` """ @@ -129,10 +129,10 @@ class UOps(HashEnum): UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( UOp(UOps.LOAD, dtypes.int, arg=None, src=( x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), - UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), + UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), UOp(UOps.LOAD, dtypes.int, arg=None, src=( x3, - UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)) + UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)) ``` The scheduler rewrites this by pushing the expand in SWIZZLE through the reduce, to the LOAD: @@ -143,16 +143,16 @@ class UOps(HashEnum): - UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (0, 1)), src=( - UOp(UOps.LOAD, dtypes.int, arg=None, src=( - x3:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), - - UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), + - UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)), + UOp(UOps.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2, 3)), src=( + UOp(UOps.LOAD, dtypes.int, arg=None, src=( + x2:=UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int), arg=1, src=()), - + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), + + UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32, 32, 32), strides=(0, 0, 32, 1), offset=0, mask=None, contiguous=False),)), src=()),)),)), UOp(UOps.LOAD, dtypes.int, arg=None, src=( - x3, - - UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)) + - UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32), strides=(32, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)) + x2, - + UOp(UOps.SHAPETRACKER, None, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)) + + UOp(UOps.SHAPETRACKER, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()),)),)) ``` @@ -265,7 +265,7 @@ class UOps(HashEnum): """ STORE = auto() """ - - **`dtype`**: `None` + - **`dtype`**: `dtypes.void` - **`src`**: Similar to LOAD, the scheduler and Kernel create STOREs with a SHAPETRACKER uop in src: @@ -293,7 +293,7 @@ class UOps(HashEnum): """ Inserts a warp sync between local stores and local loads. - - **`dtype`**: `None` + - **`dtype`**: `dtypes.void` - **`src`**: `Tuple[UOp, ...]`, Only local STOREs are allowed. - **`arg`**: `None` """ @@ -301,7 +301,7 @@ class UOps(HashEnum): """ Gates a single STORE to global memory. The IF block could also contain additional UOps the STORE depends on. - - **`dtype`**: `None` + - **`dtype`**: `dtypes.void` - **`src`**: `Tuple[UOp, UOp]` - Gate UOp, can only return `dtypes.bool` @@ -314,7 +314,7 @@ class UOps(HashEnum): ``` UOp(UOps.IF, src=( UOp(UOps.ALU, dtypes.bool, (...), BinaryOps.CMPNE), - UOp(UOps.BARRIER, None, (...)))) + UOp(UOps.BARRIER, dtypes.void, (...)))) ``` The kernel: ```