mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-28 08:17:58 -05:00
remove np from metal graph (#3129)
This commit is contained in:
@@ -1,5 +1,4 @@
|
||||
from typing import List, Any, Dict, cast, Optional
|
||||
import numpy as np
|
||||
import Metal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, unwrap2
|
||||
@@ -50,7 +49,7 @@ class MetalGraph:
|
||||
icb_command.setBarrier()
|
||||
self.all_resources = dedup(all_resources)
|
||||
self.command_buffer: Any = None
|
||||
if len(var_vals): self.int_buf_view = np.frombuffer(self.int_buf.contents().as_buffer(self.int_buf.length()), np.int32)
|
||||
if len(var_vals): self.int_buf_view = self.int_buf.contents().as_buffer(self.int_buf.length()).cast('i')
|
||||
|
||||
def __call__(self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False) -> Optional[float]:
|
||||
# NOTE: you at least can't update the ints if this is running
|
||||
@@ -61,7 +60,7 @@ class MetalGraph:
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501
|
||||
if len(var_vals): self.int_buf_view[:] = list(var_vals.values())
|
||||
for j, value in enumerate(var_vals.values()): self.int_buf_view[j] = value
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
encoder = command_buffer.computeCommandEncoder()
|
||||
encoder.useResources_count_usage_(all_resources, len(all_resources), Metal.MTLResourceUsageRead | Metal.MTLResourceUsageWrite)
|
||||
|
||||
Reference in New Issue
Block a user