Merge remote-tracking branch 'openai/main' into IFU-230517

Conflicts:
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/test/unit/language/assert_helper.py
	python/triton/third_party/cuda/bin/ptxas
	test/Conversion/tritongpu_to_llvm.mlir

 It looks like you may be committing a merge.
 If this is not correct, please remove the file
	.git/MERGE_HEAD
 and try again.
This commit is contained in:
Jason Furmanek
2023-05-17 15:03:42 +00:00
99 changed files with 4561 additions and 1251 deletions

View File

@@ -25,9 +25,9 @@ jobs:
id: set-matrix
run: |
if [ x"${{ github.repository }}" == x"openai/triton" ]; then
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"], "macos-10.15"]'
echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"]]'
else
echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]'
echo '::set-output name=matrix::["ubuntu-latest"]'
fi
Integration-Tests:
@@ -101,6 +101,18 @@ jobs:
cd python/test/unit
python3 -m pytest
- name: Create artifacts archive
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
run: |
tar -czvf artifacts.tar.gz ~/.triton/cache
- name: Upload artifacts archive
if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}}
uses: actions/upload-artifact@v2
with:
name: artifacts
path: artifacts.tar.gz
- name: Run CXX unittests
if: ${{ env.BACKEND != 'ROCM'}}
run: |

View File

@@ -30,7 +30,7 @@ jobs:
#export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest"
export CIBW_BEFORE_BUILD="pip install cmake;"
export CIBW_SKIP="{cp,pp}35-*"
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64"
export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64 cp3*-musllinux_x86_64"
python3 -m cibuildwheel python --output-dir wheelhouse

3
.gitmodules vendored
View File

@@ -1,3 +0,0 @@
[submodule "deps/dlfcn-win32"]
path = deps/dlfcn-win32
url = https://github.com/dlfcn-win32/dlfcn-win32.git

View File

@@ -49,12 +49,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
# Third-party
include_directories(${PYBIND11_INCLUDE_DIR})
if(WIN32)
SET(BUILD_SHARED_LIBS OFF)
find_package(dlfcn-win32 REQUIRED)
set(CMAKE_DL_LIBS dlfcn-win32::dl)
endif()
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
if (TRITON_USE_ROCM)

View File

@@ -34,6 +34,7 @@ Shape Manipulation Ops
:nosignatures:
broadcast_to
expand_dims
reshape
ravel

View File

@@ -1,6 +1,7 @@
#ifndef TRITON_ANALYSIS_ALLOCATION_H
#define TRITON_ANALYSIS_ALLOCATION_H
#include "triton/Analysis/Utility.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SetVector.h"
@@ -49,18 +50,25 @@ private:
T End = std::numeric_limits<T>::max();
};
template <class T> Interval(T, T) -> Interval<T>;
class Allocation {
public:
/// A unique identifier for shared memory buffers
using BufferId = size_t;
using BufferIdSetT = DenseSet<BufferId>;
using FuncAllocMapT = CallGraph<Allocation>::FuncDataMapT;
static constexpr BufferId InvalidBufferId =
std::numeric_limits<BufferId>::max();
Allocation() = default;
/// Creates a new Allocation analysis that computes the shared memory
/// information for all associated shared memory values.
Allocation(Operation *operation) : operation(operation) { run(); }
explicit Allocation(Operation *operation) : operation(operation) {}
/// Runs allocation analysis on the given top-level operation.
void run(FuncAllocMapT &funcAllocMap);
/// Returns the operation this analysis was constructed from.
Operation *getOperation() const { return operation; }
@@ -75,6 +83,12 @@ public:
return bufferSet.at(bufferId).size;
}
/// Returns the allocated interval of the given buffer.
Interval<size_t> getAllocatedInterval(BufferId bufferId) const {
auto &buffer = bufferSet.at(bufferId);
return Interval<size_t>(buffer.offset, buffer.offset + buffer.size);
}
/// Returns the buffer id of the given value.
/// This interface only returns the allocated buffer id.
/// If you want to get all the buffer ids that are associated with the given
@@ -104,26 +118,28 @@ public:
BufferId getBufferId(Operation *operation) const {
if (opScratch.count(operation)) {
return opScratch.lookup(operation)->id;
} else if (opVirtual.count(operation)) {
return opVirtual.lookup(operation)->id;
} else {
return InvalidBufferId;
}
}
/// Returns the size of the given buffer is a virtual buffer.
bool isVirtualBuffer(BufferId bufferId) const {
return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual;
}
/// Returns the size of total shared memory allocated
size_t getSharedMemorySize() const { return sharedMemorySize; }
bool isIntersected(BufferId lhsId, BufferId rhsId) const {
if (lhsId == InvalidBufferId || rhsId == InvalidBufferId)
return false;
auto lhsBuffer = bufferSet.at(lhsId);
auto rhsBuffer = bufferSet.at(rhsId);
return lhsBuffer.intersects(rhsBuffer);
}
private:
/// A class that represents a shared memory buffer
struct BufferT {
enum class BufferKind { Explicit, Scratch };
/// Explicit: triton_gpu.alloc_tensor
/// Scratch: triton_gpu.convert_layout
/// Virtual: triton.call
enum class BufferKind { Explicit, Scratch, Virtual };
/// MT: thread-safe
inline static std::atomic<BufferId> nextId = 0;
@@ -142,12 +158,6 @@ private:
BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {}
BufferT(BufferKind kind, size_t size, size_t offset)
: kind(kind), id(nextId++), size(size), offset(offset) {}
bool intersects(const BufferT &other) const {
return Interval<size_t>(offset, offset + size)
.intersects(
Interval<size_t>(other.offset, other.offset + other.size));
}
};
/// Op -> Scratch Buffer
@@ -158,8 +168,6 @@ private:
using AliasBufferMapT = llvm::MapVector<Value, llvm::SetVector<BufferT *>>;
/// BufferId -> Buffer
using BufferSetT = std::map<BufferId, BufferT>;
/// Runs allocation analysis on the given top-level operation.
void run();
private:
template <BufferT::BufferKind Kind, typename KeyType, typename... Args>
@@ -168,6 +176,8 @@ private:
bufferSet[buffer.id] = std::move(buffer);
if constexpr (Kind == BufferT::BufferKind::Explicit) {
valueBuffer[key] = &bufferSet[buffer.id];
} else if constexpr (Kind == BufferT::BufferKind::Virtual) {
opVirtual[key] = &bufferSet[buffer.id];
} else {
opScratch[key] = &bufferSet[buffer.id];
}
@@ -178,8 +188,9 @@ private:
}
private:
Operation *operation;
Operation *operation = nullptr;
OpScratchMapT opScratch;
OpScratchMapT opVirtual;
ValueBufferMapT valueBuffer;
AliasBufferMapT aliasBuffer;
BufferSetT bufferSet;
@@ -188,7 +199,53 @@ private:
friend class triton::AllocationAnalysis;
};
template <typename T> Interval(T, T) -> Interval<T>;
/// Static analysis that computes the allocation of shared memory buffers
/// of the entire call graph.
/// The allocation is performed in a post-order walk of the call graph.
/// Each call op is treated like convert_layout that allocates a scratch buffer.
/// At each call, we compute the start offset of the scratch buffer and pass it
/// as an argument to the callee.
class ModuleAllocation : public CallGraph<Allocation> {
public:
using FuncOffsetMapT = DenseMap<FunctionOpInterface, Value>;
explicit ModuleAllocation(ModuleOp moduleOp)
: CallGraph<Allocation>(moduleOp) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
// Pre-order edge walk callback
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
// Post-order node walk callback
[&](FunctionOpInterface funcOp) {
auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp);
if (inserted)
iter->second.run(funcMap);
});
}
size_t getSharedMemorySize() {
size_t size = 0;
for (auto funcOp : getRoots()) {
auto *alloc = getFuncData(funcOp);
size = std::max(size, alloc->getSharedMemorySize());
}
return size;
}
size_t getSharedMemorySize(FunctionOpInterface funcOp) {
return getFuncData(funcOp)->getSharedMemorySize();
}
void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) {
sharedMemoryValue[funcOp] = value;
}
Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) {
return sharedMemoryValue[funcOp];
}
private:
FuncOffsetMapT sharedMemoryValue;
};
} // namespace mlir

View File

@@ -286,16 +286,71 @@ public:
AxisInfoAnalysis(DataFlowSolver &solver);
using dataflow::SparseDataFlowAnalysis<
dataflow::Lattice<AxisInfo>>::getLatticeElement;
using FuncAxisInfoMapT = DenseMap<FunctionOpInterface, AxisInfo>;
void visitOperation(Operation *op,
ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
ArrayRef<dataflow::Lattice<AxisInfo> *> results) override;
};
/// Module level axis info analysis based on the call graph, assuming that we
/// do not have recursive functions.
/// Since each function will be called multiple times, we need to
/// calculate the axis info based on the axis info of all the callers.
/// In the future, we can perform optimization using function cloning so that
/// each call site will have unique axis info.
using AxisInfoMapT = DenseMap<Value, AxisInfo>;
class ModuleAxisInfoAnalysis : public CallGraph<AxisInfoMapT> {
public:
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
: CallGraph<AxisInfoMapT>(moduleOp) {
SmallVector<FunctionOpInterface> funcs;
for (auto root : getRoots()) {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
// Pre-order edge walk callback
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
// Post-order node walk callback
[&](FunctionOpInterface funcOp) {
funcs.push_back(funcOp);
funcMap.try_emplace(funcOp, AxisInfoMapT{});
});
}
SetVector<FunctionOpInterface> sortedFuncs(funcs.begin(), funcs.end());
SymbolTableCollection symbolTable;
for (auto funcOp : llvm::reverse(sortedFuncs)) {
initialize(funcOp);
funcOp.walk([&](CallOpInterface callOp) {
auto callee =
dyn_cast<FunctionOpInterface>(callOp.resolveCallable(&symbolTable));
update(callOp, callee);
});
}
}
AxisInfo *getAxisInfo(Value value) {
auto funcOp =
value.getParentRegion()->getParentOfType<FunctionOpInterface>();
auto *axisInfoMap = getFuncData(funcOp);
if (!axisInfoMap) {
return nullptr;
}
auto it = axisInfoMap->find(value);
if (it == axisInfoMap->end()) {
return nullptr;
}
return &(it->second);
}
unsigned getPtrContiguity(Value ptr);
unsigned getPtrAlignment(Value ptr);
unsigned getMaskAlignment(Value mask);
private:
void initialize(FunctionOpInterface funcOp);
void update(CallOpInterface callOp, FunctionOpInterface funcOp);
};
} // namespace mlir

View File

@@ -4,20 +4,75 @@
#include "Allocation.h"
#include "llvm/ADT/SmallPtrSet.h"
#include <set>
namespace mlir {
class OpBuilder;
struct BlockInfo {
using BufferIdSetT = Allocation::BufferIdSetT;
using IntervalSetT = std::set<Interval<size_t>>;
IntervalSetT syncReadIntervals;
IntervalSetT syncWriteIntervals;
BlockInfo() = default;
/// Unions two BlockInfo objects.
BlockInfo &join(const BlockInfo &other) {
syncReadIntervals.insert(other.syncReadIntervals.begin(),
other.syncReadIntervals.end());
syncWriteIntervals.insert(other.syncWriteIntervals.begin(),
other.syncWriteIntervals.end());
return *this;
}
/// Returns true if intervals in two BlockInfo objects are intersected.
bool isIntersected(const BlockInfo &other) const {
return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) ||
/*WAR*/
isIntersected(syncReadIntervals, other.syncWriteIntervals) ||
/*WAW*/
isIntersected(syncWriteIntervals, other.syncWriteIntervals);
}
/// Clears the intervals because a barrier is inserted.
void sync() {
syncReadIntervals.clear();
syncWriteIntervals.clear();
}
/// Compares two BlockInfo objects.
bool operator==(const BlockInfo &other) const {
return syncReadIntervals == other.syncReadIntervals &&
syncWriteIntervals == other.syncWriteIntervals;
}
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
private:
bool isIntersected(const IntervalSetT &lhsIntervalSet,
const IntervalSetT &rhsIntervalSet) const {
for (auto &lhs : lhsIntervalSet)
for (auto &rhs : rhsIntervalSet)
if (lhs.intersects(rhs))
return true;
return false;
}
};
//===----------------------------------------------------------------------===//
// Shared Memory Barrier Analysis
//===----------------------------------------------------------------------===//
class MembarAnalysis {
public:
using FuncBlockInfoMapT = CallGraph<BlockInfo>::FuncDataMapT;
/// Creates a new Membar analysis that generates the shared memory barrier
/// in the following circumstances:
/// - RAW: If a shared memory write is followed by a shared memory read, and
/// their addresses are intersected, a barrier is inserted.
/// - WAR: If a shared memory read is followed by a shared memory read, and
/// - WAR: If a shared memory read is followed by a shared memory write, and
/// their addresses are intersected, a barrier is inserted.
/// The following circumstances do not require a barrier:
/// - WAW: not possible because overlapped memory allocation is not allowed.
@@ -26,75 +81,14 @@ public:
/// a shared memory read. If the temporary storage is written but not read,
/// it is considered as the problem of the operation itself but not the membar
/// analysis.
/// The following circumstances are not considered yet:
/// - Double buffers
/// - N buffers
MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
MembarAnalysis() = default;
explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {}
/// Runs the membar analysis to the given operation, inserts a barrier if
/// necessary.
void run();
void run(FuncBlockInfoMapT &funcBlockInfoMap);
private:
struct BlockInfo {
using BufferIdSetT = Allocation::BufferIdSetT;
BufferIdSetT syncReadBuffers;
BufferIdSetT syncWriteBuffers;
BlockInfo() = default;
BlockInfo(const BufferIdSetT &syncReadBuffers,
const BufferIdSetT &syncWriteBuffers)
: syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) {
}
/// Unions two BlockInfo objects.
BlockInfo &join(const BlockInfo &other) {
syncReadBuffers.insert(other.syncReadBuffers.begin(),
other.syncReadBuffers.end());
syncWriteBuffers.insert(other.syncWriteBuffers.begin(),
other.syncWriteBuffers.end());
return *this;
}
/// Returns true if buffers in two BlockInfo objects are intersected.
bool isIntersected(const BlockInfo &other, Allocation *allocation) const {
return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers,
allocation) ||
/*WAR*/
isIntersected(syncReadBuffers, other.syncWriteBuffers,
allocation) ||
/*WAW*/
isIntersected(syncWriteBuffers, other.syncWriteBuffers,
allocation);
}
/// Clears the buffers because a barrier is inserted.
void sync() {
syncReadBuffers.clear();
syncWriteBuffers.clear();
}
/// Compares two BlockInfo objects.
bool operator==(const BlockInfo &other) const {
return syncReadBuffers == other.syncReadBuffers &&
syncWriteBuffers == other.syncWriteBuffers;
}
bool operator!=(const BlockInfo &other) const { return !(*this == other); }
private:
/// Returns true if buffers in two sets are intersected.
bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs,
Allocation *allocation) const {
return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) {
return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) {
return allocation->isIntersected(lhsId, rhsId);
});
});
}
};
/// Applies the barrier analysis based on the SCF dialect, in which each
/// region has a single basic block only.
/// Example:
@@ -109,18 +103,48 @@ private:
/// op6
/// op7
/// TODO: Explain why we don't use ForwardAnalysis:
void resolve(Operation *operation, OpBuilder *builder);
void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder);
/// Updates the BlockInfo operation based on the operation.
void update(Operation *operation, BlockInfo *blockInfo, OpBuilder *builder);
void update(Operation *operation, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder);
/// Collects the successors of the terminator
void visitTerminator(Operation *operation, SmallVector<Block *> &successors);
private:
Allocation *allocation;
DenseMap<Block *, BlockInfo> inputBlockInfoMap;
DenseMap<Block *, BlockInfo> outputBlockInfoMap;
Allocation *allocation = nullptr;
};
/// Postorder traversal on the callgraph to insert membar instructions
/// of each function.
/// Each function maintains a BlockInfo map that includes all potential buffers
/// after returning. This way users do not have to explicitly insert membars
/// before and after function calls, but might be a bit conservative.
class ModuleMembarAnalysis : public CallGraph<BlockInfo> {
public:
ModuleMembarAnalysis(ModuleAllocation *moduleAllocation)
: CallGraph<BlockInfo>(moduleAllocation->getModuleOp()),
moduleAllocation(moduleAllocation) {}
void run() {
walk<WalkOrder::PreOrder, WalkOrder::PostOrder>(
// Pre-order walk callback
[](CallOpInterface callOp, FunctionOpInterface funcOp) {},
// Post-order walk callback
[&](FunctionOpInterface funcOp) {
auto *allocation = moduleAllocation->getFuncData(funcOp);
auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo());
if (inserted) {
MembarAnalysis analysis(allocation);
analysis.run(funcMap);
}
});
}
private:
ModuleAllocation *moduleAllocation;
};
} // namespace mlir

View File

@@ -39,6 +39,10 @@ public:
unsigned getIntraWarpSize();
unsigned getInterWarpSizeWithUniqueData();
unsigned getIntraWarpSizeWithUniqueData();
unsigned getThreadsReductionAxis();
SmallVector<unsigned> getScratchConfigBasic();
@@ -57,8 +61,6 @@ private:
int axis;
};
bool isSharedEncoding(Value value);
bool maybeSharedAllocationOp(Operation *op);
bool maybeAliasOp(Operation *op);
@@ -116,14 +118,153 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy);
SetVector<Operation *>
multiRootTopologicalSort(const SetVector<Operation *> &toSort);
// This uses the toplogicalSort above
/// This uses the toplogicalSort above
SetVector<Operation *>
multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr,
TransitiveFilter forwardFilter = nullptr);
// Create a basic DataFlowSolver with constant and dead code analysis included.
/// Create a basic DataFlowSolver with constant and dead code analysis included.
std::unique_ptr<DataFlowSolver> createDataFlowSolver();
/// This class represents a call graph for a given ModuleOp and holds
/// data of type T associated with each FunctionOpInterface.
template <typename T> class CallGraph {
public:
using FuncDataMapT = DenseMap<FunctionOpInterface, T>;
/// Constructor that builds the call graph for the given moduleOp.
explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); }
/// Walks the call graph and applies the provided update functions
/// to the edges and nodes.
template <WalkOrder UpdateEdgeOrder = WalkOrder::PreOrder,
WalkOrder UpdateNodeOrder = WalkOrder::PreOrder,
typename UpdateEdgeFn, typename UpdateNodeFn>
void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) {
DenseSet<FunctionOpInterface> visited;
for (auto root : roots) {
doWalk<UpdateEdgeOrder, UpdateNodeOrder>(root, visited, updateEdgeFn,
updateNodeFn);
}
}
/// Retrieves the data associated with a function
T *getFuncData(FunctionOpInterface funcOp) {
if (funcMap.count(funcOp)) {
return &funcMap[funcOp];
}
return nullptr;
}
/// Getters
ModuleOp getModuleOp() const { return moduleOp; }
SmallVector<FunctionOpInterface> getRoots() const { return roots; }
size_t getNumFunctions() const { return funcMap.size(); }
/// Returns true if the given function is a root.
bool isRoot(FunctionOpInterface funcOp) const {
return llvm::is_contained(roots, funcOp);
}
/// Maps the data and the graph nodes associated with a funcOp to a
/// targetFuncOp.
template <typename FROM, typename TO>
void mapFuncOp(FROM funcOp, TO targetFuncOp) {
// Iterate over graph and replace
for (auto &kv : graph) {
for (auto &edge : kv.second) {
if (edge.second == funcOp) {
edge.second = targetFuncOp;
}
}
}
graph[targetFuncOp] = graph[funcOp];
// Replace in roots
for (auto it = roots.begin(); it != roots.end(); ++it) {
if (*it == funcOp) {
*it = targetFuncOp;
break;
}
}
// Replace in funcMap
funcMap[targetFuncOp] = funcMap[funcOp];
}
/// Maps the graph edges associated with a callOp to a targetCallOp.
template <typename FROM, typename TO>
void mapCallOp(FROM callOp, TO targetCallOp) {
// Iterate over graph and replace
for (auto &kv : graph) {
for (auto &edge : kv.second) {
if (edge.first == callOp) {
edge.first = targetCallOp;
}
}
}
}
private:
void build() {
SymbolTableCollection symbolTable;
DenseSet<FunctionOpInterface> visited;
// Build graph
moduleOp.walk([&](Operation *op) {
auto caller = op->getParentOfType<FunctionOpInterface>();
if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto *callee = callOp.resolveCallable(&symbolTable);
auto funcOp = dyn_cast_or_null<FunctionOpInterface>(callee);
if (funcOp) {
graph[caller].emplace_back(
std::pair<CallOpInterface, FunctionOpInterface>(callOp, funcOp));
visited.insert(funcOp);
}
}
});
// Find roots
moduleOp.walk([&](FunctionOpInterface funcOp) {
if (!visited.count(funcOp)) {
roots.push_back(funcOp);
}
});
}
template <WalkOrder UpdateEdgeOrder = WalkOrder::PreOrder,
WalkOrder UpdateNodeOrder = WalkOrder::PreOrder,
typename UpdateEdgeFn, typename UpdateNodeFn>
void doWalk(FunctionOpInterface funcOp,
DenseSet<FunctionOpInterface> &visited, UpdateEdgeFn updateEdgeFn,
UpdateNodeFn updateNodeFn) {
if (visited.count(funcOp)) {
llvm::report_fatal_error("Cycle detected in call graph");
}
if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) {
updateNodeFn(funcOp);
}
for (auto [callOp, callee] : graph[funcOp]) {
if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) {
updateEdgeFn(callOp, callee);
}
doWalk<UpdateEdgeOrder, UpdateNodeOrder>(callee, visited, updateEdgeFn,
updateNodeFn);
if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) {
updateEdgeFn(callOp, callee);
}
}
if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) {
updateNodeFn(funcOp);
}
visited.erase(funcOp);
}
protected:
ModuleOp moduleOp;
DenseMap<FunctionOpInterface,
SmallVector<std::pair<CallOpInterface, FunctionOpInterface>>>
graph;
FuncDataMapT funcMap;
SmallVector<FunctionOpInterface> roots;
};
} // namespace mlir
#endif // TRITON_ANALYSIS_UTILITY_H

View File

@@ -155,7 +155,7 @@ def TT_LoadOp : TT_Op<"load",
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor pointer with boundary check and padding
OpBuilder<(ins "Value":$ptr, "ArrayRef<int32_t>":$boundaryCheck,
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A tensor of pointers or a pointer to a scalar with mask
OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache,
@@ -164,8 +164,9 @@ def TT_LoadOp : TT_Op<"load",
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>,
// A utility function to build the operation with all attributes
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "Optional<ArrayRef<int32_t>>":$boundaryCheck,
"Optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other,
"std::optional<ArrayRef<int32_t>>":$boundaryCheck,
"std::optional<triton::PaddingOption>":$padding, "triton::CacheModifier":$cache,
"triton::EvictionPolicy":$evict, "bool":$isVolatile)>
];
@@ -600,6 +601,11 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
CallInterfaceCallable getCallableForCallee() {
return (*this)->getAttrOfType<SymbolRefAttr>("callee");
}
/// Set the callee for this operation.
void setCalleeFromCallable(CallInterfaceCallable callee) {
(*this)->setAttr("callee", callee.get<SymbolRefAttr>());
}
}];
let assemblyFormat = [{

View File

@@ -31,8 +31,37 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout);
SmallVector<unsigned> getSizePerThread(Attribute layout);
// Returns the number of contiguous elements that each thread
// has access to, on each dimension of the tensor. E.g.
// for a blocked layout with sizePerThread = [1, 4], returns [1, 4],
// regardless of the shape of the tensor.
SmallVector<unsigned> getContigPerThread(Attribute layout);
// Returns the number of non-replicated contiguous elements that each thread
// has access to, on each dimension of the tensor. For a blocked layout
// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements
// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1,
// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be
// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4].
SmallVector<unsigned> getUniqueContigPerThread(Type type);
// Returns the number of threads per warp that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17
// have access to the full tensor, whereas the other threads have access to
// replicated elements, so this function returns [2, 2].
SmallVector<unsigned>
getThreadsPerWarpWithUniqueData(Attribute layout,
ArrayRef<int64_t> tensorShape);
// Returns the number of warps per CTA that have access to non-replicated
// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1,
// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2],
// returns [1, 1], since the first warp has access to the full tensor, whereas
// the other warps have access to replicated elements.
SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
SmallVector<unsigned> getThreadsPerCTA(Attribute layout);
SmallVector<unsigned>
@@ -45,6 +74,9 @@ bool isaDistributedLayout(Attribute layout);
} // namespace gpu
} // namespace triton
bool isSharedEncoding(Value value);
} // namespace mlir
#endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_

View File

@@ -509,10 +509,22 @@ section 9.7.13.4.1 for more details.
let parameters = (
ins
"unsigned":$opIdx,
"Attribute":$parent
"Attribute":$parent,
"unsigned":$MMAv2kWidth
);
let builders = [
// Specially for MMAV1(Volta)
AttrBuilder<(ins "unsigned":$opIdx,
"Attribute":$parent,
"Type":$eltTy), [{
MmaEncodingAttr parentAttr = parent.dyn_cast<MmaEncodingAttr>();
if (!parentAttr || !parentAttr.isAmpere())
return $_get(context, opIdx, parent, 0);
unsigned bitwidth = eltTy.getIntOrFloatBitWidth();
unsigned MMAv2kWidth = 32 / bitwidth;
return $_get(context, opIdx, parent, MMAv2kWidth);
}]>
];
let hasCustomAssemblyFormat = 1;

View File

@@ -57,7 +57,6 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
"int32_t", /*default*/"80",
"device compute capability">
];
}
def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> {

View File

@@ -4,7 +4,6 @@
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "triton/Analysis/Alias.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "llvm/ADT/SmallVector.h"
@@ -117,8 +116,11 @@ SmallVector<unsigned> getScratchConfigForAtomicCAS(triton::AtomicCASOp op) {
class AllocationAnalysis {
public:
AllocationAnalysis(Operation *operation, Allocation *allocation)
: operation(operation), allocation(allocation) {
AllocationAnalysis(Operation *operation,
Allocation::FuncAllocMapT *funcAllocMap,
Allocation *allocation)
: operation(operation), funcAllocMap(funcAllocMap),
allocation(allocation) {
run();
}
@@ -219,6 +221,12 @@ private:
? elems * kPtrBitWidth / 8
: elems * elemTy.getIntOrFloatBitWidth() / 8;
allocation->addBuffer<BufferT::BufferKind::Scratch>(op, bytes);
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto callable = callOp.resolveCallable();
auto funcOp = dyn_cast<FunctionOpInterface>(callable);
auto *funcAlloc = &(*funcAllocMap)[funcOp];
auto bytes = funcAlloc->getSharedMemorySize();
allocation->addBuffer<BufferT::BufferKind::Virtual>(op, bytes);
}
}
@@ -298,15 +306,19 @@ private:
/// allocated, but is used to store intermediate results.
void resolveScratchBufferLiveness(
const DenseMap<Operation *, size_t> &operationId) {
// Analyze liveness of scratch buffers
for (auto opScratchIter : allocation->opScratch) {
// Any scratch memory's live range is the current operation's live
// range.
auto *op = opScratchIter.first;
auto *buffer = opScratchIter.second;
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
}
// Analyze liveness of scratch buffers and vritual buffers.
auto processScratchMemory = [&](const auto &container) {
for (auto opScratchIter : container) {
// Any scratch memory's live range is the current operation's live
// range.
auto *op = opScratchIter.first;
auto *buffer = opScratchIter.second;
bufferRange.insert({buffer, Interval(operationId.lookup(op),
operationId.lookup(op) + 1)});
}
};
processScratchMemory(allocation->opScratch);
processScratchMemory(allocation->opVirtual);
}
/// Resolves liveness of all values involved under the root operation.
@@ -499,11 +511,15 @@ private:
private:
Operation *operation;
Allocation::FuncAllocMapT *funcAllocMap;
Allocation *allocation;
BufferRangeMapT bufferRange;
};
} // namespace triton
void Allocation::run() { triton::AllocationAnalysis(getOperation(), this); }
void Allocation::run(FuncAllocMapT &funcAllocMap) {
triton::AllocationAnalysis(getOperation(), &funcAllocMap, this);
}
} // namespace mlir

View File

@@ -77,7 +77,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) {
if (blockArg && blockArg.getOwner()->isEntryBlock()) {
Operation *op = blockArg.getOwner()->getParentOp();
if (auto fun = dyn_cast<triton::FuncOp>(op))
if (auto fun = dyn_cast<FunctionOpInterface>(op))
initPessimisticStateFromFunc(blockArg.getArgNumber(), fun,
&knownContiguity, &knownDivisibility,
&knownConstancy);
@@ -696,13 +696,13 @@ private:
const AxisInfo &rhs) override {
if (lhs.getConstantValue().has_value() &&
rhs.getConstantValue().has_value()) {
if constexpr (std::is_same<OpTy, arith::AndIOp>::value) {
if constexpr (std::is_same_v<OpTy, arith::AndIOp>) {
return {lhs.getConstantValue().value() &
rhs.getConstantValue().value()};
} else if constexpr (std::is_same<OpTy, arith::OrIOp>::value) {
} else if constexpr (std::is_same_v<OpTy, arith::OrIOp>) {
return {lhs.getConstantValue().value() |
rhs.getConstantValue().value()};
} else if constexpr (std::is_same<OpTy, arith::XOrIOp>::value) {
} else if constexpr (std::is_same_v<OpTy, arith::XOrIOp>) {
return {lhs.getConstantValue().value() ^
rhs.getConstantValue().value()};
}
@@ -907,37 +907,37 @@ void AxisInfoAnalysis::visitOperation(
propagateIfChanged(result, result->join(curr));
}
unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) {
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
// Here order should be ordered by contiguous first, so the first element
// should have the largest contiguous.
auto order = triton::gpu::getOrder(layout);
unsigned align = getPtrAlignment(ptr);
unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]];
contigPerThread = std::min(align, contigPerThread);
contigPerThread = std::min<unsigned>(shape[order[0]], contigPerThread);
auto uniqueContigPerThread = triton::gpu::getUniqueContigPerThread(tensorTy);
assert(order[0] < uniqueContigPerThread.size() &&
"Unxpected uniqueContigPerThread size");
unsigned contiguity = uniqueContigPerThread[order[0]];
contiguity = std::min(align, contiguity);
return contigPerThread;
return contiguity;
}
unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(ptr);
if (!latticeElement)
auto *axisInfo = getAxisInfo(ptr);
if (!axisInfo)
return 1;
auto axisInfo = latticeElement->getValue();
auto layout = tensorTy.getEncoding();
auto order = triton::gpu::getOrder(layout);
auto maxMultipleBytes = axisInfo.getDivisibility(order[0]);
auto maxContig = axisInfo.getContiguity(order[0]);
auto maxMultipleBytes = axisInfo->getDivisibility(order[0]);
auto maxContig = axisInfo->getContiguity(order[0]);
auto elemNumBits = triton::getPointeeBitWidth(tensorTy);
auto elemNumBytes = std::max<unsigned>(elemNumBits / 8, 1);
auto maxMultiple = std::max<int64_t>(maxMultipleBytes / elemNumBytes, 1);
@@ -945,17 +945,69 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) {
return alignment;
}
unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) {
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
auto tensorTy = mask.getType().dyn_cast<RankedTensorType>();
if (!tensorTy)
return 1;
dataflow::Lattice<AxisInfo> *latticeElement = getLatticeElement(mask);
if (!latticeElement)
auto *axisInfo = getAxisInfo(mask);
if (!axisInfo)
return 1;
auto maskAxis = latticeElement->getValue();
auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding());
auto alignment = std::max<unsigned>(maskAxis.getConstancy(maskOrder[0]), 1);
auto alignment = std::max<unsigned>(axisInfo->getConstancy(maskOrder[0]), 1);
return alignment;
}
void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) {
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(funcOp)))
return;
auto *axisInfoMap = getFuncData(funcOp);
auto updateAxisInfoMap = [&](Value value) {
auto axisInfo = analysis->getLatticeElement(value)->getValue();
AxisInfo curAxisInfo;
if (axisInfoMap->count(value)) {
curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value));
} else {
curAxisInfo = axisInfo;
}
(*axisInfoMap)[value] = curAxisInfo;
};
funcOp.walk([&](Operation *op) {
for (auto value : op->getResults()) {
updateAxisInfoMap(value);
}
});
funcOp.walk([&](Block *block) {
for (auto value : block->getArguments()) {
updateAxisInfoMap(value);
}
});
}
void ModuleAxisInfoAnalysis::update(CallOpInterface callOp,
FunctionOpInterface callee) {
auto caller = callOp->getParentOfType<FunctionOpInterface>();
auto *axisInfoMap = getFuncData(caller);
for (auto entry : llvm::enumerate(callOp->getOperands())) {
auto index = entry.index();
auto value = entry.value();
auto setAttrFn = [&](StringRef attrName, int64_t prevValue) {
auto curValue = highestPowOf2Divisor<int64_t>(0);
if (callee.getArgAttrOfType<IntegerAttr>(index, attrName)) {
curValue =
callee.getArgAttrOfType<IntegerAttr>(index, attrName).getInt();
}
auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64),
gcd(prevValue, curValue));
callee.setArgAttr(index, attrName, attr);
};
auto axisInfo = axisInfoMap->lookup(value);
assert(axisInfo.getRank() == 1 && "only scalar arguments are supported");
setAttrFn("tt.contiguity", axisInfo.getContiguity(0));
setAttrFn("tt.divisibility", axisInfo.getDivisibility(0));
setAttrFn("tt.constancy", axisInfo.getConstancy(0));
}
}
} // namespace mlir

View File

@@ -11,4 +11,7 @@ add_mlir_library(TritonAnalysis
LINK_LIBS PUBLIC
MLIRAnalysis
MLIRLLVMDialect
TritonIR
TritonGPUIR
)

View File

@@ -9,16 +9,21 @@
namespace mlir {
void MembarAnalysis::run() {
auto *operation = allocation->getOperation();
OpBuilder builder(operation);
resolve(operation, &builder);
void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) {
FunctionOpInterface funcOp =
dyn_cast<FunctionOpInterface>(allocation->getOperation());
OpBuilder builder(funcOp.getContext());
resolve(funcOp, &funcBlockInfoMap, &builder);
}
void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) {
void MembarAnalysis::resolve(FunctionOpInterface funcOp,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
// Initialize the blockList
DenseMap<Block *, BlockInfo> inputBlockInfoMap;
DenseMap<Block *, BlockInfo> outputBlockInfoMap;
std::deque<Block *> blockList;
operation->walk<WalkOrder::PreOrder>([&](Block *block) {
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
for (auto &op : block->getOperations()) {
// Check if the operation belongs to scf dialect, if so, we need to
// throw an error
@@ -38,13 +43,13 @@ void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) {
auto *block = blockList.front();
blockList.pop_front();
// Make a copy of the inputblockInfo but not update
auto inputBlockInfo = inputBlockInfoMap.lookup(block);
auto inputBlockInfo = inputBlockInfoMap[block];
SmallVector<Block *> successors;
for (auto &op : block->getOperations()) {
if (op.hasTrait<OpTrait::IsTerminator>()) {
visitTerminator(&op, successors);
} else {
update(&op, &inputBlockInfo, builder);
update(&op, &inputBlockInfo, funcBlockInfoMap, builder);
}
}
// Get the reference because we want to update if it changed
@@ -62,15 +67,22 @@ void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) {
blockList.emplace_back(successor);
}
}
// Update the final dangling buffers that haven't been synced
auto &funcBlockInfo = (*funcBlockInfoMap)[funcOp];
funcOp.walk<WalkOrder::PreOrder>([&](Block *block) {
block->walk([&](triton::ReturnOp returnOp) {
funcBlockInfo.join(outputBlockInfoMap[block]);
});
});
}
void MembarAnalysis::visitTerminator(Operation *op,
SmallVector<Block *> &successors) {
if (auto branchInterface = dyn_cast<BranchOpInterface>(op)) {
Block *parentBlock = branchInterface->getBlock();
for (Block *successor : parentBlock->getSuccessors()) {
successors.push_back(successor);
}
successors.append(std::begin(parentBlock->getSuccessors()),
std::end(parentBlock->getSuccessors()));
return;
}
// Otherwise, it could be a return op
@@ -81,6 +93,7 @@ void MembarAnalysis::visitTerminator(Operation *op,
}
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
FuncBlockInfoMapT *funcBlockInfoMap,
OpBuilder *builder) {
if (isa<triton::gpu::ExtractSliceOp>(op) ||
isa<triton::gpu::AllocTensorOp>(op) || isa<triton::TransOp>(op)) {
@@ -108,36 +121,51 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
}
BlockInfo curBlockInfo;
for (Value value : op->getOperands()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op)) {
// FIXME(Keren): insert_slice and insert_slice_async are always
// alias for now
curBlockInfo.syncWriteBuffers.insert(bufferId);
} else {
// ConvertLayoutOp: shared memory -> registers
curBlockInfo.syncReadBuffers.insert(bufferId);
if (isa<triton::CallOp>(op)) {
// Inter-function dependencies
auto callOpInterface = dyn_cast<CallOpInterface>(op);
if (auto callee =
dyn_cast<FunctionOpInterface>(callOpInterface.resolveCallable())) {
curBlockInfo = funcBlockInfoMap->lookup(callee);
}
} else {
// Intra-function dependencies
for (Value value : op->getOperands()) {
for (auto bufferId : allocation->getBufferIds(value)) {
if (bufferId != Allocation::InvalidBufferId) {
if (isa<triton::gpu::InsertSliceAsyncOp>(op) ||
isa<tensor::InsertSliceOp>(op)) {
// FIXME(Keren): insert_slice and insert_slice_async are always
// alias for now
curBlockInfo.syncWriteIntervals.insert(
allocation->getAllocatedInterval(bufferId));
} else {
// ConvertLayoutOp: shared memory -> registers
curBlockInfo.syncReadIntervals.insert(
allocation->getAllocatedInterval(bufferId));
}
}
}
}
}
for (Value value : op->getResults()) {
// ConvertLayoutOp: registers -> shared memory
auto bufferId = allocation->getBufferId(value);
for (Value value : op->getResults()) {
// ConvertLayoutOp: registers -> shared memory
auto bufferId = allocation->getBufferId(value);
if (bufferId != Allocation::InvalidBufferId) {
curBlockInfo.syncWriteIntervals.insert(
allocation->getAllocatedInterval(bufferId));
}
}
// Scratch buffer is considered as both shared memory write & read
auto bufferId = allocation->getBufferId(op);
if (bufferId != Allocation::InvalidBufferId) {
curBlockInfo.syncWriteBuffers.insert(bufferId);
curBlockInfo.syncWriteIntervals.insert(
allocation->getAllocatedInterval(bufferId));
curBlockInfo.syncReadIntervals.insert(
allocation->getAllocatedInterval(bufferId));
}
}
// Scratch buffer is considered as both shared memory write & read
auto bufferId = allocation->getBufferId(op);
if (bufferId != Allocation::InvalidBufferId) {
curBlockInfo.syncWriteBuffers.insert(bufferId);
curBlockInfo.syncReadBuffers.insert(bufferId);
}
if (blockInfo->isIntersected(curBlockInfo, allocation)) {
if (blockInfo->isIntersected(curBlockInfo)) {
OpBuilder::InsertionGuard g(*builder);
builder->setInsertionPoint(op);
builder->create<gpu::BarrierOp>(op->getLoc());

View File

@@ -26,10 +26,27 @@ unsigned ReduceOpHelper::getIntraWarpSize() {
triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]);
}
unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData();
return std::min(srcReduceDimSize / sizeIntraWarps,
triton::gpu::getWarpsPerCTAWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}
unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() {
auto srcReduceDimSize = static_cast<unsigned>(srcShape[axis]);
return std::min(srcReduceDimSize,
triton::gpu::getThreadsPerWarpWithUniqueData(
getSrcLayout(), getSrcShape())[axis]);
}
unsigned ReduceOpHelper::getThreadsReductionAxis() {
auto srcLayout = getSrcLayout();
return triton::gpu::getThreadsPerWarp(srcLayout)[axis] *
triton::gpu::getWarpsPerCTA(srcLayout)[axis];
auto srcShape = getSrcShape();
return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout,
srcShape)[axis] *
triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis];
}
SmallVector<unsigned> ReduceOpHelper::getScratchConfigBasic() {
@@ -91,14 +108,8 @@ bool ReduceOpHelper::isSupportedLayout() {
return true;
}
}
return false;
}
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
auto encoding = tensorType.getEncoding();
return encoding && encoding.isa<triton::gpu::SharedEncodingAttr>();
if (auto sliceLayout = srcLayout.dyn_cast<triton::gpu::SliceEncodingAttr>()) {
return true;
}
return false;
}

View File

@@ -106,13 +106,22 @@ private:
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
unsigned dim = sliceLayout.getDim();
auto parentEncoding = sliceLayout.getParent();
auto parentSizePerThread = getSizePerThread(parentEncoding);
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
parentEncoding);
auto multiDimOffsetParent =
getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
auto offsets = emitOffsetForLayout(layout, type);
auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy);
SmallVector<int> idxs;
for (SmallVector<unsigned> off : offsets) {
off.insert(off.begin() + dim, 0);
auto it = std::find(parentOffset.begin(), parentOffset.end(), off);
idxs.push_back(std::distance(parentOffset.begin(), it));
}
auto multiDimOffsetParent = getMultiDimOffset(
parentEncoding, loc, rewriter, idxs[elemId], parentTy,
sliceLayout.paddedShape(multiDimCTAInRepId),
sliceLayout.paddedShape(shapePerCTA));
SmallVector<Value> multiDimOffset(rank);
for (unsigned d = 0; d < rank + 1; ++d) {
if (d == dim)
@@ -329,7 +338,8 @@ private:
if (needTrans) {
// do transpose
auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma);
auto aEncoding =
DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0);
int numM = aEncoding.getMMAv1NumOuter(shape);
int numN = accumSizePerThread / numM;
@@ -619,10 +629,10 @@ private:
// is implemented
SmallVector<Value> reorderedVals;
for (unsigned i = 0; i < vecVals.size(); i += 4) {
reorderedVals.push_back(vecVals[i]);
reorderedVals.push_back(vecVals[i + 2]);
reorderedVals.push_back(vecVals[i + 1]);
reorderedVals.push_back(vecVals[i + 3]);
reorderedVals.push_back(bitcast(vecVals[i], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty));
reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty));
}
Value view = getTypeConverter()->packLLElements(loc, reorderedVals,
@@ -641,19 +651,19 @@ private:
auto loc = op.getLoc();
Value src = op.getSrc();
Value dst = op.getResult();
bool isHMMA = supportMMA(dst, mmaLayout.getVersionMajor());
auto smemObj =
getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter);
Value res;
if (!isOuter && mmaLayout.isAmpere() && isHMMA) { // tensor core v2
if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2
res = SharedToDotOperandMMAv2::convertLayout(
dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout,
smemObj, getTypeConverter(), tid_val());
} else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1
} else if (!isOuter && mmaLayout.isVolta() &&
supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
auto srcSharedLayout = src.getType()
.cast<RankedTensorType>()
@@ -675,14 +685,13 @@ private:
}
return res;
}
}; // namespace triton::gpu::ConvertLayoutOp>
}; // namespace triton::gpu::ConvertLayoutOp
void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation, smem,
patterns.add<ConvertLayoutOpConversion>(typeConverter, allocation,
indexCacheInfo, benefit);
}

View File

@@ -10,8 +10,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr;
void populateConvertLayoutOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

View File

@@ -358,11 +358,11 @@ SmallVector<CoordTy> getMNCoords(Value thread,
Value _fpw1 = i32_val(fpw[1]);
// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();

View File

@@ -19,10 +19,10 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
class MMA16816SmemLoader {
public:
MMA16816SmemLoader(int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
ArrayRef<int> instrShape, ArrayRef<int> matShape,
int perPhase, int maxPhase, int elemBytes,
ConversionPatternRewriter &rewriter,
int kWidth, ArrayRef<Value> smemStrides,
ArrayRef<int64_t> tileShape, ArrayRef<int> instrShape,
ArrayRef<int> matShape, int perPhase, int maxPhase,
int elemBytes, ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter,
const Location &loc);
@@ -33,7 +33,7 @@ public:
if (canUseLdmatrix)
return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset);
else
return computeLdsMatOffs(warpOff, lane, cSwizzleOffset, elemBytes);
return computeLdsMatOffs(warpOff, lane, cSwizzleOffset);
return {};
}
@@ -45,7 +45,7 @@ public:
Value cSwizzleOffset);
// compute 8-bit matrix offset.
SmallVector<Value> computeLdsMatOffs(Value warpOff, Value lane,
Value cSwizzleOffset, int elemBytes);
Value cSwizzleOffset);
// Load 4 matrices and returns 4 vec<2> elements.
std::tuple<Value, Value, Value, Value>
@@ -55,6 +55,7 @@ public:
private:
SmallVector<uint32_t> order;
int kOrder;
int kWidth;
SmallVector<int64_t> tileShape;
SmallVector<int> instrShape;
SmallVector<int> matShape;
@@ -176,9 +177,7 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane,
SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
Value lane,
Value cSwizzleOffset,
int elemBytes) {
assert(elemBytes <= 4);
Value cSwizzleOffset) {
int cTileShape = tileShape[order[0]];
int sTileShape = tileShape[order[1]];
if (!needTrans) {
@@ -187,10 +186,10 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
SmallVector<Value> offs(numPtrs);
int vecWidth = kWidth;
int threadsPerQuad[2] = {8, 4};
int laneWidth = 4;
int laneHeight = 8;
int vecWidth = 4 / elemBytes;
int quadWidth = laneWidth * vecWidth;
int quadHeight = laneHeight;
int numQuadI = 2;
@@ -232,8 +231,8 @@ SmallVector<Value> MMA16816SmemLoader::computeLdsMatOffs(Value warpOff,
Value i = add(iBase, mul(iOff, i32_val(quadHeight)));
Value j = add(jBase, mul(jOff, i32_val(quadWidth)));
// wrap around the bounds
i = urem(i, i32_val(cTileShape));
j = urem(j, i32_val(sTileShape));
// i = urem(i, i32_val(cTileShape));
// j = urem(j, i32_val(sTileShape));
if (needTrans) {
offs[idx] = add(i, mul(j, sStride));
} else {
@@ -304,7 +303,6 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1),
extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)};
} else {
elemTy = matTy.cast<LLVM::LLVMStructType>().getBody()[0];
// base pointers
std::array<std::array<Value, 4>, 2> ptrs;
int vecWidth = 4 / elemBytes;
@@ -324,39 +322,50 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef<Value> offs,
std::array<Value, 2> ii = {i0, i1};
// load 4 32-bit values from shared memory
// (equivalent to ldmatrix.x4)
SmallVector<SmallVector<Value>> vals(4, SmallVector<Value>(vecWidth));
SmallVector<SmallVector<Value>> vptrs(4, SmallVector<Value>(vecWidth));
for (int i = 0; i < 4; ++i)
for (int j = 0; j < vecWidth; ++j)
vals[i][j] = load(gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]));
vptrs[i][j] = gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]);
// row + trans and col + no-trans are equivalent
if ((needTrans && kOrder == 1) || (!needTrans && kOrder == 0))
std::swap(vals[1], vals[2]);
bool isActualTrans =
(needTrans && kOrder == 1) || (!needTrans && kOrder == 0);
if (isActualTrans)
std::swap(vptrs[1], vptrs[2]);
// pack loaded vectors into 4 32-bit values
int inc = needTrans ? 1 : kWidth;
VectorType packedTy = vec_ty(int_ty(8 * elemBytes), inc);
int canonBits = std::min(32, 8 * elemBytes * inc);
int canonWidth = (8 * elemBytes * inc) / canonBits;
Type canonInt = int_ty(canonBits);
std::array<Value, 4> retElems;
retElems.fill(undef(elemTy));
for (int m = 0; m < 4; ++m) {
for (int e = 0; e < vecWidth; ++e)
retElems[m] = insert_element(retElems[m].getType(), retElems[m],
vals[m][e], i32_val(e));
retElems.fill(undef(vec_ty(canonInt, 32 / canonBits)));
for (int r = 0; r < 2; ++r) {
for (int em = 0; em < 2 * vecWidth; em += inc) {
int e = em % vecWidth;
int m = em / vecWidth;
int idx = m * 2 + r;
Value ptr = bitcast(vptrs[idx][e], ptr_ty(packedTy, 3));
Value val = load(ptr);
Value canonval = bitcast(val, vec_ty(canonInt, canonWidth));
for (int w = 0; w < canonWidth; ++w) {
retElems[idx + w * kWidth / vecWidth] =
insert_element(retElems[idx + w * kWidth / vecWidth],
extract_element(canonval, i32_val(w)), i32_val(e));
}
}
}
if (elemBytes == 1)
return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty),
bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)};
else
return {retElems[0], retElems[1], retElems[2], retElems[3]};
return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty),
bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)};
}
assert(false && "Invalid smem load");
return {Value{}, Value{}, Value{}, Value{}};
}
MMA16816SmemLoader::MMA16816SmemLoader(
int wpt, ArrayRef<uint32_t> order, uint32_t kOrder,
int wpt, ArrayRef<uint32_t> order, uint32_t kOrder, int kWidth,
ArrayRef<Value> smemStrides, ArrayRef<int64_t> tileShape,
ArrayRef<int> instrShape, ArrayRef<int> matShape, int perPhase,
int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter,
TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc)
: order(order.begin(), order.end()), kOrder(kOrder),
: order(order.begin(), order.end()), kOrder(kOrder), kWidth(kWidth),
tileShape(tileShape.begin(), tileShape.end()),
instrShape(instrShape.begin(), instrShape.end()),
matShape(matShape.begin(), matShape.end()), perPhase(perPhase),
@@ -369,7 +378,8 @@ MMA16816SmemLoader::MMA16816SmemLoader(
// rule: k must be the fast-changing axis.
needTrans = kOrder != order[0];
canUseLdmatrix = elemBytes == 2 || (!needTrans); // b16
canUseLdmatrix = elemBytes == 2 || (!needTrans);
canUseLdmatrix = canUseLdmatrix && (kWidth == 4 / elemBytes);
if (canUseLdmatrix) {
// Each CTA, the warps is arranged as [1xwpt] if not transposed,
@@ -409,42 +419,12 @@ Type getShemPtrTy(Type argType) {
return ptr_ty(type::i16Ty(ctx), 3);
else if (argType.isF32())
return ptr_ty(type::f32Ty(ctx), 3);
else if (argType.isInteger(8))
else if (argType.getIntOrFloatBitWidth() == 8)
return ptr_ty(type::i8Ty(ctx), 3);
else
llvm::report_fatal_error("mma16816 data type not supported");
}
Type getMatType(Type argType) {
MLIRContext *ctx = argType.getContext();
// floating point types
Type fp32x1Ty = vec_ty(type::f32Ty(ctx), 1);
Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2);
Type i16x2Ty = vec_ty(type::i16Ty(ctx), 2);
Type fp16x2Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp16x2Ty));
// LLVM 14.0 does not support bf16 type, so we use i16 instead.
Type bf16x2Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i16x2Ty));
Type fp32Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, fp32x1Ty));
// integer types
Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4);
Type i8x4Pack4Ty =
LLVM::LLVMStructType::getLiteral(ctx, SmallVector<Type>(4, i8x4Ty));
if (argType.isF16())
return fp16x2Pack4Ty;
else if (argType.isBF16())
return bf16x2Pack4Ty;
else if (argType.isF32())
return fp32Pack4Ty;
else if (argType.isInteger(8))
return i8x4Pack4Ty;
else
llvm::report_fatal_error("mma16816 data type not supported");
}
Value composeValuesToDotOperandLayoutStruct(
const ValueTable &vals, int n0, int n1,
TritonGPUToLLVMTypeConverter *typeConverter, Location loc,
@@ -470,7 +450,7 @@ Value composeValuesToDotOperandLayoutStruct(
std::function<void(int, int)>
getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder,
MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, int kWidth,
SmallVector<int> instrShape, SmallVector<int> matShape,
Value warpId, Value lane, ValueTable &vals, bool isA,
TritonGPUToLLVMTypeConverter *typeConverter,
@@ -485,143 +465,105 @@ getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj,
const int elemBytes = tensorTy.getElementTypeBitWidth() / 8;
auto order = sharedLayout.getOrder();
// the original register_lds2, but discard the prefetch logic.
auto ld2 = [](ValueTable &vals, int mn, int k, Value val) {
vals[{mn, k}] = val;
};
// (a, b) is the coordinate.
auto load = [=, &rewriter, &vals, &ld2](int a, int b) {
auto load = [=, &rewriter, &vals](int a, int b) {
MMA16816SmemLoader loader(
wpt, sharedLayout.getOrder(), kOrder, smemObj.strides,
wpt, sharedLayout.getOrder(), kOrder, kWidth, smemObj.strides,
tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase,
maxPhase, elemBytes, rewriter, typeConverter, loc);
Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]);
SmallVector<Value> offs =
loader.computeOffsets(warpId, lane, cSwizzleOffset);
// initialize pointers
const int numPtrs = loader.getNumPtrs();
SmallVector<Value> ptrs(numPtrs);
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
Type smemPtrTy = getShemPtrTy(eltTy);
for (int i = 0; i < numPtrs; ++i) {
ptrs[i] =
bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy);
}
for (int i = 0; i < numPtrs; ++i)
ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy);
// actually load from shared memory
auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(),
SmallVector<Type>(4, i32_ty));
auto [ha0, ha1, ha2, ha3] = loader.loadX4(
(kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs,
ptrs, getMatType(eltTy), getShemPtrTy(eltTy));
if (isA) {
ld2(vals, a, b, ha0);
ld2(vals, a + 1, b, ha1);
ld2(vals, a, b + 1, ha2);
ld2(vals, a + 1, b + 1, ha3);
} else {
ld2(vals, a, b, ha0);
ld2(vals, a + 1, b, ha2);
ld2(vals, a, b + 1, ha1);
ld2(vals, a + 1, b + 1, ha3);
}
ptrs, matTy, getShemPtrTy(eltTy));
if (!isA)
std::swap(ha1, ha2);
// the following is incorrect
// but causes dramatically better performance in ptxas
// although it only changes the order of operands in
// `mma.sync`
// if(isA)
// std::swap(ha1, ha2);
// update user-provided values in-place
vals[{a, b}] = ha0;
vals[{a + 1, b}] = ha1;
vals[{a, b + 1}] = ha2;
vals[{a + 1, b + 1}] = ha3;
};
return load;
}
Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
DotOperandEncodingAttr aEncoding, const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
auto aTensorTy = tensor.getType().cast<RankedTensorType>();
int bitwidth = aTensorTy.getElementTypeBitWidth();
auto mmaLayout = aEncoding.getParent().cast<MmaEncodingAttr>();
SmallVector<int64_t> shape(aTensorTy.getShape().begin(),
aTensorTy.getShape().end());
ValueTable ha;
std::function<void(int, int)> loadFn;
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
auto numRep = aEncoding.getMMAv2Rep(aTensorTy.getShape(), bitwidth);
int numRepM = numRep[0];
int numRepK = numRep[1];
if (aTensorTy.getEncoding().isa<SharedEncodingAttr>()) {
int wpt0 = mmaLayout.getWarpsPerCTA()[0];
Value warp = udiv(thread, i32_val(32));
Value lane = urem(thread, i32_val(32));
Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16));
// load from smem
// we use ldmatrix.x4 so each warp processes 16x16 elements.
int wpt = std::min<int>(wpt0, shape[0] / 16);
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/,
{mmaInstrM, mmaInstrK} /*instrShape*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/,
ha /*vals*/, true /*isA*/, typeConverter /* typeConverter */,
rewriter /*rewriter*/, loc /*loc*/);
} else if (aTensorTy.getEncoding().isa<BlockedEncodingAttr>()) {
// load from registers, used in gemm fuse
// TODO(Superjomn) Port the logic.
assert(false && "Loading A from register is not supported yet.");
} else {
assert(false && "A's layout is not supported.");
}
// step1. Perform loading.
for (int m = 0; m < numRepM; ++m)
for (int k = 0; k < numRepK; ++k)
loadFn(2 * m, 2 * k);
// step2. Format the values to LLVM::Struct to passing to mma codegen.
return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK,
typeConverter, loc, rewriter);
}
Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
ValueTable hb;
Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor,
DotOperandEncodingAttr encoding,
const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread,
bool isA) {
auto tensorTy = tensor.getType().cast<RankedTensorType>();
int bitwidth = tensorTy.getElementTypeBitWidth();
auto mmaLayout = bEncoding.getParent().cast<MmaEncodingAttr>();
auto mmaLayout = encoding.getParent().cast<MmaEncodingAttr>();
SmallVector<int64_t> shape(tensorTy.getShape().begin(),
tensorTy.getShape().end());
ValueTable vals;
int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth;
int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth;
auto numRep = bEncoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
int numRepK = numRep[0];
int numRepN = numRep[1];
auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth);
int kWidth = encoding.getMMAv2kWidth();
int wpt0 = mmaLayout.getWarpsPerCTA()[0];
int wpt1 = mmaLayout.getWarpsPerCTA()[1];
Value warp = udiv(thread, i32_val(32));
Value lane = urem(thread, i32_val(32));
Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16));
Value warpMN = udiv(warp, i32_val(wpt0));
Value warpN = urem(urem(warpMN, i32_val(wpt1)), i32_val(shape[1] / 8));
// we use ldmatrix.x4 so each warp processes 16x16 elements.
int wpt = std::min<int>(wpt1, shape[1] / 16);
auto loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/,
{mmaInstrK, mmaInstrN} /*instrShape*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/,
hb /*vals*/, false /*isA*/, typeConverter /* typeConverter */,
rewriter /*rewriter*/, loc /*loc*/);
for (int n = 0; n < std::max(numRepN / 2, 1); ++n) {
int wpt;
if (isA)
wpt = std::min<int>(wpt0, shape[0] / 16);
else
wpt = std::min<int>(wpt1, shape[1] / 16);
std::function<void(int, int)> loadFn;
if (isA)
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, kWidth,
{mmaInstrM, mmaInstrK} /*instrShape*/,
{matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/,
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
rewriter /*rewriter*/, loc /*loc*/);
else
loadFn = getLoadMatrixFn(
tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, kWidth,
{mmaInstrK, mmaInstrN} /*instrShape*/,
{matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/,
vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */,
rewriter /*rewriter*/, loc /*loc*/);
// Perform loading.
int numRepOuter = isA ? numRep[0] : std::max<int>(numRep[1] / 2, 1);
int numRepK = isA ? numRep[1] : numRep[0];
for (int m = 0; m < numRepOuter; ++m)
for (int k = 0; k < numRepK; ++k)
loadFn(2 * n, 2 * k);
}
loadFn(2 * m, 2 * k);
Value result = composeValuesToDotOperandLayoutStruct(
hb, std::max(numRepN / 2, 1), numRepK, typeConverter, loc, rewriter);
return result;
// Format the values to LLVM::Struct to passing to mma codegen.
return composeValuesToDotOperandLayoutStruct(vals, numRepOuter, numRepK,
typeConverter, loc, rewriter);
}
namespace SharedToDotOperandMMAv2 {
@@ -630,12 +572,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter,
const SharedMemoryObject &smemObj,
TritonGPUToLLVMTypeConverter *typeConverter, Value thread) {
if (opIdx == 0)
return loadA(rewriter, loc, tensor, encoding, smemObj, typeConverter,
thread);
return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter,
thread, true);
else {
assert(opIdx == 1);
return loadB(rewriter, loc, tensor, encoding, smemObj, typeConverter,
thread);
return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter,
thread, false);
}
}
} // namespace SharedToDotOperandMMAv2

View File

@@ -62,9 +62,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
};
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
RewritePatternSet &patterns,
ModuleAllocation &allocation,
PatternBenefit benefit) {
patterns.add<DotOpConversion>(typeConverter, allocation, smem, benefit);
patterns.add<DotOpConversion>(typeConverter, allocation, benefit);
}

View File

@@ -7,9 +7,8 @@ using namespace mlir;
using namespace mlir::triton;
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
RewritePatternSet &patterns,
ModuleAllocation &allocation,
PatternBenefit benefit);
#endif

View File

@@ -4,11 +4,140 @@ using namespace mlir;
using namespace mlir::triton;
using ::mlir::triton::gpu::getTotalElemsPerThread;
static SmallVector<Value> reorderValues(const SmallVector<Value> &values,
Type inType, Type ouType) {
auto inTensorTy = inType.dyn_cast<RankedTensorType>();
auto ouTensorTy = ouType.dyn_cast<RankedTensorType>();
if (!inTensorTy || !ouTensorTy)
return values;
auto inEncoding =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(inTensorTy.getEncoding());
auto ouEncoding =
dyn_cast<triton::gpu::DotOperandEncodingAttr>(ouTensorTy.getEncoding());
assert(inEncoding == ouEncoding);
if (!inEncoding)
return values;
size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth();
size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth();
auto ouEltTy = ouTensorTy.getElementType();
if (inBitWidth == ouBitWidth)
return values;
if (inBitWidth == 16 && ouBitWidth == 32) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 8) {
ret.push_back(values[i]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
}
return ret;
}
if (inBitWidth == 8 && ouBitWidth == 16) {
SmallVector<Value> ret;
for (unsigned i = 0; i < values.size(); i += 16) {
ret.push_back(values[i + 0]);
ret.push_back(values[i + 1]);
ret.push_back(values[i + 2]);
ret.push_back(values[i + 3]);
ret.push_back(values[i + 8]);
ret.push_back(values[i + 9]);
ret.push_back(values[i + 10]);
ret.push_back(values[i + 11]);
ret.push_back(values[i + 4]);
ret.push_back(values[i + 5]);
ret.push_back(values[i + 6]);
ret.push_back(values[i + 7]);
ret.push_back(values[i + 12]);
ret.push_back(values[i + 13]);
ret.push_back(values[i + 14]);
ret.push_back(values[i + 15]);
}
return ret;
// for (unsigned i = 0; i < values.size(); i += 16) {
// ret.push_back(values[i]);
// ret.push_back(values[i + 1]);
// ret.push_back(values[i + 4]);
// ret.push_back(values[i + 5]);
// ret.push_back(values[i + 8]);
// ret.push_back(values[i + 9]);
// ret.push_back(values[i + 12]);
// ret.push_back(values[i + 13]);
// ret.push_back(values[i + 2]);
// ret.push_back(values[i + 3]);
// ret.push_back(values[i + 6]);
// ret.push_back(values[i + 7]);
// ret.push_back(values[i + 10]);
// ret.push_back(values[i + 11]);
// ret.push_back(values[i + 14]);
// ret.push_back(values[i + 15]);
// }
return values;
}
llvm_unreachable("unimplemented code path");
}
inline SmallVector<Value> unpackI32(const SmallVector<Value> &inValues,
Type srcTy,
ConversionPatternRewriter &rewriter,
Location loc,
TypeConverter *typeConverter) {
auto tensorTy = srcTy.dyn_cast<RankedTensorType>();
if (!tensorTy)
return inValues;
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
if (!(encoding && encoding.getParent().isa<MmaEncodingAttr>()))
return inValues;
SmallVector<Value> outValues;
for (auto v : inValues) {
// cast i32 to appropriate eltType vector and extract elements
auto eltType = typeConverter->convertType(tensorTy.getElementType());
auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth());
auto vec = bitcast(v, vecType);
for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) {
outValues.push_back(extract_element(vec, i32_val(i)));
}
}
return outValues;
}
inline SmallVector<Value> packI32(const SmallVector<Value> &inValues,
Type srcTy,
ConversionPatternRewriter &rewriter,
Location loc, TypeConverter *typeConverter) {
auto tensorTy = srcTy.dyn_cast<RankedTensorType>();
if (!tensorTy)
return inValues;
auto encoding = tensorTy.getEncoding().dyn_cast<DotOperandEncodingAttr>();
if (!(encoding && encoding.getParent().isa<MmaEncodingAttr>()))
return inValues;
SmallVector<Value> outValues;
auto eltType = typeConverter->convertType(tensorTy.getElementType());
int vecWidth = 32 / eltType.getIntOrFloatBitWidth();
auto vecType = vec_ty(eltType, vecWidth);
for (int i = 0; i < inValues.size(); i += vecWidth) {
Value vec = undef(vecType);
for (int j = 0; j < vecWidth; j++) {
vec = insert_element(vec, inValues[i + j], i32_val(j));
}
outValues.push_back(bitcast(vec, i32_ty));
}
return outValues;
}
struct FpToFpOpConversion
: public ConvertTritonGPUOpToLLVMPattern<triton::FpToFpOp> {
using ConvertTritonGPUOpToLLVMPattern<
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
typedef std::function<SmallVector<Value>(
Location, ConversionPatternRewriter &, const Value &, const Value &,
const Value &, const Value &)>
ConvertorT;
/* ------------------ */
// FP8 -> FP16
/* ------------------ */
@@ -747,35 +876,14 @@ struct FpToFpOpConversion
#endif
}
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto srcTensorType = op.getFrom().getType().cast<mlir::RankedTensorType>();
auto dstTensorType =
op.getResult().getType().cast<mlir::RankedTensorType>();
auto srcEltType = srcTensorType.getElementType();
auto dstEltType = dstTensorType.getElementType();
auto loc = op->getLoc();
auto elems = getTotalElemsPerThread(dstTensorType);
SmallVector<Value> resultVals;
bool isSrcFP8 =
srcEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
bool isDstFP8 =
dstEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
// Select convertor
typedef std::function<SmallVector<Value>(
Location, ConversionPatternRewriter &, const Value &, const Value &,
const Value &, const Value &)>
ConvertorT;
ConvertorT getConversionFunc(Type srcTy, Type dstTy) const {
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNType>();
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
auto F16TyID = TypeID::get<mlir::Float16Type>();
auto BF16TyID = TypeID::get<mlir::BFloat16Type>();
auto F32TyID = TypeID::get<mlir::Float32Type>();
auto F64TyID = TypeID::get<mlir::Float64Type>();
DenseMap<std::pair<TypeID, TypeID>, ConvertorT> convertorMap = {
static DenseMap<std::pair<TypeID, TypeID>, ConvertorT> convertorMap = {
// F8 -> F16
{{F8E4M3TyID, F16TyID}, convertFp8E4M3x4ToFp16x4},
{{F8E5M2TyID, F16TyID}, convertFp8E5M2x4ToFp16x4},
@@ -796,28 +904,46 @@ struct FpToFpOpConversion
{{F32TyID, F8E5M2TyID}, convertFp32x4ToFp8E5M2x4},
};
std::pair<TypeID, TypeID> key = {srcEltType.getTypeID(),
dstEltType.getTypeID()};
std::pair<TypeID, TypeID> key = {srcTy.getTypeID(), dstTy.getTypeID()};
if (convertorMap.count(key) == 0) {
llvm::errs() << "Unsupported conversion from " << srcEltType << " to "
<< dstEltType << "\n";
llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy
<< "\n";
llvm_unreachable("");
}
auto convertor = convertorMap.lookup(key);
return convertorMap.lookup(key);
}
// Vectorized casting
LogicalResult
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// llvm::outs() << 0 << "\n";
auto srcTensorType = op.getFrom().getType().cast<mlir::RankedTensorType>();
auto dstTensorType =
op.getResult().getType().cast<mlir::RankedTensorType>();
auto loc = op->getLoc();
// check that the number of elements is divisible by 4
// Get convertor
auto cvtFunc = getConversionFunc(srcTensorType.getElementType(),
dstTensorType.getElementType());
// Unpack value
auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getFrom(),
rewriter, srcTensorType);
inVals =
unpackI32(inVals, srcTensorType, rewriter, loc, getTypeConverter());
// Cast
SmallVector<Value> outVals;
auto elems = inVals.size();
assert(elems % 4 == 0 &&
"FP8 casting only support tensors with 4-aligned sizes");
auto elements = getTypeConverter()->unpackLLElements(
loc, adaptor.getFrom(), rewriter, srcTensorType);
for (size_t i = 0; i < elems; i += 4) {
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
elements[i + 2], elements[i + 3]);
resultVals.append(converted);
}
assert(resultVals.size() == elems);
auto result = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
for (size_t i = 0; i < elems; i += 4)
outVals.append(cvtFunc(loc, rewriter, inVals[i], inVals[i + 1],
inVals[i + 2], inVals[i + 3]));
// Pack values
assert(outVals.size() == elems);
outVals = reorderValues(outVals, srcTensorType, dstTensorType);
outVals =
packI32(outVals, dstTensorType, rewriter, loc, getTypeConverter());
auto result = getTypeConverter()->packLLElements(loc, outVals, rewriter,
dstTensorType);
rewriter.replaceOp(op, result);
return success();
@@ -849,43 +975,44 @@ public:
ConversionPatternRewriter &rewriter) const override {
auto resultTy = op.getType();
Location loc = op->getLoc();
unsigned elems = getTotalElemsPerThread(resultTy);
// element type
auto resultElementTy = getElementTypeOrSelf(resultTy);
Type elemTy = this->getTypeConverter()->convertType(resultElementTy);
SmallVector<Type> types(elems, elemTy);
Type structTy = this->getTypeConverter()->convertType(resultTy);
auto *concreteThis = static_cast<const ConcreteT *>(this);
auto operands = getOperands(rewriter, adaptor, resultTy, elems, loc);
SmallVector<Value> resultVals(elems);
for (unsigned i = 0; i < elems; ++i) {
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
operands[i], loc);
if (!bool(resultVals[i]))
return failure();
SmallVector<Value> resultVals;
//
SmallVector<SmallVector<Value>> allOperands;
for (auto operand : adaptor.getOperands()) {
auto argTy = op->getOperand(0).getType();
auto sub_operands = this->getTypeConverter()->unpackLLElements(
loc, operand, rewriter, argTy);
sub_operands = unpackI32(sub_operands, argTy, rewriter, loc,
this->getTypeConverter());
allOperands.resize(sub_operands.size());
for (auto v : llvm::enumerate(sub_operands))
allOperands[v.index()].push_back(v.value());
}
if (allOperands.size() == 0)
allOperands.push_back({});
for (const SmallVector<Value> &operands : allOperands) {
Value curr =
((ConcreteT *)(this))
->createDestOp(op, adaptor, rewriter, elemTy, operands, loc);
if (!bool(curr))
return failure();
resultVals.push_back(curr);
}
if (op->getNumOperands() > 0) {
auto argTy = op->getOperand(0).getType();
resultVals = reorderValues(resultVals, argTy, resultTy);
}
resultVals =
packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter());
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, view);
return success();
}
protected:
SmallVector<SmallVector<Value>>
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
Type operandTy, const unsigned elems, Location loc) const {
SmallVector<SmallVector<Value>> operands(elems);
for (auto operand : adaptor.getOperands()) {
auto sub_operands = this->getTypeConverter()->unpackLLElements(
loc, operand, rewriter, operandTy);
for (size_t i = 0; i < elems; ++i) {
operands[i].push_back(sub_operands[i]);
}
}
return operands;
}
};
template <typename SourceOp, typename DestOp>
@@ -1344,8 +1471,7 @@ struct AbsFOpConversion
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit) {
PatternBenefit benefit) {
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)

View File

@@ -8,8 +8,7 @@ using namespace mlir::triton;
void populateElementwiseOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem, PatternBenefit benefit);
PatternBenefit benefit);
bool isLegalElementwiseOp(Operation *op);

View File

@@ -13,7 +13,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr;
// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass)
: axisAnalysisPass(axisAnalysisPass) {}
unsigned getContiguity(Value ptr) const {
@@ -38,7 +38,7 @@ struct LoadStoreConversionBase {
}
protected:
AxisInfoAnalysis &axisAnalysisPass;
ModuleAxisInfoAnalysis &axisAnalysisPass;
};
struct LoadOpConversion
@@ -48,7 +48,8 @@ struct LoadOpConversion
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
LoadOpConversion(TritonGPUToLLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
@@ -293,7 +294,8 @@ struct StoreOpConversion
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
StoreOpConversion(TritonGPUToLLVMTypeConverter &converter,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
@@ -335,14 +337,7 @@ struct StoreOpConversion
vec = std::min(vec, maskAlign);
}
// numElements = 1 for scalar
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
Value mask = getMask(valueTy, rewriter, loc);
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;
@@ -431,11 +426,11 @@ struct AtomicCASOpConversion
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
ModuleAllocation &allocation,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(
converter, allocation, smem, benefit),
converter, allocation, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
#ifdef USE_ROCM
@@ -526,13 +521,13 @@ struct AtomicCASOpConversion
auto valElements = getTypeConverter()->unpackLLElements(
loc, llVal, rewriter, op.getVal().getType());
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto valueTy = op.getResult().getType();
auto TensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
: op.getResult().getType();
: valueTy;
auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto tid = tid_val();
Value pred = icmp_eq(tid, i32_val(0));
Value mask = getMask(valueTy, rewriter, loc);
PTXBuilder ptxBuilderMemfence;
auto memfence = ptxBuilderMemfence.create<PTXInstr>("membar")->o("gl");
memfence();
@@ -552,7 +547,7 @@ struct AtomicCASOpConversion
auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r");
auto &atom = *ptxBuilderAtomicCAS.create<PTXInstr>("atom");
atom.global().o("cas").o("b32");
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred);
atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask);
auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy);
barrier();
@@ -561,7 +556,7 @@ struct AtomicCASOpConversion
auto *valOprStore = ptxBuilderStore.newOperand(old, "r");
auto &st = *ptxBuilderStore.create<PTXInstr>("st");
st.shared().o("b32");
st(dstOprStore, valOprStore).predicate(pred);
st(dstOprStore, valOprStore).predicate(mask);
ptxBuilderStore.launch(rewriter, loc, ASMReturnTy);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
barrier();
@@ -580,11 +575,11 @@ struct AtomicRMWOpConversion
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter,
const Allocation *allocation, Value smem,
AxisInfoAnalysis &axisAnalysisPass,
ModuleAllocation &allocation,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(
converter, allocation, smem, benefit),
converter, allocation, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
#ifdef USE_ROCM
@@ -747,10 +742,11 @@ struct AtomicRMWOpConversion
maskElements = getTypeConverter()->unpackLLElements(
loc, llMask, rewriter, op.getMask().getType());
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
auto valueTy = op.getResult().getType();
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
Type valueElemTy =
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
: op.getResult().getType();
: valueTy;
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
auto elemsPerThread = getTotalElemsPerThread(val.getType());
// vec = 1, numElements = 1 for scalar
@@ -763,10 +759,7 @@ struct AtomicRMWOpConversion
// mask
numElems = tensorTy.getNumElements();
}
Value mask = int_val(1, 1);
auto tid = tid_val();
mask = and_(mask,
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
Value mask = getMask(valueTy, rewriter, loc);
auto vecTy = vec_ty(valueElemTy, vec);
SmallVector<Value> resultVals(elemsPerThread);
@@ -846,7 +839,6 @@ struct AtomicRMWOpConversion
memfenc();
auto ASMReturnTy = void_ty(ctx);
ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy);
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy);
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
@@ -889,7 +881,9 @@ struct InsertSliceOpConversion
Value dst = op.getDest();
Value src = op.getSource();
Value res = op.getResult();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
auto funcOp = op->getParentOfType<FunctionOpInterface>();
auto *funcAllocation = allocation->getFuncData(funcOp);
assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice for now");
auto srcTy = src.getType().dyn_cast<RankedTensorType>();
@@ -949,12 +943,11 @@ struct InsertSliceAsyncOpConversion
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
InsertSliceAsyncOpConversion(
TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation,
Value smem,
TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
converter, allocation, smem, indexCacheInfo, benefit),
converter, allocation, indexCacheInfo, benefit),
LoadStoreConversionBase(axisAnalysisPass) {}
LogicalResult
@@ -967,7 +960,9 @@ struct InsertSliceAsyncOpConversion
Value res = op.getResult();
Value mask = op.getMask();
Value other = op.getOther();
assert(allocation->getBufferId(res) == Allocation::InvalidBufferId &&
auto funcOp = op->getParentOfType<FunctionOpInterface>();
auto *funcAllocation = allocation->getFuncData(funcOp);
assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId &&
"Only support in-place insert_slice_async for now");
auto srcTy = src.getType().cast<RankedTensorType>();
@@ -1107,19 +1102,17 @@ struct InsertSliceAsyncOpConversion
void populateLoadStoreOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<LoadOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<StoreOpConversion>(typeConverter, axisInfoAnalysis, benefit);
patterns.add<AtomicCASOpConversion>(typeConverter, allocation, smem,
patterns.add<AtomicCASOpConversion>(typeConverter, allocation,
axisInfoAnalysis, benefit);
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation, smem,
patterns.add<AtomicRMWOpConversion>(typeConverter, allocation,
axisInfoAnalysis, benefit);
patterns.add<InsertSliceOpConversion>(typeConverter, allocation, smem,
patterns.add<InsertSliceOpConversion>(typeConverter, allocation,
indexCacheInfo, benefit);
patterns.add<InsertSliceAsyncOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, axisInfoAnalysis,
benefit);
patterns.add<InsertSliceAsyncOpConversion>(
typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit);
}

View File

@@ -8,8 +8,7 @@ using namespace mlir::triton;
void populateLoadStoreOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

View File

@@ -87,6 +87,15 @@ private:
Attribute layout, SmallVector<Value> &index,
SmallVector<Value> &writeIdx,
std::map<int, Value> &ints, unsigned axis) const {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto dim = sliceLayout.getDim();
assert(dim != axis && "Reduction axis cannot be sliced");
auto parentLayout = sliceLayout.getParent();
getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints,
axis);
return;
}
writeIdx = index;
auto sizePerThread = triton::gpu::getSizePerThread(layout);
Value axisSizePerThread = ints[sizePerThread[axis]];
@@ -100,9 +109,10 @@ private:
// to map every `axisSizePerThread` to 1 value in smem as:
// writeIdx[axis] = index[axis] / axisSizePerThread
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>();
if (mmaLayout && mmaLayout.isAmpere()) {
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (!mmaLayout.isAmpere()) {
llvm::report_fatal_error("Unsupported layout");
}
if (axis == 0) {
// Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8
// rows in smem would correspond to a warp. The mapping
@@ -113,8 +123,7 @@ private:
// Same as BlockedEncodingAttr case
writeIdx[axis] = udiv(index[axis], axisSizePerThread);
}
}
if (mmaLayout && !mmaLayout.isAmpere()) {
} else {
llvm::report_fatal_error("Unsupported layout");
}
}
@@ -327,8 +336,8 @@ private:
elemPtrTys[i]);
}
unsigned sizeIntraWarps = helper.getIntraWarpSize();
unsigned sizeInterWarps = helper.getInterWarpSize();
unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData();
unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData();
unsigned srcElems = getTotalElemsPerThread(srcTys[0]);
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]);
@@ -495,10 +504,9 @@ private:
void populateReduceOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<ReduceOpConversion>(typeConverter, allocation, smem,
indexCacheInfo, benefit);
patterns.add<ReduceOpConversion>(typeConverter, allocation, indexCacheInfo,
benefit);
}

View File

@@ -8,8 +8,7 @@ using namespace mlir::triton;
void populateReduceOpToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

View File

@@ -401,8 +401,7 @@ struct MakeRangeOpConversion
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo,
benefit) {}
converter, indexCacheInfo, benefit) {}
LogicalResult
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
@@ -669,18 +668,17 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAllocation &moduleAllocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit) {
patterns.add<AddPtrOpConversion>(typeConverter, benefit);
patterns.add<AllocTensorOpConversion>(typeConverter, allocation, smem,
patterns.add<AllocTensorOpConversion>(typeConverter, moduleAllocation,
benefit);
patterns.add<AsyncCommitGroupOpConversion>(typeConverter, benefit);
patterns.add<AsyncWaitOpConversion>(typeConverter, benefit);
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
patterns.add<ExtractSliceOpConversion>(typeConverter, allocation, smem,
patterns.add<ExtractSliceOpConversion>(typeConverter, moduleAllocation,
benefit);
patterns.add<GetProgramIdOpConversion>(typeConverter, benefit);
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);

View File

@@ -8,8 +8,7 @@ using namespace mlir::triton;
void populateTritonGPUToLLVMPatterns(
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
ModuleAllocation &allocation,
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
PatternBenefit benefit);

View File

@@ -11,7 +11,7 @@
#include "Utility.h"
#include "mlir/IR/TypeUtilities.h"
#include "triton/Analysis/AxisInfo.h"
#include <set>
using namespace mlir;
using namespace mlir::triton;
@@ -41,7 +41,7 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
// All the rights are reserved by the LLVM community.
struct FuncOpConversionBase : public ConvertOpToLLVMPattern<triton::FuncOp> {
private:
protected:
/// Only retain those attributes that are not constructed by
/// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument
/// attributes.
@@ -184,14 +184,18 @@ public:
: converter(&typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
Value smem)
: converter(&typeConverter), allocation(allocation), smem(smem) {}
TritonGPUToLLVMTypeConverter &typeConverter,
IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), indexCacheInfo(indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
Value smem, IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), allocation(allocation), smem(smem),
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation)
: converter(&typeConverter), allocation(&allocation) {}
explicit ConvertTritonGPUOpToLLVMPatternBase(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
IndexCacheInfo indexCacheInfo)
: converter(&typeConverter), allocation(&allocation),
indexCacheInfo(indexCacheInfo) {}
TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; }
@@ -228,9 +232,17 @@ public:
T value) const {
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
auto bufferId = allocation->getBufferId(value);
FunctionOpInterface funcOp;
if constexpr (std::is_pointer_v<T>)
funcOp = value->template getParentOfType<FunctionOpInterface>();
else
funcOp = value.getParentRegion()
->template getParentOfType<FunctionOpInterface>();
auto *funcAllocation = allocation->getFuncData(funcOp);
auto smem = allocation->getFunctionSharedMemoryBase(funcOp);
auto bufferId = funcAllocation->getBufferId(value);
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
size_t offset = allocation->getOffset(bufferId);
size_t offset = funcAllocation->getOffset(bufferId);
Value offVal = i32_val(offset);
Value base = gep(ptrTy, smem, offVal);
return base;
@@ -409,6 +421,46 @@ public:
// -----------------------------------------------------------------------
// Utilities
// -----------------------------------------------------------------------
Value getMask(Type valueTy, ConversionPatternRewriter &rewriter,
Location loc) const {
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
Value mask = int_val(1, 1);
auto tid = tid_val();
if (tensorTy) {
auto layout = tensorTy.getEncoding();
auto shape = tensorTy.getShape();
unsigned rank = shape.size();
auto sizePerThread = triton::gpu::getSizePerThread(layout);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout);
auto order = triton::gpu::getOrder(layout);
auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape);
Value warpSize = i32_val(32);
Value laneId = urem(tid, warpSize);
Value warpId = udiv(tid, warpSize);
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, order);
SmallVector<Value> multiDimThreadId =
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
for (unsigned dim = 0; dim < rank; ++dim) {
// if there is no data replication across threads on this dimension
if (shape[dim] >= shapePerCTA[dim])
continue;
// Otherwise, we need to mask threads that will replicate data on this
// dimension. Calculate the thread index on this dimension for the CTA
Value threadDim =
add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])),
multiDimThreadId[dim]);
mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])),
i32_val(shape[dim])));
}
} else {
// If the tensor is not ranked, then it is a scalar and only thread 0 can
// write
mask = and_(mask, icmp_eq(tid, i32_val(0)));
}
return mask;
}
// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
@@ -505,13 +557,13 @@ public:
RankedTensorType type) const {
IndexCacheKeyT key = std::make_pair(layout, type);
auto cache = indexCacheInfo.baseIndexCache;
assert(cache && "baseIndexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache->count(key) > 0) {
if (cache && cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
restoreInsertionPointIfSet(insertPt, rewriter);
if (cache)
restoreInsertionPointIfSet(insertPt, rewriter);
SmallVector<Value> result;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
result =
@@ -521,11 +573,20 @@ public:
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
if (mmaLayout.isAmpere())
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentLayout);
result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy);
result.erase(result.begin() + sliceLayout.getDim());
} else {
llvm_unreachable("unsupported emitBaseIndexForLayout");
}
cache->insert(std::make_pair(key, result));
*insertPt = rewriter.saveInsertionPoint();
if (cache) {
cache->insert(std::make_pair(key, result));
*insertPt = rewriter.saveInsertionPoint();
}
return result;
}
}
@@ -540,6 +601,8 @@ public:
if (mmaLayout.isAmpere())
return emitOffsetForMmaLayoutV2(mmaLayout, type);
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>())
return emitOffsetForSliceLayout(sliceLayout, type);
llvm_unreachable("unsupported emitOffsetForLayout");
}
@@ -552,27 +615,29 @@ public:
RankedTensorType type) const {
IndexCacheKeyT key(layout, type);
auto cache = indexCacheInfo.indexCache;
assert(cache && "indexCache is nullptr");
auto insertPt = indexCacheInfo.indexInsertPoint;
if (cache->count(key) > 0) {
if (cache && cache->count(key) > 0) {
return cache->lookup(key);
} else {
ConversionPatternRewriter::InsertionGuard guard(b);
restoreInsertionPointIfSet(insertPt, b);
if (cache)
restoreInsertionPointIfSet(insertPt, b);
SmallVector<SmallVector<Value>> result;
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, blocked, type);
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
result = emitIndicesForDistributedLayout(loc, b, mma, type);
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
result = emitIndicesForSliceLayout(loc, b, slice, type);
result = emitIndicesForDistributedLayout(loc, b, slice, type);
} else {
llvm_unreachable(
"emitIndices for layouts other than blocked & slice not "
"implemented yet");
}
cache->insert(std::make_pair(key, result));
*insertPt = b.saveInsertionPoint();
if (cache) {
cache->insert(std::make_pair(key, result));
*insertPt = b.saveInsertionPoint();
}
return result;
}
}
@@ -722,11 +787,11 @@ private:
Value _fpw1 = i32_val(fpw[1]);
// A info
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();
@@ -783,12 +848,12 @@ private:
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
// A info
auto aEncoding =
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout);
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout, 0);
auto aRep = aEncoding.getMMAv1Rep();
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
// B info
auto bEncoding =
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout);
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout, 0);
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
auto bRep = bEncoding.getMMAv1Rep();
@@ -891,24 +956,29 @@ private:
return multiDimIdx;
}
SmallVector<SmallVector<Value>>
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
const SliceEncodingAttr &sliceLayout,
RankedTensorType type) const {
SmallVector<SmallVector<unsigned>>
emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout,
RankedTensorType type) const {
auto parentEncoding = sliceLayout.getParent();
unsigned dim = sliceLayout.getDim();
auto parentShape = sliceLayout.paddedShape(type.getShape());
RankedTensorType parentTy = RankedTensorType::get(
parentShape, type.getElementType(), parentEncoding);
auto parentIndices = emitIndices(loc, rewriter, parentEncoding, parentTy);
unsigned numIndices = parentIndices.size();
SmallVector<SmallVector<Value>> resultIndices;
for (unsigned i = 0; i < numIndices; ++i) {
SmallVector<Value> indices = parentIndices[i];
indices.erase(indices.begin() + dim);
resultIndices.push_back(indices);
auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy);
unsigned numOffsets = parentOffsets.size();
SmallVector<SmallVector<unsigned>> resultOffsets;
std::set<SmallVector<unsigned>> uniqueOffsets;
for (unsigned i = 0; i < numOffsets; ++i) {
SmallVector<unsigned> offsets = parentOffsets[i];
offsets.erase(offsets.begin() + dim);
if (uniqueOffsets.find(offsets) == uniqueOffsets.end()) {
resultOffsets.push_back(offsets);
uniqueOffsets.insert(offsets);
}
}
return resultIndices;
return resultOffsets;
}
#ifdef USE_ROCM
@@ -1057,8 +1127,7 @@ protected:
protected:
TritonGPUToLLVMTypeConverter *converter;
const Allocation *allocation;
Value smem;
ModuleAllocation *allocation;
IndexCacheInfo indexCacheInfo;
};
@@ -1075,16 +1144,22 @@ public:
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
Value smem, PatternBenefit benefit = 1)
TritonGPUToLLVMTypeConverter &typeConverter,
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, indexCacheInfo) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
Value smem, IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {}
explicit ConvertTritonGPUOpToLLVMPattern(
TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation,
IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation,
indexCacheInfo) {}
protected:

View File

@@ -2,7 +2,7 @@
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
@@ -61,17 +61,38 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
LogicalResult
matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
unsigned numArguments = op.getNumOperands();
// Currently, Triton kernel function always return nothing.
// TODO(Superjomn) add support for non-inline device function
if (numArguments > 0) {
return rewriter.notifyMatchFailure(
op, "Only kernel function with nothing returned is supported.");
auto funcOp = op->getParentOfType<LLVM::LLVMFuncOp>();
if (funcOp->hasAttr("nvvm.kernel")) {
// A GPU kernel
if (op.getNumOperands() > 0) {
return rewriter.notifyMatchFailure(
op, "Kernel functions do not support return with operands");
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
} else {
// A device function
LLVM::ReturnOp newOp;
if (adaptor.getOperands().size() < 2) {
// Single or no return value.
newOp =
rewriter.create<LLVM::ReturnOp>(op.getLoc(), adaptor.getOperands());
} else {
// Pack the results into a struct.
auto packedResultsTy = this->getTypeConverter()->packFunctionResults(
funcOp.getResultTypes());
Value packedResults =
rewriter.create<LLVM::UndefOp>(op.getLoc(), packedResultsTy);
auto loc = op.getLoc();
for (auto it : llvm::enumerate(adaptor.getOperands())) {
packedResults = insert_val(packedResultsTy, packedResults, it.value(),
it.index());
}
newOp = rewriter.create<LLVM::ReturnOp>(op.getLoc(), packedResults);
}
newOp->setAttrs(op->getAttrs());
rewriter.replaceOp(op, newOp->getResults());
}
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
op->getAttrs());
return success();
}
};
@@ -81,19 +102,57 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern<triton::ReturnOp> {
/// information.
struct FuncOpConversion : public FuncOpConversionBase {
FuncOpConversion(LLVMTypeConverter &converter, int numWarps,
PatternBenefit benefit)
: FuncOpConversionBase(converter, benefit), numWarps(numWarps) {}
ModuleAllocation &allocation, PatternBenefit benefit)
: FuncOpConversionBase(converter, benefit), numWarps(numWarps),
allocation(allocation) {}
triton::FuncOp amendFuncOp(triton::FuncOp funcOp,
ConversionPatternRewriter &rewriter) const {
// Push back a variable that indicates the current stack pointer of shared
// memory to the function arguments.
auto loc = funcOp.getLoc();
auto ctx = funcOp->getContext();
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
// 1. Modify the function type to add the new argument.
auto funcTy = funcOp.getFunctionType();
auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs());
amendedInputTy.push_back(ptrTy);
auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy,
funcTy.getResults());
// 2. Modify the argument attributes to add the new argument.
SmallVector<NamedAttribute> amendedAttrs;
filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs);
auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs());
amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx));
amendedAttrs.push_back(rewriter.getNamedAttr(
funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs)));
// 3. Add a new argument to the region
auto amendedFuncOp = rewriter.create<triton::FuncOp>(
funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs);
auto &region = funcOp.getBody();
region.addArgument(ptrTy, loc);
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
amendedFuncOp.end());
return amendedFuncOp;
}
LogicalResult
matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
// Prevent LLVM's inliner to inline this function
auto amendedFuncOp = funcOp;
if (!allocation.isRoot(funcOp))
amendedFuncOp = amendFuncOp(funcOp, rewriter);
auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter);
if (!newFuncOp) {
return failure();
}
auto ctx = funcOp->getContext();
<<<<<<< HEAD
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
@@ -102,6 +161,25 @@ struct FuncOpConversion : public FuncOpConversionBase {
// for `nvvm.annotation` metadata.
newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps));
#endif
=======
if (allocation.isRoot(funcOp)) {
// Set an attribute to indicate this function is a kernel entry.
newFuncOp->setAttr("nvvm.kernel",
rewriter.getIntegerAttr(type::u1Ty(ctx), 1));
} else {
// The noinline attribute will be used by the LLVM codegen to prevent
// inlining.
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267
newFuncOp.setPassthroughAttr(
ArrayAttr::get(ctx, rewriter.getStringAttr("noinline")));
rewriter.eraseOp(amendedFuncOp);
}
// Set an attribute for maxntidx, it could be used in latter LLVM codegen
// for `nvvm.annotation` metadata.
newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps));
// The call graph is updated by mapping the old function to the new one.
allocation.mapFuncOp(funcOp, newFuncOp);
>>>>>>> openai/main
rewriter.eraseOp(funcOp);
return success();
@@ -109,6 +187,99 @@ struct FuncOpConversion : public FuncOpConversionBase {
private:
int numWarps{0};
ModuleAllocation &allocation;
};
// CallOpInterfaceLowering is adapted from
// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485
struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
CallOpConversion(LLVMTypeConverter &converter, int numWarps,
ModuleAllocation &allocation, PatternBenefit benefit)
: ConvertOpToLLVMPattern<triton::CallOp>(converter, benefit),
numWarps(numWarps), allocation(allocation) {}
LogicalResult
matchAndRewrite(triton::CallOp callOp,
typename triton::CallOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto promotedOperands = promoteOperands(callOp, adaptor, rewriter);
auto newCallOp =
convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter);
if (!newCallOp)
return failure();
allocation.mapCallOp(callOp, newCallOp);
auto results = getCallOpResults(callOp, newCallOp, rewriter);
rewriter.replaceOp(callOp, results);
return success();
}
private:
SmallVector<Value, 4>
promoteOperands(triton::CallOp callOp,
typename triton::CallOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Get the last argument of the caller, which is the current stack pointer
// of shared memory and append it to the operands of the callOp.
auto loc = callOp.getLoc();
auto caller = callOp->getParentOfType<FunctionOpInterface>();
auto base = allocation.getFunctionSharedMemoryBase(caller);
auto *funcAllocation = allocation.getFuncData(caller);
auto bufferId = funcAllocation->getBufferId(callOp);
auto offset = funcAllocation->getOffset(bufferId);
auto ptrTy = LLVM::LLVMPointerType::get(
this->getTypeConverter()->convertType(rewriter.getI8Type()),
NVVM::kSharedMemorySpace);
auto offsetValue = gep(ptrTy, base, i32_val(offset));
auto promotedOperands = this->getTypeConverter()->promoteOperands(
callOp.getLoc(), /*opOperands=*/callOp->getOperands(),
adaptor.getOperands(), rewriter);
promotedOperands.push_back(offsetValue);
return promotedOperands;
}
LLVM::CallOp
convertCallOpToLLVMCallOp(triton::CallOp callOp,
ArrayRef<Value> promotedOperands,
ConversionPatternRewriter &rewriter) const {
// Pack the result types into a struct.
Type packedResult = nullptr;
unsigned numResults = callOp.getNumResults();
auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes());
if (numResults != 0) {
if (!(packedResult =
this->getTypeConverter()->packFunctionResults(resultTypes)))
return nullptr;
}
auto newCallOp = rewriter.create<LLVM::CallOp>(
callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(),
promotedOperands, callOp->getAttrs());
return newCallOp;
}
SmallVector<Value>
getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp,
ConversionPatternRewriter &rewriter) const {
auto numResults = callOp.getNumResults();
SmallVector<Value> results;
if (numResults < 2) {
// If < 2 results, packing did not do anything and we can just return.
results.append(newCallOp.result_begin(), newCallOp.result_end());
} else {
// Otherwise, it had been converted to an operation producing a structure.
// Extract individual results from the structure and return them as list.
results.reserve(numResults);
for (unsigned i = 0; i < numResults; ++i) {
results.push_back(rewriter.create<LLVM::ExtractValueOp>(
callOp.getLoc(), newCallOp->getResult(0), i));
}
}
return results;
}
int numWarps{0};
ModuleAllocation &allocation;
};
class TritonLLVMConversionTarget : public ConversionTarget {
@@ -145,26 +316,25 @@ public:
TritonLLVMConversionTarget target(*context, isROCM);
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
/* preprocess */
// Preprocess
decomposeMmaToDotOperand(mod, numWarps);
decomposeBlockedToDotOperand(mod);
if (failed(decomposeInsertSliceAsyncOp(mod)))
return signalPassFailure();
/* allocate shared memory and set barrier */
Allocation allocation(mod);
MembarAnalysis membarPass(&allocation);
// Allocate shared memory and set barrier
ModuleAllocation allocation(mod);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();
/* lower functions */
// Lower functions
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
RewritePatternSet funcPatterns(context);
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, allocation,
/*benefit=*/1);
funcPatterns.add<ReturnOpConversion>(typeConverter);
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
funcPatterns);
if (failed(
@@ -172,36 +342,49 @@ public:
return signalPassFailure();
}
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(mod)))
return signalPassFailure();
initSharedMemory(allocation.getSharedMemorySize(), typeConverter);
mod->setAttr("triton_gpu.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
allocation.getSharedMemorySize()));
// initSharedMemory is run before the conversion of call and ret ops,
// because the call op has to know the shared memory base address of each
// function
initSharedMemory(allocation, typeConverter);
/* rewrite ops */
// Convert call and ret ops
{
mlir::LowerToLLVMOptions option(context);
TritonGPUToLLVMTypeConverter typeConverter(context, option);
TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM);
RewritePatternSet funcPatterns(context);
funcPatterns.add<CallOpConversion>(typeConverter, numWarps, allocation,
/*benefit=*/1);
funcPatterns.add<ReturnOpConversion>(typeConverter, /*benefit=*/1);
if (failed(
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
return signalPassFailure();
}
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
// Rewrite ops
RewritePatternSet patterns(context);
// TritonGPU lowering patterns
OpBuilder::InsertPoint indexInsertPoint;
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
&baseIndexCache, &indexCache, &indexInsertPoint};
auto populatePatterns1 = [&](auto populateFunc) {
populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis,
&allocation, smem, indexCacheInfo, /*benefit*/ 1);
};
auto populatePatterns2 = [&](auto populateFunc) {
populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis,
&allocation, smem, /*benefit*/ 1);
};
populatePatterns1(populateTritonGPUToLLVMPatterns);
populatePatterns1(populateConvertLayoutOpToLLVMPatterns);
populatePatterns2(populateDotOpToLLVMPatterns);
populatePatterns2(populateElementwiseOpToLLVMPatterns);
populatePatterns1(populateLoadStoreOpToLLVMPatterns);
populatePatterns1(populateReduceOpToLLVMPatterns);
populatePatterns2(populateViewOpToLLVMPatterns);
// TODO: enable index cache if there are multiple functions
if (axisInfoAnalysis.getNumFunctions() > 1) {
indexCacheInfo = {nullptr, nullptr, nullptr};
}
populateTritonGPUToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateDotOpToLLVMPatterns(typeConverter, patterns, allocation,
/*benefit=*/1);
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis,
allocation, indexCacheInfo,
/*benefit=*/1);
populateReduceOpToLLVMPatterns(typeConverter, patterns, allocation,
indexCacheInfo, /*benefit=*/1);
populateViewOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1);
// Native lowering patterns
if (isROCM) {
@@ -218,8 +401,6 @@ public:
}
private:
Value smem;
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
baseIndexCache;
@@ -230,10 +411,11 @@ private:
int computeCapability{};
bool isROCM{};
void initSharedMemory(size_t size,
void initSharedMemory(ModuleAllocation &allocation,
TritonGPUToLLVMTypeConverter &typeConverter) {
ModuleOp mod = getOperation();
OpBuilder b(mod.getBodyRegion());
auto ctx = mod.getContext();
auto loc = mod.getLoc();
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
// Set array size 0 and external linkage indicates that we use dynamic
@@ -244,15 +426,23 @@ private:
"global_smem", /*value=*/Attribute(), /*alignment=*/0,
// Add ROCm support.
static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace));
SmallVector<LLVM::LLVMFuncOp> funcs;
mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); });
assert(funcs.size() == 1 &&
"Inliner pass is expected before TritonGPUToLLVM");
b.setInsertionPointToStart(&funcs[0].getBody().front());
smem = b.create<LLVM::AddressOfOp>(loc, global);
auto ptrTy =
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3);
smem = b.create<LLVM::BitcastOp>(loc, ptrTy, smem);
mod.walk([&](FunctionOpInterface funcOp) {
Value funcSmem;
b.setInsertionPointToStart(&funcOp.getFunctionBody().front());
if (allocation.isRoot(funcOp)) {
funcSmem = b.create<LLVM::AddressOfOp>(loc, global);
} else {
funcSmem = funcOp.getArgument(funcOp.getNumArguments() - 1);
}
auto ptrTy =
LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()),
NVVM::NVVMMemorySpace::kSharedMemorySpace);
funcSmem = b.create<LLVM::BitcastOp>(loc, ptrTy, funcSmem);
allocation.setFunctionSharedMemoryValue(funcOp, funcSmem);
});
mod->setAttr("triton_gpu.shared",
mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32),
allocation.getSharedMemorySize()));
}
void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) const {
@@ -310,10 +500,7 @@ private:
}
LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const {
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(mod)))
return failure();
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
// TODO(Keren): This is a hacky knob that may cause performance regression
// when decomposition has been performed. We should remove this knob once we
// have thorough analysis on async wait. Currently, we decompose
@@ -347,7 +534,7 @@ private:
auto resSharedLayout =
dstTy.getEncoding().dyn_cast<triton::gpu::SharedEncodingAttr>();
auto resElemTy = dstTy.getElementType();
unsigned inVec = axisInfoAnalysis->getPtrContiguity(src);
unsigned inVec = axisInfoAnalysis.getPtrContiguity(src);
unsigned outVec = resSharedLayout.getVec();
unsigned minVec = std::min(outVec, inVec);
auto maxBitWidth =

View File

@@ -106,17 +106,8 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct(
return elemTy;
if (mmaParent.isAmpere()) {
int bitwidth = elemTy.getIntOrFloatBitWidth();
// sub-word integer types need to be packed for perf reasons
if (elemTy.isa<IntegerType>() && bitwidth < 32)
return IntegerType::get(ctx, 32);
// TODO: unify everything to use packed integer-types
// otherwise, vector types are ok
const llvm::DenseMap<int, Type> elemTyMap = {
{32, vec_ty(elemTy, 1)},
{16, vec_ty(elemTy, 2)},
{8, vec_ty(elemTy, 4)},
};
return elemTyMap.lookup(bitwidth);
assert(bitwidth <= 32);
return IntegerType::get(ctx, 32);
} else {
assert(mmaParent.isVolta());
return vec_ty(elemTy, 2);

View File

@@ -88,6 +88,7 @@
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
// Types
#define int_ty(width) rewriter.getIntegerType(width)
#define i64_ty rewriter.getIntegerType(64)
#define i32_ty rewriter.getIntegerType(32)
#define i16_ty rewriter.getIntegerType(16)

View File

@@ -116,23 +116,64 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
}
};
template <typename SourceOp>
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
using OpAdaptor = typename SourceOp::Adaptor;
explicit ViewLikeOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern<ViewOp> {
using OpAdaptor = typename ViewOp::Adaptor;
explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ViewOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
matchAndRewrite(ViewOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto resultTy = op.getType().template cast<RankedTensorType>();
auto vals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
Value view =
Value ret =
this->getTypeConverter()->packLLElements(loc, vals, rewriter, resultTy);
rewriter.replaceOp(op, view);
rewriter.replaceOp(op, ret);
return success();
}
};
struct ExpandDimsOpConversion
: public ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp> {
using OpAdaptor = typename ExpandDimsOp::Adaptor;
explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
PatternBenefit benefit = 1)
: ConvertTritonGPUOpToLLVMPattern<ExpandDimsOp>(typeConverter, benefit) {}
LogicalResult
matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto srcVals = this->getTypeConverter()->unpackLLElements(
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
auto srcTy = op.getSrc().getType().cast<RankedTensorType>();
auto resultTy = op.getType().template cast<RankedTensorType>();
assert(srcTy.getEncoding().isa<SliceEncodingAttr>() &&
"ExpandDimsOp only support SliceEncodingAttr");
auto srcLayout = srcTy.getEncoding().dyn_cast<SliceEncodingAttr>();
auto resultLayout = resultTy.getEncoding();
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
for (size_t i = 0; i < srcOffsets.size(); i++) {
srcValues[srcOffsets[i]] = srcVals[i];
}
SmallVector<Value> resultVals;
for (size_t i = 0; i < resultOffsets.size(); i++) {
auto offset = resultOffsets[i];
offset.erase(offset.begin() + srcLayout.getDim());
resultVals.push_back(srcValues.lookup(offset));
}
Value ret = this->getTypeConverter()->packLLElements(loc, resultVals,
rewriter, resultTy);
rewriter.replaceOp(op, ret);
return success();
}
};
@@ -161,13 +202,10 @@ struct TransOpConversion
};
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
RewritePatternSet &patterns,
PatternBenefit benefit) {
patterns.add<ViewLikeOpConversion<triton::ViewOp>>(typeConverter, benefit);
patterns.add<ViewLikeOpConversion<triton::ExpandDimsOp>>(typeConverter,
benefit);
patterns.add<ViewOpConversion>(typeConverter, benefit);
patterns.add<ExpandDimsOpConversion>(typeConverter, benefit);
patterns.add<SplatOpConversion>(typeConverter, benefit);
patterns.add<ArithConstantSplatOpConversion>(typeConverter, benefit);
patterns.add<CatOpConversion>(typeConverter, benefit);

View File

@@ -7,9 +7,7 @@ using namespace mlir;
using namespace mlir::triton;
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
RewritePatternSet &patterns, int numWarps,
AxisInfoAnalysis &axisInfoAnalysis,
const Allocation *allocation, Value smem,
RewritePatternSet &patterns,
PatternBenefit benefit);
#endif

View File

@@ -14,6 +14,7 @@ add_mlir_conversion_library(TritonToTritonGPU
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
MLIRTransforms
TritonIR
TritonGPUIR
TritonGPUTransforms

View File

@@ -67,19 +67,21 @@ public:
matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Type retType = getTypeConverter()->convertType(op.getType());
auto retShapedType = retType.cast<ShapedType>();
auto value = adaptor.getValue().dyn_cast<DenseElementsAttr>();
if (dyn_cast<RankedTensorType>(retType)) {
if (dyn_cast<RankedTensorType>(retShapedType)) {
assert(value);
if (value.getElementType().isInteger(1) && value.isSplat())
// Workaround until https://reviews.llvm.org/D133743 is included.
value = DenseElementsAttr::get(retType, value.getSplatValue<bool>());
value =
DenseElementsAttr::get(retShapedType, value.getSplatValue<bool>());
else
// This is a hack. We just want to add encoding
value = value.reshape(retType);
value = value.reshape(retShapedType);
}
addNamedAttrs(
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value),
adaptor.getAttributes());
addNamedAttrs(rewriter.replaceOpWithNewOp<arith::ConstantOp>(
op, retShapedType, value),
adaptor.getAttributes());
return success();
}
};
@@ -165,8 +167,6 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
MLIRContext *context = patterns.getContext();
// Rewrite rule
patterns.add<StdSelectPattern>(typeConverter, context);
target.addLegalOp<triton::ReturnOp>(); // this is ok because all functions are
// inlined by the frontend
}
void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter,
@@ -274,6 +274,8 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
// a & b must be of smem layout
auto aType = adaptor.getA().getType().cast<RankedTensorType>();
auto bType = adaptor.getB().getType().cast<RankedTensorType>();
Type aEltType = aType.getElementType();
Type bEltType = bType.getElementType();
Attribute aEncoding = aType.getEncoding();
Attribute bEncoding = bType.getEncoding();
if (!aEncoding || !bEncoding)
@@ -282,17 +284,17 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
Value b = adaptor.getB();
Value c = adaptor.getC();
if (!aEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding);
auto dstType = RankedTensorType::get(aType.getShape(),
aType.getElementType(), encoding);
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 0, dEncoding, aEltType);
auto dstType =
RankedTensorType::get(aType.getShape(), aEltType, encoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), dstType, a);
}
if (!bEncoding.isa<triton::gpu::DotOperandEncodingAttr>()) {
Attribute encoding =
triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding);
auto dstType = RankedTensorType::get(bType.getShape(),
bType.getElementType(), encoding);
Attribute encoding = triton::gpu::DotOperandEncodingAttr::get(
getContext(), 1, dEncoding, bEltType);
auto dstType =
RankedTensorType::get(bType.getShape(), bEltType, encoding);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), dstType, b);
}
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
@@ -533,6 +535,52 @@ struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
}
};
class TritonFuncOpPattern : public OpConversionPattern<triton::FuncOp> {
public:
using OpConversionPattern<triton::FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
op, op.getName(), op.getFunctionType());
addNamedAttrs(newOp, adaptor.getAttributes());
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
newOp.getBody().end());
if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter)))
return failure();
return success();
}
};
class TritonCallOpPattern : public OpConversionPattern<triton::CallOp> {
public:
using OpConversionPattern<triton::CallOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::CallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto newOp = rewriter.replaceOpWithNewOp<triton::CallOp>(
op, op.getCallee(), op.getResultTypes(), adaptor.getOperands());
addNamedAttrs(newOp, adaptor.getAttributes());
return success();
}
};
class TritonReturnOpPattern : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<ReturnOp>(op, adaptor.getOperands());
return success();
}
};
void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
@@ -550,7 +598,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
TritonLoadPattern, TritonStorePattern,
TritonExternElementwisePattern<triton::PureExternElementwiseOp>,
TritonExternElementwisePattern<triton::ImpureExternElementwiseOp>,
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>(
TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern,
TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>(
typeConverter, context);
}
@@ -752,31 +801,10 @@ public:
}
};
class FuncOpPattern : public OpConversionPattern<triton::FuncOp> {
public:
using OpConversionPattern<triton::FuncOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto converter = getTypeConverter();
auto newOp = rewriter.replaceOpWithNewOp<triton::FuncOp>(
op, op.getName(), op.getFunctionType());
addNamedAttrs(newOp, adaptor.getAttributes());
rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(),
newOp.getBody().end());
if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter)))
return failure();
return success();
}
};
void populateCFPatterns(TritonGPUTypeConverter &typeConverter,
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
patterns.add<FuncOpPattern, CFCondBranchPattern, CFBranchPattern>(
typeConverter, context);
patterns.add<CFCondBranchPattern, CFBranchPattern>(typeConverter, context);
}
//

View File

@@ -1,5 +1,4 @@
add_mlir_dialect_library(TritonIR
Interfaces.cpp
Dialect.cpp
Ops.cpp
Types.cpp
@@ -11,5 +10,6 @@ add_mlir_dialect_library(TritonIR
LINK_LIBS PUBLIC
MLIRIR
MLIRArithDialect
MLIRMathDialect
MLIRSCFDialect
)

View File

@@ -22,14 +22,22 @@ using namespace mlir::triton;
namespace {
struct TritonInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
bool isLegalToInline(Operation *call, Operation *callable,
bool wouldBeCloned) const final {
auto funcOp = dyn_cast<triton::FuncOp>(callable);
if (!funcOp)
return true;
if (funcOp->hasAttr("noinline"))
return !funcOp->getAttrOfType<BoolAttr>("noinline").getValue();
return true;
}
bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
IRMapping &valueMapping) const final {
return true;
}
bool isLegalToInline(Operation *, Region *, bool wouldBeCloned,
IRMapping &) const final {
return true;
@@ -83,5 +91,5 @@ void TritonDialect::initialize() {
Operation *TritonDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
return builder.create<arith::ConstantOp>(loc, type, value);
return arith::ConstantOp::materialize(builder, value, type, loc);
}

View File

@@ -230,6 +230,54 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
state.addTypes({resultType});
}
// load(ptr, splat(1), ...) -> load(ptr, ...)
// load(ptr, splat(0), other, ...) -> other
struct CanonicalizeMaskedLoadPattern
: public mlir::OpRewritePattern<triton::LoadOp> {
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
: OpRewritePattern<triton::LoadOp>(context, 1) {}
mlir::LogicalResult
matchAndRewrite(triton::LoadOp loadOp,
mlir::PatternRewriter &rewriter) const override {
auto mask = loadOp.getMask();
if (!mask)
return mlir::failure();
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
if (!splatMask)
return mlir::failure();
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::LoadOp>(
loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(),
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
} else {
// mask = splat(0)
// If there's no "other", the value is "undef". Perhaps we want to
// optimize it in the future.x
auto otherVal = loadOp.getOther();
if (!otherVal)
return mlir::failure();
rewriter.replaceOp(loadOp, otherVal);
}
return mlir::success();
}
};
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeMaskedLoadPattern>(context);
}
//-- StoreOp --
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
::mlir::Value ptr, ::mlir::Value value,
@@ -257,10 +305,51 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
evict);
}
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
// store(ptr, value, splat(0), ...) -> [none]
struct CanonicalizeMaskedStorePattern
: public mlir::OpRewritePattern<triton::StoreOp> {
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
: OpRewritePattern<triton::StoreOp>(context, 1) {}
mlir::LogicalResult
matchAndRewrite(triton::StoreOp storeOp,
mlir::PatternRewriter &rewriter) const override {
auto mask = storeOp.getMask();
if (!mask)
return mlir::failure();
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
if (!splatMask)
return mlir::failure();
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::StoreOp>(
storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(),
storeOp.getEvict());
} else {
// mask = splat(0)
rewriter.eraseOp(storeOp);
}
return mlir::success();
}
};
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeMaskedStorePattern>(context);
}
//-- TransOp --
mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the input
auto argTy = operands[0].getType().cast<RankedTensorType>();
@@ -287,7 +376,7 @@ mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes(
//-- DotOp --
mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// type is the same as the accumulator
auto accTy = operands[2].getType().cast<RankedTensorType>();
@@ -355,7 +444,7 @@ void ReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state,
mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
for (auto arg : operands) {
auto argTy = arg.getType().cast<RankedTensorType>();
@@ -462,7 +551,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) {
//-- ExpandDimsOp --
mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes(
MLIRContext *context, std::optional<Location> loc, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
// infer shape
auto arg = operands[0];

View File

@@ -9,4 +9,9 @@ add_mlir_dialect_library(TritonTransforms
DEPENDS
TritonTransformsIncGen
TritonCombineIncGen
LINK_LIBS PUBLIC
MLIRPass
MLIRTransformUtils
TritonIR
)

View File

@@ -37,7 +37,7 @@ bool isBroadcastConstantCombinable(Attribute value) {
DenseElementsAttr getConstantValue(Builder &builder, Attribute value,
Value bcast_res) {
Type resType = bcast_res.getType();
auto resType = bcast_res.getType().cast<ShapedType>();
DenseElementsAttr res;
if (auto denseValue = value.dyn_cast<DenseElementsAttr>()) {
res =
@@ -101,95 +101,6 @@ public:
}
};
// load(ptr, splat(1), ...) -> load(ptr, ...)
// load(ptr, splat(0), other, ...) -> other
struct CanonicalizeMaskedLoadPattern
: public mlir::OpRewritePattern<triton::LoadOp> {
CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context)
: OpRewritePattern<triton::LoadOp>(context, 1) {}
mlir::LogicalResult
matchAndRewrite(triton::LoadOp loadOp,
mlir::PatternRewriter &rewriter) const override {
auto mask = loadOp.getMask();
if (!mask)
return mlir::failure();
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
if (!splatMask)
return mlir::failure();
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::LoadOp>(
loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(),
loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(),
loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile());
} else {
// mask = splat(0)
// If there's no "other", the value is "undef". Perhaps we want to
// optimize it in the future.x
auto otherVal = loadOp.getOther();
if (!otherVal)
return mlir::failure();
rewriter.replaceOp(loadOp, otherVal);
}
return mlir::success();
}
};
void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeMaskedLoadPattern>(context);
}
// store(ptr, value, splat(1), ...) -> store(ptr, value, ...)
// store(ptr, value, splat(0), ...) -> [none]
struct CanonicalizeMaskedStorePattern
: public mlir::OpRewritePattern<triton::StoreOp> {
CanonicalizeMaskedStorePattern(mlir::MLIRContext *context)
: OpRewritePattern<triton::StoreOp>(context, 1) {}
mlir::LogicalResult
matchAndRewrite(triton::StoreOp storeOp,
mlir::PatternRewriter &rewriter) const override {
auto mask = storeOp.getMask();
if (!mask)
return mlir::failure();
auto constantMask =
llvm::dyn_cast_or_null<arith::ConstantOp>(mask.getDefiningOp());
if (!constantMask)
return mlir::failure();
auto splatMask = constantMask.getValue().dyn_cast<SplatElementsAttr>();
if (!splatMask)
return mlir::failure();
if (splatMask.getSplatValue<IntegerAttr>().getValue() == true) {
// mask = splat(1)
rewriter.replaceOpWithNewOp<triton::StoreOp>(
storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(),
storeOp.getEvict());
} else {
// mask = splat(0)
rewriter.eraseOp(storeOp);
}
return mlir::success();
}
};
void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<CanonicalizeMaskedStorePattern>(context);
}
#define GEN_PASS_CLASSES
#include "triton/Dialect/Triton/Transforms/Passes.h.inc"

View File

@@ -167,10 +167,10 @@ public:
auto otherTensorType = RankedTensorType::get(tensorShape, elementType);
// Set zero padding value
Attribute attr =
TypedAttr attr =
elementType.isIntOrIndex()
? builder.getIntegerAttr(elementType, 0).cast<Attribute>()
: builder.getFloatAttr(elementType, 0).cast<Attribute>();
? builder.getIntegerAttr(elementType, 0).cast<TypedAttr>()
: builder.getFloatAttr(elementType, 0).cast<TypedAttr>();
// Float NaN padding case
if (padding.value() == triton::PaddingOption::PAD_NAN) {

View File

@@ -7,5 +7,6 @@ add_mlir_dialect_library(TritonGPUIR
TritonGPUAttrDefsIncGen
LINK_LIBS PUBLIC
MLIRGPUOps
TritonIR
)

View File

@@ -81,10 +81,41 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout) {
if (mmaLayout.isAmpere())
return {8, 4};
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentThreadsPerWarp = getThreadsPerWarp(parent);
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < threadsPerWarp.size(); i++)
threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()];
return threadsPerWarp;
}
assert(0 && "getThreadsPerWarp not implemented");
return {};
}
SmallVector<unsigned>
getThreadsPerWarpWithUniqueData(Attribute layout,
ArrayRef<int64_t> tensorShape) {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(tensorShape);
auto parentThreadsPerWarp =
getThreadsPerWarpWithUniqueData(parentLayout, parentShape);
SmallVector<unsigned> threadsPerWarp = parentThreadsPerWarp;
threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim());
return threadsPerWarp;
}
auto threadsPerWarp = getThreadsPerWarp(layout);
assert(threadsPerWarp.size() == tensorShape.size() &&
"layout and tensor shape must have the same rank");
for (unsigned i = 0; i < threadsPerWarp.size(); i++) {
threadsPerWarp[i] = std::min<unsigned>(threadsPerWarp[i], tensorShape[i]);
}
return threadsPerWarp;
}
SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getWarpsPerCTA().begin(),
@@ -94,19 +125,51 @@ SmallVector<unsigned> getWarpsPerCTA(Attribute layout) {
return SmallVector<unsigned>(mmaLayout.getWarpsPerCTA().begin(),
mmaLayout.getWarpsPerCTA().end());
}
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parent = sliceLayout.getParent();
auto parentWarpsPerCTA = getWarpsPerCTA(parent);
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
for (unsigned i = 0; i < warpsPerCTA.size(); i++)
warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()];
return warpsPerCTA;
}
assert(0 && "getWarpsPerCTA not implemented");
return {};
}
SmallVector<unsigned>
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape) {
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(tensorShape);
auto parentWarpsPerCTA =
getWarpsPerCTAWithUniqueData(parentLayout, parentShape);
SmallVector<unsigned> warpsPerCTA = parentWarpsPerCTA;
warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim());
return warpsPerCTA;
}
auto warpsPerCTA = getWarpsPerCTA(layout);
assert(warpsPerCTA.size() == tensorShape.size() &&
"layout and tensor shape must have the same rank");
for (unsigned i = 0; i < warpsPerCTA.size(); i++) {
auto sizePerWarp =
getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i];
auto maxWarpsPerDim = ceil<unsigned>(tensorShape[i], sizePerWarp);
warpsPerCTA[i] = std::min<unsigned>(warpsPerCTA[i], maxWarpsPerDim);
}
return warpsPerCTA;
}
SmallVector<unsigned> getSizePerThread(Attribute layout) {
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
return SmallVector<unsigned>(blockedLayout.getSizePerThread().begin(),
blockedLayout.getSizePerThread().end());
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto ret = getSizePerThread(sliceLayout.getParent());
return ret;
// ret.erase(ret.begin() + sliceLayout.getDim());
return ret;
auto sizePerThread = getSizePerThread(sliceLayout.getParent());
sizePerThread.erase(sizePerThread.begin() + sliceLayout.getDim());
return sizePerThread;
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.isAmpere()) {
return {2, 2};
@@ -146,11 +209,43 @@ SmallVector<unsigned> getContigPerThread(Attribute layout) {
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
assert(mmaLayout.isVolta() || mmaLayout.isAmpere());
return {1, 2};
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
return getContigPerThread(parentLayout);
} else {
return getSizePerThread(layout);
}
}
SmallVector<unsigned> getUniqueContigPerThread(Type type) {
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
return SmallVector<unsigned>(1, 1);
auto tensorType = type.cast<RankedTensorType>();
auto shape = tensorType.getShape();
// If slice layout, call recursively on parent layout, and drop
// sliced dim
if (auto sliceLayout =
tensorType.getEncoding().dyn_cast<SliceEncodingAttr>()) {
auto parentLayout = sliceLayout.getParent();
auto parentShape = sliceLayout.paddedShape(shape);
auto parentTy = RankedTensorType::get(
parentShape, tensorType.getElementType(), parentLayout);
auto parentUniqueContigPerThread = getUniqueContigPerThread(parentTy);
parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() +
sliceLayout.getDim());
return parentUniqueContigPerThread;
}
// Base case
auto rank = shape.size();
SmallVector<unsigned> ret(rank);
auto contigPerThread = getContigPerThread(tensorType.getEncoding());
assert(contigPerThread.size() == rank && "Unexpected contigPerThread size");
for (int d = 0; d < rank; ++d) {
ret[d] = std::min<unsigned>(shape[d], contigPerThread[d]);
}
return ret;
}
SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
SmallVector<unsigned> threads;
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
@@ -158,7 +253,7 @@ SmallVector<unsigned> getThreadsPerCTA(Attribute layout) {
threads.push_back(blockedLayout.getThreadsPerWarp()[d] *
blockedLayout.getWarpsPerCTA()[d]);
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
if (mmaLayout.getVersionMajor() == 2) {
if (mmaLayout.isAmpere()) {
threads = {8 * mmaLayout.getWarpsPerCTA()[0],
4 * mmaLayout.getWarpsPerCTA()[1]};
} else
@@ -261,6 +356,16 @@ bool isaDistributedLayout(Attribute layout) {
} // namespace gpu
} // namespace triton
bool isSharedEncoding(Value value) {
auto type = value.getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
auto encoding = tensorType.getEncoding();
return encoding && encoding.isa<triton::gpu::SharedEncodingAttr>();
}
return false;
}
} // namespace mlir
static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr,
@@ -375,6 +480,7 @@ SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
auto parent = getParent();
auto parentElemsPerThread =
::getElemsPerThread(parent, paddedShape(shape), eltTy);
parentElemsPerThread.erase(parentElemsPerThread.begin() + getDim());
return parentElemsPerThread;
}
unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
@@ -774,14 +880,27 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
return {};
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
Attribute parent = attrs.get("parent");
auto mmaParent = parent.dyn_cast<MmaEncodingAttr>();
unsigned kWidth = 0;
Attribute _kWidth = attrs.get("kWidth");
if (_kWidth) {
if (!mmaParent || mmaParent.isVolta()) {
auto loc = parser.getNameLoc();
parser.emitError(loc, "kWidth only supported for MMAv2+ parent");
return Attribute();
}
kWidth = _kWidth.cast<IntegerAttr>().getInt();
}
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
parent);
parent, kWidth);
}
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>();
printer << "<{"
<< "opIdx = " << getOpIdx() << ", "
<< "parent = " << getParent();
<< "opIdx = " << getOpIdx() << ", parent = " << getParent();
if (mmaParent && mmaParent.isAmpere())
printer << ", kWidth = " << getMMAv2kWidth();
printer << "}>";
}
@@ -1029,9 +1148,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
return mlir::failure();
}
auto newType = op->getResult(0).getType().cast<RankedTensorType>();
// Ensure that the new insert_slice op is placed in the same place as the
// old insert_slice op. Otherwise, the new insert_slice op may be placed
// after the async_wait op, which is not allowed.
// Ensure that the new insert_slice op is placed in the same place as
// the old insert_slice op. Otherwise, the new insert_slice op may be
// placed after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(insert_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -1059,9 +1178,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
auto resType = RankedTensorType::get(
origResType.getShape(), origResType.getElementType(),
extract_slice.getType().cast<RankedTensorType>().getEncoding());
// Ensure that the new extract_slice op is placed in the same place as the
// old extract_slice op. Otherwise, the new extract_slice op may be placed
// after the async_wait op, which is not allowed.
// Ensure that the new extract_slice op is placed in the same place as
// the old extract_slice op. Otherwise, the new extract_slice op may be
// placed after the async_wait op, which is not allowed.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(extract_slice);
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
@@ -1109,8 +1228,8 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
// cvt(type, constant) -> constant
if (auto cst = llvm::dyn_cast<arith::ConstantOp>(arg))
if (auto ret = cst.getValue().dyn_cast<SplatElementsAttr>()) {
auto newRet = SplatElementsAttr::get(op->getResultTypes().front(),
ret.getSplatValue<Attribute>());
auto ty = op->getResultTypes().front().cast<ShapedType>();
auto newRet = SplatElementsAttr::get(ty, ret.getSplatValue<Attribute>());
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newRet);
return mlir::success();
}

View File

@@ -1,5 +1,5 @@
#include "triton/Dialect/TritonGPU/IR/Traits.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
mlir::LogicalResult
mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) {

View File

@@ -47,12 +47,13 @@ SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
const ArrayRef<int64_t> shape,
int numWarps) {
SetVector<Operation *> slices;
mlir::getForwardSlice(dotOp.getResult(), &slices);
if (llvm::find_if(slices, [](Operation *op) {
return isa<triton::DotOp>(op);
}) != slices.end())
return {(unsigned)numWarps, 1};
auto filter = [&dotOp](Operation *op) {
return op->getParentRegion() == dotOp->getParentRegion();
};
auto slices = mlir::getSlice(dotOp, filter);
for (Operation *op : slices)
if (isa<triton::DotOp>(op) && (op != dotOp))
return {(unsigned)numWarps, 1};
SmallVector<unsigned, 2> ret = {1, 1};
SmallVector<int64_t, 2> shapePerWarp = {16, 8};
@@ -173,14 +174,17 @@ public:
.cast<triton::gpu::BlockedEncodingAttr>()
.getOrder();
auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldAType.getContext(), 0, newRetType.getEncoding(),
oldAType.getElementType());
auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get(
oldBType.getContext(), 1, newRetType.getEncoding(),
oldBType.getElementType());
auto newAType = RankedTensorType::get(
oldAType.getShape(), oldAType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
newRetType.getEncoding()));
oldAType.getShape(), oldAType.getElementType(), newAEncoding);
auto newBType = RankedTensorType::get(
oldBType.getShape(), oldBType.getElementType(),
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
newRetType.getEncoding()));
oldBType.getShape(), oldBType.getElementType(), newBEncoding);
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);

View File

@@ -14,7 +14,9 @@ add_mlir_dialect_library(TritonGPUTransforms
TritonGPUTransformsIncGen
LINK_LIBS PUBLIC
MLIRTransforms
MLIRTransformUtils
TritonAnalysis
TritonIR
TritonGPUIR
MLIRTransformUtils
)

View File

@@ -22,16 +22,13 @@ template <class T> SmallVector<unsigned, 4> argSort(const T &arr) {
typedef DenseMap<Value, std::function<Type(Type)>> LayoutMap;
struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr,
int numWarps) {
Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis,
Value ptr, int numWarps) {
auto origType = ptr.getType().cast<RankedTensorType>();
// Get the shape of the tensor.
size_t rank = origType.getRank();
dataflow::Lattice<AxisInfo> *latticeElement =
axisInfo.getLatticeElement(ptr);
AxisInfo info = latticeElement ? latticeElement->getValue() : AxisInfo();
// Get the contiguity order of `ptr`
auto order = argSort(info.getContiguity());
auto order = argSort(axisInfoAnalysis.getAxisInfo(ptr)->getContiguity());
// The desired divisibility is the maximum divisibility
// among all dependent pointers who have the same order as
// `ptr`
@@ -42,8 +39,8 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
for (Value val : op->getResults()) {
if (val.getType() != origType)
continue;
auto valInfo = axisInfo.getLatticeElement(val);
auto currOrder = argSort(valInfo->getValue().getContiguity());
auto currOrder =
argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity());
if (order == currOrder)
withSameOrder.insert(val);
}
@@ -61,10 +58,11 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
unsigned elemNumBytes = std::max(elemNumBits / 8, 1u);
unsigned perThread = 1;
for (Value val : withSameOrder) {
AxisInfo info = axisInfo.getLatticeElement(val)->getValue();
unsigned maxMultipleBytes = info.getDivisibility(order[0]);
unsigned maxMultipleBytes =
axisInfoAnalysis.getAxisInfo(val)->getDivisibility(order[0]);
unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u);
unsigned maxContig = info.getContiguity(order[0]);
unsigned maxContig =
axisInfoAnalysis.getAxisInfo(val)->getContiguity(order[0]);
unsigned alignment = std::min(maxMultiple, maxContig);
unsigned currPerThread = std::min(alignment, 128 / elemNumBits);
perThread = std::max(perThread, currPerThread);
@@ -78,9 +76,10 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
return encoding;
}
std::function<Type(Type)> getTypeConverter(AxisInfoAnalysis &axisInfo,
Value ptr, int numWarps) {
Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps);
std::function<Type(Type)>
getTypeConverter(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr,
int numWarps) {
Attribute encoding = getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps);
return [encoding](Type _type) {
RankedTensorType type = _type.cast<RankedTensorType>();
return RankedTensorType::get(type.getShape(), type.getElementType(),
@@ -127,17 +126,14 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
}
void runOnOperation() override {
Operation *op = getOperation();
// Run axis info analysis
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *axisInfo = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(op)))
return signalPassFailure();
ModuleOp moduleOp = getOperation();
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
// For each i/o operation, we determine what layout
// the pointers should have for best memory coalescing
LayoutMap layoutMap;
op->walk([&](Operation *curr) {
moduleOp.walk([&](Operation *curr) {
Value ptr;
if (auto op = dyn_cast<triton::LoadOp>(curr))
ptr = op.getPtr();
@@ -154,10 +150,9 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
RankedTensorType ty = ptr.getType().template dyn_cast<RankedTensorType>();
if (!ty || !ty.getElementType().isa<PointerType>())
return;
AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue();
auto mod = curr->getParentOfType<ModuleOp>();
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
auto convertType = getTypeConverter(*axisInfo, ptr, numWarps);
auto convertType = getTypeConverter(axisInfoAnalysis, ptr, numWarps);
layoutMap[ptr] = convertType;
});
@@ -168,7 +163,7 @@ struct CoalescePass : public TritonGPUCoalesceBase<CoalescePass> {
// produces a tensor with layout L2
// 4. Convert the output of this new memory op back to L1
// 5. Replace all the uses of the original memory op by the new one
op->walk([&](Operation *curr) {
moduleOp.walk([&](Operation *curr) {
OpBuilder builder(curr);
if (auto load = dyn_cast<triton::LoadOp>(curr)) {
coalesceOp<triton::LoadOp>(layoutMap, curr, load.getPtr(), builder);

View File

@@ -72,6 +72,76 @@ public:
}
};
//
class MoveOpAfterLayoutConversion : public mlir::RewritePattern {
public:
MoveOpAfterLayoutConversion(mlir::MLIRContext *context)
: mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(),
1, context) {}
mlir::LogicalResult
matchAndRewrite(mlir::Operation *op,
mlir::PatternRewriter &rewriter) const override {
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
auto srcTy = cvt.getOperand().getType().cast<RankedTensorType>();
auto retTy = cvt.getResult().getType().dyn_cast<RankedTensorType>();
auto retEncoding =
retTy.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
auto srcEncoding =
srcTy.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>();
if (!retTy)
return failure();
if (!retEncoding)
return failure();
auto retEncodingParent =
retEncoding.getParent().dyn_cast<triton::gpu::MmaEncodingAttr>();
if (!retEncodingParent || retEncodingParent.isVolta())
return failure();
if (!srcEncoding)
return failure();
// don't move things around when cvt operand is a block arg
Operation *argOp = cvt.getOperand().getDefiningOp();
if (!argOp)
return failure();
//
SetVector<Operation *> processed;
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
int numCvts = simulateBackwardRematerialization(cvt, processed, layout,
toConvert, retEncoding);
if (numCvts > 1 || toConvert.size() == 1)
return failure();
for (Operation *op : processed) {
if (op->getNumOperands() != 1)
continue;
auto srcTy = op->getOperand(0).getType().cast<RankedTensorType>();
auto dstTy = op->getResult(0).getType().cast<RankedTensorType>();
// we don't want to push conversions backward if there is a downcast
// since it would result in more shared memory traffic
if (srcTy.getElementType().getIntOrFloatBitWidth() >
dstTy.getElementType().getIntOrFloatBitWidth())
return failure();
// we only push back when the first op in the chain has a load operand
if ((op == processed.back()) &&
!isa<triton::LoadOp>(op->getOperand(0).getDefiningOp()))
return failure();
// we don't want to use ldmatrix for 8-bit data that requires trans
// since Nvidia GPUs can't do it efficiently
bool isTrans =
(retEncoding.getOpIdx() == 1) ^ (srcEncoding.getOrder()[0] == 0);
bool isInt8 = srcTy.getElementType().getIntOrFloatBitWidth() == 8;
if (isTrans && isInt8)
return failure();
}
IRMapping mapping;
rematerializeConversionChain(toConvert, rewriter, processed, mapping);
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
return mlir::success();
}
};
} // namespace
#define GEN_PASS_CLASSES
@@ -93,6 +163,7 @@ public:
mlir::RewritePatternSet patterns(context);
patterns.add<ConvertTransConvert>(context);
patterns.add<MoveOpAfterLayoutConversion>(context);
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
signalPassFailure();
if (fixupLoops(m).failed())

View File

@@ -1,3 +1,4 @@
#include "Utility.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/TypeUtilities.h"
@@ -42,6 +43,9 @@ class LoopPipeliner {
/// Loads to be pipelined
SetVector<Value> loads;
/// Smallest data-type for each load (used to optimize swizzle and
/// (create DotOpEncoding layout)
DenseMap<Value, Type> loadsSmallestType;
/// The value that each load will be mapped to (after layout conversion)
DenseMap<Value, Value> loadsMapping;
/// load => buffer
@@ -181,24 +185,19 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op,
/// this pass?)
LogicalResult LoopPipeliner::initialize() {
Block *loop = forOp.getBody();
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(forOp->getParentOfType<ModuleOp>()))) {
return failure();
}
ModuleOp moduleOp = forOp->getParentOfType<ModuleOp>();
ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp);
// can we use forOp.walk(...) here?
SmallVector<triton::LoadOp, 2> validLoads;
for (Operation &op : *loop)
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
auto ptr = loadOp.getPtr();
unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr);
unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr);
if (auto mask = loadOp.getMask())
vec = std::min<unsigned>(vec, axisInfoAnalysis->getMaskAlignment(mask));
vec = std::min<unsigned>(vec, axisInfoAnalysis.getMaskAlignment(mask));
auto lattice = axisInfoAnalysis->getLatticeElement(ptr)->getValue();
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
if (!tensorTy || tensorTy.getRank() < 2)
continue;
@@ -256,33 +255,62 @@ LogicalResult LoopPipeliner::initialize() {
use = *use->getResult(0).getUsers().begin();
}
if (auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use)) {
if (auto tensorType = convertLayout.getResult()
.getType()
.dyn_cast<RankedTensorType>()) {
if (auto dotOpEnc = tensorType.getEncoding()
.dyn_cast<ttg::DotOperandEncodingAttr>()) {
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), ty.getElementType());
loadsBufferType[loadOp] = RankedTensorType::get(
bufferShape, ty.getElementType(), sharedEnc);
}
}
}
} else
auto convertLayout = llvm::dyn_cast<ttg::ConvertLayoutOp>(use);
if (!convertLayout)
continue;
auto tensorType =
convertLayout.getResult().getType().dyn_cast<RankedTensorType>();
if (!tensorType)
continue;
auto dotOpEnc =
tensorType.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
if (!dotOpEnc)
continue;
isCandidate = true;
loadsMapping[loadOp] = convertLayout;
}
else
isCandidate = false;
if (isCandidate)
loads.insert(loadOp);
}
// we need to find the smallest ocmmon dtype
// since this determines the layout of `mma.sync` operands
// in mixed-precision mode
Type smallestType;
for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
auto ty = loadOp.getType().cast<RankedTensorType>();
Type eltTy = ty.getElementType();
if (!smallestType ||
(eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth()))
smallestType = eltTy;
}
for (auto loadCvt : loadsMapping)
loadsSmallestType[loadCvt.first] = smallestType;
for (auto loadCvt : loadsMapping) {
auto loadOp = loadCvt.first;
Value cvt = loadCvt.second;
auto dotOpEnc = cvt.getType()
.cast<RankedTensorType>()
.getEncoding()
.cast<ttg::DotOperandEncodingAttr>();
auto ty = loadOp.getType().cast<RankedTensorType>();
SmallVector<int64_t> bufferShape(ty.getShape().begin(),
ty.getShape().end());
bufferShape.insert(bufferShape.begin(), numStages);
auto sharedEnc = ttg::SharedEncodingAttr::get(
ty.getContext(), dotOpEnc, ty.getShape(),
triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]);
loadsBufferType[loadOp] =
RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc);
}
// We have some loads to pipeline
if (!loads.empty()) {
// Update depArgs & depOps
@@ -336,10 +364,6 @@ Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask,
}
void LoopPipeliner::emitPrologue() {
// llvm::errs() << "loads to pipeline...:\n";
// for (Value load : loads)
// llvm::errs() << load << "\n";
OpBuilder builder(forOp);
for (BlockArgument &arg : forOp.getRegionIterArgs()) {
OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg);
@@ -364,7 +388,7 @@ void LoopPipeliner::emitPrologue() {
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&
@@ -541,44 +565,44 @@ scf::ForOp LoopPipeliner::createNewForOp() {
mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]);
mapping.map(forOp.getInductionVar(), newForOp.getInductionVar());
// 2.1 clone the loop body, replace original args with args of the new ForOp
// 2. clone the loop body, replace original args with args of the new ForOp
// Insert async wait if necessary.
DenseSet<Value> isModified;
for (Operation &op : forOp.getBody()->without_terminator()) {
Operation *newOp = builder.clone(op, mapping);
// update mapping of results
for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults()))
mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx));
}
// is modified
auto it = std::find(loads.begin(), loads.end(), op.getOperand(0));
if (it == loads.end()) {
Operation *newOp = cloneWithInferType(builder, &op, mapping);
continue;
}
// 3. replace loads with block args (from prologue)
for (size_t idx = 0; idx < loads.size(); ++idx) {
OpBuilder::InsertionGuard guard(builder);
Value load = loads[idx];
assert(load.hasOneUse() &&
"we assume that this load has one use (ConvertLayout)");
Value loadUse = load.getUsers().begin()->getResult(0);
// set insertion point
Value newLoad = mapping.lookup(load);
Value newLoadUse = mapping.lookup(loadUse);
builder.setInsertionPoint(newLoadUse.getDefiningOp());
// create conversion
// we replace the use new load use with a convert layout
size_t i = std::distance(loads.begin(), it);
auto cvtDstTy = op.getResult(0).getType().cast<RankedTensorType>();
auto cvtDstEnc =
cvtDstTy.getEncoding().dyn_cast<ttg::DotOperandEncodingAttr>();
if (!cvtDstEnc) {
builder.clone(op, mapping);
continue;
}
auto newDstTy = RankedTensorType::get(
cvtDstTy.getShape(), cvtDstTy.getElementType(),
ttg::DotOperandEncodingAttr::get(
cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(), cvtDstEnc.getParent(),
loadsSmallestType[op.getOperand(0)]));
auto cvt = builder.create<ttg::ConvertLayoutOp>(
loadUse.getLoc(), loadUse.getType(),
newForOp.getRegionIterArgs()[loadIdx + idx]);
// replace uses
newLoadUse.replaceAllUsesWith(cvt.getResult());
// delete old load and layout conversion
newLoadUse.getDefiningOp()->erase();
newLoad.getDefiningOp()->erase();
op.getResult(0).getLoc(), newDstTy,
newForOp.getRegionIterArgs()[loadIdx + i]);
mapping.map(op.getResult(0), cvt.getResult());
isModified.insert(op.getResult(0));
}
// 4. prefetch the next iteration
// 3. prefetch the next iteration
SmallVector<Operation *> orderedDeps;
for (Operation &op : forOp.getLoopBody().front()) {
if (depOps.contains(&op))
orderedDeps.push_back(&op);
else if (loads.contains(op.getResult(0)))
else if (op.getNumResults() > 0 && loads.contains(op.getResult(0)))
orderedDeps.push_back(&op);
}
assert(depOps.size() + loads.size() == orderedDeps.size() &&

View File

@@ -27,7 +27,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/IRMapping.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -45,7 +44,7 @@ class Prefetcher {
scf::YieldOp yieldOp;
///
// TODO: add a hook to infer prefetchWidth
unsigned prefetchWidth = 16;
unsigned prefetchWidth = 32;
/// dots to be prefetched
SetVector<Value> dots;
@@ -56,6 +55,8 @@ class Prefetcher {
DenseMap<Value, Value> dot2bHeaderDef;
DenseMap<Value, Value> dot2aYield;
DenseMap<Value, Value> dot2bYield;
DenseMap<Value, SmallVector<Value>> dot2aVals;
DenseMap<Value, SmallVector<Value>> dot2bVals;
/// operand => defining
DenseMap<Value, Value> operand2headPrefetch;
@@ -66,6 +67,9 @@ class Prefetcher {
std::optional<int64_t> offsetK = std::nullopt,
std::optional<int64_t> shapeK = std::nullopt);
void cloneElementwiseOps(Value &bRem, const SmallVector<Value> &vals,
OpBuilder &builder);
public:
Prefetcher() = delete;
@@ -80,6 +84,24 @@ public:
scf::ForOp createNewForOp();
};
void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector<Value> &vals,
OpBuilder &builder) {
IRMapping mapping;
mapping.map(vals[0], ret);
for (int i = 1; i < vals.size(); i++) {
Value v = vals[i];
Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0);
auto retType = RankedTensorType::get(
ret.getType().cast<RankedTensorType>().getShape(),
curr.getType().cast<RankedTensorType>().getElementType(),
curr.getType().cast<RankedTensorType>().getEncoding());
curr.setType(retType);
mapping.map(v, curr);
}
if (vals.size() > 1)
ret = mapping.lookup(vals.back());
}
Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
Attribute dotEncoding, OpBuilder &builder,
std::optional<int64_t> offsetK,
@@ -110,7 +132,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
builder.getContext(), opIdx, dotEncoding);
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
Value prefetchSlice = builder.create<triton::gpu::ConvertLayoutOp>(
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
newSmem);
@@ -135,11 +157,32 @@ LogicalResult Prefetcher::initialize() {
return failure();
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> Value {
if (auto cvt = v.getDefiningOp<triton::gpu::ConvertLayoutOp>())
if (isSharedEncoding(cvt.getOperand()))
return cvt.getSrc();
return Value();
// returns source of cvt
auto getPrefetchSrc = [](Value v) -> SmallVector<Value> {
// walk back to conversion
Operation *op = v.getDefiningOp();
bool foundConvertFromShared = false;
SmallVector<Value> rets;
rets.push_back(op->getResult(0));
while (op) {
if (op->getNumOperands() != 1)
break;
if (!op->getResult(0).hasOneUse())
break;
rets.push_back(op->getOperand(0));
if (auto cvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(op))
if (isSharedEncoding(cvt.getOperand())) {
foundConvertFromShared = true;
break;
}
op = op->getOperand(0).getDefiningOp();
}
std::reverse(rets.begin(), rets.end());
if (foundConvertFromShared)
return rets;
return {};
};
auto getIncomingOp = [this](Value v) -> Value {
@@ -156,24 +199,39 @@ LogicalResult Prefetcher::initialize() {
};
for (triton::DotOp dot : dotsInFor) {
auto kSize = dot.getA().getType().cast<RankedTensorType>().getShape()[1];
auto aType = dot.getA().getType().cast<RankedTensorType>();
auto bType = dot.getB().getType().cast<RankedTensorType>();
auto aEnc = aType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
auto bEnc = bType.getEncoding().cast<triton::gpu::DotOperandEncodingAttr>();
int aKWidth = aEnc.getMMAv2kWidth();
int bKWidth = bEnc.getMMAv2kWidth();
assert(aKWidth == bKWidth);
auto kSize = aType.getShape()[1];
// works better with nvidia tensor cores
unsigned elementWidth =
dot.getA().getType().cast<RankedTensorType>().getElementTypeBitWidth();
prefetchWidth = 256 / elementWidth;
unsigned elementWidth = aType.getElementTypeBitWidth();
if (aKWidth == 0)
prefetchWidth = 256 / elementWidth;
else
prefetchWidth = 8 * aKWidth;
// Skip prefetching if kSize is less than prefetchWidth
if (kSize < prefetchWidth)
continue;
Value aSmem = getPrefetchSrc(dot.getA());
Value bSmem = getPrefetchSrc(dot.getB());
if (aSmem && bSmem) {
auto aVals = getPrefetchSrc(dot.getA());
auto bVals = getPrefetchSrc(dot.getB());
if (aVals.size() && bVals.size()) {
Value aSmem = aVals.front();
Value bSmem = bVals.front();
Value aHeaderDef = getIncomingOp(aSmem);
Value bHeaderDef = getIncomingOp(bSmem);
// Only prefetch loop arg
if (aHeaderDef && bHeaderDef) {
dots.insert(dot);
dot2aVals[dot] = aVals;
dot2bVals[dot] = bVals;
dot2aHeaderDef[dot] = aHeaderDef;
dot2bHeaderDef[dot] = bHeaderDef;
dot2aLoopArg[dot] = aSmem;
@@ -195,10 +253,13 @@ void Prefetcher::emitPrologue() {
dot.getType().cast<RankedTensorType>().getEncoding();
Value aPrefetched =
generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().getA()] =
aPrefetched;
cloneElementwiseOps(aPrefetched, dot2aVals[dot], builder);
Value bPrefetched =
generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder);
cloneElementwiseOps(bPrefetched, dot2bVals[dot], builder);
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().getA()] =
aPrefetched;
operand2headPrefetch[dot.getDefiningOp<triton::DotOp>().getB()] =
bPrefetched;
}
@@ -256,9 +317,11 @@ scf::ForOp Prefetcher::createNewForOp() {
Value aRem =
generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false,
dotEncoding, builder, kOff, kShape);
cloneElementwiseOps(aRem, dot2aVals[dot], builder);
Value bRem =
generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false,
dotEncoding, builder, kOff, kShape);
cloneElementwiseOps(bRem, dot2bVals[dot], builder);
builder.restoreInsertionPoint(insertionPoint);
newOp = builder.clone(*dot, mapping);
newOp->setOperand(0, aRem);
@@ -281,10 +344,15 @@ scf::ForOp Prefetcher::createNewForOp() {
for (Value dot : dots) {
Attribute dotEncoding =
dot.getType().cast<RankedTensorType>().getEncoding();
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0,
true, dotEncoding, builder));
yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1,
true, dotEncoding, builder));
Value aToYield = generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true,
dotEncoding, builder);
cloneElementwiseOps(aToYield, dot2aVals[dot], builder);
yieldValues.push_back(aToYield);
// bToYield
Value bToYield = generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true,
dotEncoding, builder);
cloneElementwiseOps(bToYield, dot2bVals[dot], builder);
yieldValues.push_back(bToYield);
}
// Update ops of yield
if (!yieldValues.empty())

View File

@@ -13,7 +13,6 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/RegionUtils.h"
#include "triton/Analysis/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
@@ -341,7 +340,10 @@ public:
cvt.getOperand().getType().cast<RankedTensorType>().getEncoding();
auto dstEncoding =
cvt.getResult().getType().cast<RankedTensorType>().getEncoding();
// XXX: why is this needed?
if (srcEncoding.isa<triton::gpu::SharedEncodingAttr>() ||
dstEncoding.isa<triton::gpu::SharedEncodingAttr>())
return failure();
// heuristics for flash attention
if (srcEncoding.isa<triton::gpu::SliceEncodingAttr>())
return failure();
SetVector<Operation *> cvtSlices;
@@ -365,7 +367,7 @@ public:
// don't rematerialize non-element-wise
if (!op->hasTrait<mlir::OpTrait::SameOperandsAndResultEncoding>() &&
!op->hasTrait<mlir::OpTrait::Elementwise>() &&
!isa<triton::StoreOp>(op)) {
!isa<triton::StoreOp>(op) && !isa<triton::ReduceOp>(op)) {
return failure();
}
// don't rematerialize if it adds an extra conversion that can't
@@ -375,9 +377,10 @@ public:
SetVector<Operation *> processed;
SetVector<Attribute> layout;
llvm::MapVector<Value, Attribute> toConvert;
if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 &&
simulateBackwardRematerialization(argOp, processed, layout,
toConvert, srcEncoding) > 0) {
int numAddedConvs = simulateBackwardRematerialization(
argOp, processed, layout, toConvert, srcEncoding);
if (argOp && !isa<triton::gpu::ConvertLayoutOp>(argOp) &&
cvtSlices.count(argOp) == 0 && numAddedConvs > 0) {
return failure();
}
}

View File

@@ -89,11 +89,11 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
}
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
// Case 1a: A size 1 tensor is not expensive since all threads will load the
// Case 1: A size 1 tensor is not expensive since all threads will load the
// same
if (isSingleValue(op->getOperand(0)))
return false;
// Case 1b: Tensor of pointers has more threads than elements
// Case 2: Tensor of pointers has more threads than elements
// we can presume a high hit-rate that makes it cheap to load
auto ptrType = op->getOperand(0).getType().cast<RankedTensorType>();
IntegerAttr numWarps =
@@ -104,28 +104,6 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
if (ptrType.getNumElements() < numWarps.getInt() * 32)
return false;
}
// auto ptr = op->getOperand(0);
//// Case 2: We assume that `evict_last` loads/stores have high hit rate
// if (auto load = dyn_cast<triton::LoadOp>(op))
// if (load.getEvict() == triton::EvictionPolicy::EVICT_LAST)
// return false;
// if (auto store = dyn_cast<triton::StoreOp>(op))
// if (store.getEvict() == triton::EvictionPolicy::EVICT_LAST)
// return false;
// if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
// auto encoding = tensorTy.getEncoding();
// // Case 3: Different type conversion is expensive (e.g., mma <->
// block) if (encoding.getTypeID() != targetEncoding.getTypeID())
// return true;
// auto sizePerThread = triton::gpu::getSizePerThread(encoding);
// auto targetSizePerThread =
// triton::gpu::getSizePerThread(targetEncoding); auto order =
// triton::gpu::getOrder(encoding); auto targetOrder =
// triton::gpu::getOrder(targetEncoding);
// // Case 4: The targeEncoding may expose more vectorization
// opportunities return sizePerThread[order[0]] >=
// targetSizePerThread[targetOrder[0]];
// }
return true;
}
@@ -144,6 +122,12 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
return false;
}
bool canFoldConversion(Operation *op) {
return isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp, triton::ViewOp,
triton::CatOp>(*op);
}
int simulateBackwardRematerialization(
Operation *initOp, SetVector<Operation *> &processed,
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
@@ -189,10 +173,7 @@ int simulateBackwardRematerialization(
continue;
// If the conversion can be folded into opArgI then
// we don't count this conversion as expensive
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
continue;
if (isa<triton::ViewOp, triton::CatOp>(opArgI))
if (canFoldConversion(opArgI))
continue;
// We add one expensive conversion for the current operand
@@ -206,11 +187,24 @@ int simulateBackwardRematerialization(
//
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping) {
Operation *newOp = rewriter.clone(*op, mapping);
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
// if input types haven't changed, we're done
bool preserveTypes =
std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) {
return !mapping.contains(v) ||
v.getType() == mapping.lookup(v).getType();
});
if (preserveTypes)
return newOp;
if (newOp->getNumResults() == 0)
return newOp;
auto origType = op->getResult(0).getType().dyn_cast<RankedTensorType>();
auto argType = newOp->getOperand(0).getType().dyn_cast<RankedTensorType>();
if (!origType || !argType)
return newOp;
auto newType = RankedTensorType::get(
origType.getShape(), origType.getElementType(), argType.getEncoding());
newOp->getResult(0).setType(newType);
@@ -219,9 +213,12 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
SmallVector<Type, 1> newTypes;
auto success = typeInfer.inferReturnTypes(
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
newOp->getAttrDictionary(), newOp->getRegions(), newTypes);
if (succeeded(success))
newOp->getResult(0).setType(newTypes.front());
newOp->getAttrDictionary(), newOp->getPropertiesStorage(),
newOp->getRegions(), newTypes);
if (succeeded(success)) {
for (size_t i = 0; i < newTypes.size(); i++)
newOp->getResult(i).setType(newTypes[i]);
}
}
return newOp;
}

View File

@@ -21,7 +21,7 @@ int simulateBackwardRematerialization(
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
Attribute targetEncoding);
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op,
IRMapping &mapping);
void rematerializeConversionChain(

View File

@@ -5,8 +5,17 @@ add_mlir_translation_library(TritonLLVMIR
Core
LINK_LIBS PUBLIC
MLIRArithToLLVM
MLIRBuiltinToLLVMIRTranslation
MLIRExecutionEngineUtils
MLIRIndexToLLVM
MLIRIR
MLIRLLVMDialect
MLIRLLVMToLLVMIRTranslation
MLIRNVVMToLLVMIRTranslation
MLIRROCDLToLLVMIRTranslation
MLIRSCFToControlFlow
MLIRSupport
MLIRTargetLLVMIRExport
TritonGPUToLLVM
)

View File

@@ -25,12 +25,22 @@
#include "llvm/IRReader/IRReader.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/SourceMgr.h"
<<<<<<< HEAD
#include <iostream>
=======
#ifdef _WIN32
#define WIN32_LEAN_AND_MEAN
#include <windows.h>
#else
>>>>>>> openai/main
#include <dlfcn.h>
#endif
#include <filesystem>
#include <iterator>
namespace fs = std::filesystem;
namespace mlir {
namespace triton {
@@ -117,6 +127,32 @@ extractNVVMMetadata(mlir::ModuleOp module,
}
}
static std::filesystem::path getThisLibraryPath() {
#ifdef _WIN32
/* Get module of the specified address */
HMODULE hModule;
GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS |
GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
reinterpret_cast<LPCSTR>(&getThisLibraryPath), &hModule);
if (NULL == hModule) {
return std::filesystem::path();
}
char fileName[1024]; // this is way beyond Windows MAX_PATH limit.
DWORD dwSize = GetModuleFileNameA(hModule, fileName, sizeof(fileName));
if (0 == dwSize || sizeof(fileName) == dwSize) {
return std::filesystem::path();
}
return std::filesystem::path(fileName);
#else
Dl_info fileinfo;
if (dladdr(reinterpret_cast<void *>(&getThisLibraryPath), &fileinfo) == 0) {
return std::filesystem::path();
}
return std::filesystem::path(fileinfo.dli_fname);
#endif
}
static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
std::map<std::string, std::string> externLibs;
SmallVector<LLVM::LLVMFuncOp> funcs;
@@ -156,17 +192,10 @@ static std::map<std::string, std::string> getExternLibs(mlir::ModuleOp module) {
externLibs.try_emplace(libdevice, env_path);
return externLibs;
}
namespace fs = std::filesystem;
// Search for libdevice relative to its library path if used from Python
// Then native code is in `triton/_C/libtriton.so` and libdevice in
// `triton/third_party/cuda/lib/libdevice.10.bc`
static const auto this_library_path = [] {
Dl_info fileinfo;
if (dladdr(reinterpret_cast<void *>(&getExternLibs), &fileinfo) == 0) {
return std::filesystem::path();
}
return std::filesystem::path(fileinfo.dli_fname);
}();
static const auto this_library_path = getThisLibraryPath();
static const auto runtime_path =
this_library_path.parent_path().parent_path() / "third_party" / "cuda" /
"lib" / "libdevice.10.bc";

View File

@@ -68,7 +68,7 @@ def get_llvm_package_info():
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
release_suffix = "assert" if use_assert_enabled_llvm else "release"
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
version = "llvm-17.0.0-f733b4fb9b8b"
version = "llvm-17.0.0-c5dede880d17"
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
@@ -249,6 +249,7 @@ setup(
"triton/_C",
"triton/common",
"triton/compiler",
"triton/debugger",
"triton/language",
"triton/language/extra",
"triton/ops",
@@ -294,7 +295,6 @@ setup(
"numpy",
"pytest",
"scipy>=1.7.1",
"torch",
],
"tutorials": [
"matplotlib",

View File

@@ -264,6 +264,11 @@ void init_triton_ir(py::module &&m) {
return !self.empty() &&
self.back().hasTrait<mlir::OpTrait::IsTerminator>();
})
.def("has_return",
[](mlir::Block &self) {
return !self.empty() &&
self.back().hasTrait<mlir::OpTrait::ReturnLike>();
})
.def("erase", [](mlir::Block &self) { self.erase(); });
// using eattr = ir::attribute_kind_t;
@@ -430,6 +435,25 @@ void init_triton_ir(py::module &&m) {
self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val));
},
ret::reference)
.def("finalize",
[](mlir::triton::FuncOp &self) -> void {
// Remove dead code
// 1. Unreachable code after return
self.walk([&](mlir::Block *block) {
mlir::Operation *retOp = nullptr;
block->walk([&](mlir::Operation *op) {
if (mlir::isa<mlir::triton::ReturnOp>(op))
if (retOp == nullptr)
retOp = op;
});
if (retOp && retOp != &block->back()) {
auto pos = retOp->getIterator();
pos++;
auto *newBlock = block->splitBlock(pos);
newBlock->erase();
}
});
})
.def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType)
.def("reset_type", &mlir::triton::FuncOp::setType);
@@ -454,7 +478,8 @@ void init_triton_ir(py::module &&m) {
[](mlir::OpBuilder &self, mlir::triton::FuncOp &func,
std::vector<mlir::Value> &args) -> mlir::OpState {
auto loc = self.getUnknownLoc();
return self.create<mlir::triton::CallOp>(loc, func, args);
auto callOp = self.create<mlir::triton::CallOp>(loc, func, args);
return callOp;
})
// insertion block/point
.def("set_insertion_point_to_start",
@@ -645,14 +670,16 @@ void init_triton_ir(py::module &&m) {
.def("get_or_insert_function",
[](mlir::OpBuilder &self, mlir::ModuleOp &module,
std::string &funcName, mlir::Type &funcType,
std::string &visibility) -> mlir::triton::FuncOp {
std::string &visibility, bool noinline) -> mlir::triton::FuncOp {
if (mlir::Operation *funcOperation = module.lookupSymbol(funcName))
return llvm::dyn_cast<mlir::triton::FuncOp>(funcOperation);
auto loc = self.getUnknownLoc();
if (auto funcTy = funcType.dyn_cast<mlir::FunctionType>()) {
llvm::SmallVector<mlir::NamedAttribute> attrs = {
mlir::NamedAttribute(self.getStringAttr("sym_visibility"),
self.getStringAttr(visibility))};
self.getStringAttr(visibility)),
mlir::NamedAttribute(self.getStringAttr("noinline"),
self.getBoolAttr(noinline))};
return self.create<mlir::triton::FuncOp>(loc, funcName, funcTy,
attrs);
}
@@ -1699,54 +1726,59 @@ void init_triton_translation(py::module &m) {
m.def("compile_ptx_to_cubin",
[](const std::string &ptxCode, const std::string &ptxasPath,
int capability) -> py::object {
py::gil_scoped_release allow_threads;
std::string cubin;
{
py::gil_scoped_release allow_threads;
// compile ptx with ptxas
llvm::SmallString<64> fsrc;
llvm::SmallString<64> flog;
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
std::string fbin = std::string(fsrc) + ".o";
llvm::FileRemover logRemover(flog);
llvm::FileRemover binRemover(fbin);
const char *_fsrc = fsrc.c_str();
const char *_flog = flog.c_str();
const char *_fbin = fbin.c_str();
std::ofstream ofs(_fsrc);
ofs << ptxCode << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) +
(capability == 90 ? "a " : " ") + _fsrc + " -o " + _fsrc +
".o 2> " + _flog;
// compile ptx with ptxas
llvm::SmallString<64> fsrc;
llvm::SmallString<64> flog;
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
std::string fbin = std::string(fsrc) + ".o";
llvm::FileRemover logRemover(flog);
llvm::FileRemover binRemover(fbin);
const char *_fsrc = fsrc.c_str();
const char *_flog = flog.c_str();
const char *_fbin = fbin.c_str();
std::ofstream ofs(_fsrc);
ofs << ptxCode << std::endl;
ofs.close();
std::string cmd;
int err;
cmd = ptxasPath + " -v --gpu-name=sm_" +
std::to_string(capability) + (capability == 90 ? "a " : " ") +
_fsrc + " -o " + _fsrc + ".o 2> " + _flog;
err = system(cmd.c_str());
if (err != 0) {
err >>= 8;
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
if (err == 255) {
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
log);
} else if (err == 128 + SIGSEGV) {
throw std::runtime_error("Please run `ptxas " + fsrc.str().str() +
"` to confirm that this is a "
"bug in `ptxas`\n" +
log);
err = system(cmd.c_str());
if (err != 0) {
err >>= 8;
std::ifstream _log(_flog);
std::string log(std::istreambuf_iterator<char>(_log), {});
if (err == 255) {
throw std::runtime_error(
"Internal Triton PTX codegen error: \n" + log);
} else if (err == 128 + SIGSEGV) {
throw std::runtime_error("Please run `ptxas " +
fsrc.str().str() +
"` to confirm that this is a "
"bug in `ptxas`\n" +
log);
} else {
throw std::runtime_error("`ptxas` failed with error code " +
std::to_string(err) + ": \n" + log);
}
return {};
} else {
throw std::runtime_error("`ptxas` failed with error code " +
std::to_string(err) + ": \n" + log);
llvm::FileRemover srcRemover(fsrc);
std::ifstream _cubin(_fbin, std::ios::binary);
cubin = std::string(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
// Do not return here, exit the gil scope and return below
}
return {};
} else {
llvm::FileRemover srcRemover(fsrc);
std::ifstream _cubin(_fbin, std::ios::binary);
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
_cubin.close();
py::bytes bytes(cubin);
return std::move(bytes);
}
py::bytes bytes(cubin);
return std::move(bytes);
});
m.def("add_external_libs",

View File

@@ -0,0 +1,68 @@
import torch
import triton
import triton.language as tl
def test_chained_matmul():
# Regression test for issue #1601
def chained_matmul_reference(a, b, c):
intermediate = torch.einsum('MK,NK->MN', a, b)
return torch.einsum('MN,NK->MK', intermediate, c)
@triton.jit
def chained_matmul_kernel(
A, # shape: (m, k)
B, # shape: (n, k)
C, # shape: (n, k)
out, # shape: (m, k)
m, n, k: tl.constexpr,
block_m: tl.constexpr,
block_n: tl.constexpr,
block_k: tl.constexpr):
tl.static_assert(block_k == k,
f"expected block_k == k but got {block_k} != {k}")
block_ix = tl.program_id(0)
a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \
+ tl.arange(0, block_k)[None, :]
a = tl.load(A + a_tile, mask=a_tile < m * k, other=0.0)
acc = tl.zeros([block_m, block_k], dtype=tl.float32)
for loop_block_start in range(0, n, block_n):
bc_tile = (loop_block_start + tl.arange(0, block_n))[:, None] * block_k \
+ tl.arange(0, block_k)[None, :]
b = tl.load(B + bc_tile, mask=bc_tile < n * k, other=0.0)
intermediate = tl.dot(a, tl.trans(b))
intermediate_mask = ((loop_block_start + tl.arange(0, block_n)) < n)[None, :] \
* (tl.arange(0, block_m) < m)[:, None]
intermediate = tl.where(intermediate_mask, intermediate, 0.0)
c = tl.load(C + bc_tile, mask=bc_tile < n * k)
acc += tl.dot(intermediate.to(A.dtype.element_ty), c)
tl.store(out + a_tile, acc.to(A.dtype.element_ty), mask=a_tile < m * k)
m, n, k = 32, 64, 128
block_m, block_n, block_k = 16, 32, k
grid = (triton.cdiv(m, block_m),)
a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16,
device='cuda')
b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16,
device='cuda')
c = torch.randint_like(b, low=0, high=2)
triton_result = torch.zeros_like(a)
torch_result = chained_matmul_reference(a, b, c)
chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k,
block_m=block_m, block_n=block_n,
block_k=block_k)
assert (torch_result == triton_result).all()

View File

@@ -0,0 +1,69 @@
import random
import torch
import triton
import triton.language as tl
from triton.debugger.debugger import program_ids_from_grid
def test_addition():
@triton.jit(interpret=True)
def add_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
output = x + y
tl.store(output_ptr + offsets, output, mask=mask)
a = torch.rand((128,), device="cuda")
b = torch.rand((128,), device="cuda")
expected = a + b
output = torch.empty((128,), device="cuda")
def grid(meta):
return (triton.cdiv(128, meta["BLOCK_SIZE"]),)
add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32)
assert torch.allclose(expected, output, atol=1e-2, rtol=0)
def test_program_ids_from_grid():
random.seed(123)
grid = (3, 4)
expected_combinations = 3 * 4
unique_combinations = set(program_ids_from_grid(grid))
assert len(unique_combinations) == expected_combinations
first_run = list(program_ids_from_grid(grid))
second_run = list(program_ids_from_grid(grid))
assert first_run != second_run
def test_atomic():
@triton.jit(interpret=True)
def atomic(
x_ptr,
):
pid = tl.program_id(axis=0)
tl.atomic_add(x_ptr + pid, 1)
t = tl.atomic_xchg(x_ptr + pid, 3)
t += 1 # 2
tl.atomic_cas(x_ptr + pid, 3, t) # match
tl.atomic_cas(x_ptr + pid, 40, 9) # no match
nb_dim = 16
a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda")
atomic[(nb_dim, )](a)
assert torch.allclose(a, torch.full_like(a, 2))

View File

@@ -14,6 +14,14 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
# Trivial assert
tl.device_assert(0 == 0, "x != 0")
tl.store(Y + tl.arange(0, BLOCK), x)
@triton.jit
def kernel_assert(X, Y, BLOCK: tl.constexpr):
x = tl.load(X + tl.arange(0, BLOCK))
@@ -33,7 +41,12 @@ def test_assert(func: str):
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
if func == "device_assert":
<<<<<<< HEAD
kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
=======
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0])
>>>>>>> openai/main
elif func == "assert":
kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0])
elif func == "static_assert":

View File

@@ -457,6 +457,86 @@ def test_broadcast(dtype):
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()
# ----------------
# test expand_dims
# ----------------
def test_expand_dims():
@triton.jit
def expand_dims_kernel(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, 0)
tl.static_assert(t.shape == [1, N])
t = tl.expand_dims(offset1, 1)
tl.static_assert(t.shape == [N, 1])
t = tl.expand_dims(offset1, -1)
tl.static_assert(t.shape == [N, 1])
t = tl.expand_dims(offset1, -2)
tl.static_assert(t.shape == [1, N])
t = tl.expand_dims(offset1, (0, -1))
tl.static_assert(t.shape == [1, N, 1])
t = tl.expand_dims(offset1, (0, 1, 3))
tl.static_assert(t.shape == [1, 1, N, 1])
t = tl.expand_dims(offset1, (-4, 2, -1))
tl.static_assert(t.shape == [1, N, 1, 1])
t = tl.expand_dims(offset1, (3, 1, 2))
tl.static_assert(t.shape == [N, 1, 1, 1])
N = 32
dummy_tensor = torch.empty((), device="cuda")
expand_dims_kernel[(1,)](dummy_tensor, N)
def test_expand_dims_error_cases():
@triton.jit
def dim_out_of_range1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, -2)
t = tl.expand_dims(offset1, -3)
@triton.jit
def dim_out_of_range2(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, 1)
t = tl.expand_dims(offset1, 2)
@triton.jit
def duplicate_dim1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, (0, 0))
@triton.jit
def duplicate_dim2(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
t = tl.expand_dims(offset1, (0, -3))
N = 32
dummy_tensor = torch.empty((), device="cuda")
with pytest.raises(triton.CompilationError, match="invalid axis -3"):
dim_out_of_range1[(1,)](dummy_tensor, N)
with pytest.raises(triton.CompilationError, match="invalid axis 2"):
dim_out_of_range2[(1,)](dummy_tensor, N)
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
duplicate_dim1[(1,)](dummy_tensor, N)
with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
duplicate_dim2[(1,)](dummy_tensor, N)
# ---------------
# test where
# ---------------
@@ -688,7 +768,7 @@ def test_index1d(expr, dtype_str, device='cuda'):
@triton.jit
def fn(a, b):
def tuples_fn(a, b):
return a + b, \
a - b, \
a * b
@@ -701,7 +781,7 @@ def test_tuples():
def with_fn(X, Y, A, B, C):
x = tl.load(X)
y = tl.load(Y)
a, b, c = fn(x, y)
a, b, c = tuples_fn(x, y)
tl.store(A, a)
tl.store(B, b)
tl.store(C, c)
@@ -728,6 +808,92 @@ def test_tuples():
assert c_tri == c_ref
@triton.jit(noinline=True)
def noinline_simple_fn(x, y, Z):
z = x + y
tl.store(Z, z)
@triton.jit(noinline=True)
def noinline_call_graph_fn1(x):
return x + 1
@triton.jit(noinline=True)
def noinline_call_graph_fn2(y):
return y + 2
@triton.jit(noinline=True)
def noinline_call_graph_fn(x, y, Z):
t0 = noinline_call_graph_fn1(x)
t1 = noinline_call_graph_fn2(y)
z = t0 + t1
tl.store(Z, z)
@triton.jit(noinline=True)
def noinline_shared_fn(x, y, Z):
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
z = tl.load(Z + offs)
z = tl.dot(z, z) + x + y
tl.store(Z + offs, z)
@triton.jit(noinline=True)
def noinline_dynamic_fn(x, y, Z):
if x >= 1:
x = noinline_call_graph_fn1(x)
else:
x = noinline_call_graph_fn2(x)
if y >= 2:
y = noinline_call_graph_fn2(y)
else:
y = noinline_call_graph_fn1(y)
z = x + y
tl.store(Z, z)
@triton.jit(noinline=True)
def noinline_call_multi_values_fn(x, y):
return x + 1, y + 2
@triton.jit(noinline=True)
def noinline_multi_values_fn(x, y, Z):
x, y = noinline_call_multi_values_fn(x, y)
z = x + y
tl.store(Z, z)
@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"])
def test_noinline(mode):
device = 'cuda'
@triton.jit
def kernel(X, Y, Z):
x = tl.load(X)
y = tl.load(Y)
GENERATE_TEST_HERE(x, y, Z)
func_name = f'noinline_{mode}_fn'
kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name})
x = torch.tensor([1.0], device=device, dtype=torch.float32)
y = torch.tensor([2.0], device=device, dtype=torch.float32)
if mode == "shared":
z = torch.ones((16, 16), device=device, dtype=torch.float32)
else:
z = torch.tensor([0.0], device=device, dtype=torch.float32)
kernel[(1,)](x, y, z, num_warps=1)
if mode == "simple":
assert torch.equal(z, x + y)
elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values":
assert torch.equal(z, x + 1 + y + 2)
elif mode == "shared":
ref = torch.full((16, 16), 16, device=device, dtype=torch.float32)
assert torch.equal(z, ref + x + y)
# ---------------
# test atomics
# ---------------
@@ -1334,6 +1500,99 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'):
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
]
@pytest.mark.parametrize("M", [32, 64, 128, 256])
@pytest.mark.parametrize("src_layout", layouts)
def test_store_op(M, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<f32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.splat %arg0 : (!tt.ptr<f32>) -> tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr<f32>, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%7 = tt.splat %arg1 : (!tt.ptr<f32>) -> tensor<{M}x1x!tt.ptr<f32>, #src>
%8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr<f32>, #src>, tensor<{M}x1xi32, #src>
tt.store %8, %4 : tensor<{M}x1xf32, #src>
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
store_kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(0, 4, (M, 1)).astype('float32')
y = np.zeros((M, 1), dtype='float32')
x_tri = torch.tensor(x, device=device)
y_tri = torch.tensor(y, device=device)
pgm = store_kernel[(1, 1, 1)](x_tri, y_tri)
y_ref = x
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
MmaLayout(version=(2, 0), warps_per_cta=[4, 1])
]
@pytest.mark.parametrize("M", [64, 128, 256])
@pytest.mark.parametrize("src_layout", layouts)
@pytest.mark.parametrize("dst_layout", layouts)
@pytest.mark.parametrize("src_dim", [0, 1])
@pytest.mark.parametrize("dst_dim", [0, 1])
def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device='cuda'):
ir = f"""
#dst = {dst_layout}
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @kernel(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%0 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>
%4 = tt.splat %arg1 : (!tt.ptr<i32>) -> tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr<i32>, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
%7 = triton_gpu.convert_layout %3 : (tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.store %6, %7 : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(0, 4, (M, )).astype('int32')
y = np.zeros((M, ), dtype='int32')
x_tri = torch.tensor(x, device=device)
y_tri = torch.tensor(y, device=device)
pgm = kernel[(1, 1, 1)](x_tri, y_tri)
y_ref = x
np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
@triton.jit
def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
delta = mean_2 - mean_1
@@ -1346,6 +1605,68 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
)
layouts = [
BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]),
BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]),
BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]),
BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1])
]
@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]])
@pytest.mark.parametrize("src_layout", layouts)
def test_chain_reduce(M, N, src_layout, device='cuda'):
ir = f"""
#src = {src_layout}
module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{
tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr<i32> {{tt.divisibility = 16 : i32}}) {{
%cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src>
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src>
%2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src>
%3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
%4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src>
%5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src>
%7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src>
%8 = tt.splat %arg0 : (!tt.ptr<i32>) -> tensor<{M}x{N}x!tt.ptr<i32>, #src>
%9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr<i32>, #src>, tensor<{M}x{N}xi32, #src>
%10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src>
%11 = "tt.reduce"(%10) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
%12 = "tt.reduce"(%11) ({{
^bb0(%arg2: i32, %arg3: i32):
%13 = arith.addi %arg2, %arg3 : i32
tt.reduce.return %13 : i32
}}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32
tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32
tt.return
}}
}}
"""
import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f:
f.write(ir)
f.flush()
kernel = triton.compile(f.name)
rs = RandomState(17)
x = rs.randint(0, 4, (M, N)).astype('int32')
z = np.zeros((1,)).astype('int32')
x_tri = torch.tensor(x, device=device)
z_tri = torch.tensor(z, device=device)
pgm = kernel[(1, 1, 1)](x_tri, z_tri)
z_ref = np.sum(x)
np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3)
def test_generic_reduction(device='cuda'):
@triton.jit
@@ -2013,57 +2334,79 @@ def val_multiplier(val, i):
return val * i
@triton.jit(noinline=True)
def val_multiplier_noinline(val, i):
return val * i
@triton.jit
def vecmul_kernel(ptr, n_elements, rep):
def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr):
pid = tl.program_id(axis=0)
offsets = pid * 128 + tl.arange(0, 128)
mask = offsets < n_elements
vec = tl.load(ptr + offsets, mask=mask)
for i in range(1, rep):
vec = val_multiplier(vec, i)
if type == "inline":
vec = val_multiplier(vec, i)
else:
vec = val_multiplier_noinline(vec, i)
tl.store(ptr + offsets, vec, mask=mask)
def test_call():
@pytest.mark.parametrize("type", ["inline", "noinline"])
def test_call(type):
@triton.jit
def kernel(ptr, n_elements, num1, num2):
vecmul_kernel(ptr, n_elements, num1)
vecmul_kernel(ptr, n_elements, num2)
def kernel(ptr, n_elements, num1, num2, type: tl.constexpr):
vecmul_kernel(ptr, n_elements, num1, type)
vecmul_kernel(ptr, n_elements, num2, type)
size = 1024
rand_val = numpy_random((size,), dtype_str="float32")
rand_val_tri = to_triton(rand_val, device='cuda')
kernel[(size // 128,)](rand_val_tri, size, 3, 5)
err_msg = ""
try:
kernel[(size // 128,)](rand_val_tri, size, 3, 5, type)
except Exception as e:
err_msg = str(e)
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
if type == "noinline":
assert err_msg is not ""
else:
ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4
np.testing.assert_equal(to_numpy(rand_val_tri), ans)
# -------------
# test if
# -------------
@pytest.mark.parametrize("if_type", ["if", "if_exp"])
@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"])
def test_if(if_type):
@triton.jit
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr):
def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr):
pid = tl.program_id(0)
cond = tl.load(Cond)
if IfType == "if":
if pid % 2:
if pid % 2 == 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
else:
elif IfType == "if_exp":
tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse))
elif IfType == "if_and":
if BoolVar and pid % 2 == 0:
tl.store(Ret, tl.load(XTrue))
else:
tl.store(Ret, tl.load(XFalse))
cond = torch.ones(1, dtype=torch.int32, device='cuda')
x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda')
x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda')
ret = torch.empty(1, dtype=torch.float32, device='cuda')
kernel[(1,)](cond, x_true, x_false, ret, if_type)
kernel[(1,)](cond, x_true, x_false, ret, if_type, True)
assert torch.equal(ret, x_true)
def test_num_warps_pow2():
@@ -2227,24 +2570,105 @@ def test_if_else():
assert to_numpy(out)[0] == false_val[0]
def test_if_return():
@pytest.mark.parametrize("mode", ["dynamic", "static"])
def test_if_return(mode):
@triton.jit
def kernel(ExitEarly, Out):
if tl.load(ExitEarly):
tl.store(Out, 0)
return
def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr):
if mode == "dynamic":
if tl.load(ExitEarly):
tl.store(Out, 0)
return
else:
if cond:
tl.store(Out, 0)
return
tl.store(Out, 1)
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
# exit early path taken
exit_early[0] = 1
kernel[(1,)](exit_early, out)
kernel[(1,)](exit_early, out, True, mode)
assert to_numpy(out)[0] == 0
# exit early path not taken
exit_early[0] = 0
kernel[(1,)](exit_early, out)
kernel[(1,)](exit_early, out, False, mode)
assert to_numpy(out)[0] == 1
@triton.jit
def add_fn(x):
return x + 1
@triton.jit(noinline=True)
def add_fn_noinline(x):
return x + 1
@triton.jit
def add_fn_return(x, pid):
if pid == 0:
return x + 1
else:
return x + 2
@triton.jit
def add_fn_expr(Out, x):
tl.store(Out, x)
@triton.jit
def add_fn_static_cond(x, cond: tl.constexpr):
if cond == "":
return x
else:
return x + 1
@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return",
"ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"])
def test_if_call(call_type):
@triton.jit
def kernel(Out, call_type: tl.constexpr):
pid = tl.program_id(0)
o = tl.load(Out)
if pid == 0:
if call_type == "attribute":
# call attribute
a = o + 1
a = a.to(tl.int32).to(tl.int32)
o = a
else:
a = o
if call_type == "jit_function":
# regular function call
a = add_fn(a)
elif call_type == "jit_function_return":
# function without end_if block
a = add_fn_return(a, pid)
elif call_type == "ifexp":
# ifexp expression
a = add_fn(a) if pid == 0 else add_fn_return(a, pid)
elif call_type == "expr":
if pid == 1:
return
a = add_fn(a)
if pid == 0:
# call without return
add_fn_expr(Out, a)
elif call_type == "jit_function_static_cond":
a = add_fn_static_cond(a, call_type)
elif call_type == "jit_function_noinline":
a = add_fn_noinline(a)
o = a
tl.store(Out, o)
out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda')
kernel[(1,)](out, call_type)
assert to_numpy(out)[0] == 1

View File

@@ -1,7 +1,5 @@
import multiprocessing
import os
import shutil
from collections import namedtuple
import pytest
import torch
@@ -170,34 +168,32 @@ def test_jit_debug() -> None:
assert bins[0].asm['ttir'] != bins[1].asm['ttir']
def test_compile_in_subproc() -> None:
@triton.jit
def add_fn(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx))
def test_jit_noinline() -> None:
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx,
tl.load(a + idx) - tl.load(b + idx) * 777)
def kernel_add_device(a, b, o, N: tl.constexpr):
add_fn(a, b, o, N)
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = namedtuple("instance_descriptor", [
"divisible_by_16", "equal_to_1"])(
tuple(range(4)),
())
proc = multiprocessing.Process(
target=triton.compile,
kwargs=dict(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
))
proc.start()
proc.join()
assert proc.exitcode == 0
device = torch.cuda.current_device()
assert len(kernel_add_device.cache[device]) == 0
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add_device.cache[device]) == 1
bins = list(kernel_add_device.cache[device].values())
inline_ttir = bins[0].asm['ttir']
add_fn.noinline = True
add_fn.hash = None
kernel_add_device.hash = None
kernel_add_device.cache[device].clear()
kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
assert len(kernel_add_device.cache[device]) == 1
bins = list(kernel_add_device.cache[device].values())
noinline_ttir = bins[0].asm['ttir']
assert inline_ttir != noinline_ttir
def test_memory_leak() -> None:

View File

@@ -0,0 +1,83 @@
import multiprocessing
import os
import shutil
from collections import namedtuple
import torch
import triton
import triton.language as tl
tmpdir = ".tmp"
def reset_tmp_dir():
os.environ["TRITON_CACHE_DIR"] = tmpdir
if os.path.exists(tmpdir):
shutil.rmtree(tmpdir)
instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])
def compile_fn(config, cc):
@triton.jit
def kernel_sub(a, b, o, N: tl.constexpr):
idx = tl.arange(0, N)
tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777)
triton.compile(
fn=kernel_sub,
signature={0: "*fp32", 1: "*fp32", 2: "*fp32"},
device=0,
constants={3: 32},
configs=[config],
warm_cache_only=True,
cc=cc,
)
def test_compile_in_subproc() -> None:
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(4)), ())
multiprocessing.set_start_method('fork')
proc = multiprocessing.Process(
target=compile_fn,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0
def compile_fn_dot(config, cc):
@triton.jit
def kernel_dot(Z):
offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :]
z = tl.load(Z + offs)
z = tl.dot(z, z)
tl.store(Z + offs, z)
triton.compile(
fn=kernel_dot,
signature={0: "*fp32"},
device=0,
configs=[config],
warm_cache_only=True,
cc=cc,
)
def test_compile_in_forked_subproc() -> None:
reset_tmp_dir()
major, minor = torch.cuda.get_device_capability(0)
cc = major * 10 + minor
config = instance_descriptor(tuple(range(1)), ())
assert multiprocessing.get_start_method() == 'fork'
proc = multiprocessing.Process(
target=compile_fn_dot,
args=(config, cc))
proc.start()
proc.join()
assert proc.exitcode == 0

View File

@@ -4,10 +4,6 @@ __version__ = '2.1.0'
# ---------------------------------------
# Note: import order is significant here.
# TODO: torch needs to be imported first
# or pybind11 shows `munmap_chunk(): invalid pointer`
import torch # noqa: F401
# submodules
from .runtime import (
autotune,
@@ -22,6 +18,8 @@ from .runtime import (
)
from .runtime.jit import jit
from .compiler import compile, CompilationError
from .debugger.debugger import program_ids_from_grid
from . import language
from . import testing
@@ -45,6 +43,7 @@ __all__ = [
"runtime",
"TensorWrapper",
"testing",
"program_ids_from_grid",
]

View File

@@ -46,6 +46,8 @@ def mangle_fn(name, arg_tys, constants):
mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)])
mangled_constants = mangled_constants.replace('.', '_d_')
mangled_constants = mangled_constants.replace("'", '_sq_')
# [ and ] are not allowed in LLVM identifiers
mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
return ret
@@ -58,10 +60,21 @@ def _is_constexpr(o: Any) -> bool:
return isinstance(o, constexpr)
def _is_triton_scalar(o: Any) -> bool:
return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1)
def _unwrap_if_constexpr(o: Any):
return o.value if isinstance(o, constexpr) else o
def _check_fn_args(node, fn, args):
if fn.noinline:
for idx, arg in enumerate(args):
if not _is_constexpr(arg) and not _is_triton_scalar(arg):
raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}')
_condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels
@@ -86,7 +99,8 @@ class enter_sub_region:
class CodeGenerator(ast.NodeVisitor):
def __init__(self, context, prototype, gscope, attributes, constants, function_name,
module=None, is_kernel=False, function_types: Optional[Dict] = None, debug=False):
module=None, is_kernel=False, function_types: Optional[Dict] = None,
debug=False, noinline=False):
self.builder = ir.builder(context)
self.module = self.builder.create_module() if module is None else module
self.function_ret_types = {} if function_types is None else function_types
@@ -99,7 +113,9 @@ class CodeGenerator(ast.NodeVisitor):
self.is_kernel = is_kernel
self.last_node = None
self.debug = debug
self.noinline = noinline
self.scf_stack = []
self.last_ret_type = None
# SSA-construction
# name => language.tensor
self.local_defs: Dict[str, tensor] = {}
@@ -134,7 +150,7 @@ class CodeGenerator(ast.NodeVisitor):
def set_value(self, name: str,
value: Union[tensor, constexpr]) -> None:
''' This function:
called by visit_Assign() & visit_FuncDef() to store left value (lvalue)
called by visit_Assign() & visit_FunctionDef() to store left value (lvalue)
1. record local defined name (FIXME: should consider control flow)
2. store tensor in self.lvalue
'''
@@ -146,10 +162,9 @@ class CodeGenerator(ast.NodeVisitor):
#
def visit_compound_statement(self, stmts):
for stmt in stmts:
self.last_ret_type = self.visit(stmt)
if isinstance(stmt, ast.Return):
break
return stmts and isinstance(stmt, ast.Return)
ret_type = self.visit(stmt)
if ret_type is not None and isinstance(stmt, ast.Return):
self.last_ret_type = ret_type
# TODO: should be its own AST visitor
def contains_return_op(self, node):
@@ -164,8 +179,23 @@ class CodeGenerator(ast.NodeVisitor):
pred = lambda s: self.contains_return_op(s)
return any(pred(s) for s in node.body)
elif isinstance(node, ast.Call):
def check_undefined_name(cur_node):
# Check if name is an undefined local variable,
# which can only be a tensor or a constexpr
if isinstance(cur_node.func, ast.Attribute):
if isinstance(cur_node.func.value, ast.Name):
name = cur_node.func.value.id
if name not in self.lscope and name not in self.gscope:
return True
return False
# chain of calls
# e.g., tl.load(a).to(tl.float32)
return check_undefined_name(cur_node.func.value)
return False
if check_undefined_name(node):
return False
fn = self.visit(node.func)
if isinstance(fn, JITFunction):
if isinstance(fn, JITFunction) and fn.noinline is not True:
old_gscope = self.gscope
self.gscope = sys.modules[fn.fn.__module__].__dict__
ret = self.contains_return_op(fn.parse())
@@ -178,6 +208,18 @@ class CodeGenerator(ast.NodeVisitor):
if node.orelse:
ret = ret or any(pred(s) for s in node.orelse)
return ret
elif isinstance(node, ast.IfExp):
return self.contains_return_op(node.body) or self.contains_return_op(node.orelse)
elif isinstance(node, ast.Expr):
ret = False
for _, value in ast.iter_fields(node):
if isinstance(value, list):
for item in value:
if isinstance(item, ast.AST):
ret = ret or self.contains_return_op(item)
elif isinstance(value, ast.AST):
ret = ret or self.contains_return_op(value)
return ret
else:
return False
@@ -228,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor):
self.visit(init_node)
# initialize function
visibility = "public" if self.is_kernel else "private"
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility)
fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline)
self.module.push_back(fn)
entry = fn.add_entry_block()
arg_values = []
@@ -251,9 +293,9 @@ class CodeGenerator(ast.NodeVisitor):
self.set_value(arg_name, arg_value)
self.builder.set_insertion_point_to_start(entry)
# visit function body
has_ret = self.visit_compound_statement(node.body)
self.visit_compound_statement(node.body)
# finalize function
if not has_ret:
if self.last_ret_type is None:
self.builder.ret([])
else:
# update return type
@@ -265,6 +307,8 @@ class CodeGenerator(ast.NodeVisitor):
fn.reset_type(self.prototype.to_ir(self.builder))
if insert_pt:
self.builder.set_insertion_point_to_end(insert_pt)
# Remove dead code
fn.finalize()
def visit_arguments(self, node):
arg_names = []
@@ -415,6 +459,7 @@ class CodeGenerator(ast.NodeVisitor):
return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types
def visit_if_top_level(self, cond, node):
has_endif_block = True
with enter_sub_region(self) as sr:
liveins, ip_block = sr
then_block = self.builder.create_block()
@@ -429,20 +474,25 @@ class CodeGenerator(ast.NodeVisitor):
self.visit_then_else_blocks(node, liveins, then_block, else_block)
# then terminator
self.builder.set_insertion_point_to_end(then_block)
if not then_block.has_terminator():
if then_block.has_return() and else_block.has_return():
has_endif_block = False
endif_block.erase()
if not then_block.has_terminator() and has_endif_block:
self.builder.create_branch(endif_block, [then_defs[n].handle for n in names])
# else terminator
self.builder.set_insertion_point_to_end(else_block)
if not else_block.has_terminator():
if not else_block.has_terminator() and has_endif_block:
self.builder.create_branch(endif_block, [else_defs[n].handle for n in names])
for ty in ir_ret_types:
endif_block.add_argument(ty)
# change block
self.builder.set_insertion_point_to_start(endif_block)
# update value
for i, name in enumerate(names):
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
self.set_value(name, new_tensor)
if has_endif_block:
for ty in ir_ret_types:
endif_block.add_argument(ty)
if has_endif_block:
# change block
self.builder.set_insertion_point_to_start(endif_block)
# update value
for i, name in enumerate(names):
new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i])
self.set_value(name, new_tensor)
# TODO: refactor
def visit_if_scf(self, cond, node):
@@ -650,6 +700,8 @@ class CodeGenerator(ast.NodeVisitor):
ub = language.core._to_tensor(ub, self.builder)
step = language.core._to_tensor(step, self.builder)
# induction variable type
if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int():
raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})")
iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype)
iv_ir_type = iv_type.to_ir(self.builder)
@@ -773,7 +825,7 @@ class CodeGenerator(ast.NodeVisitor):
if not self.module.has_function(fn_name):
prototype = language.function_type([], arg_types)
gscope = sys.modules[fn.fn.__module__].__dict__
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug)
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=fn.debug, noinline=fn.noinline)
generator.visit(fn.parse())
callee_ret_type = generator.last_ret_type
self.function_ret_types[fn_name] = callee_ret_type
@@ -805,6 +857,7 @@ class CodeGenerator(ast.NodeVisitor):
if not self.debug:
return
if isinstance(fn, JITFunction):
_check_fn_args(node, fn, args)
return self.call_JitFunction(fn, args, kws)
if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn):
extra_kwargs = dict(_builder=self.builder)

View File

@@ -11,8 +11,6 @@ from collections import namedtuple
from pathlib import Path
from typing import Any, Tuple
import torch
import triton
import triton._C.libtriton.triton as _triton
from ..runtime import driver
@@ -254,7 +252,7 @@ def convert_type_repr(x):
return x
def make_hash(fn, **kwargs):
def make_hash(fn, arch, **kwargs):
if isinstance(fn, triton.runtime.JITFunction):
configs = kwargs["configs"]
signature = kwargs["signature"]
@@ -265,7 +263,7 @@ def make_hash(fn, **kwargs):
# Get unique key for the compiled code
get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1))
configs_key = [get_conf_key(conf) for conf in configs]
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}"
key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}"
return hashlib.md5(key.encode("utf-8")).hexdigest()
assert isinstance(fn, str)
return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest()
@@ -330,6 +328,10 @@ def is_hip():
def get_architecture_descriptor(capability):
try:
import torch
except ImportError:
raise ImportError("Triton requires PyTorch to be installed")
if capability is None:
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
@@ -428,7 +430,7 @@ def compile(fn, **kwargs):
# cache manager
so_path = make_stub(name, signature, constants)
# create cache manager
fn_cache_manager = get_cache_manager(make_hash(fn, **kwargs))
fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs))
# determine name and extension type of provided function
if isinstance(fn, triton.runtime.JITFunction):
name, ext = fn.__name__, "ast"

View File

@@ -0,0 +1,9 @@
from typing import Tuple
import dataclasses
@dataclasses.dataclass
class ExecutionContext:
program_id: Tuple[int]
program_size: Tuple[int]

View File

@@ -0,0 +1,170 @@
import itertools
import random
from typing import Tuple
import triton
import triton.language as tl
from .core import ExecutionContext
from .memory_map import MemoryMap
from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor,
debugger_constexpr)
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
tl_method_backup = {}
def get_proxy_method(proxy, name):
method = getattr(proxy, name)
def fun(*args, **kwarg):
return method(*args, **kwarg)
return fun
def attach_triton(module, proxy):
method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"]
for name in method_list:
if hasattr(module, name):
attr = getattr(module, name)
tl_method_backup[name] = attr
if callable(attr):
setattr(module, name, get_proxy_method(proxy, name))
else:
setattr(module, name, getattr(proxy, name))
def detach_triton(module):
for name, method in tl_method_backup.items():
setattr(module, name, method)
def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]:
# reverse the grid dimensions and generate the range for each dimension
reversed_grid = reversed(grid)
ranges_for_each_dimension = [range(dim) for dim in reversed_grid]
# gen all combinations
index_combinations = list(itertools.product(*ranges_for_each_dimension))
random.shuffle(index_combinations)
for index_combination in index_combinations:
yield index_combination
class DebuggerFunction:
def __init__(self, func, grid=(1,)):
self.func = func
self.grid = grid
def _is_constexpr(self, name):
return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr
def _get_constexpr(self):
result = []
for name, annotation in self.func.__annotations__.items():
if annotation is triton.language.core.constexpr:
result.append(name)
return result
def _assert_constexpr(self, **kwargs):
constexp = self._get_constexpr()
missing = [i for i in constexp if i not in kwargs.keys()]
assert len(missing) == 0, f"You must specify constexpr {missing}"
def _get_grid(self, **kwargs):
if callable(self.grid):
return self.grid(kwargs)
else:
return self.grid
def __call__(self, *args, **kwargs):
self._assert_constexpr(**kwargs)
memory = MemoryMap()
def convert_arg(v):
name, arg = v
if torch.is_tensor(arg):
ptr = memory.add_tensor(arg)
return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda"))
if self._is_constexpr(name):
return debugger_constexpr(arg)
return WrappedTensor(_primitive_to_tensor(arg))
new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args)))
new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]}
grid = self._get_grid(**kwargs)
for program_id in program_ids_from_grid(grid):
proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid))
attach_triton(tl, proxy)
self.func(*new_args, **new_kwargs)
detach_triton(tl)
class GridSelector:
"""
Entry point of the debugger
"""
def __init__(self, func):
version = torch.__version__
assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}"
self.func = func
def __getitem__(self, grid):
return DebuggerFunction(self.func, grid)
def __call__(self, *args, **kwargs):
return DebuggerFunction(self.func)(*args, **kwargs)
class AutotuneGridSelector:
def __init__(self, func, autotune_params):
self.func = func
self.autotune_params = autotune_params
def __getitem__(self, grid):
return AutotuneRunner(self.func, self.autotune_params, grid)
def __call__(self, *args, **kwargs):
return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs)
class AutotuneRunner:
def __init__(self, func, autotune_params, grid=None):
self.func = func
self.autotune_params = autotune_params
self.grid = grid
def __call__(self, *args, **kwargs):
assert len(self.autotune_params["configs"]) >= 1
for config in self.autotune_params["configs"][1:]:
def convert_arg(v):
if torch.is_tensor(v):
return torch.clone(v)
return v
new_args = tuple(map(convert_arg, args))
new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()}
if self.grid:
self.func[self.grid](*new_args, **new_kwargs, **config.kwargs)
else:
self.func(*new_args, **new_kwargs, **config.kwargs)
main_config = self.autotune_params["configs"][0]
if self.grid:
self.func[self.grid](*args, **kwargs, **main_config.kwargs)
else:
self.func(*args, **kwargs, **main_config.kwargs)
def triton_debug_autotune(**kwars):
def wrapper(func):
return AutotuneGridSelector(func, kwars)
return wrapper

View File

@@ -0,0 +1,100 @@
import dataclasses
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
@dataclasses.dataclass
class RegisteredStorage:
storage: torch.Storage
dtype: torch.dtype
size: int
ptr: int
@property
def end_ptr(self) -> int:
return self.ptr + self.size
@property
def access_tensor(self) -> torch.Tensor:
return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device)
def ensure_immutable(self):
assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size
class MemoryMap:
storages: [RegisteredStorage]
def __init__(self):
self.storages = []
def _get_registered_storage(self, pointer: torch.Tensor):
max_pointer = torch.max(pointer).item()
min_pointer = torch.min(pointer).item()
registered_storage = next(
filter(
lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages
),
None,
)
if registered_storage is None:
raise Exception("Storage not found or pointers spanning multiple tensors")
registered_storage.ensure_immutable()
return registered_storage
def add_tensor(self, t: torch.Tensor):
storage = t.untyped_storage()
self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr()))
return t.data_ptr()
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
):
assert pointer.is_cuda
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert mask.is_cuda
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
# Todo: The type is wrong here, we can't determine the correct type
return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda")
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda")
block[mask] = access_tensor[index_tensor[mask]]
return block
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
assert 0 < pointer.dim() < 3
assert pointer.dtype == torch.int64
if mask is None:
mask = torch.ones_like(pointer).bool()
assert 0 < mask.dim() < 3
assert mask.dtype == torch.bool
mask = mask.expand(pointer.size())
if torch.all(~mask):
return
registered_storage = self._get_registered_storage(pointer[mask])
access_tensor = registered_storage.access_tensor
index_tensor = pointer - registered_storage.ptr
access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype)

View File

@@ -0,0 +1,621 @@
import triton
from .core import ExecutionContext
from .memory_map import MemoryMap
from triton.debugger import torch_wrapper
torch = torch_wrapper.torch
def _primitive_to_tensor(x):
"""
Converts various Python primitive data types to PyTorch tensor.
"""
tensor_args = {"device": "cuda"}
if isinstance(x, bool):
return torch.tensor([x], dtype=torch.bool, **tensor_args)
elif isinstance(x, int):
if -(2**31) <= x < 2**31:
return torch.tensor([x], dtype=torch.int32, **tensor_args)
elif -(2**63) <= x < 2**63:
return torch.tensor([x], dtype=torch.int64, **tensor_args)
else:
raise RuntimeError(f"Nonrepresentable integer {x}.")
elif isinstance(x, float):
return torch.tensor([x], dtype=torch.float32, **tensor_args)
elif torch.is_tensor(x):
return x
elif isinstance(x, WrappedTensor):
return x
elif isinstance(x, debugger_constexpr):
if x.value is None:
return None
return _primitive_to_tensor(x.value)
elif x is None:
return None
assert False, f"cannot convert {x} of type {type(x)} to tensor"
def _infer_tensor(func):
"""
A decorator function to harmonize function args:
- converts primitives to PyTorch tensors
- wraps PyTorch tensors with WrappedTensors
"""
def wrapper(*args):
new_args = tuple(map(lambda v: _primitive_to_tensor(v), args))
new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args))
return func(*new_args)
return wrapper
def _tensor_operation(func):
"""
A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function.
Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor).
"""
def wrapper(*args, **kwargs):
for arg in args:
assert not torch.is_tensor(arg), "unexpected tensor argument"
def unwrap_tensor(v):
if isinstance(v, WrappedTensor):
return v.tensor
if isinstance(v, debugger_constexpr):
return v.value
return v
new_args = tuple(map(unwrap_tensor, args))
new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()}
result = func(args[0], *new_args[1:], **new_kwargs)
return WrappedTensor(result) if torch.is_tensor(result) else result
return wrapper
class debugger_constexpr:
def __init__(self, value):
if isinstance(value, debugger_constexpr):
self.value = value.value
else:
self.value = value
def __str__(self) -> str:
return "debugger_constexpr(" + str(self.value) + ")"
def __index__(self) -> int:
return self.value
def __bool__(self):
return bool(self.value)
def __ge__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value >= other
def __gt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value > other
def __le__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value <= other
def __lt__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value < other
def __eq__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value == other
def __or__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __ror__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value | other
def __and__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def __rand__(self, other):
other = other.value if isinstance(other, debugger_constexpr) else other
return self.value & other
def to(self, dtype, bitcast=False, _builder=None):
if dtype in [torch.int64]:
ret_ty = int
elif dtype == torch.bool:
ret_ty = bool
elif dtype in [torch.float64]:
ret_ty = float
else:
raise ValueError("dtype not supported in debugger")
return debugger_constexpr(ret_ty(self.value))
class WrappedTensor:
def __init__(self, tensor):
self.tensor = tensor
def __index__(self) -> int:
return self.tensor.item()
def __str__(self) -> str:
return "wrapped_" + str(self.tensor)
def __bool__(self) -> bool:
return torch.all(self.tensor == True).item() # noqa: E712
@property
def dtype(self):
return self.tensor.dtype
@_infer_tensor
@_tensor_operation
def __add__(self, other):
return torch.add(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __radd__(self, other):
return self.__add__(other)
@_infer_tensor
@_tensor_operation
def __sub__(self, other):
return torch.sub(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rsub__(self, other):
return torch.sub(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mul__(self, other):
return torch.mul(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmul__(self, other):
return self.__mul__(other)
@_infer_tensor
@_tensor_operation
def __truediv__(self, other):
return torch.div(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rtruediv__(self, other):
return torch.div(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __floordiv__(self, other):
return torch.floor_divide(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rfloordiv__(self, other):
return torch.floor_divide(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __mod__(self, other):
return torch.remainder(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rmod__(self, other):
return torch.remainder(other, self.tensor)
@_infer_tensor
@_tensor_operation
def __neg__(self):
return -self.tensor
@_infer_tensor
@_tensor_operation
def __invert__(self):
return ~self.tensor
@_infer_tensor
@_tensor_operation
def __and__(self, other):
return torch.bitwise_and(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __or__(self, other):
return torch.bitwise_or(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __xor__(self, other):
return torch.bitwise_xor(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __lshift__(self, other):
return torch.bitwise_left_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __rshift__(self, other):
return torch.bitwise_right_shift(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __gt__(self, other):
return self.tensor > other
@_infer_tensor
@_tensor_operation
def __rgt__(self, other):
return other > self.tensor
@_infer_tensor
@_tensor_operation
def __ge__(self, other):
return self.tensor >= other
@_infer_tensor
@_tensor_operation
def __rge__(self, other):
return other >= self.tensor
@_infer_tensor
@_tensor_operation
def __lt__(self, other):
return self.tensor < other
@_infer_tensor
@_tensor_operation
def __rlt__(self, other):
return other < self.tensor
@_infer_tensor
@_tensor_operation
def __le__(self, other):
return self.tensor <= other
@_infer_tensor
@_tensor_operation
def __rle__(self, other):
return other <= self.tensor
@_infer_tensor
@_tensor_operation
def __eq__(self, other):
return torch.equal(self.tensor, other)
@_infer_tensor
@_tensor_operation
def __ne__(self, other):
return not torch.equal(self.tensor, other)
@_tensor_operation
def __getitem__(self, slices):
return self.tensor.__getitem__(slices)
# if isinstance(slices, slice):
# slices = [slices]
# src_shape = self.shape
# dst_shape = []
# curr = 0
# for sl in slices:
# if isinstance(sl, constexpr) and sl.value is None:
# dst_shape.append(1)
# elif sl == slice(None, None, None):
# dst_shape.append(src_shape[curr].value)
# curr += 1
# ret = torch.reshape(self.tensor, dst_shape, )
# return ret
@_tensor_operation
def to(self, dtype, bitcast=False):
return self.tensor.to(dtype)
# if isinstance(bitcast, constexpr):
# bitcast = bitcast.value
# if bitcast:
# return semantic.bitcast(self, dtype, )
# return semantic.cast(self, dtype, )
def _constexpr_to_value(v):
if isinstance(v, debugger_constexpr):
return v.value
return v
class TritonLangProxy:
_memory_map: MemoryMap
_context: ExecutionContext
def __init__(self, memory_map: MemoryMap, context: ExecutionContext):
self._memory_map = memory_map
self._context = context
# Types
# Removed void, int1, float8, uint16, uint32, uint64, pi32_t
# constexpr = debugger_constexpr
# Program functions
@_tensor_operation
def load(
self,
pointer: torch.Tensor,
mask: torch.Tensor = None,
other=0.0,
cache_modifier="",
eviction_policy="",
volatile=False,
):
return self._memory_map.load(pointer, mask, other)
@_tensor_operation
def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None):
return self._memory_map.store(pointer, value, mask)
@_tensor_operation
def program_id(self, axis):
assert axis < len(self._context.program_id)
return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def num_programs(self, axis):
assert axis < len(self._context.program_size)
return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda")
@_tensor_operation
def arange(self, start, end):
return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda")
@_tensor_operation
def zeros(self, shape, dtype):
for i, d in enumerate(shape):
if not isinstance(d, debugger_constexpr):
raise TypeError(f"Shape element {i} must have type `constexpr`")
if not isinstance(d.value, int):
raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]")
shape = [x.value for x in shape]
if isinstance(dtype, triton.language.core.dtype):
if dtype.is_fp32():
dtype = torch.float32
elif dtype.is_fp16():
dtype = torch.float16
elif dtype.is_bf16():
dtype = torch.bfloat16
elif dtype.is_int32():
dtype = torch.int32
elif dtype.is_int16():
dtype = torch.int16
elif dtype.is_int8():
dtype = torch.int8
else:
raise TypeError(f"Unsupported dtype {dtype}")
return torch.zeros(size=shape, dtype=dtype, device="cuda")
@_tensor_operation
def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16):
raise NotImplementedError()
@_tensor_operation
def broadcast(self, input, other):
raise NotImplementedError()
@_tensor_operation
def broadcast_to(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def cat(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def reshape(self, input, shape):
raise NotImplementedError()
@_tensor_operation
def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True):
assert input.dtype == other.dtype
if trans_a:
input = input.T
if trans_b:
other = other.T
return torch.matmul(input=input, other=other)
@_tensor_operation
def atomic_cas(self, pointer, cmp, val):
stored = self._memory_map.load(pointer, None, 0.0)
if not isinstance(cmp, torch.Tensor):
cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda")
if not isinstance(val, torch.Tensor):
val = torch.tensor([val], dtype=stored.dtype, device="cuda")
if stored == cmp:
self._memory_map.store(pointer, val, None)
return stored
@_tensor_operation
def atomic_xchg(self, pointer, val, mask=None):
if isinstance(val, int):
val = torch.tensor([val], dtype=torch.int32, device="cuda")
stored = self._memory_map.load(pointer, mask, 0.0)
self._memory_map.store(pointer, val, mask)
return stored
@_tensor_operation
def atomic_add(self, pointer, val, mask=None):
# arbitrary other value as it will masked during storing
stored = self._memory_map.load(pointer, mask, 0.0)
result = stored + val
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_max(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.maximum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_min(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0.0)
result = torch.minimum(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_and(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_and(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_or(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_or(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def atomic_xor(self, pointer, val, mask=None):
stored = self._memory_map.load(pointer, mask, 0)
result = torch.bitwise_xor(stored, val)
self._memory_map.store(pointer, result, mask)
return stored
@_tensor_operation
def where(self, condition, x, y):
condition = _primitive_to_tensor(condition)
x = _primitive_to_tensor(x)
y = _primitive_to_tensor(y)
return torch.where(condition, x, y)
@_tensor_operation
def umulhi(self, x, y):
raise NotImplementedError()
@_tensor_operation
def fdiv(self, x, y, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def exp(self, x):
return torch.exp(x)
@_tensor_operation
def log(self, x):
return torch.log(x)
@_tensor_operation
def cos(self, x):
return torch.cos(x)
@_tensor_operation
def sin(self, x):
return torch.sin(x)
@_tensor_operation
def sqrt(self, x):
return torch.sqrt(x)
@_tensor_operation
def globaltimer(self):
raise NotImplementedError()
@_tensor_operation
def clock(self):
raise NotImplementedError()
@_tensor_operation
def debug_barrier(self):
raise NotImplementedError()
@_tensor_operation
def multiple_of(self, input, values):
return input
@_tensor_operation
def max_contiguous(self, input, values):
return input
@_tensor_operation
def abs(self, x):
return torch.abs(x)
@_tensor_operation
def cdiv(self, x, div):
return (x + div - 1) // div
@_tensor_operation
def minimum(self, x, y):
if isinstance(x, int):
x = torch.tensor(x, device="cuda")
if isinstance(y, int):
y = torch.tensor(y, device="cuda")
return torch.minimum(x, y)
@_tensor_operation
def maximum(self, x, y):
return torch.maximum(x, y)
@_tensor_operation
def sigmoid(self, x):
raise NotImplementedError()
@_tensor_operation
def softmax(self, x, ieee_rounding=False):
raise NotImplementedError()
@_tensor_operation
def ravel(self, x):
raise NotImplementedError()
@_tensor_operation
def swizzle2d(self, i, j, size_i, size_j, size_g):
raise NotImplementedError()
@_tensor_operation
def zeros_like(self, input):
raise NotImplementedError()
@_tensor_operation
def max(self, input, axis=None):
if axis is None:
return torch.max(input)
return torch.max(input, dim=axis).values
@_tensor_operation
def argmax(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def min(self, input, axis=None):
if axis is None:
return torch.min(input)
return torch.min(input, dim=axis).values
@_tensor_operation
def argmin(self, input, axis):
raise NotImplementedError()
@_tensor_operation
def sum(self, input, axis=None):
if axis is None:
return torch.sum(input)
return torch.sum(input, dim=axis)
@_tensor_operation
def xor_sum(self, input, axis):
raise NotImplementedError()

View File

@@ -0,0 +1,18 @@
try:
import torch as _torch
except ImportError:
_torch = None
class TorchWrapper:
"""
Helps in making torch an optional dependency
"""
def __getattr__(self, name):
if _torch is None:
raise ImportError("Triton requires PyTorch to be installed")
return getattr(_torch, name)
torch = TorchWrapper()

View File

@@ -39,6 +39,7 @@ from .core import (
dot,
dtype,
exp,
expand_dims,
full,
fdiv,
float16,
@@ -130,6 +131,7 @@ __all__ = [
"dot",
"dtype",
"exp",
"expand_dims",
"extra",
"fdiv",
"float16",

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from typing import Callable, List, TypeVar
from typing import Callable, List, Sequence, TypeVar
import triton
from . import semantic
@@ -903,6 +903,41 @@ def reshape(input, shape, _builder=None):
shape = _shape_check_impl(shape)
return semantic.reshape(input, shape, _builder)
def _wrap_axis(axis, ndim):
if not (-ndim <= axis < ndim):
raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}")
return axis if axis >= 0 else axis + ndim
@builtin
def expand_dims(input, axis, _builder=None):
"""
Expand the shape of a tensor, by inserting new length-1 dimensions.
Axis indices are with respect to the resulting tensor, so
``result.shape[axis]`` will be 1 for each axis.
:param input: The input tensor.
:type input: tl.tensor
:param axis: The indices to add new axes
:type axis: int | Sequence[int]
"""
axis = _constexpr_to_value(axis)
axes = list(axis) if isinstance(axis, Sequence) else [axis]
new_ndim = len(input.shape) + len(axes)
axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]
if len(set(axes)) != len(axes):
raise ValueError(f"expand_dims recieved duplicate axes, normalized axes = {axes}")
ret = input
for a in sorted(axes):
ret = semantic.expand_dims(ret, a, _builder)
return ret
# -----------------------
# Linear Algebra
# -----------------------
@@ -1301,9 +1336,9 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):
if len(input.shape) > 1:
# Broadcast index across the non-reduced axes
expand_dims_index = [constexpr(None)] * len(input.shape)
expand_dims_index[axis] = slice(None)
index = index.__getitem__(expand_dims_index, _builder=_builder)
axes_to_expand = [constexpr(d) for d in range(len(input.shape))]
del axes_to_expand[axis]
index = expand_dims(index, axes_to_expand, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)
rvalue, rindices = reduce((input, index), axis, combine_fn,

View File

@@ -3,9 +3,6 @@ from __future__ import annotations # remove after python 3.11
from functools import wraps
from typing import List, Optional, Sequence, Tuple, TypeVar
import torch
import triton
from . import core as tl
from triton._C.libtriton.triton import ir
@@ -665,7 +662,7 @@ def bitcast(input: tl.tensor,
src_bits = src_sca_ty.primitive_bitwidth
dst_bits = dst_sca_ty.primitive_bitwidth
if src_bits != dst_bits:
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to "
raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to "
"data-type of size " + str(dst_bits))
return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)),
dst_ty)
@@ -1190,14 +1187,6 @@ def dot(lhs: tl.tensor,
allow_tf32: bool,
out_dtype: tl.dtype,
builder: ir.builder) -> tl.tensor:
if torch.version.hip is None:
device = triton.runtime.jit.get_current_device()
capability = triton.runtime.jit.get_device_capability(device)
capability = capability[0] * 10 + capability[1]
if capability < 70:
assert (
not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8()
), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)"
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!"
assert len(lhs.shape) == 2 and len(rhs.shape) == 2
@@ -1398,6 +1387,10 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
cond_ty = cond.type
if not cond_ty.is_block():
cond_ty = tl.block_type(cond_ty.scalar, (1,))
cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty)
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)

View File

@@ -4,17 +4,6 @@ import triton
import triton.language as tl
def next_power_of_2(n):
n -= 1
n |= n >> 1
n |= n >> 2
n |= n >> 4
n |= n >> 8
n |= n >> 16
n += 1
return n
def num_warps(N):
if N < 2048:
return 4
@@ -24,7 +13,7 @@ def num_warps(N):
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
@triton.jit
def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)
@@ -49,7 +38,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr):
@triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])})
@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])})
@triton.jit
def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr):
row = tl.program_id(0)

View File

@@ -142,6 +142,7 @@ class Autotuner(KernelInterface):
class Config:
"""
An object that represents a possible kernel configuration for the auto-tuner to try.
:ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments.
:type meta: dict[Str, Any]
:ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if
@@ -173,8 +174,10 @@ class Config:
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
"""
Decorator for auto-tuning a :code:`triton.jit`'d function.
.. highlight:: python
.. code-block:: python
@triton.autotune(configs=[
triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4),
triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8),
@@ -223,8 +226,10 @@ def heuristics(values):
"""
Decorator for specifying how the values of certain meta-parameters may be computed.
This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable.
.. highlight:: python
.. code-block:: python
@triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))})
@triton.jit
def kernel(x_ptr, x_size, **META):

View File

@@ -83,7 +83,8 @@ class DependenciesFinder(ast.NodeVisitor):
finder = DependenciesFinder(func.__globals__, func.src)
finder.visit(tree)
func.hash = finder.ret
self.ret = (self.ret + func.hash).encode("utf-8")
noinline = str(getattr(func, 'noinline', False))
self.ret = (self.ret + func.hash + noinline).encode("utf-8")
self.ret = hashlib.md5(self.ret).hexdigest()
# -----------------------------------------------------------------------------
@@ -175,7 +176,7 @@ class JITFunction(KernelInterface[T]):
return True
return False
divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize}
return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1))
# return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1)
@@ -298,13 +299,13 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
set_current_device(device)
if stream is None and not warmup:
stream = get_cuda_stream(device)
try:
bin = cache[device][key]
bin = cache[device].get(key, None)
if bin is not None:
if not warmup:
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args})
return bin
# kernel not cached -- compile
except KeyError:
else:
# build dict of constant values
args = [{args}]
all_args = {', '.join([f'{arg}' for arg in self.arg_names])},
@@ -334,7 +335,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
exec(src, scope)
return scope[self.fn.__name__]
def __init__(self, fn, version=None, do_not_specialize=None, debug=None):
def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None):
self.fn = fn
self.module = fn.__module__
self.version = version
@@ -356,6 +357,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
self.kernel_decorators = []
self.kernel = None
self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug
self.noinline = noinline
# annotations
normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty
self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()}
@@ -425,6 +427,7 @@ def jit(
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
) -> Callable[[T], JITFunction[T]]:
...
@@ -435,6 +438,8 @@ def jit(
version=None,
do_not_specialize: Optional[Iterable[int]] = None,
debug: Optional[bool] = None,
noinline: Optional[bool] = None,
interpret: Optional[bool] = None,
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
"""
Decorator for JIT-compiling a function using the Triton compiler.
@@ -456,13 +461,17 @@ def jit(
def decorator(fn: T) -> JITFunction[T]:
assert callable(fn)
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
debug=debug,
)
if interpret:
from ..debugger.debugger import GridSelector
return GridSelector(fn)
else:
return JITFunction(
fn,
version=version,
do_not_specialize=do_not_specialize,
debug=debug,
noinline=noinline,
)
if fn is not None:
return decorator(fn)

BIN
python/triton/third_party/cuda/bin/ptxas vendored Executable file

Binary file not shown.

View File

@@ -6,8 +6,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
// CHECK-LABEL: matmul_loop
// There shouldn't be any aliasing with the dot op encoding.

View File

@@ -402,8 +402,6 @@ tt.func @permute_2d(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: i32
// -----
module {
// This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer.
// CHECK-LABEL: @store_constant_align
tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) {
@@ -433,8 +431,6 @@ tt.func @store_constant_align(%addr: !tt.ptr<f32> {tt.divisibility = 16 : i32},
tt.return
}
}
// -----
// This IR is dumped from vecadd test.
@@ -491,3 +487,88 @@ tt.func @vecadd_mask_align_1(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %
tt.store %15, %13, %10 : tensor<64xf32>
tt.return
}
// -----
module {
// We don't use function cloning here, so the alignment info is the gcd of all call sites.
// CHECK-LABEL: @addptr_hints
tt.func @addptr_hints(%arg0: !tt.ptr<i32>) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%cst1 = arith.constant 1 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%1 = tt.addptr %arg0, %cst1 : !tt.ptr<i32>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4
%cst4 = arith.constant 4 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%2 = tt.addptr %arg0, %cst4 : !tt.ptr<i32>, i32
// CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16
%cst16 = arith.constant 16 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%3 = tt.addptr %arg0, %cst4 : !tt.ptr<i32>, i32
tt.return
}
// CHECK-LABEL: @kernel_div16
tt.func @kernel_div16(%arg0: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
tt.return
}
// CHECK-LABEL: @kernel_div8
tt.func @kernel_div8(%arg0: !tt.ptr<i32> {tt.divisibility = 8 : i32}) {
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
tt.return
}
// CHECK-LABEL: @kernel_div4
tt.func @kernel_div4(%arg0: !tt.ptr<i32> {tt.divisibility = 4 : i32}) {
tt.call @addptr_hints(%arg0) : (!tt.ptr<i32>) -> ()
tt.return
}
}
// -----
module {
// We don't use function cloning here, so the alignment info is the gcd of all call sites.
// CHECK-LABEL: @mul
tt.func @mul(%arg0: i32) {
// CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1
%cst1 = arith.constant 1 : i32
// CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%1 = arith.muli %arg0, %cst1 : i32
tt.return
}
// CHECK-LABEL: @bar
tt.func @bar(%arg0: i32) {
tt.call @mul(%arg0) : (i32) -> ()
tt.return
}
// CHECK-LABEL: @foo
tt.func @foo(%arg0: i32) {
tt.call @mul(%arg0) : (i32) -> ()
tt.return
}
// CHECK-LABEL: @call_graph
tt.func @call_graph(%arg0: i32) {
// CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 12
%cst12 = arith.constant 12 : i32
// CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = <none>
%0 = arith.muli %arg0, %cst12 : i32
tt.call @foo(%0) : (i32) -> ()
// CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = 8
%cst8 = arith.constant 8 : i32
// CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = <none>
%1 = arith.muli %arg0, %cst8 : i32
tt.call @bar(%1) : (i32) -> ()
tt.return
}
}

View File

@@ -7,8 +7,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
@@ -28,10 +28,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK: offset = 0, size = 4608
// CHECK: scratch offset = 0, size = 4608
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
// CHECK-NEXT: offset = 0, size = 4224
// CHECK-NEXT: scratch offset = 0, size = 4224
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT>
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
@@ -56,17 +56,17 @@ tt.func @reusable(%A : !tt.ptr<f16>) {
%a_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
%b_ptr = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #AL>
%a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK-NEXT: offset = 0, size = 4608
// CHECK-NEXT: scratch offset = 0, size = 4608
%a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: offset = 0, size = 1152
// CHECK-NEXT: scratch offset = 0, size = 1152
%a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
%a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
// CHECK-NEXT: offset = 0, size = 4608
// CHECK-NEXT: scratch offset = 0, size = 4608
%a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT>
%c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
%a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL>
// CHECK-NEXT: offset = 0, size = 1152
// CHECK-NEXT: scratch offset = 0, size = 1152
%a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT>
%c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C>
tt.return
@@ -396,3 +396,127 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
}
}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: alloc1
tt.func @alloc1(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: alloc2
tt.func @alloc2(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 1024
%cst0 = triton_gpu.alloc_tensor : tensor<32x16xf16, #A_SHARED>
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: alloc3
tt.func @alloc3(%cond : i1) {
scf.if %cond {
// CHECK: offset = 0, size = 512
%cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
} else {
// CHECK-NEXT: offset = 0, size = 1024
%cst0 = triton_gpu.alloc_tensor : tensor<16x32xf16, #A_SHARED>
}
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: alloc4
tt.func @alloc4(%A : !tt.ptr<f16>, %cond : i1) {
scf.if %cond {
// CHECK: virtual offset = 0, size = 1024
tt.call @alloc3(%cond) : (i1) -> ()
} else {
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: single_call
tt.func @single_call(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: multiple_calls
tt.func @multiple_calls(%A : !tt.ptr<f16>) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK-NEXT: virtual offset = 0, size = 1024
tt.call @alloc2(%A) : (!tt.ptr<f16>) -> ()
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: if_else_calls
tt.func @if_else_calls(%A : !tt.ptr<f16>, %cond : i1) {
scf.if %cond {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: offset = 0, size = 1024
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
} else {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK-NEXT: virtual offset = 0, size = 1024
tt.call @alloc2(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: for_calls
tt.func @for_calls(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
%lb = arith.constant 0 : index
%ub = arith.constant 10 : index
%step = arith.constant 1 : index
scf.for %iv = %lb to %ub step %step {
// CHECK-NEXT: virtual offset = 0, size = 512
tt.call @alloc1(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
// CHECK-NEXT: size = 512
}
// CHECK-LABEL: call_graph_1
tt.func @call_graph_1(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: virtual offset = 0, size = 1024
tt.call @alloc3(%cond) : (i1) -> ()
tt.return
// CHECK-NEXT: size = 1024
}
// CHECK-LABEL: call_graph_2
tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK: offset = 0, size = 512
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK-NEXT: virtual offset = 0, size = 1024
tt.call @alloc4(%A, %cond) : (!tt.ptr<f16>, i1) -> ()
tt.return
// CHECK-NEXT: size = 1024
}
}

View File

@@ -7,8 +7,8 @@
#A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}>
#B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
@@ -503,3 +503,136 @@ tt.func @cf_if_else_return(%i1 : i1) {
}
}
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: convert_layout1
tt.func @convert_layout1(%A : !tt.ptr<f16>) {
// CHECK-NOT: gpu.barrier
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
tt.return
}
// CHECK-LABEL: convert_layout2
tt.func @convert_layout2(%A : !tt.ptr<f16>) {
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%1 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
%2 = tt.cat %1, %1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED>
// CHECK: triton_gpu.convert_layout
// CHECK-NEXT: gpu.barrier
%3 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
%4 = tt.cat %2, %2 {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #AL>
tt.return
}
// CHECK-LABEL: convert_layout3
tt.func @convert_layout3(%cond : i1) {
scf.if %cond {
%0 = triton_gpu.alloc_tensor : tensor<16x64xf16, #A_SHARED>
// CHECK: triton_gpu.convert_layout
// CHECK-NOT: gpu.barrier
%1 = triton_gpu.convert_layout %0 : (tensor<16x64xf16, #A_SHARED>) -> tensor<16x64xf16, #AL>
} else {
%0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED>
// CHECK: triton_gpu.convert_layout
// CHECK-NEXT: gpu.barrier
// CHECK-NEXT: triton_gpu.convert_layout
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
%2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED>
}
tt.return
}
// CHEKC-LABEL: convert_layout4
tt.func @convert_layout4(%A : !tt.ptr<f16>, %cond : i1) {
// CHECK-NOT: gpu.barrier
scf.if %cond {
tt.call @convert_layout3(%cond) : (i1) -> ()
} else {
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
}
// CHECK-LABEL: single_call_sync
tt.func @single_call_sync(%A : !tt.ptr<f16>) {
%0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK: tt.call
// CHECK-NEXT: gpu.barrier
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
%1 = triton_gpu.convert_layout %0 : (tensor<16x32xf16, #AL>) -> tensor<16x32xf16, #BL>
tt.return
}
// CHECK-LABEL: single_call_no_sync
// %1 can reuse %0 in convert_layout2, which has been synced
tt.func @single_call_no_sync(%A : !tt.ptr<f16>) {
// CHECK-NOT: gpu.barrier
%0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL>
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #BL>
tt.return
}
// CHECK-LABEL: multiple_calls
tt.func @multiple_calls(%A : !tt.ptr<f16>) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
tt.return
}
// CHECK-LABEL: if_else_calls
tt.func @if_else_calls(%A : !tt.ptr<f16>, %cond : i1) {
scf.if %cond {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.call
// CHECK-NEXT: gpu.barrier
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED>
} else {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
// CHECK: tt.call
// CHECK-NOT: gpu.barrier
tt.call @convert_layout2(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
}
// CHECK-LABEL: for_calls
tt.func @for_calls(%A : !tt.ptr<f16>, %cond : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
%cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL>
%lb = arith.constant 0 : index
%ub = arith.constant 10 : index
%step = arith.constant 1 : index
scf.for %iv = %lb to %ub step %step {
// CHECK: gpu.barrier
// CHECK-NEXT: tt.call
tt.call @convert_layout1(%A) : (!tt.ptr<f16>) -> ()
}
tt.return
}
// CHECK-LABEL: call_graph_1
tt.func @call_graph_1(%A : !tt.ptr<f16>, %cond : i1) {
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
// CHECK: gpu.barrier
// CHECK-NEXT: tt.call
tt.call @convert_layout3(%cond) : (i1) -> ()
tt.return
}
// CHECK-LABEL: call_graph_2
tt.func @call_graph_2(%A : !tt.ptr<f16>, %cond : i1) {
tt.call @convert_layout4(%A, %cond) : (!tt.ptr<f16>, i1) -> ()
// CHECK: tt.call
// CHECK-NEXT: gpu.barrier
%cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED>
tt.return
}
}

View File

@@ -1132,8 +1132,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}>
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
#mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// PTX-LABEL: convert_dot
// This test is disabled for GCN target, because it is PTX specific
@@ -1250,7 +1250,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
module attributes {"triton_gpu.num-warps" = 1 : i32} {
// CHECK-LABEL: convert_blocked1d_to_slice1
tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) {
// CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
// CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr<vector<1xi32>, 3>
%cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>>
tt.return
}
@@ -1279,8 +1279,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} {
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// PTX-LABEL: matmul_kernel_dot_operand_layout
// This test is disabled for GCN target, because it is PTX specific
@@ -1357,8 +1357,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
#mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}>
#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}>
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}>
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}>
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// PTX-LABEL: matmul_tf32dot
// This test is disabled for GCN target, because it is PTX specific
@@ -1366,12 +1366,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @matmul_tf32dot(%ptr:!tt.ptr<f32> {tt.divisibility = 16 : i32},
%a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) {
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
<<<<<<< HEAD
// PTX: llvm.inline_asm
// PTX-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// PTX-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>)
// PTX: llvm.inline_asm
// PTX-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// PTX-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>)
=======
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// CHECK-SAME: (i32, i32, i32, i32)
// CHECK: llvm.inline_asm
// CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16
// CHECK-SAME: (i32, i32, i32, i32)
>>>>>>> openai/main
%a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a>
%b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b>
@@ -1399,6 +1408,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
<<<<<<< HEAD
// GCN-NOT: llvm.inline_asm
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
// PTX: llvm.icmp "slt"
@@ -1406,6 +1416,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// PTX-SAME: @$3 atom.global.gpu.add.f32
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.f32
=======
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
>>>>>>> openai/main
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
tt.return
}
@@ -1416,11 +1432,18 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f32_scalar
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
<<<<<<< HEAD
// GCN-NOT: llvm.inline_asm
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
// PTX: llvm.icmp "eq"
// PTX: llvm.inline_asm
// PTX-SAME: @$3 atom.global.gpu.add.f32
=======
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK: llvm.inline_asm
// CHECK-SAME: @$3 atom.global.gpu.add.f32
>>>>>>> openai/main
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
tt.return
}
@@ -1432,6 +1455,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32
tt.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
<<<<<<< HEAD
// GCN-NOT: llvm.inline_asm
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
@@ -1440,6 +1464,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
// PTX-SAME: @$2 st.global.b32
// PTX: llvm.inline_asm
// PTX-SAME: @$2 st.global.b32
=======
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
>>>>>>> openai/main
tt.store %arg0, %arg1 : tensor<256xf32, #blocked0>
tt.return
}
@@ -1450,11 +1480,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
module attributes {"triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: store_f32_scalar
tt.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
<<<<<<< HEAD
// GCN-NOT: llvm.inline_asm
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
// PTX: llvm.icmp "slt"
// PTX: llvm.inline_asm
// PTX-SAME: @$2 st.global.b32
=======
// CHECK: llvm.icmp "eq"
// CHECK: llvm.inline_asm
// CHECK-SAME: @$2 st.global.b32
>>>>>>> openai/main
tt.store %arg0, %arg1 : f32
tt.return
}

View File

@@ -0,0 +1,22 @@
// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s
// == LLVM IR check begin ==
// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule'
// CHECK: define void @test_func
// CHECK: define void @test_kernel
// CHECK: tail call void @test_func
module attributes {"triton_gpu.num-warps" = 4 : i32} {
tt.func @test_func(%lb : index, %A : !tt.ptr<f16>) attributes { noinline = true } {
%0 = arith.constant 1.0 : f16
tt.store %A, %0 : f16
tt.return
}
tt.func @test_kernel(%lb : index, %A : !tt.ptr<f16>) {
tt.call @test_func(%lb, %A) : (index, !tt.ptr<f16>) -> ()
tt.return
}
}

View File

@@ -0,0 +1,52 @@
// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -tritongpu-remove-layout-conversions -canonicalize | FileCheck %s
#Cv2 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#Av2 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}>
#Bv2 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}>
#Cv1 = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [4, 1]}>
#Av1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv1}>
#Bv1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv1}>
#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}>
// CHECK: tt.func @push_elementwise1
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]]
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]]
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
// CHECK: %[[C:.*]] = tt.dot %[[AF16]]
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma>
tt.func @push_elementwise1(
%pa: tensor<16x16x!tt.ptr<i8>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av2>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv2>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2>
tt.return %newc : tensor<16x16xf32, #Cv2>
}
// CHECK: tt.func @push_elementwise2
// CHECK: %[[ALOAD:.*]] = tt.load %arg0
// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]]
// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]]
// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]]
// CHECK: %[[C:.*]] = tt.dot %[[ACVT]]
// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1>
tt.func @push_elementwise2(
%pa: tensor<16x16x!tt.ptr<i8>, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%pb: tensor<16x16x!tt.ptr<f16>, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32},
%c: tensor<16x16xf32, #Cv1>) -> tensor<16x16xf32, #Cv1>{
%ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL>
%b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL>
%af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL>
%a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL>
%dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av1>
%dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv1>
%newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1>
tt.return %newc : tensor<16x16xf32, #Cv1>
}

View File

@@ -8,8 +8,8 @@
#BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}>
#BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}>
#C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}>
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}>
#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}>
// CHECK: tt.func @matmul_loop
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32

View File

@@ -1,4 +1,4 @@
// RUN: triton-opt %s -split-input-file -tritongpu-prefetch | FileCheck %s
// RUN: triton-opt %s -split-input-file -tritongpu-prefetch -canonicalize | FileCheck %s
// 4 warps
// matmul: 128x32 @ 32x128 -> 128x128
@@ -7,33 +7,36 @@
#A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}>
#C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}>
#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C}>
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}>
#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>
#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>
// CHECK: tt.func @matmul_loop
// CHECK: tt.func @matmul_loop_mixed
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[A0:.*]][0, 0] [128, 16]
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]]
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[B0:.*]][0, 0] [16, 128]
// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.convert_layout %[[B0_PREFETCH_SMEM]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_PREFETCH]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]]
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]]
// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_a0]][0, 16] [128, 16]
// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.convert_layout %[[A_REM_SMEM]]
// CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]]
// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_b0]][16, 0] [16, 128]
// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.convert_layout %[[B_REM_SMEM]]
// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}}
// CHECK: tt.dot %[[A_REM]], %[[B_REM]], %[[D_FIRST:.*]]
// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]]
// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [128, 16]
// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_A_PREFETCH_SMEM]]
// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]]
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [16, 128]
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]]
tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f8E5M2>, %B : !tt.ptr<f16>) -> tensor<128x128xf32, #C>{
%a_ptr_init = tt.broadcast %A : (!tt.ptr<f8E5M2>) -> tensor<128x32x!tt.ptr<f8E5M2>, #AL>
%b_ptr_init = tt.broadcast %B : (!tt.ptr<f16>) -> tensor<32x128x!tt.ptr<f16>, #BL>
%a_mask = arith.constant dense<true> : tensor<128x32xi1, #AL>
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf8E5M2, #AL>
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
%b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL>
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
@@ -41,24 +44,25 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>,
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
%a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf8E5M2, #AL>
%a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf8E5M2, #AL>) -> tensor<128x32xf8E5M2, #A>
%b_ = tt.load %b_ptr_init, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%b_init = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) {
%a_op = triton_gpu.convert_layout %a : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A_OP>
%loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f8E5M2>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf8E5M2, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) {
%a_op_ = triton_gpu.convert_layout %a : (tensor<128x32xf8E5M2, #A>) -> tensor<128x32xf8E5M2, #A_OP>
%a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP>
%b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP>
%c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<128x32xi32, #AL>
%next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr<f8E5M2>, #AL>, tensor<128x32xi32, #AL>
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
%next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
%next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
%next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf8E5M2, #AL>
%next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> tensor<128x32xf8E5M2, #A>
%next_b_ = tt.load %next_b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
%next_b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>
scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr<f8E5M2>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x32xf8E5M2, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>
}
tt.return
tt.return %loop#4 : tensor<128x128xf32, #C>
}

View File

@@ -5,6 +5,7 @@ add_mlir_library(TritonTestAnalysis
TestMembar.cpp
LINK_LIBS PUBLIC
MLIRPass
TritonAnalysis
${dialect_libs}
)

View File

@@ -6,7 +6,7 @@ using namespace mlir;
namespace {
struct TestAllocationPass
: public PassWrapper<TestAllocationPass, OperationPass<triton::FuncOp>> {
: public PassWrapper<TestAllocationPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass);
@@ -16,31 +16,37 @@ struct TestAllocationPass
}
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
ModuleOp moduleOp = getOperation();
// Convert to std::string can remove quotes from opName
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
Allocation allocation(operation);
operation->walk([&](Operation *op) {
auto scratchBufferId = allocation.getBufferId(op);
if (scratchBufferId != Allocation::InvalidBufferId) {
size_t offset = allocation.getOffset(scratchBufferId);
size_t size = allocation.getAllocatedSize(scratchBufferId);
os << "scratch offset = " << offset << ", size = " << size << "\n";
}
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
auto bufferId = allocation.getBufferId(result);
if (bufferId != Allocation::InvalidBufferId) {
size_t offset = allocation.getOffset(bufferId);
size_t size = allocation.getAllocatedSize(bufferId);
os << "offset = " << offset << ", size = " << size << "\n";
ModuleAllocation moduleAllocation(moduleOp);
moduleOp.walk([&](triton::FuncOp funcOp) {
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
os << opName << "\n";
auto *allocation = moduleAllocation.getFuncData(funcOp);
funcOp.walk([&](Operation *op) {
auto scratchBufferId = allocation->getBufferId(op);
if (scratchBufferId != Allocation::InvalidBufferId) {
size_t offset = allocation->getOffset(scratchBufferId);
size_t size = allocation->getAllocatedSize(scratchBufferId);
if (allocation->isVirtualBuffer(scratchBufferId))
os << "virtual offset = " << offset << ", size = " << size << "\n";
else
os << "scratch offset = " << offset << ", size = " << size << "\n";
}
}
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
auto bufferId = allocation->getBufferId(result);
if (bufferId != Allocation::InvalidBufferId) {
size_t offset = allocation->getOffset(bufferId);
size_t size = allocation->getAllocatedSize(bufferId);
os << "offset = " << offset << ", size = " << size << "\n";
}
}
});
os << "size = " << allocation->getSharedMemorySize() << "\n";
});
os << "size = " << allocation.getSharedMemorySize() << "\n";
}
};

View File

@@ -7,7 +7,7 @@ using namespace mlir;
namespace {
struct TestAxisInfoPass
: public PassWrapper<TestAxisInfoPass, OperationPass<triton::FuncOp>> {
: public PassWrapper<TestAxisInfoPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass);
@@ -18,23 +18,24 @@ struct TestAxisInfoPass
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << "@" << opName << "\n";
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
AxisInfoAnalysis *analysis = solver->load<AxisInfoAnalysis>();
if (failed(solver->initializeAndRun(operation)))
return signalPassFailure();
operation->walk([&](Operation *op) {
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
result.print(os);
os << " => ";
analysis->getLatticeElement(result)->getValue().print(os);
os << "\n";
}
ModuleOp moduleOp = cast<ModuleOp>(operation);
ModuleAxisInfoAnalysis moduleAxisInfoAnalysis(moduleOp);
moduleOp.walk([&](triton::FuncOp funcOp) {
auto &os = llvm::errs();
auto opName = SymbolTable::getSymbolName(funcOp).getValue().str();
os << "@" << opName << "\n";
funcOp.walk([&](Operation *op) {
if (op->getNumResults() < 1)
return;
for (Value result : op->getResults()) {
result.print(os);
os << " => ";
auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result);
if (axisInfo)
axisInfo->print(os);
os << "\n";
}
});
});
}
};

View File

@@ -11,7 +11,7 @@ using namespace mlir;
namespace {
struct TestMembarPass
: public PassWrapper<TestMembarPass, OperationPass<triton::FuncOp>> {
: public PassWrapper<TestMembarPass, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass);
@@ -22,17 +22,11 @@ struct TestMembarPass
void runOnOperation() override {
Operation *operation = getOperation();
auto &os = llvm::errs();
// Convert to std::string can remove quotes from op_name
auto opName = SymbolTable::getSymbolName(operation).getValue().str();
os << opName << "\n";
ModuleOp moduleOp = cast<ModuleOp>(operation);
// Print all ops after membar pass
Allocation allocation(operation);
MembarAnalysis membarPass(&allocation);
ModuleAllocation allocation(moduleOp);
ModuleMembarAnalysis membarPass(&allocation);
membarPass.run();
os << *operation << "\n";
}
};

View File

@@ -1,5 +1,8 @@
add_triton_ut(
NAME TestTritonAnalysis
SRCS UtilityTest.cpp
LIBS TritonAnalysis
LIBS
TritonAnalysis
TritonIR
TritonGPUIR
)

View File

@@ -30,7 +30,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) {
// create encoding
auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1});
auto encoding =
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent);
triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent, 0);
// create element type
Type eltType = IntegerType::get(&ctx, params.typeWidth);