mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-27 07:48:07 -05:00
early uop UOps.BUFFER (only once) [run_process_replay] (#6820)
* buf_uops lookup [run_process_replay] * next diff will be this * fix ImageDType
This commit is contained in:
@@ -127,7 +127,7 @@ def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp:
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
|
||||
bufs:Tuple[Buffer, ...], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||
buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
||||
"""recursively create a UOp"""
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
@@ -136,32 +136,34 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype
|
||||
|
||||
# buffer ops define ShapeTracker
|
||||
if buf.buffer in bufs and buf not in outputs:
|
||||
# if it's a const, we generate it
|
||||
if buf.op is MetaOps.CONST:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if isinstance(val:=buf.arg, Variable):
|
||||
val, var_val = val.unbind()
|
||||
var_vals[val] = var_val
|
||||
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
|
||||
return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
# if it's realized, it's a load and we add it to the inputs
|
||||
if (ubuf:=buf_uops.get(buf.buffer)) and buf not in outputs:
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
# if it's a const, we generate it
|
||||
if buf.op is MetaOps.CONST:
|
||||
if isinstance(val:=buf.arg, Variable):
|
||||
val, var_val = val.unbind()
|
||||
var_vals[val] = var_val
|
||||
else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}"
|
||||
return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0))
|
||||
# otherwise, it's a load and we add it to the inputs
|
||||
if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))):
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
|
||||
return UOp(UOps.LOAD, dtype, (UOp(UOps.BUFFER, buf.dtype, (), bufs.index(buf.buffer)), unbound_st.to_uop()))
|
||||
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, bufs, assign_targets, cache)
|
||||
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, buf_uops, assign_targets, cache)
|
||||
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st))
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, bufs, assign_targets, cache) for x in buf.srcs)
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs)
|
||||
if buf.op is MetaOps.CONTIGUOUS:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
return in_uops[0]
|
||||
@@ -170,7 +172,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[Buffer, ...]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
||||
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs)), {}
|
||||
@@ -181,14 +183,14 @@ def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[Buffer, ...]) -> Tuple[L
|
||||
ast: List[UOp] = []
|
||||
inputs: List[LazyBuffer] = []
|
||||
for out in outs:
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, bufs, assign_targets, cache=cache)
|
||||
src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, buf_uops, assign_targets, cache=cache)
|
||||
if out.op is MetaOps.ASSIGN and out.arg:
|
||||
assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}"
|
||||
output_st = out.arg[0]
|
||||
output_st, vv = output_st.simplify().unbind()
|
||||
var_vals.update(vv)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (UOp(UOps.BUFFER, out.dtype, (), bufs.index(out.buffer)), output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(dedup(bufs.index(x.buffer) for x in outs+inputs))))
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(buf_uops[x.buffer].arg for x in outs+inputs)))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
@@ -272,7 +274,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff
|
||||
|
||||
def _get_output_groups(outs:List[LazyBuffer]) -> \
|
||||
Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups
|
||||
Tuple[Buffer, ...], # these are all the realizes in the graph
|
||||
Dict[Buffer, UOp], # this is a map of realized Buffers to UOps.BUFFER
|
||||
Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule
|
||||
"""find all the realizes in the graph, group the output LazyBuffers into kernels."""
|
||||
# start by just realizing the buffers passed in
|
||||
@@ -348,21 +350,26 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \
|
||||
for tr in group: del realizes[tr]
|
||||
|
||||
output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
|
||||
buf_uops: Dict[Buffer, UOp] = {}
|
||||
for buf in realizes:
|
||||
if buf.realized is not None or buf.op is MetaOps.CONST: continue
|
||||
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
|
||||
# TODO: const can be in buf_uops too, SWIZZLE on VALID pushes through!
|
||||
if buf.op is MetaOps.CONST: continue
|
||||
if buf.realized is None:
|
||||
output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf)
|
||||
|
||||
# make things that can't be images not images
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
||||
buf.dtype = dtypes.float32
|
||||
# hack the underlying buffer too
|
||||
if buf.base is buf:
|
||||
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
|
||||
buf.buffer.dtype = dtypes.float32
|
||||
buf.buffer.options = None
|
||||
return output_groups, tuple(dedup(x.buffer for x in realizes)), assign_targets
|
||||
# make things that can't be images not images
|
||||
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
|
||||
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
|
||||
if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
|
||||
buf.dtype = dtypes.float32
|
||||
# hack the underlying buffer too
|
||||
if buf.base is buf:
|
||||
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
|
||||
buf.buffer.dtype = dtypes.float32
|
||||
buf.buffer.options = None
|
||||
# NOTE: UOps.BUFFER creation must come after the ImageDType fixup
|
||||
buf_uops.setdefault(buf.buffer, UOp(UOps.BUFFER, buf.buffer.dtype, (), len(buf_uops)))
|
||||
return output_groups, buf_uops, assign_targets
|
||||
|
||||
SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = []
|
||||
def _graph_schedule(outs:List[LazyBuffer]) -> \
|
||||
@@ -370,12 +377,12 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \
|
||||
DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph
|
||||
Dict[Variable, int]]: # this has all the var values of the schedule
|
||||
"""create a graph for realizing the outputs"""
|
||||
output_groups, bufs, assign_targets = _get_output_groups(outs)
|
||||
output_groups, buf_uops, assign_targets = _get_output_groups(outs)
|
||||
# preschedule all buffers in realizes
|
||||
prescheduled: List[LBScheduleItem] = []
|
||||
var_vals: Dict[Variable, int] = {}
|
||||
for group in output_groups.values():
|
||||
prescheduled.append((ret:=_lower_lazybuffer(group, bufs))[0])
|
||||
prescheduled.append((ret:=_lower_lazybuffer(group, buf_uops))[0])
|
||||
var_vals = merge_dicts([var_vals, ret[1]])
|
||||
schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user