From 9bc81c6db41c058e7314b8630f346c7917418d40 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 16 Aug 2024 23:26:34 -0700 Subject: [PATCH] UOps.SHAPETRACKER (#6129) * UOps.SHAPETRACKER [run_process_replay] * no process replay --- docs/abstractions2.py | 7 +++---- extra/ops.py | 8 ++++---- tinygrad/codegen/kernel.py | 32 ++++++++++++++++---------------- tinygrad/codegen/lowerer.py | 2 +- tinygrad/engine/schedule.py | 13 ++++++------- tinygrad/ops.py | 8 ++++---- tinygrad/shape/shapetracker.py | 2 +- 7 files changed, 35 insertions(+), 37 deletions(-) diff --git a/docs/abstractions2.py b/docs/abstractions2.py index f43c78ad9d..301eac6054 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -51,12 +51,11 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc # describe the computation buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1) buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2) -ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *ShapeTracker.from_shape((1,)).to_uops())) -ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *ShapeTracker.from_shape((1,)).to_uops())) +ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop())) +ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop())) alu = ld_1 + ld_2 output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0) -idx, valid = ShapeTracker.from_shape((1,)).to_uops() -st_0 = UOp(UOps.STORE, None, (output_buf, idx, alu, valid)) +st_0 = UOp(UOps.STORE, None, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu)) s = UOp(UOps.SINK, None, (st_0,)) # convert the computation to a "linearized" format (print the format) diff --git a/extra/ops.py b/extra/ops.py index c3703b60fd..3ce45f6c90 100644 --- a/extra/ops.py +++ b/extra/ops.py @@ -115,14 +115,14 @@ def to_uop(*a) -> UOp: @functools.lru_cache(None) def create_uop(lop:LazyOp) -> UOp: if lop.op in BufferOps: - idx, valid = lop.arg.st.to_uops() + st_uop = lop.arg.st.to_uop() membuf_dtype: DType = lop.arg.dtype dtype = membuf_dtype.base if isinstance(membuf_dtype, ImageDType) else membuf_dtype if lop.op is BufferOps.CONST: - return UOp(UOps.CONST, dtype, (UOp(UOps.CONST, dtype, (valid,), 0), valid), lop.arg.val) + return UOp(UOps.CONST, dtype, (st_uop,), lop.arg.val) buf = UOp(UOps.DEFINE_GLOBAL, membuf_dtype if isinstance(membuf_dtype, ImageDType) else PtrDType(membuf_dtype), (), lop.arg.idx) - if lop.op is BufferOps.LOAD: return UOp(UOps.LOAD, dtype, (buf, idx, valid)) - return UOp(UOps.STORE, None, (buf, idx, create_uop(lop.src[0]), valid)) + if lop.op is BufferOps.LOAD: return UOp(UOps.LOAD, dtype, (buf, st_uop)) + return UOp(UOps.STORE, None, (buf, st_uop, create_uop(lop.src[0]))) src = tuple(create_uop(x) for x in lop.src) if lop.op is MetaOps.KERNEL: return UOp(UOps.SINK, None, src) if lop.op in ReduceOps: return UOp(UOps.REDUCE_AXIS, src[0].dtype, src, (lop.op, lop.arg)) diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 7434525454..9a63eb2fa0 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -638,10 +638,10 @@ class Kernel: if op.op in BUFFER_UOPS: # for locals, we use the ShapeTracker that's in the srcs st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)] - idx, valid = (st if apply_to_st is None else apply_to_st(st)).to_uops() - if op.op is UOps.CONST: return replace(op, src=(valid,)) - if op.op is UOps.STORE: return replace(op, src=(op.src[0], idx, fixup_ast(op.src[2], apply_to_st), valid)) - return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src[:-2])+(idx, valid)) + st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop() + if op.op is UOps.CONST: return replace(op, src=(st_uop,)) + if op.op is UOps.STORE: return replace(op, src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st))) + return replace(op, src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]])) if op.op is UOps.REDUCE_AXIS: reduce_idx = len(self.bufs) + self.reduceops.index(op)*2 reduceop: Union[Literal[ReduceOps.SUM], Literal[ReduceOps.MAX]] = op.arg[0] @@ -697,10 +697,10 @@ class Kernel: for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])): st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD] local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape)) - idx, valid = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uops() - membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", idx.arg.real_size())) - local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, idx, src, valid)), fix_st_fxn) - srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, local_store, idx, valid))) + st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop() + membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", st_uop.arg.real_size())) + local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn) + srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store))) else: # for TC=2, we can't do the shapetracker fixup srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])] @@ -714,9 +714,9 @@ class Kernel: start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(op.arg[0], axis)) local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \ (1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)]) - idx, valid = ShapeTracker.from_shape(local_shape).to_uops() - local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(cast(DType, op.dtype)), (), ("temp1", idx.arg.real_size())) - local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, UOp.store(local_buffer, idx, start, valid), idx, valid)) + st_uop = ShapeTracker.from_shape(local_shape).to_uop() + local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(cast(DType, op.dtype)), (), ("temp1", st_uop.arg.real_size())) + local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start))) second_axis = tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces)) return UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis)) arg = (reduceop, axis) @@ -760,7 +760,7 @@ class Kernel: return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes, global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None) -# the living definition of UOps.ST_IDX and UOps.ST_VALID +# the living definition of UOps.SHAPETRACKER def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]: assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK" sts: Dict[UOp, ShapeTracker] = {} @@ -768,15 +768,15 @@ def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]: if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return # restore globals from the two stage reduce if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL: - assert_valid(local_reduce:=op.src[1].src[2], op.st_arg) + assert_valid(local_reduce:=op.src[2].src[2], op.st_arg) return sts.setdefault(op, sts[local_reduce]) for x in op.src: assert_valid(x, st) # only reduceop is allowed to change shape, limited to turning n to 1 if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[op.src[0]].reduce(op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1])) else: - # movementops are pushed to the edges with ST_IDX, ST_VALID + # movementops are pushed to the edges with SHAPETRACKER # elementwise inherits shape - st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]] + st = op.arg if op.op is UOps.SHAPETRACKER else sts[op.src[-1]] for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src): if sts[x].shape != st.shape: if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}") @@ -785,4 +785,4 @@ def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]: for out in ast.src: assert_valid(out, out.st_arg) shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])] assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}" - return sts \ No newline at end of file + return sts diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index 575d6fbf6f..46018b7590 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -96,7 +96,7 @@ class IndependentLowerer: if x.op is UOps.CONST: return valid.where(UOp.const(x.dtype, x.arg), UOp.const(x.dtype, 0)) buf = x.src[0] if x.op is UOps.LOAD: - barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[1]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else () + barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else () return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((UOp.const(x.dtype, 0), valid) if has_valid else ()) + barrier) # NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!) if x.src[0].op is UOps.DEFINE_GLOBAL: diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index b3690e5713..70690245de 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -61,7 +61,6 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. # buffer ops define ShapeTracker if buf.realized is not None or (buf in realizes and buf not in outputs): unbound_st, st_var_vals = st.simplify().unbind() - idx, valid = UOp(UOps.ST_IDX, dtypes.pyint, (), unbound_st), UOp(UOps.ST_VALID, dtypes.bool, (), unbound_st) var_vals.update(st_var_vals) # if it's a const, we generate it if buf.op is MetaOps.CONST: @@ -69,7 +68,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. val, var_val = val.unbind() var_vals[val] = var_val else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}" - return UOp(UOps.CONST, dtype, (valid,), val) + return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val) # otherwise, it's a load and we add it to the inputs if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \ ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))): @@ -78,7 +77,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (), outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.setdefault(buf, len(inputs))) - return UOp(UOps.LOAD, dtype, (ubuf, idx, valid)) + return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop())) # reduce ops change ShapeTracker if buf.op in ReduceOps: @@ -147,9 +146,9 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]: - idx, valid = ShapeTracker.from_shape(out.arg).to_uops() - rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), idx, valid)) - wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), idx, rd, valid)) + st_uop = ShapeTracker.from_shape(out.arg).to_uop() + rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), st_uop)) + wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), st_uop, rd)) return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs]) if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs]) @@ -180,7 +179,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> output_st, vv = output_st.simplify().unbind() if vv: var_vals.update(vv) ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i) - ast.append(UOp(UOps.STORE, None, (ubuf, UOp(UOps.ST_IDX, dtypes.pyint, (), output_st), src, UOp(UOps.ST_VALID, dtypes.bool, (), output_st)))) + ast.append(UOp(UOps.STORE, None, (ubuf, output_st.to_uop(), src))) return LBScheduleItem(UOp(UOps.SINK, None, tuple(ast)), outs, list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 31e2ba9fdc..f6ee8786fd 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -78,7 +78,7 @@ def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x: # the order of these UOps controls the order of the toposort class UOps(Enum): # ops that aren't rendered - SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); ST_IDX = auto(); ST_VALID = auto() # noqa: E702 + SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); SHAPETRACKER = auto() # noqa: E702 DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702 CONST = auto(); SPECIAL = auto() # noqa: E702 NOOP = auto(); GEP = auto() # noqa: E702 @@ -120,8 +120,8 @@ class UOp: @property def st_arg(self) -> ShapeTracker: assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}" - ret = self.src[-1] - assert ret.op in {UOps.ST_IDX, UOps.ST_VALID}, f"st_arg trying to return {ret}" + ret = self.src[0 if self.op is UOps.CONST else 1] + assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}" return ret.arg def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,)) @@ -170,7 +170,7 @@ class UOp: def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None} @functools.cached_property def full_shape(self) -> Tuple[sint, ...]: - if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape + if self.op is UOps.SHAPETRACKER: return self.arg.shape # NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}])) def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR]) diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index 7e147173f3..54ef9458a7 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -64,7 +64,7 @@ class ShapeTracker: def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape)) - def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self) + def to_uop(self) -> UOp: return UOp(UOps.SHAPETRACKER, None, (), self) def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]: idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \