remove np from metal graph (#3129)

This commit is contained in:
nimlgen
2024-01-15 19:44:35 +03:00
committed by GitHub
parent 2ef09ca641
commit 5ec66938de

View File

@@ -1,5 +1,4 @@
from typing import List, Any, Dict, cast, Optional
import numpy as np
import Metal
from tinygrad.dtype import dtypes
from tinygrad.helpers import dedup, unwrap2
@@ -50,7 +49,7 @@ class MetalGraph:
icb_command.setBarrier()
self.all_resources = dedup(all_resources)
self.command_buffer: Any = None
if len(var_vals): self.int_buf_view = np.frombuffer(self.int_buf.contents().as_buffer(self.int_buf.length()), np.int32)
if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
# NOTE: you at least can't update the ints if this is running
@@ -61,7 +60,7 @@ class MetalGraph:
for j in self.jc_idx_with_updatable_launch_dims:
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501
if len(var_vals): self.int_buf_view[:] = list(var_vals.values())
for j, value in enumerate(var_vals.values()): self.int_buf_view[j] = value
command_buffer = self.device.mtl_queue.commandBuffer()
encoder = command_buffer.computeCommandEncoder()
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)