mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
give EXT schedules metadata [pr] (#6865)
This commit is contained in:
4
test/external/fuzz_schedule.py
vendored
4
test/external/fuzz_schedule.py
vendored
@@ -38,7 +38,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
assign_targets[out.srcs[1]] = out
|
||||
for x in lsi.inputs:
|
||||
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
||||
si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0))
|
||||
si = ScheduleItem(lsi.ast, tuple(x.buffer for x in lsi.outputs+lsi.inputs if x.size != 0), lsi.metadata)
|
||||
_exec_si(si, seed)
|
||||
for out in lsi.outputs:
|
||||
ground_truth[out] = out.buffer.as_buffer()
|
||||
@@ -60,7 +60,7 @@ def fuzz_schedule(outs:List[LazyBuffer]):
|
||||
elif x.device == "NPY": rawbufs[x] = x.buffer
|
||||
# copy the pre realized input
|
||||
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=bytes(prerealized[x]))
|
||||
si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0))
|
||||
si = ScheduleItem(lsi.ast, tuple(rawbufs[x] for x in lsi.bufs if x.size != 0), lsi.metadata)
|
||||
_exec_si(si, seed)
|
||||
for out in lsi.outputs:
|
||||
outbuf = np.frombuffer(rawbufs[out].as_buffer(), _to_np_dtype(out.dtype))
|
||||
|
||||
@@ -23,7 +23,7 @@ BUF_LIMIT = {"METAL": 32}
|
||||
class ScheduleItem:
|
||||
ast: UOp
|
||||
bufs: Tuple[Buffer, ...]
|
||||
metadata: Optional[Tuple[Metadata, ...]] = None
|
||||
metadata: Optional[Tuple[Metadata, ...]]
|
||||
@property
|
||||
def outputs(self) -> Tuple[Buffer, ...]:
|
||||
"""Read/write or write only buffers in the schedule."""
|
||||
@@ -37,7 +37,7 @@ class ScheduleItem:
|
||||
class LBScheduleItem:
|
||||
ast: UOp
|
||||
bufs: Tuple[LazyBuffer, ...]
|
||||
metadata: Optional[Tuple[Metadata, ...]] = None
|
||||
metadata: Optional[Tuple[Metadata, ...]]
|
||||
@property
|
||||
def outputs(self) -> Tuple[LazyBuffer, ...]: return self.bufs[:len(self.ast.src)] if self.ast.op is UOps.SINK else self.bufs[0:1]
|
||||
@property
|
||||
@@ -174,7 +174,8 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
if (out:=outs[0]).op in {MetaOps.CUSTOM, MetaOps.COPY, MetaOps.EMPTY, MetaOps.VIEW}:
|
||||
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs)), {}
|
||||
metadata = (out.metadata,) if out.metadata is not None else None
|
||||
return LBScheduleItem(UOp(UOps.EXT, out.dtype, (), (out.op, out.arg)), (out,)+tuple(x.base for x in out.srcs), metadata), {}
|
||||
# create the stores
|
||||
var_vals = merge_dicts([out.st.var_vals.copy() for out in outs])
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||
|
||||
Reference in New Issue
Block a user