mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Initial commit to resolve merge conflicts
rename tl.float8e4 to tl.float8e4nv to align with upstream ROCM IFU: Fix python arch issues ROCM IFU: Fix kernel launcher ROCM IFU: Fix merge conflicts fix debug build Set correct threadsPerCTA
This commit is contained in:
@@ -363,17 +363,10 @@ class JITFunction(KernelInterface[T]):
|
||||
args_signature = args_signature + ', ' if len(args_signature) > 0 else ''
|
||||
|
||||
src = f"""
|
||||
<<<<<<< HEAD
|
||||
|
||||
def {self.fn.__name__}({args_signature}, grid=None, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel
|
||||
sig_key = {sig_keys},
|
||||
=======
|
||||
import triton
|
||||
def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None):
|
||||
from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages
|
||||
sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()}
|
||||
>>>>>>> 36fc54b6f28168d3644808bfe299f1ba06a36272
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
assert num_ctas > 0
|
||||
|
||||
Reference in New Issue
Block a user