diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 1e2a992b9d..b16605a116 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -137,8 +137,7 @@ def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp: # *** List[LazyBuffer] lowering to ScheduleItem *** def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer], - buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer], - cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: + buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp: """recursively create a UOp""" if buf is not buf.base: st, buf = buf.st+st, buf.base if (buf, st) in cache: return cache[(buf, st)] @@ -168,7 +167,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, .. elif buf.op is MetaOps.CONTIGUOUS: assert buf in outputs, f"{buf.op} must be writable" ret = src[0] - elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (src[1].src[0], src[0])) + elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[0])) elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype) elif buf.op is UnaryOps.BITCAST: ret = src[0].bitcast(dtype) else: ret = UOp(UOps.ALU, dtype, tuple(src), buf.op) @@ -195,7 +194,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tupl var_vals.update(vv) ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src))) sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs)) - return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals + return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x.metadata for x,_ in cache if x.metadata and x not in inputs]))), var_vals # *** DAG creation: decide which LazyBuffers should realize ***