From e88a640ca5436cff81d781b5e71b6a4a4cb102c7 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Wed, 26 Mar 2025 18:42:43 +0700 Subject: [PATCH] fix _access_resources for offset buffers (#9580) * fix _access_resources for offset buffers * test --- test/test_graph.py | 17 +++++++++++++++++ tinygrad/engine/jit.py | 4 +++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/test/test_graph.py b/test/test_graph.py index d37c4c50ba..ffc04844ed 100644 --- a/test/test_graph.py +++ b/test/test_graph.py @@ -38,6 +38,10 @@ def helper_alloc_rawbuffer(device, fill=False): rawbuf.copyin(Tensor(data).realize().lazydata.base.realized.as_buffer()) return rawbuf +def helper_create_offset_rawbuffer(base, offset=0): + x = Buffer(base.device, base.size-offset, base.dtype, base=base, offset=offset) + return x.ensure_allocated() + def helper_run_jit(jis, bufs, out_buffers): for rawbuf in out_buffers: mv = memoryview(bytearray(rawbuf.size * rawbuf.dtype.itemsize)) @@ -229,5 +233,18 @@ class TestGraph(unittest.TestCase): helper_test_graphs(Device[d0].graph, graphs) + def test_graph_offset_bufs(self): + d0 = Device.DEFAULT + if not hasattr(Device[d0].allocator, "_offset"): self.skipTest("device does not support _offset") + + b0 = [helper_alloc_rawbuffer(d0, fill=True) for _ in range(1)] + b0 += [helper_create_offset_rawbuffer(b0[0]), helper_create_offset_rawbuffer(b0[0])] + + graphs = [ + [helper_copy_op(d0, b0[0], b0[2]), helper_exec_op(d0, b0[1], [b0[0], b0[2]])], + ] + + helper_test_graphs(Device[d0].graph, graphs) + if __name__ == '__main__': unittest.main() diff --git a/tinygrad/engine/jit.py b/tinygrad/engine/jit.py index 98e9d43e2d..1400d8d772 100644 --- a/tinygrad/engine/jit.py +++ b/tinygrad/engine/jit.py @@ -120,7 +120,9 @@ class GraphRunner(Runner): if id(rawbuf.base._buf) in self.w_dependency_map: wait_nodes.append(self.w_dependency_map[id(rawbuf.base._buf)]) if i in write: if id(rawbuf.base._buf) in self.r_dependency_map: wait_nodes.extend(self.r_dependency_map.pop(id(rawbuf.base._buf))) - self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency + + for i,rawbuf in enumerate(rawbufs): + if i in write: self.w_dependency_map[id(rawbuf.base._buf)] = new_dependency else: self.r_dependency_map[id(rawbuf.base._buf)].append(new_dependency) return list({id(x):x for x in wait_nodes}.values())