From 0f167290238bb54f08d6a7eaff629fa81dcd2c70 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Sun, 10 Mar 2024 10:27:52 -0700 Subject: [PATCH] RDNA3: restore launch bounds (#3672) * bring launch bounds back * works * that second flag didn't do anything * fix linter --- tinygrad/renderer/cstyle.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 56eeda23be..c5e58941b9 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -3,7 +3,7 @@ import math, functools from collections import defaultdict, Counter from tinygrad.codegen.linearizer import UOps, UOp from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps -from tinygrad.helpers import strip_parens, getenv +from tinygrad.helpers import strip_parens, getenv, prod from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.codegen.uops import UOpGraph @@ -62,12 +62,13 @@ class CStyleLanguage(NamedTuple): out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val + def get_kernel_modifier(self, uops:UOpGraph) -> str: return "" def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str: tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501 buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else ("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs] - prg = ''.join([f"{self.kernel_prefix}void {function_name}(",] + + prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] + [', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" @@ -315,4 +316,10 @@ class HIPLanguage(CStyleLanguage): ("signed int", "int", 4), ("signed int", "int", 2)])) return super().render_kernel(function_name, kernel, bufs, uops, prefix) + def get_kernel_modifier(self, uops:UOpGraph) -> str: + requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop == UOps.SPECIAL and u.arg[1][0] == "l") + # https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size + # NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters + return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))" + HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())