mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
small changes from the lazy_pm branch [pr] (#7047)
This commit is contained in:
@@ -137,8 +137,7 @@ def full_ast_rewrite(base_sink:UOp, bufs:Tuple[int, ...]) -> UOp:
|
||||
# *** List[LazyBuffer] lowering to ScheduleItem ***
|
||||
|
||||
def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], inputs:List[LazyBuffer],
|
||||
buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||
cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
||||
buf_uops:Dict[Buffer, UOp], assign_targets:Dict[LazyBuffer, LazyBuffer], cache:Dict[Tuple[LazyBuffer, ShapeTracker], UOp]) -> UOp:
|
||||
"""recursively create a UOp"""
|
||||
if buf is not buf.base: st, buf = buf.st+st, buf.base
|
||||
if (buf, st) in cache: return cache[(buf, st)]
|
||||
@@ -168,7 +167,7 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
elif buf.op is MetaOps.CONTIGUOUS:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
ret = src[0]
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (src[1].src[0], src[0]))
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (buf_uops[buf.buffer], src[0]))
|
||||
elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype)
|
||||
elif buf.op is UnaryOps.BITCAST: ret = src[0].bitcast(dtype)
|
||||
else: ret = UOp(UOps.ALU, dtype, tuple(src), buf.op)
|
||||
@@ -195,7 +194,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tupl
|
||||
var_vals.update(vv)
|
||||
ast.append(UOp(UOps.STORE, dtypes.void, (buf_uops[out.buffer], output_st.to_uop(), src)))
|
||||
sink = full_ast_rewrite(ast[0].sink(*ast[1:]), tuple(buf_uops[x.buffer].arg[0] for x in outs+inputs))
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs]))), var_vals
|
||||
return LBScheduleItem(sink, tuple(outs+inputs), tuple(dedup([x.metadata for x,_ in cache if x.metadata and x not in inputs]))), var_vals
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user