[FRONTEND][BACKEND] no longer using indices for loops (#1370)

This commit is contained in:
Philippe Tillet
2023-03-19 14:57:50 -07:00
committed by GitHub
parent 28e05c9799
commit e4b2d1bc3d
8 changed files with 7 additions and 137 deletions

View File

@@ -2,7 +2,6 @@
#define TRITON_CONVERSION_PASSES_H
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"

View File

@@ -49,18 +49,4 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
];
}
def TritonConvertArithToIndex : Pass<"triton-convert-arith-to-index", "mlir::ModuleOp"> {
let summary = "Convert arith to index";
let constructor = "mlir::triton::createTritonConvertArithToIndexPass()";
let description = [{
Convert arith operation on index values to corresponding ops in the index dialect.
We need this because SCFToCF conversion currently generates arith ops on indices.
}];
let dependentDialects = ["mlir::index::IndexDialect"];
}
#endif

View File

@@ -1,20 +0,0 @@
#ifndef TRITON_CONVERSION_ARITH_TO_INDEX_H
#define TRITON_CONVERSION_ARITH_TO_INDEX_H
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Transforms/DialectConversion.h"
#include <memory>
namespace mlir {
class ModuleOp;
template <typename T> class OperationPass;
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createTritonConvertArithToIndexPass();
}
} // namespace mlir
#endif

View File

@@ -1,90 +0,0 @@
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
#include "mlir/Dialect/Index/IR/IndexDialect.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
#include "mlir/Pass/Pass.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
using namespace mlir;
using namespace mlir::triton;
#define GEN_PASS_CLASSES
#include "triton/Conversion/Passes.h.inc"
namespace {
class TritonArithToIndexConversionTarget : public mlir::ConversionTarget {
public:
static bool hasIndexResultOrOperand(Operation *op) {
if (!op)
return false;
bool hasRetIndex = llvm::find_if(op->getResultTypes(), [](Type type) {
return type.isIndex();
}) != op->getResultTypes().end();
bool hasArgIndex = llvm::find_if(op->getOperandTypes(), [](Type type) {
return type.isIndex();
}) != op->getOperandTypes().end();
return !hasRetIndex && !hasArgIndex;
}
explicit TritonArithToIndexConversionTarget(MLIRContext &ctx)
: ConversionTarget(ctx) {
addLegalDialect<index::IndexDialect>();
addDynamicallyLegalDialect<arith::ArithDialect>(hasIndexResultOrOperand);
}
};
template <class SrcOp, class DstOp>
LogicalResult replaceArithWithIndex(SrcOp op, PatternRewriter &rewriter) {
// if (!hasIndexResultOrOperand(&*op))
// return failure();
rewriter.replaceOpWithNewOp<DstOp>(op, op->getResultTypes(),
op->getOperands(), op->getAttrs());
return success();
}
LogicalResult replaceArithCmpWithIndexCmp(arith::CmpIOp op,
PatternRewriter &rewriter) {
// if (!hasIndexResultOrOperand(&*op))
// return failure();
rewriter.replaceOpWithNewOp<index::CmpOp>(
op, op.getResult().getType(), (index::IndexCmpPredicate)op.getPredicate(),
op.getOperand(0), op.getOperand(1));
return success();
}
class ArithToIndex : public TritonConvertArithToIndexBase<ArithToIndex> {
public:
void runOnOperation() override {
MLIRContext *context = &getContext();
ModuleOp mod = getOperation();
TritonArithToIndexConversionTarget target(*context);
RewritePatternSet patterns(context);
patterns.add(replaceArithWithIndex<arith::IndexCastOp, index::CastSOp>);
patterns.add(replaceArithWithIndex<arith::ConstantOp, index::ConstantOp>);
patterns.add(replaceArithWithIndex<arith::AddIOp, index::AddOp>);
patterns.add(replaceArithCmpWithIndexCmp);
if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace
namespace mlir {
namespace triton {
std::unique_ptr<OperationPass<ModuleOp>> createTritonConvertArithToIndexPass() {
return std::make_unique<::ArithToIndex>();
}
} // namespace triton
} // namespace mlir

View File

@@ -1,5 +1,4 @@
add_mlir_conversion_library(TritonGPUToLLVM
ArithToIndexPass.cpp
ConvertLayoutOpToLLVM.cpp
DotOpToLLVM.cpp
ElementwiseOpToLLVM.cpp

View File

@@ -12,7 +12,6 @@
#include "mlir/Target/LLVMIR/Export.h"
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
#include "mlir/Transforms/Passes.h"
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
#include "triton/Tools/Sys/GetEnv.hpp"
#include "llvm/ADT/APInt.h"
@@ -296,7 +295,6 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(createTritonConvertArithToIndexPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
pm.addPass(mlir::createArithToLLVMConversionPass());

View File

@@ -69,7 +69,7 @@ def get_llvm_package_info():
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
release_suffix = "assert" if use_assert_enabled_llvm else "release"
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
version = "llvm-17.0.0-8e5a41e8271f"
version = "llvm-17.0.0-2538e550420f"
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")

View File

@@ -701,14 +701,15 @@ class CodeGenerator(ast.NodeVisitor):
iv_type = triton.language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
iv_type = triton.language.semantic.integer_promote_impl(iv_type, step.dtype)
iv_ir_type = iv_type.to_ir(self.builder)
iv_is_signed = iv_type.int_signedness == triton.language.core.dtype.SIGNEDNESS.SIGNED
# lb/ub/step might be constexpr, we need to cast them to tensor
lb = lb.handle
ub = ub.handle
step = step.handle
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
lb = self.builder.create_to_index(lb)
ub = self.builder.create_to_index(ub)
step = self.builder.create_to_index(step)
lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed)
ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
# Create placeholder for the loop induction variable
iv = self.builder.create_undef(iv_ir_type)
self.set_value(node.target.id, triton.language.core.tensor(iv, iv_type))
@@ -767,12 +768,9 @@ class CodeGenerator(ast.NodeVisitor):
# update induction variable with actual value, and replace all uses
self.builder.set_insertion_point_to_start(for_op.get_body(0))
iv = self.builder.create_index_to_si(for_op.get_induction_var())
iv = self.builder.create_int_cast(iv, iv_ir_type, True)
iv = for_op.get_induction_var()
if negative_step:
ub_si = self.builder.create_index_to_si(ub)
ub_si = self.builder.create_int_cast(ub_si, iv_ir_type, True)
iv = self.builder.create_sub(ub_si, iv)
iv = self.builder.create_sub(ub, iv)
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
self.set_value(node.target.id, triton.language.core.tensor(iv, iv_type))