override assign_target in fuzz_schedule (#4342)

* store assign_targets

* cleanup

* override target
This commit is contained in:
qazal
2024-04-29 11:04:04 +03:00
committed by GitHub
parent bb849a57d1
commit 774a9b0bca

View File

@@ -23,6 +23,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
# setup ground truth
ground_truth: Dict[LazyBuffer, memoryview] = {}
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
# IMPORTANT: freeze prerealized bufs before ScheduleItem exec
prerealized: Dict[LazyBuffer, memoryview] = {}
seed = Tensor._seed
@@ -30,7 +31,9 @@ def fuzz_schedule(outs: List[LazyBuffer]):
for key in ts:
for out in (ps:=prescheduled[key]).outputs:
# freeze assign state before exec
if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer()
if out.op is LoadOps.ASSIGN:
prerealized[out] = out.buffer.as_buffer()
assign_targets[out.srcs[1]] = out
for x in ps.inputs:
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
si = ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0))
@@ -49,7 +52,9 @@ def fuzz_schedule(outs: List[LazyBuffer]):
if out.op is LoadOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
for x in ps.inputs:
if x not in rawbufs:
if x.device == "NPY": rawbufs[x] = x.buffer
# override the assign_target after ASSIGN
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
elif x.device == "NPY": rawbufs[x] = x.buffer
# copy the pre realized input
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in (ps.outputs+ps.inputs) if x.size != 0))