hotfix: fix metal with vars (#5294)

* hotfix: fix metal with vars

* one more place
This commit is contained in:
nimlgen
2024-07-05 16:53:40 +03:00
committed by GitHub
parent 8a548b0b6e
commit d7835a705c

View File

@@ -24,7 +24,7 @@ class MetalGraph(GraphRunner):
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
all_resources = [self.int_buf] if len(self.vars) else []
all_resources = [self.int_buf.buf] if len(self.vars) else []
for j,ji in enumerate(self.jit_cache):
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
@@ -38,7 +38,7 @@ class MetalGraph(GraphRunner):
if b is not None:
icb_command.setKernelBuffer_offset_atIndex_(b._buf.buf, b._buf.offset, i)
all_resources.append(b._buf.buf)
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
if j not in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = prg.p.launch_dims(var_vals)
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
@@ -46,7 +46,7 @@ class MetalGraph(GraphRunner):
self.all_resources = dedup(all_resources)
self.command_buffer: Any = None
if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
if len(self.vars): self.int_buf_view = self.int_buf.buf.contents().as_buffer(self.int_buf.buf.length()).cast('i')
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)