mirror of
https://github.com/ROCm/ROCm.git
synced 2026-02-21 03:00:39 -05:00
Merge remote-tracking branch 'openai/main' into IFU-230517
Conflicts: lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp lib/Target/LLVMIR/LLVMIRTranslation.cpp python/test/unit/language/assert_helper.py python/triton/third_party/cuda/bin/ptxas test/Conversion/tritongpu_to_llvm.mlir It looks like you may be committing a merge. If this is not correct, please remove the file .git/MERGE_HEAD and try again.
This commit is contained in:
16
.github/workflows/integration-tests.yml
vendored
16
.github/workflows/integration-tests.yml
vendored
@@ -25,9 +25,9 @@ jobs:
|
||||
id: set-matrix
|
||||
run: |
|
||||
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
|
||||
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"], "macos-10.15"]'
|
||||
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"]]'
|
||||
else
|
||||
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
|
||||
echo '::set-output name=matrix::["ubuntu-latest"]'
|
||||
fi
|
||||
|
||||
Integration-Tests:
|
||||
@@ -101,6 +101,18 @@ jobs:
|
||||
cd python/test/unit
|
||||
python3 -m pytest
|
||||
|
||||
- name: Create artifacts archive
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
|
||||
run: |
|
||||
tar -czvf artifacts.tar.gz ~/.triton/cache
|
||||
|
||||
- name: Upload artifacts archive
|
||||
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: artifacts
|
||||
path: artifacts.tar.gz
|
||||
|
||||
- name: Run CXX unittests
|
||||
if: ${{ env.BACKEND != 'ROCM'}}
|
||||
run: |
|
||||
|
||||
2
.github/workflows/wheels.yml
vendored
2
.github/workflows/wheels.yml
vendored
@@ -30,7 +30,7 @@ jobs:
|
||||
#export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest"
|
||||
export CIBW_BEFORE_BUILD="pip install cmake;"
|
||||
export CIBW_SKIP="{cp,pp}35-*"
|
||||
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
|
||||
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64 cp3*-musllinux_x86_64"
|
||||
python3 -m cibuildwheel python --output-dir wheelhouse
|
||||
|
||||
|
||||
|
||||
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -1,3 +0,0 @@
|
||||
[submodule "deps/dlfcn-win32"]
|
||||
path = deps/dlfcn-win32
|
||||
url = https://github.com/dlfcn-win32/dlfcn-win32.git
|
||||
@@ -49,12 +49,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
# Third-party
|
||||
include_directories(${PYBIND11_INCLUDE_DIR})
|
||||
|
||||
if(WIN32)
|
||||
SET(BUILD_SHARED_LIBS OFF)
|
||||
find_package(dlfcn-win32 REQUIRED)
|
||||
set(CMAKE_DL_LIBS dlfcn-win32::dl)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
|
||||
|
||||
if (TRITON_USE_ROCM)
|
||||
|
||||
@@ -34,6 +34,7 @@ Shape Manipulation Ops
|
||||
:nosignatures:
|
||||
|
||||
broadcast_to
|
||||
expand_dims
|
||||
reshape
|
||||
ravel
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#ifndef TRITON_ANALYSIS_ALLOCATION_H
|
||||
#define TRITON_ANALYSIS_ALLOCATION_H
|
||||
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
@@ -49,18 +50,25 @@ private:
|
||||
T End = std::numeric_limits<T>::max();
|
||||
};
|
||||
|
||||
template <class T> Interval(T, T) -> Interval<T>;
|
||||
|
||||
class Allocation {
|
||||
public:
|
||||
/// A unique identifier for shared memory buffers
|
||||
using BufferId = size_t;
|
||||
using BufferIdSetT = DenseSet<BufferId>;
|
||||
using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;
|
||||
|
||||
static constexpr BufferId InvalidBufferId =
|
||||
std::numeric_limits<BufferId>::max();
|
||||
|
||||
Allocation() = default;
|
||||
/// Creates a new Allocation analysis that computes the shared memory
|
||||
/// information for all associated shared memory values.
|
||||
Allocation(Operation *operation) : operation(operation) { run(); }
|
||||
explicit Allocation(Operation *operation) : operation(operation) {}
|
||||
|
||||
/// Runs allocation analysis on the given top-level operation.
|
||||
void run(FuncAllocMapT &funcAllocMap);
|
||||
|
||||
/// Returns the operation this analysis was constructed from.
|
||||
Operation *getOperation() const { return operation; }
|
||||
@@ -75,6 +83,12 @@ public:
|
||||
return bufferSet.at(bufferId).size;
|
||||
}
|
||||
|
||||
/// Returns the allocated interval of the given buffer.
|
||||
Interval<size_t> getAllocatedInterval(BufferId bufferId) const {
|
||||
auto &buffer = bufferSet.at(bufferId);
|
||||
return Interval<size_t>(buffer.offset, buffer.offset + buffer.size);
|
||||
}
|
||||
|
||||
/// Returns the buffer id of the given value.
|
||||
/// This interface only returns the allocated buffer id.
|
||||
/// If you want to get all the buffer ids that are associated with the given
|
||||
@@ -104,26 +118,28 @@ public:
|
||||
BufferId getBufferId(Operation *operation) const {
|
||||
if (opScratch.count(operation)) {
|
||||
return opScratch.lookup(operation)->id;
|
||||
} else if (opVirtual.count(operation)) {
|
||||
return opVirtual.lookup(operation)->id;
|
||||
} else {
|
||||
return InvalidBufferId;
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the size of the given buffer is a virtual buffer.
|
||||
bool isVirtualBuffer(BufferId bufferId) const {
|
||||
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual;
|
||||
}
|
||||
|
||||
/// Returns the size of total shared memory allocated
|
||||
size_t getSharedMemorySize() const { return sharedMemorySize; }
|
||||
|
||||
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
|
||||
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
|
||||
return false;
|
||||
auto lhsBuffer = bufferSet.at(lhsId);
|
||||
auto rhsBuffer = bufferSet.at(rhsId);
|
||||
return lhsBuffer.intersects(rhsBuffer);
|
||||
}
|
||||
|
||||
private:
|
||||
/// A class that represents a shared memory buffer
|
||||
struct BufferT {
|
||||
enum class BufferKind { Explicit, Scratch };
|
||||
/// Explicit: triton_gpu.alloc_tensor
|
||||
/// Scratch: triton_gpu.convert_layout
|
||||
/// Virtual: triton.call
|
||||
enum class BufferKind { Explicit, Scratch, Virtual };
|
||||
|
||||
/// MT: thread-safe
|
||||
inline static std::atomic<BufferId> nextId = 0;
|
||||
@@ -142,12 +158,6 @@ private:
|
||||
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
|
||||
BufferT(BufferKind kind, size_t size, size_t offset)
|
||||
: kind(kind), id(nextId++), size(size), offset(offset) {}
|
||||
|
||||
bool intersects(const BufferT &other) const {
|
||||
return Interval<size_t>(offset, offset + size)
|
||||
.intersects(
|
||||
Interval<size_t>(other.offset, other.offset + other.size));
|
||||
}
|
||||
};
|
||||
|
||||
/// Op -> Scratch Buffer
|
||||
@@ -158,8 +168,6 @@ private:
|
||||
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
|
||||
/// BufferId -> Buffer
|
||||
using BufferSetT = std::map<BufferId, BufferT>;
|
||||
/// Runs allocation analysis on the given top-level operation.
|
||||
void run();
|
||||
|
||||
private:
|
||||
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
|
||||
@@ -168,6 +176,8 @@ private:
|
||||
bufferSet[buffer.id] = std::move(buffer);
|
||||
if constexpr (Kind == BufferT::BufferKind::Explicit) {
|
||||
valueBuffer[key] = &bufferSet[buffer.id];
|
||||
} else if constexpr (Kind == BufferT::BufferKind::Virtual) {
|
||||
opVirtual[key] = &bufferSet[buffer.id];
|
||||
} else {
|
||||
opScratch[key] = &bufferSet[buffer.id];
|
||||
}
|
||||
@@ -178,8 +188,9 @@ private:
|
||||
}
|
||||
|
||||
private:
|
||||
Operation *operation;
|
||||
Operation *operation = nullptr;
|
||||
OpScratchMapT opScratch;
|
||||
OpScratchMapT opVirtual;
|
||||
ValueBufferMapT valueBuffer;
|
||||
AliasBufferMapT aliasBuffer;
|
||||
BufferSetT bufferSet;
|
||||
@@ -188,7 +199,53 @@ private:
|
||||
friend class triton::AllocationAnalysis;
|
||||
};
|
||||
|
||||
template <typename T> Interval(T, T) -> Interval<T>;
|
||||
/// Static analysis that computes the allocation of shared memory buffers
|
||||
/// of the entire call graph.
|
||||
/// The allocation is performed in a post-order walk of the call graph.
|
||||
/// Each call op is treated like convert_layout that allocates a scratch buffer.
|
||||
/// At each call, we compute the start offset of the scratch buffer and pass it
|
||||
/// as an argument to the callee.
|
||||
class ModuleAllocation : public CallGraph<Allocation> {
|
||||
public:
|
||||
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
|
||||
|
||||
explicit ModuleAllocation(ModuleOp moduleOp)
|
||||
: CallGraph<Allocation>(moduleOp) {
|
||||
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
|
||||
// Pre-order edge walk callback
|
||||
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
|
||||
// Post-order node walk callback
|
||||
[&](FunctionOpInterface funcOp) {
|
||||
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
|
||||
if (inserted)
|
||||
iter->second.run(funcMap);
|
||||
});
|
||||
}
|
||||
|
||||
size_t getSharedMemorySize() {
|
||||
size_t size = 0;
|
||||
for (auto funcOp : getRoots()) {
|
||||
auto *alloc = getFuncData(funcOp);
|
||||
size = std::max(size, alloc->getSharedMemorySize());
|
||||
}
|
||||
return size;
|
||||
}
|
||||
|
||||
size_t getSharedMemorySize(FunctionOpInterface funcOp) {
|
||||
return getFuncData(funcOp)->getSharedMemorySize();
|
||||
}
|
||||
|
||||
void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) {
|
||||
sharedMemoryValue[funcOp] = value;
|
||||
}
|
||||
|
||||
Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) {
|
||||
return sharedMemoryValue[funcOp];
|
||||
}
|
||||
|
||||
private:
|
||||
FuncOffsetMapT sharedMemoryValue;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -286,16 +286,71 @@ public:
|
||||
AxisInfoAnalysis(DataFlowSolver &solver);
|
||||
using dataflow::SparseDataFlowAnalysis<
|
||||
dataflow::Lattice<AxisInfo>>::getLatticeElement;
|
||||
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
|
||||
|
||||
void visitOperation(Operation *op,
|
||||
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
|
||||
};
|
||||
|
||||
/// Module level axis info analysis based on the call graph, assuming that we
|
||||
/// do not have recursive functions.
|
||||
/// Since each function will be called multiple times, we need to
|
||||
/// calculate the axis info based on the axis info of all the callers.
|
||||
/// In the future, we can perform optimization using function cloning so that
|
||||
/// each call site will have unique axis info.
|
||||
using AxisInfoMapT = DenseMap<Value, AxisInfo>;
|
||||
class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
|
||||
public:
|
||||
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
|
||||
: CallGraph<AxisInfoMapT>(moduleOp) {
|
||||
SmallVector<FunctionOpInterface> funcs;
|
||||
for (auto root : getRoots()) {
|
||||
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
|
||||
// Pre-order edge walk callback
|
||||
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
|
||||
// Post-order node walk callback
|
||||
[&](FunctionOpInterface funcOp) {
|
||||
funcs.push_back(funcOp);
|
||||
funcMap.try_emplace(funcOp, AxisInfoMapT{});
|
||||
});
|
||||
}
|
||||
SetVector<FunctionOpInterface> sortedFuncs(funcs.begin(), funcs.end());
|
||||
SymbolTableCollection symbolTable;
|
||||
for (auto funcOp : llvm::reverse(sortedFuncs)) {
|
||||
initialize(funcOp);
|
||||
funcOp.walk([&](CallOpInterface callOp) {
|
||||
auto callee =
|
||||
dyn_cast<FunctionOpInterface>(callOp.resolveCallable(&symbolTable));
|
||||
update(callOp, callee);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
AxisInfo *getAxisInfo(Value value) {
|
||||
auto funcOp =
|
||||
value.getParentRegion()->getParentOfType<FunctionOpInterface>();
|
||||
auto *axisInfoMap = getFuncData(funcOp);
|
||||
if (!axisInfoMap) {
|
||||
return nullptr;
|
||||
}
|
||||
auto it = axisInfoMap->find(value);
|
||||
if (it == axisInfoMap->end()) {
|
||||
return nullptr;
|
||||
}
|
||||
return &(it->second);
|
||||
}
|
||||
|
||||
unsigned getPtrContiguity(Value ptr);
|
||||
|
||||
unsigned getPtrAlignment(Value ptr);
|
||||
|
||||
unsigned getMaskAlignment(Value mask);
|
||||
|
||||
private:
|
||||
void initialize(FunctionOpInterface funcOp);
|
||||
|
||||
void update(CallOpInterface callOp, FunctionOpInterface funcOp);
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -4,20 +4,75 @@
|
||||
#include "Allocation.h"
|
||||
#include "llvm/ADT/SmallPtrSet.h"
|
||||
|
||||
#include <set>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class OpBuilder;
|
||||
|
||||
struct BlockInfo {
|
||||
using BufferIdSetT = Allocation::BufferIdSetT;
|
||||
using IntervalSetT = std::set<Interval<size_t>>;
|
||||
|
||||
IntervalSetT syncReadIntervals;
|
||||
IntervalSetT syncWriteIntervals;
|
||||
|
||||
BlockInfo() = default;
|
||||
|
||||
/// Unions two BlockInfo objects.
|
||||
BlockInfo &join(const BlockInfo &other) {
|
||||
syncReadIntervals.insert(other.syncReadIntervals.begin(),
|
||||
other.syncReadIntervals.end());
|
||||
syncWriteIntervals.insert(other.syncWriteIntervals.begin(),
|
||||
other.syncWriteIntervals.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns true if intervals in two BlockInfo objects are intersected.
|
||||
bool isIntersected(const BlockInfo &other) const {
|
||||
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) ||
|
||||
/*WAR*/
|
||||
isIntersected(syncReadIntervals, other.syncWriteIntervals) ||
|
||||
/*WAW*/
|
||||
isIntersected(syncWriteIntervals, other.syncWriteIntervals);
|
||||
}
|
||||
|
||||
/// Clears the intervals because a barrier is inserted.
|
||||
void sync() {
|
||||
syncReadIntervals.clear();
|
||||
syncWriteIntervals.clear();
|
||||
}
|
||||
|
||||
/// Compares two BlockInfo objects.
|
||||
bool operator==(const BlockInfo &other) const {
|
||||
return syncReadIntervals == other.syncReadIntervals &&
|
||||
syncWriteIntervals == other.syncWriteIntervals;
|
||||
}
|
||||
|
||||
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
bool isIntersected(const IntervalSetT &lhsIntervalSet,
|
||||
const IntervalSetT &rhsIntervalSet) const {
|
||||
for (auto &lhs : lhsIntervalSet)
|
||||
for (auto &rhs : rhsIntervalSet)
|
||||
if (lhs.intersects(rhs))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shared Memory Barrier Analysis
|
||||
//===----------------------------------------------------------------------===//
|
||||
class MembarAnalysis {
|
||||
public:
|
||||
using FuncBlockInfoMapT = CallGraph<BlockInfo>::FuncDataMapT;
|
||||
/// Creates a new Membar analysis that generates the shared memory barrier
|
||||
/// in the following circumstances:
|
||||
/// - RAW: If a shared memory write is followed by a shared memory read, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// - WAR: If a shared memory read is followed by a shared memory read, and
|
||||
/// - WAR: If a shared memory read is followed by a shared memory write, and
|
||||
/// their addresses are intersected, a barrier is inserted.
|
||||
/// The following circumstances do not require a barrier:
|
||||
/// - WAW: not possible because overlapped memory allocation is not allowed.
|
||||
@@ -26,75 +81,14 @@ public:
|
||||
/// a shared memory read. If the temporary storage is written but not read,
|
||||
/// it is considered as the problem of the operation itself but not the membar
|
||||
/// analysis.
|
||||
/// The following circumstances are not considered yet:
|
||||
/// - Double buffers
|
||||
/// - N buffers
|
||||
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
|
||||
MembarAnalysis() = default;
|
||||
explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
|
||||
|
||||
/// Runs the membar analysis to the given operation, inserts a barrier if
|
||||
/// necessary.
|
||||
void run();
|
||||
void run(FuncBlockInfoMapT &funcBlockInfoMap);
|
||||
|
||||
private:
|
||||
struct BlockInfo {
|
||||
using BufferIdSetT = Allocation::BufferIdSetT;
|
||||
|
||||
BufferIdSetT syncReadBuffers;
|
||||
BufferIdSetT syncWriteBuffers;
|
||||
|
||||
BlockInfo() = default;
|
||||
BlockInfo(const BufferIdSetT &syncReadBuffers,
|
||||
const BufferIdSetT &syncWriteBuffers)
|
||||
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
|
||||
}
|
||||
|
||||
/// Unions two BlockInfo objects.
|
||||
BlockInfo &join(const BlockInfo &other) {
|
||||
syncReadBuffers.insert(other.syncReadBuffers.begin(),
|
||||
other.syncReadBuffers.end());
|
||||
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
|
||||
other.syncWriteBuffers.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// Returns true if buffers in two BlockInfo objects are intersected.
|
||||
bool isIntersected(const BlockInfo &other, Allocation *allocation) const {
|
||||
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
|
||||
allocation) ||
|
||||
/*WAR*/
|
||||
isIntersected(syncReadBuffers, other.syncWriteBuffers,
|
||||
allocation) ||
|
||||
/*WAW*/
|
||||
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
|
||||
allocation);
|
||||
}
|
||||
|
||||
/// Clears the buffers because a barrier is inserted.
|
||||
void sync() {
|
||||
syncReadBuffers.clear();
|
||||
syncWriteBuffers.clear();
|
||||
}
|
||||
|
||||
/// Compares two BlockInfo objects.
|
||||
bool operator==(const BlockInfo &other) const {
|
||||
return syncReadBuffers == other.syncReadBuffers &&
|
||||
syncWriteBuffers == other.syncWriteBuffers;
|
||||
}
|
||||
|
||||
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
|
||||
|
||||
private:
|
||||
/// Returns true if buffers in two sets are intersected.
|
||||
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
|
||||
Allocation *allocation) const {
|
||||
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
|
||||
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
|
||||
return allocation->isIntersected(lhsId, rhsId);
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/// Applies the barrier analysis based on the SCF dialect, in which each
|
||||
/// region has a single basic block only.
|
||||
/// Example:
|
||||
@@ -109,18 +103,48 @@ private:
|
||||
/// op6
|
||||
/// op7
|
||||
/// TODO: Explain why we don't use ForwardAnalysis:
|
||||
void resolve(Operation *operation, OpBuilder *builder);
|
||||
void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap,
|
||||
OpBuilder *builder);
|
||||
|
||||
/// Updates the BlockInfo operation based on the operation.
|
||||
void update(Operation *operation, BlockInfo *blockInfo, OpBuilder *builder);
|
||||
void update(Operation *operation, BlockInfo *blockInfo,
|
||||
FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder);
|
||||
|
||||
/// Collects the successors of the terminator
|
||||
void visitTerminator(Operation *operation, SmallVector<Block *> &successors);
|
||||
|
||||
private:
|
||||
Allocation *allocation;
|
||||
DenseMap<Block *, BlockInfo> inputBlockInfoMap;
|
||||
DenseMap<Block *, BlockInfo> outputBlockInfoMap;
|
||||
Allocation *allocation = nullptr;
|
||||
};
|
||||
|
||||
/// Postorder traversal on the callgraph to insert membar instructions
|
||||
/// of each function.
|
||||
/// Each function maintains a BlockInfo map that includes all potential buffers
|
||||
/// after returning. This way users do not have to explicitly insert membars
|
||||
/// before and after function calls, but might be a bit conservative.
|
||||
class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
|
||||
public:
|
||||
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation)
|
||||
: CallGraph<BlockInfo>(moduleAllocation->getModuleOp()),
|
||||
moduleAllocation(moduleAllocation) {}
|
||||
|
||||
void run() {
|
||||
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
|
||||
// Pre-order walk callback
|
||||
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
|
||||
// Post-order walk callback
|
||||
[&](FunctionOpInterface funcOp) {
|
||||
auto *allocation = moduleAllocation->getFuncData(funcOp);
|
||||
auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo());
|
||||
if (inserted) {
|
||||
MembarAnalysis analysis(allocation);
|
||||
analysis.run(funcMap);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
ModuleAllocation *moduleAllocation;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -39,6 +39,10 @@ public:
|
||||
|
||||
unsigned getIntraWarpSize();
|
||||
|
||||
unsigned getInterWarpSizeWithUniqueData();
|
||||
|
||||
unsigned getIntraWarpSizeWithUniqueData();
|
||||
|
||||
unsigned getThreadsReductionAxis();
|
||||
|
||||
SmallVector<unsigned> getScratchConfigBasic();
|
||||
@@ -57,8 +61,6 @@ private:
|
||||
int axis;
|
||||
};
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
bool maybeSharedAllocationOp(Operation *op);
|
||||
|
||||
bool maybeAliasOp(Operation *op);
|
||||
@@ -116,14 +118,153 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
|
||||
SetVector<Operation *>
|
||||
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
|
||||
|
||||
// This uses the toplogicalSort above
|
||||
/// This uses the toplogicalSort above
|
||||
SetVector<Operation *>
|
||||
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
|
||||
TransitiveFilter forwardFilter = nullptr);
|
||||
|
||||
// Create a basic DataFlowSolver with constant and dead code analysis included.
|
||||
/// Create a basic DataFlowSolver with constant and dead code analysis included.
|
||||
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
|
||||
|
||||
/// This class represents a call graph for a given ModuleOp and holds
|
||||
/// data of type T associated with each FunctionOpInterface.
|
||||
template <typename T> class CallGraph {
|
||||
public:
|
||||
using FuncDataMapT = DenseMap<FunctionOpInterface, T>;
|
||||
|
||||
/// Constructor that builds the call graph for the given moduleOp.
|
||||
explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); }
|
||||
|
||||
/// Walks the call graph and applies the provided update functions
|
||||
/// to the edges and nodes.
|
||||
template <WalkOrder UpdateEdgeOrder = WalkOrder::PreOrder,
|
||||
WalkOrder UpdateNodeOrder = WalkOrder::PreOrder,
|
||||
typename UpdateEdgeFn, typename UpdateNodeFn>
|
||||
void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) {
|
||||
DenseSet<FunctionOpInterface> visited;
|
||||
for (auto root : roots) {
|
||||
doWalk<UpdateEdgeOrder, UpdateNodeOrder>(root, visited, updateEdgeFn,
|
||||
updateNodeFn);
|
||||
}
|
||||
}
|
||||
|
||||
/// Retrieves the data associated with a function
|
||||
T *getFuncData(FunctionOpInterface funcOp) {
|
||||
if (funcMap.count(funcOp)) {
|
||||
return &funcMap[funcOp];
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
/// Getters
|
||||
ModuleOp getModuleOp() const { return moduleOp; }
|
||||
SmallVector<FunctionOpInterface> getRoots() const { return roots; }
|
||||
size_t getNumFunctions() const { return funcMap.size(); }
|
||||
|
||||
/// Returns true if the given function is a root.
|
||||
bool isRoot(FunctionOpInterface funcOp) const {
|
||||
return llvm::is_contained(roots, funcOp);
|
||||
}
|
||||
|
||||
/// Maps the data and the graph nodes associated with a funcOp to a
|
||||
/// targetFuncOp.
|
||||
template <typename FROM, typename TO>
|
||||
void mapFuncOp(FROM funcOp, TO targetFuncOp) {
|
||||
// Iterate over graph and replace
|
||||
for (auto &kv : graph) {
|
||||
for (auto &edge : kv.second) {
|
||||
if (edge.second == funcOp) {
|
||||
edge.second = targetFuncOp;
|
||||
}
|
||||
}
|
||||
}
|
||||
graph[targetFuncOp] = graph[funcOp];
|
||||
// Replace in roots
|
||||
for (auto it = roots.begin(); it != roots.end(); ++it) {
|
||||
if (*it == funcOp) {
|
||||
*it = targetFuncOp;
|
||||
break;
|
||||
}
|
||||
}
|
||||
// Replace in funcMap
|
||||
funcMap[targetFuncOp] = funcMap[funcOp];
|
||||
}
|
||||
|
||||
/// Maps the graph edges associated with a callOp to a targetCallOp.
|
||||
template <typename FROM, typename TO>
|
||||
void mapCallOp(FROM callOp, TO targetCallOp) {
|
||||
// Iterate over graph and replace
|
||||
for (auto &kv : graph) {
|
||||
for (auto &edge : kv.second) {
|
||||
if (edge.first == callOp) {
|
||||
edge.first = targetCallOp;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
void build() {
|
||||
SymbolTableCollection symbolTable;
|
||||
DenseSet<FunctionOpInterface> visited;
|
||||
// Build graph
|
||||
moduleOp.walk([&](Operation *op) {
|
||||
auto caller = op->getParentOfType<FunctionOpInterface>();
|
||||
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
|
||||
auto *callee = callOp.resolveCallable(&symbolTable);
|
||||
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callee);
|
||||
if (funcOp) {
|
||||
graph[caller].emplace_back(
|
||||
std::pair<CallOpInterface, FunctionOpInterface>(callOp, funcOp));
|
||||
visited.insert(funcOp);
|
||||
}
|
||||
}
|
||||
});
|
||||
// Find roots
|
||||
moduleOp.walk([&](FunctionOpInterface funcOp) {
|
||||
if (!visited.count(funcOp)) {
|
||||
roots.push_back(funcOp);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <WalkOrder UpdateEdgeOrder = WalkOrder::PreOrder,
|
||||
WalkOrder UpdateNodeOrder = WalkOrder::PreOrder,
|
||||
typename UpdateEdgeFn, typename UpdateNodeFn>
|
||||
void doWalk(FunctionOpInterface funcOp,
|
||||
DenseSet<FunctionOpInterface> &visited, UpdateEdgeFn updateEdgeFn,
|
||||
UpdateNodeFn updateNodeFn) {
|
||||
if (visited.count(funcOp)) {
|
||||
llvm::report_fatal_error("Cycle detected in call graph");
|
||||
}
|
||||
if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) {
|
||||
updateNodeFn(funcOp);
|
||||
}
|
||||
for (auto [callOp, callee] : graph[funcOp]) {
|
||||
if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) {
|
||||
updateEdgeFn(callOp, callee);
|
||||
}
|
||||
doWalk<UpdateEdgeOrder, UpdateNodeOrder>(callee, visited, updateEdgeFn,
|
||||
updateNodeFn);
|
||||
if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) {
|
||||
updateEdgeFn(callOp, callee);
|
||||
}
|
||||
}
|
||||
if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) {
|
||||
updateNodeFn(funcOp);
|
||||
}
|
||||
visited.erase(funcOp);
|
||||
}
|
||||
|
||||
protected:
|
||||
ModuleOp moduleOp;
|
||||
DenseMap<FunctionOpInterface,
|
||||
SmallVector<std::pair<CallOpInterface, FunctionOpInterface>>>
|
||||
graph;
|
||||
FuncDataMapT funcMap;
|
||||
SmallVector<FunctionOpInterface> roots;
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_ANALYSIS_UTILITY_H
|
||||
|
||||
@@ -155,7 +155,7 @@ def TT_LoadOp : TT_Op<"load",
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
// A tensor pointer with boundary check and padding
|
||||
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
|
||||
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
|
||||
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
// A tensor of pointers or a pointer to a scalar with mask
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
|
||||
@@ -164,8 +164,9 @@ def TT_LoadOp : TT_Op<"load",
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
|
||||
// A utility function to build the operation with all attributes
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "Optional<ArrayRef<int32_t>>":$boundaryCheck,
|
||||
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
|
||||
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other,
|
||||
"std::optional<ArrayRef<int32_t>>":$boundaryCheck,
|
||||
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
|
||||
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
|
||||
];
|
||||
|
||||
@@ -600,6 +601,11 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
|
||||
CallInterfaceCallable getCallableForCallee() {
|
||||
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
|
||||
}
|
||||
|
||||
/// Set the callee for this operation.
|
||||
void setCalleeFromCallable(CallInterfaceCallable callee) {
|
||||
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
||||
@@ -31,8 +31,37 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout);
|
||||
|
||||
// Returns the number of contiguous elements that each thread
|
||||
// has access to, on each dimension of the tensor. E.g.
|
||||
// for a blocked layout with sizePerThread = [1, 4], returns [1, 4],
|
||||
// regardless of the shape of the tensor.
|
||||
SmallVector<unsigned> getContigPerThread(Attribute layout);
|
||||
|
||||
// Returns the number of non-replicated contiguous elements that each thread
|
||||
// has access to, on each dimension of the tensor. For a blocked layout
|
||||
// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements
|
||||
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
|
||||
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
|
||||
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
|
||||
SmallVector<unsigned> getUniqueContigPerThread(Type type);
|
||||
|
||||
// Returns the number of threads per warp that have access to non-replicated
|
||||
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
|
||||
// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17
|
||||
// have access to the full tensor, whereas the other threads have access to
|
||||
// replicated elements, so this function returns [2, 2].
|
||||
SmallVector<unsigned>
|
||||
getThreadsPerWarpWithUniqueData(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape);
|
||||
|
||||
// Returns the number of warps per CTA that have access to non-replicated
|
||||
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
|
||||
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
|
||||
// returns [1, 1], since the first warp has access to the full tensor, whereas
|
||||
// the other warps have access to replicated elements.
|
||||
SmallVector<unsigned>
|
||||
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(Attribute layout);
|
||||
|
||||
SmallVector<unsigned>
|
||||
@@ -45,6 +74,9 @@ bool isaDistributedLayout(Attribute layout);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
bool isSharedEncoding(Value value);
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_
|
||||
|
||||
@@ -509,10 +509,22 @@ section 9.7.13.4.1 for more details.
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent
|
||||
"Attribute":$parent,
|
||||
"unsigned":$MMAv2kWidth
|
||||
);
|
||||
|
||||
let builders = [
|
||||
// Specially for MMAV1(Volta)
|
||||
AttrBuilder<(ins "unsigned":$opIdx,
|
||||
"Attribute":$parent,
|
||||
"Type":$eltTy), [{
|
||||
MmaEncodingAttr parentAttr = parent.dyn_cast<MmaEncodingAttr>();
|
||||
if (!parentAttr || !parentAttr.isAmpere())
|
||||
return $_get(context, opIdx, parent, 0);
|
||||
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
unsigned MMAv2kWidth = 32 / bitwidth;
|
||||
return $_get(context, opIdx, parent, MMAv2kWidth);
|
||||
}]>
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
|
||||
@@ -57,7 +57,6 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
|
||||
"int32_t", /*default*/"80",
|
||||
"device compute capability">
|
||||
];
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> {
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "triton/Analysis/Alias.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
||||
@@ -117,8 +116,11 @@ SmallVector<unsigned> getScratchConfigForAtomicCAS(triton::AtomicCASOp op) {
|
||||
|
||||
class AllocationAnalysis {
|
||||
public:
|
||||
AllocationAnalysis(Operation *operation, Allocation *allocation)
|
||||
: operation(operation), allocation(allocation) {
|
||||
AllocationAnalysis(Operation *operation,
|
||||
Allocation::FuncAllocMapT *funcAllocMap,
|
||||
Allocation *allocation)
|
||||
: operation(operation), funcAllocMap(funcAllocMap),
|
||||
allocation(allocation) {
|
||||
run();
|
||||
}
|
||||
|
||||
@@ -219,6 +221,12 @@ private:
|
||||
? elems * kPtrBitWidth / 8
|
||||
: elems * elemTy.getIntOrFloatBitWidth() / 8;
|
||||
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
|
||||
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
|
||||
auto callable = callOp.resolveCallable();
|
||||
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
|
||||
auto *funcAlloc = &(*funcAllocMap)[funcOp];
|
||||
auto bytes = funcAlloc->getSharedMemorySize();
|
||||
allocation->addBuffer<BufferT::BufferKind::Virtual>(op, bytes);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -298,15 +306,19 @@ private:
|
||||
/// allocated, but is used to store intermediate results.
|
||||
void resolveScratchBufferLiveness(
|
||||
const DenseMap<Operation *, size_t> &operationId) {
|
||||
// Analyze liveness of scratch buffers
|
||||
for (auto opScratchIter : allocation->opScratch) {
|
||||
// Any scratch memory's live range is the current operation's live
|
||||
// range.
|
||||
auto *op = opScratchIter.first;
|
||||
auto *buffer = opScratchIter.second;
|
||||
bufferRange.insert({buffer, Interval(operationId.lookup(op),
|
||||
operationId.lookup(op) + 1)});
|
||||
}
|
||||
// Analyze liveness of scratch buffers and vritual buffers.
|
||||
auto processScratchMemory = [&](const auto &container) {
|
||||
for (auto opScratchIter : container) {
|
||||
// Any scratch memory's live range is the current operation's live
|
||||
// range.
|
||||
auto *op = opScratchIter.first;
|
||||
auto *buffer = opScratchIter.second;
|
||||
bufferRange.insert({buffer, Interval(operationId.lookup(op),
|
||||
operationId.lookup(op) + 1)});
|
||||
}
|
||||
};
|
||||
processScratchMemory(allocation->opScratch);
|
||||
processScratchMemory(allocation->opVirtual);
|
||||
}
|
||||
|
||||
/// Resolves liveness of all values involved under the root operation.
|
||||
@@ -499,11 +511,15 @@ private:
|
||||
|
||||
private:
|
||||
Operation *operation;
|
||||
Allocation::FuncAllocMapT *funcAllocMap;
|
||||
Allocation *allocation;
|
||||
BufferRangeMapT bufferRange;
|
||||
};
|
||||
|
||||
} // namespace triton
|
||||
|
||||
void Allocation::run() { triton::AllocationAnalysis(getOperation(), this); }
|
||||
void Allocation::run(FuncAllocMapT &funcAllocMap) {
|
||||
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -77,7 +77,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
|
||||
|
||||
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
|
||||
Operation *op = blockArg.getOwner()->getParentOp();
|
||||
if (auto fun = dyn_cast<triton::FuncOp>(op))
|
||||
if (auto fun = dyn_cast<FunctionOpInterface>(op))
|
||||
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
|
||||
&knownContiguity, &knownDivisibility,
|
||||
&knownConstancy);
|
||||
@@ -696,13 +696,13 @@ private:
|
||||
const AxisInfo &rhs) override {
|
||||
if (lhs.getConstantValue().has_value() &&
|
||||
rhs.getConstantValue().has_value()) {
|
||||
if constexpr (std::is_same<OpTy, arith::AndIOp>::value) {
|
||||
if constexpr (std::is_same_v<OpTy, arith::AndIOp>) {
|
||||
return {lhs.getConstantValue().value() &
|
||||
rhs.getConstantValue().value()};
|
||||
} else if constexpr (std::is_same<OpTy, arith::OrIOp>::value) {
|
||||
} else if constexpr (std::is_same_v<OpTy, arith::OrIOp>) {
|
||||
return {lhs.getConstantValue().value() |
|
||||
rhs.getConstantValue().value()};
|
||||
} else if constexpr (std::is_same<OpTy, arith::XOrIOp>::value) {
|
||||
} else if constexpr (std::is_same_v<OpTy, arith::XOrIOp>) {
|
||||
return {lhs.getConstantValue().value() ^
|
||||
rhs.getConstantValue().value()};
|
||||
}
|
||||
@@ -907,37 +907,37 @@ void AxisInfoAnalysis::visitOperation(
|
||||
propagateIfChanged(result, result->join(curr));
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
|
||||
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto shape = tensorTy.getShape();
|
||||
|
||||
// Here order should be ordered by contiguous first, so the first element
|
||||
// should have the largest contiguous.
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
unsigned align = getPtrAlignment(ptr);
|
||||
|
||||
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
|
||||
contigPerThread = std::min(align, contigPerThread);
|
||||
contigPerThread = std::min<unsigned>(shape[order[0]], contigPerThread);
|
||||
auto uniqueContigPerThread = triton::gpu::getUniqueContigPerThread(tensorTy);
|
||||
assert(order[0] < uniqueContigPerThread.size() &&
|
||||
"Unxpected uniqueContigPerThread size");
|
||||
unsigned contiguity = uniqueContigPerThread[order[0]];
|
||||
contiguity = std::min(align, contiguity);
|
||||
|
||||
return contigPerThread;
|
||||
return contiguity;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(ptr);
|
||||
if (!latticeElement)
|
||||
auto *axisInfo = getAxisInfo(ptr);
|
||||
if (!axisInfo)
|
||||
return 1;
|
||||
auto axisInfo = latticeElement->getValue();
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
auto maxMultipleBytes = axisInfo.getDivisibility(order[0]);
|
||||
auto maxContig = axisInfo.getContiguity(order[0]);
|
||||
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
|
||||
auto maxContig = axisInfo->getContiguity(order[0]);
|
||||
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
|
||||
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
|
||||
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
|
||||
@@ -945,17 +945,69 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
|
||||
return alignment;
|
||||
}
|
||||
|
||||
unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
|
||||
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
|
||||
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return 1;
|
||||
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(mask);
|
||||
if (!latticeElement)
|
||||
auto *axisInfo = getAxisInfo(mask);
|
||||
if (!axisInfo)
|
||||
return 1;
|
||||
auto maskAxis = latticeElement->getValue();
|
||||
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
|
||||
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
|
||||
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
|
||||
return alignment;
|
||||
}
|
||||
|
||||
void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) {
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(funcOp)))
|
||||
return;
|
||||
auto *axisInfoMap = getFuncData(funcOp);
|
||||
auto updateAxisInfoMap = [&](Value value) {
|
||||
auto axisInfo = analysis->getLatticeElement(value)->getValue();
|
||||
AxisInfo curAxisInfo;
|
||||
if (axisInfoMap->count(value)) {
|
||||
curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value));
|
||||
} else {
|
||||
curAxisInfo = axisInfo;
|
||||
}
|
||||
(*axisInfoMap)[value] = curAxisInfo;
|
||||
};
|
||||
funcOp.walk([&](Operation *op) {
|
||||
for (auto value : op->getResults()) {
|
||||
updateAxisInfoMap(value);
|
||||
}
|
||||
});
|
||||
funcOp.walk([&](Block *block) {
|
||||
for (auto value : block->getArguments()) {
|
||||
updateAxisInfoMap(value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
void ModuleAxisInfoAnalysis::update(CallOpInterface callOp,
|
||||
FunctionOpInterface callee) {
|
||||
auto caller = callOp->getParentOfType<FunctionOpInterface>();
|
||||
auto *axisInfoMap = getFuncData(caller);
|
||||
for (auto entry : llvm::enumerate(callOp->getOperands())) {
|
||||
auto index = entry.index();
|
||||
auto value = entry.value();
|
||||
auto setAttrFn = [&](StringRef attrName, int64_t prevValue) {
|
||||
auto curValue = highestPowOf2Divisor<int64_t>(0);
|
||||
if (callee.getArgAttrOfType<IntegerAttr>(index, attrName)) {
|
||||
curValue =
|
||||
callee.getArgAttrOfType<IntegerAttr>(index, attrName).getInt();
|
||||
}
|
||||
auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64),
|
||||
gcd(prevValue, curValue));
|
||||
callee.setArgAttr(index, attrName, attr);
|
||||
};
|
||||
auto axisInfo = axisInfoMap->lookup(value);
|
||||
assert(axisInfo.getRank() == 1 && "only scalar arguments are supported");
|
||||
setAttrFn("tt.contiguity", axisInfo.getContiguity(0));
|
||||
setAttrFn("tt.divisibility", axisInfo.getDivisibility(0));
|
||||
setAttrFn("tt.constancy", axisInfo.getConstancy(0));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -11,4 +11,7 @@ add_mlir_library(TritonAnalysis
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRAnalysis
|
||||
MLIRLLVMDialect
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
)
|
||||
|
||||
@@ -9,16 +9,21 @@
|
||||
|
||||
namespace mlir {
|
||||
|
||||
void MembarAnalysis::run() {
|
||||
auto *operation = allocation->getOperation();
|
||||
OpBuilder builder(operation);
|
||||
resolve(operation, &builder);
|
||||
void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
|
||||
FunctionOpInterface funcOp =
|
||||
dyn_cast<FunctionOpInterface>(allocation->getOperation());
|
||||
OpBuilder builder(funcOp.getContext());
|
||||
resolve(funcOp, &funcBlockInfoMap, &builder);
|
||||
}
|
||||
|
||||
void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) {
|
||||
void MembarAnalysis::resolve(FunctionOpInterface funcOp,
|
||||
FuncBlockInfoMapT *funcBlockInfoMap,
|
||||
OpBuilder *builder) {
|
||||
// Initialize the blockList
|
||||
DenseMap<Block *, BlockInfo> inputBlockInfoMap;
|
||||
DenseMap<Block *, BlockInfo> outputBlockInfoMap;
|
||||
std::deque<Block *> blockList;
|
||||
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
|
||||
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
|
||||
for (auto &op : block->getOperations()) {
|
||||
// Check if the operation belongs to scf dialect, if so, we need to
|
||||
// throw an error
|
||||
@@ -38,13 +43,13 @@ void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) {
|
||||
auto *block = blockList.front();
|
||||
blockList.pop_front();
|
||||
// Make a copy of the inputblockInfo but not update
|
||||
auto inputBlockInfo = inputBlockInfoMap.lookup(block);
|
||||
auto inputBlockInfo = inputBlockInfoMap[block];
|
||||
SmallVector<Block *> successors;
|
||||
for (auto &op : block->getOperations()) {
|
||||
if (op.hasTrait<OpTrait::IsTerminator>()) {
|
||||
visitTerminator(&op, successors);
|
||||
} else {
|
||||
update(&op, &inputBlockInfo, builder);
|
||||
update(&op, &inputBlockInfo, funcBlockInfoMap, builder);
|
||||
}
|
||||
}
|
||||
// Get the reference because we want to update if it changed
|
||||
@@ -62,15 +67,22 @@ void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) {
|
||||
blockList.emplace_back(successor);
|
||||
}
|
||||
}
|
||||
|
||||
// Update the final dangling buffers that haven't been synced
|
||||
auto &funcBlockInfo = (*funcBlockInfoMap)[funcOp];
|
||||
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
|
||||
block->walk([&](triton::ReturnOp returnOp) {
|
||||
funcBlockInfo.join(outputBlockInfoMap[block]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void MembarAnalysis::visitTerminator(Operation *op,
|
||||
SmallVector<Block *> &successors) {
|
||||
if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
|
||||
Block *parentBlock = branchInterface->getBlock();
|
||||
for (Block *successor : parentBlock->getSuccessors()) {
|
||||
successors.push_back(successor);
|
||||
}
|
||||
successors.append(std::begin(parentBlock->getSuccessors()),
|
||||
std::end(parentBlock->getSuccessors()));
|
||||
return;
|
||||
}
|
||||
// Otherwise, it could be a return op
|
||||
@@ -81,6 +93,7 @@ void MembarAnalysis::visitTerminator(Operation *op,
|
||||
}
|
||||
|
||||
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
FuncBlockInfoMapT *funcBlockInfoMap,
|
||||
OpBuilder *builder) {
|
||||
if (isa<triton::gpu::ExtractSliceOp>(op) ||
|
||||
isa<triton::gpu::AllocTensorOp>(op) || isa<triton::TransOp>(op)) {
|
||||
@@ -108,36 +121,51 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
}
|
||||
|
||||
BlockInfo curBlockInfo;
|
||||
for (Value value : op->getOperands()) {
|
||||
for (auto bufferId : allocation->getBufferIds(value)) {
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op)) {
|
||||
// FIXME(Keren): insert_slice and insert_slice_async are always
|
||||
// alias for now
|
||||
curBlockInfo.syncWriteBuffers.insert(bufferId);
|
||||
} else {
|
||||
// ConvertLayoutOp: shared memory -> registers
|
||||
curBlockInfo.syncReadBuffers.insert(bufferId);
|
||||
if (isa<triton::CallOp>(op)) {
|
||||
// Inter-function dependencies
|
||||
auto callOpInterface = dyn_cast<CallOpInterface>(op);
|
||||
if (auto callee =
|
||||
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable())) {
|
||||
curBlockInfo = funcBlockInfoMap->lookup(callee);
|
||||
}
|
||||
} else {
|
||||
// Intra-function dependencies
|
||||
for (Value value : op->getOperands()) {
|
||||
for (auto bufferId : allocation->getBufferIds(value)) {
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op)) {
|
||||
// FIXME(Keren): insert_slice and insert_slice_async are always
|
||||
// alias for now
|
||||
curBlockInfo.syncWriteIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
} else {
|
||||
// ConvertLayoutOp: shared memory -> registers
|
||||
curBlockInfo.syncReadIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (Value value : op->getResults()) {
|
||||
// ConvertLayoutOp: registers -> shared memory
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
for (Value value : op->getResults()) {
|
||||
// ConvertLayoutOp: registers -> shared memory
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curBlockInfo.syncWriteIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
}
|
||||
}
|
||||
// Scratch buffer is considered as both shared memory write & read
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curBlockInfo.syncWriteBuffers.insert(bufferId);
|
||||
curBlockInfo.syncWriteIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
curBlockInfo.syncReadIntervals.insert(
|
||||
allocation->getAllocatedInterval(bufferId));
|
||||
}
|
||||
}
|
||||
// Scratch buffer is considered as both shared memory write & read
|
||||
auto bufferId = allocation->getBufferId(op);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
curBlockInfo.syncWriteBuffers.insert(bufferId);
|
||||
curBlockInfo.syncReadBuffers.insert(bufferId);
|
||||
}
|
||||
|
||||
if (blockInfo->isIntersected(curBlockInfo, allocation)) {
|
||||
if (blockInfo->isIntersected(curBlockInfo)) {
|
||||
OpBuilder::InsertionGuard g(*builder);
|
||||
builder->setInsertionPoint(op);
|
||||
builder->create<gpu::BarrierOp>(op->getLoc());
|
||||
|
||||
@@ -26,10 +26,27 @@ unsigned ReduceOpHelper::getIntraWarpSize() {
|
||||
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData();
|
||||
return std::min(srcReduceDimSize / sizeIntraWarps,
|
||||
triton::gpu::getWarpsPerCTAWithUniqueData(
|
||||
getSrcLayout(), getSrcShape())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
|
||||
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
|
||||
return std::min(srcReduceDimSize,
|
||||
triton::gpu::getThreadsPerWarpWithUniqueData(
|
||||
getSrcLayout(), getSrcShape())[axis]);
|
||||
}
|
||||
|
||||
unsigned ReduceOpHelper::getThreadsReductionAxis() {
|
||||
auto srcLayout = getSrcLayout();
|
||||
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
|
||||
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
|
||||
auto srcShape = getSrcShape();
|
||||
return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout,
|
||||
srcShape)[axis] *
|
||||
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
|
||||
}
|
||||
|
||||
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
|
||||
@@ -91,14 +108,8 @@ bool ReduceOpHelper::isSupportedLayout() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool isSharedEncoding(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorType.getEncoding();
|
||||
return encoding && encoding.isa<triton::gpu::SharedEncodingAttr>();
|
||||
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -106,13 +106,22 @@ private:
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parentEncoding = sliceLayout.getParent();
|
||||
auto parentSizePerThread = getSizePerThread(parentEncoding);
|
||||
auto parentShape = sliceLayout.paddedShape(shape);
|
||||
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
|
||||
parentEncoding);
|
||||
auto multiDimOffsetParent =
|
||||
getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy,
|
||||
sliceLayout.paddedShape(multiDimCTAInRepId),
|
||||
sliceLayout.paddedShape(shapePerCTA));
|
||||
auto offsets = emitOffsetForLayout(layout, type);
|
||||
auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy);
|
||||
SmallVector<int> idxs;
|
||||
for (SmallVector<unsigned> off : offsets) {
|
||||
off.insert(off.begin() + dim, 0);
|
||||
auto it = std::find(parentOffset.begin(), parentOffset.end(), off);
|
||||
idxs.push_back(std::distance(parentOffset.begin(), it));
|
||||
}
|
||||
auto multiDimOffsetParent = getMultiDimOffset(
|
||||
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
|
||||
sliceLayout.paddedShape(multiDimCTAInRepId),
|
||||
sliceLayout.paddedShape(shapePerCTA));
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
for (unsigned d = 0; d < rank + 1; ++d) {
|
||||
if (d == dim)
|
||||
@@ -329,7 +338,8 @@ private:
|
||||
|
||||
if (needTrans) {
|
||||
// do transpose
|
||||
auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma);
|
||||
auto aEncoding =
|
||||
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
|
||||
int numM = aEncoding.getMMAv1NumOuter(shape);
|
||||
int numN = accumSizePerThread / numM;
|
||||
|
||||
@@ -619,10 +629,10 @@ private:
|
||||
// is implemented
|
||||
SmallVector<Value> reorderedVals;
|
||||
for (unsigned i = 0; i < vecVals.size(); i += 4) {
|
||||
reorderedVals.push_back(vecVals[i]);
|
||||
reorderedVals.push_back(vecVals[i + 2]);
|
||||
reorderedVals.push_back(vecVals[i + 1]);
|
||||
reorderedVals.push_back(vecVals[i + 3]);
|
||||
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
|
||||
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
|
||||
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
|
||||
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
|
||||
}
|
||||
|
||||
Value view = getTypeConverter()->packLLElements(loc, reorderedVals,
|
||||
@@ -641,19 +651,19 @@ private:
|
||||
auto loc = op.getLoc();
|
||||
Value src = op.getSrc();
|
||||
Value dst = op.getResult();
|
||||
bool isHMMA = supportMMA(dst, mmaLayout.getVersionMajor());
|
||||
|
||||
auto smemObj =
|
||||
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
Value res;
|
||||
|
||||
if (!isOuter && mmaLayout.isAmpere() && isHMMA) { // tensor core v2
|
||||
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2
|
||||
|
||||
res = SharedToDotOperandMMAv2::convertLayout(
|
||||
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
|
||||
smemObj, getTypeConverter(), tid_val());
|
||||
|
||||
} else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1
|
||||
} else if (!isOuter && mmaLayout.isVolta() &&
|
||||
supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1
|
||||
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
|
||||
auto srcSharedLayout = src.getType()
|
||||
.cast<RankedTensorType>()
|
||||
@@ -675,14 +685,13 @@ private:
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}; // namespace triton::gpu::ConvertLayoutOp>
|
||||
}; // namespace triton::gpu::ConvertLayoutOp
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation,
|
||||
indexCacheInfo, benefit);
|
||||
}
|
||||
|
||||
@@ -10,8 +10,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
|
||||
@@ -358,11 +358,11 @@ SmallVector<CoordTy> getMNCoords(Value thread,
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
// A info
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
|
||||
@@ -19,10 +19,10 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
class MMA16816SmemLoader {
|
||||
public:
|
||||
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape,
|
||||
int perPhase, int maxPhase, int elemBytes,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
int kWidth, ArrayRef<Value> smemStrides,
|
||||
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
|
||||
ArrayRef<int> matShape, int perPhase, int maxPhase,
|
||||
int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
const Location &loc);
|
||||
|
||||
@@ -33,7 +33,7 @@ public:
|
||||
if (canUseLdmatrix)
|
||||
return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset);
|
||||
else
|
||||
return computeLdsMatOffs(warpOff, lane, cSwizzleOffset, elemBytes);
|
||||
return computeLdsMatOffs(warpOff, lane, cSwizzleOffset);
|
||||
return {};
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ public:
|
||||
Value cSwizzleOffset);
|
||||
// compute 8-bit matrix offset.
|
||||
SmallVector<Value> computeLdsMatOffs(Value warpOff, Value lane,
|
||||
Value cSwizzleOffset, int elemBytes);
|
||||
Value cSwizzleOffset);
|
||||
|
||||
// Load 4 matrices and returns 4 vec<2> elements.
|
||||
std::tuple<Value, Value, Value, Value>
|
||||
@@ -55,6 +55,7 @@ public:
|
||||
private:
|
||||
SmallVector<uint32_t> order;
|
||||
int kOrder;
|
||||
int kWidth;
|
||||
SmallVector<int64_t> tileShape;
|
||||
SmallVector<int> instrShape;
|
||||
SmallVector<int> matShape;
|
||||
@@ -176,9 +177,7 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
|
||||
|
||||
SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
Value lane,
|
||||
Value cSwizzleOffset,
|
||||
int elemBytes) {
|
||||
assert(elemBytes <= 4);
|
||||
Value cSwizzleOffset) {
|
||||
int cTileShape = tileShape[order[0]];
|
||||
int sTileShape = tileShape[order[1]];
|
||||
if (!needTrans) {
|
||||
@@ -187,10 +186,10 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
|
||||
SmallVector<Value> offs(numPtrs);
|
||||
|
||||
int vecWidth = kWidth;
|
||||
int threadsPerQuad[2] = {8, 4};
|
||||
int laneWidth = 4;
|
||||
int laneHeight = 8;
|
||||
int vecWidth = 4 / elemBytes;
|
||||
int quadWidth = laneWidth * vecWidth;
|
||||
int quadHeight = laneHeight;
|
||||
int numQuadI = 2;
|
||||
@@ -232,8 +231,8 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
|
||||
Value i = add(iBase, mul(iOff, i32_val(quadHeight)));
|
||||
Value j = add(jBase, mul(jOff, i32_val(quadWidth)));
|
||||
// wrap around the bounds
|
||||
i = urem(i, i32_val(cTileShape));
|
||||
j = urem(j, i32_val(sTileShape));
|
||||
// i = urem(i, i32_val(cTileShape));
|
||||
// j = urem(j, i32_val(sTileShape));
|
||||
if (needTrans) {
|
||||
offs[idx] = add(i, mul(j, sStride));
|
||||
} else {
|
||||
@@ -304,7 +303,6 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1),
|
||||
extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)};
|
||||
} else {
|
||||
elemTy = matTy.cast<LLVM::LLVMStructType>().getBody()[0];
|
||||
// base pointers
|
||||
std::array<std::array<Value, 4>, 2> ptrs;
|
||||
int vecWidth = 4 / elemBytes;
|
||||
@@ -324,39 +322,50 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
|
||||
std::array<Value, 2> ii = {i0, i1};
|
||||
// load 4 32-bit values from shared memory
|
||||
// (equivalent to ldmatrix.x4)
|
||||
SmallVector<SmallVector<Value>> vals(4, SmallVector<Value>(vecWidth));
|
||||
SmallVector<SmallVector<Value>> vptrs(4, SmallVector<Value>(vecWidth));
|
||||
for (int i = 0; i < 4; ++i)
|
||||
for (int j = 0; j < vecWidth; ++j)
|
||||
vals[i][j] = load(gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]));
|
||||
vptrs[i][j] = gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]);
|
||||
// row + trans and col + no-trans are equivalent
|
||||
if ((needTrans && kOrder == 1) || (!needTrans && kOrder == 0))
|
||||
std::swap(vals[1], vals[2]);
|
||||
bool isActualTrans =
|
||||
(needTrans && kOrder == 1) || (!needTrans && kOrder == 0);
|
||||
if (isActualTrans)
|
||||
std::swap(vptrs[1], vptrs[2]);
|
||||
// pack loaded vectors into 4 32-bit values
|
||||
int inc = needTrans ? 1 : kWidth;
|
||||
VectorType packedTy = vec_ty(int_ty(8 * elemBytes), inc);
|
||||
int canonBits = std::min(32, 8 * elemBytes * inc);
|
||||
int canonWidth = (8 * elemBytes * inc) / canonBits;
|
||||
Type canonInt = int_ty(canonBits);
|
||||
std::array<Value, 4> retElems;
|
||||
retElems.fill(undef(elemTy));
|
||||
for (int m = 0; m < 4; ++m) {
|
||||
for (int e = 0; e < vecWidth; ++e)
|
||||
retElems[m] = insert_element(retElems[m].getType(), retElems[m],
|
||||
vals[m][e], i32_val(e));
|
||||
retElems.fill(undef(vec_ty(canonInt, 32 / canonBits)));
|
||||
for (int r = 0; r < 2; ++r) {
|
||||
for (int em = 0; em < 2 * vecWidth; em += inc) {
|
||||
int e = em % vecWidth;
|
||||
int m = em / vecWidth;
|
||||
int idx = m * 2 + r;
|
||||
Value ptr = bitcast(vptrs[idx][e], ptr_ty(packedTy, 3));
|
||||
Value val = load(ptr);
|
||||
Value canonval = bitcast(val, vec_ty(canonInt, canonWidth));
|
||||
for (int w = 0; w < canonWidth; ++w) {
|
||||
retElems[idx + w * kWidth / vecWidth] =
|
||||
insert_element(retElems[idx + w * kWidth / vecWidth],
|
||||
extract_element(canonval, i32_val(w)), i32_val(e));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (elemBytes == 1)
|
||||
return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty),
|
||||
bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)};
|
||||
else
|
||||
return {retElems[0], retElems[1], retElems[2], retElems[3]};
|
||||
return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty),
|
||||
bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)};
|
||||
}
|
||||
|
||||
assert(false && "Invalid smem load");
|
||||
return {Value{}, Value{}, Value{}, Value{}};
|
||||
}
|
||||
|
||||
MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
|
||||
int wpt, ArrayRef<uint32_t> order, uint32_t kOrder, int kWidth,
|
||||
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
|
||||
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
|
||||
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
|
||||
: order(order.begin(), order.end()), kOrder(kOrder),
|
||||
: order(order.begin(), order.end()), kOrder(kOrder), kWidth(kWidth),
|
||||
tileShape(tileShape.begin(), tileShape.end()),
|
||||
instrShape(instrShape.begin(), instrShape.end()),
|
||||
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
|
||||
@@ -369,7 +378,8 @@ MMA16816SmemLoader::MMA16816SmemLoader(
|
||||
|
||||
// rule: k must be the fast-changing axis.
|
||||
needTrans = kOrder != order[0];
|
||||
canUseLdmatrix = elemBytes == 2 || (!needTrans); // b16
|
||||
canUseLdmatrix = elemBytes == 2 || (!needTrans);
|
||||
canUseLdmatrix = canUseLdmatrix && (kWidth == 4 / elemBytes);
|
||||
|
||||
if (canUseLdmatrix) {
|
||||
// Each CTA, the warps is arranged as [1xwpt] if not transposed,
|
||||
@@ -409,42 +419,12 @@ Type getShemPtrTy(Type argType) {
|
||||
return ptr_ty(type::i16Ty(ctx), 3);
|
||||
else if (argType.isF32())
|
||||
return ptr_ty(type::f32Ty(ctx), 3);
|
||||
else if (argType.isInteger(8))
|
||||
else if (argType.getIntOrFloatBitWidth() == 8)
|
||||
return ptr_ty(type::i8Ty(ctx), 3);
|
||||
else
|
||||
llvm::report_fatal_error("mma16816 data type not supported");
|
||||
}
|
||||
|
||||
Type getMatType(Type argType) {
|
||||
MLIRContext *ctx = argType.getContext();
|
||||
// floating point types
|
||||
Type fp32x1Ty = vec_ty(type::f32Ty(ctx), 1);
|
||||
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
|
||||
Type i16x2Ty = vec_ty(type::i16Ty(ctx), 2);
|
||||
Type fp16x2Pack4Ty =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp16x2Ty));
|
||||
// LLVM 14.0 does not support bf16 type, so we use i16 instead.
|
||||
Type bf16x2Pack4Ty =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i16x2Ty));
|
||||
Type fp32Pack4Ty =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32x1Ty));
|
||||
// integer types
|
||||
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
|
||||
Type i8x4Pack4Ty =
|
||||
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
|
||||
|
||||
if (argType.isF16())
|
||||
return fp16x2Pack4Ty;
|
||||
else if (argType.isBF16())
|
||||
return bf16x2Pack4Ty;
|
||||
else if (argType.isF32())
|
||||
return fp32Pack4Ty;
|
||||
else if (argType.isInteger(8))
|
||||
return i8x4Pack4Ty;
|
||||
else
|
||||
llvm::report_fatal_error("mma16816 data type not supported");
|
||||
}
|
||||
|
||||
Value composeValuesToDotOperandLayoutStruct(
|
||||
const ValueTable &vals, int n0, int n1,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Location loc,
|
||||
@@ -470,7 +450,7 @@ Value composeValuesToDotOperandLayoutStruct(
|
||||
|
||||
std::function<void(int, int)>
|
||||
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
|
||||
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
|
||||
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, int kWidth,
|
||||
SmallVector<int> instrShape, SmallVector<int> matShape,
|
||||
Value warpId, Value lane, ValueTable &vals, bool isA,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
@@ -485,143 +465,105 @@ getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
|
||||
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
|
||||
auto order = sharedLayout.getOrder();
|
||||
|
||||
// the original register_lds2, but discard the prefetch logic.
|
||||
auto ld2 = [](ValueTable &vals, int mn, int k, Value val) {
|
||||
vals[{mn, k}] = val;
|
||||
};
|
||||
|
||||
// (a, b) is the coordinate.
|
||||
auto load = [=, &rewriter, &vals, &ld2](int a, int b) {
|
||||
auto load = [=, &rewriter, &vals](int a, int b) {
|
||||
MMA16816SmemLoader loader(
|
||||
wpt, sharedLayout.getOrder(), kOrder, smemObj.strides,
|
||||
wpt, sharedLayout.getOrder(), kOrder, kWidth, smemObj.strides,
|
||||
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
|
||||
maxPhase, elemBytes, rewriter, typeConverter, loc);
|
||||
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
|
||||
SmallVector<Value> offs =
|
||||
loader.computeOffsets(warpId, lane, cSwizzleOffset);
|
||||
// initialize pointers
|
||||
const int numPtrs = loader.getNumPtrs();
|
||||
SmallVector<Value> ptrs(numPtrs);
|
||||
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
|
||||
Type smemPtrTy = getShemPtrTy(eltTy);
|
||||
for (int i = 0; i < numPtrs; ++i) {
|
||||
ptrs[i] =
|
||||
bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy);
|
||||
}
|
||||
|
||||
for (int i = 0; i < numPtrs; ++i)
|
||||
ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy);
|
||||
// actually load from shared memory
|
||||
auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(),
|
||||
SmallVector<Type>(4, i32_ty));
|
||||
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
|
||||
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
|
||||
ptrs, getMatType(eltTy), getShemPtrTy(eltTy));
|
||||
|
||||
if (isA) {
|
||||
ld2(vals, a, b, ha0);
|
||||
ld2(vals, a + 1, b, ha1);
|
||||
ld2(vals, a, b + 1, ha2);
|
||||
ld2(vals, a + 1, b + 1, ha3);
|
||||
} else {
|
||||
ld2(vals, a, b, ha0);
|
||||
ld2(vals, a + 1, b, ha2);
|
||||
ld2(vals, a, b + 1, ha1);
|
||||
ld2(vals, a + 1, b + 1, ha3);
|
||||
}
|
||||
ptrs, matTy, getShemPtrTy(eltTy));
|
||||
if (!isA)
|
||||
std::swap(ha1, ha2);
|
||||
// the following is incorrect
|
||||
// but causes dramatically better performance in ptxas
|
||||
// although it only changes the order of operands in
|
||||
// `mma.sync`
|
||||
// if(isA)
|
||||
// std::swap(ha1, ha2);
|
||||
// update user-provided values in-place
|
||||
vals[{a, b}] = ha0;
|
||||
vals[{a + 1, b}] = ha1;
|
||||
vals[{a, b + 1}] = ha2;
|
||||
vals[{a + 1, b + 1}] = ha3;
|
||||
};
|
||||
|
||||
return load;
|
||||
}
|
||||
|
||||
Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
DotOperandEncodingAttr aEncoding, const SharedMemoryObject &smemObj,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
|
||||
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
int bitwidth = aTensorTy.getElementTypeBitWidth();
|
||||
auto mmaLayout = aEncoding.getParent().cast<MmaEncodingAttr>();
|
||||
|
||||
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
|
||||
aTensorTy.getShape().end());
|
||||
|
||||
ValueTable ha;
|
||||
std::function<void(int, int)> loadFn;
|
||||
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
|
||||
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
|
||||
|
||||
auto numRep = aEncoding.getMMAv2Rep(aTensorTy.getShape(), bitwidth);
|
||||
int numRepM = numRep[0];
|
||||
int numRepK = numRep[1];
|
||||
|
||||
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
|
||||
int wpt0 = mmaLayout.getWarpsPerCTA()[0];
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value lane = urem(thread, i32_val(32));
|
||||
Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16));
|
||||
// load from smem
|
||||
// we use ldmatrix.x4 so each warp processes 16x16 elements.
|
||||
int wpt = std::min<int>(wpt0, shape[0] / 16);
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/,
|
||||
{mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/,
|
||||
ha /*vals*/, true /*isA*/, typeConverter /* typeConverter */,
|
||||
rewriter /*rewriter*/, loc /*loc*/);
|
||||
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
|
||||
// load from registers, used in gemm fuse
|
||||
// TODO(Superjomn) Port the logic.
|
||||
assert(false && "Loading A from register is not supported yet.");
|
||||
} else {
|
||||
assert(false && "A's layout is not supported.");
|
||||
}
|
||||
|
||||
// step1. Perform loading.
|
||||
for (int m = 0; m < numRepM; ++m)
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * m, 2 * k);
|
||||
|
||||
// step2. Format the values to LLVM::Struct to passing to mma codegen.
|
||||
return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK,
|
||||
typeConverter, loc, rewriter);
|
||||
}
|
||||
|
||||
Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
|
||||
ValueTable hb;
|
||||
Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
|
||||
DotOperandEncodingAttr encoding,
|
||||
const SharedMemoryObject &smemObj,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value thread,
|
||||
bool isA) {
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
int bitwidth = tensorTy.getElementTypeBitWidth();
|
||||
auto mmaLayout = bEncoding.getParent().cast<MmaEncodingAttr>();
|
||||
auto mmaLayout = encoding.getParent().cast<MmaEncodingAttr>();
|
||||
|
||||
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
|
||||
tensorTy.getShape().end());
|
||||
|
||||
ValueTable vals;
|
||||
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
|
||||
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
|
||||
|
||||
auto numRep = bEncoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
|
||||
int numRepK = numRep[0];
|
||||
int numRepN = numRep[1];
|
||||
auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
|
||||
int kWidth = encoding.getMMAv2kWidth();
|
||||
|
||||
int wpt0 = mmaLayout.getWarpsPerCTA()[0];
|
||||
int wpt1 = mmaLayout.getWarpsPerCTA()[1];
|
||||
Value warp = udiv(thread, i32_val(32));
|
||||
Value lane = urem(thread, i32_val(32));
|
||||
Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16));
|
||||
Value warpMN = udiv(warp, i32_val(wpt0));
|
||||
Value warpN = urem(urem(warpMN, i32_val(wpt1)), i32_val(shape[1] / 8));
|
||||
// we use ldmatrix.x4 so each warp processes 16x16 elements.
|
||||
int wpt = std::min<int>(wpt1, shape[1] / 16);
|
||||
auto loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/,
|
||||
{mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/,
|
||||
hb /*vals*/, false /*isA*/, typeConverter /* typeConverter */,
|
||||
rewriter /*rewriter*/, loc /*loc*/);
|
||||
|
||||
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
|
||||
int wpt;
|
||||
if (isA)
|
||||
wpt = std::min<int>(wpt0, shape[0] / 16);
|
||||
else
|
||||
wpt = std::min<int>(wpt1, shape[1] / 16);
|
||||
|
||||
std::function<void(int, int)> loadFn;
|
||||
if (isA)
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, kWidth,
|
||||
{mmaInstrM, mmaInstrK} /*instrShape*/,
|
||||
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/,
|
||||
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
|
||||
rewriter /*rewriter*/, loc /*loc*/);
|
||||
else
|
||||
loadFn = getLoadMatrixFn(
|
||||
tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, kWidth,
|
||||
{mmaInstrK, mmaInstrN} /*instrShape*/,
|
||||
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/,
|
||||
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
|
||||
rewriter /*rewriter*/, loc /*loc*/);
|
||||
|
||||
// Perform loading.
|
||||
int numRepOuter = isA ? numRep[0] : std::max<int>(numRep[1] / 2, 1);
|
||||
int numRepK = isA ? numRep[1] : numRep[0];
|
||||
for (int m = 0; m < numRepOuter; ++m)
|
||||
for (int k = 0; k < numRepK; ++k)
|
||||
loadFn(2 * n, 2 * k);
|
||||
}
|
||||
loadFn(2 * m, 2 * k);
|
||||
|
||||
Value result = composeValuesToDotOperandLayoutStruct(
|
||||
hb, std::max(numRepN / 2, 1), numRepK, typeConverter, loc, rewriter);
|
||||
return result;
|
||||
// Format the values to LLVM::Struct to passing to mma codegen.
|
||||
return composeValuesToDotOperandLayoutStruct(vals, numRepOuter, numRepK,
|
||||
typeConverter, loc, rewriter);
|
||||
}
|
||||
|
||||
namespace SharedToDotOperandMMAv2 {
|
||||
@@ -630,12 +572,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
|
||||
const SharedMemoryObject &smemObj,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
|
||||
if (opIdx == 0)
|
||||
return loadA(rewriter, loc, tensor, encoding, smemObj, typeConverter,
|
||||
thread);
|
||||
return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter,
|
||||
thread, true);
|
||||
else {
|
||||
assert(opIdx == 1);
|
||||
return loadB(rewriter, loc, tensor, encoding, smemObj, typeConverter,
|
||||
thread);
|
||||
return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter,
|
||||
thread, false);
|
||||
}
|
||||
}
|
||||
} // namespace SharedToDotOperandMMAv2
|
||||
|
||||
@@ -62,9 +62,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
|
||||
};
|
||||
|
||||
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
RewritePatternSet &patterns,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
|
||||
patterns.add<DotOpConversion>(typeConverter, allocation, benefit);
|
||||
}
|
||||
|
||||
@@ -7,9 +7,8 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
RewritePatternSet &patterns,
|
||||
ModuleAllocation &allocation,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -4,11 +4,140 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::getTotalElemsPerThread;
|
||||
|
||||
static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
|
||||
Type inType, Type ouType) {
|
||||
auto inTensorTy = inType.dyn_cast<RankedTensorType>();
|
||||
auto ouTensorTy = ouType.dyn_cast<RankedTensorType>();
|
||||
if (!inTensorTy || !ouTensorTy)
|
||||
return values;
|
||||
auto inEncoding =
|
||||
dyn_cast<triton::gpu::DotOperandEncodingAttr>(inTensorTy.getEncoding());
|
||||
auto ouEncoding =
|
||||
dyn_cast<triton::gpu::DotOperandEncodingAttr>(ouTensorTy.getEncoding());
|
||||
assert(inEncoding == ouEncoding);
|
||||
if (!inEncoding)
|
||||
return values;
|
||||
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
|
||||
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
|
||||
auto ouEltTy = ouTensorTy.getElementType();
|
||||
if (inBitWidth == ouBitWidth)
|
||||
return values;
|
||||
if (inBitWidth == 16 && ouBitWidth == 32) {
|
||||
SmallVector<Value> ret;
|
||||
for (unsigned i = 0; i < values.size(); i += 8) {
|
||||
ret.push_back(values[i]);
|
||||
ret.push_back(values[i + 1]);
|
||||
ret.push_back(values[i + 4]);
|
||||
ret.push_back(values[i + 5]);
|
||||
ret.push_back(values[i + 2]);
|
||||
ret.push_back(values[i + 3]);
|
||||
ret.push_back(values[i + 6]);
|
||||
ret.push_back(values[i + 7]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
if (inBitWidth == 8 && ouBitWidth == 16) {
|
||||
SmallVector<Value> ret;
|
||||
for (unsigned i = 0; i < values.size(); i += 16) {
|
||||
ret.push_back(values[i + 0]);
|
||||
ret.push_back(values[i + 1]);
|
||||
ret.push_back(values[i + 2]);
|
||||
ret.push_back(values[i + 3]);
|
||||
ret.push_back(values[i + 8]);
|
||||
ret.push_back(values[i + 9]);
|
||||
ret.push_back(values[i + 10]);
|
||||
ret.push_back(values[i + 11]);
|
||||
ret.push_back(values[i + 4]);
|
||||
ret.push_back(values[i + 5]);
|
||||
ret.push_back(values[i + 6]);
|
||||
ret.push_back(values[i + 7]);
|
||||
ret.push_back(values[i + 12]);
|
||||
ret.push_back(values[i + 13]);
|
||||
ret.push_back(values[i + 14]);
|
||||
ret.push_back(values[i + 15]);
|
||||
}
|
||||
return ret;
|
||||
// for (unsigned i = 0; i < values.size(); i += 16) {
|
||||
// ret.push_back(values[i]);
|
||||
// ret.push_back(values[i + 1]);
|
||||
// ret.push_back(values[i + 4]);
|
||||
// ret.push_back(values[i + 5]);
|
||||
// ret.push_back(values[i + 8]);
|
||||
// ret.push_back(values[i + 9]);
|
||||
// ret.push_back(values[i + 12]);
|
||||
// ret.push_back(values[i + 13]);
|
||||
|
||||
// ret.push_back(values[i + 2]);
|
||||
// ret.push_back(values[i + 3]);
|
||||
// ret.push_back(values[i + 6]);
|
||||
// ret.push_back(values[i + 7]);
|
||||
// ret.push_back(values[i + 10]);
|
||||
// ret.push_back(values[i + 11]);
|
||||
// ret.push_back(values[i + 14]);
|
||||
// ret.push_back(values[i + 15]);
|
||||
// }
|
||||
return values;
|
||||
}
|
||||
llvm_unreachable("unimplemented code path");
|
||||
}
|
||||
|
||||
inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
|
||||
Type srcTy,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc,
|
||||
TypeConverter *typeConverter) {
|
||||
auto tensorTy = srcTy.dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return inValues;
|
||||
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
if (!(encoding && encoding.getParent().isa<MmaEncodingAttr>()))
|
||||
return inValues;
|
||||
SmallVector<Value> outValues;
|
||||
for (auto v : inValues) {
|
||||
// cast i32 to appropriate eltType vector and extract elements
|
||||
auto eltType = typeConverter->convertType(tensorTy.getElementType());
|
||||
auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth());
|
||||
auto vec = bitcast(v, vecType);
|
||||
for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) {
|
||||
outValues.push_back(extract_element(vec, i32_val(i)));
|
||||
}
|
||||
}
|
||||
return outValues;
|
||||
}
|
||||
|
||||
inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
|
||||
Type srcTy,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, TypeConverter *typeConverter) {
|
||||
auto tensorTy = srcTy.dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
return inValues;
|
||||
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
|
||||
if (!(encoding && encoding.getParent().isa<MmaEncodingAttr>()))
|
||||
return inValues;
|
||||
SmallVector<Value> outValues;
|
||||
auto eltType = typeConverter->convertType(tensorTy.getElementType());
|
||||
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
|
||||
auto vecType = vec_ty(eltType, vecWidth);
|
||||
for (int i = 0; i < inValues.size(); i += vecWidth) {
|
||||
Value vec = undef(vecType);
|
||||
for (int j = 0; j < vecWidth; j++) {
|
||||
vec = insert_element(vec, inValues[i + j], i32_val(j));
|
||||
}
|
||||
outValues.push_back(bitcast(vec, i32_ty));
|
||||
}
|
||||
return outValues;
|
||||
}
|
||||
|
||||
struct FpToFpOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
typedef std::function<SmallVector<Value>(
|
||||
Location, ConversionPatternRewriter &, const Value &, const Value &,
|
||||
const Value &, const Value &)>
|
||||
ConvertorT;
|
||||
/* ------------------ */
|
||||
// FP8 -> FP16
|
||||
/* ------------------ */
|
||||
@@ -747,35 +876,14 @@ struct FpToFpOpConversion
|
||||
#endif
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto srcTensorType = op.getFrom().getType().cast<mlir::RankedTensorType>();
|
||||
auto dstTensorType =
|
||||
op.getResult().getType().cast<mlir::RankedTensorType>();
|
||||
auto srcEltType = srcTensorType.getElementType();
|
||||
auto dstEltType = dstTensorType.getElementType();
|
||||
auto loc = op->getLoc();
|
||||
auto elems = getTotalElemsPerThread(dstTensorType);
|
||||
SmallVector<Value> resultVals;
|
||||
bool isSrcFP8 =
|
||||
srcEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
|
||||
bool isDstFP8 =
|
||||
dstEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
|
||||
|
||||
// Select convertor
|
||||
typedef std::function<SmallVector<Value>(
|
||||
Location, ConversionPatternRewriter &, const Value &, const Value &,
|
||||
const Value &, const Value &)>
|
||||
ConvertorT;
|
||||
|
||||
ConvertorT getConversionFunc(Type srcTy, Type dstTy) const {
|
||||
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNType>();
|
||||
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
|
||||
auto F16TyID = TypeID::get<mlir::Float16Type>();
|
||||
auto BF16TyID = TypeID::get<mlir::BFloat16Type>();
|
||||
auto F32TyID = TypeID::get<mlir::Float32Type>();
|
||||
auto F64TyID = TypeID::get<mlir::Float64Type>();
|
||||
DenseMap<std::pair<TypeID, TypeID>, ConvertorT> convertorMap = {
|
||||
static DenseMap<std::pair<TypeID, TypeID>, ConvertorT> convertorMap = {
|
||||
// F8 -> F16
|
||||
{{F8E4M3TyID, F16TyID}, convertFp8E4M3x4ToFp16x4},
|
||||
{{F8E5M2TyID, F16TyID}, convertFp8E5M2x4ToFp16x4},
|
||||
@@ -796,28 +904,46 @@ struct FpToFpOpConversion
|
||||
{{F32TyID, F8E5M2TyID}, convertFp32x4ToFp8E5M2x4},
|
||||
};
|
||||
|
||||
std::pair<TypeID, TypeID> key = {srcEltType.getTypeID(),
|
||||
dstEltType.getTypeID()};
|
||||
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
|
||||
if (convertorMap.count(key) == 0) {
|
||||
llvm::errs() << "Unsupported conversion from " << srcEltType << " to "
|
||||
<< dstEltType << "\n";
|
||||
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
|
||||
<< "\n";
|
||||
llvm_unreachable("");
|
||||
}
|
||||
auto convertor = convertorMap.lookup(key);
|
||||
return convertorMap.lookup(key);
|
||||
}
|
||||
|
||||
// Vectorized casting
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// llvm::outs() << 0 << "\n";
|
||||
auto srcTensorType = op.getFrom().getType().cast<mlir::RankedTensorType>();
|
||||
auto dstTensorType =
|
||||
op.getResult().getType().cast<mlir::RankedTensorType>();
|
||||
auto loc = op->getLoc();
|
||||
// check that the number of elements is divisible by 4
|
||||
// Get convertor
|
||||
auto cvtFunc = getConversionFunc(srcTensorType.getElementType(),
|
||||
dstTensorType.getElementType());
|
||||
// Unpack value
|
||||
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getFrom(),
|
||||
rewriter, srcTensorType);
|
||||
inVals =
|
||||
unpackI32(inVals, srcTensorType, rewriter, loc, getTypeConverter());
|
||||
// Cast
|
||||
SmallVector<Value> outVals;
|
||||
auto elems = inVals.size();
|
||||
assert(elems % 4 == 0 &&
|
||||
"FP8 casting only support tensors with 4-aligned sizes");
|
||||
auto elements = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getFrom(), rewriter, srcTensorType);
|
||||
for (size_t i = 0; i < elems; i += 4) {
|
||||
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
|
||||
elements[i + 2], elements[i + 3]);
|
||||
resultVals.append(converted);
|
||||
}
|
||||
|
||||
assert(resultVals.size() == elems);
|
||||
auto result = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
|
||||
for (size_t i = 0; i < elems; i += 4)
|
||||
outVals.append(cvtFunc(loc, rewriter, inVals[i], inVals[i + 1],
|
||||
inVals[i + 2], inVals[i + 3]));
|
||||
// Pack values
|
||||
assert(outVals.size() == elems);
|
||||
outVals = reorderValues(outVals, srcTensorType, dstTensorType);
|
||||
outVals =
|
||||
packI32(outVals, dstTensorType, rewriter, loc, getTypeConverter());
|
||||
auto result = getTypeConverter()->packLLElements(loc, outVals, rewriter,
|
||||
dstTensorType);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
@@ -849,43 +975,44 @@ public:
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto resultTy = op.getType();
|
||||
Location loc = op->getLoc();
|
||||
|
||||
unsigned elems = getTotalElemsPerThread(resultTy);
|
||||
// element type
|
||||
auto resultElementTy = getElementTypeOrSelf(resultTy);
|
||||
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = this->getTypeConverter()->convertType(resultTy);
|
||||
|
||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||
auto operands = getOperands(rewriter, adaptor, resultTy, elems, loc);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
|
||||
operands[i], loc);
|
||||
if (!bool(resultVals[i]))
|
||||
return failure();
|
||||
SmallVector<Value> resultVals;
|
||||
//
|
||||
SmallVector<SmallVector<Value>> allOperands;
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto argTy = op->getOperand(0).getType();
|
||||
auto sub_operands = this->getTypeConverter()->unpackLLElements(
|
||||
loc, operand, rewriter, argTy);
|
||||
sub_operands = unpackI32(sub_operands, argTy, rewriter, loc,
|
||||
this->getTypeConverter());
|
||||
allOperands.resize(sub_operands.size());
|
||||
for (auto v : llvm::enumerate(sub_operands))
|
||||
allOperands[v.index()].push_back(v.value());
|
||||
}
|
||||
if (allOperands.size() == 0)
|
||||
allOperands.push_back({});
|
||||
for (const SmallVector<Value> &operands : allOperands) {
|
||||
Value curr =
|
||||
((ConcreteT *)(this))
|
||||
->createDestOp(op, adaptor, rewriter, elemTy, operands, loc);
|
||||
if (!bool(curr))
|
||||
return failure();
|
||||
resultVals.push_back(curr);
|
||||
}
|
||||
if (op->getNumOperands() > 0) {
|
||||
auto argTy = op->getOperand(0).getType();
|
||||
resultVals = reorderValues(resultVals, argTy, resultTy);
|
||||
}
|
||||
resultVals =
|
||||
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
|
||||
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
protected:
|
||||
SmallVector<SmallVector<Value>>
|
||||
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
|
||||
Type operandTy, const unsigned elems, Location loc) const {
|
||||
SmallVector<SmallVector<Value>> operands(elems);
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = this->getTypeConverter()->unpackLLElements(
|
||||
loc, operand, rewriter, operandTy);
|
||||
for (size_t i = 0; i < elems; ++i) {
|
||||
operands[i].push_back(sub_operands[i]);
|
||||
}
|
||||
}
|
||||
return operands;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp, typename DestOp>
|
||||
@@ -1344,8 +1471,7 @@ struct AbsFOpConversion
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem, PatternBenefit benefit) {
|
||||
PatternBenefit benefit) {
|
||||
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
|
||||
|
||||
@@ -8,8 +8,7 @@ using namespace mlir::triton;
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem, PatternBenefit benefit);
|
||||
PatternBenefit benefit);
|
||||
|
||||
bool isLegalElementwiseOp(Operation *op);
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
// Contains some helper functions for both Load and Store conversions.
|
||||
struct LoadStoreConversionBase {
|
||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass)
|
||||
: axisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
unsigned getContiguity(Value ptr) const {
|
||||
@@ -38,7 +38,7 @@ struct LoadStoreConversionBase {
|
||||
}
|
||||
|
||||
protected:
|
||||
AxisInfoAnalysis &axisAnalysisPass;
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass;
|
||||
};
|
||||
|
||||
struct LoadOpConversion
|
||||
@@ -48,7 +48,8 @@ struct LoadOpConversion
|
||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LoadOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
@@ -293,7 +294,8 @@ struct StoreOpConversion
|
||||
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
StoreOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
@@ -335,14 +337,7 @@ struct StoreOpConversion
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
// numElements = 1 for scalar
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNBits = dtsize * 8;
|
||||
@@ -431,11 +426,11 @@ struct AtomicCASOpConversion
|
||||
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
ModuleAllocation &allocation,
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(
|
||||
converter, allocation, smem, benefit),
|
||||
converter, allocation, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -526,13 +521,13 @@ struct AtomicCASOpConversion
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto valueTy = op.getResult().getType();
|
||||
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
: valueTy;
|
||||
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
|
||||
memfence();
|
||||
@@ -552,7 +547,7 @@ struct AtomicCASOpConversion
|
||||
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
|
||||
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
|
||||
atom.global().o("cas").o("b32");
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
|
||||
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
|
||||
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
|
||||
barrier();
|
||||
|
||||
@@ -561,7 +556,7 @@ struct AtomicCASOpConversion
|
||||
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
|
||||
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
|
||||
st.shared().o("b32");
|
||||
st(dstOprStore, valOprStore).predicate(pred);
|
||||
st(dstOprStore, valOprStore).predicate(mask);
|
||||
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
barrier();
|
||||
@@ -580,11 +575,11 @@ struct AtomicRMWOpConversion
|
||||
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
ModuleAllocation &allocation,
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
|
||||
converter, allocation, smem, benefit),
|
||||
converter, allocation, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -747,10 +742,11 @@ struct AtomicRMWOpConversion
|
||||
maskElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llMask, rewriter, op.getMask().getType());
|
||||
|
||||
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto valueTy = op.getResult().getType();
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
: valueTy;
|
||||
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getTotalElemsPerThread(val.getType());
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
@@ -763,10 +759,7 @@ struct AtomicRMWOpConversion
|
||||
// mask
|
||||
numElems = tensorTy.getNumElements();
|
||||
}
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
Value mask = getMask(valueTy, rewriter, loc);
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
@@ -846,7 +839,6 @@ struct AtomicRMWOpConversion
|
||||
memfenc();
|
||||
auto ASMReturnTy = void_ty(ctx);
|
||||
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
|
||||
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
@@ -889,7 +881,9 @@ struct InsertSliceOpConversion
|
||||
Value dst = op.getDest();
|
||||
Value src = op.getSource();
|
||||
Value res = op.getResult();
|
||||
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
auto funcOp = op->getParentOfType<FunctionOpInterface>();
|
||||
auto *funcAllocation = allocation->getFuncData(funcOp);
|
||||
assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
"Only support in-place insert_slice for now");
|
||||
|
||||
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
|
||||
@@ -949,12 +943,11 @@ struct InsertSliceAsyncOpConversion
|
||||
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
InsertSliceAsyncOpConversion(
|
||||
TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation,
|
||||
Value smem,
|
||||
TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||
converter, allocation, smem, indexCacheInfo, benefit),
|
||||
converter, allocation, indexCacheInfo, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
|
||||
LogicalResult
|
||||
@@ -967,7 +960,9 @@ struct InsertSliceAsyncOpConversion
|
||||
Value res = op.getResult();
|
||||
Value mask = op.getMask();
|
||||
Value other = op.getOther();
|
||||
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
auto funcOp = op->getParentOfType<FunctionOpInterface>();
|
||||
auto *funcAllocation = allocation->getFuncData(funcOp);
|
||||
assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId &&
|
||||
"Only support in-place insert_slice_async for now");
|
||||
|
||||
auto srcTy = src.getType().cast<RankedTensorType>();
|
||||
@@ -1107,19 +1102,17 @@ struct InsertSliceAsyncOpConversion
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
|
||||
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<AtomicCASOpConversion>(typeConverter, allocation,
|
||||
axisInfoAnalysis, benefit);
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation,
|
||||
axisInfoAnalysis, benefit);
|
||||
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<InsertSliceOpConversion>(typeConverter, allocation,
|
||||
indexCacheInfo, benefit);
|
||||
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
|
||||
indexCacheInfo, axisInfoAnalysis,
|
||||
benefit);
|
||||
patterns.add<InsertSliceAsyncOpConversion>(
|
||||
typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit);
|
||||
}
|
||||
|
||||
@@ -8,8 +8,7 @@ using namespace mlir::triton;
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
|
||||
@@ -87,6 +87,15 @@ private:
|
||||
Attribute layout, SmallVector<Value> &index,
|
||||
SmallVector<Value> &writeIdx,
|
||||
std::map<int, Value> &ints, unsigned axis) const {
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto dim = sliceLayout.getDim();
|
||||
assert(dim != axis && "Reduction axis cannot be sliced");
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints,
|
||||
axis);
|
||||
return;
|
||||
}
|
||||
|
||||
writeIdx = index;
|
||||
auto sizePerThread = triton::gpu::getSizePerThread(layout);
|
||||
Value axisSizePerThread = ints[sizePerThread[axis]];
|
||||
@@ -100,9 +109,10 @@ private:
|
||||
// to map every `axisSizePerThread` to 1 value in smem as:
|
||||
// writeIdx[axis] = index[axis] / axisSizePerThread
|
||||
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
|
||||
}
|
||||
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
|
||||
if (mmaLayout && mmaLayout.isAmpere()) {
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (!mmaLayout.isAmpere()) {
|
||||
llvm::report_fatal_error("Unsupported layout");
|
||||
}
|
||||
if (axis == 0) {
|
||||
// Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8
|
||||
// rows in smem would correspond to a warp. The mapping
|
||||
@@ -113,8 +123,7 @@ private:
|
||||
// Same as BlockedEncodingAttr case
|
||||
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
|
||||
}
|
||||
}
|
||||
if (mmaLayout && !mmaLayout.isAmpere()) {
|
||||
} else {
|
||||
llvm::report_fatal_error("Unsupported layout");
|
||||
}
|
||||
}
|
||||
@@ -327,8 +336,8 @@ private:
|
||||
elemPtrTys[i]);
|
||||
}
|
||||
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSize();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
|
||||
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
|
||||
|
||||
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
|
||||
@@ -495,10 +504,9 @@ private:
|
||||
|
||||
void populateReduceOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem,
|
||||
indexCacheInfo, benefit);
|
||||
patterns.add<ReduceOpConversion>(typeConverter, allocation, indexCacheInfo,
|
||||
benefit);
|
||||
}
|
||||
|
||||
@@ -8,8 +8,7 @@ using namespace mlir::triton;
|
||||
|
||||
void populateReduceOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
|
||||
@@ -401,8 +401,7 @@ struct MakeRangeOpConversion
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
|
||||
converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo,
|
||||
benefit) {}
|
||||
converter, indexCacheInfo, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
@@ -669,18 +668,17 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAllocation &moduleAllocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<AllocTensorOpConversion>(typeConverter, moduleAllocation,
|
||||
benefit);
|
||||
patterns.add<AsyncCommitGroupOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
|
||||
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
|
||||
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
|
||||
benefit);
|
||||
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
|
||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||
|
||||
@@ -8,8 +8,7 @@ using namespace mlir::triton;
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ModuleAllocation &allocation,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit);
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
|
||||
#include <set>
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
@@ -41,7 +41,7 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
// All the rights are reserved by the LLVM community.
|
||||
|
||||
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<triton::FuncOp> {
|
||||
private:
|
||||
protected:
|
||||
/// Only retain those attributes that are not constructed by
|
||||
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
|
||||
/// attributes.
|
||||
@@ -184,14 +184,18 @@ public:
|
||||
: converter(&typeConverter) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
|
||||
Value smem)
|
||||
: converter(&typeConverter), allocation(allocation), smem(smem) {}
|
||||
TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
IndexCacheInfo indexCacheInfo)
|
||||
: converter(&typeConverter), indexCacheInfo(indexCacheInfo) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
|
||||
Value smem, IndexCacheInfo indexCacheInfo)
|
||||
: converter(&typeConverter), allocation(allocation), smem(smem),
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation)
|
||||
: converter(&typeConverter), allocation(&allocation) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
IndexCacheInfo indexCacheInfo)
|
||||
: converter(&typeConverter), allocation(&allocation),
|
||||
indexCacheInfo(indexCacheInfo) {}
|
||||
|
||||
TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; }
|
||||
@@ -228,9 +232,17 @@ public:
|
||||
T value) const {
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
FunctionOpInterface funcOp;
|
||||
if constexpr (std::is_pointer_v<T>)
|
||||
funcOp = value->template getParentOfType<FunctionOpInterface>();
|
||||
else
|
||||
funcOp = value.getParentRegion()
|
||||
->template getParentOfType<FunctionOpInterface>();
|
||||
auto *funcAllocation = allocation->getFuncData(funcOp);
|
||||
auto smem = allocation->getFunctionSharedMemoryBase(funcOp);
|
||||
auto bufferId = funcAllocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
size_t offset = funcAllocation->getOffset(bufferId);
|
||||
Value offVal = i32_val(offset);
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
@@ -409,6 +421,46 @@ public:
|
||||
// -----------------------------------------------------------------------
|
||||
// Utilities
|
||||
// -----------------------------------------------------------------------
|
||||
Value getMask(Type valueTy, ConversionPatternRewriter &rewriter,
|
||||
Location loc) const {
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
if (tensorTy) {
|
||||
auto layout = tensorTy.getEncoding();
|
||||
auto shape = tensorTy.getShape();
|
||||
unsigned rank = shape.size();
|
||||
auto sizePerThread = triton::gpu::getSizePerThread(layout);
|
||||
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
|
||||
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
|
||||
auto order = triton::gpu::getOrder(layout);
|
||||
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
|
||||
Value warpSize = i32_val(32);
|
||||
Value laneId = urem(tid, warpSize);
|
||||
Value warpId = udiv(tid, warpSize);
|
||||
SmallVector<Value> multiDimWarpId =
|
||||
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
|
||||
SmallVector<Value> multiDimThreadId =
|
||||
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
|
||||
for (unsigned dim = 0; dim < rank; ++dim) {
|
||||
// if there is no data replication across threads on this dimension
|
||||
if (shape[dim] >= shapePerCTA[dim])
|
||||
continue;
|
||||
// Otherwise, we need to mask threads that will replicate data on this
|
||||
// dimension. Calculate the thread index on this dimension for the CTA
|
||||
Value threadDim =
|
||||
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
|
||||
multiDimThreadId[dim]);
|
||||
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
|
||||
i32_val(shape[dim])));
|
||||
}
|
||||
} else {
|
||||
// If the tensor is not ranked, then it is a scalar and only thread 0 can
|
||||
// write
|
||||
mask = and_(mask, icmp_eq(tid, i32_val(0)));
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
// Convert an \param index to a multi-dim coordinate given \param shape and
|
||||
// \param order.
|
||||
@@ -505,13 +557,13 @@ public:
|
||||
RankedTensorType type) const {
|
||||
IndexCacheKeyT key = std::make_pair(layout, type);
|
||||
auto cache = indexCacheInfo.baseIndexCache;
|
||||
assert(cache && "baseIndexCache is nullptr");
|
||||
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||
if (cache->count(key) > 0) {
|
||||
if (cache && cache->count(key) > 0) {
|
||||
return cache->lookup(key);
|
||||
} else {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
restoreInsertionPointIfSet(insertPt, rewriter);
|
||||
if (cache)
|
||||
restoreInsertionPointIfSet(insertPt, rewriter);
|
||||
SmallVector<Value> result;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
result =
|
||||
@@ -521,11 +573,20 @@ public:
|
||||
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
|
||||
if (mmaLayout.isAmpere())
|
||||
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(type.getShape());
|
||||
RankedTensorType parentTy = RankedTensorType::get(
|
||||
parentShape, type.getElementType(), parentLayout);
|
||||
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy);
|
||||
result.erase(result.begin() + sliceLayout.getDim());
|
||||
} else {
|
||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||
}
|
||||
cache->insert(std::make_pair(key, result));
|
||||
*insertPt = rewriter.saveInsertionPoint();
|
||||
if (cache) {
|
||||
cache->insert(std::make_pair(key, result));
|
||||
*insertPt = rewriter.saveInsertionPoint();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -540,6 +601,8 @@ public:
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, type);
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
|
||||
return emitOffsetForSliceLayout(sliceLayout, type);
|
||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||
}
|
||||
|
||||
@@ -552,27 +615,29 @@ public:
|
||||
RankedTensorType type) const {
|
||||
IndexCacheKeyT key(layout, type);
|
||||
auto cache = indexCacheInfo.indexCache;
|
||||
assert(cache && "indexCache is nullptr");
|
||||
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||
if (cache->count(key) > 0) {
|
||||
if (cache && cache->count(key) > 0) {
|
||||
return cache->lookup(key);
|
||||
} else {
|
||||
ConversionPatternRewriter::InsertionGuard guard(b);
|
||||
restoreInsertionPointIfSet(insertPt, b);
|
||||
if (cache)
|
||||
restoreInsertionPointIfSet(insertPt, b);
|
||||
SmallVector<SmallVector<Value>> result;
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, blocked, type);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, mma, type);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
result = emitIndicesForSliceLayout(loc, b, slice, type);
|
||||
result = emitIndicesForDistributedLayout(loc, b, slice, type);
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"emitIndices for layouts other than blocked & slice not "
|
||||
"implemented yet");
|
||||
}
|
||||
cache->insert(std::make_pair(key, result));
|
||||
*insertPt = b.saveInsertionPoint();
|
||||
if (cache) {
|
||||
cache->insert(std::make_pair(key, result));
|
||||
*insertPt = b.saveInsertionPoint();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
@@ -722,11 +787,11 @@ private:
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
// A info
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
@@ -783,12 +848,12 @@ private:
|
||||
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
|
||||
// A info
|
||||
auto aEncoding =
|
||||
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout);
|
||||
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout, 0);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding =
|
||||
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout);
|
||||
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout, 0);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
@@ -891,24 +956,29 @@ private:
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
RankedTensorType type) const {
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout,
|
||||
RankedTensorType type) const {
|
||||
auto parentEncoding = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parentShape = sliceLayout.paddedShape(type.getShape());
|
||||
RankedTensorType parentTy = RankedTensorType::get(
|
||||
parentShape, type.getElementType(), parentEncoding);
|
||||
auto parentIndices = emitIndices(loc, rewriter, parentEncoding, parentTy);
|
||||
unsigned numIndices = parentIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices;
|
||||
for (unsigned i = 0; i < numIndices; ++i) {
|
||||
SmallVector<Value> indices = parentIndices[i];
|
||||
indices.erase(indices.begin() + dim);
|
||||
resultIndices.push_back(indices);
|
||||
auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy);
|
||||
|
||||
unsigned numOffsets = parentOffsets.size();
|
||||
SmallVector<SmallVector<unsigned>> resultOffsets;
|
||||
std::set<SmallVector<unsigned>> uniqueOffsets;
|
||||
|
||||
for (unsigned i = 0; i < numOffsets; ++i) {
|
||||
SmallVector<unsigned> offsets = parentOffsets[i];
|
||||
offsets.erase(offsets.begin() + dim);
|
||||
if (uniqueOffsets.find(offsets) == uniqueOffsets.end()) {
|
||||
resultOffsets.push_back(offsets);
|
||||
uniqueOffsets.insert(offsets);
|
||||
}
|
||||
}
|
||||
return resultIndices;
|
||||
return resultOffsets;
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
@@ -1057,8 +1127,7 @@ protected:
|
||||
|
||||
protected:
|
||||
TritonGPUToLLVMTypeConverter *converter;
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
ModuleAllocation *allocation;
|
||||
IndexCacheInfo indexCacheInfo;
|
||||
};
|
||||
|
||||
@@ -1075,16 +1144,22 @@ public:
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
|
||||
Value smem, PatternBenefit benefit = 1)
|
||||
TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, indexCacheInfo) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
|
||||
Value smem, IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
|
||||
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation,
|
||||
indexCacheInfo) {}
|
||||
|
||||
protected:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
@@ -61,17 +61,38 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
// Currently, Triton kernel function always return nothing.
|
||||
// TODO(Superjomn) add support for non-inline device function
|
||||
if (numArguments > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only kernel function with nothing returned is supported.");
|
||||
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
|
||||
if (funcOp->hasAttr("nvvm.kernel")) {
|
||||
// A GPU kernel
|
||||
if (op.getNumOperands() > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Kernel functions do not support return with operands");
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
||||
op->getAttrs());
|
||||
} else {
|
||||
// A device function
|
||||
LLVM::ReturnOp newOp;
|
||||
if (adaptor.getOperands().size() < 2) {
|
||||
// Single or no return value.
|
||||
newOp =
|
||||
rewriter.create<LLVM::ReturnOp>(op.getLoc(), adaptor.getOperands());
|
||||
} else {
|
||||
// Pack the results into a struct.
|
||||
auto packedResultsTy = this->getTypeConverter()->packFunctionResults(
|
||||
funcOp.getResultTypes());
|
||||
Value packedResults =
|
||||
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
|
||||
auto loc = op.getLoc();
|
||||
for (auto it : llvm::enumerate(adaptor.getOperands())) {
|
||||
packedResults = insert_val(packedResultsTy, packedResults, it.value(),
|
||||
it.index());
|
||||
}
|
||||
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
|
||||
}
|
||||
newOp->setAttrs(op->getAttrs());
|
||||
rewriter.replaceOp(op, newOp->getResults());
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
||||
op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -81,19 +102,57 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
|
||||
/// information.
|
||||
struct FuncOpConversion : public FuncOpConversionBase {
|
||||
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||
PatternBenefit benefit)
|
||||
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
|
||||
ModuleAllocation &allocation, PatternBenefit benefit)
|
||||
: FuncOpConversionBase(converter, benefit), numWarps(numWarps),
|
||||
allocation(allocation) {}
|
||||
|
||||
triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Push back a variable that indicates the current stack pointer of shared
|
||||
// memory to the function arguments.
|
||||
auto loc = funcOp.getLoc();
|
||||
auto ctx = funcOp->getContext();
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
// 1. Modify the function type to add the new argument.
|
||||
auto funcTy = funcOp.getFunctionType();
|
||||
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
|
||||
amendedInputTy.push_back(ptrTy);
|
||||
auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy,
|
||||
funcTy.getResults());
|
||||
// 2. Modify the argument attributes to add the new argument.
|
||||
SmallVector<NamedAttribute> amendedAttrs;
|
||||
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs);
|
||||
auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs());
|
||||
amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx));
|
||||
amendedAttrs.push_back(rewriter.getNamedAttr(
|
||||
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs)));
|
||||
// 3. Add a new argument to the region
|
||||
auto amendedFuncOp = rewriter.create<triton::FuncOp>(
|
||||
funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs);
|
||||
auto ®ion = funcOp.getBody();
|
||||
region.addArgument(ptrTy, loc);
|
||||
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
|
||||
amendedFuncOp.end());
|
||||
return amendedFuncOp;
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
||||
// Prevent LLVM's inliner to inline this function
|
||||
auto amendedFuncOp = funcOp;
|
||||
if (!allocation.isRoot(funcOp))
|
||||
amendedFuncOp = amendFuncOp(funcOp, rewriter);
|
||||
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter);
|
||||
if (!newFuncOp) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto ctx = funcOp->getContext();
|
||||
|
||||
<<<<<<< HEAD
|
||||
// Set an attribute to indicate this function is a kernel entry.
|
||||
newFuncOp->setAttr("nvvm.kernel",
|
||||
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
|
||||
@@ -102,6 +161,25 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
// for `nvvm.annotation` metadata.
|
||||
newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps));
|
||||
#endif
|
||||
=======
|
||||
if (allocation.isRoot(funcOp)) {
|
||||
// Set an attribute to indicate this function is a kernel entry.
|
||||
newFuncOp->setAttr("nvvm.kernel",
|
||||
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
|
||||
} else {
|
||||
// The noinline attribute will be used by the LLVM codegen to prevent
|
||||
// inlining.
|
||||
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267
|
||||
newFuncOp.setPassthroughAttr(
|
||||
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
|
||||
rewriter.eraseOp(amendedFuncOp);
|
||||
}
|
||||
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
|
||||
// for `nvvm.annotation` metadata.
|
||||
newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps));
|
||||
// The call graph is updated by mapping the old function to the new one.
|
||||
allocation.mapFuncOp(funcOp, newFuncOp);
|
||||
>>>>>>> openai/main
|
||||
|
||||
rewriter.eraseOp(funcOp);
|
||||
return success();
|
||||
@@ -109,6 +187,99 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
|
||||
private:
|
||||
int numWarps{0};
|
||||
ModuleAllocation &allocation;
|
||||
};
|
||||
|
||||
// CallOpInterfaceLowering is adapted from
|
||||
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485
|
||||
struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
|
||||
CallOpConversion(LLVMTypeConverter &converter, int numWarps,
|
||||
ModuleAllocation &allocation, PatternBenefit benefit)
|
||||
: ConvertOpToLLVMPattern<triton::CallOp>(converter, benefit),
|
||||
numWarps(numWarps), allocation(allocation) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::CallOp callOp,
|
||||
typename triton::CallOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto promotedOperands = promoteOperands(callOp, adaptor, rewriter);
|
||||
auto newCallOp =
|
||||
convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter);
|
||||
if (!newCallOp)
|
||||
return failure();
|
||||
allocation.mapCallOp(callOp, newCallOp);
|
||||
auto results = getCallOpResults(callOp, newCallOp, rewriter);
|
||||
rewriter.replaceOp(callOp, results);
|
||||
return success();
|
||||
}
|
||||
|
||||
private:
|
||||
SmallVector<Value, 4>
|
||||
promoteOperands(triton::CallOp callOp,
|
||||
typename triton::CallOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Get the last argument of the caller, which is the current stack pointer
|
||||
// of shared memory and append it to the operands of the callOp.
|
||||
auto loc = callOp.getLoc();
|
||||
auto caller = callOp->getParentOfType<FunctionOpInterface>();
|
||||
auto base = allocation.getFunctionSharedMemoryBase(caller);
|
||||
auto *funcAllocation = allocation.getFuncData(caller);
|
||||
auto bufferId = funcAllocation->getBufferId(callOp);
|
||||
auto offset = funcAllocation->getOffset(bufferId);
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()),
|
||||
NVVM::kSharedMemorySpace);
|
||||
auto offsetValue = gep(ptrTy, base, i32_val(offset));
|
||||
auto promotedOperands = this->getTypeConverter()->promoteOperands(
|
||||
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
|
||||
adaptor.getOperands(), rewriter);
|
||||
promotedOperands.push_back(offsetValue);
|
||||
return promotedOperands;
|
||||
}
|
||||
|
||||
LLVM::CallOp
|
||||
convertCallOpToLLVMCallOp(triton::CallOp callOp,
|
||||
ArrayRef<Value> promotedOperands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// Pack the result types into a struct.
|
||||
Type packedResult = nullptr;
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
|
||||
|
||||
if (numResults != 0) {
|
||||
if (!(packedResult =
|
||||
this->getTypeConverter()->packFunctionResults(resultTypes)))
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
auto newCallOp = rewriter.create<LLVM::CallOp>(
|
||||
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
|
||||
promotedOperands, callOp->getAttrs());
|
||||
return newCallOp;
|
||||
}
|
||||
|
||||
SmallVector<Value>
|
||||
getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto numResults = callOp.getNumResults();
|
||||
SmallVector<Value> results;
|
||||
if (numResults < 2) {
|
||||
// If < 2 results, packing did not do anything and we can just return.
|
||||
results.append(newCallOp.result_begin(), newCallOp.result_end());
|
||||
} else {
|
||||
// Otherwise, it had been converted to an operation producing a structure.
|
||||
// Extract individual results from the structure and return them as list.
|
||||
results.reserve(numResults);
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
|
||||
callOp.getLoc(), newCallOp->getResult(0), i));
|
||||
}
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
int numWarps{0};
|
||||
ModuleAllocation &allocation;
|
||||
};
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
@@ -145,26 +316,25 @@ public:
|
||||
TritonLLVMConversionTarget target(*context, isROCM);
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
/* preprocess */
|
||||
// Preprocess
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
if (failed(decomposeInsertSliceAsyncOp(mod)))
|
||||
return signalPassFailure();
|
||||
|
||||
/* allocate shared memory and set barrier */
|
||||
Allocation allocation(mod);
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
// Allocate shared memory and set barrier
|
||||
ModuleAllocation allocation(mod);
|
||||
ModuleMembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
/* lower functions */
|
||||
// Lower functions
|
||||
{
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
|
||||
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, allocation,
|
||||
/*benefit=*/1);
|
||||
funcPatterns.add<ReturnOpConversion>(typeConverter);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
funcPatterns);
|
||||
if (failed(
|
||||
@@ -172,36 +342,49 @@ public:
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(mod)))
|
||||
return signalPassFailure();
|
||||
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
|
||||
mod->setAttr("triton_gpu.shared",
|
||||
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
|
||||
allocation.getSharedMemorySize()));
|
||||
// initSharedMemory is run before the conversion of call and ret ops,
|
||||
// because the call op has to know the shared memory base address of each
|
||||
// function
|
||||
initSharedMemory(allocation, typeConverter);
|
||||
|
||||
/* rewrite ops */
|
||||
// Convert call and ret ops
|
||||
{
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<CallOpConversion>(typeConverter, numWarps, allocation,
|
||||
/*benefit=*/1);
|
||||
funcPatterns.add<ReturnOpConversion>(typeConverter, /*benefit=*/1);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
|
||||
// Rewrite ops
|
||||
RewritePatternSet patterns(context);
|
||||
// TritonGPU lowering patterns
|
||||
OpBuilder::InsertPoint indexInsertPoint;
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
|
||||
&baseIndexCache, &indexCache, &indexInsertPoint};
|
||||
auto populatePatterns1 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis,
|
||||
&allocation, smem, indexCacheInfo, /*benefit*/ 1);
|
||||
};
|
||||
auto populatePatterns2 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis,
|
||||
&allocation, smem, /*benefit*/ 1);
|
||||
};
|
||||
populatePatterns1(populateTritonGPUToLLVMPatterns);
|
||||
populatePatterns1(populateConvertLayoutOpToLLVMPatterns);
|
||||
populatePatterns2(populateDotOpToLLVMPatterns);
|
||||
populatePatterns2(populateElementwiseOpToLLVMPatterns);
|
||||
populatePatterns1(populateLoadStoreOpToLLVMPatterns);
|
||||
populatePatterns1(populateReduceOpToLLVMPatterns);
|
||||
populatePatterns2(populateViewOpToLLVMPatterns);
|
||||
// TODO: enable index cache if there are multiple functions
|
||||
if (axisInfoAnalysis.getNumFunctions() > 1) {
|
||||
indexCacheInfo = {nullptr, nullptr, nullptr};
|
||||
}
|
||||
populateTritonGPUToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
indexCacheInfo, /*benefit=*/1);
|
||||
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
indexCacheInfo, /*benefit=*/1);
|
||||
populateDotOpToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
/*benefit=*/1);
|
||||
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
|
||||
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis,
|
||||
allocation, indexCacheInfo,
|
||||
/*benefit=*/1);
|
||||
populateReduceOpToLLVMPatterns(typeConverter, patterns, allocation,
|
||||
indexCacheInfo, /*benefit=*/1);
|
||||
populateViewOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
|
||||
|
||||
// Native lowering patterns
|
||||
if (isROCM) {
|
||||
@@ -218,8 +401,6 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
Value smem;
|
||||
|
||||
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
baseIndexCache;
|
||||
@@ -230,10 +411,11 @@ private:
|
||||
int computeCapability{};
|
||||
bool isROCM{};
|
||||
|
||||
void initSharedMemory(size_t size,
|
||||
void initSharedMemory(ModuleAllocation &allocation,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter) {
|
||||
ModuleOp mod = getOperation();
|
||||
OpBuilder b(mod.getBodyRegion());
|
||||
auto ctx = mod.getContext();
|
||||
auto loc = mod.getLoc();
|
||||
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
|
||||
// Set array size 0 and external linkage indicates that we use dynamic
|
||||
@@ -244,15 +426,23 @@ private:
|
||||
"global_smem", /*value=*/Attribute(), /*alignment=*/0,
|
||||
// Add ROCm support.
|
||||
static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace));
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
|
||||
assert(funcs.size() == 1 &&
|
||||
"Inliner pass is expected before TritonGPUToLLVM");
|
||||
b.setInsertionPointToStart(&funcs[0].getBody().front());
|
||||
smem = b.create<LLVM::AddressOfOp>(loc, global);
|
||||
auto ptrTy =
|
||||
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3);
|
||||
smem = b.create<LLVM::BitcastOp>(loc, ptrTy, smem);
|
||||
mod.walk([&](FunctionOpInterface funcOp) {
|
||||
Value funcSmem;
|
||||
b.setInsertionPointToStart(&funcOp.getFunctionBody().front());
|
||||
if (allocation.isRoot(funcOp)) {
|
||||
funcSmem = b.create<LLVM::AddressOfOp>(loc, global);
|
||||
} else {
|
||||
funcSmem = funcOp.getArgument(funcOp.getNumArguments() - 1);
|
||||
}
|
||||
auto ptrTy =
|
||||
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()),
|
||||
NVVM::NVVMMemorySpace::kSharedMemorySpace);
|
||||
funcSmem = b.create<LLVM::BitcastOp>(loc, ptrTy, funcSmem);
|
||||
allocation.setFunctionSharedMemoryValue(funcOp, funcSmem);
|
||||
});
|
||||
mod->setAttr("triton_gpu.shared",
|
||||
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
|
||||
allocation.getSharedMemorySize()));
|
||||
}
|
||||
|
||||
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) const {
|
||||
@@ -310,10 +500,7 @@ private:
|
||||
}
|
||||
|
||||
LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const {
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(mod)))
|
||||
return failure();
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
|
||||
// TODO(Keren): This is a hacky knob that may cause performance regression
|
||||
// when decomposition has been performed. We should remove this knob once we
|
||||
// have thorough analysis on async wait. Currently, we decompose
|
||||
@@ -347,7 +534,7 @@ private:
|
||||
auto resSharedLayout =
|
||||
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
|
||||
auto resElemTy = dstTy.getElementType();
|
||||
unsigned inVec = axisInfoAnalysis->getPtrContiguity(src);
|
||||
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
|
||||
unsigned outVec = resSharedLayout.getVec();
|
||||
unsigned minVec = std::min(outVec, inVec);
|
||||
auto maxBitWidth =
|
||||
|
||||
@@ -106,17 +106,8 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct(
|
||||
return elemTy;
|
||||
if (mmaParent.isAmpere()) {
|
||||
int bitwidth = elemTy.getIntOrFloatBitWidth();
|
||||
// sub-word integer types need to be packed for perf reasons
|
||||
if (elemTy.isa<IntegerType>() && bitwidth < 32)
|
||||
return IntegerType::get(ctx, 32);
|
||||
// TODO: unify everything to use packed integer-types
|
||||
// otherwise, vector types are ok
|
||||
const llvm::DenseMap<int, Type> elemTyMap = {
|
||||
{32, vec_ty(elemTy, 1)},
|
||||
{16, vec_ty(elemTy, 2)},
|
||||
{8, vec_ty(elemTy, 4)},
|
||||
};
|
||||
return elemTyMap.lookup(bitwidth);
|
||||
assert(bitwidth <= 32);
|
||||
return IntegerType::get(ctx, 32);
|
||||
} else {
|
||||
assert(mmaParent.isVolta());
|
||||
return vec_ty(elemTy, 2);
|
||||
|
||||
@@ -88,6 +88,7 @@
|
||||
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
|
||||
|
||||
// Types
|
||||
#define int_ty(width) rewriter.getIntegerType(width)
|
||||
#define i64_ty rewriter.getIntegerType(64)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define i16_ty rewriter.getIntegerType(16)
|
||||
|
||||
@@ -116,23 +116,64 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename SourceOp>
|
||||
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
explicit ViewLikeOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
|
||||
using OpAdaptor = typename ViewOp::Adaptor;
|
||||
explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<ViewOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
auto vals = this->getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
|
||||
Value view =
|
||||
Value ret =
|
||||
this->getTypeConverter()->packLLElements(loc, vals, rewriter, resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ExpandDimsOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp> {
|
||||
using OpAdaptor = typename ExpandDimsOp::Adaptor;
|
||||
explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto srcVals = this->getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
|
||||
|
||||
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
|
||||
assert(srcTy.getEncoding().isa<SliceEncodingAttr>() &&
|
||||
"ExpandDimsOp only support SliceEncodingAttr");
|
||||
auto srcLayout = srcTy.getEncoding().dyn_cast<SliceEncodingAttr>();
|
||||
auto resultLayout = resultTy.getEncoding();
|
||||
|
||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
|
||||
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
|
||||
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
|
||||
for (size_t i = 0; i < srcOffsets.size(); i++) {
|
||||
srcValues[srcOffsets[i]] = srcVals[i];
|
||||
}
|
||||
|
||||
SmallVector<Value> resultVals;
|
||||
for (size_t i = 0; i < resultOffsets.size(); i++) {
|
||||
auto offset = resultOffsets[i];
|
||||
offset.erase(offset.begin() + srcLayout.getDim());
|
||||
resultVals.push_back(srcValues.lookup(offset));
|
||||
}
|
||||
Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -161,13 +202,10 @@ struct TransOpConversion
|
||||
};
|
||||
|
||||
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
RewritePatternSet &patterns,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
|
||||
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
|
||||
benefit);
|
||||
patterns.add<ViewOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
|
||||
patterns.add<CatOpConversion>(typeConverter, benefit);
|
||||
|
||||
@@ -7,9 +7,7 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
RewritePatternSet &patterns,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -14,6 +14,7 @@ add_mlir_conversion_library(TritonToTritonGPU
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
TritonGPUTransforms
|
||||
|
||||
@@ -67,19 +67,21 @@ public:
|
||||
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
auto retShapedType = retType.cast<ShapedType>();
|
||||
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
|
||||
if (dyn_cast<RankedTensorType>(retType)) {
|
||||
if (dyn_cast<RankedTensorType>(retShapedType)) {
|
||||
assert(value);
|
||||
if (value.getElementType().isInteger(1) && value.isSplat())
|
||||
// Workaround until https://reviews.llvm.org/D133743 is included.
|
||||
value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
|
||||
value =
|
||||
DenseElementsAttr::get(retShapedType, value.getSplatValue<bool>());
|
||||
else
|
||||
// This is a hack. We just want to add encoding
|
||||
value = value.reshape(retType);
|
||||
value = value.reshape(retShapedType);
|
||||
}
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value),
|
||||
adaptor.getAttributes());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<arith::ConstantOp>(
|
||||
op, retShapedType, value),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -165,8 +167,6 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
MLIRContext *context = patterns.getContext();
|
||||
// Rewrite rule
|
||||
patterns.add<StdSelectPattern>(typeConverter, context);
|
||||
target.addLegalOp<triton::ReturnOp>(); // this is ok because all functions are
|
||||
// inlined by the frontend
|
||||
}
|
||||
|
||||
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
|
||||
@@ -274,6 +274,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
// a & b must be of smem layout
|
||||
auto aType = adaptor.getA().getType().cast<RankedTensorType>();
|
||||
auto bType = adaptor.getB().getType().cast<RankedTensorType>();
|
||||
Type aEltType = aType.getElementType();
|
||||
Type bEltType = bType.getElementType();
|
||||
Attribute aEncoding = aType.getEncoding();
|
||||
Attribute bEncoding = bType.getEncoding();
|
||||
if (!aEncoding || !bEncoding)
|
||||
@@ -282,17 +284,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
Value b = adaptor.getB();
|
||||
Value c = adaptor.getC();
|
||||
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
|
||||
auto dstType = RankedTensorType::get(aType.getShape(),
|
||||
aType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
getContext(), 0, dEncoding, aEltType);
|
||||
auto dstType =
|
||||
RankedTensorType::get(aType.getShape(), aEltType, encoding);
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
|
||||
}
|
||||
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
|
||||
Attribute encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
|
||||
auto dstType = RankedTensorType::get(bType.getShape(),
|
||||
bType.getElementType(), encoding);
|
||||
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
getContext(), 1, dEncoding, bEltType);
|
||||
auto dstType =
|
||||
RankedTensorType::get(bType.getShape(), bEltType, encoding);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
|
||||
}
|
||||
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
|
||||
@@ -533,6 +535,52 @@ struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
|
||||
}
|
||||
};
|
||||
|
||||
class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
|
||||
public:
|
||||
using OpConversionPattern<triton::FuncOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = getTypeConverter();
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
|
||||
op, op.getName(), op.getFunctionType());
|
||||
addNamedAttrs(newOp, adaptor.getAttributes());
|
||||
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
|
||||
newOp.getBody().end());
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TritonCallOpPattern : public OpConversionPattern<triton::CallOp> {
|
||||
public:
|
||||
using OpConversionPattern<triton::CallOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::CallOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::CallOp>(
|
||||
op, op.getCallee(), op.getResultTypes(), adaptor.getOperands());
|
||||
addNamedAttrs(newOp, adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
class TritonReturnOpPattern : public OpConversionPattern<ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern<ReturnOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
@@ -550,7 +598,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonLoadPattern, TritonStorePattern,
|
||||
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
|
||||
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
|
||||
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>(
|
||||
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern,
|
||||
TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
@@ -752,31 +801,10 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class FuncOpPattern : public OpConversionPattern<triton::FuncOp> {
|
||||
public:
|
||||
using OpConversionPattern<triton::FuncOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto converter = getTypeConverter();
|
||||
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
|
||||
op, op.getName(), op.getFunctionType());
|
||||
addNamedAttrs(newOp, adaptor.getAttributes());
|
||||
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
|
||||
newOp.getBody().end());
|
||||
if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateCFPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
MLIRContext *context = patterns.getContext();
|
||||
patterns.add<FuncOpPattern, CFCondBranchPattern, CFBranchPattern>(
|
||||
typeConverter, context);
|
||||
patterns.add<CFCondBranchPattern, CFBranchPattern>(typeConverter, context);
|
||||
}
|
||||
//
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
add_mlir_dialect_library(TritonIR
|
||||
Interfaces.cpp
|
||||
Dialect.cpp
|
||||
Ops.cpp
|
||||
Types.cpp
|
||||
@@ -11,5 +10,6 @@ add_mlir_dialect_library(TritonIR
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRIR
|
||||
MLIRArithDialect
|
||||
MLIRMathDialect
|
||||
MLIRSCFDialect
|
||||
)
|
||||
|
||||
@@ -22,14 +22,22 @@ using namespace mlir::triton;
|
||||
namespace {
|
||||
struct TritonInlinerInterface : public DialectInlinerInterface {
|
||||
using DialectInlinerInterface::DialectInlinerInterface;
|
||||
|
||||
bool isLegalToInline(Operation *call, Operation *callable,
|
||||
bool wouldBeCloned) const final {
|
||||
auto funcOp = dyn_cast<triton::FuncOp>(callable);
|
||||
if (!funcOp)
|
||||
return true;
|
||||
if (funcOp->hasAttr("noinline"))
|
||||
return !funcOp->getAttrOfType<BoolAttr>("noinline").getValue();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
|
||||
IRMapping &valueMapping) const final {
|
||||
return true;
|
||||
}
|
||||
|
||||
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
|
||||
IRMapping &) const final {
|
||||
return true;
|
||||
@@ -83,5 +91,5 @@ void TritonDialect::initialize() {
|
||||
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
|
||||
Attribute value, Type type,
|
||||
Location loc) {
|
||||
return builder.create<arith::ConstantOp>(loc, type, value);
|
||||
return arith::ConstantOp::materialize(builder, value, type, loc);
|
||||
}
|
||||
|
||||
@@ -230,6 +230,54 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
state.addTypes({resultType});
|
||||
}
|
||||
|
||||
// load(ptr, splat(1), ...) -> load(ptr, ...)
|
||||
// load(ptr, splat(0), other, ...) -> other
|
||||
struct CanonicalizeMaskedLoadPattern
|
||||
: public mlir::OpRewritePattern<triton::LoadOp> {
|
||||
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::LoadOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::LoadOp loadOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = loadOp.getMask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(),
|
||||
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
|
||||
// If there's no "other", the value is "undef". Perhaps we want to
|
||||
// optimize it in the future.x
|
||||
auto otherVal = loadOp.getOther();
|
||||
if (!otherVal)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOp(loadOp, otherVal);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedLoadPattern>(context);
|
||||
}
|
||||
|
||||
//-- StoreOp --
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value,
|
||||
@@ -257,10 +305,51 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
evict);
|
||||
}
|
||||
|
||||
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
|
||||
// store(ptr, value, splat(0), ...) -> [none]
|
||||
struct CanonicalizeMaskedStorePattern
|
||||
: public mlir::OpRewritePattern<triton::StoreOp> {
|
||||
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::StoreOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::StoreOp storeOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = storeOp.getMask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(),
|
||||
storeOp.getEvict());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
rewriter.eraseOp(storeOp);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedStorePattern>(context);
|
||||
}
|
||||
|
||||
//-- TransOp --
|
||||
mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// type is the same as the input
|
||||
auto argTy = operands[0].getType().cast<RankedTensorType>();
|
||||
@@ -287,7 +376,7 @@ mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
|
||||
//-- DotOp --
|
||||
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// type is the same as the accumulator
|
||||
auto accTy = operands[2].getType().cast<RankedTensorType>();
|
||||
@@ -355,7 +444,7 @@ void ReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
|
||||
|
||||
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
for (auto arg : operands) {
|
||||
auto argTy = arg.getType().cast<RankedTensorType>();
|
||||
@@ -462,7 +551,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
|
||||
//-- ExpandDimsOp --
|
||||
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
|
||||
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
// infer shape
|
||||
auto arg = operands[0];
|
||||
|
||||
@@ -9,4 +9,9 @@ add_mlir_dialect_library(TritonTransforms
|
||||
DEPENDS
|
||||
TritonTransformsIncGen
|
||||
TritonCombineIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRPass
|
||||
MLIRTransformUtils
|
||||
TritonIR
|
||||
)
|
||||
|
||||
@@ -37,7 +37,7 @@ bool isBroadcastConstantCombinable(Attribute value) {
|
||||
DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
|
||||
Value bcast_res) {
|
||||
|
||||
Type resType = bcast_res.getType();
|
||||
auto resType = bcast_res.getType().cast<ShapedType>();
|
||||
DenseElementsAttr res;
|
||||
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
|
||||
res =
|
||||
@@ -101,95 +101,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
// load(ptr, splat(1), ...) -> load(ptr, ...)
|
||||
// load(ptr, splat(0), other, ...) -> other
|
||||
struct CanonicalizeMaskedLoadPattern
|
||||
: public mlir::OpRewritePattern<triton::LoadOp> {
|
||||
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::LoadOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::LoadOp loadOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = loadOp.getMask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(),
|
||||
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
|
||||
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
|
||||
// If there's no "other", the value is "undef". Perhaps we want to
|
||||
// optimize it in the future.x
|
||||
auto otherVal = loadOp.getOther();
|
||||
if (!otherVal)
|
||||
return mlir::failure();
|
||||
rewriter.replaceOp(loadOp, otherVal);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedLoadPattern>(context);
|
||||
}
|
||||
|
||||
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
|
||||
// store(ptr, value, splat(0), ...) -> [none]
|
||||
struct CanonicalizeMaskedStorePattern
|
||||
: public mlir::OpRewritePattern<triton::StoreOp> {
|
||||
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
|
||||
: OpRewritePattern<triton::StoreOp>(context, 1) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(triton::StoreOp storeOp,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto mask = storeOp.getMask();
|
||||
if (!mask)
|
||||
return mlir::failure();
|
||||
|
||||
auto constantMask =
|
||||
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
|
||||
if (!constantMask)
|
||||
return mlir::failure();
|
||||
|
||||
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!splatMask)
|
||||
return mlir::failure();
|
||||
|
||||
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
|
||||
// mask = splat(1)
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(),
|
||||
storeOp.getEvict());
|
||||
} else {
|
||||
// mask = splat(0)
|
||||
rewriter.eraseOp(storeOp);
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<CanonicalizeMaskedStorePattern>(context);
|
||||
}
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"
|
||||
|
||||
|
||||
@@ -167,10 +167,10 @@ public:
|
||||
auto otherTensorType = RankedTensorType::get(tensorShape, elementType);
|
||||
|
||||
// Set zero padding value
|
||||
Attribute attr =
|
||||
TypedAttr attr =
|
||||
elementType.isIntOrIndex()
|
||||
? builder.getIntegerAttr(elementType, 0).cast<Attribute>()
|
||||
: builder.getFloatAttr(elementType, 0).cast<Attribute>();
|
||||
? builder.getIntegerAttr(elementType, 0).cast<TypedAttr>()
|
||||
: builder.getFloatAttr(elementType, 0).cast<TypedAttr>();
|
||||
|
||||
// Float NaN padding case
|
||||
if (padding.value() == triton::PaddingOption::PAD_NAN) {
|
||||
|
||||
@@ -7,5 +7,6 @@ add_mlir_dialect_library(TritonGPUIR
|
||||
TritonGPUAttrDefsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRGPUOps
|
||||
TritonIR
|
||||
)
|
||||
|
||||
@@ -81,10 +81,41 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
|
||||
if (mmaLayout.isAmpere())
|
||||
return {8, 4};
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parent = sliceLayout.getParent();
|
||||
auto parentThreadsPerWarp = getThreadsPerWarp(parent);
|
||||
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
|
||||
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
|
||||
for (unsigned i = 0; i < threadsPerWarp.size(); i++)
|
||||
threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()];
|
||||
return threadsPerWarp;
|
||||
}
|
||||
assert(0 && "getThreadsPerWarp not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getThreadsPerWarpWithUniqueData(Attribute layout,
|
||||
ArrayRef<int64_t> tensorShape) {
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(tensorShape);
|
||||
auto parentThreadsPerWarp =
|
||||
getThreadsPerWarpWithUniqueData(parentLayout, parentShape);
|
||||
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
|
||||
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
|
||||
return threadsPerWarp;
|
||||
}
|
||||
auto threadsPerWarp = getThreadsPerWarp(layout);
|
||||
assert(threadsPerWarp.size() == tensorShape.size() &&
|
||||
"layout and tensor shape must have the same rank");
|
||||
for (unsigned i = 0; i < threadsPerWarp.size(); i++) {
|
||||
threadsPerWarp[i] = std::min<unsigned>(threadsPerWarp[i], tensorShape[i]);
|
||||
}
|
||||
|
||||
return threadsPerWarp;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
|
||||
@@ -94,19 +125,51 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
|
||||
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
|
||||
mmaLayout.getWarpsPerCTA().end());
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parent = sliceLayout.getParent();
|
||||
auto parentWarpsPerCTA = getWarpsPerCTA(parent);
|
||||
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
|
||||
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
|
||||
for (unsigned i = 0; i < warpsPerCTA.size(); i++)
|
||||
warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()];
|
||||
return warpsPerCTA;
|
||||
}
|
||||
assert(0 && "getWarpsPerCTA not implemented");
|
||||
return {};
|
||||
}
|
||||
|
||||
SmallVector<unsigned>
|
||||
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape) {
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(tensorShape);
|
||||
auto parentWarpsPerCTA =
|
||||
getWarpsPerCTAWithUniqueData(parentLayout, parentShape);
|
||||
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
|
||||
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
|
||||
return warpsPerCTA;
|
||||
}
|
||||
auto warpsPerCTA = getWarpsPerCTA(layout);
|
||||
assert(warpsPerCTA.size() == tensorShape.size() &&
|
||||
"layout and tensor shape must have the same rank");
|
||||
for (unsigned i = 0; i < warpsPerCTA.size(); i++) {
|
||||
auto sizePerWarp =
|
||||
getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i];
|
||||
auto maxWarpsPerDim = ceil<unsigned>(tensorShape[i], sizePerWarp);
|
||||
warpsPerCTA[i] = std::min<unsigned>(warpsPerCTA[i], maxWarpsPerDim);
|
||||
}
|
||||
|
||||
return warpsPerCTA;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getSizePerThread(Attribute layout) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
|
||||
blockedLayout.getSizePerThread().end());
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto ret = getSizePerThread(sliceLayout.getParent());
|
||||
return ret;
|
||||
// ret.erase(ret.begin() + sliceLayout.getDim());
|
||||
return ret;
|
||||
auto sizePerThread = getSizePerThread(sliceLayout.getParent());
|
||||
sizePerThread.erase(sizePerThread.begin() + sliceLayout.getDim());
|
||||
return sizePerThread;
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
return {2, 2};
|
||||
@@ -146,11 +209,43 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
assert(mmaLayout.isVolta() || mmaLayout.isAmpere());
|
||||
return {1, 2};
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
return getContigPerThread(parentLayout);
|
||||
} else {
|
||||
return getSizePerThread(layout);
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getUniqueContigPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
|
||||
return SmallVector<unsigned>(1, 1);
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
auto shape = tensorType.getShape();
|
||||
// If slice layout, call recursively on parent layout, and drop
|
||||
// sliced dim
|
||||
if (auto sliceLayout =
|
||||
tensorType.getEncoding().dyn_cast<SliceEncodingAttr>()) {
|
||||
auto parentLayout = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(shape);
|
||||
auto parentTy = RankedTensorType::get(
|
||||
parentShape, tensorType.getElementType(), parentLayout);
|
||||
auto parentUniqueContigPerThread = getUniqueContigPerThread(parentTy);
|
||||
parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() +
|
||||
sliceLayout.getDim());
|
||||
return parentUniqueContigPerThread;
|
||||
}
|
||||
// Base case
|
||||
auto rank = shape.size();
|
||||
SmallVector<unsigned> ret(rank);
|
||||
auto contigPerThread = getContigPerThread(tensorType.getEncoding());
|
||||
assert(contigPerThread.size() == rank && "Unexpected contigPerThread size");
|
||||
for (int d = 0; d < rank; ++d) {
|
||||
ret[d] = std::min<unsigned>(shape[d], contigPerThread[d]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
SmallVector<unsigned> threads;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
@@ -158,7 +253,7 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
|
||||
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
|
||||
blockedLayout.getWarpsPerCTA()[d]);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.getVersionMajor() == 2) {
|
||||
if (mmaLayout.isAmpere()) {
|
||||
threads = {8 * mmaLayout.getWarpsPerCTA()[0],
|
||||
4 * mmaLayout.getWarpsPerCTA()[1]};
|
||||
} else
|
||||
@@ -261,6 +356,16 @@ bool isaDistributedLayout(Attribute layout) {
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace triton
|
||||
|
||||
bool isSharedEncoding(Value value) {
|
||||
auto type = value.getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorType.getEncoding();
|
||||
return encoding && encoding.isa<triton::gpu::SharedEncodingAttr>();
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
|
||||
@@ -375,6 +480,7 @@ SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
auto parent = getParent();
|
||||
auto parentElemsPerThread =
|
||||
::getElemsPerThread(parent, paddedShape(shape), eltTy);
|
||||
parentElemsPerThread.erase(parentElemsPerThread.begin() + getDim());
|
||||
return parentElemsPerThread;
|
||||
}
|
||||
unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
|
||||
@@ -774,14 +880,27 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
auto mmaParent = parent.dyn_cast<MmaEncodingAttr>();
|
||||
unsigned kWidth = 0;
|
||||
Attribute _kWidth = attrs.get("kWidth");
|
||||
if (_kWidth) {
|
||||
if (!mmaParent || mmaParent.isVolta()) {
|
||||
auto loc = parser.getNameLoc();
|
||||
parser.emitError(loc, "kWidth only supported for MMAv2+ parent");
|
||||
return Attribute();
|
||||
}
|
||||
kWidth = _kWidth.cast<IntegerAttr>().getInt();
|
||||
}
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent);
|
||||
parent, kWidth);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent();
|
||||
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
|
||||
if (mmaParent && mmaParent.isAmpere())
|
||||
printer << ", kWidth = " << getMMAv2kWidth();
|
||||
printer << "}>";
|
||||
}
|
||||
|
||||
@@ -1029,9 +1148,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
return mlir::failure();
|
||||
}
|
||||
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// Ensure that the new insert_slice op is placed in the same place as the
|
||||
// old insert_slice op. Otherwise, the new insert_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
// Ensure that the new insert_slice op is placed in the same place as
|
||||
// the old insert_slice op. Otherwise, the new insert_slice op may be
|
||||
// placed after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(insert_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -1059,9 +1178,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
auto resType = RankedTensorType::get(
|
||||
origResType.getShape(), origResType.getElementType(),
|
||||
extract_slice.getType().cast<RankedTensorType>().getEncoding());
|
||||
// Ensure that the new extract_slice op is placed in the same place as the
|
||||
// old extract_slice op. Otherwise, the new extract_slice op may be placed
|
||||
// after the async_wait op, which is not allowed.
|
||||
// Ensure that the new extract_slice op is placed in the same place as
|
||||
// the old extract_slice op. Otherwise, the new extract_slice op may be
|
||||
// placed after the async_wait op, which is not allowed.
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
@@ -1109,8 +1228,8 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
// cvt(type, constant) -> constant
|
||||
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
|
||||
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
|
||||
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
|
||||
ret.getSplatValue<Attribute>());
|
||||
auto ty = op->getResultTypes().front().cast<ShapedType>();
|
||||
auto newRet = SplatElementsAttr::get(ty, ret.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "triton/Dialect/TritonGPU/IR/Traits.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
mlir::LogicalResult
|
||||
mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {
|
||||
|
||||
@@ -47,12 +47,13 @@ SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
SetVector<Operation *> slices;
|
||||
mlir::getForwardSlice(dotOp.getResult(), &slices);
|
||||
if (llvm::find_if(slices, [](Operation *op) {
|
||||
return isa<triton::DotOp>(op);
|
||||
}) != slices.end())
|
||||
return {(unsigned)numWarps, 1};
|
||||
auto filter = [&dotOp](Operation *op) {
|
||||
return op->getParentRegion() == dotOp->getParentRegion();
|
||||
};
|
||||
auto slices = mlir::getSlice(dotOp, filter);
|
||||
for (Operation *op : slices)
|
||||
if (isa<triton::DotOp>(op) && (op != dotOp))
|
||||
return {(unsigned)numWarps, 1};
|
||||
|
||||
SmallVector<unsigned, 2> ret = {1, 1};
|
||||
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
|
||||
@@ -173,14 +174,17 @@ public:
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
|
||||
auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(),
|
||||
oldAType.getElementType());
|
||||
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(),
|
||||
oldBType.getElementType());
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding()));
|
||||
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding()));
|
||||
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
|
||||
@@ -14,7 +14,9 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
TritonGPUTransformsIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRTransforms
|
||||
MLIRTransformUtils
|
||||
TritonAnalysis
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
MLIRTransformUtils
|
||||
)
|
||||
|
||||
@@ -22,16 +22,13 @@ template <class T> SmallVector<unsigned, 4> argSort(const T &arr) {
|
||||
typedef DenseMap<Value, std::function<Type(Type)>> LayoutMap;
|
||||
|
||||
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
|
||||
int numWarps) {
|
||||
Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis,
|
||||
Value ptr, int numWarps) {
|
||||
auto origType = ptr.getType().cast<RankedTensorType>();
|
||||
// Get the shape of the tensor.
|
||||
size_t rank = origType.getRank();
|
||||
dataflow::Lattice<AxisInfo> *latticeElement =
|
||||
axisInfo.getLatticeElement(ptr);
|
||||
AxisInfo info = latticeElement ? latticeElement->getValue() : AxisInfo();
|
||||
// Get the contiguity order of `ptr`
|
||||
auto order = argSort(info.getContiguity());
|
||||
auto order = argSort(axisInfoAnalysis.getAxisInfo(ptr)->getContiguity());
|
||||
// The desired divisibility is the maximum divisibility
|
||||
// among all dependent pointers who have the same order as
|
||||
// `ptr`
|
||||
@@ -42,8 +39,8 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
for (Value val : op->getResults()) {
|
||||
if (val.getType() != origType)
|
||||
continue;
|
||||
auto valInfo = axisInfo.getLatticeElement(val);
|
||||
auto currOrder = argSort(valInfo->getValue().getContiguity());
|
||||
auto currOrder =
|
||||
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
|
||||
if (order == currOrder)
|
||||
withSameOrder.insert(val);
|
||||
}
|
||||
@@ -61,10 +58,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
|
||||
unsigned perThread = 1;
|
||||
for (Value val : withSameOrder) {
|
||||
AxisInfo info = axisInfo.getLatticeElement(val)->getValue();
|
||||
unsigned maxMultipleBytes = info.getDivisibility(order[0]);
|
||||
unsigned maxMultipleBytes =
|
||||
axisInfoAnalysis.getAxisInfo(val)->getDivisibility(order[0]);
|
||||
unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u);
|
||||
unsigned maxContig = info.getContiguity(order[0]);
|
||||
unsigned maxContig =
|
||||
axisInfoAnalysis.getAxisInfo(val)->getContiguity(order[0]);
|
||||
unsigned alignment = std::min(maxMultiple, maxContig);
|
||||
unsigned currPerThread = std::min(alignment, 128 / elemNumBits);
|
||||
perThread = std::max(perThread, currPerThread);
|
||||
@@ -78,9 +76,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
return encoding;
|
||||
}
|
||||
|
||||
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
|
||||
Value ptr, int numWarps) {
|
||||
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
|
||||
std::function<Type(Type)>
|
||||
getTypeConverter(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr,
|
||||
int numWarps) {
|
||||
Attribute encoding = getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps);
|
||||
return [encoding](Type _type) {
|
||||
RankedTensorType type = _type.cast<RankedTensorType>();
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
@@ -127,17 +126,14 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *op = getOperation();
|
||||
// Run axis info analysis
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfo = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(op)))
|
||||
return signalPassFailure();
|
||||
ModuleOp moduleOp = getOperation();
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
|
||||
|
||||
// For each i/o operation, we determine what layout
|
||||
// the pointers should have for best memory coalescing
|
||||
LayoutMap layoutMap;
|
||||
op->walk([&](Operation *curr) {
|
||||
moduleOp.walk([&](Operation *curr) {
|
||||
Value ptr;
|
||||
if (auto op = dyn_cast<triton::LoadOp>(curr))
|
||||
ptr = op.getPtr();
|
||||
@@ -154,10 +150,9 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
|
||||
if (!ty || !ty.getElementType().isa<PointerType>())
|
||||
return;
|
||||
AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue();
|
||||
auto mod = curr->getParentOfType<ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
auto convertType = getTypeConverter(*axisInfo, ptr, numWarps);
|
||||
auto convertType = getTypeConverter(axisInfoAnalysis, ptr, numWarps);
|
||||
layoutMap[ptr] = convertType;
|
||||
});
|
||||
|
||||
@@ -168,7 +163,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
|
||||
// produces a tensor with layout L2
|
||||
// 4. Convert the output of this new memory op back to L1
|
||||
// 5. Replace all the uses of the original memory op by the new one
|
||||
op->walk([&](Operation *curr) {
|
||||
moduleOp.walk([&](Operation *curr) {
|
||||
OpBuilder builder(curr);
|
||||
if (auto load = dyn_cast<triton::LoadOp>(curr)) {
|
||||
coalesceOp<triton::LoadOp>(layoutMap, curr, load.getPtr(), builder);
|
||||
|
||||
@@ -72,6 +72,76 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
|
||||
class MoveOpAfterLayoutConversion : public mlir::RewritePattern {
|
||||
|
||||
public:
|
||||
MoveOpAfterLayoutConversion(mlir::MLIRContext *context)
|
||||
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
|
||||
1, context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcTy = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto retTy = cvt.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto retEncoding =
|
||||
retTy.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto srcEncoding =
|
||||
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
|
||||
if (!retTy)
|
||||
return failure();
|
||||
if (!retEncoding)
|
||||
return failure();
|
||||
auto retEncodingParent =
|
||||
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
|
||||
if (!retEncodingParent || retEncodingParent.isVolta())
|
||||
return failure();
|
||||
if (!srcEncoding)
|
||||
return failure();
|
||||
// don't move things around when cvt operand is a block arg
|
||||
Operation *argOp = cvt.getOperand().getDefiningOp();
|
||||
if (!argOp)
|
||||
return failure();
|
||||
//
|
||||
SetVector<Operation *> processed;
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
int numCvts = simulateBackwardRematerialization(cvt, processed, layout,
|
||||
toConvert, retEncoding);
|
||||
if (numCvts > 1 || toConvert.size() == 1)
|
||||
return failure();
|
||||
for (Operation *op : processed) {
|
||||
if (op->getNumOperands() != 1)
|
||||
continue;
|
||||
auto srcTy = op->getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto dstTy = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
// we don't want to push conversions backward if there is a downcast
|
||||
// since it would result in more shared memory traffic
|
||||
if (srcTy.getElementType().getIntOrFloatBitWidth() >
|
||||
dstTy.getElementType().getIntOrFloatBitWidth())
|
||||
return failure();
|
||||
// we only push back when the first op in the chain has a load operand
|
||||
if ((op == processed.back()) &&
|
||||
!isa<triton::LoadOp>(op->getOperand(0).getDefiningOp()))
|
||||
return failure();
|
||||
// we don't want to use ldmatrix for 8-bit data that requires trans
|
||||
// since Nvidia GPUs can't do it efficiently
|
||||
bool isTrans =
|
||||
(retEncoding.getOpIdx() == 1) ^ (srcEncoding.getOrder()[0] == 0);
|
||||
bool isInt8 = srcTy.getElementType().getIntOrFloatBitWidth() == 8;
|
||||
if (isTrans && isInt8)
|
||||
return failure();
|
||||
}
|
||||
IRMapping mapping;
|
||||
rematerializeConversionChain(toConvert, rewriter, processed, mapping);
|
||||
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
||||
return mlir::success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
@@ -93,6 +163,7 @@ public:
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertTransConvert>(context);
|
||||
patterns.add<MoveOpAfterLayoutConversion>(context);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
if (fixupLoops(m).failed())
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
@@ -42,6 +43,9 @@ class LoopPipeliner {
|
||||
|
||||
/// Loads to be pipelined
|
||||
SetVector<Value> loads;
|
||||
/// Smallest data-type for each load (used to optimize swizzle and
|
||||
/// (create DotOpEncoding layout)
|
||||
DenseMap<Value, Type> loadsSmallestType;
|
||||
/// The value that each load will be mapped to (after layout conversion)
|
||||
DenseMap<Value, Value> loadsMapping;
|
||||
/// load => buffer
|
||||
@@ -181,24 +185,19 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
|
||||
/// this pass?)
|
||||
LogicalResult LoopPipeliner::initialize() {
|
||||
Block *loop = forOp.getBody();
|
||||
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(forOp->getParentOfType<ModuleOp>()))) {
|
||||
return failure();
|
||||
}
|
||||
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
|
||||
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
|
||||
|
||||
// can we use forOp.walk(...) here?
|
||||
SmallVector<triton::LoadOp, 2> validLoads;
|
||||
for (Operation &op : *loop)
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
|
||||
auto ptr = loadOp.getPtr();
|
||||
unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr);
|
||||
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
|
||||
|
||||
if (auto mask = loadOp.getMask())
|
||||
vec = std::min<unsigned>(vec, axisInfoAnalysis->getMaskAlignment(mask));
|
||||
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
|
||||
|
||||
auto lattice = axisInfoAnalysis->getLatticeElement(ptr)->getValue();
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy || tensorTy.getRank() < 2)
|
||||
continue;
|
||||
@@ -256,33 +255,62 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
use = *use->getResult(0).getUsers().begin();
|
||||
}
|
||||
|
||||
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
|
||||
if (auto tensorType = convertLayout.getResult()
|
||||
.getType()
|
||||
.dyn_cast<RankedTensorType>()) {
|
||||
if (auto dotOpEnc = tensorType.getEncoding()
|
||||
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
|
||||
isCandidate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
|
||||
loadsBufferType[loadOp] = RankedTensorType::get(
|
||||
bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else
|
||||
auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use);
|
||||
if (!convertLayout)
|
||||
continue;
|
||||
auto tensorType =
|
||||
convertLayout.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
auto dotOpEnc =
|
||||
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
|
||||
if (!dotOpEnc)
|
||||
continue;
|
||||
isCandidate = true;
|
||||
loadsMapping[loadOp] = convertLayout;
|
||||
}
|
||||
|
||||
else
|
||||
isCandidate = false;
|
||||
|
||||
if (isCandidate)
|
||||
loads.insert(loadOp);
|
||||
}
|
||||
|
||||
// we need to find the smallest ocmmon dtype
|
||||
// since this determines the layout of `mma.sync` operands
|
||||
// in mixed-precision mode
|
||||
Type smallestType;
|
||||
for (auto loadCvt : loadsMapping) {
|
||||
auto loadOp = loadCvt.first;
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
Type eltTy = ty.getElementType();
|
||||
if (!smallestType ||
|
||||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
|
||||
smallestType = eltTy;
|
||||
}
|
||||
|
||||
for (auto loadCvt : loadsMapping)
|
||||
loadsSmallestType[loadCvt.first] = smallestType;
|
||||
|
||||
for (auto loadCvt : loadsMapping) {
|
||||
auto loadOp = loadCvt.first;
|
||||
Value cvt = loadCvt.second;
|
||||
auto dotOpEnc = cvt.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<ttg::DotOperandEncodingAttr>();
|
||||
auto ty = loadOp.getType().cast<RankedTensorType>();
|
||||
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
|
||||
ty.getShape().end());
|
||||
bufferShape.insert(bufferShape.begin(), numStages);
|
||||
auto sharedEnc = ttg::SharedEncodingAttr::get(
|
||||
ty.getContext(), dotOpEnc, ty.getShape(),
|
||||
triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
|
||||
loadsBufferType[loadOp] =
|
||||
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
|
||||
}
|
||||
|
||||
// We have some loads to pipeline
|
||||
if (!loads.empty()) {
|
||||
// Update depArgs & depOps
|
||||
@@ -336,10 +364,6 @@ Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask,
|
||||
}
|
||||
|
||||
void LoopPipeliner::emitPrologue() {
|
||||
// llvm::errs() << "loads to pipeline...:\n";
|
||||
// for (Value load : loads)
|
||||
// llvm::errs() << load << "\n";
|
||||
|
||||
OpBuilder builder(forOp);
|
||||
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
|
||||
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
|
||||
@@ -364,7 +388,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
else if (loads.contains(op.getResult(0)))
|
||||
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
|
||||
orderedDeps.push_back(&op);
|
||||
}
|
||||
assert(depOps.size() + loads.size() == orderedDeps.size() &&
|
||||
@@ -541,44 +565,44 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
|
||||
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
|
||||
|
||||
// 2.1 clone the loop body, replace original args with args of the new ForOp
|
||||
// 2. clone the loop body, replace original args with args of the new ForOp
|
||||
// Insert async wait if necessary.
|
||||
DenseSet<Value> isModified;
|
||||
for (Operation &op : forOp.getBody()->without_terminator()) {
|
||||
Operation *newOp = builder.clone(op, mapping);
|
||||
// update mapping of results
|
||||
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
|
||||
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
|
||||
}
|
||||
// is modified
|
||||
auto it = std::find(loads.begin(), loads.end(), op.getOperand(0));
|
||||
if (it == loads.end()) {
|
||||
Operation *newOp = cloneWithInferType(builder, &op, mapping);
|
||||
continue;
|
||||
}
|
||||
|
||||
// 3. replace loads with block args (from prologue)
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Value load = loads[idx];
|
||||
assert(load.hasOneUse() &&
|
||||
"we assume that this load has one use (ConvertLayout)");
|
||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||
// set insertion point
|
||||
Value newLoad = mapping.lookup(load);
|
||||
Value newLoadUse = mapping.lookup(loadUse);
|
||||
builder.setInsertionPoint(newLoadUse.getDefiningOp());
|
||||
// create conversion
|
||||
// we replace the use new load use with a convert layout
|
||||
size_t i = std::distance(loads.begin(), it);
|
||||
auto cvtDstTy = op.getResult(0).getType().cast<RankedTensorType>();
|
||||
auto cvtDstEnc =
|
||||
cvtDstTy.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
|
||||
if (!cvtDstEnc) {
|
||||
builder.clone(op, mapping);
|
||||
continue;
|
||||
}
|
||||
auto newDstTy = RankedTensorType::get(
|
||||
cvtDstTy.getShape(), cvtDstTy.getElementType(),
|
||||
ttg::DotOperandEncodingAttr::get(
|
||||
cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(), cvtDstEnc.getParent(),
|
||||
loadsSmallestType[op.getOperand(0)]));
|
||||
auto cvt = builder.create<ttg::ConvertLayoutOp>(
|
||||
loadUse.getLoc(), loadUse.getType(),
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx]);
|
||||
|
||||
// replace uses
|
||||
newLoadUse.replaceAllUsesWith(cvt.getResult());
|
||||
// delete old load and layout conversion
|
||||
newLoadUse.getDefiningOp()->erase();
|
||||
newLoad.getDefiningOp()->erase();
|
||||
op.getResult(0).getLoc(), newDstTy,
|
||||
newForOp.getRegionIterArgs()[loadIdx + i]);
|
||||
mapping.map(op.getResult(0), cvt.getResult());
|
||||
isModified.insert(op.getResult(0));
|
||||
}
|
||||
|
||||
// 4. prefetch the next iteration
|
||||
// 3. prefetch the next iteration
|
||||
SmallVector<Operation *> orderedDeps;
|
||||
for (Operation &op : forOp.getLoopBody().front()) {
|
||||
if (depOps.contains(&op))
|
||||
orderedDeps.push_back(&op);
|
||||
else if (loads.contains(op.getResult(0)))
|
||||
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
|
||||
orderedDeps.push_back(&op);
|
||||
}
|
||||
assert(depOps.size() + loads.size() == orderedDeps.size() &&
|
||||
|
||||
@@ -27,7 +27,6 @@
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
|
||||
@@ -45,7 +44,7 @@ class Prefetcher {
|
||||
scf::YieldOp yieldOp;
|
||||
///
|
||||
// TODO: add a hook to infer prefetchWidth
|
||||
unsigned prefetchWidth = 16;
|
||||
unsigned prefetchWidth = 32;
|
||||
|
||||
/// dots to be prefetched
|
||||
SetVector<Value> dots;
|
||||
@@ -56,6 +55,8 @@ class Prefetcher {
|
||||
DenseMap<Value, Value> dot2bHeaderDef;
|
||||
DenseMap<Value, Value> dot2aYield;
|
||||
DenseMap<Value, Value> dot2bYield;
|
||||
DenseMap<Value, SmallVector<Value>> dot2aVals;
|
||||
DenseMap<Value, SmallVector<Value>> dot2bVals;
|
||||
/// operand => defining
|
||||
DenseMap<Value, Value> operand2headPrefetch;
|
||||
|
||||
@@ -66,6 +67,9 @@ class Prefetcher {
|
||||
std::optional<int64_t> offsetK = std::nullopt,
|
||||
std::optional<int64_t> shapeK = std::nullopt);
|
||||
|
||||
void cloneElementwiseOps(Value &bRem, const SmallVector<Value> &vals,
|
||||
OpBuilder &builder);
|
||||
|
||||
public:
|
||||
Prefetcher() = delete;
|
||||
|
||||
@@ -80,6 +84,24 @@ public:
|
||||
scf::ForOp createNewForOp();
|
||||
};
|
||||
|
||||
void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector<Value> &vals,
|
||||
OpBuilder &builder) {
|
||||
IRMapping mapping;
|
||||
mapping.map(vals[0], ret);
|
||||
for (int i = 1; i < vals.size(); i++) {
|
||||
Value v = vals[i];
|
||||
Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0);
|
||||
auto retType = RankedTensorType::get(
|
||||
ret.getType().cast<RankedTensorType>().getShape(),
|
||||
curr.getType().cast<RankedTensorType>().getElementType(),
|
||||
curr.getType().cast<RankedTensorType>().getEncoding());
|
||||
curr.setType(retType);
|
||||
mapping.map(v, curr);
|
||||
}
|
||||
if (vals.size() > 1)
|
||||
ret = mapping.lookup(vals.back());
|
||||
}
|
||||
|
||||
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
|
||||
Attribute dotEncoding, OpBuilder &builder,
|
||||
std::optional<int64_t> offsetK,
|
||||
@@ -110,7 +132,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
|
||||
|
||||
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
|
||||
builder.getContext(), opIdx, dotEncoding);
|
||||
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
|
||||
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
|
||||
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
|
||||
newSmem);
|
||||
@@ -135,11 +157,32 @@ LogicalResult Prefetcher::initialize() {
|
||||
return failure();
|
||||
|
||||
// returns source of cvt
|
||||
auto getPrefetchSrc = [](Value v) -> Value {
|
||||
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
|
||||
if (isSharedEncoding(cvt.getOperand()))
|
||||
return cvt.getSrc();
|
||||
return Value();
|
||||
|
||||
// returns source of cvt
|
||||
auto getPrefetchSrc = [](Value v) -> SmallVector<Value> {
|
||||
// walk back to conversion
|
||||
Operation *op = v.getDefiningOp();
|
||||
bool foundConvertFromShared = false;
|
||||
SmallVector<Value> rets;
|
||||
rets.push_back(op->getResult(0));
|
||||
while (op) {
|
||||
if (op->getNumOperands() != 1)
|
||||
break;
|
||||
if (!op->getResult(0).hasOneUse())
|
||||
break;
|
||||
rets.push_back(op->getOperand(0));
|
||||
if (auto cvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(op))
|
||||
if (isSharedEncoding(cvt.getOperand())) {
|
||||
foundConvertFromShared = true;
|
||||
break;
|
||||
}
|
||||
op = op->getOperand(0).getDefiningOp();
|
||||
}
|
||||
std::reverse(rets.begin(), rets.end());
|
||||
|
||||
if (foundConvertFromShared)
|
||||
return rets;
|
||||
return {};
|
||||
};
|
||||
|
||||
auto getIncomingOp = [this](Value v) -> Value {
|
||||
@@ -156,24 +199,39 @@ LogicalResult Prefetcher::initialize() {
|
||||
};
|
||||
|
||||
for (triton::DotOp dot : dotsInFor) {
|
||||
auto kSize = dot.getA().getType().cast<RankedTensorType>().getShape()[1];
|
||||
auto aType = dot.getA().getType().cast<RankedTensorType>();
|
||||
auto bType = dot.getB().getType().cast<RankedTensorType>();
|
||||
auto aEnc = aType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
auto bEnc = bType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
int aKWidth = aEnc.getMMAv2kWidth();
|
||||
int bKWidth = bEnc.getMMAv2kWidth();
|
||||
assert(aKWidth == bKWidth);
|
||||
|
||||
auto kSize = aType.getShape()[1];
|
||||
|
||||
// works better with nvidia tensor cores
|
||||
unsigned elementWidth =
|
||||
dot.getA().getType().cast<RankedTensorType>().getElementTypeBitWidth();
|
||||
prefetchWidth = 256 / elementWidth;
|
||||
unsigned elementWidth = aType.getElementTypeBitWidth();
|
||||
if (aKWidth == 0)
|
||||
prefetchWidth = 256 / elementWidth;
|
||||
else
|
||||
prefetchWidth = 8 * aKWidth;
|
||||
|
||||
// Skip prefetching if kSize is less than prefetchWidth
|
||||
if (kSize < prefetchWidth)
|
||||
continue;
|
||||
Value aSmem = getPrefetchSrc(dot.getA());
|
||||
Value bSmem = getPrefetchSrc(dot.getB());
|
||||
if (aSmem && bSmem) {
|
||||
auto aVals = getPrefetchSrc(dot.getA());
|
||||
auto bVals = getPrefetchSrc(dot.getB());
|
||||
|
||||
if (aVals.size() && bVals.size()) {
|
||||
Value aSmem = aVals.front();
|
||||
Value bSmem = bVals.front();
|
||||
Value aHeaderDef = getIncomingOp(aSmem);
|
||||
Value bHeaderDef = getIncomingOp(bSmem);
|
||||
// Only prefetch loop arg
|
||||
if (aHeaderDef && bHeaderDef) {
|
||||
dots.insert(dot);
|
||||
dot2aVals[dot] = aVals;
|
||||
dot2bVals[dot] = bVals;
|
||||
dot2aHeaderDef[dot] = aHeaderDef;
|
||||
dot2bHeaderDef[dot] = bHeaderDef;
|
||||
dot2aLoopArg[dot] = aSmem;
|
||||
@@ -195,10 +253,13 @@ void Prefetcher::emitPrologue() {
|
||||
dot.getType().cast<RankedTensorType>().getEncoding();
|
||||
Value aPrefetched =
|
||||
generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder);
|
||||
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().getA()] =
|
||||
aPrefetched;
|
||||
cloneElementwiseOps(aPrefetched, dot2aVals[dot], builder);
|
||||
Value bPrefetched =
|
||||
generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder);
|
||||
cloneElementwiseOps(bPrefetched, dot2bVals[dot], builder);
|
||||
|
||||
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().getA()] =
|
||||
aPrefetched;
|
||||
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().getB()] =
|
||||
bPrefetched;
|
||||
}
|
||||
@@ -256,9 +317,11 @@ scf::ForOp Prefetcher::createNewForOp() {
|
||||
Value aRem =
|
||||
generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false,
|
||||
dotEncoding, builder, kOff, kShape);
|
||||
cloneElementwiseOps(aRem, dot2aVals[dot], builder);
|
||||
Value bRem =
|
||||
generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false,
|
||||
dotEncoding, builder, kOff, kShape);
|
||||
cloneElementwiseOps(bRem, dot2bVals[dot], builder);
|
||||
builder.restoreInsertionPoint(insertionPoint);
|
||||
newOp = builder.clone(*dot, mapping);
|
||||
newOp->setOperand(0, aRem);
|
||||
@@ -281,10 +344,15 @@ scf::ForOp Prefetcher::createNewForOp() {
|
||||
for (Value dot : dots) {
|
||||
Attribute dotEncoding =
|
||||
dot.getType().cast<RankedTensorType>().getEncoding();
|
||||
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0,
|
||||
true, dotEncoding, builder));
|
||||
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1,
|
||||
true, dotEncoding, builder));
|
||||
Value aToYield = generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true,
|
||||
dotEncoding, builder);
|
||||
cloneElementwiseOps(aToYield, dot2aVals[dot], builder);
|
||||
yieldValues.push_back(aToYield);
|
||||
// bToYield
|
||||
Value bToYield = generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true,
|
||||
dotEncoding, builder);
|
||||
cloneElementwiseOps(bToYield, dot2bVals[dot], builder);
|
||||
yieldValues.push_back(bToYield);
|
||||
}
|
||||
// Update ops of yield
|
||||
if (!yieldValues.empty())
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "mlir/Transforms/RegionUtils.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
|
||||
@@ -341,7 +340,10 @@ public:
|
||||
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
|
||||
auto dstEncoding =
|
||||
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
|
||||
// XXX: why is this needed?
|
||||
if (srcEncoding.isa<triton::gpu::SharedEncodingAttr>() ||
|
||||
dstEncoding.isa<triton::gpu::SharedEncodingAttr>())
|
||||
return failure();
|
||||
// heuristics for flash attention
|
||||
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
|
||||
return failure();
|
||||
SetVector<Operation *> cvtSlices;
|
||||
@@ -365,7 +367,7 @@ public:
|
||||
// don't rematerialize non-element-wise
|
||||
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
|
||||
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
|
||||
!isa<triton::StoreOp>(op)) {
|
||||
!isa<triton::StoreOp>(op) && !isa<triton::ReduceOp>(op)) {
|
||||
return failure();
|
||||
}
|
||||
// don't rematerialize if it adds an extra conversion that can't
|
||||
@@ -375,9 +377,10 @@ public:
|
||||
SetVector<Operation *> processed;
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 &&
|
||||
simulateBackwardRematerialization(argOp, processed, layout,
|
||||
toConvert, srcEncoding) > 0) {
|
||||
int numAddedConvs = simulateBackwardRematerialization(
|
||||
argOp, processed, layout, toConvert, srcEncoding);
|
||||
if (argOp && !isa<triton::gpu::ConvertLayoutOp>(argOp) &&
|
||||
cvtSlices.count(argOp) == 0 && numAddedConvs > 0) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,11 +89,11 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
}
|
||||
|
||||
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
// Case 1a: A size 1 tensor is not expensive since all threads will load the
|
||||
// Case 1: A size 1 tensor is not expensive since all threads will load the
|
||||
// same
|
||||
if (isSingleValue(op->getOperand(0)))
|
||||
return false;
|
||||
// Case 1b: Tensor of pointers has more threads than elements
|
||||
// Case 2: Tensor of pointers has more threads than elements
|
||||
// we can presume a high hit-rate that makes it cheap to load
|
||||
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
|
||||
IntegerAttr numWarps =
|
||||
@@ -104,28 +104,6 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
if (ptrType.getNumElements() < numWarps.getInt() * 32)
|
||||
return false;
|
||||
}
|
||||
// auto ptr = op->getOperand(0);
|
||||
//// Case 2: We assume that `evict_last` loads/stores have high hit rate
|
||||
// if (auto load = dyn_cast<triton::LoadOp>(op))
|
||||
// if (load.getEvict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
// return false;
|
||||
// if (auto store = dyn_cast<triton::StoreOp>(op))
|
||||
// if (store.getEvict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
// return false;
|
||||
// if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
|
||||
// auto encoding = tensorTy.getEncoding();
|
||||
// // Case 3: Different type conversion is expensive (e.g., mma <->
|
||||
// block) if (encoding.getTypeID() != targetEncoding.getTypeID())
|
||||
// return true;
|
||||
// auto sizePerThread = triton::gpu::getSizePerThread(encoding);
|
||||
// auto targetSizePerThread =
|
||||
// triton::gpu::getSizePerThread(targetEncoding); auto order =
|
||||
// triton::gpu::getOrder(encoding); auto targetOrder =
|
||||
// triton::gpu::getOrder(targetEncoding);
|
||||
// // Case 4: The targeEncoding may expose more vectorization
|
||||
// opportunities return sizePerThread[order[0]] >=
|
||||
// targetSizePerThread[targetOrder[0]];
|
||||
// }
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -144,6 +122,12 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool canFoldConversion(Operation *op) {
|
||||
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp,
|
||||
triton::CatOp>(*op);
|
||||
}
|
||||
|
||||
int simulateBackwardRematerialization(
|
||||
Operation *initOp, SetVector<Operation *> &processed,
|
||||
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
|
||||
@@ -189,10 +173,7 @@ int simulateBackwardRematerialization(
|
||||
continue;
|
||||
// If the conversion can be folded into opArgI then
|
||||
// we don't count this conversion as expensive
|
||||
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
||||
continue;
|
||||
if (isa<triton::ViewOp, triton::CatOp>(opArgI))
|
||||
if (canFoldConversion(opArgI))
|
||||
continue;
|
||||
|
||||
// We add one expensive conversion for the current operand
|
||||
@@ -206,11 +187,24 @@ int simulateBackwardRematerialization(
|
||||
|
||||
//
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
|
||||
IRMapping &mapping) {
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
|
||||
// if input types haven't changed, we're done
|
||||
bool preserveTypes =
|
||||
std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) {
|
||||
return !mapping.contains(v) ||
|
||||
v.getType() == mapping.lookup(v).getType();
|
||||
});
|
||||
if (preserveTypes)
|
||||
return newOp;
|
||||
|
||||
if (newOp->getNumResults() == 0)
|
||||
return newOp;
|
||||
auto origType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
auto argType = newOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!origType || !argType)
|
||||
return newOp;
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(), argType.getEncoding());
|
||||
newOp->getResult(0).setType(newType);
|
||||
@@ -219,9 +213,12 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
SmallVector<Type, 1> newTypes;
|
||||
auto success = typeInfer.inferReturnTypes(
|
||||
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
||||
newOp->getAttrDictionary(), newOp->getRegions(), newTypes);
|
||||
if (succeeded(success))
|
||||
newOp->getResult(0).setType(newTypes.front());
|
||||
newOp->getAttrDictionary(), newOp->getPropertiesStorage(),
|
||||
newOp->getRegions(), newTypes);
|
||||
if (succeeded(success)) {
|
||||
for (size_t i = 0; i < newTypes.size(); i++)
|
||||
newOp->getResult(i).setType(newTypes[i]);
|
||||
}
|
||||
}
|
||||
return newOp;
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ int simulateBackwardRematerialization(
|
||||
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
|
||||
Attribute targetEncoding);
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
|
||||
IRMapping &mapping);
|
||||
|
||||
void rematerializeConversionChain(
|
||||
|
||||
@@ -5,8 +5,17 @@ add_mlir_translation_library(TritonLLVMIR
|
||||
Core
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRArithToLLVM
|
||||
MLIRBuiltinToLLVMIRTranslation
|
||||
MLIRExecutionEngineUtils
|
||||
MLIRIndexToLLVM
|
||||
MLIRIR
|
||||
MLIRLLVMDialect
|
||||
MLIRLLVMToLLVMIRTranslation
|
||||
MLIRNVVMToLLVMIRTranslation
|
||||
MLIRROCDLToLLVMIRTranslation
|
||||
MLIRSCFToControlFlow
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIRExport
|
||||
TritonGPUToLLVM
|
||||
)
|
||||
|
||||
@@ -25,12 +25,22 @@
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Linker/Linker.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
<<<<<<< HEAD
|
||||
|
||||
#include <iostream>
|
||||
=======
|
||||
#ifdef _WIN32
|
||||
#define WIN32_LEAN_AND_MEAN
|
||||
#include <windows.h>
|
||||
#else
|
||||
>>>>>>> openai/main
|
||||
#include <dlfcn.h>
|
||||
#endif
|
||||
#include <filesystem>
|
||||
#include <iterator>
|
||||
|
||||
namespace fs = std::filesystem;
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
@@ -117,6 +127,32 @@ extractNVVMMetadata(mlir::ModuleOp module,
|
||||
}
|
||||
}
|
||||
|
||||
static std::filesystem::path getThisLibraryPath() {
|
||||
#ifdef _WIN32
|
||||
/* Get module of the specified address */
|
||||
HMODULE hModule;
|
||||
GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
|
||||
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
|
||||
reinterpret_cast<LPCSTR>(&getThisLibraryPath), &hModule);
|
||||
if (NULL == hModule) {
|
||||
return std::filesystem::path();
|
||||
}
|
||||
|
||||
char fileName[1024]; // this is way beyond Windows MAX_PATH limit.
|
||||
DWORD dwSize = GetModuleFileNameA(hModule, fileName, sizeof(fileName));
|
||||
if (0 == dwSize || sizeof(fileName) == dwSize) {
|
||||
return std::filesystem::path();
|
||||
}
|
||||
return std::filesystem::path(fileName);
|
||||
#else
|
||||
Dl_info fileinfo;
|
||||
if (dladdr(reinterpret_cast<void *>(&getThisLibraryPath), &fileinfo) == 0) {
|
||||
return std::filesystem::path();
|
||||
}
|
||||
return std::filesystem::path(fileinfo.dli_fname);
|
||||
#endif
|
||||
}
|
||||
|
||||
static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
std::map<std::string, std::string> externLibs;
|
||||
SmallVector<LLVM::LLVMFuncOp> funcs;
|
||||
@@ -156,17 +192,10 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
|
||||
externLibs.try_emplace(libdevice, env_path);
|
||||
return externLibs;
|
||||
}
|
||||
namespace fs = std::filesystem;
|
||||
// Search for libdevice relative to its library path if used from Python
|
||||
// Then native code is in `triton/_C/libtriton.so` and libdevice in
|
||||
// `triton/third_party/cuda/lib/libdevice.10.bc`
|
||||
static const auto this_library_path = [] {
|
||||
Dl_info fileinfo;
|
||||
if (dladdr(reinterpret_cast<void *>(&getExternLibs), &fileinfo) == 0) {
|
||||
return std::filesystem::path();
|
||||
}
|
||||
return std::filesystem::path(fileinfo.dli_fname);
|
||||
}();
|
||||
static const auto this_library_path = getThisLibraryPath();
|
||||
static const auto runtime_path =
|
||||
this_library_path.parent_path().parent_path() / "third_party" / "cuda" /
|
||||
"lib" / "libdevice.10.bc";
|
||||
|
||||
@@ -68,7 +68,7 @@ def get_llvm_package_info():
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
|
||||
version = "llvm-17.0.0-f733b4fb9b8b"
|
||||
version = "llvm-17.0.0-c5dede880d17"
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
|
||||
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
|
||||
@@ -249,6 +249,7 @@ setup(
|
||||
"triton/_C",
|
||||
"triton/common",
|
||||
"triton/compiler",
|
||||
"triton/debugger",
|
||||
"triton/language",
|
||||
"triton/language/extra",
|
||||
"triton/ops",
|
||||
@@ -294,7 +295,6 @@ setup(
|
||||
"numpy",
|
||||
"pytest",
|
||||
"scipy>=1.7.1",
|
||||
"torch",
|
||||
],
|
||||
"tutorials": [
|
||||
"matplotlib",
|
||||
|
||||
@@ -264,6 +264,11 @@ void init_triton_ir(py::module &&m) {
|
||||
return !self.empty() &&
|
||||
self.back().hasTrait<mlir::OpTrait::IsTerminator>();
|
||||
})
|
||||
.def("has_return",
|
||||
[](mlir::Block &self) {
|
||||
return !self.empty() &&
|
||||
self.back().hasTrait<mlir::OpTrait::ReturnLike>();
|
||||
})
|
||||
.def("erase", [](mlir::Block &self) { self.erase(); });
|
||||
|
||||
// using eattr = ir::attribute_kind_t;
|
||||
@@ -430,6 +435,25 @@ void init_triton_ir(py::module &&m) {
|
||||
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
|
||||
},
|
||||
ret::reference)
|
||||
.def("finalize",
|
||||
[](mlir::triton::FuncOp &self) -> void {
|
||||
// Remove dead code
|
||||
// 1. Unreachable code after return
|
||||
self.walk([&](mlir::Block *block) {
|
||||
mlir::Operation *retOp = nullptr;
|
||||
block->walk([&](mlir::Operation *op) {
|
||||
if (mlir::isa<mlir::triton::ReturnOp>(op))
|
||||
if (retOp == nullptr)
|
||||
retOp = op;
|
||||
});
|
||||
if (retOp && retOp != &block->back()) {
|
||||
auto pos = retOp->getIterator();
|
||||
pos++;
|
||||
auto *newBlock = block->splitBlock(pos);
|
||||
newBlock->erase();
|
||||
}
|
||||
});
|
||||
})
|
||||
.def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType)
|
||||
.def("reset_type", &mlir::triton::FuncOp::setType);
|
||||
|
||||
@@ -454,7 +478,8 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self, mlir::triton::FuncOp &func,
|
||||
std::vector<mlir::Value> &args) -> mlir::OpState {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::triton::CallOp>(loc, func, args);
|
||||
auto callOp = self.create<mlir::triton::CallOp>(loc, func, args);
|
||||
return callOp;
|
||||
})
|
||||
// insertion block/point
|
||||
.def("set_insertion_point_to_start",
|
||||
@@ -645,14 +670,16 @@ void init_triton_ir(py::module &&m) {
|
||||
.def("get_or_insert_function",
|
||||
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
|
||||
std::string &funcName, mlir::Type &funcType,
|
||||
std::string &visibility) -> mlir::triton::FuncOp {
|
||||
std::string &visibility, bool noinline) -> mlir::triton::FuncOp {
|
||||
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
|
||||
return llvm::dyn_cast<mlir::triton::FuncOp>(funcOperation);
|
||||
auto loc = self.getUnknownLoc();
|
||||
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
|
||||
llvm::SmallVector<mlir::NamedAttribute> attrs = {
|
||||
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
|
||||
self.getStringAttr(visibility))};
|
||||
self.getStringAttr(visibility)),
|
||||
mlir::NamedAttribute(self.getStringAttr("noinline"),
|
||||
self.getBoolAttr(noinline))};
|
||||
return self.create<mlir::triton::FuncOp>(loc, funcName, funcTy,
|
||||
attrs);
|
||||
}
|
||||
@@ -1699,54 +1726,59 @@ void init_triton_translation(py::module &m) {
|
||||
m.def("compile_ptx_to_cubin",
|
||||
[](const std::string &ptxCode, const std::string &ptxasPath,
|
||||
int capability) -> py::object {
|
||||
py::gil_scoped_release allow_threads;
|
||||
std::string cubin;
|
||||
{
|
||||
py::gil_scoped_release allow_threads;
|
||||
|
||||
// compile ptx with ptxas
|
||||
llvm::SmallString<64> fsrc;
|
||||
llvm::SmallString<64> flog;
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
|
||||
std::string fbin = std::string(fsrc) + ".o";
|
||||
llvm::FileRemover logRemover(flog);
|
||||
llvm::FileRemover binRemover(fbin);
|
||||
const char *_fsrc = fsrc.c_str();
|
||||
const char *_flog = flog.c_str();
|
||||
const char *_fbin = fbin.c_str();
|
||||
std::ofstream ofs(_fsrc);
|
||||
ofs << ptxCode << std::endl;
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
|
||||
(capability == 90 ? "a " : " ") + _fsrc + " -o " + _fsrc +
|
||||
".o 2> " + _flog;
|
||||
// compile ptx with ptxas
|
||||
llvm::SmallString<64> fsrc;
|
||||
llvm::SmallString<64> flog;
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
|
||||
std::string fbin = std::string(fsrc) + ".o";
|
||||
llvm::FileRemover logRemover(flog);
|
||||
llvm::FileRemover binRemover(fbin);
|
||||
const char *_fsrc = fsrc.c_str();
|
||||
const char *_flog = flog.c_str();
|
||||
const char *_fbin = fbin.c_str();
|
||||
std::ofstream ofs(_fsrc);
|
||||
ofs << ptxCode << std::endl;
|
||||
ofs.close();
|
||||
std::string cmd;
|
||||
int err;
|
||||
cmd = ptxasPath + " -v --gpu-name=sm_" +
|
||||
std::to_string(capability) + (capability == 90 ? "a " : " ") +
|
||||
_fsrc + " -o " + _fsrc + ".o 2> " + _flog;
|
||||
|
||||
err = system(cmd.c_str());
|
||||
if (err != 0) {
|
||||
err >>= 8;
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
if (err == 255) {
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
|
||||
log);
|
||||
} else if (err == 128 + SIGSEGV) {
|
||||
throw std::runtime_error("Please run `ptxas " + fsrc.str().str() +
|
||||
"` to confirm that this is a "
|
||||
"bug in `ptxas`\n" +
|
||||
log);
|
||||
err = system(cmd.c_str());
|
||||
if (err != 0) {
|
||||
err >>= 8;
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
if (err == 255) {
|
||||
throw std::runtime_error(
|
||||
"Internal Triton PTX codegen error: \n" + log);
|
||||
} else if (err == 128 + SIGSEGV) {
|
||||
throw std::runtime_error("Please run `ptxas " +
|
||||
fsrc.str().str() +
|
||||
"` to confirm that this is a "
|
||||
"bug in `ptxas`\n" +
|
||||
log);
|
||||
} else {
|
||||
throw std::runtime_error("`ptxas` failed with error code " +
|
||||
std::to_string(err) + ": \n" + log);
|
||||
}
|
||||
return {};
|
||||
} else {
|
||||
throw std::runtime_error("`ptxas` failed with error code " +
|
||||
std::to_string(err) + ": \n" + log);
|
||||
llvm::FileRemover srcRemover(fsrc);
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
cubin = std::string(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
// Do not return here, exit the gil scope and return below
|
||||
}
|
||||
return {};
|
||||
} else {
|
||||
llvm::FileRemover srcRemover(fsrc);
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
py::bytes bytes(cubin);
|
||||
return std::move(bytes);
|
||||
}
|
||||
py::bytes bytes(cubin);
|
||||
return std::move(bytes);
|
||||
});
|
||||
|
||||
m.def("add_external_libs",
|
||||
|
||||
68
python/test/regression/test_functional_regressions.py
Normal file
68
python/test/regression/test_functional_regressions.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def test_chained_matmul():
|
||||
# Regression test for issue #1601
|
||||
def chained_matmul_reference(a, b, c):
|
||||
intermediate = torch.einsum('MK,NK->MN', a, b)
|
||||
return torch.einsum('MN,NK->MK', intermediate, c)
|
||||
|
||||
@triton.jit
|
||||
def chained_matmul_kernel(
|
||||
A, # shape: (m, k)
|
||||
B, # shape: (n, k)
|
||||
C, # shape: (n, k)
|
||||
out, # shape: (m, k)
|
||||
m, n, k: tl.constexpr,
|
||||
block_m: tl.constexpr,
|
||||
block_n: tl.constexpr,
|
||||
block_k: tl.constexpr):
|
||||
|
||||
tl.static_assert(block_k == k,
|
||||
f"expected block_k == k but got {block_k} != {k}")
|
||||
|
||||
block_ix = tl.program_id(0)
|
||||
a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \
|
||||
+ tl.arange(0, block_k)[None, :]
|
||||
|
||||
a = tl.load(A + a_tile, mask=a_tile < m * k, other=0.0)
|
||||
|
||||
acc = tl.zeros([block_m, block_k], dtype=tl.float32)
|
||||
|
||||
for loop_block_start in range(0, n, block_n):
|
||||
bc_tile = (loop_block_start + tl.arange(0, block_n))[:, None] * block_k \
|
||||
+ tl.arange(0, block_k)[None, :]
|
||||
b = tl.load(B + bc_tile, mask=bc_tile < n * k, other=0.0)
|
||||
|
||||
intermediate = tl.dot(a, tl.trans(b))
|
||||
intermediate_mask = ((loop_block_start + tl.arange(0, block_n)) < n)[None, :] \
|
||||
* (tl.arange(0, block_m) < m)[:, None]
|
||||
|
||||
intermediate = tl.where(intermediate_mask, intermediate, 0.0)
|
||||
|
||||
c = tl.load(C + bc_tile, mask=bc_tile < n * k)
|
||||
|
||||
acc += tl.dot(intermediate.to(A.dtype.element_ty), c)
|
||||
|
||||
tl.store(out + a_tile, acc.to(A.dtype.element_ty), mask=a_tile < m * k)
|
||||
|
||||
m, n, k = 32, 64, 128
|
||||
block_m, block_n, block_k = 16, 32, k
|
||||
|
||||
grid = (triton.cdiv(m, block_m),)
|
||||
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16,
|
||||
device='cuda')
|
||||
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16,
|
||||
device='cuda')
|
||||
c = torch.randint_like(b, low=0, high=2)
|
||||
triton_result = torch.zeros_like(a)
|
||||
|
||||
torch_result = chained_matmul_reference(a, b, c)
|
||||
chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k,
|
||||
block_m=block_m, block_n=block_n,
|
||||
block_k=block_k)
|
||||
|
||||
assert (torch_result == triton_result).all()
|
||||
69
python/test/unit/debugger/test_debugger.py
Normal file
69
python/test/unit/debugger/test_debugger.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from triton.debugger.debugger import program_ids_from_grid
|
||||
|
||||
|
||||
def test_addition():
|
||||
|
||||
@triton.jit(interpret=True)
|
||||
def add_kernel(
|
||||
x_ptr,
|
||||
y_ptr,
|
||||
output_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
output = x + y
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
a = torch.rand((128,), device="cuda")
|
||||
b = torch.rand((128,), device="cuda")
|
||||
expected = a + b
|
||||
output = torch.empty((128,), device="cuda")
|
||||
|
||||
def grid(meta):
|
||||
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)
|
||||
|
||||
add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)
|
||||
|
||||
assert torch.allclose(expected, output, atol=1e-2, rtol=0)
|
||||
|
||||
|
||||
def test_program_ids_from_grid():
|
||||
random.seed(123)
|
||||
grid = (3, 4)
|
||||
expected_combinations = 3 * 4
|
||||
unique_combinations = set(program_ids_from_grid(grid))
|
||||
assert len(unique_combinations) == expected_combinations
|
||||
|
||||
first_run = list(program_ids_from_grid(grid))
|
||||
second_run = list(program_ids_from_grid(grid))
|
||||
assert first_run != second_run
|
||||
|
||||
|
||||
def test_atomic():
|
||||
@triton.jit(interpret=True)
|
||||
def atomic(
|
||||
x_ptr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
tl.atomic_add(x_ptr + pid, 1)
|
||||
t = tl.atomic_xchg(x_ptr + pid, 3)
|
||||
t += 1 # 2
|
||||
tl.atomic_cas(x_ptr + pid, 3, t) # match
|
||||
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
|
||||
nb_dim = 16
|
||||
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")
|
||||
|
||||
atomic[(nb_dim, )](a)
|
||||
assert torch.allclose(a, torch.full_like(a, 2))
|
||||
@@ -14,6 +14,14 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
# Trivial assert
|
||||
tl.device_assert(0 == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
@@ -33,7 +41,12 @@ def test_assert(func: str):
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
<<<<<<< HEAD
|
||||
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
=======
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
|
||||
>>>>>>> openai/main
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
|
||||
@@ -457,6 +457,86 @@ def test_broadcast(dtype):
|
||||
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
|
||||
|
||||
|
||||
# ----------------
|
||||
# test expand_dims
|
||||
# ----------------
|
||||
def test_expand_dims():
|
||||
@triton.jit
|
||||
def expand_dims_kernel(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, N)
|
||||
|
||||
t = tl.expand_dims(offset1, 0)
|
||||
tl.static_assert(t.shape == [1, N])
|
||||
|
||||
t = tl.expand_dims(offset1, 1)
|
||||
tl.static_assert(t.shape == [N, 1])
|
||||
|
||||
t = tl.expand_dims(offset1, -1)
|
||||
tl.static_assert(t.shape == [N, 1])
|
||||
|
||||
t = tl.expand_dims(offset1, -2)
|
||||
tl.static_assert(t.shape == [1, N])
|
||||
|
||||
t = tl.expand_dims(offset1, (0, -1))
|
||||
tl.static_assert(t.shape == [1, N, 1])
|
||||
|
||||
t = tl.expand_dims(offset1, (0, 1, 3))
|
||||
tl.static_assert(t.shape == [1, 1, N, 1])
|
||||
|
||||
t = tl.expand_dims(offset1, (-4, 2, -1))
|
||||
tl.static_assert(t.shape == [1, N, 1, 1])
|
||||
|
||||
t = tl.expand_dims(offset1, (3, 1, 2))
|
||||
tl.static_assert(t.shape == [N, 1, 1, 1])
|
||||
|
||||
N = 32
|
||||
dummy_tensor = torch.empty((), device="cuda")
|
||||
expand_dims_kernel[(1,)](dummy_tensor, N)
|
||||
|
||||
|
||||
def test_expand_dims_error_cases():
|
||||
@triton.jit
|
||||
def dim_out_of_range1(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, N)
|
||||
|
||||
t = tl.expand_dims(offset1, -2)
|
||||
t = tl.expand_dims(offset1, -3)
|
||||
|
||||
@triton.jit
|
||||
def dim_out_of_range2(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, N)
|
||||
|
||||
t = tl.expand_dims(offset1, 1)
|
||||
t = tl.expand_dims(offset1, 2)
|
||||
|
||||
@triton.jit
|
||||
def duplicate_dim1(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, N)
|
||||
|
||||
t = tl.expand_dims(offset1, (0, 0))
|
||||
|
||||
@triton.jit
|
||||
def duplicate_dim2(dummy, N: tl.constexpr):
|
||||
offset1 = tl.arange(0, N)
|
||||
|
||||
t = tl.expand_dims(offset1, (0, -3))
|
||||
|
||||
N = 32
|
||||
dummy_tensor = torch.empty((), device="cuda")
|
||||
|
||||
with pytest.raises(triton.CompilationError, match="invalid axis -3"):
|
||||
dim_out_of_range1[(1,)](dummy_tensor, N)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match="invalid axis 2"):
|
||||
dim_out_of_range2[(1,)](dummy_tensor, N)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
|
||||
duplicate_dim1[(1,)](dummy_tensor, N)
|
||||
|
||||
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
|
||||
duplicate_dim2[(1,)](dummy_tensor, N)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test where
|
||||
# ---------------
|
||||
@@ -688,7 +768,7 @@ def test_index1d(expr, dtype_str, device='cuda'):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fn(a, b):
|
||||
def tuples_fn(a, b):
|
||||
return a + b, \
|
||||
a - b, \
|
||||
a * b
|
||||
@@ -701,7 +781,7 @@ def test_tuples():
|
||||
def with_fn(X, Y, A, B, C):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
a, b, c = fn(x, y)
|
||||
a, b, c = tuples_fn(x, y)
|
||||
tl.store(A, a)
|
||||
tl.store(B, b)
|
||||
tl.store(C, c)
|
||||
@@ -728,6 +808,92 @@ def test_tuples():
|
||||
assert c_tri == c_ref
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def noinline_simple_fn(x, y, Z):
|
||||
z = x + y
|
||||
tl.store(Z, z)
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def noinline_call_graph_fn1(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def noinline_call_graph_fn2(y):
|
||||
return y + 2
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def noinline_call_graph_fn(x, y, Z):
|
||||
t0 = noinline_call_graph_fn1(x)
|
||||
t1 = noinline_call_graph_fn2(y)
|
||||
z = t0 + t1
|
||||
tl.store(Z, z)
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def noinline_shared_fn(x, y, Z):
|
||||
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
|
||||
z = tl.load(Z + offs)
|
||||
z = tl.dot(z, z) + x + y
|
||||
tl.store(Z + offs, z)
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def noinline_dynamic_fn(x, y, Z):
|
||||
if x >= 1:
|
||||
x = noinline_call_graph_fn1(x)
|
||||
else:
|
||||
x = noinline_call_graph_fn2(x)
|
||||
if y >= 2:
|
||||
y = noinline_call_graph_fn2(y)
|
||||
else:
|
||||
y = noinline_call_graph_fn1(y)
|
||||
z = x + y
|
||||
tl.store(Z, z)
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def noinline_call_multi_values_fn(x, y):
|
||||
return x + 1, y + 2
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def noinline_multi_values_fn(x, y, Z):
|
||||
x, y = noinline_call_multi_values_fn(x, y)
|
||||
z = x + y
|
||||
tl.store(Z, z)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"])
|
||||
def test_noinline(mode):
|
||||
device = 'cuda'
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, Y, Z):
|
||||
x = tl.load(X)
|
||||
y = tl.load(Y)
|
||||
GENERATE_TEST_HERE(x, y, Z)
|
||||
|
||||
func_name = f'noinline_{mode}_fn'
|
||||
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name})
|
||||
x = torch.tensor([1.0], device=device, dtype=torch.float32)
|
||||
y = torch.tensor([2.0], device=device, dtype=torch.float32)
|
||||
if mode == "shared":
|
||||
z = torch.ones((16, 16), device=device, dtype=torch.float32)
|
||||
else:
|
||||
z = torch.tensor([0.0], device=device, dtype=torch.float32)
|
||||
kernel[(1,)](x, y, z, num_warps=1)
|
||||
if mode == "simple":
|
||||
assert torch.equal(z, x + y)
|
||||
elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values":
|
||||
assert torch.equal(z, x + 1 + y + 2)
|
||||
elif mode == "shared":
|
||||
ref = torch.full((16, 16), 16, device=device, dtype=torch.float32)
|
||||
assert torch.equal(z, ref + x + y)
|
||||
|
||||
|
||||
# ---------------
|
||||
# test atomics
|
||||
# ---------------
|
||||
@@ -1334,6 +1500,99 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
|
||||
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
layouts = [
|
||||
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [32, 64, 128, 256])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
def test_store_op(M, src_layout, device='cuda'):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
||||
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
|
||||
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
|
||||
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
|
||||
tt.store %8, %4 : tensor<{M}x1xf32, #src>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
store_kernel = triton.compile(f.name)
|
||||
|
||||
rs = RandomState(17)
|
||||
x = rs.randint(0, 4, (M, 1)).astype('float32')
|
||||
y = np.zeros((M, 1), dtype='float32')
|
||||
x_tri = torch.tensor(x, device=device)
|
||||
y_tri = torch.tensor(y, device=device)
|
||||
|
||||
pgm = store_kernel[(1, 1, 1)](x_tri, y_tri)
|
||||
y_ref = x
|
||||
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
layouts = [
|
||||
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
||||
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M", [64, 128, 256])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
@pytest.mark.parametrize("dst_layout", layouts)
|
||||
@pytest.mark.parametrize("src_dim", [0, 1])
|
||||
@pytest.mark.parametrize("dst_dim", [0, 1])
|
||||
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device='cuda'):
|
||||
ir = f"""
|
||||
#dst = {dst_layout}
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
||||
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
|
||||
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
%7 = triton_gpu.convert_layout %3 : (tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
tt.store %6, %7 : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
|
||||
rs = RandomState(17)
|
||||
x = rs.randint(0, 4, (M, )).astype('int32')
|
||||
y = np.zeros((M, ), dtype='int32')
|
||||
x_tri = torch.tensor(x, device=device)
|
||||
y_tri = torch.tensor(y, device=device)
|
||||
pgm = kernel[(1, 1, 1)](x_tri, y_tri)
|
||||
y_ref = x
|
||||
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
|
||||
delta = mean_2 - mean_1
|
||||
@@ -1346,6 +1605,68 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
|
||||
)
|
||||
|
||||
|
||||
layouts = [
|
||||
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
|
||||
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]])
|
||||
@pytest.mark.parametrize("src_layout", layouts)
|
||||
def test_chain_reduce(M, N, src_layout, device='cuda'):
|
||||
ir = f"""
|
||||
#src = {src_layout}
|
||||
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
|
||||
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
|
||||
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
|
||||
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
|
||||
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
|
||||
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
|
||||
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
|
||||
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
|
||||
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
|
||||
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
|
||||
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
|
||||
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
|
||||
%11 = "tt.reduce"(%10) ({{
|
||||
^bb0(%arg2: i32, %arg3: i32):
|
||||
%13 = arith.addi %arg2, %arg3 : i32
|
||||
tt.reduce.return %13 : i32
|
||||
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
|
||||
%12 = "tt.reduce"(%11) ({{
|
||||
^bb0(%arg2: i32, %arg3: i32):
|
||||
%13 = arith.addi %arg2, %arg3 : i32
|
||||
tt.reduce.return %13 : i32
|
||||
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
|
||||
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
|
||||
tt.return
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
|
||||
f.write(ir)
|
||||
f.flush()
|
||||
kernel = triton.compile(f.name)
|
||||
|
||||
rs = RandomState(17)
|
||||
x = rs.randint(0, 4, (M, N)).astype('int32')
|
||||
|
||||
z = np.zeros((1,)).astype('int32')
|
||||
|
||||
x_tri = torch.tensor(x, device=device)
|
||||
z_tri = torch.tensor(z, device=device)
|
||||
|
||||
pgm = kernel[(1, 1, 1)](x_tri, z_tri)
|
||||
z_ref = np.sum(x)
|
||||
|
||||
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
|
||||
|
||||
|
||||
def test_generic_reduction(device='cuda'):
|
||||
|
||||
@triton.jit
|
||||
@@ -2013,57 +2334,79 @@ def val_multiplier(val, i):
|
||||
return val * i
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def val_multiplier_noinline(val, i):
|
||||
return val * i
|
||||
|
||||
|
||||
@triton.jit
|
||||
def vecmul_kernel(ptr, n_elements, rep):
|
||||
def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offsets = pid * 128 + tl.arange(0, 128)
|
||||
mask = offsets < n_elements
|
||||
vec = tl.load(ptr + offsets, mask=mask)
|
||||
for i in range(1, rep):
|
||||
vec = val_multiplier(vec, i)
|
||||
if type == "inline":
|
||||
vec = val_multiplier(vec, i)
|
||||
else:
|
||||
vec = val_multiplier_noinline(vec, i)
|
||||
tl.store(ptr + offsets, vec, mask=mask)
|
||||
|
||||
|
||||
def test_call():
|
||||
@pytest.mark.parametrize("type", ["inline", "noinline"])
|
||||
def test_call(type):
|
||||
|
||||
@triton.jit
|
||||
def kernel(ptr, n_elements, num1, num2):
|
||||
vecmul_kernel(ptr, n_elements, num1)
|
||||
vecmul_kernel(ptr, n_elements, num2)
|
||||
def kernel(ptr, n_elements, num1, num2, type: tl.constexpr):
|
||||
vecmul_kernel(ptr, n_elements, num1, type)
|
||||
vecmul_kernel(ptr, n_elements, num2, type)
|
||||
|
||||
size = 1024
|
||||
rand_val = numpy_random((size,), dtype_str="float32")
|
||||
rand_val_tri = to_triton(rand_val, device='cuda')
|
||||
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
|
||||
err_msg = ""
|
||||
try:
|
||||
kernel[(size // 128,)](rand_val_tri, size, 3, 5, type)
|
||||
except Exception as e:
|
||||
err_msg = str(e)
|
||||
|
||||
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||
if type == "noinline":
|
||||
assert err_msg is not ""
|
||||
else:
|
||||
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
|
||||
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
|
||||
|
||||
# -------------
|
||||
# test if
|
||||
# -------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_exp"])
|
||||
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"])
|
||||
def test_if(if_type):
|
||||
|
||||
@triton.jit
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr):
|
||||
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
cond = tl.load(Cond)
|
||||
if IfType == "if":
|
||||
if pid % 2:
|
||||
if pid % 2 == 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
else:
|
||||
elif IfType == "if_exp":
|
||||
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
|
||||
elif IfType == "if_and":
|
||||
if BoolVar and pid % 2 == 0:
|
||||
tl.store(Ret, tl.load(XTrue))
|
||||
else:
|
||||
tl.store(Ret, tl.load(XFalse))
|
||||
|
||||
cond = torch.ones(1, dtype=torch.int32, device='cuda')
|
||||
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
|
||||
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
|
||||
ret = torch.empty(1, dtype=torch.float32, device='cuda')
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type)
|
||||
kernel[(1,)](cond, x_true, x_false, ret, if_type, True)
|
||||
assert torch.equal(ret, x_true)
|
||||
|
||||
|
||||
def test_num_warps_pow2():
|
||||
@@ -2227,24 +2570,105 @@ def test_if_else():
|
||||
assert to_numpy(out)[0] == false_val[0]
|
||||
|
||||
|
||||
def test_if_return():
|
||||
@pytest.mark.parametrize("mode", ["dynamic", "static"])
|
||||
def test_if_return(mode):
|
||||
|
||||
@triton.jit
|
||||
def kernel(ExitEarly, Out):
|
||||
if tl.load(ExitEarly):
|
||||
tl.store(Out, 0)
|
||||
return
|
||||
def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr):
|
||||
if mode == "dynamic":
|
||||
if tl.load(ExitEarly):
|
||||
tl.store(Out, 0)
|
||||
return
|
||||
else:
|
||||
if cond:
|
||||
tl.store(Out, 0)
|
||||
return
|
||||
tl.store(Out, 1)
|
||||
|
||||
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
||||
exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
||||
# exit early path taken
|
||||
exit_early[0] = 1
|
||||
kernel[(1,)](exit_early, out)
|
||||
kernel[(1,)](exit_early, out, True, mode)
|
||||
assert to_numpy(out)[0] == 0
|
||||
# exit early path not taken
|
||||
exit_early[0] = 0
|
||||
kernel[(1,)](exit_early, out)
|
||||
kernel[(1,)](exit_early, out, False, mode)
|
||||
assert to_numpy(out)[0] == 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_fn(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
@triton.jit(noinline=True)
|
||||
def add_fn_noinline(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_fn_return(x, pid):
|
||||
if pid == 0:
|
||||
return x + 1
|
||||
else:
|
||||
return x + 2
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_fn_expr(Out, x):
|
||||
tl.store(Out, x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def add_fn_static_cond(x, cond: tl.constexpr):
|
||||
if cond == "":
|
||||
return x
|
||||
else:
|
||||
return x + 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
|
||||
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
|
||||
def test_if_call(call_type):
|
||||
@triton.jit
|
||||
def kernel(Out, call_type: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
o = tl.load(Out)
|
||||
if pid == 0:
|
||||
if call_type == "attribute":
|
||||
# call attribute
|
||||
a = o + 1
|
||||
a = a.to(tl.int32).to(tl.int32)
|
||||
o = a
|
||||
else:
|
||||
a = o
|
||||
if call_type == "jit_function":
|
||||
# regular function call
|
||||
a = add_fn(a)
|
||||
elif call_type == "jit_function_return":
|
||||
# function without end_if block
|
||||
a = add_fn_return(a, pid)
|
||||
elif call_type == "ifexp":
|
||||
# ifexp expression
|
||||
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
|
||||
elif call_type == "expr":
|
||||
if pid == 1:
|
||||
return
|
||||
a = add_fn(a)
|
||||
if pid == 0:
|
||||
# call without return
|
||||
add_fn_expr(Out, a)
|
||||
elif call_type == "jit_function_static_cond":
|
||||
a = add_fn_static_cond(a, call_type)
|
||||
elif call_type == "jit_function_noinline":
|
||||
a = add_fn_noinline(a)
|
||||
o = a
|
||||
|
||||
tl.store(Out, o)
|
||||
|
||||
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
|
||||
kernel[(1,)](out, call_type)
|
||||
assert to_numpy(out)[0] == 1
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@@ -170,34 +168,32 @@ def test_jit_debug() -> None:
|
||||
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
@triton.jit
|
||||
def add_fn(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
|
||||
|
||||
|
||||
def test_jit_noinline() -> None:
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||
def kernel_add_device(a, b, o, N: tl.constexpr):
|
||||
add_fn(a, b, o, N)
|
||||
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = namedtuple("instance_descriptor", [
|
||||
"divisible_by_16", "equal_to_1"])(
|
||||
tuple(range(4)),
|
||||
())
|
||||
|
||||
proc = multiprocessing.Process(
|
||||
target=triton.compile,
|
||||
kwargs=dict(
|
||||
fn=kernel_sub,
|
||||
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||
device=0,
|
||||
constants={3: 32},
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
device = torch.cuda.current_device()
|
||||
assert len(kernel_add_device.cache[device]) == 0
|
||||
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add_device.cache[device]) == 1
|
||||
bins = list(kernel_add_device.cache[device].values())
|
||||
inline_ttir = bins[0].asm['ttir']
|
||||
add_fn.noinline = True
|
||||
add_fn.hash = None
|
||||
kernel_add_device.hash = None
|
||||
kernel_add_device.cache[device].clear()
|
||||
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add_device.cache[device]) == 1
|
||||
bins = list(kernel_add_device.cache[device].values())
|
||||
noinline_ttir = bins[0].asm['ttir']
|
||||
assert inline_ttir != noinline_ttir
|
||||
|
||||
|
||||
def test_memory_leak() -> None:
|
||||
|
||||
83
python/test/unit/runtime/test_subproc.py
Normal file
83
python/test/unit/runtime/test_subproc.py
Normal file
@@ -0,0 +1,83 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
tmpdir = ".tmp"
|
||||
|
||||
|
||||
def reset_tmp_dir():
|
||||
os.environ["TRITON_CACHE_DIR"] = tmpdir
|
||||
if os.path.exists(tmpdir):
|
||||
shutil.rmtree(tmpdir)
|
||||
|
||||
|
||||
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])
|
||||
|
||||
|
||||
def compile_fn(config, cc):
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
|
||||
triton.compile(
|
||||
fn=kernel_sub,
|
||||
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
|
||||
device=0,
|
||||
constants={3: 32},
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(4)), ())
|
||||
|
||||
multiprocessing.set_start_method('fork')
|
||||
proc = multiprocessing.Process(
|
||||
target=compile_fn,
|
||||
args=(config, cc))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
|
||||
|
||||
def compile_fn_dot(config, cc):
|
||||
@triton.jit
|
||||
def kernel_dot(Z):
|
||||
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
|
||||
z = tl.load(Z + offs)
|
||||
z = tl.dot(z, z)
|
||||
tl.store(Z + offs, z)
|
||||
|
||||
triton.compile(
|
||||
fn=kernel_dot,
|
||||
signature={0: "*fp32"},
|
||||
device=0,
|
||||
configs=[config],
|
||||
warm_cache_only=True,
|
||||
cc=cc,
|
||||
)
|
||||
|
||||
|
||||
def test_compile_in_forked_subproc() -> None:
|
||||
reset_tmp_dir()
|
||||
major, minor = torch.cuda.get_device_capability(0)
|
||||
cc = major * 10 + minor
|
||||
config = instance_descriptor(tuple(range(1)), ())
|
||||
|
||||
assert multiprocessing.get_start_method() == 'fork'
|
||||
proc = multiprocessing.Process(
|
||||
target=compile_fn_dot,
|
||||
args=(config, cc))
|
||||
proc.start()
|
||||
proc.join()
|
||||
assert proc.exitcode == 0
|
||||
@@ -4,10 +4,6 @@ __version__ = '2.1.0'
|
||||
# ---------------------------------------
|
||||
# Note: import order is significant here.
|
||||
|
||||
# TODO: torch needs to be imported first
|
||||
# or pybind11 shows `munmap_chunk(): invalid pointer`
|
||||
import torch # noqa: F401
|
||||
|
||||
# submodules
|
||||
from .runtime import (
|
||||
autotune,
|
||||
@@ -22,6 +18,8 @@ from .runtime import (
|
||||
)
|
||||
from .runtime.jit import jit
|
||||
from .compiler import compile, CompilationError
|
||||
from .debugger.debugger import program_ids_from_grid
|
||||
|
||||
from . import language
|
||||
from . import testing
|
||||
|
||||
@@ -45,6 +43,7 @@ __all__ = [
|
||||
"runtime",
|
||||
"TensorWrapper",
|
||||
"testing",
|
||||
"program_ids_from_grid",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -46,6 +46,8 @@ def mangle_fn(name, arg_tys, constants):
|
||||
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
|
||||
mangled_constants = mangled_constants.replace('.', '_d_')
|
||||
mangled_constants = mangled_constants.replace("'", '_sq_')
|
||||
# [ and ] are not allowed in LLVM identifiers
|
||||
mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
|
||||
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
|
||||
return ret
|
||||
|
||||
@@ -58,10 +60,21 @@ def _is_constexpr(o: Any) -> bool:
|
||||
return isinstance(o, constexpr)
|
||||
|
||||
|
||||
def _is_triton_scalar(o: Any) -> bool:
|
||||
return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
|
||||
|
||||
|
||||
def _unwrap_if_constexpr(o: Any):
|
||||
return o.value if isinstance(o, constexpr) else o
|
||||
|
||||
|
||||
def _check_fn_args(node, fn, args):
|
||||
if fn.noinline:
|
||||
for idx, arg in enumerate(args):
|
||||
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
|
||||
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
|
||||
|
||||
|
||||
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
|
||||
|
||||
|
||||
@@ -86,7 +99,8 @@ class enter_sub_region:
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name,
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None, debug=False):
|
||||
module=None, is_kernel=False, function_types: Optional[Dict] = None,
|
||||
debug=False, noinline=False):
|
||||
self.builder = ir.builder(context)
|
||||
self.module = self.builder.create_module() if module is None else module
|
||||
self.function_ret_types = {} if function_types is None else function_types
|
||||
@@ -99,7 +113,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.is_kernel = is_kernel
|
||||
self.last_node = None
|
||||
self.debug = debug
|
||||
self.noinline = noinline
|
||||
self.scf_stack = []
|
||||
self.last_ret_type = None
|
||||
# SSA-construction
|
||||
# name => language.tensor
|
||||
self.local_defs: Dict[str, tensor] = {}
|
||||
@@ -134,7 +150,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def set_value(self, name: str,
|
||||
value: Union[tensor, constexpr]) -> None:
|
||||
''' This function:
|
||||
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
|
||||
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
|
||||
1. record local defined name (FIXME: should consider control flow)
|
||||
2. store tensor in self.lvalue
|
||||
'''
|
||||
@@ -146,10 +162,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
#
|
||||
def visit_compound_statement(self, stmts):
|
||||
for stmt in stmts:
|
||||
self.last_ret_type = self.visit(stmt)
|
||||
if isinstance(stmt, ast.Return):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
ret_type = self.visit(stmt)
|
||||
if ret_type is not None and isinstance(stmt, ast.Return):
|
||||
self.last_ret_type = ret_type
|
||||
|
||||
# TODO: should be its own AST visitor
|
||||
def contains_return_op(self, node):
|
||||
@@ -164,8 +179,23 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
pred = lambda s: self.contains_return_op(s)
|
||||
return any(pred(s) for s in node.body)
|
||||
elif isinstance(node, ast.Call):
|
||||
def check_undefined_name(cur_node):
|
||||
# Check if name is an undefined local variable,
|
||||
# which can only be a tensor or a constexpr
|
||||
if isinstance(cur_node.func, ast.Attribute):
|
||||
if isinstance(cur_node.func.value, ast.Name):
|
||||
name = cur_node.func.value.id
|
||||
if name not in self.lscope and name not in self.gscope:
|
||||
return True
|
||||
return False
|
||||
# chain of calls
|
||||
# e.g., tl.load(a).to(tl.float32)
|
||||
return check_undefined_name(cur_node.func.value)
|
||||
return False
|
||||
if check_undefined_name(node):
|
||||
return False
|
||||
fn = self.visit(node.func)
|
||||
if isinstance(fn, JITFunction):
|
||||
if isinstance(fn, JITFunction) and fn.noinline is not True:
|
||||
old_gscope = self.gscope
|
||||
self.gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
ret = self.contains_return_op(fn.parse())
|
||||
@@ -178,6 +208,18 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if node.orelse:
|
||||
ret = ret or any(pred(s) for s in node.orelse)
|
||||
return ret
|
||||
elif isinstance(node, ast.IfExp):
|
||||
return self.contains_return_op(node.body) or self.contains_return_op(node.orelse)
|
||||
elif isinstance(node, ast.Expr):
|
||||
ret = False
|
||||
for _, value in ast.iter_fields(node):
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, ast.AST):
|
||||
ret = ret or self.contains_return_op(item)
|
||||
elif isinstance(value, ast.AST):
|
||||
ret = ret or self.contains_return_op(value)
|
||||
return ret
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -228,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit(init_node)
|
||||
# initialize function
|
||||
visibility = "public" if self.is_kernel else "private"
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility)
|
||||
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
|
||||
self.module.push_back(fn)
|
||||
entry = fn.add_entry_block()
|
||||
arg_values = []
|
||||
@@ -251,9 +293,9 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.set_value(arg_name, arg_value)
|
||||
self.builder.set_insertion_point_to_start(entry)
|
||||
# visit function body
|
||||
has_ret = self.visit_compound_statement(node.body)
|
||||
self.visit_compound_statement(node.body)
|
||||
# finalize function
|
||||
if not has_ret:
|
||||
if self.last_ret_type is None:
|
||||
self.builder.ret([])
|
||||
else:
|
||||
# update return type
|
||||
@@ -265,6 +307,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
fn.reset_type(self.prototype.to_ir(self.builder))
|
||||
if insert_pt:
|
||||
self.builder.set_insertion_point_to_end(insert_pt)
|
||||
# Remove dead code
|
||||
fn.finalize()
|
||||
|
||||
def visit_arguments(self, node):
|
||||
arg_names = []
|
||||
@@ -415,6 +459,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types
|
||||
|
||||
def visit_if_top_level(self, cond, node):
|
||||
has_endif_block = True
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, ip_block = sr
|
||||
then_block = self.builder.create_block()
|
||||
@@ -429,20 +474,25 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.visit_then_else_blocks(node, liveins, then_block, else_block)
|
||||
# then terminator
|
||||
self.builder.set_insertion_point_to_end(then_block)
|
||||
if not then_block.has_terminator():
|
||||
if then_block.has_return() and else_block.has_return():
|
||||
has_endif_block = False
|
||||
endif_block.erase()
|
||||
if not then_block.has_terminator() and has_endif_block:
|
||||
self.builder.create_branch(endif_block, [then_defs[n].handle for n in names])
|
||||
# else terminator
|
||||
self.builder.set_insertion_point_to_end(else_block)
|
||||
if not else_block.has_terminator():
|
||||
if not else_block.has_terminator() and has_endif_block:
|
||||
self.builder.create_branch(endif_block, [else_defs[n].handle for n in names])
|
||||
for ty in ir_ret_types:
|
||||
endif_block.add_argument(ty)
|
||||
# change block
|
||||
self.builder.set_insertion_point_to_start(endif_block)
|
||||
# update value
|
||||
for i, name in enumerate(names):
|
||||
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
|
||||
self.set_value(name, new_tensor)
|
||||
if has_endif_block:
|
||||
for ty in ir_ret_types:
|
||||
endif_block.add_argument(ty)
|
||||
if has_endif_block:
|
||||
# change block
|
||||
self.builder.set_insertion_point_to_start(endif_block)
|
||||
# update value
|
||||
for i, name in enumerate(names):
|
||||
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
|
||||
self.set_value(name, new_tensor)
|
||||
|
||||
# TODO: refactor
|
||||
def visit_if_scf(self, cond, node):
|
||||
@@ -650,6 +700,8 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
ub = language.core._to_tensor(ub, self.builder)
|
||||
step = language.core._to_tensor(step, self.builder)
|
||||
# induction variable type
|
||||
if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
|
||||
raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
|
||||
iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
||||
iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype)
|
||||
iv_ir_type = iv_type.to_ir(self.builder)
|
||||
@@ -773,7 +825,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not self.module.has_function(fn_name):
|
||||
prototype = language.function_type([], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug)
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=fn.debug, noinline=fn.noinline)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -805,6 +857,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not self.debug:
|
||||
return
|
||||
if isinstance(fn, JITFunction):
|
||||
_check_fn_args(node, fn, args)
|
||||
return self.call_JitFunction(fn, args, kws)
|
||||
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
|
||||
extra_kwargs = dict(_builder=self.builder)
|
||||
|
||||
@@ -11,8 +11,6 @@ from collections import namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
import triton._C.libtriton.triton as _triton
|
||||
from ..runtime import driver
|
||||
@@ -254,7 +252,7 @@ def convert_type_repr(x):
|
||||
return x
|
||||
|
||||
|
||||
def make_hash(fn, **kwargs):
|
||||
def make_hash(fn, arch, **kwargs):
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
configs = kwargs["configs"]
|
||||
signature = kwargs["signature"]
|
||||
@@ -265,7 +263,7 @@ def make_hash(fn, **kwargs):
|
||||
# Get unique key for the compiled code
|
||||
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
|
||||
configs_key = [get_conf_key(conf) for conf in configs]
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}"
|
||||
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}"
|
||||
return hashlib.md5(key.encode("utf-8")).hexdigest()
|
||||
assert isinstance(fn, str)
|
||||
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
|
||||
@@ -330,6 +328,10 @@ def is_hip():
|
||||
|
||||
|
||||
def get_architecture_descriptor(capability):
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
if capability is None:
|
||||
if torch.version.hip is None:
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
@@ -428,7 +430,7 @@ def compile(fn, **kwargs):
|
||||
# cache manager
|
||||
so_path = make_stub(name, signature, constants)
|
||||
# create cache manager
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, **kwargs))
|
||||
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
|
||||
# determine name and extension type of provided function
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
name, ext = fn.__name__, "ast"
|
||||
|
||||
9
python/triton/debugger/core.py
Normal file
9
python/triton/debugger/core.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from typing import Tuple
|
||||
|
||||
import dataclasses
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ExecutionContext:
|
||||
program_id: Tuple[int]
|
||||
program_size: Tuple[int]
|
||||
170
python/triton/debugger/debugger.py
Normal file
170
python/triton/debugger/debugger.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import itertools
|
||||
import random
|
||||
from typing import Tuple
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
|
||||
debugger_constexpr)
|
||||
from triton.debugger import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
tl_method_backup = {}
|
||||
|
||||
|
||||
def get_proxy_method(proxy, name):
|
||||
method = getattr(proxy, name)
|
||||
|
||||
def fun(*args, **kwarg):
|
||||
return method(*args, **kwarg)
|
||||
|
||||
return fun
|
||||
|
||||
|
||||
def attach_triton(module, proxy):
|
||||
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
|
||||
for name in method_list:
|
||||
if hasattr(module, name):
|
||||
attr = getattr(module, name)
|
||||
tl_method_backup[name] = attr
|
||||
if callable(attr):
|
||||
setattr(module, name, get_proxy_method(proxy, name))
|
||||
else:
|
||||
setattr(module, name, getattr(proxy, name))
|
||||
|
||||
|
||||
def detach_triton(module):
|
||||
for name, method in tl_method_backup.items():
|
||||
setattr(module, name, method)
|
||||
|
||||
|
||||
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
|
||||
# reverse the grid dimensions and generate the range for each dimension
|
||||
reversed_grid = reversed(grid)
|
||||
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
|
||||
|
||||
# gen all combinations
|
||||
index_combinations = list(itertools.product(*ranges_for_each_dimension))
|
||||
random.shuffle(index_combinations)
|
||||
|
||||
for index_combination in index_combinations:
|
||||
yield index_combination
|
||||
|
||||
|
||||
class DebuggerFunction:
|
||||
def __init__(self, func, grid=(1,)):
|
||||
self.func = func
|
||||
self.grid = grid
|
||||
|
||||
def _is_constexpr(self, name):
|
||||
return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr
|
||||
|
||||
def _get_constexpr(self):
|
||||
result = []
|
||||
for name, annotation in self.func.__annotations__.items():
|
||||
if annotation is triton.language.core.constexpr:
|
||||
result.append(name)
|
||||
return result
|
||||
|
||||
def _assert_constexpr(self, **kwargs):
|
||||
constexp = self._get_constexpr()
|
||||
missing = [i for i in constexp if i not in kwargs.keys()]
|
||||
assert len(missing) == 0, f"You must specify constexpr {missing}"
|
||||
|
||||
def _get_grid(self, **kwargs):
|
||||
if callable(self.grid):
|
||||
return self.grid(kwargs)
|
||||
else:
|
||||
return self.grid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._assert_constexpr(**kwargs)
|
||||
|
||||
memory = MemoryMap()
|
||||
|
||||
def convert_arg(v):
|
||||
name, arg = v
|
||||
if torch.is_tensor(arg):
|
||||
ptr = memory.add_tensor(arg)
|
||||
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
|
||||
if self._is_constexpr(name):
|
||||
return debugger_constexpr(arg)
|
||||
return WrappedTensor(_primitive_to_tensor(arg))
|
||||
|
||||
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
|
||||
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
|
||||
|
||||
grid = self._get_grid(**kwargs)
|
||||
for program_id in program_ids_from_grid(grid):
|
||||
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
|
||||
attach_triton(tl, proxy)
|
||||
self.func(*new_args, **new_kwargs)
|
||||
detach_triton(tl)
|
||||
|
||||
|
||||
class GridSelector:
|
||||
"""
|
||||
Entry point of the debugger
|
||||
"""
|
||||
|
||||
def __init__(self, func):
|
||||
version = torch.__version__
|
||||
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
|
||||
self.func = func
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return DebuggerFunction(self.func, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return DebuggerFunction(self.func)(*args, **kwargs)
|
||||
|
||||
|
||||
class AutotuneGridSelector:
|
||||
def __init__(self, func, autotune_params):
|
||||
self.func = func
|
||||
self.autotune_params = autotune_params
|
||||
|
||||
def __getitem__(self, grid):
|
||||
return AutotuneRunner(self.func, self.autotune_params, grid)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
|
||||
|
||||
|
||||
class AutotuneRunner:
|
||||
def __init__(self, func, autotune_params, grid=None):
|
||||
self.func = func
|
||||
self.autotune_params = autotune_params
|
||||
self.grid = grid
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert len(self.autotune_params["configs"]) >= 1
|
||||
|
||||
for config in self.autotune_params["configs"][1:]:
|
||||
|
||||
def convert_arg(v):
|
||||
if torch.is_tensor(v):
|
||||
return torch.clone(v)
|
||||
return v
|
||||
|
||||
new_args = tuple(map(convert_arg, args))
|
||||
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
|
||||
if self.grid:
|
||||
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
|
||||
else:
|
||||
self.func(*new_args, **new_kwargs, **config.kwargs)
|
||||
|
||||
main_config = self.autotune_params["configs"][0]
|
||||
if self.grid:
|
||||
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
|
||||
else:
|
||||
self.func(*args, **kwargs, **main_config.kwargs)
|
||||
|
||||
|
||||
def triton_debug_autotune(**kwars):
|
||||
def wrapper(func):
|
||||
return AutotuneGridSelector(func, kwars)
|
||||
|
||||
return wrapper
|
||||
100
python/triton/debugger/memory_map.py
Normal file
100
python/triton/debugger/memory_map.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import dataclasses
|
||||
|
||||
from triton.debugger import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class RegisteredStorage:
|
||||
storage: torch.Storage
|
||||
dtype: torch.dtype
|
||||
size: int
|
||||
ptr: int
|
||||
|
||||
@property
|
||||
def end_ptr(self) -> int:
|
||||
return self.ptr + self.size
|
||||
|
||||
@property
|
||||
def access_tensor(self) -> torch.Tensor:
|
||||
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
|
||||
|
||||
def ensure_immutable(self):
|
||||
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
|
||||
|
||||
|
||||
class MemoryMap:
|
||||
storages: [RegisteredStorage]
|
||||
|
||||
def __init__(self):
|
||||
self.storages = []
|
||||
|
||||
def _get_registered_storage(self, pointer: torch.Tensor):
|
||||
max_pointer = torch.max(pointer).item()
|
||||
min_pointer = torch.min(pointer).item()
|
||||
|
||||
registered_storage = next(
|
||||
filter(
|
||||
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
|
||||
),
|
||||
None,
|
||||
)
|
||||
if registered_storage is None:
|
||||
raise Exception("Storage not found or pointers spanning multiple tensors")
|
||||
registered_storage.ensure_immutable()
|
||||
return registered_storage
|
||||
|
||||
def add_tensor(self, t: torch.Tensor):
|
||||
storage = t.untyped_storage()
|
||||
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
|
||||
return t.data_ptr()
|
||||
|
||||
def load(
|
||||
self,
|
||||
pointer: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
other=0.0,
|
||||
):
|
||||
assert pointer.is_cuda
|
||||
assert 0 < pointer.dim() < 3
|
||||
assert pointer.dtype == torch.int64
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(pointer).bool()
|
||||
assert mask.is_cuda
|
||||
assert 0 < mask.dim() < 3
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.expand(pointer.size())
|
||||
|
||||
if torch.all(~mask):
|
||||
# Todo: The type is wrong here, we can't determine the correct type
|
||||
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
|
||||
|
||||
registered_storage = self._get_registered_storage(pointer[mask])
|
||||
access_tensor = registered_storage.access_tensor
|
||||
|
||||
index_tensor = pointer - registered_storage.ptr
|
||||
|
||||
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
|
||||
block[mask] = access_tensor[index_tensor[mask]]
|
||||
return block
|
||||
|
||||
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
||||
assert 0 < pointer.dim() < 3
|
||||
assert pointer.dtype == torch.int64
|
||||
|
||||
if mask is None:
|
||||
mask = torch.ones_like(pointer).bool()
|
||||
assert 0 < mask.dim() < 3
|
||||
assert mask.dtype == torch.bool
|
||||
mask = mask.expand(pointer.size())
|
||||
|
||||
if torch.all(~mask):
|
||||
return
|
||||
|
||||
registered_storage = self._get_registered_storage(pointer[mask])
|
||||
access_tensor = registered_storage.access_tensor
|
||||
|
||||
index_tensor = pointer - registered_storage.ptr
|
||||
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)
|
||||
621
python/triton/debugger/tl_lang.py
Normal file
621
python/triton/debugger/tl_lang.py
Normal file
@@ -0,0 +1,621 @@
|
||||
import triton
|
||||
from .core import ExecutionContext
|
||||
from .memory_map import MemoryMap
|
||||
from triton.debugger import torch_wrapper
|
||||
|
||||
torch = torch_wrapper.torch
|
||||
|
||||
|
||||
def _primitive_to_tensor(x):
|
||||
"""
|
||||
Converts various Python primitive data types to PyTorch tensor.
|
||||
"""
|
||||
tensor_args = {"device": "cuda"}
|
||||
if isinstance(x, bool):
|
||||
return torch.tensor([x], dtype=torch.bool, **tensor_args)
|
||||
elif isinstance(x, int):
|
||||
if -(2**31) <= x < 2**31:
|
||||
return torch.tensor([x], dtype=torch.int32, **tensor_args)
|
||||
elif -(2**63) <= x < 2**63:
|
||||
return torch.tensor([x], dtype=torch.int64, **tensor_args)
|
||||
else:
|
||||
raise RuntimeError(f"Nonrepresentable integer {x}.")
|
||||
elif isinstance(x, float):
|
||||
return torch.tensor([x], dtype=torch.float32, **tensor_args)
|
||||
elif torch.is_tensor(x):
|
||||
return x
|
||||
elif isinstance(x, WrappedTensor):
|
||||
return x
|
||||
elif isinstance(x, debugger_constexpr):
|
||||
if x.value is None:
|
||||
return None
|
||||
return _primitive_to_tensor(x.value)
|
||||
elif x is None:
|
||||
return None
|
||||
assert False, f"cannot convert {x} of type {type(x)} to tensor"
|
||||
|
||||
|
||||
def _infer_tensor(func):
|
||||
"""
|
||||
A decorator function to harmonize function args:
|
||||
- converts primitives to PyTorch tensors
|
||||
- wraps PyTorch tensors with WrappedTensors
|
||||
"""
|
||||
def wrapper(*args):
|
||||
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
|
||||
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
|
||||
|
||||
return func(*new_args)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _tensor_operation(func):
|
||||
"""
|
||||
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
|
||||
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
|
||||
"""
|
||||
def wrapper(*args, **kwargs):
|
||||
for arg in args:
|
||||
assert not torch.is_tensor(arg), "unexpected tensor argument"
|
||||
|
||||
def unwrap_tensor(v):
|
||||
if isinstance(v, WrappedTensor):
|
||||
return v.tensor
|
||||
if isinstance(v, debugger_constexpr):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
new_args = tuple(map(unwrap_tensor, args))
|
||||
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
|
||||
|
||||
result = func(args[0], *new_args[1:], **new_kwargs)
|
||||
return WrappedTensor(result) if torch.is_tensor(result) else result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
class debugger_constexpr:
|
||||
def __init__(self, value):
|
||||
if isinstance(value, debugger_constexpr):
|
||||
self.value = value.value
|
||||
else:
|
||||
self.value = value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "debugger_constexpr(" + str(self.value) + ")"
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.value
|
||||
|
||||
def __bool__(self):
|
||||
return bool(self.value)
|
||||
|
||||
def __ge__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value >= other
|
||||
|
||||
def __gt__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value > other
|
||||
|
||||
def __le__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value <= other
|
||||
|
||||
def __lt__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value < other
|
||||
|
||||
def __eq__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value == other
|
||||
|
||||
def __or__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value | other
|
||||
|
||||
def __ror__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value | other
|
||||
|
||||
def __and__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value & other
|
||||
|
||||
def __rand__(self, other):
|
||||
other = other.value if isinstance(other, debugger_constexpr) else other
|
||||
return self.value & other
|
||||
|
||||
def to(self, dtype, bitcast=False, _builder=None):
|
||||
if dtype in [torch.int64]:
|
||||
ret_ty = int
|
||||
elif dtype == torch.bool:
|
||||
ret_ty = bool
|
||||
elif dtype in [torch.float64]:
|
||||
ret_ty = float
|
||||
else:
|
||||
raise ValueError("dtype not supported in debugger")
|
||||
return debugger_constexpr(ret_ty(self.value))
|
||||
|
||||
|
||||
class WrappedTensor:
|
||||
def __init__(self, tensor):
|
||||
self.tensor = tensor
|
||||
|
||||
def __index__(self) -> int:
|
||||
return self.tensor.item()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "wrapped_" + str(self.tensor)
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return torch.all(self.tensor == True).item() # noqa: E712
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.tensor.dtype
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __add__(self, other):
|
||||
return torch.add(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __radd__(self, other):
|
||||
return self.__add__(other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __sub__(self, other):
|
||||
return torch.sub(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rsub__(self, other):
|
||||
return torch.sub(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __mul__(self, other):
|
||||
return torch.mul(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rmul__(self, other):
|
||||
return self.__mul__(other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __truediv__(self, other):
|
||||
return torch.div(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rtruediv__(self, other):
|
||||
return torch.div(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __floordiv__(self, other):
|
||||
return torch.floor_divide(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rfloordiv__(self, other):
|
||||
return torch.floor_divide(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __mod__(self, other):
|
||||
return torch.remainder(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rmod__(self, other):
|
||||
return torch.remainder(other, self.tensor)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __neg__(self):
|
||||
return -self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __invert__(self):
|
||||
return ~self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __and__(self, other):
|
||||
return torch.bitwise_and(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __or__(self, other):
|
||||
return torch.bitwise_or(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __xor__(self, other):
|
||||
return torch.bitwise_xor(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __lshift__(self, other):
|
||||
return torch.bitwise_left_shift(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rshift__(self, other):
|
||||
return torch.bitwise_right_shift(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __gt__(self, other):
|
||||
return self.tensor > other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rgt__(self, other):
|
||||
return other > self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __ge__(self, other):
|
||||
return self.tensor >= other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rge__(self, other):
|
||||
return other >= self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __lt__(self, other):
|
||||
return self.tensor < other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rlt__(self, other):
|
||||
return other < self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __le__(self, other):
|
||||
return self.tensor <= other
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __rle__(self, other):
|
||||
return other <= self.tensor
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __eq__(self, other):
|
||||
return torch.equal(self.tensor, other)
|
||||
|
||||
@_infer_tensor
|
||||
@_tensor_operation
|
||||
def __ne__(self, other):
|
||||
return not torch.equal(self.tensor, other)
|
||||
|
||||
@_tensor_operation
|
||||
def __getitem__(self, slices):
|
||||
return self.tensor.__getitem__(slices)
|
||||
# if isinstance(slices, slice):
|
||||
# slices = [slices]
|
||||
# src_shape = self.shape
|
||||
# dst_shape = []
|
||||
# curr = 0
|
||||
# for sl in slices:
|
||||
# if isinstance(sl, constexpr) and sl.value is None:
|
||||
# dst_shape.append(1)
|
||||
# elif sl == slice(None, None, None):
|
||||
# dst_shape.append(src_shape[curr].value)
|
||||
# curr += 1
|
||||
# ret = torch.reshape(self.tensor, dst_shape, )
|
||||
# return ret
|
||||
|
||||
@_tensor_operation
|
||||
def to(self, dtype, bitcast=False):
|
||||
return self.tensor.to(dtype)
|
||||
# if isinstance(bitcast, constexpr):
|
||||
# bitcast = bitcast.value
|
||||
# if bitcast:
|
||||
# return semantic.bitcast(self, dtype, )
|
||||
# return semantic.cast(self, dtype, )
|
||||
|
||||
|
||||
def _constexpr_to_value(v):
|
||||
if isinstance(v, debugger_constexpr):
|
||||
return v.value
|
||||
return v
|
||||
|
||||
|
||||
class TritonLangProxy:
|
||||
_memory_map: MemoryMap
|
||||
_context: ExecutionContext
|
||||
|
||||
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
|
||||
self._memory_map = memory_map
|
||||
self._context = context
|
||||
|
||||
# Types
|
||||
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
|
||||
|
||||
# constexpr = debugger_constexpr
|
||||
|
||||
# Program functions
|
||||
|
||||
@_tensor_operation
|
||||
def load(
|
||||
self,
|
||||
pointer: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
other=0.0,
|
||||
cache_modifier="",
|
||||
eviction_policy="",
|
||||
volatile=False,
|
||||
):
|
||||
return self._memory_map.load(pointer, mask, other)
|
||||
|
||||
@_tensor_operation
|
||||
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
|
||||
return self._memory_map.store(pointer, value, mask)
|
||||
|
||||
@_tensor_operation
|
||||
def program_id(self, axis):
|
||||
assert axis < len(self._context.program_id)
|
||||
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def num_programs(self, axis):
|
||||
assert axis < len(self._context.program_size)
|
||||
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def arange(self, start, end):
|
||||
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def zeros(self, shape, dtype):
|
||||
for i, d in enumerate(shape):
|
||||
if not isinstance(d, debugger_constexpr):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr`")
|
||||
if not isinstance(d.value, int):
|
||||
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
|
||||
shape = [x.value for x in shape]
|
||||
if isinstance(dtype, triton.language.core.dtype):
|
||||
if dtype.is_fp32():
|
||||
dtype = torch.float32
|
||||
elif dtype.is_fp16():
|
||||
dtype = torch.float16
|
||||
elif dtype.is_bf16():
|
||||
dtype = torch.bfloat16
|
||||
elif dtype.is_int32():
|
||||
dtype = torch.int32
|
||||
elif dtype.is_int16():
|
||||
dtype = torch.int16
|
||||
elif dtype.is_int8():
|
||||
dtype = torch.int8
|
||||
else:
|
||||
raise TypeError(f"Unsupported dtype {dtype}")
|
||||
return torch.zeros(size=shape, dtype=dtype, device="cuda")
|
||||
|
||||
@_tensor_operation
|
||||
def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def broadcast(self, input, other):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def broadcast_to(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def cat(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def reshape(self, input, shape):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
|
||||
assert input.dtype == other.dtype
|
||||
if trans_a:
|
||||
input = input.T
|
||||
if trans_b:
|
||||
other = other.T
|
||||
return torch.matmul(input=input, other=other)
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_cas(self, pointer, cmp, val):
|
||||
stored = self._memory_map.load(pointer, None, 0.0)
|
||||
if not isinstance(cmp, torch.Tensor):
|
||||
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
|
||||
if not isinstance(val, torch.Tensor):
|
||||
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
|
||||
if stored == cmp:
|
||||
self._memory_map.store(pointer, val, None)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_xchg(self, pointer, val, mask=None):
|
||||
if isinstance(val, int):
|
||||
val = torch.tensor([val], dtype=torch.int32, device="cuda")
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
self._memory_map.store(pointer, val, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_add(self, pointer, val, mask=None):
|
||||
# arbitrary other value as it will masked during storing
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = stored + val
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_max(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = torch.maximum(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_min(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0.0)
|
||||
result = torch.minimum(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_and(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_and(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_or(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_or(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def atomic_xor(self, pointer, val, mask=None):
|
||||
stored = self._memory_map.load(pointer, mask, 0)
|
||||
result = torch.bitwise_xor(stored, val)
|
||||
self._memory_map.store(pointer, result, mask)
|
||||
return stored
|
||||
|
||||
@_tensor_operation
|
||||
def where(self, condition, x, y):
|
||||
condition = _primitive_to_tensor(condition)
|
||||
x = _primitive_to_tensor(x)
|
||||
y = _primitive_to_tensor(y)
|
||||
return torch.where(condition, x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def umulhi(self, x, y):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def fdiv(self, x, y, ieee_rounding=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def exp(self, x):
|
||||
return torch.exp(x)
|
||||
|
||||
@_tensor_operation
|
||||
def log(self, x):
|
||||
return torch.log(x)
|
||||
|
||||
@_tensor_operation
|
||||
def cos(self, x):
|
||||
return torch.cos(x)
|
||||
|
||||
@_tensor_operation
|
||||
def sin(self, x):
|
||||
return torch.sin(x)
|
||||
|
||||
@_tensor_operation
|
||||
def sqrt(self, x):
|
||||
return torch.sqrt(x)
|
||||
|
||||
@_tensor_operation
|
||||
def globaltimer(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def clock(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def debug_barrier(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def multiple_of(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def max_contiguous(self, input, values):
|
||||
return input
|
||||
|
||||
@_tensor_operation
|
||||
def abs(self, x):
|
||||
return torch.abs(x)
|
||||
|
||||
@_tensor_operation
|
||||
def cdiv(self, x, div):
|
||||
return (x + div - 1) // div
|
||||
|
||||
@_tensor_operation
|
||||
def minimum(self, x, y):
|
||||
if isinstance(x, int):
|
||||
x = torch.tensor(x, device="cuda")
|
||||
if isinstance(y, int):
|
||||
y = torch.tensor(y, device="cuda")
|
||||
return torch.minimum(x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def maximum(self, x, y):
|
||||
return torch.maximum(x, y)
|
||||
|
||||
@_tensor_operation
|
||||
def sigmoid(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def softmax(self, x, ieee_rounding=False):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def ravel(self, x):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def swizzle2d(self, i, j, size_i, size_j, size_g):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def zeros_like(self, input):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def max(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.max(input)
|
||||
return torch.max(input, dim=axis).values
|
||||
|
||||
@_tensor_operation
|
||||
def argmax(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def min(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.min(input)
|
||||
return torch.min(input, dim=axis).values
|
||||
|
||||
@_tensor_operation
|
||||
def argmin(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
|
||||
@_tensor_operation
|
||||
def sum(self, input, axis=None):
|
||||
if axis is None:
|
||||
return torch.sum(input)
|
||||
return torch.sum(input, dim=axis)
|
||||
|
||||
@_tensor_operation
|
||||
def xor_sum(self, input, axis):
|
||||
raise NotImplementedError()
|
||||
18
python/triton/debugger/torch_wrapper.py
Normal file
18
python/triton/debugger/torch_wrapper.py
Normal file
@@ -0,0 +1,18 @@
|
||||
try:
|
||||
import torch as _torch
|
||||
except ImportError:
|
||||
_torch = None
|
||||
|
||||
|
||||
class TorchWrapper:
|
||||
"""
|
||||
Helps in making torch an optional dependency
|
||||
"""
|
||||
|
||||
def __getattr__(self, name):
|
||||
if _torch is None:
|
||||
raise ImportError("Triton requires PyTorch to be installed")
|
||||
return getattr(_torch, name)
|
||||
|
||||
|
||||
torch = TorchWrapper()
|
||||
@@ -39,6 +39,7 @@ from .core import (
|
||||
dot,
|
||||
dtype,
|
||||
exp,
|
||||
expand_dims,
|
||||
full,
|
||||
fdiv,
|
||||
float16,
|
||||
@@ -130,6 +131,7 @@ __all__ = [
|
||||
"dot",
|
||||
"dtype",
|
||||
"exp",
|
||||
"expand_dims",
|
||||
"extra",
|
||||
"fdiv",
|
||||
"float16",
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from functools import wraps
|
||||
from typing import Callable, List, TypeVar
|
||||
from typing import Callable, List, Sequence, TypeVar
|
||||
|
||||
import triton
|
||||
from . import semantic
|
||||
@@ -903,6 +903,41 @@ def reshape(input, shape, _builder=None):
|
||||
shape = _shape_check_impl(shape)
|
||||
return semantic.reshape(input, shape, _builder)
|
||||
|
||||
|
||||
def _wrap_axis(axis, ndim):
|
||||
if not (-ndim <= axis < ndim):
|
||||
raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}")
|
||||
|
||||
return axis if axis >= 0 else axis + ndim
|
||||
|
||||
|
||||
@builtin
|
||||
def expand_dims(input, axis, _builder=None):
|
||||
"""
|
||||
Expand the shape of a tensor, by inserting new length-1 dimensions.
|
||||
|
||||
Axis indices are with respect to the resulting tensor, so
|
||||
``result.shape[axis]`` will be 1 for each axis.
|
||||
|
||||
:param input: The input tensor.
|
||||
:type input: tl.tensor
|
||||
:param axis: The indices to add new axes
|
||||
:type axis: int | Sequence[int]
|
||||
|
||||
"""
|
||||
axis = _constexpr_to_value(axis)
|
||||
axes = list(axis) if isinstance(axis, Sequence) else [axis]
|
||||
new_ndim = len(input.shape) + len(axes)
|
||||
axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]
|
||||
|
||||
if len(set(axes)) != len(axes):
|
||||
raise ValueError(f"expand_dims recieved duplicate axes, normalized axes = {axes}")
|
||||
|
||||
ret = input
|
||||
for a in sorted(axes):
|
||||
ret = semantic.expand_dims(ret, a, _builder)
|
||||
return ret
|
||||
|
||||
# -----------------------
|
||||
# Linear Algebra
|
||||
# -----------------------
|
||||
@@ -1301,9 +1336,9 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
|
||||
|
||||
if len(input.shape) > 1:
|
||||
# Broadcast index across the non-reduced axes
|
||||
expand_dims_index = [constexpr(None)] * len(input.shape)
|
||||
expand_dims_index[axis] = slice(None)
|
||||
index = index.__getitem__(expand_dims_index, _builder=_builder)
|
||||
axes_to_expand = [constexpr(d) for d in range(len(input.shape))]
|
||||
del axes_to_expand[axis]
|
||||
index = expand_dims(index, axes_to_expand, _builder=_builder)
|
||||
index = broadcast_to(index, input.shape, _builder=_builder)
|
||||
|
||||
rvalue, rindices = reduce((input, index), axis, combine_fn,
|
||||
|
||||
@@ -3,9 +3,6 @@ from __future__ import annotations # remove after python 3.11
|
||||
from functools import wraps
|
||||
from typing import List, Optional, Sequence, Tuple, TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
import triton
|
||||
from . import core as tl
|
||||
from triton._C.libtriton.triton import ir
|
||||
|
||||
@@ -665,7 +662,7 @@ def bitcast(input: tl.tensor,
|
||||
src_bits = src_sca_ty.primitive_bitwidth
|
||||
dst_bits = dst_sca_ty.primitive_bitwidth
|
||||
if src_bits != dst_bits:
|
||||
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to "
|
||||
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
|
||||
"data-type of size " + str(dst_bits))
|
||||
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
@@ -1190,14 +1187,6 @@ def dot(lhs: tl.tensor,
|
||||
allow_tf32: bool,
|
||||
out_dtype: tl.dtype,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
if torch.version.hip is None:
|
||||
device = triton.runtime.jit.get_current_device()
|
||||
capability = triton.runtime.jit.get_device_capability(device)
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < 70:
|
||||
assert (
|
||||
not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8()
|
||||
), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)"
|
||||
assert lhs.type.is_block() and rhs.type.is_block()
|
||||
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
|
||||
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
|
||||
@@ -1398,6 +1387,10 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.
|
||||
|
||||
|
||||
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
|
||||
cond_ty = cond.type
|
||||
if not cond_ty.is_block():
|
||||
cond_ty = tl.block_type(cond_ty.scalar, (1,))
|
||||
cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty)
|
||||
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
|
||||
|
||||
|
||||
|
||||
@@ -4,17 +4,6 @@ import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
def next_power_of_2(n):
|
||||
n -= 1
|
||||
n |= n >> 1
|
||||
n |= n >> 2
|
||||
n |= n >> 4
|
||||
n |= n >> 8
|
||||
n |= n >> 16
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
def num_warps(N):
|
||||
if N < 2048:
|
||||
return 4
|
||||
@@ -24,7 +13,7 @@ def num_warps(N):
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
@@ -49,7 +38,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
|
||||
|
||||
|
||||
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
|
||||
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
|
||||
@triton.jit
|
||||
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
|
||||
row = tl.program_id(0)
|
||||
|
||||
@@ -142,6 +142,7 @@ class Autotuner(KernelInterface):
|
||||
class Config:
|
||||
"""
|
||||
An object that represents a possible kernel configuration for the auto-tuner to try.
|
||||
|
||||
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
|
||||
:type meta: dict[Str, Any]
|
||||
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
|
||||
@@ -173,8 +174,10 @@ class Config:
|
||||
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
"""
|
||||
Decorator for auto-tuning a :code:`triton.jit`'d function.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.autotune(configs=[
|
||||
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
|
||||
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
|
||||
@@ -223,8 +226,10 @@ def heuristics(values):
|
||||
"""
|
||||
Decorator for specifying how the values of certain meta-parameters may be computed.
|
||||
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
|
||||
|
||||
.. highlight:: python
|
||||
.. code-block:: python
|
||||
|
||||
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
|
||||
@triton.jit
|
||||
def kernel(x_ptr, x_size, **META):
|
||||
|
||||
@@ -83,7 +83,8 @@ class DependenciesFinder(ast.NodeVisitor):
|
||||
finder = DependenciesFinder(func.__globals__, func.src)
|
||||
finder.visit(tree)
|
||||
func.hash = finder.ret
|
||||
self.ret = (self.ret + func.hash).encode("utf-8")
|
||||
noinline = str(getattr(func, 'noinline', False))
|
||||
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
|
||||
self.ret = hashlib.md5(self.ret).hexdigest()
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
@@ -175,7 +176,7 @@ class JITFunction(KernelInterface[T]):
|
||||
return True
|
||||
return False
|
||||
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
|
||||
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
|
||||
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
|
||||
|
||||
@@ -298,13 +299,13 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
set_current_device(device)
|
||||
if stream is None and not warmup:
|
||||
stream = get_cuda_stream(device)
|
||||
try:
|
||||
bin = cache[device][key]
|
||||
bin = cache[device].get(key, None)
|
||||
if bin is not None:
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
|
||||
return bin
|
||||
# kernel not cached -- compile
|
||||
except KeyError:
|
||||
else:
|
||||
# build dict of constant values
|
||||
args = [{args}]
|
||||
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
|
||||
@@ -334,7 +335,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None):
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
@@ -356,6 +357,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug
|
||||
self.noinline = noinline
|
||||
# annotations
|
||||
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
|
||||
self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
|
||||
@@ -425,6 +427,7 @@ def jit(
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
noinline: Optional[bool] = None,
|
||||
) -> Callable[[T], JITFunction[T]]:
|
||||
...
|
||||
|
||||
@@ -435,6 +438,8 @@ def jit(
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
noinline: Optional[bool] = None,
|
||||
interpret: Optional[bool] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
@@ -456,13 +461,17 @@ def jit(
|
||||
|
||||
def decorator(fn: T) -> JITFunction[T]:
|
||||
assert callable(fn)
|
||||
return JITFunction(
|
||||
fn,
|
||||
version=version,
|
||||
do_not_specialize=do_not_specialize,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
if interpret:
|
||||
from ..debugger.debugger import GridSelector
|
||||
return GridSelector(fn)
|
||||
else:
|
||||
return JITFunction(
|
||||
fn,
|
||||
version=version,
|
||||
do_not_specialize=do_not_specialize,
|
||||
debug=debug,
|
||||
noinline=noinline,
|
||||
)
|
||||
if fn is not None:
|
||||
return decorator(fn)
|
||||
|
||||
|
||||
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
Executable file
BIN
python/triton/third_party/cuda/bin/ptxas
vendored
Executable file
Binary file not shown.
@@ -6,8 +6,8 @@
|
||||
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
|
||||
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
|
||||
|
||||
// CHECK-LABEL: matmul_loop
|
||||
// There shouldn't be any aliasing with the dot op encoding.
|
||||
|
||||
@@ -402,8 +402,6 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
|
||||
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
|
||||
// CHECK-LABEL: @store_constant_align
|
||||
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
|
||||
@@ -433,8 +431,6 @@ tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// This IR is dumped from vecadd test.
|
||||
@@ -491,3 +487,88 @@ tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
|
||||
tt.store %15, %13, %10 : tensor<64xf32>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
|
||||
// We don't use function cloning here, so the alignment info is the gcd of all call sites.
|
||||
// CHECK-LABEL: @addptr_hints
|
||||
tt.func @addptr_hints(%arg0: !tt.ptr<i32>) {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
||||
%cst1 = arith.constant 1 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
|
||||
%1 = tt.addptr %arg0, %cst1 : !tt.ptr<i32>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4
|
||||
%cst4 = arith.constant 4 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
|
||||
%2 = tt.addptr %arg0, %cst4 : !tt.ptr<i32>, i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16
|
||||
%cst16 = arith.constant 16 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
|
||||
%3 = tt.addptr %arg0, %cst4 : !tt.ptr<i32>, i32
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @kernel_div16
|
||||
tt.func @kernel_div16(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @kernel_div8
|
||||
tt.func @kernel_div8(%arg0: !tt.ptr<i32> {tt.divisibility = 8 : i32}) {
|
||||
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @kernel_div4
|
||||
tt.func @kernel_div4(%arg0: !tt.ptr<i32> {tt.divisibility = 4 : i32}) {
|
||||
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module {
|
||||
|
||||
// We don't use function cloning here, so the alignment info is the gcd of all call sites.
|
||||
// CHECK-LABEL: @mul
|
||||
tt.func @mul(%arg0: i32) {
|
||||
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
|
||||
%cst1 = arith.constant 1 : i32
|
||||
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
|
||||
%1 = arith.muli %arg0, %cst1 : i32
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @bar
|
||||
tt.func @bar(%arg0: i32) {
|
||||
tt.call @mul(%arg0) : (i32) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @foo
|
||||
tt.func @foo(%arg0: i32) {
|
||||
tt.call @mul(%arg0) : (i32) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @call_graph
|
||||
tt.func @call_graph(%arg0: i32) {
|
||||
// CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 12
|
||||
%cst12 = arith.constant 12 : i32
|
||||
// CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
|
||||
%0 = arith.muli %arg0, %cst12 : i32
|
||||
tt.call @foo(%0) : (i32) -> ()
|
||||
// CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = 8
|
||||
%cst8 = arith.constant 8 : i32
|
||||
// CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
|
||||
%1 = arith.muli %arg0, %cst8 : i32
|
||||
tt.call @bar(%1) : (i32) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
|
||||
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
@@ -28,10 +28,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK: offset = 0, size = 4608
|
||||
// CHECK: scratch offset = 0, size = 4608
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
// CHECK-NEXT: offset = 0, size = 4224
|
||||
// CHECK-NEXT: scratch offset = 0, size = 4224
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
@@ -56,17 +56,17 @@ tt.func @reusable(%A : !tt.ptr<f16>) {
|
||||
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
|
||||
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK-NEXT: offset = 0, size = 4608
|
||||
// CHECK-NEXT: scratch offset = 0, size = 4608
|
||||
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
|
||||
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK-NEXT: offset = 0, size = 1152
|
||||
// CHECK-NEXT: scratch offset = 0, size = 1152
|
||||
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
|
||||
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
// CHECK-NEXT: offset = 0, size = 4608
|
||||
// CHECK-NEXT: scratch offset = 0, size = 4608
|
||||
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
|
||||
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
|
||||
// CHECK-NEXT: offset = 0, size = 1152
|
||||
// CHECK-NEXT: scratch offset = 0, size = 1152
|
||||
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
|
||||
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
|
||||
tt.return
|
||||
@@ -396,3 +396,127 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK-LABEL: alloc1
|
||||
tt.func @alloc1(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc2
|
||||
tt.func @alloc2(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 1024
|
||||
%cst0 = triton_gpu.alloc_tensor : tensor<32x16xf16, #A_SHARED>
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 1024
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc3
|
||||
tt.func @alloc3(%cond : i1) {
|
||||
scf.if %cond {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
} else {
|
||||
// CHECK-NEXT: offset = 0, size = 1024
|
||||
%cst0 = triton_gpu.alloc_tensor : tensor<16x32xf16, #A_SHARED>
|
||||
}
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 1024
|
||||
}
|
||||
|
||||
// CHECK-LABEL: alloc4
|
||||
tt.func @alloc4(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
scf.if %cond {
|
||||
// CHECK: virtual offset = 0, size = 1024
|
||||
tt.call @alloc3(%cond) : (i1) -> ()
|
||||
} else {
|
||||
// CHECK-NEXT: virtual offset = 0, size = 512
|
||||
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
|
||||
}
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 1024
|
||||
}
|
||||
|
||||
// CHECK-LABEL: single_call
|
||||
tt.func @single_call(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
// CHECK-NEXT: virtual offset = 0, size = 512
|
||||
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// CHECK-LABEL: multiple_calls
|
||||
tt.func @multiple_calls(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: virtual offset = 0, size = 512
|
||||
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
// CHECK-NEXT: virtual offset = 0, size = 1024
|
||||
tt.call @alloc2(%A) : (!tt.ptr<f16>) -> ()
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 1024
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_else_calls
|
||||
tt.func @if_else_calls(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
scf.if %cond {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: offset = 0, size = 1024
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
// CHECK-NEXT: virtual offset = 0, size = 512
|
||||
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
|
||||
} else {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
// CHECK-NEXT: virtual offset = 0, size = 1024
|
||||
tt.call @alloc2(%A) : (!tt.ptr<f16>) -> ()
|
||||
}
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 1024
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_calls
|
||||
tt.func @for_calls(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
%lb = arith.constant 0 : index
|
||||
%ub = arith.constant 10 : index
|
||||
%step = arith.constant 1 : index
|
||||
scf.for %iv = %lb to %ub step %step {
|
||||
// CHECK-NEXT: virtual offset = 0, size = 512
|
||||
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
|
||||
}
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
|
||||
// CHECK-LABEL: call_graph_1
|
||||
tt.func @call_graph_1(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: virtual offset = 0, size = 1024
|
||||
tt.call @alloc3(%cond) : (i1) -> ()
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 1024
|
||||
}
|
||||
|
||||
// CHECK-LABEL: call_graph_2
|
||||
tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK-NEXT: virtual offset = 0, size = 1024
|
||||
tt.call @alloc4(%A, %cond) : (!tt.ptr<f16>, i1) -> ()
|
||||
tt.return
|
||||
// CHECK-NEXT: size = 1024
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
|
||||
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
|
||||
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
@@ -503,3 +503,136 @@ tt.func @cf_if_else_return(%i1 : i1) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
// CHECK-LABEL: convert_layout1
|
||||
tt.func @convert_layout1(%A : !tt.ptr<f16>) {
|
||||
// CHECK-NOT: gpu.barrier
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: convert_layout2
|
||||
tt.func @convert_layout2(%A : !tt.ptr<f16>) {
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
%1 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
%2 = tt.cat %1, %1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
|
||||
// CHECK: triton_gpu.convert_layout
|
||||
// CHECK-NEXT: gpu.barrier
|
||||
%3 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
%4 = tt.cat %2, %2 {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #AL>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: convert_layout3
|
||||
tt.func @convert_layout3(%cond : i1) {
|
||||
scf.if %cond {
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x64xf16, #A_SHARED>
|
||||
// CHECK: triton_gpu.convert_layout
|
||||
// CHECK-NOT: gpu.barrier
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<16x64xf16, #A_SHARED>) -> tensor<16x64xf16, #AL>
|
||||
} else {
|
||||
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: triton_gpu.convert_layout
|
||||
// CHECK-NEXT: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
%2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
|
||||
}
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHEKC-LABEL: convert_layout4
|
||||
tt.func @convert_layout4(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
// CHECK-NOT: gpu.barrier
|
||||
scf.if %cond {
|
||||
tt.call @convert_layout3(%cond) : (i1) -> ()
|
||||
} else {
|
||||
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
|
||||
}
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: single_call_sync
|
||||
tt.func @single_call_sync(%A : !tt.ptr<f16>) {
|
||||
%0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
// CHECK: tt.call
|
||||
// CHECK-NEXT: gpu.barrier
|
||||
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<16x32xf16, #AL>) -> tensor<16x32xf16, #BL>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: single_call_no_sync
|
||||
// %1 can reuse %0 in convert_layout2, which has been synced
|
||||
tt.func @single_call_no_sync(%A : !tt.ptr<f16>) {
|
||||
// CHECK-NOT: gpu.barrier
|
||||
%0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
|
||||
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #BL>
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: multiple_calls
|
||||
tt.func @multiple_calls(%A : !tt.ptr<f16>) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: if_else_calls
|
||||
tt.func @if_else_calls(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
scf.if %cond {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.call
|
||||
// CHECK-NEXT: gpu.barrier
|
||||
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
|
||||
} else {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
// CHECK: tt.call
|
||||
// CHECK-NOT: gpu.barrier
|
||||
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
|
||||
}
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: for_calls
|
||||
tt.func @for_calls(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
|
||||
%lb = arith.constant 0 : index
|
||||
%ub = arith.constant 10 : index
|
||||
%step = arith.constant 1 : index
|
||||
scf.for %iv = %lb to %ub step %step {
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.call
|
||||
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
|
||||
}
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: call_graph_1
|
||||
tt.func @call_graph_1(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: tt.call
|
||||
tt.call @convert_layout3(%cond) : (i1) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: call_graph_2
|
||||
tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
|
||||
tt.call @convert_layout4(%A, %cond) : (!tt.ptr<f16>, i1) -> ()
|
||||
// CHECK: tt.call
|
||||
// CHECK-NEXT: gpu.barrier
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1132,8 +1132,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
|
||||
#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// PTX-LABEL: convert_dot
|
||||
// This test is disabled for GCN target, because it is PTX specific
|
||||
@@ -1250,7 +1250,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
// CHECK-LABEL: convert_blocked1d_to_slice1
|
||||
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
|
||||
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||
// CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
|
||||
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
|
||||
tt.return
|
||||
}
|
||||
@@ -1279,8 +1279,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX-LABEL: matmul_kernel_dot_operand_layout
|
||||
// This test is disabled for GCN target, because it is PTX specific
|
||||
@@ -1357,8 +1357,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}>
|
||||
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
|
||||
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX-LABEL: matmul_tf32dot
|
||||
// This test is disabled for GCN target, because it is PTX specific
|
||||
@@ -1366,12 +1366,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
|
||||
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
|
||||
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
|
||||
<<<<<<< HEAD
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
// PTX-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>)
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
// PTX-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>)
|
||||
=======
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
// CHECK-SAME: (i32, i32, i32, i32)
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
|
||||
// CHECK-SAME: (i32, i32, i32, i32)
|
||||
>>>>>>> openai/main
|
||||
%a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a>
|
||||
%b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b>
|
||||
|
||||
@@ -1399,6 +1408,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32
|
||||
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||
<<<<<<< HEAD
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
|
||||
// PTX: llvm.icmp "slt"
|
||||
@@ -1406,6 +1416,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
=======
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
>>>>>>> openai/main
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
||||
tt.return
|
||||
}
|
||||
@@ -1416,11 +1432,18 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32_scalar
|
||||
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
|
||||
<<<<<<< HEAD
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
|
||||
// PTX: llvm.icmp "eq"
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
=======
|
||||
// CHECK: llvm.icmp "eq"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$3 atom.global.gpu.add.f32
|
||||
>>>>>>> openai/main
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
|
||||
tt.return
|
||||
}
|
||||
@@ -1432,6 +1455,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32
|
||||
tt.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
|
||||
<<<<<<< HEAD
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
|
||||
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
|
||||
@@ -1440,6 +1464,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// PTX-SAME: @$2 st.global.b32
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$2 st.global.b32
|
||||
=======
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
>>>>>>> openai/main
|
||||
tt.store %arg0, %arg1 : tensor<256xf32, #blocked0>
|
||||
tt.return
|
||||
}
|
||||
@@ -1450,11 +1480,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32_scalar
|
||||
tt.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
|
||||
<<<<<<< HEAD
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
|
||||
// PTX: llvm.icmp "slt"
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$2 st.global.b32
|
||||
=======
|
||||
// CHECK: llvm.icmp "eq"
|
||||
// CHECK: llvm.inline_asm
|
||||
// CHECK-SAME: @$2 st.global.b32
|
||||
>>>>>>> openai/main
|
||||
tt.store %arg0, %arg1 : f32
|
||||
tt.return
|
||||
}
|
||||
|
||||
22
test/Target/tritongpu_to_llvmir_noinline.mlir
Normal file
22
test/Target/tritongpu_to_llvmir_noinline.mlir
Normal file
@@ -0,0 +1,22 @@
|
||||
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s
|
||||
|
||||
// == LLVM IR check begin ==
|
||||
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
|
||||
// CHECK: define void @test_func
|
||||
// CHECK: define void @test_kernel
|
||||
// CHECK: tail call void @test_func
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
tt.func @test_func(%lb : index, %A : !tt.ptr<f16>) attributes { noinline = true } {
|
||||
%0 = arith.constant 1.0 : f16
|
||||
tt.store %A, %0 : f16
|
||||
tt.return
|
||||
}
|
||||
|
||||
tt.func @test_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
tt.call @test_func(%lb, %A) : (index, !tt.ptr<f16>) -> ()
|
||||
tt.return
|
||||
}
|
||||
|
||||
}
|
||||
52
test/TritonGPU/dot-operands.mlir
Normal file
52
test/TritonGPU/dot-operands.mlir
Normal file
@@ -0,0 +1,52 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -tritongpu-remove-layout-conversions -canonicalize | FileCheck %s
|
||||
|
||||
#Cv2 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#Av2 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}>
|
||||
#Bv2 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}>
|
||||
#Cv1 = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [4, 1]}>
|
||||
#Av1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv1}>
|
||||
#Bv1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv1}>
|
||||
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
|
||||
|
||||
// CHECK: tt.func @push_elementwise1
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]]
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[AF16]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
|
||||
tt.func @push_elementwise1(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av2>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv2>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv2>
|
||||
}
|
||||
|
||||
// CHECK: tt.func @push_elementwise2
|
||||
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
|
||||
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
|
||||
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
|
||||
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
|
||||
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
|
||||
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1>
|
||||
tt.func @push_elementwise2(
|
||||
%pa: tensor<16x16x!tt.ptr<i8>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%pb: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
|
||||
%c: tensor<16x16xf32, #Cv1>) -> tensor<16x16xf32, #Cv1>{
|
||||
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL>
|
||||
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
|
||||
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL>
|
||||
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL>
|
||||
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av1>
|
||||
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv1>
|
||||
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1>
|
||||
tt.return %newc : tensor<16x16xf32, #Cv1>
|
||||
}
|
||||
@@ -8,8 +8,8 @@
|
||||
#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}>
|
||||
#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}>
|
||||
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
|
||||
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
|
||||
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
|
||||
|
||||
// CHECK: tt.func @matmul_loop
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-prefetch | FileCheck %s
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-prefetch -canonicalize | FileCheck %s
|
||||
|
||||
// 4 warps
|
||||
// matmul: 128x32 @ 32x128 -> 128x128
|
||||
@@ -7,33 +7,36 @@
|
||||
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
|
||||
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
|
||||
#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
|
||||
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
|
||||
#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>
|
||||
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>
|
||||
|
||||
|
||||
// CHECK: tt.func @matmul_loop
|
||||
// CHECK: tt.func @matmul_loop_mixed
|
||||
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[A0:.*]][0, 0] [128, 16]
|
||||
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
|
||||
// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]]
|
||||
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[B0:.*]][0, 0] [16, 128]
|
||||
// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.convert_layout %[[B0_PREFETCH_SMEM]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_PREFETCH]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]]
|
||||
// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_a0]][0, 16] [128, 16]
|
||||
// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.convert_layout %[[A_REM_SMEM]]
|
||||
// CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]]
|
||||
// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_b0]][16, 0] [16, 128]
|
||||
// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.convert_layout %[[B_REM_SMEM]]
|
||||
// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}}
|
||||
// CHECK: tt.dot %[[A_REM]], %[[B_REM]], %[[D_FIRST:.*]]
|
||||
// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]]
|
||||
// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [128, 16]
|
||||
// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_A_PREFETCH_SMEM]]
|
||||
// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]]
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [16, 128]
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
|
||||
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]]
|
||||
tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f8E5M2>, %B : !tt.ptr<f16>) -> tensor<128x128xf32, #C>{
|
||||
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f8E5M2>) -> tensor<128x32x!tt.ptr<f8E5M2>, #AL>
|
||||
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
|
||||
|
||||
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
|
||||
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
||||
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf8E5M2, #AL>
|
||||
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
|
||||
%b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL>
|
||||
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
|
||||
@@ -41,24 +44,25 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
|
||||
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
|
||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||
|
||||
%a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf8E5M2, #AL>
|
||||
%a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf8E5M2, #AL>) -> tensor<128x32xf8E5M2, #A>
|
||||
%b_ = tt.load %b_ptr_init, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b_init = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) {
|
||||
%a_op = triton_gpu.convert_layout %a : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A_OP>
|
||||
%loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f8E5M2>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf8E5M2, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) {
|
||||
%a_op_ = triton_gpu.convert_layout %a : (tensor<128x32xf8E5M2, #A>) -> tensor<128x32xf8E5M2, #A_OP>
|
||||
%a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP>
|
||||
%b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP>
|
||||
%c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C>
|
||||
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f8E5M2>, #AL>, tensor<128x32xi32, #AL>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
%next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf8E5M2, #AL>
|
||||
%next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> tensor<128x32xf8E5M2, #A>
|
||||
%next_b_ = tt.load %next_b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%next_b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr<f8E5M2>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf8E5M2, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>
|
||||
}
|
||||
tt.return
|
||||
tt.return %loop#4 : tensor<128x128xf32, #C>
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ add_mlir_library(TritonTestAnalysis
|
||||
TestMembar.cpp
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRPass
|
||||
TritonAnalysis
|
||||
${dialect_libs}
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestAllocationPass
|
||||
: public PassWrapper<TestAllocationPass, OperationPass<triton::FuncOp>> {
|
||||
: public PassWrapper<TestAllocationPass, OperationPass<ModuleOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
|
||||
|
||||
@@ -16,31 +16,37 @@ struct TestAllocationPass
|
||||
}
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
ModuleOp moduleOp = getOperation();
|
||||
// Convert to std::string can remove quotes from opName
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
Allocation allocation(operation);
|
||||
operation->walk([&](Operation *op) {
|
||||
auto scratchBufferId = allocation.getBufferId(op);
|
||||
if (scratchBufferId != Allocation::InvalidBufferId) {
|
||||
size_t offset = allocation.getOffset(scratchBufferId);
|
||||
size_t size = allocation.getAllocatedSize(scratchBufferId);
|
||||
os << "scratch offset = " << offset << ", size = " << size << "\n";
|
||||
}
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
auto bufferId = allocation.getBufferId(result);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
size_t offset = allocation.getOffset(bufferId);
|
||||
size_t size = allocation.getAllocatedSize(bufferId);
|
||||
os << "offset = " << offset << ", size = " << size << "\n";
|
||||
ModuleAllocation moduleAllocation(moduleOp);
|
||||
moduleOp.walk([&](triton::FuncOp funcOp) {
|
||||
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
|
||||
os << opName << "\n";
|
||||
auto *allocation = moduleAllocation.getFuncData(funcOp);
|
||||
funcOp.walk([&](Operation *op) {
|
||||
auto scratchBufferId = allocation->getBufferId(op);
|
||||
if (scratchBufferId != Allocation::InvalidBufferId) {
|
||||
size_t offset = allocation->getOffset(scratchBufferId);
|
||||
size_t size = allocation->getAllocatedSize(scratchBufferId);
|
||||
if (allocation->isVirtualBuffer(scratchBufferId))
|
||||
os << "virtual offset = " << offset << ", size = " << size << "\n";
|
||||
else
|
||||
os << "scratch offset = " << offset << ", size = " << size << "\n";
|
||||
}
|
||||
}
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
auto bufferId = allocation->getBufferId(result);
|
||||
if (bufferId != Allocation::InvalidBufferId) {
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
size_t size = allocation->getAllocatedSize(bufferId);
|
||||
os << "offset = " << offset << ", size = " << size << "\n";
|
||||
}
|
||||
}
|
||||
});
|
||||
os << "size = " << allocation->getSharedMemorySize() << "\n";
|
||||
});
|
||||
os << "size = " << allocation.getSharedMemorySize() << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestAxisInfoPass
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<triton::FuncOp>> {
|
||||
: public PassWrapper<TestAxisInfoPass, OperationPass<ModuleOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
|
||||
|
||||
@@ -18,23 +18,24 @@ struct TestAxisInfoPass
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << "@" << opName << "\n";
|
||||
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(operation)))
|
||||
return signalPassFailure();
|
||||
operation->walk([&](Operation *op) {
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
result.print(os);
|
||||
os << " => ";
|
||||
analysis->getLatticeElement(result)->getValue().print(os);
|
||||
os << "\n";
|
||||
}
|
||||
ModuleOp moduleOp = cast<ModuleOp>(operation);
|
||||
ModuleAxisInfoAnalysis moduleAxisInfoAnalysis(moduleOp);
|
||||
moduleOp.walk([&](triton::FuncOp funcOp) {
|
||||
auto &os = llvm::errs();
|
||||
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
|
||||
os << "@" << opName << "\n";
|
||||
funcOp.walk([&](Operation *op) {
|
||||
if (op->getNumResults() < 1)
|
||||
return;
|
||||
for (Value result : op->getResults()) {
|
||||
result.print(os);
|
||||
os << " => ";
|
||||
auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result);
|
||||
if (axisInfo)
|
||||
axisInfo->print(os);
|
||||
os << "\n";
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -11,7 +11,7 @@ using namespace mlir;
|
||||
namespace {
|
||||
|
||||
struct TestMembarPass
|
||||
: public PassWrapper<TestMembarPass, OperationPass<triton::FuncOp>> {
|
||||
: public PassWrapper<TestMembarPass, OperationPass<ModuleOp>> {
|
||||
|
||||
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
|
||||
|
||||
@@ -22,17 +22,11 @@ struct TestMembarPass
|
||||
|
||||
void runOnOperation() override {
|
||||
Operation *operation = getOperation();
|
||||
auto &os = llvm::errs();
|
||||
// Convert to std::string can remove quotes from op_name
|
||||
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
|
||||
os << opName << "\n";
|
||||
|
||||
ModuleOp moduleOp = cast<ModuleOp>(operation);
|
||||
// Print all ops after membar pass
|
||||
Allocation allocation(operation);
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
ModuleAllocation allocation(moduleOp);
|
||||
ModuleMembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
os << *operation << "\n";
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
add_triton_ut(
|
||||
NAME TestTritonAnalysis
|
||||
SRCS UtilityTest.cpp
|
||||
LIBS TritonAnalysis
|
||||
LIBS
|
||||
TritonAnalysis
|
||||
TritonIR
|
||||
TritonGPUIR
|
||||
)
|
||||
|
||||
@@ -30,7 +30,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
|
||||
// create encoding
|
||||
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1});
|
||||
auto encoding =
|
||||
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
|
||||
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent, 0);
|
||||
|
||||
// create element type
|
||||
Type eltType = IntegerType::get(&ctx, params.typeWidth);
|
||||
|
||||
Reference in New Issue
Block a user