Initial commit to resolve merge conflicts

rename tl.float8e4 to tl.float8e4nv to align with upstream

ROCM IFU: Fix python arch issues

ROCM IFU: Fix kernel launcher

ROCM IFU: Fix merge conflicts

fix debug build

Set correct threadsPerCTA
This commit is contained in:
Jason Furmanek
2023-09-12 20:43:59 +00:00
parent 74fd8e9754
commit e5d7bb4fae
36 changed files with 414 additions and 1005 deletions

View File

@@ -363,17 +363,10 @@ class JITFunction(KernelInterface[T]):
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
src = f"""
<<<<<<< HEAD
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel
sig_key = {sig_keys},
=======
import triton
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
assert num_ctas > 0