mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-14 00:25:10 -05:00
CompiledASTRunner -> CompiledRunner (#4148)
This commit is contained in:
@@ -2,7 +2,7 @@ import ctypes, collections
|
||||
from typing import Any, Optional, Tuple, Dict, List, cast
|
||||
import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, GraphException, getenv
|
||||
from tinygrad.device import CompiledASTRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.device import CompiledRunner, update_stats, Buffer, MultiDeviceJITGraph, BufferXfer, Device, BufferOptions
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
@@ -11,7 +11,7 @@ from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_wi
|
||||
class CUDAGraph(MultiDeviceJITGraph):
|
||||
def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
# Check all jit items are compatible.
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
|
||||
if not all(isinstance(ji.prg, CompiledRunner) or isinstance(ji.prg, BufferXfer) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
@@ -27,7 +27,7 @@ class CUDAGraph(MultiDeviceJITGraph):
|
||||
self.cpu_buffers = []
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledASTRunner):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
global_size, local_size = ji.prg.launch_dims(var_vals)
|
||||
|
||||
new_node = cuda.CUgraphNode()
|
||||
@@ -81,12 +81,12 @@ class CUDAGraph(MultiDeviceJITGraph):
|
||||
|
||||
# Update var_vals in the c_args struct.
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
|
||||
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars):
|
||||
setattr(self.updatable_nodes[j][2], f'v{i}', var_vals[v])
|
||||
|
||||
# Update launch dims in the kern_params struct.
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals))
|
||||
self.set_kernel_node_launch_dims(self.updatable_nodes[j][1], *cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals))
|
||||
|
||||
# Update graph nodes with the updated structs.
|
||||
for node, c_node_params, c_args, is_copy in self.updatable_nodes.values():
|
||||
|
||||
@@ -2,7 +2,7 @@ import ctypes, collections, time, itertools
|
||||
from typing import List, Any, Dict, cast, Optional, Union, Tuple
|
||||
from tinygrad.helpers import GraphException, init_c_var, round_up
|
||||
from tinygrad.buffer import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
|
||||
from tinygrad.device import Compiled, CompiledRunner, BufferXfer, MultiDeviceJITGraph, update_stats, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
@@ -36,7 +36,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
# Check all jit items are compatible.
|
||||
compiled_devices = set()
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledASTRunner): compiled_devices.add(ji.prg.device)
|
||||
if isinstance(ji.prg, CompiledRunner): compiled_devices.add(ji.prg.device)
|
||||
elif isinstance(ji.prg, BufferXfer):
|
||||
for x in ji.rawbufs[0:2]: compiled_devices.add(Device[cast(Buffer, x).device])
|
||||
else: raise GraphException
|
||||
@@ -47,13 +47,13 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
# Allocate kernel args.
|
||||
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
|
||||
for ji in self.jit_cache:
|
||||
if isinstance(ji.prg, CompiledASTRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
if isinstance(ji.prg, CompiledRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz, BufferOptions()) for dev,sz in kernargs_size.items()}
|
||||
|
||||
# Fill initial arguments.
|
||||
self.ji_kargs_structs: Dict[int, ctypes.Structure] = {}
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if not isinstance(ji.prg, CompiledASTRunner): continue
|
||||
if not isinstance(ji.prg, CompiledRunner): continue
|
||||
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
|
||||
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
|
||||
for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf)
|
||||
@@ -75,7 +75,7 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
for dev in self.devices: self.virt_aql_queues[dev].submit_barrier([], self.kickoff_signals[dev])
|
||||
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
if isinstance(ji.prg, CompiledASTRunner):
|
||||
if isinstance(ji.prg, CompiledRunner):
|
||||
wait_signals = self.access_resources(ji.rawbufs[(outs:=ji.prg.outcount):], ji.rawbufs[:outs], new_dependency=j, sync_with_aql_packets=False)
|
||||
for i in range(0, len(wait_signals), 5):
|
||||
self.virt_aql_queues[ji.prg.device].submit_barrier(wait_signals[i:i+5])
|
||||
@@ -131,12 +131,12 @@ class HSAGraph(MultiDeviceJITGraph):
|
||||
|
||||
# Update var_vals
|
||||
for j in self.jc_idxs_with_updatable_var_vals:
|
||||
for i,v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
|
||||
for i,v in enumerate(cast(CompiledRunner, self.jit_cache[j].prg).vars):
|
||||
self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[v])
|
||||
|
||||
# Update launch dims
|
||||
for j in self.jc_idxs_with_updatable_launch_dims:
|
||||
gl, lc = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
gl, lc = cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.packets[j].workgroup_size_x = lc[0]
|
||||
self.packets[j].workgroup_size_y = lc[1]
|
||||
self.packets[j].workgroup_size_z = lc[2]
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import List, Any, Dict, cast, Optional
|
||||
import Metal
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import dedup, unwrap2, GraphException
|
||||
from tinygrad.device import Buffer, CompiledASTRunner, update_stats
|
||||
from tinygrad.device import Buffer, CompiledRunner, update_stats
|
||||
from tinygrad.engine.realize import ExecItem
|
||||
from tinygrad.engine.jit import get_input_replace, get_jit_stats, get_jc_idxs_with_updatable_launch_dims
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
@@ -10,7 +10,7 @@ from tinygrad.runtime.ops_metal import MetalDevice, wait_check
|
||||
|
||||
class MetalGraph:
|
||||
def __init__(self, device:MetalDevice, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int]):
|
||||
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException
|
||||
if not all(isinstance(ji.prg, CompiledRunner) for ji in jit_cache): raise GraphException
|
||||
|
||||
self.jit_cache = jit_cache
|
||||
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
|
||||
@@ -30,7 +30,7 @@ class MetalGraph:
|
||||
if len(var_vals): self.int_buf = self.device.allocator.alloc(len(var_vals)*dtypes.int32.itemsize)
|
||||
all_resources = [self.int_buf] if len(var_vals) else []
|
||||
for j,ji in enumerate(self.jit_cache):
|
||||
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
|
||||
prg: CompiledRunner = cast(CompiledRunner, ji.prg)
|
||||
descriptor = Metal.MTLComputePipelineDescriptor.new()
|
||||
descriptor.setComputeFunction_(prg.clprg.fxn)
|
||||
descriptor.setSupportIndirectCommandBuffers_(True)
|
||||
@@ -62,7 +62,7 @@ class MetalGraph:
|
||||
for (j,i),input_idx in self.input_replace.items():
|
||||
self.icb.indirectComputeCommandAtIndex_(j).setKernelBuffer_offset_atIndex_(input_rawbuffers[input_idx]._buf, 0, i)
|
||||
for j in self.jc_idx_with_updatable_launch_dims:
|
||||
global_size, local_size = cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
global_size, local_size = cast(CompiledRunner, self.jit_cache[j].prg).launch_dims(var_vals)
|
||||
self.icb.indirectComputeCommandAtIndex_(j).concurrentDispatchThreadgroups_threadsPerThreadgroup_(Metal.MTLSize(*global_size), Metal.MTLSize(*local_size)) # noqa: E501
|
||||
for j, value in enumerate(var_vals.values()): self.int_buf_view[j] = value
|
||||
command_buffer = self.device.mtl_queue.commandBuffer()
|
||||
|
||||
Reference in New Issue
Block a user