[FRONTEND] Added interpreter mode (#1573)

Simple mechanism to run Triton kernels on PyTorch for debugging purpose
(upstream from Kernl).

Todo:
- random grid iteration
- support of atomic ops
- more unit tests
- cover new APIs?
This commit is contained in:
Michaël Benesty
2023-05-08 23:28:20 +02:00
committed by GitHub
parent 132fe1bb01
commit 858a2f0a5e
11 changed files with 1063 additions and 68 deletions

View File

@@ -249,6 +249,7 @@ setup(
"triton/_C",
"triton/common",
"triton/compiler",
"triton/debugger",
"triton/language",
"triton/language/extra",
"triton/ops",

View File

@@ -0,0 +1,69 @@
import random
import torch
import triton
import triton.language as tl
from triton.debugger.debugger import program_ids_from_grid
def test_addition():
@triton.jit(interpret=True)
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
a = torch.rand((128,), device="cuda")
b = torch.rand((128,), device="cuda")
expected = a + b
output = torch.empty((128,), device="cuda")
def grid(meta):
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)
add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)
assert torch.allclose(expected, output, atol=1e-2, rtol=0)
def test_program_ids_from_grid():
random.seed(123)
grid = (3, 4)
expected_combinations = 3 * 4
unique_combinations = set(program_ids_from_grid(grid))
assert len(unique_combinations) == expected_combinations
first_run = list(program_ids_from_grid(grid))
second_run = list(program_ids_from_grid(grid))
assert first_run != second_run
def test_atomic():
@triton.jit(interpret=True)
def atomic(
x_ptr,
):
pid = tl.program_id(axis=0)
tl.atomic_add(x_ptr + pid, 1)
t = tl.atomic_xchg(x_ptr + pid, 3)
t += 1 # 2
tl.atomic_cas(x_ptr + pid, 3, t) # match
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
nb_dim = 16
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")
atomic[(nb_dim, )](a)
assert torch.allclose(a, torch.full_like(a, 2))

View File

@@ -1605,68 +1605,68 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
)
layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
]
# layouts = [
# BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
# BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
# BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
# BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
# ]
@pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
@pytest.mark.parametrize("src_layout", layouts)
def test_reduce_2d(M, N, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%12 = "tt.reduce"(%11) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('int32')
x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')
z = np.zeros((1,)).astype('int32')
x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)
pgm = kernel[(1, 1, 1)](x_tri, z_tri)
z_ref = np.sum(x)
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
# @pytest.mark.parametrize("M, N", [[32, 128], [128, 128], [128, 32]])
# @pytest.mark.parametrize("src_layout", layouts)
# def test_reduce_2d(M, N, src_layout, device='cuda'):
# ir = f"""
# #src = {src_layout}
# module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
# tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
# %cst = arith.constant dense<{M}> : tensor<{M}x1xi32, #src>
# %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
# %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
# %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
# %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
# %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
# %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
# %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
# %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
# %8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
# %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
# %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
# %11 = "tt.reduce"(%10) ({{
# ^bb0(%arg2: i32, %arg3: i32):
# %13 = arith.addi %arg2, %arg3 : i32
# tt.reduce.return %13 : i32
# }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
# %12 = "tt.reduce"(%11) ({{
# ^bb0(%arg2: i32, %arg3: i32):
# %13 = arith.addi %arg2, %arg3 : i32
# tt.reduce.return %13 : i32
# }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
# tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
# tt.return
# }}
# }}
# """
# import tempfile
# with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
# f.write(ir)
# f.flush()
# kernel = triton.compile(f.name)
#
# rs = RandomState(17)
# x = rs.randint(0, 4, (M, N)).astype('int32')
# x = (x.view('uint32') & np.uint32(0xffffe000)).view('int32')
#
# z = np.zeros((1,)).astype('int32')
#
# x_tri = torch.tensor(x, device=device)
# z_tri = torch.tensor(z, device=device)
#
# pgm = kernel[(1, 1, 1)](x_tri, z_tri)
#
# z_ref = np.sum(x)
#
# np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
def test_generic_reduction(device='cuda'):

View File

@@ -18,6 +18,8 @@ from .runtime import (
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from .debugger.debugger import program_ids_from_grid
from . import language
from . import testing
@@ -41,6 +43,7 @@ __all__ = [
"runtime",
"TensorWrapper",
"testing",
"program_ids_from_grid",
]

View File

View File

@@ -0,0 +1,9 @@
from typing import Tuple
import dataclasses
@dataclasses.dataclass
class ExecutionContext:
program_id: Tuple[int]
program_size: Tuple[int]

View File

@@ -0,0 +1,170 @@
import itertools
import random
from typing import Tuple
import triton
import triton.language as tl
from .core import ExecutionContext
from .memory_map import MemoryMap
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
debugger_constexpr)
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
tl_method_backup = {}
def get_proxy_method(proxy, name):
method = getattr(proxy, name)
def fun(*args, **kwarg):
return method(*args, **kwarg)
return fun
def attach_triton(module, proxy):
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
for name in method_list:
if hasattr(module, name):
attr = getattr(module, name)
tl_method_backup[name] = attr
if callable(attr):
setattr(module, name, get_proxy_method(proxy, name))
else:
setattr(module, name, getattr(proxy, name))
def detach_triton(module):
for name, method in tl_method_backup.items():
setattr(module, name, method)
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
# reverse the grid dimensions and generate the range for each dimension
reversed_grid = reversed(grid)
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
# gen all combinations
index_combinations = list(itertools.product(*ranges_for_each_dimension))
random.shuffle(index_combinations)
for index_combination in index_combinations:
yield index_combination
class DebuggerFunction:
def __init__(self, func, grid=(1,)):
self.func = func
self.grid = grid
def _is_constexpr(self, name):
return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr
def _get_constexpr(self):
result = []
for name, annotation in self.func.__annotations__.items():
if annotation is triton.language.core.constexpr:
result.append(name)
return result
def _assert_constexpr(self, **kwargs):
constexp = self._get_constexpr()
missing = [i for i in constexp if i not in kwargs.keys()]
assert len(missing) == 0, f"You must specify constexpr {missing}"
def _get_grid(self, **kwargs):
if callable(self.grid):
return self.grid(kwargs)
else:
return self.grid
def __call__(self, *args, **kwargs):
self._assert_constexpr(**kwargs)
memory = MemoryMap()
def convert_arg(v):
name, arg = v
if torch.is_tensor(arg):
ptr = memory.add_tensor(arg)
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
if self._is_constexpr(name):
return debugger_constexpr(arg)
return WrappedTensor(_primitive_to_tensor(arg))
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
grid = self._get_grid(**kwargs)
for program_id in program_ids_from_grid(grid):
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
attach_triton(tl, proxy)
self.func(*new_args, **new_kwargs)
detach_triton(tl)
class GridSelector:
"""
Entry point of the debugger
"""
def __init__(self, func):
version = torch.__version__
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
self.func = func
def __getitem__(self, grid):
return DebuggerFunction(self.func, grid)
def __call__(self, *args, **kwargs):
return DebuggerFunction(self.func)(*args, **kwargs)
class AutotuneGridSelector:
def __init__(self, func, autotune_params):
self.func = func
self.autotune_params = autotune_params
def __getitem__(self, grid):
return AutotuneRunner(self.func, self.autotune_params, grid)
def __call__(self, *args, **kwargs):
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
class AutotuneRunner:
def __init__(self, func, autotune_params, grid=None):
self.func = func
self.autotune_params = autotune_params
self.grid = grid
def __call__(self, *args, **kwargs):
assert len(self.autotune_params["configs"]) >= 1
for config in self.autotune_params["configs"][1:]:
def convert_arg(v):
if torch.is_tensor(v):
return torch.clone(v)
return v
new_args = tuple(map(convert_arg, args))
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
if self.grid:
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
else:
self.func(*new_args, **new_kwargs, **config.kwargs)
main_config = self.autotune_params["configs"][0]
if self.grid:
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
else:
self.func(*args, **kwargs, **main_config.kwargs)
def triton_debug_autotune(**kwars):
def wrapper(func):
return AutotuneGridSelector(func, kwars)
return wrapper

View File

@@ -0,0 +1,100 @@
import dataclasses
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
@dataclasses.dataclass
class RegisteredStorage:
storage: torch.Storage
dtype: torch.dtype
size: int
ptr: int
@property
def end_ptr(self) -> int:
return self.ptr + self.size
@property
def access_tensor(self) -> torch.Tensor:
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
def ensure_immutable(self):
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
class MemoryMap:
storages: [RegisteredStorage]
def __init__(self):
self.storages = []
def _get_registered_storage(self, pointer: torch.Tensor):
max_pointer = torch.max(pointer).item()
min_pointer = torch.min(pointer).item()
registered_storage = next(
filter(
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
),
None,
)
if registered_storage is None:
raise Exception("Storage not found or pointers spanning multiple tensors")
registered_storage.ensure_immutable()
return registered_storage
def add_tensor(self, t: torch.Tensor):
storage = t.untyped_storage()
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
return t.data_ptr()
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
):
assert pointer.is_cuda
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert mask.is_cuda
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
# Todo: The type is wrong here, we can't determine the correct type
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
block[mask] = access_tensor[index_tensor[mask]]
return block
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
return
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)

View File

@@ -0,0 +1,621 @@
import triton
from .core import ExecutionContext
from .memory_map import MemoryMap
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
def _primitive_to_tensor(x):
"""
Converts various Python primitive data types to PyTorch tensor.
"""
tensor_args = {"device": "cuda"}
if isinstance(x, bool):
return torch.tensor([x], dtype=torch.bool, **tensor_args)
elif isinstance(x, int):
if -(2**31) <= x < 2**31:
return torch.tensor([x], dtype=torch.int32, **tensor_args)
elif -(2**63) <= x < 2**63:
return torch.tensor([x], dtype=torch.int64, **tensor_args)
else:
raise RuntimeError(f"Nonrepresentable integer {x}.")
elif isinstance(x, float):
return torch.tensor([x], dtype=torch.float32, **tensor_args)
elif torch.is_tensor(x):
return x
elif isinstance(x, WrappedTensor):
return x
elif isinstance(x, debugger_constexpr):
if x.value is None:
return None
return _primitive_to_tensor(x.value)
elif x is None:
return None
assert False, f"cannot convert {x} of type {type(x)} to tensor"
def _infer_tensor(func):
"""
A decorator function to harmonize function args:
- converts primitives to PyTorch tensors
- wraps PyTorch tensors with WrappedTensors
"""
def wrapper(*args):
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
return func(*new_args)
return wrapper
def _tensor_operation(func):
"""
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
"""
def wrapper(*args, **kwargs):
for arg in args:
assert not torch.is_tensor(arg), "unexpected tensor argument"
def unwrap_tensor(v):
if isinstance(v, WrappedTensor):
return v.tensor
if isinstance(v, debugger_constexpr):
return v.value
return v
new_args = tuple(map(unwrap_tensor, args))
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
result = func(args[0], *new_args[1:], **new_kwargs)
return WrappedTensor(result) if torch.is_tensor(result) else result
return wrapper
class debugger_constexpr:
def __init__(self, value):
if isinstance(value, debugger_constexpr):
self.value = value.value
else:
self.value = value
def __str__(self) -> str:
return "debugger_constexpr(" + str(self.value) + ")"
def __index__(self) -> int:
return self.value
def __bool__(self):
return bool(self.value)
def __ge__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value >= other
def __gt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value > other
def __le__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value <= other
def __lt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value < other
def __eq__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value == other
def __or__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __ror__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __and__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def __rand__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def to(self, dtype, bitcast=False, _builder=None):
if dtype in [torch.int64]:
ret_ty = int
elif dtype == torch.bool:
ret_ty = bool
elif dtype in [torch.float64]:
ret_ty = float
else:
raise ValueError("dtype not supported in debugger")
return debugger_constexpr(ret_ty(self.value))
class WrappedTensor:
def __init__(self, tensor):
self.tensor = tensor
def __index__(self) -> int:
return self.tensor.item()
def __str__(self) -> str:
return "wrapped_" + str(self.tensor)
def __bool__(self) -> bool:
return torch.all(self.tensor == True).item() # noqa: E712
@property
def dtype(self):
return self.tensor.dtype
@_infer_tensor
@_tensor_operation
def __add__(self, other):
return torch.add(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __radd__(self, other):
return self.__add__(other)
@_infer_tensor
@_tensor_operation
def __sub__(self, other):
return torch.sub(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rsub__(self, other):
return torch.sub(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mul__(self, other):
return torch.mul(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmul__(self, other):
return self.__mul__(other)
@_infer_tensor
@_tensor_operation
def __truediv__(self, other):
return torch.div(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rtruediv__(self, other):
return torch.div(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __floordiv__(self, other):
return torch.floor_divide(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rfloordiv__(self, other):
return torch.floor_divide(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mod__(self, other):
return torch.remainder(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmod__(self, other):
return torch.remainder(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __neg__(self):
return -self.tensor
@_infer_tensor
@_tensor_operation
def __invert__(self):
return ~self.tensor
@_infer_tensor
@_tensor_operation
def __and__(self, other):
return torch.bitwise_and(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __or__(self, other):
return torch.bitwise_or(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __xor__(self, other):
return torch.bitwise_xor(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __lshift__(self, other):
return torch.bitwise_left_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rshift__(self, other):
return torch.bitwise_right_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __gt__(self, other):
return self.tensor > other
@_infer_tensor
@_tensor_operation
def __rgt__(self, other):
return other > self.tensor
@_infer_tensor
@_tensor_operation
def __ge__(self, other):
return self.tensor >= other
@_infer_tensor
@_tensor_operation
def __rge__(self, other):
return other >= self.tensor
@_infer_tensor
@_tensor_operation
def __lt__(self, other):
return self.tensor < other
@_infer_tensor
@_tensor_operation
def __rlt__(self, other):
return other < self.tensor
@_infer_tensor
@_tensor_operation
def __le__(self, other):
return self.tensor <= other
@_infer_tensor
@_tensor_operation
def __rle__(self, other):
return other <= self.tensor
@_infer_tensor
@_tensor_operation
def __eq__(self, other):
return torch.equal(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __ne__(self, other):
return not torch.equal(self.tensor, other)
@_tensor_operation
def __getitem__(self, slices):
return self.tensor.__getitem__(slices)
# if isinstance(slices, slice):
# slices = [slices]
# src_shape = self.shape
# dst_shape = []
# curr = 0
# for sl in slices:
# if isinstance(sl, constexpr) and sl.value is None:
# dst_shape.append(1)
# elif sl == slice(None, None, None):
# dst_shape.append(src_shape[curr].value)
# curr += 1
# ret = torch.reshape(self.tensor, dst_shape, )
# return ret
@_tensor_operation
def to(self, dtype, bitcast=False):
return self.tensor.to(dtype)
# if isinstance(bitcast, constexpr):
# bitcast = bitcast.value
# if bitcast:
# return semantic.bitcast(self, dtype, )
# return semantic.cast(self, dtype, )
def _constexpr_to_value(v):
if isinstance(v, debugger_constexpr):
return v.value
return v
class TritonLangProxy:
_memory_map: MemoryMap
_context: ExecutionContext
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
self._memory_map = memory_map
self._context = context
# Types
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
# constexpr = debugger_constexpr
# Program functions
@_tensor_operation
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
cache_modifier="",
eviction_policy="",
volatile=False,
):
return self._memory_map.load(pointer, mask, other)
@_tensor_operation
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
return self._memory_map.store(pointer, value, mask)
@_tensor_operation
def program_id(self, axis):
assert axis < len(self._context.program_id)
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def num_programs(self, axis):
assert axis < len(self._context.program_size)
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def arange(self, start, end):
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
@_tensor_operation
def zeros(self, shape, dtype):
for i, d in enumerate(shape):
if not isinstance(d, debugger_constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
if isinstance(dtype, triton.language.core.dtype):
if dtype.is_fp32():
dtype = torch.float32
elif dtype.is_fp16():
dtype = torch.float16
elif dtype.is_bf16():
dtype = torch.bfloat16
elif dtype.is_int32():
dtype = torch.int32
elif dtype.is_int16():
dtype = torch.int16
elif dtype.is_int8():
dtype = torch.int8
else:
raise TypeError(f"Unsupported dtype {dtype}")
return torch.zeros(size=shape, dtype=dtype, device="cuda")
@_tensor_operation
def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16):
raise NotImplementedError()
@_tensor_operation
def broadcast(self, input, other):
raise NotImplementedError()
@_tensor_operation
def broadcast_to(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def cat(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def reshape(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
assert input.dtype == other.dtype
if trans_a:
input = input.T
if trans_b:
other = other.T
return torch.matmul(input=input, other=other)
@_tensor_operation
def atomic_cas(self, pointer, cmp, val):
stored = self._memory_map.load(pointer, None, 0.0)
if not isinstance(cmp, torch.Tensor):
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
if not isinstance(val, torch.Tensor):
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
if stored == cmp:
self._memory_map.store(pointer, val, None)
return stored
@_tensor_operation
def atomic_xchg(self, pointer, val, mask=None):
if isinstance(val, int):
val = torch.tensor([val], dtype=torch.int32, device="cuda")
stored = self._memory_map.load(pointer, mask, 0.0)
self._memory_map.store(pointer, val, mask)
return stored
@_tensor_operation
def atomic_add(self, pointer, val, mask=None):
# arbitrary other value as it will masked during storing
stored = self._memory_map.load(pointer, mask, 0.0)
result = stored + val
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_max(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.maximum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_min(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.minimum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_and(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_and(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_or(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_or(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_xor(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_xor(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def where(self, condition, x, y):
condition = _primitive_to_tensor(condition)
x = _primitive_to_tensor(x)
y = _primitive_to_tensor(y)
return torch.where(condition, x, y)
@_tensor_operation
def umulhi(self, x, y):
raise NotImplementedError()
@_tensor_operation
def fdiv(self, x, y, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def exp(self, x):
return torch.exp(x)
@_tensor_operation
def log(self, x):
return torch.log(x)
@_tensor_operation
def cos(self, x):
return torch.cos(x)
@_tensor_operation
def sin(self, x):
return torch.sin(x)
@_tensor_operation
def sqrt(self, x):
return torch.sqrt(x)
@_tensor_operation
def globaltimer(self):
raise NotImplementedError()
@_tensor_operation
def clock(self):
raise NotImplementedError()
@_tensor_operation
def debug_barrier(self):
raise NotImplementedError()
@_tensor_operation
def multiple_of(self, input, values):
return input
@_tensor_operation
def max_contiguous(self, input, values):
return input
@_tensor_operation
def abs(self, x):
return torch.abs(x)
@_tensor_operation
def cdiv(self, x, div):
return (x + div - 1) // div
@_tensor_operation
def minimum(self, x, y):
if isinstance(x, int):
x = torch.tensor(x, device="cuda")
if isinstance(y, int):
y = torch.tensor(y, device="cuda")
return torch.minimum(x, y)
@_tensor_operation
def maximum(self, x, y):
return torch.maximum(x, y)
@_tensor_operation
def sigmoid(self, x):
raise NotImplementedError()
@_tensor_operation
def softmax(self, x, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def ravel(self, x):
raise NotImplementedError()
@_tensor_operation
def swizzle2d(self, i, j, size_i, size_j, size_g):
raise NotImplementedError()
@_tensor_operation
def zeros_like(self, input):
raise NotImplementedError()
@_tensor_operation
def max(self, input, axis=None):
if axis is None:
return torch.max(input)
return torch.max(input, dim=axis).values
@_tensor_operation
def argmax(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def min(self, input, axis=None):
if axis is None:
return torch.min(input)
return torch.min(input, dim=axis).values
@_tensor_operation
def argmin(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def sum(self, input, axis=None):
if axis is None:
return torch.sum(input)
return torch.sum(input, dim=axis)
@_tensor_operation
def xor_sum(self, input, axis):
raise NotImplementedError()

View File

@@ -0,0 +1,18 @@
try:
import torch as _torch
except ImportError:
_torch = None
class TorchWrapper:
"""
Helps in making torch an optional dependency
"""
def __getattr__(self, name):
if _torch is None:
raise ImportError("Triton requires PyTorch to be installed")
return getattr(_torch, name)
torch = TorchWrapper()

View File

@@ -439,6 +439,7 @@ def jit(
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
interpret: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
@@ -460,14 +461,17 @@ def jit(
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
debug=debug,
noinline=noinline,
)
if interpret:
from ..debugger.debugger import GridSelector
return GridSelector(fn)
else:
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)