diff --git a/tinygrad/runtime/graph/hsa.py b/tinygrad/runtime/graph/hsa.py index 5f8fa5cfe2..ca369456d5 100644 --- a/tinygrad/runtime/graph/hsa.py +++ b/tinygrad/runtime/graph/hsa.py @@ -1,6 +1,6 @@ import ctypes, collections, time, itertools from typing import List, Any, Dict, cast, Optional, Union, Tuple -from tinygrad.helpers import GraphException, init_c_var +from tinygrad.helpers import GraphException, init_c_var, round_up from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats from tinygrad.shape.symbolic import Variable from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler @@ -45,9 +45,9 @@ class HSAGraph(MultiDeviceJITGraph): self.devices: List[HSADevice] = list(compiled_devices) #type:ignore # Allocate kernel args. - kernargs_size: Dict[HSADevice, int] = collections.defaultdict(int) + kernargs_size: Dict[Compiled, int] = collections.defaultdict(int) for ji in self.jit_cache: - if isinstance(ji.prg, CompiledASTRunner): kernargs_size[cast(HSADevice, ji.prg.device)] += (ctypes.sizeof(ji.prg.clprg.args_struct_t)+15) & ~15 + if isinstance(ji.prg, CompiledASTRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz) for dev,sz in kernargs_size.items()} # Fill initial arguments. @@ -55,7 +55,7 @@ class HSAGraph(MultiDeviceJITGraph): for j,ji in enumerate(self.jit_cache): if not isinstance(ji.prg, CompiledASTRunner): continue self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device]) - kernargs_ptrs[ji.prg.device] += (ctypes.sizeof(ji.prg.clprg.args_struct_t) + 15) & ~15 + kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16) for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf) for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]]) diff --git a/tinygrad/runtime/ops_hsa.py b/tinygrad/runtime/ops_hsa.py index 1043cd5150..aabb5e0856 100644 --- a/tinygrad/runtime/ops_hsa.py +++ b/tinygrad/runtime/ops_hsa.py @@ -250,7 +250,7 @@ class HSADevice(Compiled): def alloc_kernargs(self, sz): if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz: self._new_kernargs_region(int(self.kernarg_pool_sz * 2)) result = self.kernarg_next_addr - self.kernarg_next_addr = (self.kernarg_next_addr + sz + 15) & (~15) # align to 16 bytes + self.kernarg_next_addr = round_up(self.kernarg_next_addr + sz, 16) return result def _new_kernargs_region(self, sz:int):