[FRONTEND] interpreter rewrite (#2321)

This is a new interpreter mode that shares semantic analysis with the
JIT'ed codepath and that the Triton core team is committed to maintain
This commit is contained in:
Philippe Tillet
2023-09-17 14:58:50 -07:00
committed by GitHub
parent 2b066000aa
commit e686b4d6d4
17 changed files with 599 additions and 1033 deletions

View File

@@ -59,8 +59,8 @@ class Package(NamedTuple):
def get_pybind11_package_info():
name = "pybind11-2.10.0"
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
name = "pybind11-2.11.1"
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz"
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
# llvm
@@ -296,7 +296,6 @@ setup(
"triton/_C",
"triton/common",
"triton/compiler",
"triton/interpreter",
"triton/language",
"triton/language/extra",
"triton/ops",

View File

@@ -64,6 +64,7 @@
#include <stdexcept>
#include <string>
#include <pybind11/numpy.h>
namespace py = pybind11;
PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy);
@@ -1961,11 +1962,52 @@ void init_triton_translation(py::module &m) {
});
}
void init_triton_interpreter(py::module &&m) {
using ret = py::return_value_policy;
m.def("load",
[](py::array_t<uint64_t> ptrs, py::array_t<bool> masks, py::array other,
py::dtype ret_dtype) -> py::array {
int numel = ptrs.size();
auto shape =
std::vector<ptrdiff_t>(ptrs.shape(), ptrs.shape() + ptrs.ndim());
py::array ret(ret_dtype, py::array::ShapeContainer{numel});
py::array_t<uint64_t> reshaped_ptrs = ptrs.reshape({numel});
py::array_t<bool> reshaped_masks = masks.reshape({numel});
py::array reshaped_others = other.reshape({numel});
for (size_t i = 0; i < ptrs.size(); ++i) {
if (reshaped_masks.at(i))
memcpy(ret.mutable_data(i),
reinterpret_cast<void *>(reshaped_ptrs.at(i)),
ret_dtype.itemsize());
else
memcpy(ret.mutable_data(i), reshaped_others.data(i),
ret_dtype.itemsize());
}
return ret.reshape(shape);
});
m.def("store", [](py::array_t<uint64_t> ptrs, py::array values,
py::array_t<bool> mask) {
int numel = ptrs.size();
py::array_t<uint64_t> reshaped_ptrs = ptrs.reshape({numel});
py::array_t<int8_t> reshaped_masks = mask.reshape({numel});
py::array reshaped_values = values.reshape({numel});
for (size_t i = 0; i < ptrs.size(); ++i) {
if (reshaped_masks.at(i)) {
memcpy(reinterpret_cast<void *>(reshaped_ptrs.mutable_at(i)),
reshaped_values.data(i), values.dtype().itemsize());
}
}
});
}
void init_triton(py::module &m) {
py::module subm = m.def_submodule("triton");
init_triton_env_vars(subm);
// init_triton_codegen(subm.def_submodule("code_gen"));
init_triton_runtime(subm.def_submodule("runtime"));
init_triton_ir(subm.def_submodule("ir"));
init_triton_interpreter(subm.def_submodule("interpreter"));
init_triton_translation(subm);
}

View File

@@ -1,69 +0,0 @@
import random
import torch
import triton
import triton.language as tl
from triton.interpreter.interpreter import program_ids_from_grid
def test_addition():
@triton.jit(interpret=True)
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
a = torch.rand((128,), device="cuda")
b = torch.rand((128,), device="cuda")
expected = a + b
output = torch.empty((128,), device="cuda")
def grid(meta):
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)
add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)
assert torch.allclose(expected, output, atol=1e-2, rtol=0)
def test_program_ids_from_grid():
random.seed(123)
grid = (3, 4)
expected_combinations = 3 * 4
unique_combinations = set(program_ids_from_grid(grid))
assert len(unique_combinations) == expected_combinations
first_run = list(program_ids_from_grid(grid))
second_run = list(program_ids_from_grid(grid))
assert first_run != second_run
def test_atomic():
@triton.jit(interpret=True)
def atomic(
x_ptr,
):
pid = tl.program_id(axis=0)
tl.atomic_add(x_ptr + pid, 1)
t = tl.atomic_xchg(x_ptr + pid, 3)
t += 1 # 2
tl.atomic_cas(x_ptr + pid, 3, t) # match
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
nb_dim = 16
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")
atomic[(nb_dim, )](a)
assert torch.allclose(a, torch.full_like(a, 2))

View File

@@ -2421,7 +2421,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o
if epilogue == 'chain-dot':
z_ref = np.matmul(z_ref, w)
# compare
# print(z_ref[:,0], z_tri[:,0])
if in_dtype == 'float32':
# XXX: Somehow there's a larger difference when we use float32
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3)

View File

@@ -5,10 +5,10 @@ import triton
import triton.ops
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16),
(4, 48, 1024, 32),
(4, 48, 1024, 64),
(4, 48, 1024, 128)])
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16),
(2, 4, 512, 32),
(2, 4, 512, 64),
(2, 4, 512, 128)])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
@pytest.mark.parametrize('causal', [True, False])
@pytest.mark.parametrize('seq_par', [True, False])
@@ -21,7 +21,8 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par):
pytest.skip('Segmentation fault')
capability = torch.cuda.get_device_capability()
if capability[0] < 8:
interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"]
if not interpreter and capability[0] < 8:
pytest.skip("Flash attention only supported for compute capability < 80")
torch.manual_seed(20)
q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_()

View File

@@ -1,9 +0,0 @@
from typing import Tuple
import dataclasses
@dataclasses.dataclass
class ExecutionContext:
program_id: Tuple[int]
program_size: Tuple[int]

View File

@@ -1,171 +0,0 @@
import itertools
import random
from typing import Tuple
from .. import language as tl
# import .language.core as lcore
from ..language import core as lcore
from . import torch_wrapper
from .core import ExecutionContext
from .memory_map import MemoryMap
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
debugger_constexpr)
torch = torch_wrapper.torch
tl_method_backup = {}
def get_proxy_method(proxy, name):
method = getattr(proxy, name)
def fun(*args, **kwarg):
return method(*args, **kwarg)
return fun
def attach_triton(module, proxy):
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
for name in method_list:
if hasattr(module, name):
attr = getattr(module, name)
tl_method_backup[name] = attr
if callable(attr):
setattr(module, name, get_proxy_method(proxy, name))
else:
setattr(module, name, getattr(proxy, name))
def detach_triton(module):
for name, method in tl_method_backup.items():
setattr(module, name, method)
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
# reverse the grid dimensions and generate the range for each dimension
reversed_grid = reversed(grid)
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
# gen all combinations
index_combinations = list(itertools.product(*ranges_for_each_dimension))
random.shuffle(index_combinations)
for index_combination in index_combinations:
yield index_combination
class DebuggerFunction:
def __init__(self, func, grid=(1,)):
self.func = func
self.grid = grid
def _is_constexpr(self, name):
return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr
def _get_constexpr(self):
result = []
for name, annotation in self.func.__annotations__.items():
if annotation is lcore.constexpr:
result.append(name)
return result
def _assert_constexpr(self, **kwargs):
constexp = self._get_constexpr()
missing = [i for i in constexp if i not in kwargs.keys()]
assert len(missing) == 0, f"You must specify constexpr {missing}"
def _get_grid(self, **kwargs):
if callable(self.grid):
return self.grid(kwargs)
else:
return self.grid
def __call__(self, *args, **kwargs):
self._assert_constexpr(**kwargs)
memory = MemoryMap()
def convert_arg(v):
name, arg = v
if torch.is_tensor(arg):
ptr = memory.add_tensor(arg)
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
if self._is_constexpr(name):
return debugger_constexpr(arg)
return WrappedTensor(_primitive_to_tensor(arg))
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
grid = self._get_grid(**kwargs)
for program_id in program_ids_from_grid(grid):
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
attach_triton(tl, proxy)
self.func(*new_args, **new_kwargs)
detach_triton(tl)
class GridSelector:
"""
Entry point of the debugger
"""
def __init__(self, func):
version = torch.__version__
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
self.func = func
def __getitem__(self, grid):
return DebuggerFunction(self.func, grid)
def __call__(self, *args, **kwargs):
return DebuggerFunction(self.func)(*args, **kwargs)
class AutotuneGridSelector:
def __init__(self, func, autotune_params):
self.func = func
self.autotune_params = autotune_params
def __getitem__(self, grid):
return AutotuneRunner(self.func, self.autotune_params, grid)
def __call__(self, *args, **kwargs):
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
class AutotuneRunner:
def __init__(self, func, autotune_params, grid=None):
self.func = func
self.autotune_params = autotune_params
self.grid = grid
def __call__(self, *args, **kwargs):
assert len(self.autotune_params["configs"]) >= 1
for config in self.autotune_params["configs"][1:]:
def convert_arg(v):
if torch.is_tensor(v):
return torch.clone(v)
return v
new_args = tuple(map(convert_arg, args))
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
if self.grid:
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
else:
self.func(*new_args, **new_kwargs, **config.kwargs)
main_config = self.autotune_params["configs"][0]
if self.grid:
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
else:
self.func(*args, **kwargs, **main_config.kwargs)
def triton_debug_autotune(**kwars):
def wrapper(func):
return AutotuneGridSelector(func, kwars)
return wrapper

View File

@@ -1,102 +0,0 @@
from __future__ import annotations
import dataclasses
from . import torch_wrapper
torch = torch_wrapper.torch
@dataclasses.dataclass
class RegisteredStorage:
storage: torch.Storage
dtype: torch.dtype
size: int
ptr: int
@property
def end_ptr(self) -> int:
return self.ptr + self.size
@property
def access_tensor(self) -> torch.Tensor:
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
def ensure_immutable(self):
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
class MemoryMap:
storages: [RegisteredStorage]
def __init__(self):
self.storages = []
def _get_registered_storage(self, pointer: torch.Tensor):
max_pointer = torch.max(pointer).item()
min_pointer = torch.min(pointer).item()
registered_storage = next(
filter(
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
),
None,
)
if registered_storage is None:
raise Exception("Storage not found or pointers spanning multiple tensors")
registered_storage.ensure_immutable()
return registered_storage
def add_tensor(self, t: torch.Tensor):
storage = t.untyped_storage()
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
return t.data_ptr()
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
):
assert pointer.is_cuda
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert mask.is_cuda
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
# Todo: The type is wrong here, we can't determine the correct type
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
block[mask] = access_tensor[index_tensor[mask]]
return block
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
return
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)

View File

@@ -1,641 +0,0 @@
from __future__ import annotations
from ..language import core as lcore
from . import torch_wrapper
from .core import ExecutionContext
from .memory_map import MemoryMap
torch = torch_wrapper.torch
def _primitive_to_tensor(x):
"""
Converts various Python primitive data types to PyTorch tensor.
"""
tensor_args = {"device": "cuda"}
if isinstance(x, bool):
return torch.tensor([x], dtype=torch.bool, **tensor_args)
elif isinstance(x, int):
if -(2**31) <= x < 2**31:
return torch.tensor([x], dtype=torch.int32, **tensor_args)
elif -(2**63) <= x < 2**63:
return torch.tensor([x], dtype=torch.int64, **tensor_args)
else:
raise RuntimeError(f"Nonrepresentable integer {x}.")
elif isinstance(x, float):
return torch.tensor([x], dtype=torch.float32, **tensor_args)
elif torch.is_tensor(x):
return x
elif isinstance(x, WrappedTensor):
return x
elif isinstance(x, debugger_constexpr):
if x.value is None:
return None
return _primitive_to_tensor(x.value)
elif x is None:
return None
assert False, f"cannot convert {x} of type {type(x)} to tensor"
def _infer_tensor(func):
"""
A decorator function to harmonize function args:
- converts primitives to PyTorch tensors
- wraps PyTorch tensors with WrappedTensors
"""
def wrapper(*args):
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
return func(*new_args)
return wrapper
def _tensor_operation(func):
"""
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
"""
def wrapper(*args, **kwargs):
for arg in args:
assert not torch.is_tensor(arg), "unexpected tensor argument"
def unwrap_tensor(v):
if isinstance(v, WrappedTensor):
return v.tensor
if isinstance(v, debugger_constexpr):
return v.value
return v
new_args = tuple(map(unwrap_tensor, args))
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
result = func(args[0], *new_args[1:], **new_kwargs)
return WrappedTensor(result) if torch.is_tensor(result) else result
return wrapper
class debugger_constexpr:
def __init__(self, value):
if isinstance(value, debugger_constexpr):
self.value = value.value
else:
self.value = value
def __str__(self) -> str:
return "debugger_constexpr(" + str(self.value) + ")"
def __index__(self) -> int:
return self.value
def __bool__(self):
return bool(self.value)
def __ge__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value >= other
def __gt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value > other
def __le__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value <= other
def __lt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value < other
def __eq__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value == other
def __or__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __ror__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __and__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def __rand__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def to(self, dtype, bitcast=False, _builder=None):
if dtype in [torch.int64]:
ret_ty = int
elif dtype == torch.bool:
ret_ty = bool
elif dtype in [torch.float64]:
ret_ty = float
else:
raise ValueError("dtype not supported in debugger")
return debugger_constexpr(ret_ty(self.value))
class WrappedTensor:
def __init__(self, tensor):
self.tensor = tensor
def __index__(self) -> int:
return self.tensor.item()
def __str__(self) -> str:
return "wrapped_" + str(self.tensor)
def __bool__(self) -> bool:
return torch.all(self.tensor == True).item() # noqa: E712
@property
def dtype(self):
return self.tensor.dtype
@_infer_tensor
@_tensor_operation
def __add__(self, other):
return torch.add(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __radd__(self, other):
return self.__add__(other)
@_infer_tensor
@_tensor_operation
def __sub__(self, other):
return torch.sub(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rsub__(self, other):
return torch.sub(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mul__(self, other):
return torch.mul(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmul__(self, other):
return self.__mul__(other)
@_infer_tensor
@_tensor_operation
def __truediv__(self, other):
return torch.div(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rtruediv__(self, other):
return torch.div(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __floordiv__(self, other):
return torch.floor_divide(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rfloordiv__(self, other):
return torch.floor_divide(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mod__(self, other):
return torch.remainder(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmod__(self, other):
return torch.remainder(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __neg__(self):
return -self.tensor
@_infer_tensor
@_tensor_operation
def __invert__(self):
return ~self.tensor
@_infer_tensor
@_tensor_operation
def __and__(self, other):
return torch.bitwise_and(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __or__(self, other):
return torch.bitwise_or(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __xor__(self, other):
return torch.bitwise_xor(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __lshift__(self, other):
return torch.bitwise_left_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rshift__(self, other):
return torch.bitwise_right_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __gt__(self, other):
return self.tensor > other
@_infer_tensor
@_tensor_operation
def __rgt__(self, other):
return other > self.tensor
@_infer_tensor
@_tensor_operation
def __ge__(self, other):
return self.tensor >= other
@_infer_tensor
@_tensor_operation
def __rge__(self, other):
return other >= self.tensor
@_infer_tensor
@_tensor_operation
def __lt__(self, other):
return self.tensor < other
@_infer_tensor
@_tensor_operation
def __rlt__(self, other):
return other < self.tensor
@_infer_tensor
@_tensor_operation
def __le__(self, other):
return self.tensor <= other
@_infer_tensor
@_tensor_operation
def __rle__(self, other):
return other <= self.tensor
@_infer_tensor
@_tensor_operation
def __eq__(self, other):
return torch.equal(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __ne__(self, other):
return not torch.equal(self.tensor, other)
@_tensor_operation
def __getitem__(self, slices):
return self.tensor.__getitem__(slices)
# if isinstance(slices, slice):
# slices = [slices]
# src_shape = self.shape
# dst_shape = []
# curr = 0
# for sl in slices:
# if isinstance(sl, constexpr) and sl.value is None:
# dst_shape.append(1)
# elif sl == slice(None, None, None):
# dst_shape.append(src_shape[curr].value)
# curr += 1
# ret = torch.reshape(self.tensor, dst_shape, )
# return ret
@_tensor_operation
def to(self, dtype, bitcast=False):
return self.tensor.to(dtype)
# if isinstance(bitcast, constexpr):
# bitcast = bitcast.value
# if bitcast:
# return semantic.bitcast(self, dtype, )
# return semantic.cast(self, dtype, )
def _constexpr_to_value(v):
if isinstance(v, debugger_constexpr):
return v.value
return v
class TritonLangProxy:
_memory_map: MemoryMap
_context: ExecutionContext
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
self._memory_map = memory_map
self._context = context
# Types
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
# constexpr = debugger_constexpr
# Program functions
@_tensor_operation
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
cache_modifier="",
eviction_policy="",
volatile=False,
):
return self._memory_map.load(pointer, mask, other)
@_tensor_operation
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
return self._memory_map.store(pointer, value, mask)
@_tensor_operation
def program_id(self, axis):
assert axis < len(self._context.program_id)
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def num_programs(self, axis):
assert axis < len(self._context.program_size)
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def arange(self, start, end):
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
@_tensor_operation
def zeros(self, shape, dtype):
for i, d in enumerate(shape):
if not isinstance(d, debugger_constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
if isinstance(dtype, lcore.dtype):
if dtype.is_fp32():
dtype = torch.float32
elif dtype.is_fp16():
dtype = torch.float16
elif dtype.is_bf16():
dtype = torch.bfloat16
elif dtype.is_int32():
dtype = torch.int32
elif dtype.is_int16():
dtype = torch.int16
elif dtype.is_int8():
dtype = torch.int8
else:
raise TypeError(f"Unsupported dtype {dtype}")
return torch.zeros(size=shape, dtype=dtype, device="cuda")
@_tensor_operation
def dequantize(self, input, scale, shift, nbit, dst_ty=None):
if dst_ty is None:
dst_ty = torch.float16
raise NotImplementedError()
@_tensor_operation
def broadcast(self, input, other):
raise NotImplementedError()
@_tensor_operation
def broadcast_to(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def cat(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def reshape(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
assert input.dtype == other.dtype
if trans_a:
input = input.T
if trans_b:
other = other.T
return torch.matmul(input=input, other=other)
@_tensor_operation
def atomic_cas(self, pointer, cmp, val):
stored = self._memory_map.load(pointer, None, 0.0)
if not isinstance(cmp, torch.Tensor):
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
if not isinstance(val, torch.Tensor):
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
if stored == cmp:
self._memory_map.store(pointer, val, None)
return stored
@_tensor_operation
def atomic_xchg(self, pointer, val, mask=None):
if isinstance(val, int):
val = torch.tensor([val], dtype=torch.int32, device="cuda")
stored = self._memory_map.load(pointer, mask, 0.0)
self._memory_map.store(pointer, val, mask)
return stored
@_tensor_operation
def atomic_add(self, pointer, val, mask=None):
# arbitrary other value as it will masked during storing
stored = self._memory_map.load(pointer, mask, 0.0)
result = stored + val
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_max(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.maximum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_min(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.minimum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_and(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_and(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_or(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_or(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_xor(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_xor(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def where(self, condition, x, y):
condition = _primitive_to_tensor(condition)
x = _primitive_to_tensor(x)
y = _primitive_to_tensor(y)
return torch.where(condition, x, y)
@_tensor_operation
def umulhi(self, x, y):
raise NotImplementedError()
@_tensor_operation
def fdiv(self, x, y, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def exp(self, x):
return torch.exp(x)
@_tensor_operation
def log(self, x):
return torch.log(x)
@_tensor_operation
def cos(self, x):
return torch.cos(x)
@_tensor_operation
def sin(self, x):
return torch.sin(x)
@_tensor_operation
def sqrt(self, x):
return torch.sqrt(x)
@_tensor_operation
def globaltimer(self):
raise NotImplementedError()
@_tensor_operation
def clock(self):
raise NotImplementedError()
@_tensor_operation
def debug_barrier(self):
raise NotImplementedError()
@_tensor_operation
def multiple_of(self, input, values):
return input
@_tensor_operation
def max_contiguous(self, input, values):
return input
@_tensor_operation
def max_constancy(self, input, values):
return input
@_tensor_operation
def abs(self, x):
return torch.abs(x)
@_tensor_operation
def cdiv(self, x, div):
return (x + div - 1) // div
@_tensor_operation
def minimum(self, x, y):
if isinstance(x, int):
x = torch.tensor(x, device="cuda")
if isinstance(y, int):
y = torch.tensor(y, device="cuda")
return torch.minimum(x, y)
@_tensor_operation
def maximum(self, x, y):
return torch.maximum(x, y)
@_tensor_operation
def sigmoid(self, x):
raise NotImplementedError()
@_tensor_operation
def softmax(self, x, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def ravel(self, x):
raise NotImplementedError()
@_tensor_operation
def swizzle2d(self, i, j, size_i, size_j, size_g):
raise NotImplementedError()
@_tensor_operation
def zeros_like(self, input):
raise NotImplementedError()
@_tensor_operation
def max(self, input, axis=None):
if axis is None:
return torch.max(input)
return torch.max(input, dim=axis).values
@_tensor_operation
def argmax(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def min(self, input, axis=None):
if axis is None:
return torch.min(input)
return torch.min(input, dim=axis).values
@_tensor_operation
def argmin(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def sum(self, input, axis=None):
if axis is None:
return torch.sum(input)
return torch.sum(input, dim=axis)
@_tensor_operation
def xor_sum(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def cumsum(self, input, axis=None):
if axis is None:
return torch.cumsum(input)
return torch.cumsum(input, dim=axis)
@_tensor_operation
def cumprod(self, input, axis=None):
if axis is None:
return torch.cumprod(input)
return torch.cumprod(input, dim=axis)

View File

@@ -1,18 +0,0 @@
try:
import torch as _torch
except ImportError:
_torch = None
class TorchWrapper:
"""
Helps in making torch an optional dependency
"""
def __getattr__(self, name):
if _torch is None:
raise ImportError("Triton requires PyTorch to be installed")
return getattr(_torch, name)
torch = TorchWrapper()

View File

@@ -542,14 +542,15 @@ class tensor:
self.shape = [constexpr(s) for s in self.shape]
def __str__(self) -> str:
# ex. "float32[3,4]"
return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']'
# ex. "float32[16, 32]"
return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']'
@builtin
def __add__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.add(self, other, _builder)
@builtin
def __radd__(self, other, _builder=None):
return self.__add__(other, _builder=_builder)
@@ -558,6 +559,7 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.sub(self, other, _builder)
@builtin
def __rsub__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.sub(other, self, _builder)
@@ -567,6 +569,7 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.mul(self, other, _builder)
@builtin
def __rmul__(self, other, _builder=None):
return self.__mul__(other, _builder=_builder)
@@ -575,6 +578,7 @@ class tensor:
other = _to_tensor(other, _builder)
return semantic.truediv(self, other, _builder)
@builtin
def __rtruediv__(self, other, _builder=None):
other = _to_tensor(other, _builder)
return semantic.truediv(other, self, _builder)
@@ -666,8 +670,6 @@ class tensor:
else:
return semantic.lshr(other, self, _builder)
# comparison operators
# >
@builtin
def __gt__(self, other, _builder=None):
@@ -745,7 +747,7 @@ class tensor:
slices = [slices]
ret = self
for dim, sl in enumerate(slices):
if isinstance(sl, constexpr) and sl.value is None:
if sl is None or isinstance(sl, constexpr) and sl.value is None:
ret = semantic.expand_dims(ret, dim, _builder)
elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None:
pass
@@ -830,6 +832,8 @@ def arange(start, end, _builder=None):
def _shape_check_impl(shape):
shape = _constexpr_to_value(shape)
for i, d in enumerate(shape):
if isinstance(d, int):
d = constexpr(d)
if not isinstance(d, constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):

View File

@@ -1570,6 +1570,8 @@ def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno:
def _convert_elem_to_ir_value(builder, elem, require_i64):
if isinstance(elem, int):
elem = tl.constexpr(elem)
if isinstance(elem, tl.constexpr):
return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value)
elif isinstance(elem, tl.tensor):

View File

@@ -160,7 +160,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr
else:
return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast)
else:
if core.constexpr(input.dtype.primitive_bitwidth) < 32:
if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32):
if core.constexpr(input.dtype.is_floating()):
input = input.to(core.float32)
else:

View File

@@ -0,0 +1,525 @@
import inspect
import numpy as np
import triton
import triton.language as tl
from .._C.libtriton.triton import interpreter as _interpreter
# TODO: duplicate
def str_to_ty(name):
language = tl
if name[0] == "*":
ty = str_to_ty(name[1:])
return language.pointer_type(ty)
tys = {
"fp8e4nv": language.float8e4nv,
"fp8e5": language.float8e5,
"fp8e4b15": language.float8e4b15,
"fp8e4b15x4": language.float8e4b15x4,
"fp16": language.float16,
"bf16": language.bfloat16,
"fp32": language.float32,
"fp64": language.float64,
"i1": language.int1,
"i8": language.int8,
"i16": language.int16,
"i32": language.int32,
"i64": language.int64,
"u8": language.uint8,
"u16": language.uint16,
"u32": language.uint32,
"u64": language.uint64,
"B": language.int1,
}
return tys[name]
class TensorHandle:
def __init__(self, data, dtype):
self.data = data
self.dtype = dtype
def __bool__(self):
return bool(self.data.all())
class BlockPointerHandle:
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
self.base = base
self.shape = shape
self.strides = strides
self.offsets = offsets
self.tensor_shape = tensor_shape
self.order = order
def materialize_pointers(self, boundary_check):
dtype_tt = self.base.dtype.element_ty
n_bytes = dtype_tt.primitive_bitwidth // 8
tensor_shape = self.tensor_shape
ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
masks = np.ones(self.tensor_shape, dtype=bool)
for dim in range(len(tensor_shape)):
bcast_dims = [1] * len(tensor_shape)
bcast_dims[dim] = tensor_shape[dim]
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
if dim in boundary_check:
masks = np.logical_and(masks, off < self.shape[dim].data)
ptrs = TensorHandle(ptrs, self.base.dtype)
return ptrs, masks
def wrap_ret(compute_ret_ty):
def wrapper(fn):
def wrapped(*args, **kwargs):
ret = fn(*args, **kwargs)
return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs))
return wrapped
return wrapper
class Builder:
def __init__(self) -> None:
self.arch = None
# pass
def set_grid_idx(self, x, y, z):
assert x < self.grid_dim[0]
assert y < self.grid_dim[1]
assert z < self.grid_dim[2]
self.grid_idx = (x, y, z)
def set_grid_dim(self, nx, ny, nz):
self.grid_dim = (nx, ny, nz)
def np_dtype(self, tt_dtype):
if isinstance(tt_dtype, tl.pointer_type):
return np.dtype(np.uint64)
np_types = {
tl.float16: np.dtype(np.float16),
tl.float32: np.dtype(np.float32),
tl.float64: np.dtype(np.float64),
tl.int8: np.dtype(np.int8),
tl.uint8: np.dtype(np.uint8),
tl.int16: np.dtype(np.int16),
tl.uint16: np.dtype(np.uint16),
tl.int32: np.dtype(np.int32),
tl.uint32: np.dtype(np.uint32),
tl.int64: np.dtype(np.int64),
tl.uint64: np.dtype(np.uint64),
}
return np_types[tt_dtype]
# constants
def get_half_ty(self):
return tl.float16
def get_float_ty(self):
return tl.float32
def get_int64_ty(self):
return tl.int64
def get_ptr_ty(self, elt_ty, addr_space):
return tl.pointer_type(elt_ty, addr_space)
def get_block_ty(self, dtype, shape):
return tl.tensor(shape, dtype)
def get_int32(self, value):
return TensorHandle(np.array([value], dtype=np.int32), tl.int32)
def get_int64(self, value):
return TensorHandle(np.array([value], dtype=np.int64), tl.int64)
def get_fp16(self, value):
return TensorHandle(np.array([value], dtype=np.float16), tl.float16)
def get_fp32(self, value):
return TensorHandle(np.array([value], dtype=np.float32), tl.float32)
def get_null_value(self, type):
return TensorHandle(np.array([0], dtype=self.np_dtype(type)), type)
# programming model
def create_get_program_id(self, axis):
assert self.grid_idx is not None
return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32)
def create_get_num_programs(self, axis):
return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32)
# memory ops
def create_load(self, ptr, _0, _1, is_volatile):
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
other = None
return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile)
def create_store(self, ptr, val, _0, _1):
mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1)
return self.create_masked_store(ptr, val, mask, None, None)
def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile):
dtype_tt = ptrs.dtype.element_ty
dtype_np = self.np_dtype(dtype_tt)
if other is None:
other = TensorHandle(np.ones_like(ptrs.data, dtype=dtype_np), dtype_tt)
ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np)
return TensorHandle(ret, dtype_tt)
def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy):
return _interpreter.store(ptrs.data, value.data, mask.data)
# casting ops
def cast_impl(self, src, dst_type):
if isinstance(dst_type, tl.tensor):
dst_type = dst_type.dtype
return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type)
create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type)
create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type)
def create_fp_to_fp(self, src, dst_type):
assert "float8 not NotImplemented yet"
def create_bitcast(self, src, dst_type):
return TensorHandle(src.data.view(self.np_dtype(dst_type)), dst_type)
# binary operators
def binary_op(self, lhs, rhs, op):
return TensorHandle(op(lhs.data, rhs.data), lhs.dtype)
create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide)
create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply)
create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide)
create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide)
create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder)
create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add)
create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract)
create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift)
create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
create_ashr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift)
create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
create_minf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum)
create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
create_maxf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum)
create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less)
create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater)
create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal)
create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal)
create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal)
create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal)
create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and)
create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor)
create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or)
# ternary functions
def ternary_op(self, lhs, rhs, other, op):
return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype)
create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where)
# unary functions
def unary_op(self, arg, op):
return TensorHandle(op(arg.data), arg.dtype)
create_exp = lambda self, arg: self.unary_op(arg, np.exp)
create_cos = lambda self, arg: self.unary_op(arg, np.cos)
create_sin = lambda self, arg: self.unary_op(arg, np.sin)
create_log = lambda self, arg: self.unary_op(arg, np.log)
create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt)
create_fabs = lambda self, arg: self.unary_op(arg, np.abs)
create_iabs = lambda self, arg: self.unary_op(arg, np.abs)
# tensor operators
create_dot = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.dot)
create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype)
create_trans = lambda self, arg: self.unary_op(arg, np.transpose)
def create_dot(self, a, b, d, allow_tf32, maxNumImpreciseAcc):
return TensorHandle(np.dot(a.data, b.data) + d.data, a.dtype)
def create_make_range(self, start, stop):
return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32)
# pointer arithmetic
def create_addptr(self, ptr, offset):
dtype_tt = ptr.dtype.element_ty
return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype)
def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile):
ptrs, masks = ptr.materialize_pointers(boundary_check)
assert padding_option is None
other = None
return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile)
def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy):
ptrs, masks = ptr.materialize_pointers(boundary_check)
return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy)
def create_expand_dims(self, arg, axis):
return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype)
def create_broadcast(self, arg, shape):
return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype)
def create_int_to_ptr(self, val, dst_ty):
return TensorHandle(val.data.astype(np.uint64), dst_ty)
# def create_cat(self, lhs, rhs):
# pass
# def create_broadcast(self, arg, shape):
# pass
def create_splat(self, arg, shape):
return TensorHandle(np.full(shape, arg.data[0], dtype=self.np_dtype(arg.dtype)), arg.dtype)
# def create_atomic_cas(self, ptr, cmp, val, sem):
# pass
# def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem):
# pass
# def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure):
# pass
# def create_reduce(self, operands, axis):
# pass
# def create_reduce_ret(self, args):
# pass
# def create_scan(self, operands, axis):
# pass
# def create_scan_ret(self, args):
# pass
# def create_ptr_to_int(self, val, type):
# pass
# def create_int_to_ptr(self, val, type):
# pass
# def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack):
# pass
# def create_print(self, prefix, values):
# pass
# def create_assert(self, condition, message, fileName, funcName, lineNo):
# pass
# def create_undef(self, type):
# pass
# def create_barrier(self):
# pass
def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
return BlockPointerHandle(base, shape, strides, np.array(offsets), tensor_shape, order)
def create_advance(self, ptr, offsets):
assert len(ptr.offsets) == len(offsets)
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, ptr.offsets, ptr.tensor_shape, ptr.order)
for i in range(len(offsets)):
ret.offsets[i].data += offsets[i].data
return ret
def patch_attr(obj, name, member, builder):
new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder))
setattr(obj, name, new_member)
def _patch_lang_tensor(tensor, builder):
for name, member in inspect.getmembers(tensor):
if tl.core.is_builtin(member):
patch_attr(tensor, name, member, builder)
tensor.__index__ = lambda self: int(self.handle.data)
tensor.__bool__ = lambda self: True
def _patch_lang_core(lang, builder):
for name, member in inspect.getmembers(lang):
if tl.core.is_builtin(member):
patch_attr(lang, name, member, builder)
# reduce is better off with a separate patch due to how
# the builder currently interfaces with custom functions
def _new_reduce(input, axis, combine_fn):
fn = combine_fn.fn.__name__
mapping = {
'maximum': np.max,
'_sum_combine': np.sum,
}
ret = mapping[fn](input.handle.data, axis=axis)
ret_type = tl.block_type(input.dtype, ret.shape)
return tl.core.tensor(TensorHandle(ret, input.dtype), ret_type)
lang.reduce = _new_reduce
def _patch_lang_math(lang, builder):
math = lang.math
mapping = {
'abs': 'abs',
'acos': 'arccos',
'asin': 'arcsin',
'exp2': 'exp2',
'log2': 'log2',
'max': 'maximum',
}
def make_numpy(name):
def impl(*args, **kwargs):
ret_type = args[0].type # TODO: incorrect
ret_dtype = args[0].dtype # TODO: incorrect
args = [arg.handle.data for arg in args]
kwargs = {k: v.handle.data for k, v in kwargs.items()}
ret = getattr(np, mapping[name])(*args, **kwargs)
ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type)
return ret
return impl
def make_fallback(name):
def fallback(*args, **kwargs):
raise NotImplementedError(f"""
{name} not supported in interpreter mode: no known numpy implementation.
If you think that {name} in fact does have a numpy implementation, please add it
to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math.
""")
return fallback
for name, member in inspect.getmembers(math):
if name in mapping:
setattr(math, name, make_numpy(name))
else:
setattr(math, name, make_fallback(name))
# TODO: wrap everything in triton tensors
def _implicit_cvt(arg):
if isinstance(arg, int):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg], dtype=np.int32), ty)
return tl.tensor(handle, ty)
if hasattr(arg, 'data_ptr'):
ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg)))
handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty)
return tl.tensor(handle, ty)
return arg
def _unwrap(tensor):
if isinstance(tensor, triton.TensorWrapper):
return tensor.base
return tensor
builder = Builder()
RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization']
class GridExecutor:
def __init__(self, fn, arg_names, grid):
from .jit import _normalize_ty # TODO: modularize
self.fn = fn
self.arg_names = arg_names
self.grid = grid
__annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()}
self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr']
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_core(lang[0], builder)
_patch_lang_math(lang[0], builder)
def __call__(self, *args_dev, **kwargs):
args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev]
# removes reserved keywords from kwargs
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS}
# remaps core language functions to interpreted ones
self._patch_lang(builder)
# we need to copy arguments to the host for the interpreter
# implicitly convert tensor arguments to their base pointers
args = inspect.getcallargs(self.fn, *args_hst, **kwargs)
args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()}
# iterate through grid
grid = self.grid(args) if callable(self.grid) else self.grid
assert len(grid) <= 3
grid = grid + (1,) * (3 - len(grid))
builder.set_grid_dim(*grid)
for x in range(grid[0]):
for y in range(grid[1]):
for z in range(grid[2]):
builder.set_grid_idx(x, y, z)
self.fn(**args)
# copy arguments back to propagate side-effects
for arg_dev, arg_hst in zip(args_dev, args_hst):
if hasattr(arg_dev, 'data_ptr'):
_unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device))
class InterpretedFunction:
def _patch_lang(self, builder):
lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]]
assert len(lang) == 1, "triton.language must be visible from within jit'd function"
_patch_lang_tensor(getattr(lang[0], 'tensor'), builder)
_patch_lang_core(lang[0], builder)
def __init__(self, fn) -> None:
self.fn = fn
def run(*args, **kwargs):
grid = kwargs['grid']
kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']}
return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs)
self.run = run
signature = inspect.signature(fn)
self.arg_names = [v.name for v in signature.parameters.values()]
def __getitem__(self, grid):
return GridExecutor(self.fn, self.arg_names, grid)
def __call__(self, *args, **kwargs):
self._patch_lang(builder)
return self.fn(*args, **kwargs)

View File

@@ -14,6 +14,7 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union,
from .._C.libtriton.triton import TMAInfos
from ..common.backend import get_backend, path_to_ptxas
from ..language.core import dtype
from .interpreter import InterpretedFunction
TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
TRITON_VERSION = "2.1.0"
@@ -270,10 +271,6 @@ class JITFunction(KernelInterface[T]):
tys[v] = v
return key if isinstance(key, str) else f"*{tys[dtype_str]}"
def _make_signature(self, sig_key):
signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)])
return signature
def _make_constants(self, constexpr_key):
constants = dict(zip(self.constexprs, constexpr_key))
return constants
@@ -568,7 +565,6 @@ def jit(
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
interpret: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
@@ -590,9 +586,8 @@ def jit(
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
if interpret:
from ..interpreter.interpreter import GridSelector
return GridSelector(fn)
if os.getenv("TRITON_INTERPRET", "0") == "1":
return InterpretedFunction(fn)
else:
return JITFunction(
fn,