mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
[FRONTEND][BACKEND] no longer using indices for loops (#1370)
This commit is contained in:
@@ -2,7 +2,6 @@
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
|
||||
|
||||
@@ -49,18 +49,4 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
];
|
||||
}
|
||||
|
||||
def TritonConvertArithToIndex : Pass<"triton-convert-arith-to-index", "mlir::ModuleOp"> {
|
||||
|
||||
let summary = "Convert arith to index";
|
||||
|
||||
let constructor = "mlir::triton::createTritonConvertArithToIndexPass()";
|
||||
|
||||
let description = [{
|
||||
Convert arith operation on index values to corresponding ops in the index dialect.
|
||||
We need this because SCFToCF conversion currently generates arith ops on indices.
|
||||
}];
|
||||
|
||||
let dependentDialects = ["mlir::index::IndexDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
#ifndef TRITON_CONVERSION_ARITH_TO_INDEX_H
|
||||
#define TRITON_CONVERSION_ARITH_TO_INDEX_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createTritonConvertArithToIndexPass();
|
||||
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -1,90 +0,0 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
class TritonArithToIndexConversionTarget : public mlir::ConversionTarget {
|
||||
public:
|
||||
static bool hasIndexResultOrOperand(Operation *op) {
|
||||
if (!op)
|
||||
return false;
|
||||
bool hasRetIndex = llvm::find_if(op->getResultTypes(), [](Type type) {
|
||||
return type.isIndex();
|
||||
}) != op->getResultTypes().end();
|
||||
bool hasArgIndex = llvm::find_if(op->getOperandTypes(), [](Type type) {
|
||||
return type.isIndex();
|
||||
}) != op->getOperandTypes().end();
|
||||
return !hasRetIndex && !hasArgIndex;
|
||||
}
|
||||
|
||||
explicit TritonArithToIndexConversionTarget(MLIRContext &ctx)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<index::IndexDialect>();
|
||||
addDynamicallyLegalDialect<arith::ArithDialect>(hasIndexResultOrOperand);
|
||||
}
|
||||
};
|
||||
|
||||
template <class SrcOp, class DstOp>
|
||||
LogicalResult replaceArithWithIndex(SrcOp op, PatternRewriter &rewriter) {
|
||||
// if (!hasIndexResultOrOperand(&*op))
|
||||
// return failure();
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, op->getResultTypes(),
|
||||
op->getOperands(), op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult replaceArithCmpWithIndexCmp(arith::CmpIOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
// if (!hasIndexResultOrOperand(&*op))
|
||||
// return failure();
|
||||
rewriter.replaceOpWithNewOp<index::CmpOp>(
|
||||
op, op.getResult().getType(), (index::IndexCmpPredicate)op.getPredicate(),
|
||||
op.getOperand(0), op.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
|
||||
class ArithToIndex : public TritonConvertArithToIndexBase<ArithToIndex> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
TritonArithToIndexConversionTarget target(*context);
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add(replaceArithWithIndex<arith::IndexCastOp, index::CastSOp>);
|
||||
patterns.add(replaceArithWithIndex<arith::ConstantOp, index::ConstantOp>);
|
||||
patterns.add(replaceArithWithIndex<arith::AddIOp, index::AddOp>);
|
||||
patterns.add(replaceArithCmpWithIndexCmp);
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createTritonConvertArithToIndexPass() {
|
||||
return std::make_unique<::ArithToIndex>();
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
@@ -1,5 +1,4 @@
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
ArithToIndexPass.cpp
|
||||
ConvertLayoutOpToLLVM.cpp
|
||||
DotOpToLLVM.cpp
|
||||
ElementwiseOpToLLVM.cpp
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "llvm/ADT/APInt.h"
|
||||
@@ -296,7 +295,6 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.addPass(createTritonConvertArithToIndexPass());
|
||||
pm.addPass(mlir::createConvertIndexToLLVMPass());
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
|
||||
pm.addPass(mlir::createArithToLLVMConversionPass());
|
||||
|
||||
@@ -69,7 +69,7 @@ def get_llvm_package_info():
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
|
||||
version = "llvm-17.0.0-8e5a41e8271f"
|
||||
version = "llvm-17.0.0-2538e550420f"
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
|
||||
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
|
||||
|
||||
@@ -701,14 +701,15 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
iv_type = triton.language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
||||
iv_type = triton.language.semantic.integer_promote_impl(iv_type, step.dtype)
|
||||
iv_ir_type = iv_type.to_ir(self.builder)
|
||||
iv_is_signed = iv_type.int_signedness == triton.language.core.dtype.SIGNEDNESS.SIGNED
|
||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||
lb = lb.handle
|
||||
ub = ub.handle
|
||||
step = step.handle
|
||||
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
|
||||
lb = self.builder.create_to_index(lb)
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
lb = self.builder.create_int_cast(lb, iv_ir_type, iv_is_signed)
|
||||
ub = self.builder.create_int_cast(ub, iv_ir_type, iv_is_signed)
|
||||
step = self.builder.create_int_cast(step, iv_ir_type, iv_is_signed)
|
||||
# Create placeholder for the loop induction variable
|
||||
iv = self.builder.create_undef(iv_ir_type)
|
||||
self.set_value(node.target.id, triton.language.core.tensor(iv, iv_type))
|
||||
@@ -767,12 +768,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
|
||||
# update induction variable with actual value, and replace all uses
|
||||
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
||||
iv = self.builder.create_index_to_si(for_op.get_induction_var())
|
||||
iv = self.builder.create_int_cast(iv, iv_ir_type, True)
|
||||
iv = for_op.get_induction_var()
|
||||
if negative_step:
|
||||
ub_si = self.builder.create_index_to_si(ub)
|
||||
ub_si = self.builder.create_int_cast(ub_si, iv_ir_type, True)
|
||||
iv = self.builder.create_sub(ub_si, iv)
|
||||
iv = self.builder.create_sub(ub, iv)
|
||||
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
||||
self.set_value(node.target.id, triton.language.core.tensor(iv, iv_type))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user