[FRONTEND] Added typehints support to speedup triton kernel launch (#1431)

One of the possible optimizations for kernel launch overhead. Basically,
we are trying to avoid having to run `hasattr` and `isinstance` for each
argument, by adding typehints to the kernel definition. Also, added a
unit test to regression to make sure we keep the launch overhead within
an expected range.
This commit is contained in:
zahimoud
2023-03-28 22:37:34 -07:00
committed by GitHub
parent ee593fca0b
commit 73b124155b
2 changed files with 86 additions and 6 deletions

View File

@@ -1,11 +1,20 @@
import gc
import importlib
import os
import sys
import tempfile
import textwrap
import time
import tracemalloc
from typing import Tuple
import torch
import triton
import triton.language as tl
LATENCY_THRESHOLD_US = 43
def test_memory_leak() -> None:
@@ -33,3 +42,64 @@ def test_memory_leak() -> None:
assert end - begin < 1000
finally:
tracemalloc.stop()
def test_kernel_launch_latency() -> None:
def define_kernel(kernel_name: str, num_tensor_args: int) -> str:
arg_str = ",".join([f"arg{i}: torch.Tensor" for i in range(num_tensor_args)])
arg_str += ", n_elements: int, BLOCK_SIZE: tl.constexpr"
func_str = f"""
import torch
import triton
import triton.language as tl
@triton.jit
def {kernel_name}({arg_str}):
pass
"""
with tempfile.NamedTemporaryFile(mode="w+t", suffix=".py", delete=False) as temp_file:
temp_file.write(textwrap.dedent(func_str))
temp_file_path = temp_file.name
return temp_file_path
def import_kernel(file_path, kernel_name):
directory, filename = os.path.split(file_path)
module_name, _ = os.path.splitext(filename)
sys.path.insert(0, directory)
module = importlib.import_module(module_name)
kernel = getattr(module, kernel_name)
return kernel
def empty(*kernel_args: Tuple[torch.Tensor]):
first_arg = kernel_args[0]
n_elements = first_arg.numel()
grid = (triton.cdiv(n_elements, 1024),)
device = torch.cuda.current_device()
# Warmup
empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device)
torch.cuda.synchronize()
# Measure launch overhead at steady state
num_runs = 1000
start_time = time.time()
for i in range(num_runs):
empty_kernel[grid](*kernel_args, n_elements, BLOCK_SIZE=1024, device=device)
end_time = time.time()
latency_us = (end_time - start_time) / num_runs * 1e6
assert latency_us < LATENCY_THRESHOLD_US, "Kernel launch time has increased!"
num_tensor_args = 40
kernel_name = 'empty_kernel'
file_path = define_kernel(kernel_name, num_tensor_args)
empty_kernel = import_kernel(file_path, kernel_name)
# Initialize random tensors for the empty_kernel
torch.manual_seed(0)
size = 1024
kernel_args = (torch.rand(size, device='cuda') for i in range(num_tensor_args))
# Run empty, which would run empty_kernel internally
empty(*kernel_args)

View File

@@ -247,14 +247,23 @@ class JITFunction(KernelInterface[T]):
for i, arg in enumerate(regular_args):
if i in self.do_not_specialize:
continue
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
f'else (False,)']
arg_annotation = self.__annotations__.get(arg, None)
if not arg_annotation:
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") '
f'else ({arg} % {JITFunction.divisibility} == 0, {arg} == 1) if isinstance({arg}, int) '
f'else (False,)']
elif arg_annotation == 'torch.Tensor':
specializations += [f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)']
elif arg_annotation == 'int':
specializations += [f'({arg} % {JITFunction.divisibility} == 0, {arg} == 1)']
else:
specializations += ['(False,)']
spec_keys = ', '.join(specializations)
grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names])
src = f"""
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False):
def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stages=3, extern_libs=None, stream=None, warmup=False, device=None):
sig_key = {sig_keys},
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
@@ -268,8 +277,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
grid_0 = grid[0]
grid_1 = grid[1] if grid_size > 1 else 1
grid_2 = grid[2] if grid_size > 2 else 1
device = get_current_device()
set_current_device(device)
if device is None:
device = get_current_device()
set_current_device(device)
if stream is None and not warmup:
stream = get_cuda_stream(device)
try: