give EXT schedules metadata [pr] (#6865)

This commit is contained in:
qazal
2024-10-03 20:14:18 +08:00
committed by GitHub
parent 5517a07a09
commit 17068410e6
2 changed files with 6 additions and 5 deletions

View File

@@ -38,7 +38,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
assign_targets[out.srcs[1]] = out
for x in lsi.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0))
si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)
_exec_si(si, seed)
for out in lsi.outputs:
ground_truth[out] = out.buffer.as_buffer()
@@ -60,7 +60,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
elif x.device == "NPY": rawbufs[x] = x.buffer
# copy the pre realized input
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=bytes(prerealized[x]))
si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0))
si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0), lsi.metadata)
_exec_si(si, seed)
for out in lsi.outputs:
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))

View File

@@ -23,7 +23,7 @@ BUF_LIMIT = {"METAL": 32}
class ScheduleItem:
ast: UOp
bufs: Tuple[Buffer, ...]
metadata: Optional[Tuple[Metadata, ...]] = None
metadata: Optional[Tuple[Metadata, ...]]
@property
def outputs(self) -> Tuple[Buffer, ...]:
"""Read/write or write only buffers in the schedule."""
@@ -37,7 +37,7 @@ class ScheduleItem:
class LBScheduleItem:
ast: UOp
bufs: Tuple[LazyBuffer, ...]
metadata: Optional[Tuple[Metadata, ...]] = None
metadata: Optional[Tuple[Metadata, ...]]
@property
def outputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1]
@property
@@ -174,7 +174,8 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs)), {}
metadata = (out.metadata,) if out.metadata is not None else None
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs), metadata), {}
# create the stores
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}