mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] Added typehints support to speedup triton kernel launch (#1431)
One of the possible optimizations for kernel launch overhead. Basically, we are trying to avoid having to run `hasattr` and `isinstance` for each argument, by adding typehints to the kernel definition. Also, added a unit test to regression to make sure we keep the launch overhead within an expected range.
This commit is contained in:
@@ -1,11 +1,20 @@
|
||||
import gc
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import tracemalloc
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
LATENCY_THRESHOLD_US = 43
|
||||
|
||||
|
||||
def test_memory_leak() -> None:
|
||||
|
||||
@@ -33,3 +42,64 @@ def test_memory_leak() -> None:
|
||||
assert end - begin < 1000
|
||||
finally:
|
||||
tracemalloc.stop()
|
||||
|
||||
|
||||
def test_kernel_launch_latency() -> None:
|
||||
def define_kernel(kernel_name: str, num_tensor_args: int) -> str:
|
||||
arg_str = ",".join([f"arg{i}: torch.Tensor" for i in range(num_tensor_args)])
|
||||
arg_str += ", n_elements: int, BLOCK_SIZE: tl.constexpr"
|
||||
func_str = f"""
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def {kernel_name}({arg_str}):
|
||||
pass
|
||||
"""
|
||||
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py", delete=False) as temp_file:
|
||||
temp_file.write(textwrap.dedent(func_str))
|
||||
temp_file_path = temp_file.name
|
||||
|
||||
return temp_file_path
|
||||
|
||||
def import_kernel(file_path, kernel_name):
|
||||
directory, filename = os.path.split(file_path)
|
||||
module_name, _ = os.path.splitext(filename)
|
||||
sys.path.insert(0, directory)
|
||||
|
||||
module = importlib.import_module(module_name)
|
||||
kernel = getattr(module, kernel_name)
|
||||
return kernel
|
||||
|
||||
def empty(*kernel_args: Tuple[torch.Tensor]):
|
||||
first_arg = kernel_args[0]
|
||||
n_elements = first_arg.numel()
|
||||
grid = (triton.cdiv(n_elements, 1024),)
|
||||
device = torch.cuda.current_device()
|
||||
# Warmup
|
||||
empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device)
|
||||
torch.cuda.synchronize()
|
||||
# Measure launch overhead at steady state
|
||||
num_runs = 1000
|
||||
start_time = time.time()
|
||||
for i in range(num_runs):
|
||||
empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device)
|
||||
end_time = time.time()
|
||||
latency_us = (end_time - start_time) / num_runs * 1e6
|
||||
|
||||
assert latency_us < LATENCY_THRESHOLD_US, "Kernel launch time has increased!"
|
||||
|
||||
num_tensor_args = 40
|
||||
kernel_name = 'empty_kernel'
|
||||
file_path = define_kernel(kernel_name, num_tensor_args)
|
||||
empty_kernel = import_kernel(file_path, kernel_name)
|
||||
|
||||
# Initialize random tensors for the empty_kernel
|
||||
torch.manual_seed(0)
|
||||
size = 1024
|
||||
kernel_args = (torch.rand(size, device='cuda') for i in range(num_tensor_args))
|
||||
|
||||
# Run empty, which would run empty_kernel internally
|
||||
empty(*kernel_args)
|
||||
|
||||
@@ -247,14 +247,23 @@ class JITFunction(KernelInterface[T]):
|
||||
for i, arg in enumerate(regular_args):
|
||||
if i in self.do_not_specialize:
|
||||
continue
|
||||
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
|
||||
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
|
||||
f'else (False,)']
|
||||
arg_annotation = self.__annotations__.get(arg, None)
|
||||
if not arg_annotation:
|
||||
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
|
||||
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
|
||||
f'else (False,)']
|
||||
elif arg_annotation == 'torch.Tensor':
|
||||
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)']
|
||||
elif arg_annotation == 'int':
|
||||
specializations += [f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)']
|
||||
else:
|
||||
specializations += ['(False,)']
|
||||
|
||||
spec_keys = ', '.join(specializations)
|
||||
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
|
||||
|
||||
src = f"""
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
|
||||
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None):
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
@@ -268,8 +277,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
grid_0 = grid[0]
|
||||
grid_1 = grid[1] if grid_size > 1 else 1
|
||||
grid_2 = grid[2] if grid_size > 2 else 1
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
stream = get_cuda_stream(device)
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user