mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
hotfix: fix metal with vars (#5294)
* hotfix: fix metal with vars * one more place
This commit is contained in:
@@ -24,7 +24,7 @@ class MetalGraph(GraphRunner):
|
||||
if self.icb is None: raise GraphException("create indirect command buffer failed, does your system support this?")
|
||||
|
||||
if len(self.vars): self.int_buf = self.device.allocator.alloc(len(self.vars)*dtypes.int32.itemsize)
|
||||
all_resources = [self.int_buf] if len(self.vars) else []
|
||||
all_resources = [self.int_buf.buf] if len(self.vars) else []
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
||||
@@ -38,7 +38,7 @@ class MetalGraph(GraphRunner):
|
||||
if b is not None:
|
||||
icb_command.setKernelBuffer_offset_atIndex_(b._buf.buf, b._buf.offset, i)
|
||||
all_resources.append(b._buf.buf)
|
||||
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
||||
for i,v in enumerate(prg.p.vars): icb_command.setKernelBuffer_offset_atIndex_(self.int_buf.buf, self.vars.index(v)*4, len(ji.bufs)+i)
|
||||
if j not in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = prg.p.launch_dims(var_vals)
|
||||
icb_command.concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size))
|
||||
@@ -46,7 +46,7 @@ class MetalGraph(GraphRunner):
|
||||
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.command_buffer: Any = None
|
||||
if len(self.vars): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
||||
if len(self.vars): self.int_buf_view = self.int_buf.buf.contents().as_buffer(self.int_buf.buf.length()).cast('i')
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False) -> Optional[float]:
|
||||
if self.command_buffer is not None and self.command_buffer in self.device.mtl_buffers_in_flight: wait_check(self.command_buffer)
|
||||
|
||||
Reference in New Issue
Block a user