Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03162023

This commit is contained in:
Rohit Santhanam
2023-03-16 13:21:15 +00:00
15 changed files with 111 additions and 47 deletions

View File

@@ -249,9 +249,12 @@ if(TRITON_BUILD_PYTHON_MODULE)
${TRITON_LIBRARIES}
)
else()
target_link_libraries(triton ${LLVM_LIBRARIES} z stdc++fs
target_link_libraries(triton ${LLVM_LIBRARIES} z
${TRITON_LIBRARIES}
)
# TODO: Figure out which target is sufficient to fix errors; triton is
# apparently not enough
link_libraries(stdc++fs)
endif()
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})

View File

@@ -112,7 +112,8 @@ private:
/// The _divisibility_ information maps the `d`-th
/// dimension to the largest power-of-two that
/// divides the first element of all the values along it
/// divides the first element of all groups of
// _contiguity_ values along it
/// For example:
/// [10, 11, 12, 13, 18, 19, 20, 21]
/// [20, 21, 22, 23, 28, 29, 30, 31]
@@ -123,6 +124,10 @@ private:
/// [14, 18, 22, 26]
/// [15, 19, 23, 27]
// would have divisibility [4, 1]
// On the other hand:
// [0, 1, 2, 0, 4, 5, 6, 7]
// would have divisibility 1 because
// _contiguity_=1
DimVectorT divisibility;
/// The _constancy_ information maps the `d`-th

View File

@@ -164,16 +164,16 @@ public:
}
};
class ConstantOpAxisInfoVisitor final
: public AxisInfoVisitorImpl<arith::ConstantOp> {
template <typename OpTy>
class ConstantOpAxisInfoVisitor final : public AxisInfoVisitorImpl<OpTy> {
public:
using AxisInfoVisitorImpl<arith::ConstantOp>::AxisInfoVisitorImpl;
using AxisInfoVisitorImpl<OpTy>::AxisInfoVisitorImpl;
AxisInfo
getAxisInfo(arith::ConstantOp op,
getAxisInfo(OpTy op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands) override {
auto intAttr = op.getValue().dyn_cast<IntegerAttr>();
auto boolAttr = op.getValue().dyn_cast<BoolAttr>();
auto intAttr = op.getValue().template dyn_cast<IntegerAttr>();
auto boolAttr = op.getValue().template dyn_cast<BoolAttr>();
if (intAttr || boolAttr) {
int64_t value{};
if (intAttr)
@@ -186,10 +186,10 @@ public:
/*knownConstantValue=*/{value});
}
// TODO: generalize to dense attr
auto splatAttr = op.getValue().dyn_cast<SplatElementsAttr>();
auto splatAttr = op.getValue().template dyn_cast<SplatElementsAttr>();
if (splatAttr && splatAttr.getElementType().isIntOrIndex()) {
int64_t value = splatAttr.getSplatValue<APInt>().getZExtValue();
TensorType ty = splatAttr.getType().cast<TensorType>();
int64_t value = splatAttr.template getSplatValue<APInt>().getZExtValue();
TensorType ty = splatAttr.getType().template cast<TensorType>();
return AxisInfo(
/*contiguity=*/AxisInfo::DimVectorT(ty.getRank(), 1),
/*divisibility=*/
@@ -233,7 +233,8 @@ private:
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value()) {
if constexpr (std::is_same_v<OpTy, arith::AddIOp> ||
std::is_same_v<OpTy, triton::AddPtrOp>) {
std::is_same_v<OpTy, triton::AddPtrOp> ||
std::is_same_v<OpTy, LLVM::AddOp>) {
return {lhs.getConstantValue().value() +
rhs.getConstantValue().value()};
} else if constexpr (std::is_same_v<OpTy, arith::SubIOp>) {
@@ -334,14 +335,11 @@ private:
if (lhs.getConstantValue().has_value() &&
lhs.getConstantValue().value() == 0)
return lhs.getDivisibility(dim);
// Case 2: rhs is constant
if (rhs.getConstantValue().has_value()) {
auto lhsDivisibility = lhs.getDivisibility(dim);
auto rhsValue = rhs.getConstantValue().value();
if (lhsDivisibility % rhsValue == 0)
return lhsDivisibility / rhsValue;
}
// Case 3: both are not constant
// Case 2: rhs is 1
if (rhs.getConstantValue().has_value() &&
rhs.getConstantValue().value() == 1)
return lhs.getDivisibility(dim);
// otherwise: return 1
return 1;
}
@@ -815,11 +813,15 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
CastOpAxisInfoVisitor<triton::gpu::ConvertLayoutOp>,
CastOpAxisInfoVisitor<mlir::UnrealizedConversionCastOp>,
CastOpAxisInfoVisitor<triton::BitcastOp>>();
// TODO: Remove rules for LLVM::ConstantOp, LLVM::AddOp
// when scf.for supports integers induction variable
visitors.append<MakeRangeOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor>();
visitors.append<ConstantOpAxisInfoVisitor<arith::ConstantOp>,
ConstantOpAxisInfoVisitor<LLVM::ConstantOp>>();
visitors.append<AddSubOpAxisInfoVisitor<triton::AddPtrOp>,
AddSubOpAxisInfoVisitor<arith::AddIOp>,
AddSubOpAxisInfoVisitor<arith::SubIOp>>();
AddSubOpAxisInfoVisitor<arith::SubIOp>,
AddSubOpAxisInfoVisitor<LLVM::AddOp>>();
visitors.append<MulIOpAxisInfoVisitor>();
visitors.append<DivOpAxisInfoVisitor<arith::DivSIOp>,
DivOpAxisInfoVisitor<arith::DivUIOp>>();

View File

@@ -82,7 +82,7 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
return failure();
ret = sliceEncoding.getParent();
}
if (auto view = dyn_cast<triton::ViewOp>(op)) {
if (isa<triton::ViewOp, triton::CatOp>(op)) {
return failure();
}
return success();
@@ -179,8 +179,6 @@ int simulateBackwardRematerialization(
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
continue;
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
continue;
// We add one expensive conversion for the current operand
numCvts += 1;

View File

@@ -275,7 +275,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module) {
}
auto optPipeline = mlir::makeOptimizingTransformer(
/*optLevel=*/3, /*sizeLevel=*/0,
/*optLevel=*/0, /*sizeLevel=*/0,
/*targetMachine=*/nullptr);
if (auto err = optPipeline(llvmModule.get())) {
@@ -328,7 +328,6 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
return nullptr;
}
// llvm::outs() << module << "\n";
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
if (!llvmIR) {
llvm::errs() << "Translate to LLVM IR failed";

View File

@@ -83,8 +83,8 @@ def get_thirdparty_packages(triton_cache_path):
if p.syspath_var_name in os.environ:
package_dir = os.environ[p.syspath_var_name]
version_file_path = os.path.join(package_dir, "version.txt")
if not os.path.exists(version_file_path) or\
Path(version_file_path).read_text() != p.url:
if p.syspath_var_name not in os.environ and\
(not os.path.exists(version_file_path) or Path(version_file_path).read_text() != p.url):
try:
shutil.rmtree(package_root_dir)
except Exception:

View File

@@ -56,7 +56,7 @@ matmul_data = {
'a100': {
(512, 512, 512): {'float16': 0.08, 'float32': 0.13, 'int8': 0.05},
(1024, 1024, 1024): {'float16': 0.33, 'float32': 0.35, 'int8': 0.169},
(2048, 2048, 2048): {'float16': 0.64, 'float32': 0.57, 'int8': 0.34},
(2048, 2048, 2048): {'float16': 0.59, 'float32': 0.57, 'int8': 0.34},
(4096, 4096, 4096): {'float16': 0.81, 'float32': 0.75, 'int8': 0.46},
(8192, 8192, 8192): {'float16': 0.77, 'float32': 0.85, 'int8': 0.51},
# tall-skinny

View File

@@ -295,6 +295,43 @@ def test_floordiv(dtype_x, dtype_y, device='cuda'):
_test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device)
def test_unsigned_name_mangling(device='cuda'):
# Test that uint32 and int32 are mangled differently by the compiler
SIZE = 128
# define the kernel / launch-grid
@triton.jit
def kernel(O1, O2, X, Y, SIZE: tl.constexpr):
off = tl.arange(0, SIZE)
x = tl.load(X + off)
y = tl.load(Y + off)
out1 = tl.abs(x) # uint32 -> nop
out2 = tl.abs(-y) # int32 -> should have an effect
tl.store(O1 + off, out1)
tl.store(O2 + off, out2)
dtype_x = 'uint32'
dtype_y = 'int32'
# inputs
rs = RandomState(17)
x = numpy_random(SIZE, dtype_str=dtype_x, rs=rs)
y = numpy_random(SIZE, dtype_str=dtype_y, rs=rs)
# reference result
expect = (np.abs(x), np.abs(-y))
# triton result
x_tri = to_triton(x, device=device, dst_type=dtype_x)
y_tri = to_triton(y, device=device, dst_type=dtype_y)
actual = tuple(
to_triton(np.empty_like(e), device=device)
for e in expect
)
kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4)
# Bitwise op, so expect exact equality
assert (expect[0] == to_numpy(actual[0])).all()
assert (expect[1] == to_numpy(actual[1])).all()
# ---------------
# test bitwise ops
# ---------------
@@ -1140,7 +1177,8 @@ def test_permute(dtype_str, shape, perm, device='cuda'):
[128, 128, 64, 4],
[64, 128, 128, 4],
[32, 128, 64, 2],
[128, 128, 64, 2],
# triggers nvptx/ptxas bug on V100 curently
# [128, 128, 64, 2],
[64, 128, 128, 2]]
for allow_tf32 in [True]
for col_a in [True, False]

View File

@@ -1,5 +1,5 @@
"""isort:skip_file"""
__version__ = '2.0.0'
__version__ = '2.1.0'
# ---------------------------------------
# Note: import order is significant here.

View File

@@ -63,7 +63,9 @@ def mangle_ty(ty):
if ty.is_ptr():
return 'P' + mangle_ty(ty.element_ty)
if ty.is_int():
return 'i' + str(ty.int_bitwidth)
SIGNED = triton.language.dtype.SIGNEDNESS.SIGNED
prefix = 'i' if ty.int_signedness == SIGNED else 'u'
return prefix + str(ty.int_bitwidth)
if ty.is_fp8():
return 'fp8'
if ty.is_fp16():
@@ -908,6 +910,22 @@ class CodeGenerator(ast.NodeVisitor):
def visit_NoneType(self, node):
return None
def visit_JoinedStr(self, node):
values = list(node.values)
for i, value in enumerate(values):
if isinstance(value, ast.Constant):
values[i] = str(value.value)
elif isinstance(value, ast.FormattedValue):
conversion_code = value.conversion
evaluated = self.visit(value.value)
if not isinstance(evaluated, triton.language.constexpr):
raise NotImplementedError("Cannot evaluate f-string containing non-constexpr conversion values,"
" found conversion of type " + str(type(evaluated)))
values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value)
else:
raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value)))
return ''.join(values)
def visit(self, node):
if node is not None:
self.last_node = node
@@ -1001,7 +1019,7 @@ def build_triton_ir(fn, signature, specialization, constants, debug=False):
generator.visit(fn.parse())
except Exception as e:
node = generator.last_node
if node is None or isinstance(e, (NotImplementedError, CompilationError)):
if node is None or isinstance(e, CompilationError):
raise e
raise CompilationError(fn.src, node) from e
ret = generator.module

View File

@@ -37,8 +37,8 @@ def _to_tensor(x, builder):
class dtype:
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
SINT_TYPES = ['int8', 'int16', 'int32', 'int64']
UINT_TYPES = ['int1', 'uint8', 'uint16', 'uint32', 'uint64']
FP_TYPES = ['fp8e4', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64']
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
OTHER_TYPES = ['void']

View File

@@ -181,8 +181,8 @@ def _dsd_kernel(
inc_b = tl.load(pinc)
inc_b = tl.multiple_of(inc_b, 8)
for k in range(K, 0, -TILE_K):
a = tl.load(pa, mask=True)
b = tl.load(pb, mask=offs_bn[None, :] < DS0)
a = tl.load(pa)
b = tl.load(pb)
acc += tl.dot(a, b)
pa += inc_a
pb += inc_b * stride_bk

View File

@@ -316,7 +316,9 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
self.__annotations__ = fn.__annotations__
# index of constexprs
self.constexprs = [self.arg_names.index(ann) for ann in self.__annotations__.keys()]
from triton.language.core import \
constexpr # import here rather than at module level due to circular import tangle
self.constexprs = [index for index, ty in self.annotations.items() if issubclass(ty, constexpr)]
# launcher
self.run = self._make_launcher()
# re-use docs of wrapped function

View File

@@ -106,10 +106,8 @@ def allclose(x, y, atol=0, rtol=1e-2):
return torch.sum(x ^ y) == 0
if x.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
rtol = 0
diff = abs(x - y)
x_max = torch.max(x)
y_max = torch.max(y)
return torch.max(diff) <= atol + rtol * torch.max(x_max, y_max)
atol = 0
return torch.allclose(x, y, rtol=rtol, atol=atol)
def nvsmi(attrs):

View File

@@ -82,7 +82,7 @@ func.func @div() {
%3 = arith.divui %1, %0 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%4 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
%5 = arith.divsi %0, %4 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%6 = arith.divsi %4, %0 : tensor<128xi32>
@@ -94,11 +94,12 @@ func.func @div() {
%9 = arith.divui %0, %8 : tensor<128xi32>
// CHECK-NEXT: contiguity = [128], divisibility = [8192], constancy = [1], constant_value = <none>
%10 = tt.make_range {end = 8320 : i32, start = 8192 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [64], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
%11 = arith.divsi %10, %4 : tensor<128xi32>
return
}
// -----
// CHECK-LABEL: @rem
@@ -179,11 +180,11 @@ func.func @logic() {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64
%1 = arith.constant dense<64> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [16777216], constancy = [64], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [64], constant_value = <none>
%2 = arith.divsi %0, %1 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [8], constancy = [128], constant_value = 8
%3 = arith.constant dense<8> : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [134217728], constancy = [8], constant_value = <none>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [8], constant_value = <none>
%4 = arith.divsi %0, %3 : tensor<128xi32>
// CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = <none>
%5 = arith.andi %0, %1 : tensor<128xi32>