mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND] cleaned up language; added frontend function for globaltimer special register (#1525)
This commit is contained in:
@@ -420,14 +420,13 @@ def TT_ReduceReturnOp: TT_Op<"reduce.return",
|
||||
//
|
||||
// External Elementwise op
|
||||
//
|
||||
def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
let summary = "ext_elemwise";
|
||||
class TT_ExternElementwiseOpBase<string mnemonic, list<Trait> traits = []> :
|
||||
TT_Op<mnemonic,
|
||||
traits # [SameOperandsAndResultEncoding,
|
||||
SameVariadicOperandSize]> {
|
||||
|
||||
let description = [{
|
||||
call an external function $symbol implemented in $libpath/$libname with $args
|
||||
|
||||
return $libpath/$libname:$symbol($args...)
|
||||
}];
|
||||
|
||||
@@ -435,7 +434,17 @@ def TT_ExtElemwiseOp : TT_Op<"ext_elemwise", [Pure, Elementwise, SameOperandsAnd
|
||||
|
||||
let results = (outs TT_Type:$result);
|
||||
|
||||
let assemblyFormat = "operands attr-dict `:` type(operands) `->` type($result)";
|
||||
let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";
|
||||
}
|
||||
|
||||
|
||||
def TT_PureExternElementwiseOp : TT_ExternElementwiseOpBase<"pure_extern_elementwise", [Pure, Elementwise]> {
|
||||
let summary = "FFI for pure element-wise extern LLVM bitcode functions";
|
||||
}
|
||||
|
||||
def TT_ImpureExternElementwiseOp : TT_ExternElementwiseOpBase<"impure_extern_elementwise", [MemoryEffects<[MemRead]>,
|
||||
MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "FFI for impure element-wise extern LLVM bitcode functions";
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
@@ -734,20 +734,20 @@ struct CmpFOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
struct ExtElemwiseOpConversion
|
||||
: public ElementwiseOpConversionBase<triton::ExtElemwiseOp,
|
||||
ExtElemwiseOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<triton::ExtElemwiseOp,
|
||||
ExtElemwiseOpConversion>;
|
||||
template <class T>
|
||||
struct ExternElementwiseOpConversion
|
||||
: public ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>> {
|
||||
using Base = ElementwiseOpConversionBase<T, ExternElementwiseOpConversion<T>>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
typedef typename Base::OpAdaptor OpAdaptor;
|
||||
|
||||
Value createDestOp(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
||||
Value createDestOp(T op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
StringRef funcName = op.getSymbol();
|
||||
if (funcName.empty())
|
||||
llvm::errs() << "ExtElemwiseOpConversion";
|
||||
llvm::errs() << "ExternElementwiseOpConversion";
|
||||
|
||||
Type funcType = getFunctionType(elemTy, operands);
|
||||
LLVM::LLVMFuncOp funcOp =
|
||||
@@ -761,8 +761,7 @@ private:
|
||||
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
|
||||
}
|
||||
|
||||
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter,
|
||||
triton::ExtElemwiseOp op,
|
||||
LLVM::LLVMFuncOp appendOrGetFuncOp(ConversionPatternRewriter &rewriter, T op,
|
||||
StringRef funcName, Type funcType) const {
|
||||
using LLVM::LLVMFuncOp;
|
||||
|
||||
@@ -771,7 +770,8 @@ private:
|
||||
if (funcOp)
|
||||
return cast<LLVMFuncOp>(*funcOp);
|
||||
|
||||
mlir::OpBuilder b(op->getParentOfType<LLVMFuncOp>());
|
||||
auto parent = ((Operation *)op)->getParentOfType<mlir::LLVM::LLVMFuncOp>();
|
||||
mlir::OpBuilder b(parent);
|
||||
auto ret = b.create<LLVMFuncOp>(op->getLoc(), funcName, funcType);
|
||||
ret.getOperation()->setAttr(
|
||||
"libname", StringAttr::get(op->getContext(), op.getLibname()));
|
||||
@@ -1117,7 +1117,11 @@ void populateElementwiseOpToLLVMPatterns(
|
||||
|
||||
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ExternElementwiseOpConversion<triton::PureExternElementwiseOp>>(
|
||||
typeConverter, benefit);
|
||||
patterns
|
||||
.add<ExternElementwiseOpConversion<triton::ImpureExternElementwiseOp>>(
|
||||
typeConverter, benefit);
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is
|
||||
// FP32. For other input types, ExpOpConversionApprox will return failure and
|
||||
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
|
||||
|
||||
@@ -411,14 +411,16 @@ struct TritonAtomicRMWPattern
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonExtElemwisePattern
|
||||
: public OpConversionPattern<triton::ExtElemwiseOp> {
|
||||
using OpConversionPattern<triton::ExtElemwiseOp>::OpConversionPattern;
|
||||
template <class T>
|
||||
struct TritonExternElementwisePattern : public OpConversionPattern<T> {
|
||||
using OpConversionPattern<T>::OpConversionPattern;
|
||||
using OpConversionPattern<T>::typeConverter;
|
||||
typedef typename OpConversionPattern<T>::OpAdaptor OpAdaptor;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(T op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<T>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getArgs(), adaptor.getLibname(),
|
||||
adaptor.getLibpath(), adaptor.getSymbol()),
|
||||
@@ -539,7 +541,9 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonReduceReturnPattern, TritonTransPattern,
|
||||
TritonExpandDimsPattern, TritonMakeRangePattern, TritonDotPattern,
|
||||
TritonLoadPattern, TritonStorePattern, TritonExtElemwisePattern,
|
||||
TritonLoadPattern, TritonStorePattern,
|
||||
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
|
||||
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
|
||||
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
@@ -40,6 +40,9 @@ OpTrait::impl::verifySameOperandsEncoding(Operation *op,
|
||||
|
||||
LogicalResult OpTrait::impl::verifySameOperandsAndResultEncoding(
|
||||
Operation *op, bool allowTensorPointerType) {
|
||||
if (op->getNumOperands() == 0)
|
||||
return success();
|
||||
|
||||
if (failed(verifyAtLeastNOperands(op, 1)) ||
|
||||
failed(verifyAtLeastNResults(op, 1)))
|
||||
return failure();
|
||||
|
||||
@@ -234,9 +234,10 @@ static bool linkExternLib(llvm::Module &module, llvm::StringRef name,
|
||||
if (!isROCM) {
|
||||
if (name == "libdevice") {
|
||||
linkLibdevice(module);
|
||||
} else {
|
||||
assert(false && "unknown extern lib: ");
|
||||
}
|
||||
// else {
|
||||
// assert(false && "unknown extern lib: ");
|
||||
// }
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
graft src
|
||||
graft triton/third_party
|
||||
graft triton/runtime/backends/
|
||||
graft triton/language/extra
|
||||
|
||||
@@ -224,6 +224,7 @@ setup(
|
||||
"triton/common",
|
||||
"triton/compiler",
|
||||
"triton/language",
|
||||
"triton/language/extra",
|
||||
"triton/ops",
|
||||
"triton/ops/blocksparse",
|
||||
"triton/runtime",
|
||||
|
||||
12
python/src/extra/cuda.ll
Normal file
12
python/src/extra/cuda.ll
Normal file
@@ -0,0 +1,12 @@
|
||||
; ~/.triton/llvm/llvm+mlir-17.0.0-x86_64-linux-gnu-ubuntu-18.04-release/bin/llvm-as ./src/extra/cuda.ll -o ./triton/language/extra/cuda.bc
|
||||
|
||||
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
|
||||
target triple = "nvptx64-nvidia-cuda"
|
||||
|
||||
|
||||
define i64 @globaltimer() #0 {
|
||||
%1 = call i64 asm sideeffect "mov.u64 $0, %globaltimer;", "=l"() nounwind
|
||||
ret i64 %1
|
||||
}
|
||||
|
||||
attributes #0 = { alwaysinline nounwind }
|
||||
@@ -1267,14 +1267,18 @@ void init_triton_ir(py::module &&m) {
|
||||
ptr, val, mask);
|
||||
})
|
||||
// External
|
||||
.def("create_external_elementwise",
|
||||
.def("create_extern_elementwise",
|
||||
[](mlir::OpBuilder &self, const std::string &libName,
|
||||
const std::string &libPath, const std::string &symbol,
|
||||
std::vector<mlir::Value> &argList,
|
||||
mlir::Type retType) -> mlir::Value {
|
||||
std::vector<mlir::Value> &argList, mlir::Type retType,
|
||||
bool isPure) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::ExtElemwiseOp>(
|
||||
loc, retType, argList, libName, libPath, symbol);
|
||||
if (isPure)
|
||||
return self.create<mlir::triton::PureExternElementwiseOp>(
|
||||
loc, retType, argList, libName, libPath, symbol);
|
||||
else
|
||||
return self.create<mlir::triton::ImpureExternElementwiseOp>(
|
||||
loc, retType, argList, libName, libPath, symbol);
|
||||
})
|
||||
// Built-in instruction
|
||||
.def("create_get_program_id",
|
||||
|
||||
@@ -2312,12 +2312,35 @@ def test_while():
|
||||
# print(m[0])
|
||||
# print(n[0])
|
||||
|
||||
# -----------------------
|
||||
# test extra
|
||||
# -----------------------
|
||||
|
||||
|
||||
def test_globaltimer():
|
||||
|
||||
@triton.jit
|
||||
def kernel(Out1, Out2):
|
||||
start = tl.extra.cuda.globaltimer()
|
||||
off = tl.arange(0, 128)
|
||||
for i in range(100):
|
||||
tl.store(Out1 + off, tl.load(Out1 + off) + 1)
|
||||
end = tl.extra.cuda.globaltimer()
|
||||
tl.store(Out2, end - start)
|
||||
|
||||
out1 = to_triton(np.zeros((128,), dtype=np.int64), device='cuda')
|
||||
out2 = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
|
||||
h = kernel[(1,)](out1, out2)
|
||||
assert out2[0] > 0
|
||||
# 2 inlined globaltimers + one extra in the wrapper extern function
|
||||
assert h.asm["ptx"].count("%globaltimer") == 3
|
||||
|
||||
# -----------------------
|
||||
# test layout conversions
|
||||
# -----------------------
|
||||
# TODO: backend should be tested separately
|
||||
|
||||
|
||||
layouts = [
|
||||
# MmaLayout(version=1, warps_per_cta=[1, 4]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[1, 4]),
|
||||
|
||||
@@ -2,6 +2,16 @@
|
||||
# Import order is significant here.
|
||||
|
||||
from . import math
|
||||
from . import extra
|
||||
from .standard import (
|
||||
cdiv,
|
||||
sigmoid,
|
||||
softmax,
|
||||
ravel,
|
||||
swizzle2d,
|
||||
zeros,
|
||||
zeros_like,
|
||||
)
|
||||
from .core import (
|
||||
abs,
|
||||
advance,
|
||||
@@ -21,7 +31,6 @@ from .core import (
|
||||
broadcast,
|
||||
broadcast_to,
|
||||
cat,
|
||||
cdiv,
|
||||
constexpr,
|
||||
cos,
|
||||
debug_barrier,
|
||||
@@ -56,18 +65,14 @@ from .core import (
|
||||
pi32_t,
|
||||
pointer_type,
|
||||
program_id,
|
||||
ravel,
|
||||
reduce,
|
||||
reshape,
|
||||
sigmoid,
|
||||
sin,
|
||||
softmax,
|
||||
sqrt,
|
||||
static_assert,
|
||||
static_print,
|
||||
store,
|
||||
sum,
|
||||
swizzle2d,
|
||||
static_range,
|
||||
tensor,
|
||||
trans,
|
||||
@@ -81,8 +86,6 @@ from .core import (
|
||||
void,
|
||||
where,
|
||||
xor_sum,
|
||||
zeros,
|
||||
zeros_like,
|
||||
)
|
||||
from .random import (
|
||||
pair_uniform_to_normal,
|
||||
@@ -127,6 +130,7 @@ __all__ = [
|
||||
"dot",
|
||||
"dtype",
|
||||
"exp",
|
||||
"extra",
|
||||
"fdiv",
|
||||
"float16",
|
||||
"float32",
|
||||
|
||||
@@ -994,10 +994,36 @@ def store(pointer, value, mask=None, boundary_check=(), cache_modifier="", evict
|
||||
return semantic.store(pointer, value, mask, boundary_check, cache_modifier, eviction_policy, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
|
||||
"""
|
||||
Returns a pointer to a block in a parent tensor
|
||||
|
||||
:param base: The base pointer to the parent tensor
|
||||
:param shape: The shape of the parent tensor
|
||||
:param strides: The strides of the parent tensor
|
||||
:param offsets: The offsets to the block
|
||||
:param block_shape: The shape of the block
|
||||
:param order: The order of the original data format
|
||||
"""
|
||||
return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def advance(base: tensor, offsets, _builder=None):
|
||||
"""
|
||||
Advance a block pointer
|
||||
|
||||
:param base: the block pointer to advance
|
||||
:param offsets: the offsets to advance, a tuple by dimension
|
||||
"""
|
||||
return semantic.advance(base, offsets, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Atomic Memory Operations
|
||||
# -----------------------
|
||||
|
||||
|
||||
def _add_atomic_docstr(name: str) -> Callable[[T], T]:
|
||||
|
||||
def _decorator(func: T) -> T:
|
||||
@@ -1080,7 +1106,6 @@ def atomic_xor(pointer, val, mask=None, _builder=None):
|
||||
# Conditioning
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def where(condition, x, y, _builder=None):
|
||||
"""
|
||||
@@ -1266,6 +1291,32 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
return rindices
|
||||
|
||||
|
||||
@triton.jit
|
||||
def minimum(x, y):
|
||||
"""
|
||||
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
||||
|
||||
:param input: the first input tensor
|
||||
:type input: Block
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return where(x < y, x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def maximum(x, y):
|
||||
"""
|
||||
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
||||
|
||||
:param input: the first input tensor
|
||||
:type input: Block
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return where(x > y, x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _max_combine(a, b):
|
||||
return maximum(a, b)
|
||||
@@ -1395,127 +1446,6 @@ def max_contiguous(input, values, _builder=None):
|
||||
values = [x.value for x in values]
|
||||
return semantic.max_contiguous(input, values)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Standard library
|
||||
# -----------------------
|
||||
|
||||
|
||||
@triton.jit
|
||||
def cdiv(x, div):
|
||||
"""
|
||||
Computes the ceiling division of :code:`x` by :code:`div`
|
||||
|
||||
:param x: the input number
|
||||
:type input: Block
|
||||
:param div: the divisor
|
||||
:param div: Block
|
||||
"""
|
||||
return (x + div - 1) // div
|
||||
|
||||
|
||||
@triton.jit
|
||||
def minimum(x, y):
|
||||
"""
|
||||
Computes the element-wise minimum of :code:`x` and :code:`y`.
|
||||
|
||||
:param input: the first input tensor
|
||||
:type input: Block
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return triton.language.where(x < y, x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def maximum(x, y):
|
||||
"""
|
||||
Computes the element-wise maximum of :code:`x` and :code:`y`.
|
||||
|
||||
:param input: the first input tensor
|
||||
:type input: Block
|
||||
:param other: the second input tensor
|
||||
:type other: Block
|
||||
"""
|
||||
return triton.language.where(x > y, x, y)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_math_1arg_docstr("sigmoid")
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + triton.language.exp(-x))
|
||||
|
||||
|
||||
@triton.jit
|
||||
@_add_math_1arg_docstr("softmax")
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - triton.language.max(x, 0)
|
||||
num = triton.language.exp(z)
|
||||
den = triton.language.sum(num, 0)
|
||||
return fdiv(num, den, ieee_rounding)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def ravel(x):
|
||||
"""
|
||||
Returns a contiguous flattened view of :code:`x`.
|
||||
|
||||
:param x: the input tensor
|
||||
:type x: Block
|
||||
"""
|
||||
return triton.language.view(x, [x.numel])
|
||||
|
||||
|
||||
@triton.jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
Transforms indices of a row-major size_i*size_j matrix into those
|
||||
of one where indices are row major for each group of size_j rows.
|
||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
[4 , 5 , 6 , 7 ],
|
||||
[8 , 9 , 10, 11],
|
||||
[12, 13, 14, 15]]
|
||||
into
|
||||
[[0, 2, 4 , 6 ],
|
||||
[1, 3, 5 , 7 ],
|
||||
[8, 10, 12, 14],
|
||||
[9, 11, 13, 15]]
|
||||
"""
|
||||
# "unrolled index in array"
|
||||
ij = i * size_j + j
|
||||
# number of elements in `size_g` groups
|
||||
# of `size_j` columns
|
||||
size_gj = size_g * size_j
|
||||
# index of the group in which (i,j) is
|
||||
group_id = ij // size_gj
|
||||
# row-index of the first element of this group
|
||||
off_i = group_id * size_g
|
||||
# last group may have fewer rows
|
||||
size_g = minimum(size_i - off_i, size_g)
|
||||
# new row and column indices
|
||||
new_i = off_i + (ij % size_g)
|
||||
new_j = (ij % size_gj) // size_g
|
||||
return new_i, new_j
|
||||
|
||||
|
||||
@triton.jit
|
||||
def zeros(shape, dtype):
|
||||
"""
|
||||
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
||||
:type dtype: DType
|
||||
"""
|
||||
return full(shape, 0, dtype)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
|
||||
# -----------------------
|
||||
# Debugging functions
|
||||
# -----------------------
|
||||
@@ -1568,32 +1498,6 @@ def device_assert(cond, msg="", _builder=None):
|
||||
return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def make_block_ptr(base: tensor, shape, strides, offsets, block_shape, order, _builder=None):
|
||||
"""
|
||||
Returns a pointer to a block in a parent tensor
|
||||
|
||||
:param base: The base pointer to the parent tensor
|
||||
:param shape: The shape of the parent tensor
|
||||
:param strides: The strides of the parent tensor
|
||||
:param offsets: The offsets to the block
|
||||
:param block_shape: The shape of the block
|
||||
:param order: The order of the original data format
|
||||
"""
|
||||
return semantic.make_block_ptr(base, shape, strides, offsets, block_shape, order, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def advance(base: tensor, offsets, _builder=None):
|
||||
"""
|
||||
Advance a block pointer
|
||||
|
||||
:param base: the block pointer to advance
|
||||
:param offsets: the offsets to advance, a tuple by dimension
|
||||
"""
|
||||
return semantic.advance(base, offsets, _builder)
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Iterators
|
||||
# -----------------------
|
||||
@@ -1623,3 +1527,86 @@ class static_range:
|
||||
|
||||
def __next__(self):
|
||||
raise RuntimeError("static_range can only be used in @triton.jit'd functions")
|
||||
|
||||
|
||||
# -----------------------
|
||||
# Extern functions
|
||||
# -----------------------
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None):
|
||||
'''
|
||||
Dispatch a function to a library
|
||||
:param func: the function to dispatch
|
||||
:param lib_name: the name of the library
|
||||
:param lib_path: the path of the library
|
||||
:param args: the arguments of the function
|
||||
:param arg_type_symbol_dict: the type of the arguments
|
||||
:param ret_shape: the shape of the return value
|
||||
:param _builder: the builder
|
||||
:return: the return value of the function
|
||||
'''
|
||||
if len(arg_type_symbol_dict) == 0:
|
||||
raise ValueError("arg_type_symbol_dict is empty")
|
||||
|
||||
num_args = len(list(arg_type_symbol_dict.keys())[0])
|
||||
if len(args) != num_args:
|
||||
raise ValueError(f"length of input args does not match."
|
||||
f"Expect {len(args)}, got {num_args}")
|
||||
|
||||
arg_types = []
|
||||
arg_list = []
|
||||
for arg in args:
|
||||
if isinstance(arg, tensor):
|
||||
arg_types.append(arg.dtype)
|
||||
arg_list.append(arg.handle)
|
||||
else:
|
||||
arg_types.append(type(arg))
|
||||
arg_list.append(arg)
|
||||
arg_types = tuple(arg_types)
|
||||
|
||||
if arg_types not in arg_type_symbol_dict:
|
||||
raise ValueError(f"input arg type does not match."
|
||||
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
|
||||
else:
|
||||
symbol = arg_type_symbol_dict[arg_types][0]
|
||||
ret_type = arg_type_symbol_dict[arg_types][1]
|
||||
if ret_shape:
|
||||
ret_type = block_type(ret_type, ret_shape)
|
||||
return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type)
|
||||
|
||||
|
||||
def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None):
|
||||
'''
|
||||
Dispatch an elementwise function to a library
|
||||
:param lib_name: the name of the library
|
||||
:param lib_path: the path of the library
|
||||
:param args: the arguments of the function
|
||||
:param arg_type_symbol_dict: the type of the arguments
|
||||
:param _builder: the builder
|
||||
:return: the return value of the function
|
||||
'''
|
||||
dispatch_args = args.copy()
|
||||
all_scalar = True
|
||||
ret_shape = None
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i] = _to_tensor(dispatch_args[i], _builder)
|
||||
if dispatch_args[i].type.is_block():
|
||||
all_scalar = False
|
||||
if not all_scalar:
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder)
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_extern_elementwise")
|
||||
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, is_pure, _builder)
|
||||
|
||||
|
||||
def extern(fn):
|
||||
"""A decorator for external functions."""
|
||||
return builtin(fn)
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
from __future__ import annotations # remove after python 3.11
|
||||
|
||||
from . import core, semantic
|
||||
|
||||
|
||||
def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, _builder=None):
|
||||
'''
|
||||
Dispatch a function to a library
|
||||
:param func: the function to dispatch
|
||||
:param lib_name: the name of the library
|
||||
:param lib_path: the path of the library
|
||||
:param args: the arguments of the function
|
||||
:param arg_type_symbol_dict: the type of the arguments
|
||||
:param ret_shape: the shape of the return value
|
||||
:param _builder: the builder
|
||||
:return: the return value of the function
|
||||
'''
|
||||
if len(arg_type_symbol_dict) == 0:
|
||||
raise ValueError("arg_type_symbol_dict is empty")
|
||||
|
||||
num_args = len(list(arg_type_symbol_dict.keys())[0])
|
||||
if len(args) != num_args:
|
||||
raise ValueError(f"length of input args does not match."
|
||||
f"Expect {len(args)}, got {num_args}")
|
||||
|
||||
arg_types = []
|
||||
arg_list = []
|
||||
for arg in args:
|
||||
if isinstance(arg, core.tensor):
|
||||
arg_types.append(arg.dtype)
|
||||
arg_list.append(arg.handle)
|
||||
else:
|
||||
arg_types.append(type(arg))
|
||||
arg_list.append(arg)
|
||||
arg_types = tuple(arg_types)
|
||||
|
||||
if arg_types not in arg_type_symbol_dict:
|
||||
raise ValueError(f"input arg type does not match."
|
||||
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
|
||||
else:
|
||||
symbol = arg_type_symbol_dict[arg_types][0]
|
||||
ret_type = arg_type_symbol_dict[arg_types][1]
|
||||
if ret_shape:
|
||||
ret_type = core.block_type(ret_type, ret_shape)
|
||||
return core.tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder)), ret_type)
|
||||
|
||||
|
||||
def extern(fn):
|
||||
"""A decorator for external functions."""
|
||||
return core.builtin(fn)
|
||||
|
||||
|
||||
def elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, _builder=None):
|
||||
'''
|
||||
Dispatch an elementwise function to a library
|
||||
:param lib_name: the name of the library
|
||||
:param lib_path: the path of the library
|
||||
:param args: the arguments of the function
|
||||
:param arg_type_symbol_dict: the type of the arguments
|
||||
:param _builder: the builder
|
||||
:return: the return value of the function
|
||||
'''
|
||||
dispatch_args = args.copy()
|
||||
all_scalar = True
|
||||
ret_shape = None
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i] = core._to_tensor(dispatch_args[i], _builder)
|
||||
if dispatch_args[i].type.is_block():
|
||||
all_scalar = False
|
||||
if not all_scalar:
|
||||
broadcast_arg = dispatch_args[0]
|
||||
# Get the broadcast shape over all the arguments
|
||||
for i, item in enumerate(dispatch_args):
|
||||
_, broadcast_arg = semantic.binary_op_type_checking_impl(
|
||||
item, broadcast_arg, _builder)
|
||||
# Change the shape of each argument based on the broadcast shape
|
||||
for i in range(len(dispatch_args)):
|
||||
dispatch_args[i], _ = semantic.binary_op_type_checking_impl(
|
||||
dispatch_args[i], broadcast_arg, _builder)
|
||||
ret_shape = broadcast_arg.shape
|
||||
func = getattr(_builder, "create_external_elementwise")
|
||||
return dispatch(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, ret_shape, _builder)
|
||||
3
python/triton/language/extra/__init__.py
Normal file
3
python/triton/language/extra/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from . import cuda
|
||||
|
||||
__all__ = ['cuda']
|
||||
BIN
python/triton/language/extra/cuda.bc
Normal file
BIN
python/triton/language/extra/cuda.bc
Normal file
Binary file not shown.
12
python/triton/language/extra/cuda.py
Normal file
12
python/triton/language/extra/cuda.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import os
|
||||
|
||||
from .. import core
|
||||
|
||||
__path__ = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
@core.extern
|
||||
def globaltimer(_builder=None):
|
||||
return core.extern_elementwise("cuda", os.path.join(__path__, "cuda.bc"), [],
|
||||
{tuple(): ("globaltimer", core.dtype("int64")),
|
||||
}, is_pure=False, _builder=_builder)
|
||||
File diff suppressed because it is too large
Load Diff
98
python/triton/language/standard.py
Normal file
98
python/triton/language/standard.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..runtime.jit import jit
|
||||
from . import core
|
||||
|
||||
# -----------------------
|
||||
# Standard library
|
||||
# -----------------------
|
||||
|
||||
|
||||
@jit
|
||||
def cdiv(x, div):
|
||||
"""
|
||||
Computes the ceiling division of :code:`x` by :code:`div`
|
||||
|
||||
:param x: the input number
|
||||
:type input: Block
|
||||
:param div: the divisor
|
||||
:param div: Block
|
||||
"""
|
||||
return (x + div - 1) // div
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("sigmoid")
|
||||
def sigmoid(x):
|
||||
return 1 / (1 + core.exp(-x))
|
||||
|
||||
|
||||
@jit
|
||||
@core._add_math_1arg_docstr("softmax")
|
||||
def softmax(x, ieee_rounding=False):
|
||||
z = x - core.max(x, 0)
|
||||
num = core.exp(z)
|
||||
den = core.sum(num, 0)
|
||||
return core.fdiv(num, den, ieee_rounding)
|
||||
|
||||
|
||||
@jit
|
||||
def ravel(x):
|
||||
"""
|
||||
Returns a contiguous flattened view of :code:`x`.
|
||||
|
||||
:param x: the input tensor
|
||||
:type x: Block
|
||||
"""
|
||||
return core.view(x, [x.numel])
|
||||
|
||||
|
||||
@jit
|
||||
def swizzle2d(i, j, size_i, size_j, size_g):
|
||||
"""
|
||||
Transforms indices of a row-major size_i*size_j matrix into those
|
||||
of one where indices are row major for each group of size_j rows.
|
||||
For example, for size_i = size_j = 4 and size_g = 2, it will transform
|
||||
[[0 , 1 , 2 , 3 ],
|
||||
[4 , 5 , 6 , 7 ],
|
||||
[8 , 9 , 10, 11],
|
||||
[12, 13, 14, 15]]
|
||||
into
|
||||
[[0, 2, 4 , 6 ],
|
||||
[1, 3, 5 , 7 ],
|
||||
[8, 10, 12, 14],
|
||||
[9, 11, 13, 15]]
|
||||
"""
|
||||
# "unrolled index in array"
|
||||
ij = i * size_j + j
|
||||
# number of elements in `size_g` groups
|
||||
# of `size_j` columns
|
||||
size_gj = size_g * size_j
|
||||
# index of the group in which (i,j) is
|
||||
group_id = ij // size_gj
|
||||
# row-index of the first element of this group
|
||||
off_i = group_id * size_g
|
||||
# last group may have fewer rows
|
||||
size_g = core.minimum(size_i - off_i, size_g)
|
||||
# new row and column indices
|
||||
new_i = off_i + (ij % size_g)
|
||||
new_j = (ij % size_gj) // size_g
|
||||
return new_i, new_j
|
||||
|
||||
|
||||
@jit
|
||||
def zeros(shape, dtype):
|
||||
"""
|
||||
Returns a tensor filled with the scalar value 0 for the given :code:`shape` and :code:`dtype`.
|
||||
|
||||
:param shape: Shape of the new array, e.g., (8, 16) or (8, )
|
||||
:type shape: tuple of ints
|
||||
:param dtype: Data-type of the new array, e.g., :code:`tl.float16`
|
||||
:type dtype: DType
|
||||
"""
|
||||
return core.full(shape, 0, dtype)
|
||||
|
||||
|
||||
@jit
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
@@ -61,6 +61,9 @@ class CudaDriver(DriverBase):
|
||||
cls.instance = super(CudaDriver, cls).__new__(cls)
|
||||
return cls.instance
|
||||
|
||||
def get_extern_path(self):
|
||||
return os.path.join(self.third_party_dir(), "cuda", "lib")
|
||||
|
||||
def get_libdevice_path(self):
|
||||
return os.path.join(self.third_party_dir(), "cuda", "lib", "libdevice.10.bc")
|
||||
|
||||
@@ -68,6 +71,7 @@ class CudaDriver(DriverBase):
|
||||
self.utils = CudaUtils()
|
||||
self.backend = self.CUDA
|
||||
self.libdevice_path = self.get_libdevice_path()
|
||||
self.extern_path = self.get_extern_path()
|
||||
|
||||
# -----------------------------
|
||||
# HIP
|
||||
|
||||
15
python/triton/runtime/errors.py
Normal file
15
python/triton/runtime/errors.py
Normal file
@@ -0,0 +1,15 @@
|
||||
|
||||
class OutOfResources(Exception):
|
||||
def __init__(self, required, limit, name):
|
||||
self.message = f'out of resource: {name}, '\
|
||||
f'Required: {required}, '\
|
||||
f'Hardware limit: {limit}'
|
||||
self.message += '. Reducing block sizes or `num_stages` may help.'
|
||||
self.required = required
|
||||
self.limit = limit
|
||||
self.name = name
|
||||
super().__init__(self.message)
|
||||
|
||||
def __reduce__(self):
|
||||
# this is necessary to make CompilationError picklable
|
||||
return (type(self), (self.required, self.limit, self.name))
|
||||
@@ -156,6 +156,7 @@ class Libdevice(ExternLibrary):
|
||||
'''
|
||||
super().__init__("libdevice", path)
|
||||
self._symbol_groups = {}
|
||||
self.is_pure = True
|
||||
|
||||
@staticmethod
|
||||
def _extract_symbol(line) -> Optional[Symbol]:
|
||||
@@ -287,7 +288,7 @@ class Libdevice(ExternLibrary):
|
||||
# def <op_name>(<args>, _builder=None):
|
||||
# arg_type_symbol_dict = {[arg_type]: {(symbol, ret_type)}}
|
||||
# return extern.dispatch("libdevice", <path>, <args>, <arg_type_symbol_dict>, _builder)
|
||||
import_str = "from . import core, extern\n"
|
||||
import_str = "from . import core\n"
|
||||
import_str += "from ..runtime import driver\n"
|
||||
import_str += "import os\n"
|
||||
|
||||
@@ -295,13 +296,13 @@ class Libdevice(ExternLibrary):
|
||||
header_str += "LIBDEVICE_PATH = os.getenv(\"TRITON_LIBDEVICE_PATH\", driver.libdevice_path)\n"
|
||||
func_str = ""
|
||||
for symbols in self._symbol_groups.values():
|
||||
func_str += "@extern.extern\n"
|
||||
func_str += "@core.extern\n"
|
||||
func_name_str = f"def {symbols[0].op_name}("
|
||||
for arg_name in symbols[0].arg_names:
|
||||
func_name_str += f"{arg_name}, "
|
||||
func_name_str += "_builder=None):\n"
|
||||
|
||||
return_str = f"\treturn extern.elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
|
||||
return_str = f"\treturn core.extern_elementwise(\"{self._name}\", LIBDEVICE_PATH, ["
|
||||
for arg_name in symbols[0].arg_names:
|
||||
return_str += f"{arg_name}, "
|
||||
return_str += "], \n"
|
||||
@@ -316,7 +317,8 @@ class Libdevice(ExternLibrary):
|
||||
arg_type_symbol_dict_str += "}"
|
||||
|
||||
return_str += arg_type_symbol_dict_str
|
||||
return_str += ", _builder)\n"
|
||||
return_str += f", is_pure={self.is_pure}"
|
||||
return_str += ", _builder=_builder)\n"
|
||||
|
||||
func_str += func_name_str + return_str + "\n"
|
||||
file_str = import_str + header_str + func_str
|
||||
|
||||
Reference in New Issue
Block a user