diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 5d9bc8887b..d2cf1cff05 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -155,17 +155,18 @@ if getenv("RUN_PROCESS_REPLAY"): # *** List[LazyBuffer] lowering to ScheduleItem *** -def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], metadata:Dict[UOp, Metadata], + cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: - cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, cache).view(buf.st) + cache[buf] = ret = to_uop(buf.base, outputs, inputs, buf_uops, metadata, cache).view(buf.st) return ret if buf.op is MetaOps.CONST: return buf_uops[buf.buffer] dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype if (ubuf:=buf_uops.get(buf.buffer)) is not None and buf not in outputs: if not any(x.buffer is buf.buffer for x in outputs) and buf not in inputs: inputs.append(buf) return UOp.load(ubuf, buf.st.to_uop(), dtype=dtype) - src = tuple(to_uop(x, outputs, inputs, buf_uops, cache) for x in buf.srcs) + src = tuple(to_uop(x, outputs, inputs, buf_uops, metadata, cache) for x in buf.srcs) if buf.op in ReduceOps: ret = src[0].r(buf.op, buf.arg) elif buf.op is MetaOps.CONTIGUOUS: ret = UOp(UOps.CONTIGUOUS, dtype, src) elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[1]), buf.arg) @@ -174,14 +175,16 @@ def to_uop(buf:LazyBuffer, outputs:List[LazyBuffer], inputs:List[LazyBuffer], bu elif buf.op is UnaryOps.BITCAST: ret = UOp(UOps.BITCAST, dtype, src) else: ret = UOp(UOps.ALU, dtype, src, buf.op) cache[buf] = ret + if buf.metadata is not None: metadata[ret] = buf.metadata return ret def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_vals:Dict[Variable, int]) -> LBScheduleItem: """describe the computation for a LazyBuffer with UOp + inputs + var_vals""" cache: Dict[LazyBuffer, UOp] = {} inputs: List[LazyBuffer] = [] + metadata: Dict[UOp, Metadata] = {} sink = UOp(UOps.SINK, src=tuple(UOp.store(buf_uops[out.buffer], ShapeTracker.from_shape(out.shape).to_uop(), - to_uop(out, outs, inputs, buf_uops, cache)) for out in outs)) + to_uop(out, outs, inputs, buf_uops, metadata, cache)) for out in outs)) sink = full_ast_rewrite(sink, tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs), var_vals) # we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine if len(assign_targets:=[x.src[0] for x in sink.sparents if x.op is UOps.ASSIGN]) != 0: @@ -189,8 +192,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp], var_val and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in sink.sparents if x.op is UOps.LOAD and x.src[0] in assign_targets): raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n" +colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green")) - return LBScheduleItem(sink, tuple(outs+inputs), - tuple(dedup([x.metadata for x in cache if x.metadata is not None and (x.base in outs or x.base.buffer not in buf_uops)]))) + return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup(metadata.values()))) # *** DAG creation: decide which LazyBuffers should realize ***