ScheduleItem uses Buffer (#3995)

* schedule Buffer

* update

* update tests

* master

* works

* remove LoadOps.WAIT

* fix compile2

* bad test

* rename and note
This commit is contained in:
George Hotz
2024-03-29 20:50:27 -07:00
committed by GitHub
parent 1bd4f01da2
commit 9eef44521b
9 changed files with 47 additions and 41 deletions

View File

@@ -36,7 +36,7 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
schedule = create_schedule([ret.lazydata])
# filter schedule that don't depend on the inputs
input_lb = [x.lazydata.base for x in inputs.values()]
input_lb = [x.lazydata.base.buffer for x in inputs.values()]
depends = set(input_lb)
for si in schedule:
if any(b in depends for b in si.inputs):
@@ -89,10 +89,10 @@ def test_vs_onnx(onnx_data, schedule:Optional[List[ScheduleItem]], inputs:Dict[s
# run code (all buffers have been allocated)
GlobalCounters.reset()
for si in schedule: lower_schedule_item(si)([x.realized for x in si.outputs+si.inputs], {})
for si in schedule: lower_schedule_item(si)(si.outputs+si.inputs, {})
new_tinygrad_out = Tensor(schedule[-1].outputs[0]).numpy()
np.testing.assert_allclose(new_torch_out, new_tinygrad_out, atol=1e-4, rtol=1e-2)
new_tinygrad_out = np.frombuffer(schedule[-1].outputs[0].as_buffer(), dtype=schedule[-1].outputs[0].dtype.np)
np.testing.assert_allclose(new_torch_out.reshape(new_tinygrad_out.shape), new_tinygrad_out, atol=1e-4, rtol=1e-2)
print("semi-thneed self-test passed!")
if __name__ == "__main__":