early uop UOps.BUFFER (only once) [run_process_replay] (#6820)

* buf_uops lookup [run_process_replay]

* next diff will be this

* fix ImageDType
This commit is contained in:
qazal
2024-10-01 08:46:05 +08:00
committed by GitHub
parent e213bea426
commit a1dee0e532

View File

@@ -127,7 +127,7 @@ def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
# *** List[LazyBuffer] lowering to ScheduleItem ***
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
bufs:Tuple[Buffer, ...], assign_targets:Dict[LazyBuffer, LazyBuffer],
buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer],
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
"""recursively create a UOp"""
if buf is not buf.base: st, buf = buf.st+st, buf.base
@@ -136,32 +136,34 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
# buffer ops define ShapeTracker
if buf.buffer in bufs and buf not in outputs:
# if it's a const, we generate it
if buf.op is MetaOps.CONST:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if isinstance(val:=buf.arg, Variable):
val, var_val = val.unbind()
var_vals[val] = var_val
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0))
# if it's realized, it's a load and we add it to the inputs
if (ubuf:=buf_uops.get(buf.buffer)) and buf not in outputs:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
# if it's a const, we generate it
if buf.op is MetaOps.CONST:
if isinstance(val:=buf.arg, Variable):
val, var_val = val.unbind()
var_vals[val] = var_val
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0))
# otherwise, it's a load and we add it to the inputs
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
return UOp(UOps.LOAD, dtype, (UOp(UOps.BUFFER, buf.dtype, (), bufs.index(buf.buffer)), unbound_st.to_uop()))
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
# reduce ops change ShapeTracker
if buf.op in ReduceOps:
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, bufs, assign_targets, cache)
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, buf_uops, assign_targets, cache)
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st))
# elementwise ops pass shapetracker
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, bufs, assign_targets, cache) for x in buf.srcs)
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs)
if buf.op is MetaOps.CONTIGUOUS:
assert buf in outputs, f"{buf.op} must be writable"
return in_uops[0]
@@ -170,7 +172,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[Buffer, ...]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs)), {}
@@ -181,14 +183,14 @@ def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[Buffer, ...]) -> Tuple[L
ast: List[UOp] = []
inputs: List[LazyBuffer] = []
for out in outs:
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, bufs, assign_targets, cache=cache)
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, buf_uops, assign_targets, cache=cache)
if out.op is MetaOps.ASSIGN and out.arg:
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
output_st = out.arg[0]
output_st, vv = output_st.simplify().unbind()
var_vals.update(vv)
ast.append(UOp(UOps.STORE, dtypes.void, (UOp(UOps.BUFFER, out.dtype, (), bufs.index(out.buffer)), output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(dedup(bufs.index(x.buffer) for x in outs+inputs))))
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(buf_uops[x.buffer].arg for x in outs+inputs)))
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
# *** DAG creation: decide which LazyBuffers should realize ***
@@ -272,7 +274,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
def _get_output_groups(outs:List[LazyBuffer]) -> \
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups
Tuple[Buffer, ...], # these are all the realizes in the graph
Dict[Buffer, UOp], # this is a map of realized Buffers to UOps.BUFFER
Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule
"""find all the realizes in the graph, group the output LazyBuffers into kernels."""
# start by just realizing the buffers passed in
@@ -348,21 +350,26 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \
for tr in group: del realizes[tr]
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
buf_uops: Dict[Buffer, UOp] = {}
for buf in realizes:
if buf.realized is not None or buf.op is MetaOps.CONST: continue
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
# TODO: const can be in buf_uops too, SWIZZLE on VALID pushes through!
if buf.op is MetaOps.CONST: continue
if buf.realized is None:
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
# make things that can't be images not images
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32
# hack the underlying buffer too
if buf.base is buf:
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
buf.buffer.dtype = dtypes.float32
buf.buffer.options = None
return output_groups, tuple(dedup(x.buffer for x in realizes)), assign_targets
# make things that can't be images not images
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32
# hack the underlying buffer too
if buf.base is buf:
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
buf.buffer.dtype = dtypes.float32
buf.buffer.options = None
# NOTE: UOps.BUFFER creation must come after the ImageDType fixup
buf_uops.setdefault(buf.buffer, UOp(UOps.BUFFER, buf.buffer.dtype, (), len(buf_uops)))
return output_groups, buf_uops, assign_targets
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
def _graph_schedule(outs:List[LazyBuffer]) -> \
@@ -370,12 +377,12 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph
Dict[Variable, int]]: # this has all the var values of the schedule
"""create a graph for realizing the outputs"""
output_groups, bufs, assign_targets = _get_output_groups(outs)
output_groups, buf_uops, assign_targets = _get_output_groups(outs)
# preschedule all buffers in realizes
prescheduled: List[LBScheduleItem] = []
var_vals: Dict[Variable, int] = {}
for group in output_groups.values():
prescheduled.append((ret:=_lower_lazybuffer(group, bufs))[0])
prescheduled.append((ret:=_lower_lazybuffer(group, buf_uops))[0])
var_vals = merge_dicts([var_vals, ret[1]])
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}