mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-19 02:44:40 -05:00
RDNA3: restore launch bounds (#3672)
* bring launch bounds back * works * that second flag didn't do anything * fix linter
This commit is contained in:
@@ -3,7 +3,7 @@ import math, functools
|
||||
from collections import defaultdict, Counter
|
||||
from tinygrad.codegen.linearizer import UOps, UOp
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
|
||||
from tinygrad.helpers import strip_parens, getenv
|
||||
from tinygrad.helpers import strip_parens, getenv, prod
|
||||
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType
|
||||
from tinygrad.codegen.uops import UOpGraph
|
||||
|
||||
@@ -62,12 +62,13 @@ class CStyleLanguage(NamedTuple):
|
||||
out_val = f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]"
|
||||
return self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val
|
||||
|
||||
def get_kernel_modifier(self, uops:UOpGraph) -> str: return ""
|
||||
def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:UOpGraph, prefix=None) -> str:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _,(dtype,_) in bufs) else "" # noqa: E501
|
||||
buftypes = [(name,f"{'write_only' if mutable else 'read_only'} image2d_t" if dtype.name.startswith('image') else
|
||||
("" if mutable else "const ")+self.buffer_prefix+self.render_dtype(dtype)+"*"+self.buffer_suffix if isinstance(dtype, PtrDType) else
|
||||
self.arg_int_prefix if dtype == dtypes.int else None) for name,(dtype,mutable) in bufs]
|
||||
prg = ''.join([f"{self.kernel_prefix}void {function_name}(",] +
|
||||
prg = ''.join([f"{self.kernel_prefix}void {self.get_kernel_modifier(uops)}{function_name}(",] +
|
||||
[', '.join([f'{t} {name}' for name,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
return prg if prefix is None else "\n".join(prefix)+f"\n{prg}"
|
||||
@@ -315,4 +316,10 @@ class HIPLanguage(CStyleLanguage):
|
||||
("signed int", "int", 4), ("signed int", "int", 2)]))
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
def get_kernel_modifier(self, uops:UOpGraph) -> str:
|
||||
requiredMaxThreadsPerBlock = prod(u.arg[2] for u in uops if u.uop == UOps.SPECIAL and u.arg[1][0] == "l")
|
||||
# https://clang.llvm.org/docs/AttributeReference.html#amdgpu-flat-work-group-size
|
||||
# NOTE: this makes hlb_cifar10 twice as fast, there may be more gains in tweaking these parameters
|
||||
return f"__attribute__((amdgpu_flat_work_group_size(1, {requiredMaxThreadsPerBlock})))"
|
||||
|
||||
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
|
||||
|
||||
Reference in New Issue
Block a user