mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
override assign_target in fuzz_schedule (#4342)
* store assign_targets * cleanup * override target
This commit is contained in:
9
test/external/fuzz_schedule.py
vendored
9
test/external/fuzz_schedule.py
vendored
@@ -23,6 +23,7 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
|
||||
# setup ground truth
|
||||
ground_truth: Dict[LazyBuffer, memoryview] = {}
|
||||
assign_targets: Dict[LazyBuffer, LazyBuffer] = {}
|
||||
# IMPORTANT: freeze prerealized bufs before ScheduleItem exec
|
||||
prerealized: Dict[LazyBuffer, memoryview] = {}
|
||||
seed = Tensor._seed
|
||||
@@ -30,7 +31,9 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
for key in ts:
|
||||
for out in (ps:=prescheduled[key]).outputs:
|
||||
# freeze assign state before exec
|
||||
if out.op is LoadOps.ASSIGN: prerealized[out] = out.buffer.as_buffer()
|
||||
if out.op is LoadOps.ASSIGN:
|
||||
prerealized[out] = out.buffer.as_buffer()
|
||||
assign_targets[out.srcs[1]] = out
|
||||
for x in ps.inputs:
|
||||
if x not in ground_truth and x.device != "NPY": prerealized[x] = x.buffer.as_buffer()
|
||||
si = ScheduleItem(ps.ast, tuple(x.buffer for x in (ps.outputs+ps.inputs) if x.size != 0))
|
||||
@@ -49,7 +52,9 @@ def fuzz_schedule(outs: List[LazyBuffer]):
|
||||
if out.op is LoadOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out])
|
||||
for x in ps.inputs:
|
||||
if x not in rawbufs:
|
||||
if x.device == "NPY": rawbufs[x] = x.buffer
|
||||
# override the assign_target after ASSIGN
|
||||
if x in assign_targets and assign_targets[x] in rawbufs: rawbufs[x] = rawbufs[assign_targets[x]]
|
||||
elif x.device == "NPY": rawbufs[x] = x.buffer
|
||||
# copy the pre realized input
|
||||
else: rawbufs[x] = Buffer(x.buffer.device, x.buffer.size, x.buffer.dtype, initial_value=prerealized[x])
|
||||
si = ScheduleItem(ps.ast, tuple(rawbufs[x] for x in (ps.outputs+ps.inputs) if x.size != 0))
|
||||
|
||||
Reference in New Issue
Block a user