make alignment readable (#3766)

This commit is contained in:
nimlgen
2024-03-15 23:18:40 +03:00
committed by GitHub
parent 8ea53951c1
commit 91e181ee02
2 changed files with 5 additions and 5 deletions

View File

@@ -1,6 +1,6 @@
import ctypes, collections, time, itertools
from typing import List, Any, Dict, cast, Optional, Union, Tuple
from tinygrad.helpers import GraphException, init_c_var
from tinygrad.helpers import GraphException, init_c_var, round_up
from tinygrad.device import Compiled, Buffer, CompiledASTRunner, BufferXfer, MultiDeviceJITGraph, update_stats
from tinygrad.shape.symbolic import Variable
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
@@ -45,9 +45,9 @@ class HSAGraph(MultiDeviceJITGraph):
self.devices: List[HSADevice] = list(compiled_devices) #type:ignore
# Allocate kernel args.
kernargs_size: Dict[HSADevice, int] = collections.defaultdict(int)
kernargs_size: Dict[Compiled, int] = collections.defaultdict(int)
for ji in self.jit_cache:
if isinstance(ji.prg, CompiledASTRunner): kernargs_size[cast(HSADevice, ji.prg.device)] += (ctypes.sizeof(ji.prg.clprg.args_struct_t)+15) & ~15
if isinstance(ji.prg, CompiledASTRunner): kernargs_size[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
kernargs_ptrs: Dict[Compiled, int] = {dev:dev.allocator._alloc(sz) for dev,sz in kernargs_size.items()}
# Fill initial arguments.
@@ -55,7 +55,7 @@ class HSAGraph(MultiDeviceJITGraph):
for j,ji in enumerate(self.jit_cache):
if not isinstance(ji.prg, CompiledASTRunner): continue
self.ji_kargs_structs[j] = ji.prg.clprg.args_struct_t.from_address(kernargs_ptrs[ji.prg.device])
kernargs_ptrs[ji.prg.device] += (ctypes.sizeof(ji.prg.clprg.args_struct_t) + 15) & ~15
kernargs_ptrs[ji.prg.device] += round_up(ctypes.sizeof(ji.prg.clprg.args_struct_t), 16)
for i in range(len(ji.rawbufs)): self.ji_kargs_structs[j].__setattr__(f'f{i}', cast(Buffer, ji.rawbufs[i])._buf)
for i in range(len(ji.prg.vars)): self.ji_kargs_structs[j].__setattr__(f'v{i}', var_vals[ji.prg.vars[i]])

View File

@@ -250,7 +250,7 @@ class HSADevice(Compiled):
def alloc_kernargs(self, sz):
if self.kernarg_next_addr + sz >= self.kernarg_start_addr + self.kernarg_pool_sz: self._new_kernargs_region(int(self.kernarg_pool_sz * 2))
result = self.kernarg_next_addr
self.kernarg_next_addr = (self.kernarg_next_addr + sz + 15) & (~15) # align to 16 bytes
self.kernarg_next_addr = round_up(self.kernarg_next_addr + sz, 16)
return result
def _new_kernargs_region(self, sz:int):