mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 23:25:04 -05:00
ScheduleItem uses Buffer (#3995)
* schedule Buffer * update * update tests * master * works * remove LoadOps.WAIT * fix compile2 * bad test * rename and note
This commit is contained in:
@@ -36,7 +36,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
|
||||
schedule = create_schedule([ret.lazydata])
|
||||
|
||||
# filter schedule that don't depend on the inputs
|
||||
input_lb = [x.lazydata.base for x in inputs.values()]
|
||||
input_lb = [x.lazydata.base.buffer for x in inputs.values()]
|
||||
depends = set(input_lb)
|
||||
for si in schedule:
|
||||
if any(b in depends for b in si.inputs):
|
||||
@@ -89,10 +89,10 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
|
||||
|
||||
# run code (all buffers have been allocated)
|
||||
GlobalCounters.reset()
|
||||
for si in schedule: lower_schedule_item(si)([x.realized for x in si.outputs+si.inputs], {})
|
||||
for si in schedule: lower_schedule_item(si)(si.outputs+si.inputs, {})
|
||||
|
||||
new_tinygrad_out = Tensor(schedule[-1].outputs[0]).numpy()
|
||||
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
new_tinygrad_out = np.frombuffer(schedule[-1].outputs[0].as_buffer(), dtype=schedule[-1].outputs[0].dtype.np)
|
||||
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
|
||||
print("semi-thneed self-test passed!")
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user