From 5ec66938de8658837ee97676bba3154dea0f98e5 Mon Sep 17 00:00:00 2001 From: nimlgen <138685161+nimlgen@users.noreply.github.com> Date: Mon, 15 Jan 2024 19:44:35 +0300 Subject: [PATCH] remove np from metal graph (#3129) --- tinygrad/runtime/graph/metal.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tinygrad/runtime/graph/metal.py b/tinygrad/runtime/graph/metal.py index b03f7ab505..cd0154221f 100644 --- a/tinygrad/runtime/graph/metal.py +++ b/tinygrad/runtime/graph/metal.py @@ -1,5 +1,4 @@ from typing import List, Any, Dict, cast, Optional -import numpy as np import Metal from tinygrad.dtype import dtypes from tinygrad.helpers import dedup, unwrap2 @@ -50,7 +49,7 @@ class MetalGraph: icb_command.setBarrier() self.all_resources = dedup(all_resources) self.command_buffer: Any = None - if len(var_vals): self.int_buf_view = np.frombuffer(self.int_buf.contents().as_buffer(self.int_buf.length()), np.int32) + if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i') def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]: # NOTE: you at least can't update the ints if this is running @@ -61,7 +60,7 @@ class MetalGraph: for j in self.jc_idx_with_updatable_launch_dims: global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals) self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501 - if len(var_vals): self.int_buf_view[:] = list(var_vals.values()) + for j, value in enumerate(var_vals.values()): self.int_buf_view[j] = value command_buffer = self.device.mtl_queue.commandBuffer() encoder = command_buffer.computeCommandEncoder() encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)