UOps.SHAPETRACKER (#6129)

* UOps.SHAPETRACKER [run_process_replay]

* no process replay
This commit is contained in:
George Hotz
2024-08-16 23:26:34 -07:00
committed by GitHub
parent 5048066e79
commit 9bc81c6db4
7 changed files with 35 additions and 37 deletions

View File

@@ -51,12 +51,11 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
# describe the computation
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1)
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2)
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *ShapeTracker.from_shape((1,)).to_uops()))
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *ShapeTracker.from_shape((1,)).to_uops()))
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, ShapeTracker.from_shape((1,)).to_uop()))
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, ShapeTracker.from_shape((1,)).to_uop()))
alu = ld_1 + ld_2
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
idx, valid = ShapeTracker.from_shape((1,)).to_uops()
st_0 = UOp(UOps.STORE, None, (output_buf, idx, alu, valid))
st_0 = UOp(UOps.STORE, None, (output_buf, ShapeTracker.from_shape((1,)).to_uop(), alu))
s = UOp(UOps.SINK, None, (st_0,))
# convert the computation to a "linearized" format (print the format)

View File

@@ -115,14 +115,14 @@ def to_uop(*a) -> UOp:
@functools.lru_cache(None)
def create_uop(lop:LazyOp) -> UOp:
if lop.op in BufferOps:
idx, valid = lop.arg.st.to_uops()
st_uop = lop.arg.st.to_uop()
membuf_dtype: DType = lop.arg.dtype
dtype = membuf_dtype.base if isinstance(membuf_dtype, ImageDType) else membuf_dtype
if lop.op is BufferOps.CONST:
return UOp(UOps.CONST, dtype, (UOp(UOps.CONST, dtype, (valid,), 0), valid), lop.arg.val)
return UOp(UOps.CONST, dtype, (st_uop,), lop.arg.val)
buf = UOp(UOps.DEFINE_GLOBAL, membuf_dtype if isinstance(membuf_dtype, ImageDType) else PtrDType(membuf_dtype), (), lop.arg.idx)
if lop.op is BufferOps.LOAD: return UOp(UOps.LOAD, dtype, (buf, idx, valid))
return UOp(UOps.STORE, None, (buf, idx, create_uop(lop.src[0]), valid))
if lop.op is BufferOps.LOAD: return UOp(UOps.LOAD, dtype, (buf, st_uop))
return UOp(UOps.STORE, None, (buf, st_uop, create_uop(lop.src[0])))
src = tuple(create_uop(x) for x in lop.src)
if lop.op is MetaOps.KERNEL: return UOp(UOps.SINK, None, src)
if lop.op in ReduceOps: return UOp(UOps.REDUCE_AXIS, src[0].dtype, src, (lop.op, lop.arg))

View File

@@ -638,10 +638,10 @@ class Kernel:
if op.op in BUFFER_UOPS:
# for locals, we use the ShapeTracker that's in the srcs
st = op.st_arg if op.src[0].op is UOps.DEFINE_LOCAL else self.sts[self.bufs.index(op)]
idx, valid = (st if apply_to_st is None else apply_to_st(st)).to_uops()
if op.op is UOps.CONST: return replace(op, src=(valid,))
if op.op is UOps.STORE: return replace(op, src=(op.src[0], idx, fixup_ast(op.src[2], apply_to_st), valid))
return replace(op, src=tuple(fixup_ast(x, apply_to_st) for x in op.src[:-2])+(idx, valid))
st_uop = (st if apply_to_st is None else apply_to_st(st)).to_uop()
if op.op is UOps.CONST: return replace(op, src=(st_uop,))
if op.op is UOps.STORE: return replace(op, src=(op.src[0], st_uop, fixup_ast(op.src[2], apply_to_st)))
return replace(op, src=(op.src[0], st_uop, *[fixup_ast(x, apply_to_st) for x in op.src[2:]]))
if op.op is UOps.REDUCE_AXIS:
reduce_idx = len(self.bufs) + self.reduceops.index(op)*2
reduceop: Union[Literal[ReduceOps.SUM], Literal[ReduceOps.MAX]] = op.arg[0]
@@ -697,10 +697,10 @@ class Kernel:
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
idx, valid = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uops()
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", idx.arg.real_size()))
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, idx, src, valid)), fix_st_fxn)
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, local_store, idx, valid)))
st_uop = ShapeTracker.from_shape(local_shape).expand(ex_shape).to_uop()
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", st_uop.arg.real_size()))
local_store = fixup_ast(UOp(UOps.STORE, tc.dtype_in, (membuf, st_uop, src)), fix_st_fxn)
srcs.append(UOp(UOps.LOAD, tc.dtype_in, (membuf, st_uop, local_store)))
else:
# for TC=2, we can't do the shapetracker fixup
srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])]
@@ -714,9 +714,9 @@ class Kernel:
start = UOp(UOps.REDUCE_AXIS, op.dtype, (fixup_ast(op.src[0], apply_to_st),), arg=(op.arg[0], axis))
local_shape = (1,) * self.global_dims + self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces] + \
(1,) * (self.first_upcast - self.group_for_reduces - self.first_reduce) + tuple([x[0] for x in self.upcasted_axis(0)])
idx, valid = ShapeTracker.from_shape(local_shape).to_uops()
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(cast(DType, op.dtype)), (), ("temp1", idx.arg.real_size()))
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, UOp.store(local_buffer, idx, start, valid), idx, valid))
st_uop = ShapeTracker.from_shape(local_shape).to_uop()
local_buffer = UOp(UOps.DEFINE_LOCAL, PtrDType(cast(DType, op.dtype)), (), ("temp1", st_uop.arg.real_size()))
local_load = UOp(UOps.LOAD, op.dtype, (local_buffer, st_uop, UOp.store(local_buffer, st_uop, start)))
second_axis = tuple(range(self.first_reduce, self.first_reduce+self.group_for_reduces))
return UOp(UOps.REDUCE_AXIS, op.dtype, (local_load,), arg=(op.arg[0], second_axis))
arg = (reduceop, axis)
@@ -760,7 +760,7 @@ class Kernel:
return Program(ansiname, src, self.opts.device, self.uops, mem_estimate=mem_bytes,
global_size=[1,1,1] if self.opts.has_local else None, local_size=[1,1,1] if self.opts.has_local else None)
# the living definition of UOps.ST_IDX and UOps.ST_VALID
# the living definition of UOps.SHAPETRACKER
def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
assert ast.op is UOps.SINK and all(x.op is UOps.STORE for x in ast.src), "must be SINK"
sts: Dict[UOp, ShapeTracker] = {}
@@ -768,15 +768,15 @@ def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
if op in sts or op.op in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL}: return
# restore globals from the two stage reduce
if op.op is UOps.LOAD and op.src[0].op is UOps.DEFINE_LOCAL:
assert_valid(local_reduce:=op.src[1].src[2], op.st_arg)
assert_valid(local_reduce:=op.src[2].src[2], op.st_arg)
return sts.setdefault(op, sts[local_reduce])
for x in op.src: assert_valid(x, st)
# only reduceop is allowed to change shape, limited to turning n to 1
if op.op is UOps.REDUCE_AXIS: st = ShapeTracker.from_shape(sts[op.src[0]].reduce(op.arg[1][-1] if op.arg[0] is ReduceOps.WMMA else op.arg[1]))
else:
# movementops are pushed to the edges with ST_IDX, ST_VALID
# movementops are pushed to the edges with SHAPETRACKER
# elementwise inherits shape
st = op.arg if op.op in {UOps.ST_IDX, UOps.ST_VALID} else sts[op.src[-1]]
st = op.arg if op.op is UOps.SHAPETRACKER else sts[op.src[-1]]
for x in (op.src[1:] if op.op in BUFFER_UOPS else op.src):
if sts[x].shape != st.shape:
if prod(sts[x].shape) == prod(st.shape): raise AssertionError(f"found implicit reshape {x.op} {op.op} {sts[x].shape} != {st.shape}")
@@ -785,4 +785,4 @@ def verify_ast(ast:UOp) -> Dict[UOp, ShapeTracker]:
for out in ast.src: assert_valid(out, out.st_arg)
shape_dims = [sorted(dedup(dims)) for dims in zip(*[x.shape for x in sts.values()])]
assert all(len(x) == 1 or (len(x) == 2 and x[0] == 1) for x in shape_dims), f"shapes must have either 1 or n in each dimension, {shape_dims}"
return sts
return sts

View File

@@ -96,7 +96,7 @@ class IndependentLowerer:
if x.op is UOps.CONST: return valid.where(UOp.const(x.dtype, x.arg), UOp.const(x.dtype, 0))
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[1]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[2]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((UOp.const(x.dtype, 0), valid) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!)
if x.src[0].op is UOps.DEFINE_GLOBAL:

View File

@@ -61,7 +61,6 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
# buffer ops define ShapeTracker
if buf.realized is not None or (buf in realizes and buf not in outputs):
unbound_st, st_var_vals = st.simplify().unbind()
idx, valid = UOp(UOps.ST_IDX, dtypes.pyint, (), unbound_st), UOp(UOps.ST_VALID, dtypes.bool, (), unbound_st)
var_vals.update(st_var_vals)
# if it's a const, we generate it
if buf.op is MetaOps.CONST:
@@ -69,7 +68,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
val, var_val = val.unbind()
var_vals[val] = var_val
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
return UOp(UOps.CONST, dtype, (valid,), val)
return UOp(UOps.CONST, dtype, (unbound_st.to_uop(),), val)
# otherwise, it's a load and we add it to the inputs
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
@@ -78,7 +77,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
ubuf = UOp(UOps.DEFINE_GLOBAL, buf.dtype if isinstance(buf.dtype, ImageDType) else PtrDType(buf.dtype), (),
outputs.index(assign_targets[buf]) if buf in assign_targets else len(outputs)+inputs.setdefault(buf, len(inputs)))
return UOp(UOps.LOAD, dtype, (ubuf, idx, valid))
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
@@ -147,9 +146,9 @@ def _recurse_reduceops(buf:LazyBuffer, st:ShapeTracker, realizes:Dict[LazyBuffer
def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) -> LBScheduleItem:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op is MetaOps.COPY and getenv("USE_COPY_KERNEL") and out.device.split(":")[0] == out.srcs[0].device.split(":")[0]:
idx, valid = ShapeTracker.from_shape(out.arg).to_uops()
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), idx, valid))
wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), idx, rd, valid))
st_uop = ShapeTracker.from_shape(out.arg).to_uop()
rd = UOp(UOps.LOAD, dtypes.uint8, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.uint8), (), 1), st_uop))
wr = UOp(UOps.STORE, None, (UOp(UOps.DEFINE_GLOBAL, PtrDType(out.dtype), (), 0), st_uop, rd))
return LBScheduleItem(UOp(UOps.SINK, None, (wr,)), outs, [x.base for x in out.srcs])
if out.op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), outs, [x.base for x in out.srcs])
@@ -180,7 +179,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]) ->
output_st, vv = output_st.simplify().unbind()
if vv: var_vals.update(vv)
ubuf = UOp(UOps.DEFINE_GLOBAL, out.dtype if isinstance(out.dtype, ImageDType) else PtrDType(out.dtype), (), i)
ast.append(UOp(UOps.STORE, None, (ubuf, UOp(UOps.ST_IDX, dtypes.pyint, (), output_st), src, UOp(UOps.ST_VALID, dtypes.bool, (), output_st))))
ast.append(UOp(UOps.STORE, None, (ubuf, output_st.to_uop(), src)))
return LBScheduleItem(UOp(UOps.SINK, None, tuple(ast)), outs, list(inputs), var_vals,
dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))

View File

@@ -78,7 +78,7 @@ def exec_alu(op:Op, dtype:DType, operands): return truncate.get(dtype, lambda x:
# the order of these UOps controls the order of the toposort
class UOps(Enum):
# ops that aren't rendered
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); ST_IDX = auto(); ST_VALID = auto() # noqa: E702
SINK = auto(); EXT = auto(); EXPAND = auto(); CONTRACT = auto(); SHAPETRACKER = auto() # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_VAR = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # noqa: E702
CONST = auto(); SPECIAL = auto() # noqa: E702
NOOP = auto(); GEP = auto() # noqa: E702
@@ -120,8 +120,8 @@ class UOp:
@property
def st_arg(self) -> ShapeTracker:
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
ret = self.src[-1]
assert ret.op in {UOps.ST_IDX, UOps.ST_VALID}, f"st_arg trying to return {ret}"
ret = self.src[0 if self.op is UOps.CONST else 1]
assert ret.op is UOps.SHAPETRACKER, f"st_arg trying to return {ret}"
return ret.arg
def ufix(self, x): return self.const(x) if not isinstance(x, UOp) else x
def cast(self, dtype=None): return type(self)(UOps.CAST, dtype, (self,))
@@ -170,7 +170,7 @@ class UOp:
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
@functools.cached_property
def full_shape(self) -> Tuple[sint, ...]:
if self.op in {UOps.ST_IDX, UOps.ST_VALID}: return self.arg.shape
if self.op is UOps.SHAPETRACKER: return self.arg.shape
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])

View File

@@ -64,7 +64,7 @@ class ShapeTracker:
def reduce(self, axis:Tuple[int, ...]) -> Tuple[sint, ...]: return tuple(1 if i in axis else s for i,s in enumerate(self.shape))
def to_uops(self) -> Tuple[UOp, UOp]: return UOp(UOps.ST_IDX, dtypes.pyint, (), self), UOp(UOps.ST_VALID, dtypes.bool, (), self)
def to_uop(self) -> UOp: return UOp(UOps.SHAPETRACKER, None, (), self)
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \