diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 7da288fb81..65c656f852 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -127,7 +127,7 @@ def full_ast_rewrite(sink:UOp, ctx:ScheduleItemContext) -> UOp: # *** List[LazyBuffer] lowering to ScheduleItem *** def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer], - bufs:Tuple[Buffer, ...], assign_targets:Dict[LazyBuffer, LazyBuffer], + buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: """recursively create a UOp""" if buf is not buf.base: st, buf = buf.st+st, buf.base @@ -136,32 +136,34 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype # buffer ops define ShapeTracker - if buf.buffer in bufs and buf not in outputs: + # if it's a const, we generate it + if buf.op is MetaOps.CONST: + unbound_st, st_var_vals = st.simplify().unbind() + var_vals.update(st_var_vals) + if isinstance(val:=buf.arg, Variable): + val, var_val = val.unbind() + var_vals[val] = var_val + else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}" + return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0)) + # if it's realized, it's a load and we add it to the inputs + if (ubuf:=buf_uops.get(buf.buffer)) and buf not in outputs: unbound_st, st_var_vals = st.simplify().unbind() var_vals.update(st_var_vals) - # if it's a const, we generate it - if buf.op is MetaOps.CONST: - if isinstance(val:=buf.arg, Variable): - val, var_val = val.unbind() - var_vals[val] = var_val - else: assert isinstance(val, get_args(ConstType)), f"cannot create ConstBuffer with value {val}" - return UOp(UOps.VALID, dtypes.bool, (unbound_st.to_uop(),)).where(UOp.const(dtype, val), UOp.const(dtype, 0)) - # otherwise, it's a load and we add it to the inputs if buf in assign_targets and not (unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and \ ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask))): # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) if buf not in assign_targets and buf not in inputs: inputs.append(buf) - return UOp(UOps.LOAD, dtype, (UOp(UOps.BUFFER, buf.dtype, (), bufs.index(buf.buffer)), unbound_st.to_uop())) + return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop())) # reduce ops change ShapeTracker if buf.op in ReduceOps: - rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, bufs, assign_targets, cache) + rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, buf_uops, assign_targets, cache) return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).swizzle(st)) # elementwise ops pass shapetracker - in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, bufs, assign_targets, cache) for x in buf.srcs) + in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs) if buf.op is MetaOps.CONTIGUOUS: assert buf in outputs, f"{buf.op} must be writable" return in_uops[0] @@ -170,7 +172,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops)) return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op)) -def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[Buffer, ...]) -> Tuple[LBScheduleItem, Dict[Variable, int]]: +def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}: return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs)), {} @@ -181,14 +183,14 @@ def _lower_lazybuffer(outs:List[LazyBuffer], bufs:Tuple[Buffer, ...]) -> Tuple[L ast: List[UOp] = [] inputs: List[LazyBuffer] = [] for out in outs: - src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, bufs, assign_targets, cache=cache) + src = _recursive_uop(out, output_st:=ShapeTracker.from_shape(out.shape), tuple(outs), var_vals, inputs, buf_uops, assign_targets, cache=cache) if out.op is MetaOps.ASSIGN and out.arg: assert out.arg[0].shape == out.shape, f"ASSIGN must not override output shape {out.arg[0].shape} != {out.shape}" output_st = out.arg[0] output_st, vv = output_st.simplify().unbind() var_vals.update(vv) - ast.append(UOp(UOps.STORE, dtypes.void, (UOp(UOps.BUFFER, out.dtype, (), bufs.index(out.buffer)), output_st.to_uop(), src))) - sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(dedup(bufs.index(x.buffer) for x in outs+inputs)))) + ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src))) + sink = full_ast_rewrite(ast[0].sink(*ast[1:]), ScheduleItemContext(bufs=tuple(buf_uops[x.buffer].arg for x in outs+inputs))) return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals # *** DAG creation: decide which LazyBuffers should realize *** @@ -272,7 +274,7 @@ def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, LazyBuff def _get_output_groups(outs:List[LazyBuffer]) -> \ Tuple[DefaultDict[LazyBuffer, List[LazyBuffer]], # these are the output groups - Tuple[Buffer, ...], # these are all the realizes in the graph + Dict[Buffer, UOp], # this is a map of realized Buffers to UOps.BUFFER Dict[LazyBuffer, LazyBuffer]]: # these are the buffers we ASSIGN to in this schedule """find all the realizes in the graph, group the output LazyBuffers into kernels.""" # start by just realizing the buffers passed in @@ -348,21 +350,26 @@ def _get_output_groups(outs:List[LazyBuffer]) -> \ for tr in group: del realizes[tr] output_groups: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list) + buf_uops: Dict[Buffer, UOp] = {} for buf in realizes: - if buf.realized is not None or buf.op is MetaOps.CONST: continue - output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf) + # TODO: const can be in buf_uops too, SWIZZLE on VALID pushes through! + if buf.op is MetaOps.CONST: continue + if buf.realized is None: + output_groups[reduce_for_op[buf] if buf in reduce_for_op and MULTIOUTPUT else buf].append(buf) - # make things that can't be images not images - if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or - not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): - if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32") - buf.dtype = dtypes.float32 - # hack the underlying buffer too - if buf.base is buf: - assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer" - buf.buffer.dtype = dtypes.float32 - buf.buffer.options = None - return output_groups, tuple(dedup(x.buffer for x in realizes)), assign_targets + # make things that can't be images not images + if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or + not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())): + if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32") + buf.dtype = dtypes.float32 + # hack the underlying buffer too + if buf.base is buf: + assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer" + buf.buffer.dtype = dtypes.float32 + buf.buffer.options = None + # NOTE: UOps.BUFFER creation must come after the ImageDType fixup + buf_uops.setdefault(buf.buffer, UOp(UOps.BUFFER, buf.buffer.dtype, (), len(buf_uops))) + return output_groups, buf_uops, assign_targets SCHEDULES: List[Tuple[DefaultDict[LBScheduleItem, List[LBScheduleItem]], DefaultDict[LBScheduleItem, int]]] = [] def _graph_schedule(outs:List[LazyBuffer]) -> \ @@ -370,12 +377,12 @@ def _graph_schedule(outs:List[LazyBuffer]) -> \ DefaultDict[LBScheduleItem, int], # this is the in-degree of the graph Dict[Variable, int]]: # this has all the var values of the schedule """create a graph for realizing the outputs""" - output_groups, bufs, assign_targets = _get_output_groups(outs) + output_groups, buf_uops, assign_targets = _get_output_groups(outs) # preschedule all buffers in realizes prescheduled: List[LBScheduleItem] = [] var_vals: Dict[Variable, int] = {} for group in output_groups.values(): - prescheduled.append((ret:=_lower_lazybuffer(group, bufs))[0]) + prescheduled.append((ret:=_lower_lazybuffer(group, buf_uops))[0]) var_vals = merge_dicts([var_vals, ret[1]]) schedule_targets = {out:lsi for lsi in prescheduled for out in lsi.outputs}