diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index adb883143..81674e7d4 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -25,9 +25,9 @@ jobs: id: set-matrix run: | if [ x"${{ github.repository }}" == x"openai/triton" ]; then - echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"], "macos-10.15"]' + echo '::set-output name=matrix::[["self-hosted", "A100"], ["self-hosted", "V100"], ["self-hosted", "gfx908"]]' else - echo '::set-output name=matrix::["ubuntu-latest", "macos-10.15"]' + echo '::set-output name=matrix::["ubuntu-latest"]' fi Integration-Tests: @@ -101,6 +101,18 @@ jobs: cd python/test/unit python3 -m pytest + - name: Create artifacts archive + if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}} + run: | + tar -czvf artifacts.tar.gz ~/.triton/cache + + - name: Upload artifacts archive + if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100')}} + uses: actions/upload-artifact@v2 + with: + name: artifacts + path: artifacts.tar.gz + - name: Run CXX unittests if: ${{ env.BACKEND != 'ROCM'}} run: | diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index 915439279..18706ebae 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -30,7 +30,7 @@ jobs: #export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" export CIBW_BEFORE_BUILD="pip install cmake;" export CIBW_SKIP="{cp,pp}35-*" - export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64" + export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64 cp3*-musllinux_x86_64" python3 -m cibuildwheel python --output-dir wheelhouse diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 2754cffc4..000000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "deps/dlfcn-win32"] - path = deps/dlfcn-win32 - url = https://github.com/dlfcn-win32/dlfcn-win32.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e50d41ac..8eae0bdbe 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,12 +49,6 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) # Third-party include_directories(${PYBIND11_INCLUDE_DIR}) -if(WIN32) - SET(BUILD_SHARED_LIBS OFF) - find_package(dlfcn-win32 REQUIRED) - set(CMAKE_DL_LIBS dlfcn-win32::dl) -endif() - set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden") if (TRITON_USE_ROCM) diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 58bce15f3..5013a0242 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -34,6 +34,7 @@ Shape Manipulation Ops :nosignatures: broadcast_to + expand_dims reshape ravel diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 89b77034c..f7986a44c 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -1,6 +1,7 @@ #ifndef TRITON_ANALYSIS_ALLOCATION_H #define TRITON_ANALYSIS_ALLOCATION_H +#include "triton/Analysis/Utility.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" @@ -49,18 +50,25 @@ private: T End = std::numeric_limits::max(); }; +template Interval(T, T) -> Interval; + class Allocation { public: /// A unique identifier for shared memory buffers using BufferId = size_t; using BufferIdSetT = DenseSet; + using FuncAllocMapT = CallGraph::FuncDataMapT; static constexpr BufferId InvalidBufferId = std::numeric_limits::max(); + Allocation() = default; /// Creates a new Allocation analysis that computes the shared memory /// information for all associated shared memory values. - Allocation(Operation *operation) : operation(operation) { run(); } + explicit Allocation(Operation *operation) : operation(operation) {} + + /// Runs allocation analysis on the given top-level operation. + void run(FuncAllocMapT &funcAllocMap); /// Returns the operation this analysis was constructed from. Operation *getOperation() const { return operation; } @@ -75,6 +83,12 @@ public: return bufferSet.at(bufferId).size; } + /// Returns the allocated interval of the given buffer. + Interval getAllocatedInterval(BufferId bufferId) const { + auto &buffer = bufferSet.at(bufferId); + return Interval(buffer.offset, buffer.offset + buffer.size); + } + /// Returns the buffer id of the given value. /// This interface only returns the allocated buffer id. /// If you want to get all the buffer ids that are associated with the given @@ -104,26 +118,28 @@ public: BufferId getBufferId(Operation *operation) const { if (opScratch.count(operation)) { return opScratch.lookup(operation)->id; + } else if (opVirtual.count(operation)) { + return opVirtual.lookup(operation)->id; } else { return InvalidBufferId; } } + /// Returns the size of the given buffer is a virtual buffer. + bool isVirtualBuffer(BufferId bufferId) const { + return bufferSet.at(bufferId).kind == BufferT::BufferKind::Virtual; + } + /// Returns the size of total shared memory allocated size_t getSharedMemorySize() const { return sharedMemorySize; } - bool isIntersected(BufferId lhsId, BufferId rhsId) const { - if (lhsId == InvalidBufferId || rhsId == InvalidBufferId) - return false; - auto lhsBuffer = bufferSet.at(lhsId); - auto rhsBuffer = bufferSet.at(rhsId); - return lhsBuffer.intersects(rhsBuffer); - } - private: /// A class that represents a shared memory buffer struct BufferT { - enum class BufferKind { Explicit, Scratch }; + /// Explicit: triton_gpu.alloc_tensor + /// Scratch: triton_gpu.convert_layout + /// Virtual: triton.call + enum class BufferKind { Explicit, Scratch, Virtual }; /// MT: thread-safe inline static std::atomic nextId = 0; @@ -142,12 +158,6 @@ private: BufferT(BufferKind kind, size_t size) : BufferT(kind, size, 0) {} BufferT(BufferKind kind, size_t size, size_t offset) : kind(kind), id(nextId++), size(size), offset(offset) {} - - bool intersects(const BufferT &other) const { - return Interval(offset, offset + size) - .intersects( - Interval(other.offset, other.offset + other.size)); - } }; /// Op -> Scratch Buffer @@ -158,8 +168,6 @@ private: using AliasBufferMapT = llvm::MapVector>; /// BufferId -> Buffer using BufferSetT = std::map; - /// Runs allocation analysis on the given top-level operation. - void run(); private: template @@ -168,6 +176,8 @@ private: bufferSet[buffer.id] = std::move(buffer); if constexpr (Kind == BufferT::BufferKind::Explicit) { valueBuffer[key] = &bufferSet[buffer.id]; + } else if constexpr (Kind == BufferT::BufferKind::Virtual) { + opVirtual[key] = &bufferSet[buffer.id]; } else { opScratch[key] = &bufferSet[buffer.id]; } @@ -178,8 +188,9 @@ private: } private: - Operation *operation; + Operation *operation = nullptr; OpScratchMapT opScratch; + OpScratchMapT opVirtual; ValueBufferMapT valueBuffer; AliasBufferMapT aliasBuffer; BufferSetT bufferSet; @@ -188,7 +199,53 @@ private: friend class triton::AllocationAnalysis; }; -template Interval(T, T) -> Interval; +/// Static analysis that computes the allocation of shared memory buffers +/// of the entire call graph. +/// The allocation is performed in a post-order walk of the call graph. +/// Each call op is treated like convert_layout that allocates a scratch buffer. +/// At each call, we compute the start offset of the scratch buffer and pass it +/// as an argument to the callee. +class ModuleAllocation : public CallGraph { +public: + using FuncOffsetMapT = DenseMap; + + explicit ModuleAllocation(ModuleOp moduleOp) + : CallGraph(moduleOp) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + auto [iter, inserted] = funcMap.try_emplace(funcOp, funcOp); + if (inserted) + iter->second.run(funcMap); + }); + } + + size_t getSharedMemorySize() { + size_t size = 0; + for (auto funcOp : getRoots()) { + auto *alloc = getFuncData(funcOp); + size = std::max(size, alloc->getSharedMemorySize()); + } + return size; + } + + size_t getSharedMemorySize(FunctionOpInterface funcOp) { + return getFuncData(funcOp)->getSharedMemorySize(); + } + + void setFunctionSharedMemoryValue(FunctionOpInterface funcOp, Value value) { + sharedMemoryValue[funcOp] = value; + } + + Value getFunctionSharedMemoryBase(FunctionOpInterface funcOp) { + return sharedMemoryValue[funcOp]; + } + +private: + FuncOffsetMapT sharedMemoryValue; +}; } // namespace mlir diff --git a/include/triton/Analysis/AxisInfo.h b/include/triton/Analysis/AxisInfo.h index 7f51f2d60..af5b04ad2 100644 --- a/include/triton/Analysis/AxisInfo.h +++ b/include/triton/Analysis/AxisInfo.h @@ -286,16 +286,71 @@ public: AxisInfoAnalysis(DataFlowSolver &solver); using dataflow::SparseDataFlowAnalysis< dataflow::Lattice>::getLatticeElement; + using FuncAxisInfoMapT = DenseMap; void visitOperation(Operation *op, ArrayRef *> operands, ArrayRef *> results) override; +}; + +/// Module level axis info analysis based on the call graph, assuming that we +/// do not have recursive functions. +/// Since each function will be called multiple times, we need to +/// calculate the axis info based on the axis info of all the callers. +/// In the future, we can perform optimization using function cloning so that +/// each call site will have unique axis info. +using AxisInfoMapT = DenseMap; +class ModuleAxisInfoAnalysis : public CallGraph { +public: + explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp) + : CallGraph(moduleOp) { + SmallVector funcs; + for (auto root : getRoots()) { + walk( + // Pre-order edge walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order node walk callback + [&](FunctionOpInterface funcOp) { + funcs.push_back(funcOp); + funcMap.try_emplace(funcOp, AxisInfoMapT{}); + }); + } + SetVector sortedFuncs(funcs.begin(), funcs.end()); + SymbolTableCollection symbolTable; + for (auto funcOp : llvm::reverse(sortedFuncs)) { + initialize(funcOp); + funcOp.walk([&](CallOpInterface callOp) { + auto callee = + dyn_cast(callOp.resolveCallable(&symbolTable)); + update(callOp, callee); + }); + } + } + + AxisInfo *getAxisInfo(Value value) { + auto funcOp = + value.getParentRegion()->getParentOfType(); + auto *axisInfoMap = getFuncData(funcOp); + if (!axisInfoMap) { + return nullptr; + } + auto it = axisInfoMap->find(value); + if (it == axisInfoMap->end()) { + return nullptr; + } + return &(it->second); + } unsigned getPtrContiguity(Value ptr); unsigned getPtrAlignment(Value ptr); unsigned getMaskAlignment(Value mask); + +private: + void initialize(FunctionOpInterface funcOp); + + void update(CallOpInterface callOp, FunctionOpInterface funcOp); }; } // namespace mlir diff --git a/include/triton/Analysis/Membar.h b/include/triton/Analysis/Membar.h index ccec8e7ab..cce981b3f 100644 --- a/include/triton/Analysis/Membar.h +++ b/include/triton/Analysis/Membar.h @@ -4,20 +4,75 @@ #include "Allocation.h" #include "llvm/ADT/SmallPtrSet.h" +#include + namespace mlir { class OpBuilder; +struct BlockInfo { + using BufferIdSetT = Allocation::BufferIdSetT; + using IntervalSetT = std::set>; + + IntervalSetT syncReadIntervals; + IntervalSetT syncWriteIntervals; + + BlockInfo() = default; + + /// Unions two BlockInfo objects. + BlockInfo &join(const BlockInfo &other) { + syncReadIntervals.insert(other.syncReadIntervals.begin(), + other.syncReadIntervals.end()); + syncWriteIntervals.insert(other.syncWriteIntervals.begin(), + other.syncWriteIntervals.end()); + return *this; + } + + /// Returns true if intervals in two BlockInfo objects are intersected. + bool isIntersected(const BlockInfo &other) const { + return /*RAW*/ isIntersected(syncWriteIntervals, other.syncReadIntervals) || + /*WAR*/ + isIntersected(syncReadIntervals, other.syncWriteIntervals) || + /*WAW*/ + isIntersected(syncWriteIntervals, other.syncWriteIntervals); + } + + /// Clears the intervals because a barrier is inserted. + void sync() { + syncReadIntervals.clear(); + syncWriteIntervals.clear(); + } + + /// Compares two BlockInfo objects. + bool operator==(const BlockInfo &other) const { + return syncReadIntervals == other.syncReadIntervals && + syncWriteIntervals == other.syncWriteIntervals; + } + + bool operator!=(const BlockInfo &other) const { return !(*this == other); } + +private: + bool isIntersected(const IntervalSetT &lhsIntervalSet, + const IntervalSetT &rhsIntervalSet) const { + for (auto &lhs : lhsIntervalSet) + for (auto &rhs : rhsIntervalSet) + if (lhs.intersects(rhs)) + return true; + return false; + } +}; + //===----------------------------------------------------------------------===// // Shared Memory Barrier Analysis //===----------------------------------------------------------------------===// class MembarAnalysis { public: + using FuncBlockInfoMapT = CallGraph::FuncDataMapT; /// Creates a new Membar analysis that generates the shared memory barrier /// in the following circumstances: /// - RAW: If a shared memory write is followed by a shared memory read, and /// their addresses are intersected, a barrier is inserted. - /// - WAR: If a shared memory read is followed by a shared memory read, and + /// - WAR: If a shared memory read is followed by a shared memory write, and /// their addresses are intersected, a barrier is inserted. /// The following circumstances do not require a barrier: /// - WAW: not possible because overlapped memory allocation is not allowed. @@ -26,75 +81,14 @@ public: /// a shared memory read. If the temporary storage is written but not read, /// it is considered as the problem of the operation itself but not the membar /// analysis. - /// The following circumstances are not considered yet: - /// - Double buffers - /// - N buffers - MembarAnalysis(Allocation *allocation) : allocation(allocation) {} + MembarAnalysis() = default; + explicit MembarAnalysis(Allocation *allocation) : allocation(allocation) {} /// Runs the membar analysis to the given operation, inserts a barrier if /// necessary. - void run(); + void run(FuncBlockInfoMapT &funcBlockInfoMap); private: - struct BlockInfo { - using BufferIdSetT = Allocation::BufferIdSetT; - - BufferIdSetT syncReadBuffers; - BufferIdSetT syncWriteBuffers; - - BlockInfo() = default; - BlockInfo(const BufferIdSetT &syncReadBuffers, - const BufferIdSetT &syncWriteBuffers) - : syncReadBuffers(syncReadBuffers), syncWriteBuffers(syncWriteBuffers) { - } - - /// Unions two BlockInfo objects. - BlockInfo &join(const BlockInfo &other) { - syncReadBuffers.insert(other.syncReadBuffers.begin(), - other.syncReadBuffers.end()); - syncWriteBuffers.insert(other.syncWriteBuffers.begin(), - other.syncWriteBuffers.end()); - return *this; - } - - /// Returns true if buffers in two BlockInfo objects are intersected. - bool isIntersected(const BlockInfo &other, Allocation *allocation) const { - return /*RAW*/ isIntersected(syncWriteBuffers, other.syncReadBuffers, - allocation) || - /*WAR*/ - isIntersected(syncReadBuffers, other.syncWriteBuffers, - allocation) || - /*WAW*/ - isIntersected(syncWriteBuffers, other.syncWriteBuffers, - allocation); - } - - /// Clears the buffers because a barrier is inserted. - void sync() { - syncReadBuffers.clear(); - syncWriteBuffers.clear(); - } - - /// Compares two BlockInfo objects. - bool operator==(const BlockInfo &other) const { - return syncReadBuffers == other.syncReadBuffers && - syncWriteBuffers == other.syncWriteBuffers; - } - - bool operator!=(const BlockInfo &other) const { return !(*this == other); } - - private: - /// Returns true if buffers in two sets are intersected. - bool isIntersected(const BufferIdSetT &lhs, const BufferIdSetT &rhs, - Allocation *allocation) const { - return std::any_of(lhs.begin(), lhs.end(), [&](auto lhsId) { - return std::any_of(rhs.begin(), rhs.end(), [&](auto rhsId) { - return allocation->isIntersected(lhsId, rhsId); - }); - }); - } - }; - /// Applies the barrier analysis based on the SCF dialect, in which each /// region has a single basic block only. /// Example: @@ -109,18 +103,48 @@ private: /// op6 /// op7 /// TODO: Explain why we don't use ForwardAnalysis: - void resolve(Operation *operation, OpBuilder *builder); + void resolve(FunctionOpInterface funcOp, FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder); /// Updates the BlockInfo operation based on the operation. - void update(Operation *operation, BlockInfo *blockInfo, OpBuilder *builder); + void update(Operation *operation, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder); /// Collects the successors of the terminator void visitTerminator(Operation *operation, SmallVector &successors); private: - Allocation *allocation; - DenseMap inputBlockInfoMap; - DenseMap outputBlockInfoMap; + Allocation *allocation = nullptr; +}; + +/// Postorder traversal on the callgraph to insert membar instructions +/// of each function. +/// Each function maintains a BlockInfo map that includes all potential buffers +/// after returning. This way users do not have to explicitly insert membars +/// before and after function calls, but might be a bit conservative. +class ModuleMembarAnalysis : public CallGraph { +public: + ModuleMembarAnalysis(ModuleAllocation *moduleAllocation) + : CallGraph(moduleAllocation->getModuleOp()), + moduleAllocation(moduleAllocation) {} + + void run() { + walk( + // Pre-order walk callback + [](CallOpInterface callOp, FunctionOpInterface funcOp) {}, + // Post-order walk callback + [&](FunctionOpInterface funcOp) { + auto *allocation = moduleAllocation->getFuncData(funcOp); + auto [it, inserted] = funcMap.try_emplace(funcOp, BlockInfo()); + if (inserted) { + MembarAnalysis analysis(allocation); + analysis.run(funcMap); + } + }); + } + +private: + ModuleAllocation *moduleAllocation; }; } // namespace mlir diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index e7873c646..1c9fb5019 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -39,6 +39,10 @@ public: unsigned getIntraWarpSize(); + unsigned getInterWarpSizeWithUniqueData(); + + unsigned getIntraWarpSizeWithUniqueData(); + unsigned getThreadsReductionAxis(); SmallVector getScratchConfigBasic(); @@ -57,8 +61,6 @@ private: int axis; }; -bool isSharedEncoding(Value value); - bool maybeSharedAllocationOp(Operation *op); bool maybeAliasOp(Operation *op); @@ -116,14 +118,153 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); SetVector multiRootTopologicalSort(const SetVector &toSort); -// This uses the toplogicalSort above +/// This uses the toplogicalSort above SetVector multiRootGetSlice(Operation *op, TransitiveFilter backwardFilter = nullptr, TransitiveFilter forwardFilter = nullptr); -// Create a basic DataFlowSolver with constant and dead code analysis included. +/// Create a basic DataFlowSolver with constant and dead code analysis included. std::unique_ptr createDataFlowSolver(); +/// This class represents a call graph for a given ModuleOp and holds +/// data of type T associated with each FunctionOpInterface. +template class CallGraph { +public: + using FuncDataMapT = DenseMap; + + /// Constructor that builds the call graph for the given moduleOp. + explicit CallGraph(ModuleOp moduleOp) : moduleOp(moduleOp) { build(); } + + /// Walks the call graph and applies the provided update functions + /// to the edges and nodes. + template + void walk(UpdateEdgeFn updateEdgeFn, UpdateNodeFn updateNodeFn) { + DenseSet visited; + for (auto root : roots) { + doWalk(root, visited, updateEdgeFn, + updateNodeFn); + } + } + + /// Retrieves the data associated with a function + T *getFuncData(FunctionOpInterface funcOp) { + if (funcMap.count(funcOp)) { + return &funcMap[funcOp]; + } + return nullptr; + } + + /// Getters + ModuleOp getModuleOp() const { return moduleOp; } + SmallVector getRoots() const { return roots; } + size_t getNumFunctions() const { return funcMap.size(); } + + /// Returns true if the given function is a root. + bool isRoot(FunctionOpInterface funcOp) const { + return llvm::is_contained(roots, funcOp); + } + + /// Maps the data and the graph nodes associated with a funcOp to a + /// targetFuncOp. + template + void mapFuncOp(FROM funcOp, TO targetFuncOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.second == funcOp) { + edge.second = targetFuncOp; + } + } + } + graph[targetFuncOp] = graph[funcOp]; + // Replace in roots + for (auto it = roots.begin(); it != roots.end(); ++it) { + if (*it == funcOp) { + *it = targetFuncOp; + break; + } + } + // Replace in funcMap + funcMap[targetFuncOp] = funcMap[funcOp]; + } + + /// Maps the graph edges associated with a callOp to a targetCallOp. + template + void mapCallOp(FROM callOp, TO targetCallOp) { + // Iterate over graph and replace + for (auto &kv : graph) { + for (auto &edge : kv.second) { + if (edge.first == callOp) { + edge.first = targetCallOp; + } + } + } + } + +private: + void build() { + SymbolTableCollection symbolTable; + DenseSet visited; + // Build graph + moduleOp.walk([&](Operation *op) { + auto caller = op->getParentOfType(); + if (auto callOp = dyn_cast(op)) { + auto *callee = callOp.resolveCallable(&symbolTable); + auto funcOp = dyn_cast_or_null(callee); + if (funcOp) { + graph[caller].emplace_back( + std::pair(callOp, funcOp)); + visited.insert(funcOp); + } + } + }); + // Find roots + moduleOp.walk([&](FunctionOpInterface funcOp) { + if (!visited.count(funcOp)) { + roots.push_back(funcOp); + } + }); + } + + template + void doWalk(FunctionOpInterface funcOp, + DenseSet &visited, UpdateEdgeFn updateEdgeFn, + UpdateNodeFn updateNodeFn) { + if (visited.count(funcOp)) { + llvm::report_fatal_error("Cycle detected in call graph"); + } + if constexpr (UpdateNodeOrder == WalkOrder::PreOrder) { + updateNodeFn(funcOp); + } + for (auto [callOp, callee] : graph[funcOp]) { + if constexpr (UpdateEdgeOrder == WalkOrder::PreOrder) { + updateEdgeFn(callOp, callee); + } + doWalk(callee, visited, updateEdgeFn, + updateNodeFn); + if constexpr (UpdateEdgeOrder == WalkOrder::PostOrder) { + updateEdgeFn(callOp, callee); + } + } + if constexpr (UpdateNodeOrder == WalkOrder::PostOrder) { + updateNodeFn(funcOp); + } + visited.erase(funcOp); + } + +protected: + ModuleOp moduleOp; + DenseMap>> + graph; + FuncDataMapT funcMap; + SmallVector roots; +}; + } // namespace mlir #endif // TRITON_ANALYSIS_UTILITY_H diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 91d1bd0b8..f6e92b1e9 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -155,7 +155,7 @@ def TT_LoadOp : TT_Op<"load", "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, // A tensor pointer with boundary check and padding OpBuilder<(ins "Value":$ptr, "ArrayRef":$boundaryCheck, - "Optional":$padding, "triton::CacheModifier":$cache, + "std::optional":$padding, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, // A tensor of pointers or a pointer to a scalar with mask OpBuilder<(ins "Value":$ptr, "Value":$mask, "triton::CacheModifier":$cache, @@ -164,8 +164,9 @@ def TT_LoadOp : TT_Op<"load", OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)>, // A utility function to build the operation with all attributes - OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, "Optional>":$boundaryCheck, - "Optional":$padding, "triton::CacheModifier":$cache, + OpBuilder<(ins "Value":$ptr, "Value":$mask, "Value":$other, + "std::optional>":$boundaryCheck, + "std::optional":$padding, "triton::CacheModifier":$cache, "triton::EvictionPolicy":$evict, "bool":$isVolatile)> ]; @@ -600,6 +601,11 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI CallInterfaceCallable getCallableForCallee() { return (*this)->getAttrOfType("callee"); } + + /// Set the callee for this operation. + void setCalleeFromCallable(CallInterfaceCallable callee) { + (*this)->setAttr("callee", callee.get()); + } }]; let assemblyFormat = [{ diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 6eecbdbbc..ba6c7e4c6 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -31,8 +31,37 @@ SmallVector getWarpsPerCTA(Attribute layout); SmallVector getSizePerThread(Attribute layout); +// Returns the number of contiguous elements that each thread +// has access to, on each dimension of the tensor. E.g. +// for a blocked layout with sizePerThread = [1, 4], returns [1, 4], +// regardless of the shape of the tensor. SmallVector getContigPerThread(Attribute layout); +// Returns the number of non-replicated contiguous elements that each thread +// has access to, on each dimension of the tensor. For a blocked layout +// with sizePerThread = [1, 4] and tensor shape = [128, 1], the elements +// for thread 0 would be [A_{0, 0}, A_{0, 0}, A_{0, 0}, A_{0, 0}], returns [1, +// 1]. Whereas for a tensor shape [128, 128], the elements for thread 0 would be +// [A_{0, 0}, A_{0, 1}, A_{0, 2}, A_{0, 3}], returns [1, 4]. +SmallVector getUniqueContigPerThread(Type type); + +// Returns the number of threads per warp that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16] and tensor shape = [2, 2], threads 0, 1, 16, 17 +// have access to the full tensor, whereas the other threads have access to +// replicated elements, so this function returns [2, 2]. +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape); + +// Returns the number of warps per CTA that have access to non-replicated +// elements of the tensor. E.g. for a blocked layout with sizePerThread = [1, +// 1], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4] and tensor shape = [2, 2], +// returns [1, 1], since the first warp has access to the full tensor, whereas +// the other warps have access to replicated elements. +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape); + SmallVector getThreadsPerCTA(Attribute layout); SmallVector @@ -45,6 +74,9 @@ bool isaDistributedLayout(Attribute layout); } // namespace gpu } // namespace triton + +bool isSharedEncoding(Value value); + } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index fcf96044c..e07d3a8be 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -509,10 +509,22 @@ section 9.7.13.4.1 for more details. let parameters = ( ins "unsigned":$opIdx, - "Attribute":$parent + "Attribute":$parent, + "unsigned":$MMAv2kWidth ); let builders = [ + // Specially for MMAV1(Volta) + AttrBuilder<(ins "unsigned":$opIdx, + "Attribute":$parent, + "Type":$eltTy), [{ + MmaEncodingAttr parentAttr = parent.dyn_cast(); + if (!parentAttr || !parentAttr.isAmpere()) + return $_get(context, opIdx, parent, 0); + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + unsigned MMAv2kWidth = 32 / bitwidth; + return $_get(context, opIdx, parent, MMAv2kWidth); + }]> ]; let hasCustomAssemblyFormat = 1; diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index d71838089..b8f53eb95 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -57,7 +57,6 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul "int32_t", /*default*/"80", "device compute capability"> ]; - } def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> { diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 996cd172d..6eee987b0 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -4,7 +4,6 @@ #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "triton/Analysis/Alias.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "llvm/ADT/SmallVector.h" @@ -117,8 +116,11 @@ SmallVector getScratchConfigForAtomicCAS(triton::AtomicCASOp op) { class AllocationAnalysis { public: - AllocationAnalysis(Operation *operation, Allocation *allocation) - : operation(operation), allocation(allocation) { + AllocationAnalysis(Operation *operation, + Allocation::FuncAllocMapT *funcAllocMap, + Allocation *allocation) + : operation(operation), funcAllocMap(funcAllocMap), + allocation(allocation) { run(); } @@ -219,6 +221,12 @@ private: ? elems * kPtrBitWidth / 8 : elems * elemTy.getIntOrFloatBitWidth() / 8; allocation->addBuffer(op, bytes); + } else if (auto callOp = dyn_cast(op)) { + auto callable = callOp.resolveCallable(); + auto funcOp = dyn_cast(callable); + auto *funcAlloc = &(*funcAllocMap)[funcOp]; + auto bytes = funcAlloc->getSharedMemorySize(); + allocation->addBuffer(op, bytes); } } @@ -298,15 +306,19 @@ private: /// allocated, but is used to store intermediate results. void resolveScratchBufferLiveness( const DenseMap &operationId) { - // Analyze liveness of scratch buffers - for (auto opScratchIter : allocation->opScratch) { - // Any scratch memory's live range is the current operation's live - // range. - auto *op = opScratchIter.first; - auto *buffer = opScratchIter.second; - bufferRange.insert({buffer, Interval(operationId.lookup(op), - operationId.lookup(op) + 1)}); - } + // Analyze liveness of scratch buffers and vritual buffers. + auto processScratchMemory = [&](const auto &container) { + for (auto opScratchIter : container) { + // Any scratch memory's live range is the current operation's live + // range. + auto *op = opScratchIter.first; + auto *buffer = opScratchIter.second; + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } + }; + processScratchMemory(allocation->opScratch); + processScratchMemory(allocation->opVirtual); } /// Resolves liveness of all values involved under the root operation. @@ -499,11 +511,15 @@ private: private: Operation *operation; + Allocation::FuncAllocMapT *funcAllocMap; Allocation *allocation; BufferRangeMapT bufferRange; }; + } // namespace triton -void Allocation::run() { triton::AllocationAnalysis(getOperation(), this); } +void Allocation::run(FuncAllocMapT &funcAllocMap) { + triton::AllocationAnalysis(getOperation(), &funcAllocMap, this); +} } // namespace mlir diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index b2884c072..8fa4a0cba 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -77,7 +77,7 @@ AxisInfo AxisInfo::getPessimisticValueState(Value value) { if (blockArg && blockArg.getOwner()->isEntryBlock()) { Operation *op = blockArg.getOwner()->getParentOp(); - if (auto fun = dyn_cast(op)) + if (auto fun = dyn_cast(op)) initPessimisticStateFromFunc(blockArg.getArgNumber(), fun, &knownContiguity, &knownDivisibility, &knownConstancy); @@ -696,13 +696,13 @@ private: const AxisInfo &rhs) override { if (lhs.getConstantValue().has_value() && rhs.getConstantValue().has_value()) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same_v) { return {lhs.getConstantValue().value() & rhs.getConstantValue().value()}; - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same_v) { return {lhs.getConstantValue().value() | rhs.getConstantValue().value()}; - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same_v) { return {lhs.getConstantValue().value() ^ rhs.getConstantValue().value()}; } @@ -907,37 +907,37 @@ void AxisInfoAnalysis::visitOperation( propagateIfChanged(result, result->join(curr)); } -unsigned AxisInfoAnalysis::getPtrContiguity(Value ptr) { +unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; auto layout = tensorTy.getEncoding(); - auto shape = tensorTy.getShape(); // Here order should be ordered by contiguous first, so the first element // should have the largest contiguous. auto order = triton::gpu::getOrder(layout); unsigned align = getPtrAlignment(ptr); - unsigned contigPerThread = triton::gpu::getSizePerThread(layout)[order[0]]; - contigPerThread = std::min(align, contigPerThread); - contigPerThread = std::min(shape[order[0]], contigPerThread); + auto uniqueContigPerThread = triton::gpu::getUniqueContigPerThread(tensorTy); + assert(order[0] < uniqueContigPerThread.size() && + "Unxpected uniqueContigPerThread size"); + unsigned contiguity = uniqueContigPerThread[order[0]]; + contiguity = std::min(align, contiguity); - return contigPerThread; + return contiguity; } -unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { +unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) { auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy) return 1; - dataflow::Lattice *latticeElement = getLatticeElement(ptr); - if (!latticeElement) + auto *axisInfo = getAxisInfo(ptr); + if (!axisInfo) return 1; - auto axisInfo = latticeElement->getValue(); auto layout = tensorTy.getEncoding(); auto order = triton::gpu::getOrder(layout); - auto maxMultipleBytes = axisInfo.getDivisibility(order[0]); - auto maxContig = axisInfo.getContiguity(order[0]); + auto maxMultipleBytes = axisInfo->getDivisibility(order[0]); + auto maxContig = axisInfo->getContiguity(order[0]); auto elemNumBits = triton::getPointeeBitWidth(tensorTy); auto elemNumBytes = std::max(elemNumBits / 8, 1); auto maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1); @@ -945,17 +945,69 @@ unsigned AxisInfoAnalysis::getPtrAlignment(Value ptr) { return alignment; } -unsigned AxisInfoAnalysis::getMaskAlignment(Value mask) { +unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) { auto tensorTy = mask.getType().dyn_cast(); if (!tensorTy) return 1; - dataflow::Lattice *latticeElement = getLatticeElement(mask); - if (!latticeElement) + auto *axisInfo = getAxisInfo(mask); + if (!axisInfo) return 1; - auto maskAxis = latticeElement->getValue(); auto maskOrder = triton::gpu::getOrder(tensorTy.getEncoding()); - auto alignment = std::max(maskAxis.getConstancy(maskOrder[0]), 1); + auto alignment = std::max(axisInfo->getConstancy(maskOrder[0]), 1); return alignment; } +void ModuleAxisInfoAnalysis::initialize(FunctionOpInterface funcOp) { + std::unique_ptr solver = createDataFlowSolver(); + AxisInfoAnalysis *analysis = solver->load(); + if (failed(solver->initializeAndRun(funcOp))) + return; + auto *axisInfoMap = getFuncData(funcOp); + auto updateAxisInfoMap = [&](Value value) { + auto axisInfo = analysis->getLatticeElement(value)->getValue(); + AxisInfo curAxisInfo; + if (axisInfoMap->count(value)) { + curAxisInfo = AxisInfo::join(axisInfo, axisInfoMap->lookup(value)); + } else { + curAxisInfo = axisInfo; + } + (*axisInfoMap)[value] = curAxisInfo; + }; + funcOp.walk([&](Operation *op) { + for (auto value : op->getResults()) { + updateAxisInfoMap(value); + } + }); + funcOp.walk([&](Block *block) { + for (auto value : block->getArguments()) { + updateAxisInfoMap(value); + } + }); +} + +void ModuleAxisInfoAnalysis::update(CallOpInterface callOp, + FunctionOpInterface callee) { + auto caller = callOp->getParentOfType(); + auto *axisInfoMap = getFuncData(caller); + for (auto entry : llvm::enumerate(callOp->getOperands())) { + auto index = entry.index(); + auto value = entry.value(); + auto setAttrFn = [&](StringRef attrName, int64_t prevValue) { + auto curValue = highestPowOf2Divisor(0); + if (callee.getArgAttrOfType(index, attrName)) { + curValue = + callee.getArgAttrOfType(index, attrName).getInt(); + } + auto attr = IntegerAttr::get(IntegerType::get(callee.getContext(), 64), + gcd(prevValue, curValue)); + callee.setArgAttr(index, attrName, attr); + }; + auto axisInfo = axisInfoMap->lookup(value); + assert(axisInfo.getRank() == 1 && "only scalar arguments are supported"); + setAttrFn("tt.contiguity", axisInfo.getContiguity(0)); + setAttrFn("tt.divisibility", axisInfo.getDivisibility(0)); + setAttrFn("tt.constancy", axisInfo.getConstancy(0)); + } +} + } // namespace mlir diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 1f761f845..bb9a21ccf 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -11,4 +11,7 @@ add_mlir_library(TritonAnalysis LINK_LIBS PUBLIC MLIRAnalysis + MLIRLLVMDialect + TritonIR + TritonGPUIR ) diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index 3d045d2de..36c54a731 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -9,16 +9,21 @@ namespace mlir { -void MembarAnalysis::run() { - auto *operation = allocation->getOperation(); - OpBuilder builder(operation); - resolve(operation, &builder); +void MembarAnalysis::run(FuncBlockInfoMapT &funcBlockInfoMap) { + FunctionOpInterface funcOp = + dyn_cast(allocation->getOperation()); + OpBuilder builder(funcOp.getContext()); + resolve(funcOp, &funcBlockInfoMap, &builder); } -void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) { +void MembarAnalysis::resolve(FunctionOpInterface funcOp, + FuncBlockInfoMapT *funcBlockInfoMap, + OpBuilder *builder) { // Initialize the blockList + DenseMap inputBlockInfoMap; + DenseMap outputBlockInfoMap; std::deque blockList; - operation->walk([&](Block *block) { + funcOp.walk([&](Block *block) { for (auto &op : block->getOperations()) { // Check if the operation belongs to scf dialect, if so, we need to // throw an error @@ -38,13 +43,13 @@ void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) { auto *block = blockList.front(); blockList.pop_front(); // Make a copy of the inputblockInfo but not update - auto inputBlockInfo = inputBlockInfoMap.lookup(block); + auto inputBlockInfo = inputBlockInfoMap[block]; SmallVector successors; for (auto &op : block->getOperations()) { if (op.hasTrait()) { visitTerminator(&op, successors); } else { - update(&op, &inputBlockInfo, builder); + update(&op, &inputBlockInfo, funcBlockInfoMap, builder); } } // Get the reference because we want to update if it changed @@ -62,15 +67,22 @@ void MembarAnalysis::resolve(Operation *operation, OpBuilder *builder) { blockList.emplace_back(successor); } } + + // Update the final dangling buffers that haven't been synced + auto &funcBlockInfo = (*funcBlockInfoMap)[funcOp]; + funcOp.walk([&](Block *block) { + block->walk([&](triton::ReturnOp returnOp) { + funcBlockInfo.join(outputBlockInfoMap[block]); + }); + }); } void MembarAnalysis::visitTerminator(Operation *op, SmallVector &successors) { if (auto branchInterface = dyn_cast(op)) { Block *parentBlock = branchInterface->getBlock(); - for (Block *successor : parentBlock->getSuccessors()) { - successors.push_back(successor); - } + successors.append(std::begin(parentBlock->getSuccessors()), + std::end(parentBlock->getSuccessors())); return; } // Otherwise, it could be a return op @@ -81,6 +93,7 @@ void MembarAnalysis::visitTerminator(Operation *op, } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, + FuncBlockInfoMapT *funcBlockInfoMap, OpBuilder *builder) { if (isa(op) || isa(op) || isa(op)) { @@ -108,36 +121,51 @@ void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, } BlockInfo curBlockInfo; - for (Value value : op->getOperands()) { - for (auto bufferId : allocation->getBufferIds(value)) { - if (bufferId != Allocation::InvalidBufferId) { - if (isa(op) || - isa(op)) { - // FIXME(Keren): insert_slice and insert_slice_async are always - // alias for now - curBlockInfo.syncWriteBuffers.insert(bufferId); - } else { - // ConvertLayoutOp: shared memory -> registers - curBlockInfo.syncReadBuffers.insert(bufferId); + if (isa(op)) { + // Inter-function dependencies + auto callOpInterface = dyn_cast(op); + if (auto callee = + dyn_cast(callOpInterface.resolveCallable())) { + curBlockInfo = funcBlockInfoMap->lookup(callee); + } + } else { + // Intra-function dependencies + for (Value value : op->getOperands()) { + for (auto bufferId : allocation->getBufferIds(value)) { + if (bufferId != Allocation::InvalidBufferId) { + if (isa(op) || + isa(op)) { + // FIXME(Keren): insert_slice and insert_slice_async are always + // alias for now + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } else { + // ConvertLayoutOp: shared memory -> registers + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } } } } - } - for (Value value : op->getResults()) { - // ConvertLayoutOp: registers -> shared memory - auto bufferId = allocation->getBufferId(value); + for (Value value : op->getResults()) { + // ConvertLayoutOp: registers -> shared memory + auto bufferId = allocation->getBufferId(value); + if (bufferId != Allocation::InvalidBufferId) { + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + } + } + // Scratch buffer is considered as both shared memory write & read + auto bufferId = allocation->getBufferId(op); if (bufferId != Allocation::InvalidBufferId) { - curBlockInfo.syncWriteBuffers.insert(bufferId); + curBlockInfo.syncWriteIntervals.insert( + allocation->getAllocatedInterval(bufferId)); + curBlockInfo.syncReadIntervals.insert( + allocation->getAllocatedInterval(bufferId)); } } - // Scratch buffer is considered as both shared memory write & read - auto bufferId = allocation->getBufferId(op); - if (bufferId != Allocation::InvalidBufferId) { - curBlockInfo.syncWriteBuffers.insert(bufferId); - curBlockInfo.syncReadBuffers.insert(bufferId); - } - if (blockInfo->isIntersected(curBlockInfo, allocation)) { + if (blockInfo->isIntersected(curBlockInfo)) { OpBuilder::InsertionGuard g(*builder); builder->setInsertionPoint(op); builder->create(op->getLoc()); diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index ce2ca57c1..51cb8649e 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -26,10 +26,27 @@ unsigned ReduceOpHelper::getIntraWarpSize() { triton::gpu::getThreadsPerWarp(getSrcLayout())[axis]); } +unsigned ReduceOpHelper::getInterWarpSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + unsigned sizeIntraWarps = getIntraWarpSizeWithUniqueData(); + return std::min(srcReduceDimSize / sizeIntraWarps, + triton::gpu::getWarpsPerCTAWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); +} + +unsigned ReduceOpHelper::getIntraWarpSizeWithUniqueData() { + auto srcReduceDimSize = static_cast(srcShape[axis]); + return std::min(srcReduceDimSize, + triton::gpu::getThreadsPerWarpWithUniqueData( + getSrcLayout(), getSrcShape())[axis]); +} + unsigned ReduceOpHelper::getThreadsReductionAxis() { auto srcLayout = getSrcLayout(); - return triton::gpu::getThreadsPerWarp(srcLayout)[axis] * - triton::gpu::getWarpsPerCTA(srcLayout)[axis]; + auto srcShape = getSrcShape(); + return triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, + srcShape)[axis] * + triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; } SmallVector ReduceOpHelper::getScratchConfigBasic() { @@ -91,14 +108,8 @@ bool ReduceOpHelper::isSupportedLayout() { return true; } } - return false; -} - -bool isSharedEncoding(Value value) { - auto type = value.getType(); - if (auto tensorType = type.dyn_cast()) { - auto encoding = tensorType.getEncoding(); - return encoding && encoding.isa(); + if (auto sliceLayout = srcLayout.dyn_cast()) { + return true; } return false; } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 35890d279..95cbd6c5e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -106,13 +106,22 @@ private: if (auto sliceLayout = layout.dyn_cast()) { unsigned dim = sliceLayout.getDim(); auto parentEncoding = sliceLayout.getParent(); + auto parentSizePerThread = getSizePerThread(parentEncoding); auto parentShape = sliceLayout.paddedShape(shape); auto parentTy = RankedTensorType::get(parentShape, type.getElementType(), parentEncoding); - auto multiDimOffsetParent = - getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy, - sliceLayout.paddedShape(multiDimCTAInRepId), - sliceLayout.paddedShape(shapePerCTA)); + auto offsets = emitOffsetForLayout(layout, type); + auto parentOffset = emitOffsetForLayout(parentEncoding, parentTy); + SmallVector idxs; + for (SmallVector off : offsets) { + off.insert(off.begin() + dim, 0); + auto it = std::find(parentOffset.begin(), parentOffset.end(), off); + idxs.push_back(std::distance(parentOffset.begin(), it)); + } + auto multiDimOffsetParent = getMultiDimOffset( + parentEncoding, loc, rewriter, idxs[elemId], parentTy, + sliceLayout.paddedShape(multiDimCTAInRepId), + sliceLayout.paddedShape(shapePerCTA)); SmallVector multiDimOffset(rank); for (unsigned d = 0; d < rank + 1; ++d) { if (d == dim) @@ -329,7 +338,8 @@ private: if (needTrans) { // do transpose - auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma); + auto aEncoding = + DotOperandEncodingAttr::get(mma.getContext(), 0, mma, 0); int numM = aEncoding.getMMAv1NumOuter(shape); int numN = accumSizePerThread / numM; @@ -619,10 +629,10 @@ private: // is implemented SmallVector reorderedVals; for (unsigned i = 0; i < vecVals.size(); i += 4) { - reorderedVals.push_back(vecVals[i]); - reorderedVals.push_back(vecVals[i + 2]); - reorderedVals.push_back(vecVals[i + 1]); - reorderedVals.push_back(vecVals[i + 3]); + reorderedVals.push_back(bitcast(vecVals[i], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 2], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 1], i32_ty)); + reorderedVals.push_back(bitcast(vecVals[i + 3], i32_ty)); } Value view = getTypeConverter()->packLLElements(loc, reorderedVals, @@ -641,19 +651,19 @@ private: auto loc = op.getLoc(); Value src = op.getSrc(); Value dst = op.getResult(); - bool isHMMA = supportMMA(dst, mmaLayout.getVersionMajor()); auto smemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(), rewriter); Value res; - if (!isOuter && mmaLayout.isAmpere() && isHMMA) { // tensor core v2 + if (!isOuter && mmaLayout.isAmpere()) { // tensor core v2 res = SharedToDotOperandMMAv2::convertLayout( dotOperandLayout.getOpIdx(), rewriter, loc, src, dotOperandLayout, smemObj, getTypeConverter(), tid_val()); - } else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1 + } else if (!isOuter && mmaLayout.isVolta() && + supportMMA(dst, mmaLayout.getVersionMajor())) { // tensor core v1 bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow(); auto srcSharedLayout = src.getType() .cast() @@ -675,14 +685,13 @@ private: } return res; } -}; // namespace triton::gpu::ConvertLayoutOp> +}; // namespace triton::gpu::ConvertLayoutOp void populateConvertLayoutOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { - patterns.add(typeConverter, allocation, smem, + patterns.add(typeConverter, allocation, indexCacheInfo, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h index 0b6efaeff..c8f3396b9 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.h @@ -10,8 +10,7 @@ using ::mlir::triton::gpu::DotOperandEncodingAttr; void populateConvertLayoutOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp index ada552ab6..5a568379e 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv1.cpp @@ -358,11 +358,11 @@ SmallVector getMNCoords(Value thread, Value _fpw1 = i32_val(fpw[1]); // A info - auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout); + auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0); auto aRep = aEncoding.getMMAv1Rep(); auto aSpw = aEncoding.getMMAv1ShapePerWarp(); // B info - auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout); + auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0); auto bSpw = bEncoding.getMMAv1ShapePerWarp(); auto bRep = bEncoding.getMMAv1Rep(); diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp index 2bffe4e32..21ed46ba4 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMMAv2.cpp @@ -19,10 +19,10 @@ using ::mlir::triton::gpu::SharedEncodingAttr; class MMA16816SmemLoader { public: MMA16816SmemLoader(int wpt, ArrayRef order, uint32_t kOrder, - ArrayRef smemStrides, ArrayRef tileShape, - ArrayRef instrShape, ArrayRef matShape, - int perPhase, int maxPhase, int elemBytes, - ConversionPatternRewriter &rewriter, + int kWidth, ArrayRef smemStrides, + ArrayRef tileShape, ArrayRef instrShape, + ArrayRef matShape, int perPhase, int maxPhase, + int elemBytes, ConversionPatternRewriter &rewriter, TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc); @@ -33,7 +33,7 @@ public: if (canUseLdmatrix) return computeLdmatrixMatOffs(warpOff, lane, cSwizzleOffset); else - return computeLdsMatOffs(warpOff, lane, cSwizzleOffset, elemBytes); + return computeLdsMatOffs(warpOff, lane, cSwizzleOffset); return {}; } @@ -45,7 +45,7 @@ public: Value cSwizzleOffset); // compute 8-bit matrix offset. SmallVector computeLdsMatOffs(Value warpOff, Value lane, - Value cSwizzleOffset, int elemBytes); + Value cSwizzleOffset); // Load 4 matrices and returns 4 vec<2> elements. std::tuple @@ -55,6 +55,7 @@ public: private: SmallVector order; int kOrder; + int kWidth; SmallVector tileShape; SmallVector instrShape; SmallVector matShape; @@ -176,9 +177,7 @@ MMA16816SmemLoader::computeLdmatrixMatOffs(Value warpId, Value lane, SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value warpOff, Value lane, - Value cSwizzleOffset, - int elemBytes) { - assert(elemBytes <= 4); + Value cSwizzleOffset) { int cTileShape = tileShape[order[0]]; int sTileShape = tileShape[order[1]]; if (!needTrans) { @@ -187,10 +186,10 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value warpOff, SmallVector offs(numPtrs); + int vecWidth = kWidth; int threadsPerQuad[2] = {8, 4}; int laneWidth = 4; int laneHeight = 8; - int vecWidth = 4 / elemBytes; int quadWidth = laneWidth * vecWidth; int quadHeight = laneHeight; int numQuadI = 2; @@ -232,8 +231,8 @@ SmallVector MMA16816SmemLoader::computeLdsMatOffs(Value warpOff, Value i = add(iBase, mul(iOff, i32_val(quadHeight))); Value j = add(jBase, mul(jOff, i32_val(quadWidth))); // wrap around the bounds - i = urem(i, i32_val(cTileShape)); - j = urem(j, i32_val(sTileShape)); + // i = urem(i, i32_val(cTileShape)); + // j = urem(j, i32_val(sTileShape)); if (needTrans) { offs[idx] = add(i, mul(j, sStride)); } else { @@ -304,7 +303,6 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef offs, return {extract_val(elemTy, resV4, 0), extract_val(elemTy, resV4, 1), extract_val(elemTy, resV4, 2), extract_val(elemTy, resV4, 3)}; } else { - elemTy = matTy.cast().getBody()[0]; // base pointers std::array, 2> ptrs; int vecWidth = 4 / elemBytes; @@ -324,39 +322,50 @@ MMA16816SmemLoader::loadX4(int mat0, int mat1, ArrayRef offs, std::array ii = {i0, i1}; // load 4 32-bit values from shared memory // (equivalent to ldmatrix.x4) - SmallVector> vals(4, SmallVector(vecWidth)); + SmallVector> vptrs(4, SmallVector(vecWidth)); for (int i = 0; i < 4; ++i) for (int j = 0; j < vecWidth; ++j) - vals[i][j] = load(gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2])); + vptrs[i][j] = gep(shemPtrTy, ptrs[i / 2][j], ii[i % 2]); // row + trans and col + no-trans are equivalent - if ((needTrans && kOrder == 1) || (!needTrans && kOrder == 0)) - std::swap(vals[1], vals[2]); + bool isActualTrans = + (needTrans && kOrder == 1) || (!needTrans && kOrder == 0); + if (isActualTrans) + std::swap(vptrs[1], vptrs[2]); // pack loaded vectors into 4 32-bit values + int inc = needTrans ? 1 : kWidth; + VectorType packedTy = vec_ty(int_ty(8 * elemBytes), inc); + int canonBits = std::min(32, 8 * elemBytes * inc); + int canonWidth = (8 * elemBytes * inc) / canonBits; + Type canonInt = int_ty(canonBits); std::array retElems; - retElems.fill(undef(elemTy)); - for (int m = 0; m < 4; ++m) { - for (int e = 0; e < vecWidth; ++e) - retElems[m] = insert_element(retElems[m].getType(), retElems[m], - vals[m][e], i32_val(e)); + retElems.fill(undef(vec_ty(canonInt, 32 / canonBits))); + for (int r = 0; r < 2; ++r) { + for (int em = 0; em < 2 * vecWidth; em += inc) { + int e = em % vecWidth; + int m = em / vecWidth; + int idx = m * 2 + r; + Value ptr = bitcast(vptrs[idx][e], ptr_ty(packedTy, 3)); + Value val = load(ptr); + Value canonval = bitcast(val, vec_ty(canonInt, canonWidth)); + for (int w = 0; w < canonWidth; ++w) { + retElems[idx + w * kWidth / vecWidth] = + insert_element(retElems[idx + w * kWidth / vecWidth], + extract_element(canonval, i32_val(w)), i32_val(e)); + } + } } - if (elemBytes == 1) - return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), - bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; - else - return {retElems[0], retElems[1], retElems[2], retElems[3]}; + return {bitcast(retElems[0], i32_ty), bitcast(retElems[1], i32_ty), + bitcast(retElems[2], i32_ty), bitcast(retElems[3], i32_ty)}; } - - assert(false && "Invalid smem load"); - return {Value{}, Value{}, Value{}, Value{}}; } MMA16816SmemLoader::MMA16816SmemLoader( - int wpt, ArrayRef order, uint32_t kOrder, + int wpt, ArrayRef order, uint32_t kOrder, int kWidth, ArrayRef smemStrides, ArrayRef tileShape, ArrayRef instrShape, ArrayRef matShape, int perPhase, int maxPhase, int elemBytes, ConversionPatternRewriter &rewriter, TritonGPUToLLVMTypeConverter *typeConverter, const Location &loc) - : order(order.begin(), order.end()), kOrder(kOrder), + : order(order.begin(), order.end()), kOrder(kOrder), kWidth(kWidth), tileShape(tileShape.begin(), tileShape.end()), instrShape(instrShape.begin(), instrShape.end()), matShape(matShape.begin(), matShape.end()), perPhase(perPhase), @@ -369,7 +378,8 @@ MMA16816SmemLoader::MMA16816SmemLoader( // rule: k must be the fast-changing axis. needTrans = kOrder != order[0]; - canUseLdmatrix = elemBytes == 2 || (!needTrans); // b16 + canUseLdmatrix = elemBytes == 2 || (!needTrans); + canUseLdmatrix = canUseLdmatrix && (kWidth == 4 / elemBytes); if (canUseLdmatrix) { // Each CTA, the warps is arranged as [1xwpt] if not transposed, @@ -409,42 +419,12 @@ Type getShemPtrTy(Type argType) { return ptr_ty(type::i16Ty(ctx), 3); else if (argType.isF32()) return ptr_ty(type::f32Ty(ctx), 3); - else if (argType.isInteger(8)) + else if (argType.getIntOrFloatBitWidth() == 8) return ptr_ty(type::i8Ty(ctx), 3); else llvm::report_fatal_error("mma16816 data type not supported"); } -Type getMatType(Type argType) { - MLIRContext *ctx = argType.getContext(); - // floating point types - Type fp32x1Ty = vec_ty(type::f32Ty(ctx), 1); - Type fp16x2Ty = vec_ty(type::f16Ty(ctx), 2); - Type i16x2Ty = vec_ty(type::i16Ty(ctx), 2); - Type fp16x2Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp16x2Ty)); - // LLVM 14.0 does not support bf16 type, so we use i16 instead. - Type bf16x2Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i16x2Ty)); - Type fp32Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, fp32x1Ty)); - // integer types - Type i8x4Ty = vec_ty(type::i8Ty(ctx), 4); - Type i8x4Pack4Ty = - LLVM::LLVMStructType::getLiteral(ctx, SmallVector(4, i8x4Ty)); - - if (argType.isF16()) - return fp16x2Pack4Ty; - else if (argType.isBF16()) - return bf16x2Pack4Ty; - else if (argType.isF32()) - return fp32Pack4Ty; - else if (argType.isInteger(8)) - return i8x4Pack4Ty; - else - llvm::report_fatal_error("mma16816 data type not supported"); -} - Value composeValuesToDotOperandLayoutStruct( const ValueTable &vals, int n0, int n1, TritonGPUToLLVMTypeConverter *typeConverter, Location loc, @@ -470,7 +450,7 @@ Value composeValuesToDotOperandLayoutStruct( std::function getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, - MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, + MmaEncodingAttr mmaLayout, int wpt, uint32_t kOrder, int kWidth, SmallVector instrShape, SmallVector matShape, Value warpId, Value lane, ValueTable &vals, bool isA, TritonGPUToLLVMTypeConverter *typeConverter, @@ -485,143 +465,105 @@ getLoadMatrixFn(Value tensor, const SharedMemoryObject &smemObj, const int elemBytes = tensorTy.getElementTypeBitWidth() / 8; auto order = sharedLayout.getOrder(); - // the original register_lds2, but discard the prefetch logic. - auto ld2 = [](ValueTable &vals, int mn, int k, Value val) { - vals[{mn, k}] = val; - }; - // (a, b) is the coordinate. - auto load = [=, &rewriter, &vals, &ld2](int a, int b) { + auto load = [=, &rewriter, &vals](int a, int b) { MMA16816SmemLoader loader( - wpt, sharedLayout.getOrder(), kOrder, smemObj.strides, + wpt, sharedLayout.getOrder(), kOrder, kWidth, smemObj.strides, tensorTy.getShape() /*tileShape*/, instrShape, matShape, perPhase, maxPhase, elemBytes, rewriter, typeConverter, loc); Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); SmallVector offs = loader.computeOffsets(warpId, lane, cSwizzleOffset); + // initialize pointers const int numPtrs = loader.getNumPtrs(); SmallVector ptrs(numPtrs); - Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter); - Type smemPtrTy = getShemPtrTy(eltTy); - for (int i = 0; i < numPtrs; ++i) { - ptrs[i] = - bitcast(gep(smemPtrTy, smemBase, ValueRange({offs[i]})), smemPtrTy); - } - + for (int i = 0; i < numPtrs; ++i) + ptrs[i] = bitcast(gep(smemPtrTy, smemBase, offs[i]), smemPtrTy); + // actually load from shared memory + auto matTy = LLVM::LLVMStructType::getLiteral(eltTy.getContext(), + SmallVector(4, i32_ty)); auto [ha0, ha1, ha2, ha3] = loader.loadX4( (kOrder == 1) ? a : b /*mat0*/, (kOrder == 1) ? b : a /*mat1*/, offs, - ptrs, getMatType(eltTy), getShemPtrTy(eltTy)); - - if (isA) { - ld2(vals, a, b, ha0); - ld2(vals, a + 1, b, ha1); - ld2(vals, a, b + 1, ha2); - ld2(vals, a + 1, b + 1, ha3); - } else { - ld2(vals, a, b, ha0); - ld2(vals, a + 1, b, ha2); - ld2(vals, a, b + 1, ha1); - ld2(vals, a + 1, b + 1, ha3); - } + ptrs, matTy, getShemPtrTy(eltTy)); + if (!isA) + std::swap(ha1, ha2); + // the following is incorrect + // but causes dramatically better performance in ptxas + // although it only changes the order of operands in + // `mma.sync` + // if(isA) + // std::swap(ha1, ha2); + // update user-provided values in-place + vals[{a, b}] = ha0; + vals[{a + 1, b}] = ha1; + vals[{a, b + 1}] = ha2; + vals[{a + 1, b + 1}] = ha3; }; return load; } -Value loadA(ConversionPatternRewriter &rewriter, Location loc, Value tensor, - DotOperandEncodingAttr aEncoding, const SharedMemoryObject &smemObj, - TritonGPUToLLVMTypeConverter *typeConverter, Value thread) { - auto aTensorTy = tensor.getType().cast(); - int bitwidth = aTensorTy.getElementTypeBitWidth(); - auto mmaLayout = aEncoding.getParent().cast(); - - SmallVector shape(aTensorTy.getShape().begin(), - aTensorTy.getShape().end()); - - ValueTable ha; - std::function loadFn; - int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; - int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; - - auto numRep = aEncoding.getMMAv2Rep(aTensorTy.getShape(), bitwidth); - int numRepM = numRep[0]; - int numRepK = numRep[1]; - - if (aTensorTy.getEncoding().isa()) { - int wpt0 = mmaLayout.getWarpsPerCTA()[0]; - Value warp = udiv(thread, i32_val(32)); - Value lane = urem(thread, i32_val(32)); - Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16)); - // load from smem - // we use ldmatrix.x4 so each warp processes 16x16 elements. - int wpt = std::min(wpt0, shape[0] / 16); - loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, - {mmaInstrM, mmaInstrK} /*instrShape*/, - {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/, - ha /*vals*/, true /*isA*/, typeConverter /* typeConverter */, - rewriter /*rewriter*/, loc /*loc*/); - } else if (aTensorTy.getEncoding().isa()) { - // load from registers, used in gemm fuse - // TODO(Superjomn) Port the logic. - assert(false && "Loading A from register is not supported yet."); - } else { - assert(false && "A's layout is not supported."); - } - - // step1. Perform loading. - for (int m = 0; m < numRepM; ++m) - for (int k = 0; k < numRepK; ++k) - loadFn(2 * m, 2 * k); - - // step2. Format the values to LLVM::Struct to passing to mma codegen. - return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK, - typeConverter, loc, rewriter); -} - -Value loadB(ConversionPatternRewriter &rewriter, Location loc, Value tensor, - DotOperandEncodingAttr bEncoding, const SharedMemoryObject &smemObj, - TritonGPUToLLVMTypeConverter *typeConverter, Value thread) { - ValueTable hb; +Value loadArg(ConversionPatternRewriter &rewriter, Location loc, Value tensor, + DotOperandEncodingAttr encoding, + const SharedMemoryObject &smemObj, + TritonGPUToLLVMTypeConverter *typeConverter, Value thread, + bool isA) { auto tensorTy = tensor.getType().cast(); int bitwidth = tensorTy.getElementTypeBitWidth(); - auto mmaLayout = bEncoding.getParent().cast(); + auto mmaLayout = encoding.getParent().cast(); SmallVector shape(tensorTy.getShape().begin(), tensorTy.getShape().end()); + ValueTable vals; int mmaInstrM = 16, mmaInstrN = 8, mmaInstrK = 4 * 64 / bitwidth; int matShapeM = 8, matShapeN = 8, matShapeK = 2 * 64 / bitwidth; - auto numRep = bEncoding.getMMAv2Rep(tensorTy.getShape(), bitwidth); - int numRepK = numRep[0]; - int numRepN = numRep[1]; + auto numRep = encoding.getMMAv2Rep(tensorTy.getShape(), bitwidth); + int kWidth = encoding.getMMAv2kWidth(); int wpt0 = mmaLayout.getWarpsPerCTA()[0]; int wpt1 = mmaLayout.getWarpsPerCTA()[1]; Value warp = udiv(thread, i32_val(32)); Value lane = urem(thread, i32_val(32)); + Value warpM = urem(urem(warp, i32_val(wpt0)), i32_val(shape[0] / 16)); Value warpMN = udiv(warp, i32_val(wpt0)); Value warpN = urem(urem(warpMN, i32_val(wpt1)), i32_val(shape[1] / 8)); - // we use ldmatrix.x4 so each warp processes 16x16 elements. - int wpt = std::min(wpt1, shape[1] / 16); - auto loadFn = getLoadMatrixFn( - tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, - {mmaInstrK, mmaInstrN} /*instrShape*/, - {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/, - hb /*vals*/, false /*isA*/, typeConverter /* typeConverter */, - rewriter /*rewriter*/, loc /*loc*/); - for (int n = 0; n < std::max(numRepN / 2, 1); ++n) { + int wpt; + if (isA) + wpt = std::min(wpt0, shape[0] / 16); + else + wpt = std::min(wpt1, shape[1] / 16); + + std::function loadFn; + if (isA) + loadFn = getLoadMatrixFn( + tensor, smemObj, mmaLayout, wpt /*wpt*/, 1 /*kOrder*/, kWidth, + {mmaInstrM, mmaInstrK} /*instrShape*/, + {matShapeM, matShapeK} /*matShape*/, warpM /*warpId*/, lane /*laneId*/, + vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */, + rewriter /*rewriter*/, loc /*loc*/); + else + loadFn = getLoadMatrixFn( + tensor, smemObj, mmaLayout, wpt /*wpt*/, 0 /*kOrder*/, kWidth, + {mmaInstrK, mmaInstrN} /*instrShape*/, + {matShapeK, matShapeN} /*matShape*/, warpN /*warpId*/, lane /*laneId*/, + vals /*vals*/, isA /*isA*/, typeConverter /* typeConverter */, + rewriter /*rewriter*/, loc /*loc*/); + + // Perform loading. + int numRepOuter = isA ? numRep[0] : std::max(numRep[1] / 2, 1); + int numRepK = isA ? numRep[1] : numRep[0]; + for (int m = 0; m < numRepOuter; ++m) for (int k = 0; k < numRepK; ++k) - loadFn(2 * n, 2 * k); - } + loadFn(2 * m, 2 * k); - Value result = composeValuesToDotOperandLayoutStruct( - hb, std::max(numRepN / 2, 1), numRepK, typeConverter, loc, rewriter); - return result; + // Format the values to LLVM::Struct to passing to mma codegen. + return composeValuesToDotOperandLayoutStruct(vals, numRepOuter, numRepK, + typeConverter, loc, rewriter); } namespace SharedToDotOperandMMAv2 { @@ -630,12 +572,12 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, const SharedMemoryObject &smemObj, TritonGPUToLLVMTypeConverter *typeConverter, Value thread) { if (opIdx == 0) - return loadA(rewriter, loc, tensor, encoding, smemObj, typeConverter, - thread); + return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter, + thread, true); else { assert(opIdx == 1); - return loadB(rewriter, loc, tensor, encoding, smemObj, typeConverter, - thread); + return loadArg(rewriter, loc, tensor, encoding, smemObj, typeConverter, + thread, false); } } } // namespace SharedToDotOperandMMAv2 diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp index 11065c048..34d2da5cf 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.cpp @@ -62,9 +62,8 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern { }; void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + RewritePatternSet &patterns, + ModuleAllocation &allocation, PatternBenefit benefit) { - patterns.add(typeConverter, allocation, smem, benefit); + patterns.add(typeConverter, allocation, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.h index 500a3e59b..92ad51f38 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM.h @@ -7,9 +7,8 @@ using namespace mlir; using namespace mlir::triton; void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + RewritePatternSet &patterns, + ModuleAllocation &allocation, PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 8e4673249..8e147e0b8 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -4,11 +4,140 @@ using namespace mlir; using namespace mlir::triton; using ::mlir::triton::gpu::getTotalElemsPerThread; +static SmallVector reorderValues(const SmallVector &values, + Type inType, Type ouType) { + auto inTensorTy = inType.dyn_cast(); + auto ouTensorTy = ouType.dyn_cast(); + if (!inTensorTy || !ouTensorTy) + return values; + auto inEncoding = + dyn_cast(inTensorTy.getEncoding()); + auto ouEncoding = + dyn_cast(ouTensorTy.getEncoding()); + assert(inEncoding == ouEncoding); + if (!inEncoding) + return values; + size_t inBitWidth = inTensorTy.getElementType().getIntOrFloatBitWidth(); + size_t ouBitWidth = ouTensorTy.getElementType().getIntOrFloatBitWidth(); + auto ouEltTy = ouTensorTy.getElementType(); + if (inBitWidth == ouBitWidth) + return values; + if (inBitWidth == 16 && ouBitWidth == 32) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 8) { + ret.push_back(values[i]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + } + return ret; + } + if (inBitWidth == 8 && ouBitWidth == 16) { + SmallVector ret; + for (unsigned i = 0; i < values.size(); i += 16) { + ret.push_back(values[i + 0]); + ret.push_back(values[i + 1]); + ret.push_back(values[i + 2]); + ret.push_back(values[i + 3]); + ret.push_back(values[i + 8]); + ret.push_back(values[i + 9]); + ret.push_back(values[i + 10]); + ret.push_back(values[i + 11]); + ret.push_back(values[i + 4]); + ret.push_back(values[i + 5]); + ret.push_back(values[i + 6]); + ret.push_back(values[i + 7]); + ret.push_back(values[i + 12]); + ret.push_back(values[i + 13]); + ret.push_back(values[i + 14]); + ret.push_back(values[i + 15]); + } + return ret; + // for (unsigned i = 0; i < values.size(); i += 16) { + // ret.push_back(values[i]); + // ret.push_back(values[i + 1]); + // ret.push_back(values[i + 4]); + // ret.push_back(values[i + 5]); + // ret.push_back(values[i + 8]); + // ret.push_back(values[i + 9]); + // ret.push_back(values[i + 12]); + // ret.push_back(values[i + 13]); + + // ret.push_back(values[i + 2]); + // ret.push_back(values[i + 3]); + // ret.push_back(values[i + 6]); + // ret.push_back(values[i + 7]); + // ret.push_back(values[i + 10]); + // ret.push_back(values[i + 11]); + // ret.push_back(values[i + 14]); + // ret.push_back(values[i + 15]); + // } + return values; + } + llvm_unreachable("unimplemented code path"); +} + +inline SmallVector unpackI32(const SmallVector &inValues, + Type srcTy, + ConversionPatternRewriter &rewriter, + Location loc, + TypeConverter *typeConverter) { + auto tensorTy = srcTy.dyn_cast(); + if (!tensorTy) + return inValues; + auto encoding = tensorTy.getEncoding().dyn_cast(); + if (!(encoding && encoding.getParent().isa())) + return inValues; + SmallVector outValues; + for (auto v : inValues) { + // cast i32 to appropriate eltType vector and extract elements + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + auto vecType = vec_ty(eltType, 32 / eltType.getIntOrFloatBitWidth()); + auto vec = bitcast(v, vecType); + for (int i = 0; i < 32 / eltType.getIntOrFloatBitWidth(); i++) { + outValues.push_back(extract_element(vec, i32_val(i))); + } + } + return outValues; +} + +inline SmallVector packI32(const SmallVector &inValues, + Type srcTy, + ConversionPatternRewriter &rewriter, + Location loc, TypeConverter *typeConverter) { + auto tensorTy = srcTy.dyn_cast(); + if (!tensorTy) + return inValues; + auto encoding = tensorTy.getEncoding().dyn_cast(); + if (!(encoding && encoding.getParent().isa())) + return inValues; + SmallVector outValues; + auto eltType = typeConverter->convertType(tensorTy.getElementType()); + int vecWidth = 32 / eltType.getIntOrFloatBitWidth(); + auto vecType = vec_ty(eltType, vecWidth); + for (int i = 0; i < inValues.size(); i += vecWidth) { + Value vec = undef(vecType); + for (int j = 0; j < vecWidth; j++) { + vec = insert_element(vec, inValues[i + j], i32_val(j)); + } + outValues.push_back(bitcast(vec, i32_ty)); + } + return outValues; +} + struct FpToFpOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern; + typedef std::function( + Location, ConversionPatternRewriter &, const Value &, const Value &, + const Value &, const Value &)> + ConvertorT; /* ------------------ */ // FP8 -> FP16 /* ------------------ */ @@ -747,35 +876,14 @@ struct FpToFpOpConversion #endif } - LogicalResult - matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcTensorType = op.getFrom().getType().cast(); - auto dstTensorType = - op.getResult().getType().cast(); - auto srcEltType = srcTensorType.getElementType(); - auto dstEltType = dstTensorType.getElementType(); - auto loc = op->getLoc(); - auto elems = getTotalElemsPerThread(dstTensorType); - SmallVector resultVals; - bool isSrcFP8 = - srcEltType.isa(); - bool isDstFP8 = - dstEltType.isa(); - - // Select convertor - typedef std::function( - Location, ConversionPatternRewriter &, const Value &, const Value &, - const Value &, const Value &)> - ConvertorT; - + ConvertorT getConversionFunc(Type srcTy, Type dstTy) const { auto F8E4M3TyID = TypeID::get(); auto F8E5M2TyID = TypeID::get(); auto F16TyID = TypeID::get(); auto BF16TyID = TypeID::get(); auto F32TyID = TypeID::get(); auto F64TyID = TypeID::get(); - DenseMap, ConvertorT> convertorMap = { + static DenseMap, ConvertorT> convertorMap = { // F8 -> F16 {{F8E4M3TyID, F16TyID}, convertFp8E4M3x4ToFp16x4}, {{F8E5M2TyID, F16TyID}, convertFp8E5M2x4ToFp16x4}, @@ -796,28 +904,46 @@ struct FpToFpOpConversion {{F32TyID, F8E5M2TyID}, convertFp32x4ToFp8E5M2x4}, }; - std::pair key = {srcEltType.getTypeID(), - dstEltType.getTypeID()}; + std::pair key = {srcTy.getTypeID(), dstTy.getTypeID()}; if (convertorMap.count(key) == 0) { - llvm::errs() << "Unsupported conversion from " << srcEltType << " to " - << dstEltType << "\n"; + llvm::errs() << "Unsupported conversion from " << srcTy << " to " << dstTy + << "\n"; llvm_unreachable(""); } - auto convertor = convertorMap.lookup(key); + return convertorMap.lookup(key); + } - // Vectorized casting + LogicalResult + matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // llvm::outs() << 0 << "\n"; + auto srcTensorType = op.getFrom().getType().cast(); + auto dstTensorType = + op.getResult().getType().cast(); + auto loc = op->getLoc(); + // check that the number of elements is divisible by 4 + // Get convertor + auto cvtFunc = getConversionFunc(srcTensorType.getElementType(), + dstTensorType.getElementType()); + // Unpack value + auto inVals = getTypeConverter()->unpackLLElements(loc, adaptor.getFrom(), + rewriter, srcTensorType); + inVals = + unpackI32(inVals, srcTensorType, rewriter, loc, getTypeConverter()); + // Cast + SmallVector outVals; + auto elems = inVals.size(); assert(elems % 4 == 0 && "FP8 casting only support tensors with 4-aligned sizes"); - auto elements = getTypeConverter()->unpackLLElements( - loc, adaptor.getFrom(), rewriter, srcTensorType); - for (size_t i = 0; i < elems; i += 4) { - auto converted = convertor(loc, rewriter, elements[i], elements[i + 1], - elements[i + 2], elements[i + 3]); - resultVals.append(converted); - } - - assert(resultVals.size() == elems); - auto result = getTypeConverter()->packLLElements(loc, resultVals, rewriter, + for (size_t i = 0; i < elems; i += 4) + outVals.append(cvtFunc(loc, rewriter, inVals[i], inVals[i + 1], + inVals[i + 2], inVals[i + 3])); + // Pack values + assert(outVals.size() == elems); + outVals = reorderValues(outVals, srcTensorType, dstTensorType); + outVals = + packI32(outVals, dstTensorType, rewriter, loc, getTypeConverter()); + auto result = getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTensorType); rewriter.replaceOp(op, result); return success(); @@ -849,43 +975,44 @@ public: ConversionPatternRewriter &rewriter) const override { auto resultTy = op.getType(); Location loc = op->getLoc(); - - unsigned elems = getTotalElemsPerThread(resultTy); + // element type auto resultElementTy = getElementTypeOrSelf(resultTy); Type elemTy = this->getTypeConverter()->convertType(resultElementTy); - SmallVector types(elems, elemTy); - Type structTy = this->getTypeConverter()->convertType(resultTy); - - auto *concreteThis = static_cast(this); - auto operands = getOperands(rewriter, adaptor, resultTy, elems, loc); - SmallVector resultVals(elems); - for (unsigned i = 0; i < elems; ++i) { - resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy, - operands[i], loc); - if (!bool(resultVals[i])) - return failure(); + SmallVector resultVals; + // + SmallVector> allOperands; + for (auto operand : adaptor.getOperands()) { + auto argTy = op->getOperand(0).getType(); + auto sub_operands = this->getTypeConverter()->unpackLLElements( + loc, operand, rewriter, argTy); + sub_operands = unpackI32(sub_operands, argTy, rewriter, loc, + this->getTypeConverter()); + allOperands.resize(sub_operands.size()); + for (auto v : llvm::enumerate(sub_operands)) + allOperands[v.index()].push_back(v.value()); } + if (allOperands.size() == 0) + allOperands.push_back({}); + for (const SmallVector &operands : allOperands) { + Value curr = + ((ConcreteT *)(this)) + ->createDestOp(op, adaptor, rewriter, elemTy, operands, loc); + if (!bool(curr)) + return failure(); + resultVals.push_back(curr); + } + if (op->getNumOperands() > 0) { + auto argTy = op->getOperand(0).getType(); + resultVals = reorderValues(resultVals, argTy, resultTy); + } + resultVals = + packI32(resultVals, resultTy, rewriter, loc, this->getTypeConverter()); Value view = this->getTypeConverter()->packLLElements(loc, resultVals, rewriter, resultTy); rewriter.replaceOp(op, view); return success(); } - -protected: - SmallVector> - getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor, - Type operandTy, const unsigned elems, Location loc) const { - SmallVector> operands(elems); - for (auto operand : adaptor.getOperands()) { - auto sub_operands = this->getTypeConverter()->unpackLLElements( - loc, operand, rewriter, operandTy); - for (size_t i = 0; i < elems; ++i) { - operands[i].push_back(sub_operands[i]); - } - } - return operands; - } }; template @@ -1344,8 +1471,7 @@ struct AbsFOpConversion void populateElementwiseOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, PatternBenefit benefit) { + PatternBenefit benefit) { #define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \ patterns.add>(typeConverter, benefit); POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp) diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h index 20404f875..2187f798b 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h @@ -8,8 +8,7 @@ using namespace mlir::triton; void populateElementwiseOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, PatternBenefit benefit); + PatternBenefit benefit); bool isLegalElementwiseOp(Operation *op); diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 06d0e45b1..4bc991787 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -13,7 +13,7 @@ using ::mlir::triton::gpu::SharedEncodingAttr; // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { - explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass) + explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass) : axisAnalysisPass(axisAnalysisPass) {} unsigned getContiguity(Value ptr) const { @@ -38,7 +38,7 @@ struct LoadStoreConversionBase { } protected: - AxisInfoAnalysis &axisAnalysisPass; + ModuleAxisInfoAnalysis &axisAnalysisPass; }; struct LoadOpConversion @@ -48,7 +48,8 @@ struct LoadOpConversion triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern; LoadOpConversion(TritonGPUToLLVMTypeConverter &converter, - AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(axisAnalysisPass) {} @@ -293,7 +294,8 @@ struct StoreOpConversion triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern; StoreOpConversion(TritonGPUToLLVMTypeConverter &converter, - AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) + ModuleAxisInfoAnalysis &axisAnalysisPass, + PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern(converter, benefit), LoadStoreConversionBase(axisAnalysisPass) {} @@ -335,14 +337,7 @@ struct StoreOpConversion vec = std::min(vec, maskAlign); } - // numElements = 1 for scalar - auto tensorTy = valueTy.dyn_cast(); - auto numElems = tensorTy ? tensorTy.getNumElements() : 1; - Value mask = int_val(1, 1); - auto tid = tid_val(); - mask = and_(mask, - icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); - + Value mask = getMask(valueTy, rewriter, loc); const size_t dtsize = std::max(1, valueElemTy.getIntOrFloatBitWidth() / 8); const size_t valueElemNBits = dtsize * 8; @@ -431,11 +426,11 @@ struct AtomicCASOpConversion triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern; AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter, - const Allocation *allocation, Value smem, - AxisInfoAnalysis &axisAnalysisPass, + ModuleAllocation &allocation, + ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( - converter, allocation, smem, benefit), + converter, allocation, benefit), LoadStoreConversionBase(axisAnalysisPass) {} #ifdef USE_ROCM @@ -526,13 +521,13 @@ struct AtomicCASOpConversion auto valElements = getTypeConverter()->unpackLLElements( loc, llVal, rewriter, op.getVal().getType()); - auto TensorTy = op.getResult().getType().dyn_cast(); + auto valueTy = op.getResult().getType(); + auto TensorTy = valueTy.dyn_cast(); Type valueElemTy = TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType()) - : op.getResult().getType(); + : valueTy; auto valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); - auto tid = tid_val(); - Value pred = icmp_eq(tid, i32_val(0)); + Value mask = getMask(valueTy, rewriter, loc); PTXBuilder ptxBuilderMemfence; auto memfence = ptxBuilderMemfence.create("membar")->o("gl"); memfence(); @@ -552,7 +547,7 @@ struct AtomicCASOpConversion auto *valOpr = ptxBuilderAtomicCAS.newOperand(casVal, "r"); auto &atom = *ptxBuilderAtomicCAS.create("atom"); atom.global().o("cas").o("b32"); - atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(pred); + atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask); auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); barrier(); @@ -561,7 +556,7 @@ struct AtomicCASOpConversion auto *valOprStore = ptxBuilderStore.newOperand(old, "r"); auto &st = *ptxBuilderStore.create("st"); st.shared().o("b32"); - st(dstOprStore, valOprStore).predicate(pred); + st(dstOprStore, valOprStore).predicate(mask); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); barrier(); @@ -580,11 +575,11 @@ struct AtomicRMWOpConversion triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern; AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter, - const Allocation *allocation, Value smem, - AxisInfoAnalysis &axisAnalysisPass, + ModuleAllocation &allocation, + ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( - converter, allocation, smem, benefit), + converter, allocation, benefit), LoadStoreConversionBase(axisAnalysisPass) {} #ifdef USE_ROCM @@ -747,10 +742,11 @@ struct AtomicRMWOpConversion maskElements = getTypeConverter()->unpackLLElements( loc, llMask, rewriter, op.getMask().getType()); - auto tensorTy = op.getResult().getType().dyn_cast(); + auto valueTy = op.getResult().getType(); + auto tensorTy = valueTy.dyn_cast(); Type valueElemTy = tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType()) - : op.getResult().getType(); + : valueTy; const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth(); auto elemsPerThread = getTotalElemsPerThread(val.getType()); // vec = 1, numElements = 1 for scalar @@ -763,10 +759,7 @@ struct AtomicRMWOpConversion // mask numElems = tensorTy.getNumElements(); } - Value mask = int_val(1, 1); - auto tid = tid_val(); - mask = and_(mask, - icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems))); + Value mask = getMask(valueTy, rewriter, loc); auto vecTy = vec_ty(valueElemTy, vec); SmallVector resultVals(elemsPerThread); @@ -846,7 +839,6 @@ struct AtomicRMWOpConversion memfenc(); auto ASMReturnTy = void_ty(ctx); ptxBuilderMemfence.launch(rewriter, loc, ASMReturnTy); - rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0))); atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask); auto old = ptxBuilderAtomicRMW.launch(rewriter, loc, valueElemTy); Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation()); @@ -889,7 +881,9 @@ struct InsertSliceOpConversion Value dst = op.getDest(); Value src = op.getSource(); Value res = op.getResult(); - assert(allocation->getBufferId(res) == Allocation::InvalidBufferId && + auto funcOp = op->getParentOfType(); + auto *funcAllocation = allocation->getFuncData(funcOp); + assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId && "Only support in-place insert_slice for now"); auto srcTy = src.getType().dyn_cast(); @@ -949,12 +943,11 @@ struct InsertSliceAsyncOpConversion triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern; InsertSliceAsyncOpConversion( - TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation, - Value smem, + TritonGPUToLLVMTypeConverter &converter, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, - AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) + ModuleAxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( - converter, allocation, smem, indexCacheInfo, benefit), + converter, allocation, indexCacheInfo, benefit), LoadStoreConversionBase(axisAnalysisPass) {} LogicalResult @@ -967,7 +960,9 @@ struct InsertSliceAsyncOpConversion Value res = op.getResult(); Value mask = op.getMask(); Value other = op.getOther(); - assert(allocation->getBufferId(res) == Allocation::InvalidBufferId && + auto funcOp = op->getParentOfType(); + auto *funcAllocation = allocation->getFuncData(funcOp); + assert(funcAllocation->getBufferId(res) == Allocation::InvalidBufferId && "Only support in-place insert_slice_async for now"); auto srcTy = src.getType().cast(); @@ -1107,19 +1102,17 @@ struct InsertSliceAsyncOpConversion void populateLoadStoreOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { patterns.add(typeConverter, axisInfoAnalysis, benefit); patterns.add(typeConverter, axisInfoAnalysis, benefit); - patterns.add(typeConverter, allocation, smem, + patterns.add(typeConverter, allocation, axisInfoAnalysis, benefit); - patterns.add(typeConverter, allocation, smem, + patterns.add(typeConverter, allocation, axisInfoAnalysis, benefit); - patterns.add(typeConverter, allocation, smem, + patterns.add(typeConverter, allocation, indexCacheInfo, benefit); - patterns.add(typeConverter, allocation, smem, - indexCacheInfo, axisInfoAnalysis, - benefit); + patterns.add( + typeConverter, allocation, indexCacheInfo, axisInfoAnalysis, benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h index 830aa8718..b3f9f52af 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.h @@ -8,8 +8,7 @@ using namespace mlir::triton; void populateLoadStoreOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index dd045b9ed..568667dba 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -87,6 +87,15 @@ private: Attribute layout, SmallVector &index, SmallVector &writeIdx, std::map &ints, unsigned axis) const { + if (auto sliceLayout = layout.dyn_cast()) { + auto dim = sliceLayout.getDim(); + assert(dim != axis && "Reduction axis cannot be sliced"); + auto parentLayout = sliceLayout.getParent(); + getWriteIndexBasic(rewriter, loc, parentLayout, index, writeIdx, ints, + axis); + return; + } + writeIdx = index; auto sizePerThread = triton::gpu::getSizePerThread(layout); Value axisSizePerThread = ints[sizePerThread[axis]]; @@ -100,9 +109,10 @@ private: // to map every `axisSizePerThread` to 1 value in smem as: // writeIdx[axis] = index[axis] / axisSizePerThread writeIdx[axis] = udiv(index[axis], axisSizePerThread); - } - auto mmaLayout = layout.dyn_cast(); - if (mmaLayout && mmaLayout.isAmpere()) { + } else if (auto mmaLayout = layout.dyn_cast()) { + if (!mmaLayout.isAmpere()) { + llvm::report_fatal_error("Unsupported layout"); + } if (axis == 0) { // Because warpTileSize = [16, 8] and threadsPerWarp = [8, 4], each 8 // rows in smem would correspond to a warp. The mapping @@ -113,8 +123,7 @@ private: // Same as BlockedEncodingAttr case writeIdx[axis] = udiv(index[axis], axisSizePerThread); } - } - if (mmaLayout && !mmaLayout.isAmpere()) { + } else { llvm::report_fatal_error("Unsupported layout"); } } @@ -327,8 +336,8 @@ private: elemPtrTys[i]); } - unsigned sizeIntraWarps = helper.getIntraWarpSize(); - unsigned sizeInterWarps = helper.getInterWarpSize(); + unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); unsigned srcElems = getTotalElemsPerThread(srcTys[0]); auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTys[0]); @@ -495,10 +504,9 @@ private: void populateReduceOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { - patterns.add(typeConverter, allocation, smem, - indexCacheInfo, benefit); + patterns.add(typeConverter, allocation, indexCacheInfo, + benefit); } diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h index a7c1c9912..677da86e3 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.h @@ -8,8 +8,7 @@ using namespace mlir::triton; void populateReduceOpToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp index d416dc2b5..f58f6960f 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp @@ -401,8 +401,7 @@ struct MakeRangeOpConversion ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) : ConvertTritonGPUOpToLLVMPattern( - converter, /*Allocation*/ nullptr, Value{}, indexCacheInfo, - benefit) {} + converter, indexCacheInfo, benefit) {} LogicalResult matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, @@ -669,18 +668,17 @@ void vprintf_array(Value thread, ArrayRef arr, std::string info, void populateTritonGPUToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + ModuleAllocation &moduleAllocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit) { patterns.add(typeConverter, benefit); - patterns.add(typeConverter, allocation, smem, + patterns.add(typeConverter, moduleAllocation, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); - patterns.add(typeConverter, allocation, smem, + patterns.add(typeConverter, moduleAllocation, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h index 407116a99..50ba1355d 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.h @@ -8,8 +8,7 @@ using namespace mlir::triton; void populateTritonGPUToLLVMPatterns( TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - int numWarps, AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + ModuleAllocation &allocation, ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, PatternBenefit benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index 34c1246df..011887adb 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -11,7 +11,7 @@ #include "Utility.h" #include "mlir/IR/TypeUtilities.h" #include "triton/Analysis/AxisInfo.h" - +#include using namespace mlir; using namespace mlir::triton; @@ -41,7 +41,7 @@ void vprintf_array(Value thread, ArrayRef arr, std::string info, // All the rights are reserved by the LLVM community. struct FuncOpConversionBase : public ConvertOpToLLVMPattern { -private: +protected: /// Only retain those attributes that are not constructed by /// `LLVMFuncOp::build`. If `filterArgAttrs` is set, also filter out argument /// attributes. @@ -184,14 +184,18 @@ public: : converter(&typeConverter) {} explicit ConvertTritonGPUOpToLLVMPatternBase( - TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation, - Value smem) - : converter(&typeConverter), allocation(allocation), smem(smem) {} + TritonGPUToLLVMTypeConverter &typeConverter, + IndexCacheInfo indexCacheInfo) + : converter(&typeConverter), indexCacheInfo(indexCacheInfo) {} explicit ConvertTritonGPUOpToLLVMPatternBase( - TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation, - Value smem, IndexCacheInfo indexCacheInfo) - : converter(&typeConverter), allocation(allocation), smem(smem), + TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation) + : converter(&typeConverter), allocation(&allocation) {} + + explicit ConvertTritonGPUOpToLLVMPatternBase( + TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation, + IndexCacheInfo indexCacheInfo) + : converter(&typeConverter), allocation(&allocation), indexCacheInfo(indexCacheInfo) {} TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; } @@ -228,9 +232,17 @@ public: T value) const { auto ptrTy = LLVM::LLVMPointerType::get( this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); - auto bufferId = allocation->getBufferId(value); + FunctionOpInterface funcOp; + if constexpr (std::is_pointer_v) + funcOp = value->template getParentOfType(); + else + funcOp = value.getParentRegion() + ->template getParentOfType(); + auto *funcAllocation = allocation->getFuncData(funcOp); + auto smem = allocation->getFunctionSharedMemoryBase(funcOp); + auto bufferId = funcAllocation->getBufferId(value); assert(bufferId != Allocation::InvalidBufferId && "BufferId not found"); - size_t offset = allocation->getOffset(bufferId); + size_t offset = funcAllocation->getOffset(bufferId); Value offVal = i32_val(offset); Value base = gep(ptrTy, smem, offVal); return base; @@ -409,6 +421,46 @@ public: // ----------------------------------------------------------------------- // Utilities // ----------------------------------------------------------------------- + Value getMask(Type valueTy, ConversionPatternRewriter &rewriter, + Location loc) const { + auto tensorTy = valueTy.dyn_cast(); + Value mask = int_val(1, 1); + auto tid = tid_val(); + if (tensorTy) { + auto layout = tensorTy.getEncoding(); + auto shape = tensorTy.getShape(); + unsigned rank = shape.size(); + auto sizePerThread = triton::gpu::getSizePerThread(layout); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(layout); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(layout); + auto order = triton::gpu::getOrder(layout); + auto shapePerCTA = triton::gpu::getShapePerCTA(layout, shape); + Value warpSize = i32_val(32); + Value laneId = urem(tid, warpSize); + Value warpId = udiv(tid, warpSize); + SmallVector multiDimWarpId = + delinearize(rewriter, loc, warpId, warpsPerCTA, order); + SmallVector multiDimThreadId = + delinearize(rewriter, loc, laneId, threadsPerWarp, order); + for (unsigned dim = 0; dim < rank; ++dim) { + // if there is no data replication across threads on this dimension + if (shape[dim] >= shapePerCTA[dim]) + continue; + // Otherwise, we need to mask threads that will replicate data on this + // dimension. Calculate the thread index on this dimension for the CTA + Value threadDim = + add(mul(multiDimWarpId[dim], i32_val(threadsPerWarp[dim])), + multiDimThreadId[dim]); + mask = and_(mask, icmp_slt(mul(threadDim, i32_val(sizePerThread[dim])), + i32_val(shape[dim]))); + } + } else { + // If the tensor is not ranked, then it is a scalar and only thread 0 can + // write + mask = and_(mask, icmp_eq(tid, i32_val(0))); + } + return mask; + } // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. @@ -505,13 +557,13 @@ public: RankedTensorType type) const { IndexCacheKeyT key = std::make_pair(layout, type); auto cache = indexCacheInfo.baseIndexCache; - assert(cache && "baseIndexCache is nullptr"); auto insertPt = indexCacheInfo.indexInsertPoint; - if (cache->count(key) > 0) { + if (cache && cache->count(key) > 0) { return cache->lookup(key); } else { ConversionPatternRewriter::InsertionGuard guard(rewriter); - restoreInsertionPointIfSet(insertPt, rewriter); + if (cache) + restoreInsertionPointIfSet(insertPt, rewriter); SmallVector result; if (auto blockedLayout = layout.dyn_cast()) { result = @@ -521,11 +573,20 @@ public: result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type); if (mmaLayout.isAmpere()) result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type); + } else if (auto sliceLayout = layout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(type.getShape()); + RankedTensorType parentTy = RankedTensorType::get( + parentShape, type.getElementType(), parentLayout); + result = emitBaseIndexForLayout(loc, rewriter, parentLayout, parentTy); + result.erase(result.begin() + sliceLayout.getDim()); } else { llvm_unreachable("unsupported emitBaseIndexForLayout"); } - cache->insert(std::make_pair(key, result)); - *insertPt = rewriter.saveInsertionPoint(); + if (cache) { + cache->insert(std::make_pair(key, result)); + *insertPt = rewriter.saveInsertionPoint(); + } return result; } } @@ -540,6 +601,8 @@ public: if (mmaLayout.isAmpere()) return emitOffsetForMmaLayoutV2(mmaLayout, type); } + if (auto sliceLayout = layout.dyn_cast()) + return emitOffsetForSliceLayout(sliceLayout, type); llvm_unreachable("unsupported emitOffsetForLayout"); } @@ -552,27 +615,29 @@ public: RankedTensorType type) const { IndexCacheKeyT key(layout, type); auto cache = indexCacheInfo.indexCache; - assert(cache && "indexCache is nullptr"); auto insertPt = indexCacheInfo.indexInsertPoint; - if (cache->count(key) > 0) { + if (cache && cache->count(key) > 0) { return cache->lookup(key); } else { ConversionPatternRewriter::InsertionGuard guard(b); - restoreInsertionPointIfSet(insertPt, b); + if (cache) + restoreInsertionPointIfSet(insertPt, b); SmallVector> result; if (auto blocked = layout.dyn_cast()) { result = emitIndicesForDistributedLayout(loc, b, blocked, type); } else if (auto mma = layout.dyn_cast()) { result = emitIndicesForDistributedLayout(loc, b, mma, type); } else if (auto slice = layout.dyn_cast()) { - result = emitIndicesForSliceLayout(loc, b, slice, type); + result = emitIndicesForDistributedLayout(loc, b, slice, type); } else { llvm_unreachable( "emitIndices for layouts other than blocked & slice not " "implemented yet"); } - cache->insert(std::make_pair(key, result)); - *insertPt = b.saveInsertionPoint(); + if (cache) { + cache->insert(std::make_pair(key, result)); + *insertPt = b.saveInsertionPoint(); + } return result; } } @@ -722,11 +787,11 @@ private: Value _fpw1 = i32_val(fpw[1]); // A info - auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout); + auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout, 0); auto aRep = aEncoding.getMMAv1Rep(); auto aSpw = aEncoding.getMMAv1ShapePerWarp(); // B info - auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout); + auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout, 0); auto bSpw = bEncoding.getMMAv1ShapePerWarp(); auto bRep = bEncoding.getMMAv1Rep(); @@ -783,12 +848,12 @@ private: // TODO: seems like the apttern below to get `rep`/`spw` appears quite often // A info auto aEncoding = - DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout); + DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout, 0); auto aRep = aEncoding.getMMAv1Rep(); auto aSpw = aEncoding.getMMAv1ShapePerWarp(); // B info auto bEncoding = - DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout); + DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout, 0); auto bSpw = bEncoding.getMMAv1ShapePerWarp(); auto bRep = bEncoding.getMMAv1Rep(); @@ -891,24 +956,29 @@ private: return multiDimIdx; } - SmallVector> - emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter, - const SliceEncodingAttr &sliceLayout, - RankedTensorType type) const { + SmallVector> + emitOffsetForSliceLayout(const SliceEncodingAttr &sliceLayout, + RankedTensorType type) const { auto parentEncoding = sliceLayout.getParent(); unsigned dim = sliceLayout.getDim(); auto parentShape = sliceLayout.paddedShape(type.getShape()); RankedTensorType parentTy = RankedTensorType::get( parentShape, type.getElementType(), parentEncoding); - auto parentIndices = emitIndices(loc, rewriter, parentEncoding, parentTy); - unsigned numIndices = parentIndices.size(); - SmallVector> resultIndices; - for (unsigned i = 0; i < numIndices; ++i) { - SmallVector indices = parentIndices[i]; - indices.erase(indices.begin() + dim); - resultIndices.push_back(indices); + auto parentOffsets = emitOffsetForLayout(parentEncoding, parentTy); + + unsigned numOffsets = parentOffsets.size(); + SmallVector> resultOffsets; + std::set> uniqueOffsets; + + for (unsigned i = 0; i < numOffsets; ++i) { + SmallVector offsets = parentOffsets[i]; + offsets.erase(offsets.begin() + dim); + if (uniqueOffsets.find(offsets) == uniqueOffsets.end()) { + resultOffsets.push_back(offsets); + uniqueOffsets.insert(offsets); + } } - return resultIndices; + return resultOffsets; } #ifdef USE_ROCM @@ -1057,8 +1127,7 @@ protected: protected: TritonGPUToLLVMTypeConverter *converter; - const Allocation *allocation; - Value smem; + ModuleAllocation *allocation; IndexCacheInfo indexCacheInfo; }; @@ -1075,16 +1144,22 @@ public: ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {} explicit ConvertTritonGPUOpToLLVMPattern( - TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation, - Value smem, PatternBenefit benefit = 1) + TritonGPUToLLVMTypeConverter &typeConverter, + IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), - ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {} + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, indexCacheInfo) {} explicit ConvertTritonGPUOpToLLVMPattern( - TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation, - Value smem, IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1) + TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation, + PatternBenefit benefit = 1) : ConvertOpToLLVMPattern(typeConverter, benefit), - ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem, + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation) {} + + explicit ConvertTritonGPUOpToLLVMPattern( + TritonGPUToLLVMTypeConverter &typeConverter, ModuleAllocation &allocation, + IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit), + ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, indexCacheInfo) {} protected: diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 5a9c2285c..49c5b7dce 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -2,7 +2,7 @@ #include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" -#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h" +#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" #include "mlir/Conversion/LLVMCommon/VectorPattern.h" @@ -61,17 +61,38 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern { LogicalResult matchAndRewrite(triton::ReturnOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - unsigned numArguments = op.getNumOperands(); - - // Currently, Triton kernel function always return nothing. - // TODO(Superjomn) add support for non-inline device function - if (numArguments > 0) { - return rewriter.notifyMatchFailure( - op, "Only kernel function with nothing returned is supported."); + auto funcOp = op->getParentOfType(); + if (funcOp->hasAttr("nvvm.kernel")) { + // A GPU kernel + if (op.getNumOperands() > 0) { + return rewriter.notifyMatchFailure( + op, "Kernel functions do not support return with operands"); + } + rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), + op->getAttrs()); + } else { + // A device function + LLVM::ReturnOp newOp; + if (adaptor.getOperands().size() < 2) { + // Single or no return value. + newOp = + rewriter.create(op.getLoc(), adaptor.getOperands()); + } else { + // Pack the results into a struct. + auto packedResultsTy = this->getTypeConverter()->packFunctionResults( + funcOp.getResultTypes()); + Value packedResults = + rewriter.create(op.getLoc(), packedResultsTy); + auto loc = op.getLoc(); + for (auto it : llvm::enumerate(adaptor.getOperands())) { + packedResults = insert_val(packedResultsTy, packedResults, it.value(), + it.index()); + } + newOp = rewriter.create(op.getLoc(), packedResults); + } + newOp->setAttrs(op->getAttrs()); + rewriter.replaceOp(op, newOp->getResults()); } - - rewriter.replaceOpWithNewOp(op, TypeRange(), ValueRange(), - op->getAttrs()); return success(); } }; @@ -81,19 +102,57 @@ struct ReturnOpConversion : public ConvertOpToLLVMPattern { /// information. struct FuncOpConversion : public FuncOpConversionBase { FuncOpConversion(LLVMTypeConverter &converter, int numWarps, - PatternBenefit benefit) - : FuncOpConversionBase(converter, benefit), numWarps(numWarps) {} + ModuleAllocation &allocation, PatternBenefit benefit) + : FuncOpConversionBase(converter, benefit), numWarps(numWarps), + allocation(allocation) {} + + triton::FuncOp amendFuncOp(triton::FuncOp funcOp, + ConversionPatternRewriter &rewriter) const { + // Push back a variable that indicates the current stack pointer of shared + // memory to the function arguments. + auto loc = funcOp.getLoc(); + auto ctx = funcOp->getContext(); + auto ptrTy = LLVM::LLVMPointerType::get( + this->getTypeConverter()->convertType(rewriter.getI8Type()), 3); + // 1. Modify the function type to add the new argument. + auto funcTy = funcOp.getFunctionType(); + auto amendedInputTy = llvm::to_vector<4>(funcTy.getInputs()); + amendedInputTy.push_back(ptrTy); + auto amendedFuncTy = FunctionType::get(funcTy.getContext(), amendedInputTy, + funcTy.getResults()); + // 2. Modify the argument attributes to add the new argument. + SmallVector amendedAttrs; + filterFuncAttributes(funcOp, /*filterArgAttrs=*/true, amendedAttrs); + auto amendedArgAttrs = llvm::to_vector<4>(funcOp.getAllArgAttrs()); + amendedArgAttrs.emplace_back(DictionaryAttr::get(ctx)); + amendedAttrs.push_back(rewriter.getNamedAttr( + funcOp.getArgAttrsAttrName(), rewriter.getArrayAttr(amendedArgAttrs))); + // 3. Add a new argument to the region + auto amendedFuncOp = rewriter.create( + funcOp.getLoc(), funcOp.getName(), amendedFuncTy, amendedAttrs); + auto ®ion = funcOp.getBody(); + region.addArgument(ptrTy, loc); + rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(), + amendedFuncOp.end()); + return amendedFuncOp; + } LogicalResult matchAndRewrite(triton::FuncOp funcOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter); + // Prevent LLVM's inliner to inline this function + auto amendedFuncOp = funcOp; + if (!allocation.isRoot(funcOp)) + amendedFuncOp = amendFuncOp(funcOp, rewriter); + + auto newFuncOp = convertFuncOpToLLVMFuncOp(amendedFuncOp, rewriter); if (!newFuncOp) { return failure(); } auto ctx = funcOp->getContext(); +<<<<<<< HEAD // Set an attribute to indicate this function is a kernel entry. newFuncOp->setAttr("nvvm.kernel", rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); @@ -102,6 +161,25 @@ struct FuncOpConversion : public FuncOpConversionBase { // for `nvvm.annotation` metadata. newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps)); #endif +======= + if (allocation.isRoot(funcOp)) { + // Set an attribute to indicate this function is a kernel entry. + newFuncOp->setAttr("nvvm.kernel", + rewriter.getIntegerAttr(type::u1Ty(ctx), 1)); + } else { + // The noinline attribute will be used by the LLVM codegen to prevent + // inlining. + // https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/LLVMIR/IR/LLVMInlining.cpp#L267 + newFuncOp.setPassthroughAttr( + ArrayAttr::get(ctx, rewriter.getStringAttr("noinline"))); + rewriter.eraseOp(amendedFuncOp); + } + // Set an attribute for maxntidx, it could be used in latter LLVM codegen + // for `nvvm.annotation` metadata. + newFuncOp->setAttr("nvvm.maxntid", rewriter.getI32ArrayAttr(32 * numWarps)); + // The call graph is updated by mapping the old function to the new one. + allocation.mapFuncOp(funcOp, newFuncOp); +>>>>>>> openai/main rewriter.eraseOp(funcOp); return success(); @@ -109,6 +187,99 @@ struct FuncOpConversion : public FuncOpConversionBase { private: int numWarps{0}; + ModuleAllocation &allocation; +}; + +// CallOpInterfaceLowering is adapted from +// https://github.com/llvm/llvm-project/blob/fae656b2dd80246c3c6f01e9c77c49560368752c/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp#L485 +struct CallOpConversion : public ConvertOpToLLVMPattern { + CallOpConversion(LLVMTypeConverter &converter, int numWarps, + ModuleAllocation &allocation, PatternBenefit benefit) + : ConvertOpToLLVMPattern(converter, benefit), + numWarps(numWarps), allocation(allocation) {} + + LogicalResult + matchAndRewrite(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto promotedOperands = promoteOperands(callOp, adaptor, rewriter); + auto newCallOp = + convertCallOpToLLVMCallOp(callOp, promotedOperands, rewriter); + if (!newCallOp) + return failure(); + allocation.mapCallOp(callOp, newCallOp); + auto results = getCallOpResults(callOp, newCallOp, rewriter); + rewriter.replaceOp(callOp, results); + return success(); + } + +private: + SmallVector + promoteOperands(triton::CallOp callOp, + typename triton::CallOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + // Get the last argument of the caller, which is the current stack pointer + // of shared memory and append it to the operands of the callOp. + auto loc = callOp.getLoc(); + auto caller = callOp->getParentOfType(); + auto base = allocation.getFunctionSharedMemoryBase(caller); + auto *funcAllocation = allocation.getFuncData(caller); + auto bufferId = funcAllocation->getBufferId(callOp); + auto offset = funcAllocation->getOffset(bufferId); + auto ptrTy = LLVM::LLVMPointerType::get( + this->getTypeConverter()->convertType(rewriter.getI8Type()), + NVVM::kSharedMemorySpace); + auto offsetValue = gep(ptrTy, base, i32_val(offset)); + auto promotedOperands = this->getTypeConverter()->promoteOperands( + callOp.getLoc(), /*opOperands=*/callOp->getOperands(), + adaptor.getOperands(), rewriter); + promotedOperands.push_back(offsetValue); + return promotedOperands; + } + + LLVM::CallOp + convertCallOpToLLVMCallOp(triton::CallOp callOp, + ArrayRef promotedOperands, + ConversionPatternRewriter &rewriter) const { + // Pack the result types into a struct. + Type packedResult = nullptr; + unsigned numResults = callOp.getNumResults(); + auto resultTypes = llvm::to_vector<4>(callOp.getResultTypes()); + + if (numResults != 0) { + if (!(packedResult = + this->getTypeConverter()->packFunctionResults(resultTypes))) + return nullptr; + } + + auto newCallOp = rewriter.create( + callOp.getLoc(), packedResult ? TypeRange(packedResult) : TypeRange(), + promotedOperands, callOp->getAttrs()); + return newCallOp; + } + + SmallVector + getCallOpResults(triton::CallOp callOp, LLVM::CallOp newCallOp, + ConversionPatternRewriter &rewriter) const { + auto numResults = callOp.getNumResults(); + SmallVector results; + if (numResults < 2) { + // If < 2 results, packing did not do anything and we can just return. + results.append(newCallOp.result_begin(), newCallOp.result_end()); + } else { + // Otherwise, it had been converted to an operation producing a structure. + // Extract individual results from the structure and return them as list. + results.reserve(numResults); + for (unsigned i = 0; i < numResults; ++i) { + results.push_back(rewriter.create( + callOp.getLoc(), newCallOp->getResult(0), i)); + } + } + return results; + } + + int numWarps{0}; + ModuleAllocation &allocation; }; class TritonLLVMConversionTarget : public ConversionTarget { @@ -145,26 +316,25 @@ public: TritonLLVMConversionTarget target(*context, isROCM); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - /* preprocess */ + // Preprocess decomposeMmaToDotOperand(mod, numWarps); decomposeBlockedToDotOperand(mod); if (failed(decomposeInsertSliceAsyncOp(mod))) return signalPassFailure(); - /* allocate shared memory and set barrier */ - Allocation allocation(mod); - MembarAnalysis membarPass(&allocation); + // Allocate shared memory and set barrier + ModuleAllocation allocation(mod); + ModuleMembarAnalysis membarPass(&allocation); membarPass.run(); - /* lower functions */ + // Lower functions { mlir::LowerToLLVMOptions option(context); TritonGPUToLLVMTypeConverter typeConverter(context, option); TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM); RewritePatternSet funcPatterns(context); - funcPatterns.add(typeConverter, numWarps, + funcPatterns.add(typeConverter, numWarps, allocation, /*benefit=*/1); - funcPatterns.add(typeConverter); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, funcPatterns); if (failed( @@ -172,36 +342,49 @@ public: return signalPassFailure(); } - std::unique_ptr solver = createDataFlowSolver(); - AxisInfoAnalysis *axisInfoAnalysis = solver->load(); - if (failed(solver->initializeAndRun(mod))) - return signalPassFailure(); - initSharedMemory(allocation.getSharedMemorySize(), typeConverter); - mod->setAttr("triton_gpu.shared", - mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32), - allocation.getSharedMemorySize())); + // initSharedMemory is run before the conversion of call and ret ops, + // because the call op has to know the shared memory base address of each + // function + initSharedMemory(allocation, typeConverter); - /* rewrite ops */ + // Convert call and ret ops + { + mlir::LowerToLLVMOptions option(context); + TritonGPUToLLVMTypeConverter typeConverter(context, option); + TritonLLVMFunctionConversionTarget funcTarget(*context, isROCM); + RewritePatternSet funcPatterns(context); + funcPatterns.add(typeConverter, numWarps, allocation, + /*benefit=*/1); + funcPatterns.add(typeConverter, /*benefit=*/1); + if (failed( + applyPartialConversion(mod, funcTarget, std::move(funcPatterns)))) + return signalPassFailure(); + } + + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); + // Rewrite ops RewritePatternSet patterns(context); // TritonGPU lowering patterns OpBuilder::InsertPoint indexInsertPoint; ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{ &baseIndexCache, &indexCache, &indexInsertPoint}; - auto populatePatterns1 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis, - &allocation, smem, indexCacheInfo, /*benefit*/ 1); - }; - auto populatePatterns2 = [&](auto populateFunc) { - populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis, - &allocation, smem, /*benefit*/ 1); - }; - populatePatterns1(populateTritonGPUToLLVMPatterns); - populatePatterns1(populateConvertLayoutOpToLLVMPatterns); - populatePatterns2(populateDotOpToLLVMPatterns); - populatePatterns2(populateElementwiseOpToLLVMPatterns); - populatePatterns1(populateLoadStoreOpToLLVMPatterns); - populatePatterns1(populateReduceOpToLLVMPatterns); - populatePatterns2(populateViewOpToLLVMPatterns); + // TODO: enable index cache if there are multiple functions + if (axisInfoAnalysis.getNumFunctions() > 1) { + indexCacheInfo = {nullptr, nullptr, nullptr}; + } + populateTritonGPUToLLVMPatterns(typeConverter, patterns, allocation, + indexCacheInfo, /*benefit=*/1); + populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, allocation, + indexCacheInfo, /*benefit=*/1); + populateDotOpToLLVMPatterns(typeConverter, patterns, allocation, + /*benefit=*/1); + populateElementwiseOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1); + populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, + allocation, indexCacheInfo, + /*benefit=*/1); + populateReduceOpToLLVMPatterns(typeConverter, patterns, allocation, + indexCacheInfo, /*benefit=*/1); + populateViewOpToLLVMPatterns(typeConverter, patterns, /*benefit=*/1); // Native lowering patterns if (isROCM) { @@ -218,8 +401,6 @@ public: } private: - Value smem; - using IndexCacheKeyT = std::pair; DenseMap, CacheKeyDenseMapInfo> baseIndexCache; @@ -230,10 +411,11 @@ private: int computeCapability{}; bool isROCM{}; - void initSharedMemory(size_t size, + void initSharedMemory(ModuleAllocation &allocation, TritonGPUToLLVMTypeConverter &typeConverter) { ModuleOp mod = getOperation(); OpBuilder b(mod.getBodyRegion()); + auto ctx = mod.getContext(); auto loc = mod.getLoc(); auto elemTy = typeConverter.convertType(b.getIntegerType(8)); // Set array size 0 and external linkage indicates that we use dynamic @@ -244,15 +426,23 @@ private: "global_smem", /*value=*/Attribute(), /*alignment=*/0, // Add ROCm support. static_cast(NVVM::NVVMMemorySpace::kSharedMemorySpace)); - SmallVector funcs; - mod.walk([&](LLVM::LLVMFuncOp func) { funcs.push_back(func); }); - assert(funcs.size() == 1 && - "Inliner pass is expected before TritonGPUToLLVM"); - b.setInsertionPointToStart(&funcs[0].getBody().front()); - smem = b.create(loc, global); - auto ptrTy = - LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), 3); - smem = b.create(loc, ptrTy, smem); + mod.walk([&](FunctionOpInterface funcOp) { + Value funcSmem; + b.setInsertionPointToStart(&funcOp.getFunctionBody().front()); + if (allocation.isRoot(funcOp)) { + funcSmem = b.create(loc, global); + } else { + funcSmem = funcOp.getArgument(funcOp.getNumArguments() - 1); + } + auto ptrTy = + LLVM::LLVMPointerType::get(typeConverter.convertType(b.getI8Type()), + NVVM::NVVMMemorySpace::kSharedMemorySpace); + funcSmem = b.create(loc, ptrTy, funcSmem); + allocation.setFunctionSharedMemoryValue(funcOp, funcSmem); + }); + mod->setAttr("triton_gpu.shared", + mlir::IntegerAttr::get(mlir::IntegerType::get(ctx, 32), + allocation.getSharedMemorySize())); } void decomposeMmaToDotOperand(ModuleOp mod, int numWarps) const { @@ -310,10 +500,7 @@ private: } LogicalResult decomposeInsertSliceAsyncOp(ModuleOp mod) const { - std::unique_ptr solver = createDataFlowSolver(); - AxisInfoAnalysis *axisInfoAnalysis = solver->load(); - if (failed(solver->initializeAndRun(mod))) - return failure(); + ModuleAxisInfoAnalysis axisInfoAnalysis(mod); // TODO(Keren): This is a hacky knob that may cause performance regression // when decomposition has been performed. We should remove this knob once we // have thorough analysis on async wait. Currently, we decompose @@ -347,7 +534,7 @@ private: auto resSharedLayout = dstTy.getEncoding().dyn_cast(); auto resElemTy = dstTy.getElementType(); - unsigned inVec = axisInfoAnalysis->getPtrContiguity(src); + unsigned inVec = axisInfoAnalysis.getPtrContiguity(src); unsigned outVec = resSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); auto maxBitWidth = diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index 347988b83..ef8c474bd 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -106,17 +106,8 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( return elemTy; if (mmaParent.isAmpere()) { int bitwidth = elemTy.getIntOrFloatBitWidth(); - // sub-word integer types need to be packed for perf reasons - if (elemTy.isa() && bitwidth < 32) - return IntegerType::get(ctx, 32); - // TODO: unify everything to use packed integer-types - // otherwise, vector types are ok - const llvm::DenseMap elemTyMap = { - {32, vec_ty(elemTy, 1)}, - {16, vec_ty(elemTy, 2)}, - {8, vec_ty(elemTy, 4)}, - }; - return elemTyMap.lookup(bitwidth); + assert(bitwidth <= 32); + return IntegerType::get(ctx, 32); } else { assert(mmaParent.isVolta()); return vec_ty(elemTy, 2); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 9bf1ff5e4..d7f029acc 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -88,6 +88,7 @@ #define call(...) rewriter.create(loc, __VA_ARGS__) // Types +#define int_ty(width) rewriter.getIntegerType(width) #define i64_ty rewriter.getIntegerType(64) #define i32_ty rewriter.getIntegerType(32) #define i16_ty rewriter.getIntegerType(16) diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp index 56759e842..000903a55 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp @@ -116,23 +116,64 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern { } }; -template -struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern { - using OpAdaptor = typename SourceOp::Adaptor; - explicit ViewLikeOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, - PatternBenefit benefit = 1) - : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} +struct ViewOpConversion : public ConvertTritonGPUOpToLLVMPattern { + using OpAdaptor = typename ViewOp::Adaptor; + explicit ViewOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} LogicalResult - matchAndRewrite(SourceOp op, OpAdaptor adaptor, + matchAndRewrite(ViewOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op->getLoc(); auto resultTy = op.getType().template cast(); auto vals = this->getTypeConverter()->unpackLLElements( loc, adaptor.getSrc(), rewriter, op.getOperand().getType()); - Value view = + Value ret = this->getTypeConverter()->packLLElements(loc, vals, rewriter, resultTy); - rewriter.replaceOp(op, view); + rewriter.replaceOp(op, ret); + return success(); + } +}; + +struct ExpandDimsOpConversion + : public ConvertTritonGPUOpToLLVMPattern { + using OpAdaptor = typename ExpandDimsOp::Adaptor; + explicit ExpandDimsOpConversion(TritonGPUToLLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit) {} + + LogicalResult + matchAndRewrite(ExpandDimsOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto srcVals = this->getTypeConverter()->unpackLLElements( + loc, adaptor.getSrc(), rewriter, op.getOperand().getType()); + + auto srcTy = op.getSrc().getType().cast(); + auto resultTy = op.getType().template cast(); + + assert(srcTy.getEncoding().isa() && + "ExpandDimsOp only support SliceEncodingAttr"); + auto srcLayout = srcTy.getEncoding().dyn_cast(); + auto resultLayout = resultTy.getEncoding(); + + auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy); + auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy); + DenseMap, Value, SmallVectorKeyInfo> srcValues; + for (size_t i = 0; i < srcOffsets.size(); i++) { + srcValues[srcOffsets[i]] = srcVals[i]; + } + + SmallVector resultVals; + for (size_t i = 0; i < resultOffsets.size(); i++) { + auto offset = resultOffsets[i]; + offset.erase(offset.begin() + srcLayout.getDim()); + resultVals.push_back(srcValues.lookup(offset)); + } + Value ret = this->getTypeConverter()->packLLElements(loc, resultVals, + rewriter, resultTy); + rewriter.replaceOp(op, ret); return success(); } }; @@ -161,13 +202,10 @@ struct TransOpConversion }; void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + RewritePatternSet &patterns, PatternBenefit benefit) { - patterns.add>(typeConverter, benefit); - patterns.add>(typeConverter, - benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.h index 84a83ad97..1a8aef391 100644 --- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.h @@ -7,9 +7,7 @@ using namespace mlir; using namespace mlir::triton; void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter, - RewritePatternSet &patterns, int numWarps, - AxisInfoAnalysis &axisInfoAnalysis, - const Allocation *allocation, Value smem, + RewritePatternSet &patterns, PatternBenefit benefit); #endif diff --git a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt index 9d7f961cc..834d10a4d 100644 --- a/lib/Conversion/TritonToTritonGPU/CMakeLists.txt +++ b/lib/Conversion/TritonToTritonGPU/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_conversion_library(TritonToTritonGPU LINK_LIBS PUBLIC MLIRIR MLIRPass + MLIRTransforms TritonIR TritonGPUIR TritonGPUTransforms diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index 08b9b45be..da5be1ba2 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -67,19 +67,21 @@ public: matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Type retType = getTypeConverter()->convertType(op.getType()); + auto retShapedType = retType.cast(); auto value = adaptor.getValue().dyn_cast(); - if (dyn_cast(retType)) { + if (dyn_cast(retShapedType)) { assert(value); if (value.getElementType().isInteger(1) && value.isSplat()) // Workaround until https://reviews.llvm.org/D133743 is included. - value = DenseElementsAttr::get(retType, value.getSplatValue()); + value = + DenseElementsAttr::get(retShapedType, value.getSplatValue()); else // This is a hack. We just want to add encoding - value = value.reshape(retType); + value = value.reshape(retShapedType); } - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, retType, value), - adaptor.getAttributes()); + addNamedAttrs(rewriter.replaceOpWithNewOp( + op, retShapedType, value), + adaptor.getAttributes()); return success(); } }; @@ -165,8 +167,6 @@ void populateStdPatternsAndLegality(TritonGPUTypeConverter &typeConverter, MLIRContext *context = patterns.getContext(); // Rewrite rule patterns.add(typeConverter, context); - target.addLegalOp(); // this is ok because all functions are - // inlined by the frontend } void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, @@ -274,6 +274,8 @@ struct TritonDotPattern : public OpConversionPattern { // a & b must be of smem layout auto aType = adaptor.getA().getType().cast(); auto bType = adaptor.getB().getType().cast(); + Type aEltType = aType.getElementType(); + Type bEltType = bType.getElementType(); Attribute aEncoding = aType.getEncoding(); Attribute bEncoding = bType.getEncoding(); if (!aEncoding || !bEncoding) @@ -282,17 +284,17 @@ struct TritonDotPattern : public OpConversionPattern { Value b = adaptor.getB(); Value c = adaptor.getC(); if (!aEncoding.isa()) { - Attribute encoding = - triton::gpu::DotOperandEncodingAttr::get(getContext(), 0, dEncoding); - auto dstType = RankedTensorType::get(aType.getShape(), - aType.getElementType(), encoding); + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 0, dEncoding, aEltType); + auto dstType = + RankedTensorType::get(aType.getShape(), aEltType, encoding); a = rewriter.create(a.getLoc(), dstType, a); } if (!bEncoding.isa()) { - Attribute encoding = - triton::gpu::DotOperandEncodingAttr::get(getContext(), 1, dEncoding); - auto dstType = RankedTensorType::get(bType.getShape(), - bType.getElementType(), encoding); + Attribute encoding = triton::gpu::DotOperandEncodingAttr::get( + getContext(), 1, dEncoding, bEltType); + auto dstType = + RankedTensorType::get(bType.getShape(), bEltType, encoding); b = rewriter.create(b.getLoc(), dstType, b); } c = rewriter.create(c.getLoc(), retType, c); @@ -533,6 +535,52 @@ struct TritonAssertPattern : public OpConversionPattern { } }; +class TritonFuncOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = getTypeConverter(); + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getName(), op.getFunctionType()); + addNamedAttrs(newOp, adaptor.getAttributes()); + rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), + newOp.getBody().end()); + if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter))) + return failure(); + + return success(); + } +}; + +class TritonCallOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(triton::CallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto newOp = rewriter.replaceOpWithNewOp( + op, op.getCallee(), op.getResultTypes(), adaptor.getOperands()); + addNamedAttrs(newOp, adaptor.getAttributes()); + return success(); + } +}; + +class TritonReturnOpPattern : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ReturnOp op, ReturnOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + return success(); + } +}; + void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); @@ -550,7 +598,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, TritonLoadPattern, TritonStorePattern, TritonExternElementwisePattern, TritonExternElementwisePattern, - TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern>( + TritonPrintPattern, TritonAssertPattern, TritonAtomicRMWPattern, + TritonFuncOpPattern, TritonReturnOpPattern, TritonCallOpPattern>( typeConverter, context); } @@ -752,31 +801,10 @@ public: } }; -class FuncOpPattern : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::FuncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto converter = getTypeConverter(); - auto newOp = rewriter.replaceOpWithNewOp( - op, op.getName(), op.getFunctionType()); - addNamedAttrs(newOp, adaptor.getAttributes()); - rewriter.inlineRegionBefore(op.getBody(), newOp.getBody(), - newOp.getBody().end()); - if (failed(rewriter.convertRegionTypes(&newOp.getBody(), *converter))) - return failure(); - - return success(); - } -}; - void populateCFPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add( - typeConverter, context); + patterns.add(typeConverter, context); } // diff --git a/lib/Dialect/Triton/IR/CMakeLists.txt b/lib/Dialect/Triton/IR/CMakeLists.txt index 9488db8d8..6ee110718 100644 --- a/lib/Dialect/Triton/IR/CMakeLists.txt +++ b/lib/Dialect/Triton/IR/CMakeLists.txt @@ -1,5 +1,4 @@ add_mlir_dialect_library(TritonIR - Interfaces.cpp Dialect.cpp Ops.cpp Types.cpp @@ -11,5 +10,6 @@ add_mlir_dialect_library(TritonIR LINK_LIBS PUBLIC MLIRIR MLIRArithDialect + MLIRMathDialect MLIRSCFDialect ) diff --git a/lib/Dialect/Triton/IR/Dialect.cpp b/lib/Dialect/Triton/IR/Dialect.cpp index 9d1e5ae44..086279123 100644 --- a/lib/Dialect/Triton/IR/Dialect.cpp +++ b/lib/Dialect/Triton/IR/Dialect.cpp @@ -22,14 +22,22 @@ using namespace mlir::triton; namespace { struct TritonInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final { + auto funcOp = dyn_cast(callable); + if (!funcOp) + return true; + if (funcOp->hasAttr("noinline")) + return !funcOp->getAttrOfType("noinline").getValue(); return true; } + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, IRMapping &valueMapping) const final { return true; } + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, IRMapping &) const final { return true; @@ -83,5 +91,5 @@ void TritonDialect::initialize() { Operation *TritonDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { - return builder.create(loc, type, value); + return arith::ConstantOp::materialize(builder, value, type, loc); } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 3d97b4e98..86ea91077 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -230,6 +230,54 @@ void LoadOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, state.addTypes({resultType}); } +// load(ptr, splat(1), ...) -> load(ptr, ...) +// load(ptr, splat(0), other, ...) -> other +struct CanonicalizeMaskedLoadPattern + : public mlir::OpRewritePattern { + CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::LoadOp loadOp, + mlir::PatternRewriter &rewriter) const override { + auto mask = loadOp.getMask(); + if (!mask) + return mlir::failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return mlir::failure(); + + auto splatMask = constantMask.getValue().dyn_cast(); + if (!splatMask) + return mlir::failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), + loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + } else { + // mask = splat(0) + + // If there's no "other", the value is "undef". Perhaps we want to + // optimize it in the future.x + auto otherVal = loadOp.getOther(); + if (!otherVal) + return mlir::failure(); + rewriter.replaceOp(loadOp, otherVal); + } + return mlir::success(); + } +}; + +void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //-- StoreOp -- void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, ::mlir::Value ptr, ::mlir::Value value, @@ -257,10 +305,51 @@ void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state, evict); } +// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) +// store(ptr, value, splat(0), ...) -> [none] +struct CanonicalizeMaskedStorePattern + : public mlir::OpRewritePattern { + CanonicalizeMaskedStorePattern(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(triton::StoreOp storeOp, + mlir::PatternRewriter &rewriter) const override { + auto mask = storeOp.getMask(); + if (!mask) + return mlir::failure(); + + auto constantMask = + llvm::dyn_cast_or_null(mask.getDefiningOp()); + if (!constantMask) + return mlir::failure(); + + auto splatMask = constantMask.getValue().dyn_cast(); + if (!splatMask) + return mlir::failure(); + + if (splatMask.getSplatValue().getValue() == true) { + // mask = splat(1) + rewriter.replaceOpWithNewOp( + storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), + storeOp.getEvict()); + } else { + // mask = splat(0) + rewriter.eraseOp(storeOp); + } + return mlir::success(); + } +}; + +void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, + MLIRContext *context) { + results.add(context); +} + //-- TransOp -- mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { // type is the same as the input auto argTy = operands[0].getType().cast(); @@ -287,7 +376,7 @@ mlir::LogicalResult mlir::triton::TransOp::inferReturnTypes( //-- DotOp -- mlir::LogicalResult mlir::triton::DotOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { // type is the same as the accumulator auto accTy = operands[2].getType().cast(); @@ -355,7 +444,7 @@ void ReduceOp::build(mlir::OpBuilder &builder, mlir::OperationState &state, mlir::LogicalResult mlir::triton::ReduceOp::inferReturnTypes( MLIRContext *context, std::optional location, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { for (auto arg : operands) { auto argTy = arg.getType().cast(); @@ -462,7 +551,7 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { //-- ExpandDimsOp -- mlir::LogicalResult mlir::triton::ExpandDimsOp::inferReturnTypes( MLIRContext *context, std::optional loc, ValueRange operands, - DictionaryAttr attributes, RegionRange regions, + DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnTypes) { // infer shape auto arg = operands[0]; diff --git a/lib/Dialect/Triton/Transforms/CMakeLists.txt b/lib/Dialect/Triton/Transforms/CMakeLists.txt index b6e3b1f54..dcd050a45 100644 --- a/lib/Dialect/Triton/Transforms/CMakeLists.txt +++ b/lib/Dialect/Triton/Transforms/CMakeLists.txt @@ -9,4 +9,9 @@ add_mlir_dialect_library(TritonTransforms DEPENDS TritonTransformsIncGen TritonCombineIncGen + + LINK_LIBS PUBLIC + MLIRPass + MLIRTransformUtils + TritonIR ) diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 850182366..8e0059382 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -37,7 +37,7 @@ bool isBroadcastConstantCombinable(Attribute value) { DenseElementsAttr getConstantValue(Builder &builder, Attribute value, Value bcast_res) { - Type resType = bcast_res.getType(); + auto resType = bcast_res.getType().cast(); DenseElementsAttr res; if (auto denseValue = value.dyn_cast()) { res = @@ -101,95 +101,6 @@ public: } }; -// load(ptr, splat(1), ...) -> load(ptr, ...) -// load(ptr, splat(0), other, ...) -> other -struct CanonicalizeMaskedLoadPattern - : public mlir::OpRewritePattern { - CanonicalizeMaskedLoadPattern(mlir::MLIRContext *context) - : OpRewritePattern(context, 1) {} - - mlir::LogicalResult - matchAndRewrite(triton::LoadOp loadOp, - mlir::PatternRewriter &rewriter) const override { - auto mask = loadOp.getMask(); - if (!mask) - return mlir::failure(); - - auto constantMask = - llvm::dyn_cast_or_null(mask.getDefiningOp()); - if (!constantMask) - return mlir::failure(); - - auto splatMask = constantMask.getValue().dyn_cast(); - if (!splatMask) - return mlir::failure(); - - if (splatMask.getSplatValue().getValue() == true) { - // mask = splat(1) - rewriter.replaceOpWithNewOp( - loadOp, loadOp.getType(), loadOp.getPtr(), Value(), Value(), - loadOp.getBoundaryCheckAttr(), loadOp.getPaddingAttr(), - loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); - } else { - // mask = splat(0) - - // If there's no "other", the value is "undef". Perhaps we want to - // optimize it in the future.x - auto otherVal = loadOp.getOther(); - if (!otherVal) - return mlir::failure(); - rewriter.replaceOp(loadOp, otherVal); - } - return mlir::success(); - } -}; - -void triton::LoadOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - -// store(ptr, value, splat(1), ...) -> store(ptr, value, ...) -// store(ptr, value, splat(0), ...) -> [none] -struct CanonicalizeMaskedStorePattern - : public mlir::OpRewritePattern { - CanonicalizeMaskedStorePattern(mlir::MLIRContext *context) - : OpRewritePattern(context, 1) {} - - mlir::LogicalResult - matchAndRewrite(triton::StoreOp storeOp, - mlir::PatternRewriter &rewriter) const override { - auto mask = storeOp.getMask(); - if (!mask) - return mlir::failure(); - - auto constantMask = - llvm::dyn_cast_or_null(mask.getDefiningOp()); - if (!constantMask) - return mlir::failure(); - - auto splatMask = constantMask.getValue().dyn_cast(); - if (!splatMask) - return mlir::failure(); - - if (splatMask.getSplatValue().getValue() == true) { - // mask = splat(1) - rewriter.replaceOpWithNewOp( - storeOp, storeOp.getPtr(), storeOp.getValue(), storeOp.getCache(), - storeOp.getEvict()); - } else { - // mask = splat(0) - rewriter.eraseOp(storeOp); - } - return mlir::success(); - } -}; - -void triton::StoreOp::getCanonicalizationPatterns(RewritePatternSet &results, - MLIRContext *context) { - results.add(context); -} - #define GEN_PASS_CLASSES #include "triton/Dialect/Triton/Transforms/Passes.h.inc" diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 5c174bb46..89a6e9168 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -167,10 +167,10 @@ public: auto otherTensorType = RankedTensorType::get(tensorShape, elementType); // Set zero padding value - Attribute attr = + TypedAttr attr = elementType.isIntOrIndex() - ? builder.getIntegerAttr(elementType, 0).cast() - : builder.getFloatAttr(elementType, 0).cast(); + ? builder.getIntegerAttr(elementType, 0).cast() + : builder.getFloatAttr(elementType, 0).cast(); // Float NaN padding case if (padding.value() == triton::PaddingOption::PAD_NAN) { diff --git a/lib/Dialect/TritonGPU/IR/CMakeLists.txt b/lib/Dialect/TritonGPU/IR/CMakeLists.txt index 903dfc318..20f6f9851 100644 --- a/lib/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/IR/CMakeLists.txt @@ -7,5 +7,6 @@ add_mlir_dialect_library(TritonGPUIR TritonGPUAttrDefsIncGen LINK_LIBS PUBLIC + MLIRGPUOps TritonIR ) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 4fd648b7f..4d6bf9d17 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -81,10 +81,41 @@ SmallVector getThreadsPerWarp(Attribute layout) { if (mmaLayout.isAmpere()) return {8, 4}; } + if (auto sliceLayout = layout.dyn_cast()) { + auto parent = sliceLayout.getParent(); + auto parentThreadsPerWarp = getThreadsPerWarp(parent); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + for (unsigned i = 0; i < threadsPerWarp.size(); i++) + threadsPerWarp[i] *= parentThreadsPerWarp[sliceLayout.getDim()]; + return threadsPerWarp; + } assert(0 && "getThreadsPerWarp not implemented"); return {}; } +SmallVector +getThreadsPerWarpWithUniqueData(Attribute layout, + ArrayRef tensorShape) { + if (auto sliceLayout = layout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentThreadsPerWarp = + getThreadsPerWarpWithUniqueData(parentLayout, parentShape); + SmallVector threadsPerWarp = parentThreadsPerWarp; + threadsPerWarp.erase(threadsPerWarp.begin() + sliceLayout.getDim()); + return threadsPerWarp; + } + auto threadsPerWarp = getThreadsPerWarp(layout); + assert(threadsPerWarp.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < threadsPerWarp.size(); i++) { + threadsPerWarp[i] = std::min(threadsPerWarp[i], tensorShape[i]); + } + + return threadsPerWarp; +} + SmallVector getWarpsPerCTA(Attribute layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getWarpsPerCTA().begin(), @@ -94,19 +125,51 @@ SmallVector getWarpsPerCTA(Attribute layout) { return SmallVector(mmaLayout.getWarpsPerCTA().begin(), mmaLayout.getWarpsPerCTA().end()); } + if (auto sliceLayout = layout.dyn_cast()) { + auto parent = sliceLayout.getParent(); + auto parentWarpsPerCTA = getWarpsPerCTA(parent); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); + for (unsigned i = 0; i < warpsPerCTA.size(); i++) + warpsPerCTA[i] *= parentWarpsPerCTA[sliceLayout.getDim()]; + return warpsPerCTA; + } assert(0 && "getWarpsPerCTA not implemented"); return {}; } +SmallVector +getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef tensorShape) { + if (auto sliceLayout = layout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(tensorShape); + auto parentWarpsPerCTA = + getWarpsPerCTAWithUniqueData(parentLayout, parentShape); + SmallVector warpsPerCTA = parentWarpsPerCTA; + warpsPerCTA.erase(warpsPerCTA.begin() + sliceLayout.getDim()); + return warpsPerCTA; + } + auto warpsPerCTA = getWarpsPerCTA(layout); + assert(warpsPerCTA.size() == tensorShape.size() && + "layout and tensor shape must have the same rank"); + for (unsigned i = 0; i < warpsPerCTA.size(); i++) { + auto sizePerWarp = + getSizePerThread(layout)[i] * getThreadsPerWarp(layout)[i]; + auto maxWarpsPerDim = ceil(tensorShape[i], sizePerWarp); + warpsPerCTA[i] = std::min(warpsPerCTA[i], maxWarpsPerDim); + } + + return warpsPerCTA; +} + SmallVector getSizePerThread(Attribute layout) { if (auto blockedLayout = layout.dyn_cast()) { return SmallVector(blockedLayout.getSizePerThread().begin(), blockedLayout.getSizePerThread().end()); } else if (auto sliceLayout = layout.dyn_cast()) { - auto ret = getSizePerThread(sliceLayout.getParent()); - return ret; - // ret.erase(ret.begin() + sliceLayout.getDim()); - return ret; + auto sizePerThread = getSizePerThread(sliceLayout.getParent()); + sizePerThread.erase(sizePerThread.begin() + sliceLayout.getDim()); + return sizePerThread; } else if (auto mmaLayout = layout.dyn_cast()) { if (mmaLayout.isAmpere()) { return {2, 2}; @@ -146,11 +209,43 @@ SmallVector getContigPerThread(Attribute layout) { if (auto mmaLayout = layout.dyn_cast()) { assert(mmaLayout.isVolta() || mmaLayout.isAmpere()); return {1, 2}; + } else if (auto sliceLayout = layout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + return getContigPerThread(parentLayout); } else { return getSizePerThread(layout); } } +SmallVector getUniqueContigPerThread(Type type) { + if (type.isIntOrIndexOrFloat() || type.isa()) + return SmallVector(1, 1); + auto tensorType = type.cast(); + auto shape = tensorType.getShape(); + // If slice layout, call recursively on parent layout, and drop + // sliced dim + if (auto sliceLayout = + tensorType.getEncoding().dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto parentShape = sliceLayout.paddedShape(shape); + auto parentTy = RankedTensorType::get( + parentShape, tensorType.getElementType(), parentLayout); + auto parentUniqueContigPerThread = getUniqueContigPerThread(parentTy); + parentUniqueContigPerThread.erase(parentUniqueContigPerThread.begin() + + sliceLayout.getDim()); + return parentUniqueContigPerThread; + } + // Base case + auto rank = shape.size(); + SmallVector ret(rank); + auto contigPerThread = getContigPerThread(tensorType.getEncoding()); + assert(contigPerThread.size() == rank && "Unexpected contigPerThread size"); + for (int d = 0; d < rank; ++d) { + ret[d] = std::min(shape[d], contigPerThread[d]); + } + return ret; +} + SmallVector getThreadsPerCTA(Attribute layout) { SmallVector threads; if (auto blockedLayout = layout.dyn_cast()) { @@ -158,7 +253,7 @@ SmallVector getThreadsPerCTA(Attribute layout) { threads.push_back(blockedLayout.getThreadsPerWarp()[d] * blockedLayout.getWarpsPerCTA()[d]); } else if (auto mmaLayout = layout.dyn_cast()) { - if (mmaLayout.getVersionMajor() == 2) { + if (mmaLayout.isAmpere()) { threads = {8 * mmaLayout.getWarpsPerCTA()[0], 4 * mmaLayout.getWarpsPerCTA()[1]}; } else @@ -261,6 +356,16 @@ bool isaDistributedLayout(Attribute layout) { } // namespace gpu } // namespace triton + +bool isSharedEncoding(Value value) { + auto type = value.getType(); + if (auto tensorType = type.dyn_cast()) { + auto encoding = tensorType.getEncoding(); + return encoding && encoding.isa(); + } + return false; +} + } // namespace mlir static LogicalResult parseIntAttrValue(AsmParser &parser, Attribute attr, @@ -375,6 +480,7 @@ SliceEncodingAttr::getElemsPerThread(ArrayRef shape, auto parent = getParent(); auto parentElemsPerThread = ::getElemsPerThread(parent, paddedShape(shape), eltTy); + parentElemsPerThread.erase(parentElemsPerThread.begin() + getDim()); return parentElemsPerThread; } unsigned SliceEncodingAttr::getTotalElemsPerThread(ArrayRef shape, @@ -774,14 +880,27 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) { return {}; unsigned opIdx = attrs.get("opIdx").cast().getInt(); Attribute parent = attrs.get("parent"); + auto mmaParent = parent.dyn_cast(); + unsigned kWidth = 0; + Attribute _kWidth = attrs.get("kWidth"); + if (_kWidth) { + if (!mmaParent || mmaParent.isVolta()) { + auto loc = parser.getNameLoc(); + parser.emitError(loc, "kWidth only supported for MMAv2+ parent"); + return Attribute(); + } + kWidth = _kWidth.cast().getInt(); + } return parser.getChecked(parser.getContext(), opIdx, - parent); + parent, kWidth); } void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const { + auto mmaParent = getParent().dyn_cast(); printer << "<{" - << "opIdx = " << getOpIdx() << ", " - << "parent = " << getParent(); + << "opIdx = " << getOpIdx() << ", parent = " << getParent(); + if (mmaParent && mmaParent.isAmpere()) + printer << ", kWidth = " << getMMAv2kWidth(); printer << "}>"; } @@ -1029,9 +1148,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, return mlir::failure(); } auto newType = op->getResult(0).getType().cast(); - // Ensure that the new insert_slice op is placed in the same place as the - // old insert_slice op. Otherwise, the new insert_slice op may be placed - // after the async_wait op, which is not allowed. + // Ensure that the new insert_slice op is placed in the same place as + // the old insert_slice op. Otherwise, the new insert_slice op may be + // placed after the async_wait op, which is not allowed. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(insert_slice); auto newArg = rewriter.create( @@ -1059,9 +1178,9 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, auto resType = RankedTensorType::get( origResType.getShape(), origResType.getElementType(), extract_slice.getType().cast().getEncoding()); - // Ensure that the new extract_slice op is placed in the same place as the - // old extract_slice op. Otherwise, the new extract_slice op may be placed - // after the async_wait op, which is not allowed. + // Ensure that the new extract_slice op is placed in the same place as + // the old extract_slice op. Otherwise, the new extract_slice op may be + // placed after the async_wait op, which is not allowed. OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(extract_slice); auto newArg = rewriter.create( @@ -1109,8 +1228,8 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, // cvt(type, constant) -> constant if (auto cst = llvm::dyn_cast(arg)) if (auto ret = cst.getValue().dyn_cast()) { - auto newRet = SplatElementsAttr::get(op->getResultTypes().front(), - ret.getSplatValue()); + auto ty = op->getResultTypes().front().cast(); + auto newRet = SplatElementsAttr::get(ty, ret.getSplatValue()); rewriter.replaceOpWithNewOp(op, newRet); return mlir::success(); } diff --git a/lib/Dialect/TritonGPU/IR/Traits.cpp b/lib/Dialect/TritonGPU/IR/Traits.cpp index 03253e12c..c3d0c859e 100644 --- a/lib/Dialect/TritonGPU/IR/Traits.cpp +++ b/lib/Dialect/TritonGPU/IR/Traits.cpp @@ -1,5 +1,5 @@ #include "triton/Dialect/TritonGPU/IR/Traits.h" -#include "triton/Analysis/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" mlir::LogicalResult mlir::OpTrait::impl::verifyResultsAreSharedEncoding(Operation *op) { diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 120281eb7..651abd157 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -47,12 +47,13 @@ SmallVector mmaVersionToShapePerWarp(int version) { SmallVector warpsPerTileV2(triton::DotOp dotOp, const ArrayRef shape, int numWarps) { - SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); - if (llvm::find_if(slices, [](Operation *op) { - return isa(op); - }) != slices.end()) - return {(unsigned)numWarps, 1}; + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion(); + }; + auto slices = mlir::getSlice(dotOp, filter); + for (Operation *op : slices) + if (isa(op) && (op != dotOp)) + return {(unsigned)numWarps, 1}; SmallVector ret = {1, 1}; SmallVector shapePerWarp = {16, 8}; @@ -173,14 +174,17 @@ public: .cast() .getOrder(); + auto newAEncoding = triton::gpu::DotOperandEncodingAttr::get( + oldAType.getContext(), 0, newRetType.getEncoding(), + oldAType.getElementType()); + auto newBEncoding = triton::gpu::DotOperandEncodingAttr::get( + oldBType.getContext(), 1, newRetType.getEncoding(), + oldBType.getElementType()); + auto newAType = RankedTensorType::get( - oldAType.getShape(), oldAType.getElementType(), - triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0, - newRetType.getEncoding())); + oldAType.getShape(), oldAType.getElementType(), newAEncoding); auto newBType = RankedTensorType::get( - oldBType.getShape(), oldBType.getElementType(), - triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1, - newRetType.getEncoding())); + oldBType.getShape(), oldBType.getElementType(), newBEncoding); a = rewriter.create(a.getLoc(), newAType, a); b = rewriter.create(b.getLoc(), newBType, b); diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 37fbfb046..ced082069 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -14,7 +14,9 @@ add_mlir_dialect_library(TritonGPUTransforms TritonGPUTransformsIncGen LINK_LIBS PUBLIC + MLIRTransforms + MLIRTransformUtils + TritonAnalysis TritonIR TritonGPUIR - MLIRTransformUtils ) diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 432d4d5db..f27bfa06f 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -22,16 +22,13 @@ template SmallVector argSort(const T &arr) { typedef DenseMap> LayoutMap; struct CoalescePass : public TritonGPUCoalesceBase { - Attribute getCoalescedEncoding(AxisInfoAnalysis &axisInfo, Value ptr, - int numWarps) { + Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, + Value ptr, int numWarps) { auto origType = ptr.getType().cast(); // Get the shape of the tensor. size_t rank = origType.getRank(); - dataflow::Lattice *latticeElement = - axisInfo.getLatticeElement(ptr); - AxisInfo info = latticeElement ? latticeElement->getValue() : AxisInfo(); // Get the contiguity order of `ptr` - auto order = argSort(info.getContiguity()); + auto order = argSort(axisInfoAnalysis.getAxisInfo(ptr)->getContiguity()); // The desired divisibility is the maximum divisibility // among all dependent pointers who have the same order as // `ptr` @@ -42,8 +39,8 @@ struct CoalescePass : public TritonGPUCoalesceBase { for (Value val : op->getResults()) { if (val.getType() != origType) continue; - auto valInfo = axisInfo.getLatticeElement(val); - auto currOrder = argSort(valInfo->getValue().getContiguity()); + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); if (order == currOrder) withSameOrder.insert(val); } @@ -61,10 +58,11 @@ struct CoalescePass : public TritonGPUCoalesceBase { unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); unsigned perThread = 1; for (Value val : withSameOrder) { - AxisInfo info = axisInfo.getLatticeElement(val)->getValue(); - unsigned maxMultipleBytes = info.getDivisibility(order[0]); + unsigned maxMultipleBytes = + axisInfoAnalysis.getAxisInfo(val)->getDivisibility(order[0]); unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); - unsigned maxContig = info.getContiguity(order[0]); + unsigned maxContig = + axisInfoAnalysis.getAxisInfo(val)->getContiguity(order[0]); unsigned alignment = std::min(maxMultiple, maxContig); unsigned currPerThread = std::min(alignment, 128 / elemNumBits); perThread = std::max(perThread, currPerThread); @@ -78,9 +76,10 @@ struct CoalescePass : public TritonGPUCoalesceBase { return encoding; } - std::function getTypeConverter(AxisInfoAnalysis &axisInfo, - Value ptr, int numWarps) { - Attribute encoding = getCoalescedEncoding(axisInfo, ptr, numWarps); + std::function + getTypeConverter(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr, + int numWarps) { + Attribute encoding = getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps); return [encoding](Type _type) { RankedTensorType type = _type.cast(); return RankedTensorType::get(type.getShape(), type.getElementType(), @@ -127,17 +126,14 @@ struct CoalescePass : public TritonGPUCoalesceBase { } void runOnOperation() override { - Operation *op = getOperation(); // Run axis info analysis - std::unique_ptr solver = createDataFlowSolver(); - AxisInfoAnalysis *axisInfo = solver->load(); - if (failed(solver->initializeAndRun(op))) - return signalPassFailure(); + ModuleOp moduleOp = getOperation(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); // For each i/o operation, we determine what layout // the pointers should have for best memory coalescing LayoutMap layoutMap; - op->walk([&](Operation *curr) { + moduleOp.walk([&](Operation *curr) { Value ptr; if (auto op = dyn_cast(curr)) ptr = op.getPtr(); @@ -154,10 +150,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { RankedTensorType ty = ptr.getType().template dyn_cast(); if (!ty || !ty.getElementType().isa()) return; - AxisInfo info = axisInfo->getLatticeElement(ptr)->getValue(); auto mod = curr->getParentOfType(); int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - auto convertType = getTypeConverter(*axisInfo, ptr, numWarps); + auto convertType = getTypeConverter(axisInfoAnalysis, ptr, numWarps); layoutMap[ptr] = convertType; }); @@ -168,7 +163,7 @@ struct CoalescePass : public TritonGPUCoalesceBase { // produces a tensor with layout L2 // 4. Convert the output of this new memory op back to L1 // 5. Replace all the uses of the original memory op by the new one - op->walk([&](Operation *curr) { + moduleOp.walk([&](Operation *curr) { OpBuilder builder(curr); if (auto load = dyn_cast(curr)) { coalesceOp(layoutMap, curr, load.getPtr(), builder); diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index a30af6d6b..32605e605 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -72,6 +72,76 @@ public: } }; +// + +class MoveOpAfterLayoutConversion : public mlir::RewritePattern { + +public: + MoveOpAfterLayoutConversion(mlir::MLIRContext *context) + : mlir::RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), + 1, context) {} + + mlir::LogicalResult + matchAndRewrite(mlir::Operation *op, + mlir::PatternRewriter &rewriter) const override { + auto cvt = cast(op); + auto srcTy = cvt.getOperand().getType().cast(); + auto retTy = cvt.getResult().getType().dyn_cast(); + auto retEncoding = + retTy.getEncoding().dyn_cast(); + auto srcEncoding = + srcTy.getEncoding().dyn_cast(); + if (!retTy) + return failure(); + if (!retEncoding) + return failure(); + auto retEncodingParent = + retEncoding.getParent().dyn_cast(); + if (!retEncodingParent || retEncodingParent.isVolta()) + return failure(); + if (!srcEncoding) + return failure(); + // don't move things around when cvt operand is a block arg + Operation *argOp = cvt.getOperand().getDefiningOp(); + if (!argOp) + return failure(); + // + SetVector processed; + SetVector layout; + llvm::MapVector toConvert; + int numCvts = simulateBackwardRematerialization(cvt, processed, layout, + toConvert, retEncoding); + if (numCvts > 1 || toConvert.size() == 1) + return failure(); + for (Operation *op : processed) { + if (op->getNumOperands() != 1) + continue; + auto srcTy = op->getOperand(0).getType().cast(); + auto dstTy = op->getResult(0).getType().cast(); + // we don't want to push conversions backward if there is a downcast + // since it would result in more shared memory traffic + if (srcTy.getElementType().getIntOrFloatBitWidth() > + dstTy.getElementType().getIntOrFloatBitWidth()) + return failure(); + // we only push back when the first op in the chain has a load operand + if ((op == processed.back()) && + !isa(op->getOperand(0).getDefiningOp())) + return failure(); + // we don't want to use ldmatrix for 8-bit data that requires trans + // since Nvidia GPUs can't do it efficiently + bool isTrans = + (retEncoding.getOpIdx() == 1) ^ (srcEncoding.getOrder()[0] == 0); + bool isInt8 = srcTy.getElementType().getIntOrFloatBitWidth() == 8; + if (isTrans && isInt8) + return failure(); + } + IRMapping mapping; + rematerializeConversionChain(toConvert, rewriter, processed, mapping); + rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0))); + return mlir::success(); + } +}; + } // namespace #define GEN_PASS_CLASSES @@ -93,6 +163,7 @@ public: mlir::RewritePatternSet patterns(context); patterns.add(context); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); if (fixupLoops(m).failed()) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index 15e639366..1f3a75352 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -1,3 +1,4 @@ +#include "Utility.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/IRMapping.h" #include "mlir/IR/TypeUtilities.h" @@ -42,6 +43,9 @@ class LoopPipeliner { /// Loads to be pipelined SetVector loads; + /// Smallest data-type for each load (used to optimize swizzle and + /// (create DotOpEncoding layout) + DenseMap loadsSmallestType; /// The value that each load will be mapped to (after layout conversion) DenseMap loadsMapping; /// load => buffer @@ -181,24 +185,19 @@ ttg::AllocTensorOp LoopPipeliner::allocateEmptyBuffer(Operation *op, /// this pass?) LogicalResult LoopPipeliner::initialize() { Block *loop = forOp.getBody(); - - std::unique_ptr solver = createDataFlowSolver(); - AxisInfoAnalysis *axisInfoAnalysis = solver->load(); - if (failed(solver->initializeAndRun(forOp->getParentOfType()))) { - return failure(); - } + ModuleOp moduleOp = forOp->getParentOfType(); + ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); // can we use forOp.walk(...) here? SmallVector validLoads; for (Operation &op : *loop) if (auto loadOp = dyn_cast(&op)) { auto ptr = loadOp.getPtr(); - unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr); + unsigned vec = axisInfoAnalysis.getPtrContiguity(ptr); if (auto mask = loadOp.getMask()) - vec = std::min(vec, axisInfoAnalysis->getMaskAlignment(mask)); + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); - auto lattice = axisInfoAnalysis->getLatticeElement(ptr)->getValue(); auto tensorTy = ptr.getType().dyn_cast(); if (!tensorTy || tensorTy.getRank() < 2) continue; @@ -256,33 +255,62 @@ LogicalResult LoopPipeliner::initialize() { use = *use->getResult(0).getUsers().begin(); } - if (auto convertLayout = llvm::dyn_cast(use)) { - if (auto tensorType = convertLayout.getResult() - .getType() - .dyn_cast()) { - if (auto dotOpEnc = tensorType.getEncoding() - .dyn_cast()) { - isCandidate = true; - loadsMapping[loadOp] = convertLayout; - auto ty = loadOp.getType().cast(); - SmallVector bufferShape(ty.getShape().begin(), - ty.getShape().end()); - bufferShape.insert(bufferShape.begin(), numStages); - auto sharedEnc = ttg::SharedEncodingAttr::get( - ty.getContext(), dotOpEnc, ty.getShape(), - triton::gpu::getOrder(ty.getEncoding()), ty.getElementType()); - loadsBufferType[loadOp] = RankedTensorType::get( - bufferShape, ty.getElementType(), sharedEnc); - } - } - } - } else + auto convertLayout = llvm::dyn_cast(use); + if (!convertLayout) + continue; + auto tensorType = + convertLayout.getResult().getType().dyn_cast(); + if (!tensorType) + continue; + auto dotOpEnc = + tensorType.getEncoding().dyn_cast(); + if (!dotOpEnc) + continue; + isCandidate = true; + loadsMapping[loadOp] = convertLayout; + } + + else isCandidate = false; if (isCandidate) loads.insert(loadOp); } + // we need to find the smallest ocmmon dtype + // since this determines the layout of `mma.sync` operands + // in mixed-precision mode + Type smallestType; + for (auto loadCvt : loadsMapping) { + auto loadOp = loadCvt.first; + auto ty = loadOp.getType().cast(); + Type eltTy = ty.getElementType(); + if (!smallestType || + (eltTy.getIntOrFloatBitWidth() < smallestType.getIntOrFloatBitWidth())) + smallestType = eltTy; + } + + for (auto loadCvt : loadsMapping) + loadsSmallestType[loadCvt.first] = smallestType; + + for (auto loadCvt : loadsMapping) { + auto loadOp = loadCvt.first; + Value cvt = loadCvt.second; + auto dotOpEnc = cvt.getType() + .cast() + .getEncoding() + .cast(); + auto ty = loadOp.getType().cast(); + SmallVector bufferShape(ty.getShape().begin(), + ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), numStages); + auto sharedEnc = ttg::SharedEncodingAttr::get( + ty.getContext(), dotOpEnc, ty.getShape(), + triton::gpu::getOrder(ty.getEncoding()), loadsSmallestType[loadOp]); + loadsBufferType[loadOp] = + RankedTensorType::get(bufferShape, ty.getElementType(), sharedEnc); + } + // We have some loads to pipeline if (!loads.empty()) { // Update depArgs & depOps @@ -336,10 +364,6 @@ Value LoopPipeliner::getLoadMask(triton::LoadOp loadOp, Value mappedMask, } void LoopPipeliner::emitPrologue() { - // llvm::errs() << "loads to pipeline...:\n"; - // for (Value load : loads) - // llvm::errs() << load << "\n"; - OpBuilder builder(forOp); for (BlockArgument &arg : forOp.getRegionIterArgs()) { OpOperand &operand = forOp.getOpOperandForRegionIterArg(arg); @@ -364,7 +388,7 @@ void LoopPipeliner::emitPrologue() { for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); - else if (loads.contains(op.getResult(0))) + else if (op.getNumResults() > 0 && loads.contains(op.getResult(0))) orderedDeps.push_back(&op); } assert(depOps.size() + loads.size() == orderedDeps.size() && @@ -541,44 +565,44 @@ scf::ForOp LoopPipeliner::createNewForOp() { mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); - // 2.1 clone the loop body, replace original args with args of the new ForOp + // 2. clone the loop body, replace original args with args of the new ForOp // Insert async wait if necessary. + DenseSet isModified; for (Operation &op : forOp.getBody()->without_terminator()) { - Operation *newOp = builder.clone(op, mapping); - // update mapping of results - for (unsigned dstIdx : llvm::seq(unsigned(0), op.getNumResults())) - mapping.map(op.getResult(dstIdx), newOp->getResult(dstIdx)); - } + // is modified + auto it = std::find(loads.begin(), loads.end(), op.getOperand(0)); + if (it == loads.end()) { + Operation *newOp = cloneWithInferType(builder, &op, mapping); + continue; + } - // 3. replace loads with block args (from prologue) - for (size_t idx = 0; idx < loads.size(); ++idx) { - OpBuilder::InsertionGuard guard(builder); - Value load = loads[idx]; - assert(load.hasOneUse() && - "we assume that this load has one use (ConvertLayout)"); - Value loadUse = load.getUsers().begin()->getResult(0); - // set insertion point - Value newLoad = mapping.lookup(load); - Value newLoadUse = mapping.lookup(loadUse); - builder.setInsertionPoint(newLoadUse.getDefiningOp()); - // create conversion + // we replace the use new load use with a convert layout + size_t i = std::distance(loads.begin(), it); + auto cvtDstTy = op.getResult(0).getType().cast(); + auto cvtDstEnc = + cvtDstTy.getEncoding().dyn_cast(); + if (!cvtDstEnc) { + builder.clone(op, mapping); + continue; + } + auto newDstTy = RankedTensorType::get( + cvtDstTy.getShape(), cvtDstTy.getElementType(), + ttg::DotOperandEncodingAttr::get( + cvtDstEnc.getContext(), cvtDstEnc.getOpIdx(), cvtDstEnc.getParent(), + loadsSmallestType[op.getOperand(0)])); auto cvt = builder.create( - loadUse.getLoc(), loadUse.getType(), - newForOp.getRegionIterArgs()[loadIdx + idx]); - - // replace uses - newLoadUse.replaceAllUsesWith(cvt.getResult()); - // delete old load and layout conversion - newLoadUse.getDefiningOp()->erase(); - newLoad.getDefiningOp()->erase(); + op.getResult(0).getLoc(), newDstTy, + newForOp.getRegionIterArgs()[loadIdx + i]); + mapping.map(op.getResult(0), cvt.getResult()); + isModified.insert(op.getResult(0)); } - // 4. prefetch the next iteration + // 3. prefetch the next iteration SmallVector orderedDeps; for (Operation &op : forOp.getLoopBody().front()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); - else if (loads.contains(op.getResult(0))) + else if (op.getNumResults() > 0 && loads.contains(op.getResult(0))) orderedDeps.push_back(&op); } assert(depOps.size() + loads.size() == orderedDeps.size() && diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 4296ef58e..ad1c64b42 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -27,7 +27,6 @@ //===----------------------------------------------------------------------===// #include "mlir/IR/IRMapping.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" @@ -45,7 +44,7 @@ class Prefetcher { scf::YieldOp yieldOp; /// // TODO: add a hook to infer prefetchWidth - unsigned prefetchWidth = 16; + unsigned prefetchWidth = 32; /// dots to be prefetched SetVector dots; @@ -56,6 +55,8 @@ class Prefetcher { DenseMap dot2bHeaderDef; DenseMap dot2aYield; DenseMap dot2bYield; + DenseMap> dot2aVals; + DenseMap> dot2bVals; /// operand => defining DenseMap operand2headPrefetch; @@ -66,6 +67,9 @@ class Prefetcher { std::optional offsetK = std::nullopt, std::optional shapeK = std::nullopt); + void cloneElementwiseOps(Value &bRem, const SmallVector &vals, + OpBuilder &builder); + public: Prefetcher() = delete; @@ -80,6 +84,24 @@ public: scf::ForOp createNewForOp(); }; +void Prefetcher::cloneElementwiseOps(Value &ret, const SmallVector &vals, + OpBuilder &builder) { + IRMapping mapping; + mapping.map(vals[0], ret); + for (int i = 1; i < vals.size(); i++) { + Value v = vals[i]; + Value curr = builder.clone(*v.getDefiningOp(), mapping)->getResult(0); + auto retType = RankedTensorType::get( + ret.getType().cast().getShape(), + curr.getType().cast().getElementType(), + curr.getType().cast().getEncoding()); + curr.setType(retType); + mapping.map(v, curr); + } + if (vals.size() > 1) + ret = mapping.lookup(vals.back()); +} + Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, Attribute dotEncoding, OpBuilder &builder, std::optional offsetK, @@ -110,7 +132,7 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue, SmallVector{intAttr(1), intAttr(1)}); auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get( - builder.getContext(), opIdx, dotEncoding); + builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8); Value prefetchSlice = builder.create( v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), newSmem); @@ -135,11 +157,32 @@ LogicalResult Prefetcher::initialize() { return failure(); // returns source of cvt - auto getPrefetchSrc = [](Value v) -> Value { - if (auto cvt = v.getDefiningOp()) - if (isSharedEncoding(cvt.getOperand())) - return cvt.getSrc(); - return Value(); + + // returns source of cvt + auto getPrefetchSrc = [](Value v) -> SmallVector { + // walk back to conversion + Operation *op = v.getDefiningOp(); + bool foundConvertFromShared = false; + SmallVector rets; + rets.push_back(op->getResult(0)); + while (op) { + if (op->getNumOperands() != 1) + break; + if (!op->getResult(0).hasOneUse()) + break; + rets.push_back(op->getOperand(0)); + if (auto cvt = dyn_cast_or_null(op)) + if (isSharedEncoding(cvt.getOperand())) { + foundConvertFromShared = true; + break; + } + op = op->getOperand(0).getDefiningOp(); + } + std::reverse(rets.begin(), rets.end()); + + if (foundConvertFromShared) + return rets; + return {}; }; auto getIncomingOp = [this](Value v) -> Value { @@ -156,24 +199,39 @@ LogicalResult Prefetcher::initialize() { }; for (triton::DotOp dot : dotsInFor) { - auto kSize = dot.getA().getType().cast().getShape()[1]; + auto aType = dot.getA().getType().cast(); + auto bType = dot.getB().getType().cast(); + auto aEnc = aType.getEncoding().cast(); + auto bEnc = bType.getEncoding().cast(); + int aKWidth = aEnc.getMMAv2kWidth(); + int bKWidth = bEnc.getMMAv2kWidth(); + assert(aKWidth == bKWidth); + + auto kSize = aType.getShape()[1]; // works better with nvidia tensor cores - unsigned elementWidth = - dot.getA().getType().cast().getElementTypeBitWidth(); - prefetchWidth = 256 / elementWidth; + unsigned elementWidth = aType.getElementTypeBitWidth(); + if (aKWidth == 0) + prefetchWidth = 256 / elementWidth; + else + prefetchWidth = 8 * aKWidth; // Skip prefetching if kSize is less than prefetchWidth if (kSize < prefetchWidth) continue; - Value aSmem = getPrefetchSrc(dot.getA()); - Value bSmem = getPrefetchSrc(dot.getB()); - if (aSmem && bSmem) { + auto aVals = getPrefetchSrc(dot.getA()); + auto bVals = getPrefetchSrc(dot.getB()); + + if (aVals.size() && bVals.size()) { + Value aSmem = aVals.front(); + Value bSmem = bVals.front(); Value aHeaderDef = getIncomingOp(aSmem); Value bHeaderDef = getIncomingOp(bSmem); // Only prefetch loop arg if (aHeaderDef && bHeaderDef) { dots.insert(dot); + dot2aVals[dot] = aVals; + dot2bVals[dot] = bVals; dot2aHeaderDef[dot] = aHeaderDef; dot2bHeaderDef[dot] = bHeaderDef; dot2aLoopArg[dot] = aSmem; @@ -195,10 +253,13 @@ void Prefetcher::emitPrologue() { dot.getType().cast().getEncoding(); Value aPrefetched = generatePrefetch(dot2aHeaderDef[dot], 0, true, dotEncoding, builder); - operand2headPrefetch[dot.getDefiningOp().getA()] = - aPrefetched; + cloneElementwiseOps(aPrefetched, dot2aVals[dot], builder); Value bPrefetched = generatePrefetch(dot2bHeaderDef[dot], 1, true, dotEncoding, builder); + cloneElementwiseOps(bPrefetched, dot2bVals[dot], builder); + + operand2headPrefetch[dot.getDefiningOp().getA()] = + aPrefetched; operand2headPrefetch[dot.getDefiningOp().getB()] = bPrefetched; } @@ -256,9 +317,11 @@ scf::ForOp Prefetcher::createNewForOp() { Value aRem = generatePrefetch(mapping.lookup(dot2aLoopArg[dot]), 0, false, dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(aRem, dot2aVals[dot], builder); Value bRem = generatePrefetch(mapping.lookup(dot2bLoopArg[dot]), 1, false, dotEncoding, builder, kOff, kShape); + cloneElementwiseOps(bRem, dot2bVals[dot], builder); builder.restoreInsertionPoint(insertionPoint); newOp = builder.clone(*dot, mapping); newOp->setOperand(0, aRem); @@ -281,10 +344,15 @@ scf::ForOp Prefetcher::createNewForOp() { for (Value dot : dots) { Attribute dotEncoding = dot.getType().cast().getEncoding(); - yieldValues.push_back(generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, - true, dotEncoding, builder)); - yieldValues.push_back(generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, - true, dotEncoding, builder)); + Value aToYield = generatePrefetch(mapping.lookup(dot2aYield[dot]), 0, true, + dotEncoding, builder); + cloneElementwiseOps(aToYield, dot2aVals[dot], builder); + yieldValues.push_back(aToYield); + // bToYield + Value bToYield = generatePrefetch(mapping.lookup(dot2bYield[dot]), 1, true, + dotEncoding, builder); + cloneElementwiseOps(bToYield, dot2bVals[dot], builder); + yieldValues.push_back(bToYield); } // Update ops of yield if (!yieldValues.empty()) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index f6cc7c2ea..58a747d7b 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -13,7 +13,6 @@ #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/Transforms/Passes.h" #include "mlir/Transforms/RegionUtils.h" -#include "triton/Analysis/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" @@ -341,7 +340,10 @@ public: cvt.getOperand().getType().cast().getEncoding(); auto dstEncoding = cvt.getResult().getType().cast().getEncoding(); - // XXX: why is this needed? + if (srcEncoding.isa() || + dstEncoding.isa()) + return failure(); + // heuristics for flash attention if (srcEncoding.isa()) return failure(); SetVector cvtSlices; @@ -365,7 +367,7 @@ public: // don't rematerialize non-element-wise if (!op->hasTrait() && !op->hasTrait() && - !isa(op)) { + !isa(op) && !isa(op)) { return failure(); } // don't rematerialize if it adds an extra conversion that can't @@ -375,9 +377,10 @@ public: SetVector processed; SetVector layout; llvm::MapVector toConvert; - if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 && - simulateBackwardRematerialization(argOp, processed, layout, - toConvert, srcEncoding) > 0) { + int numAddedConvs = simulateBackwardRematerialization( + argOp, processed, layout, toConvert, srcEncoding); + if (argOp && !isa(argOp) && + cvtSlices.count(argOp) == 0 && numAddedConvs > 0) { return failure(); } } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 327d955b5..27d73e73d 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -89,11 +89,11 @@ LogicalResult invertEncoding(Attribute targetEncoding, Operation *op, } bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) { - // Case 1a: A size 1 tensor is not expensive since all threads will load the + // Case 1: A size 1 tensor is not expensive since all threads will load the // same if (isSingleValue(op->getOperand(0))) return false; - // Case 1b: Tensor of pointers has more threads than elements + // Case 2: Tensor of pointers has more threads than elements // we can presume a high hit-rate that makes it cheap to load auto ptrType = op->getOperand(0).getType().cast(); IntegerAttr numWarps = @@ -104,28 +104,6 @@ bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) { if (ptrType.getNumElements() < numWarps.getInt() * 32) return false; } - // auto ptr = op->getOperand(0); - //// Case 2: We assume that `evict_last` loads/stores have high hit rate - // if (auto load = dyn_cast(op)) - // if (load.getEvict() == triton::EvictionPolicy::EVICT_LAST) - // return false; - // if (auto store = dyn_cast(op)) - // if (store.getEvict() == triton::EvictionPolicy::EVICT_LAST) - // return false; - // if (auto tensorTy = ptr.getType().dyn_cast()) { - // auto encoding = tensorTy.getEncoding(); - // // Case 3: Different type conversion is expensive (e.g., mma <-> - // block) if (encoding.getTypeID() != targetEncoding.getTypeID()) - // return true; - // auto sizePerThread = triton::gpu::getSizePerThread(encoding); - // auto targetSizePerThread = - // triton::gpu::getSizePerThread(targetEncoding); auto order = - // triton::gpu::getOrder(encoding); auto targetOrder = - // triton::gpu::getOrder(targetEncoding); - // // Case 4: The targeEncoding may expose more vectorization - // opportunities return sizePerThread[order[0]] >= - // targetSizePerThread[targetOrder[0]]; - // } return true; } @@ -144,6 +122,12 @@ bool expensiveToRemat(Operation *op, Attribute &targetEncoding) { return false; } +bool canFoldConversion(Operation *op) { + return isa(*op); +} + int simulateBackwardRematerialization( Operation *initOp, SetVector &processed, SetVector &layout, llvm::MapVector &toConvert, @@ -189,10 +173,7 @@ int simulateBackwardRematerialization( continue; // If the conversion can be folded into opArgI then // we don't count this conversion as expensive - if (isa(*opArgI)) - continue; - if (isa(opArgI)) + if (canFoldConversion(opArgI)) continue; // We add one expensive conversion for the current operand @@ -206,11 +187,24 @@ int simulateBackwardRematerialization( // -Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping) { Operation *newOp = rewriter.clone(*op, mapping); - auto origType = op->getResult(0).getType().cast(); - auto argType = newOp->getOperand(0).getType().cast(); + // if input types haven't changed, we're done + bool preserveTypes = + std::all_of(op->operand_begin(), op->operand_end(), [&](Value v) { + return !mapping.contains(v) || + v.getType() == mapping.lookup(v).getType(); + }); + if (preserveTypes) + return newOp; + + if (newOp->getNumResults() == 0) + return newOp; + auto origType = op->getResult(0).getType().dyn_cast(); + auto argType = newOp->getOperand(0).getType().dyn_cast(); + if (!origType || !argType) + return newOp; auto newType = RankedTensorType::get( origType.getShape(), origType.getElementType(), argType.getEncoding()); newOp->getResult(0).setType(newType); @@ -219,9 +213,12 @@ Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, SmallVector newTypes; auto success = typeInfer.inferReturnTypes( newOp->getContext(), newOp->getLoc(), newOp->getOperands(), - newOp->getAttrDictionary(), newOp->getRegions(), newTypes); - if (succeeded(success)) - newOp->getResult(0).setType(newTypes.front()); + newOp->getAttrDictionary(), newOp->getPropertiesStorage(), + newOp->getRegions(), newTypes); + if (succeeded(success)) { + for (size_t i = 0; i < newTypes.size(); i++) + newOp->getResult(i).setType(newTypes[i]); + } } return newOp; } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.h b/lib/Dialect/TritonGPU/Transforms/Utility.h index cf8cc4093..607b974c8 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.h +++ b/lib/Dialect/TritonGPU/Transforms/Utility.h @@ -21,7 +21,7 @@ int simulateBackwardRematerialization( SetVector &layout, llvm::MapVector &toConvert, Attribute targetEncoding); -Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op, +Operation *cloneWithInferType(mlir::OpBuilder &rewriter, Operation *op, IRMapping &mapping); void rematerializeConversionChain( diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index ac8973ad1..1430a6af2 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -5,8 +5,17 @@ add_mlir_translation_library(TritonLLVMIR Core LINK_LIBS PUBLIC + MLIRArithToLLVM + MLIRBuiltinToLLVMIRTranslation + MLIRExecutionEngineUtils + MLIRIndexToLLVM MLIRIR MLIRLLVMDialect + MLIRLLVMToLLVMIRTranslation + MLIRNVVMToLLVMIRTranslation + MLIRROCDLToLLVMIRTranslation + MLIRSCFToControlFlow MLIRSupport MLIRTargetLLVMIRExport + TritonGPUToLLVM ) diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 11f4a65e2..bdd7b6483 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -25,12 +25,22 @@ #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" #include "llvm/Support/SourceMgr.h" +<<<<<<< HEAD #include +======= +#ifdef _WIN32 +#define WIN32_LEAN_AND_MEAN +#include +#else +>>>>>>> openai/main #include +#endif #include #include +namespace fs = std::filesystem; + namespace mlir { namespace triton { @@ -117,6 +127,32 @@ extractNVVMMetadata(mlir::ModuleOp module, } } +static std::filesystem::path getThisLibraryPath() { +#ifdef _WIN32 + /* Get module of the specified address */ + HMODULE hModule; + GetModuleHandleExA(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | + GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast(&getThisLibraryPath), &hModule); + if (NULL == hModule) { + return std::filesystem::path(); + } + + char fileName[1024]; // this is way beyond Windows MAX_PATH limit. + DWORD dwSize = GetModuleFileNameA(hModule, fileName, sizeof(fileName)); + if (0 == dwSize || sizeof(fileName) == dwSize) { + return std::filesystem::path(); + } + return std::filesystem::path(fileName); +#else + Dl_info fileinfo; + if (dladdr(reinterpret_cast(&getThisLibraryPath), &fileinfo) == 0) { + return std::filesystem::path(); + } + return std::filesystem::path(fileinfo.dli_fname); +#endif +} + static std::map getExternLibs(mlir::ModuleOp module) { std::map externLibs; SmallVector funcs; @@ -156,17 +192,10 @@ static std::map getExternLibs(mlir::ModuleOp module) { externLibs.try_emplace(libdevice, env_path); return externLibs; } - namespace fs = std::filesystem; // Search for libdevice relative to its library path if used from Python // Then native code is in `triton/_C/libtriton.so` and libdevice in // `triton/third_party/cuda/lib/libdevice.10.bc` - static const auto this_library_path = [] { - Dl_info fileinfo; - if (dladdr(reinterpret_cast(&getExternLibs), &fileinfo) == 0) { - return std::filesystem::path(); - } - return std::filesystem::path(fileinfo.dli_fname); - }(); + static const auto this_library_path = getThisLibraryPath(); static const auto runtime_path = this_library_path.parent_path().parent_path() / "third_party" / "cuda" / "lib" / "libdevice.10.bc"; diff --git a/python/setup.py b/python/setup.py index b953045b4..68c6cfee9 100644 --- a/python/setup.py +++ b/python/setup.py @@ -68,7 +68,7 @@ def get_llvm_package_info(): use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False") release_suffix = "assert" if use_assert_enabled_llvm else "release" name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}' - version = "llvm-17.0.0-f733b4fb9b8b" + version = "llvm-17.0.0-c5dede880d17" url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz" return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH") @@ -249,6 +249,7 @@ setup( "triton/_C", "triton/common", "triton/compiler", + "triton/debugger", "triton/language", "triton/language/extra", "triton/ops", @@ -294,7 +295,6 @@ setup( "numpy", "pytest", "scipy>=1.7.1", - "torch", ], "tutorials": [ "matplotlib", diff --git a/python/src/triton.cc b/python/src/triton.cc index ed41f8267..04b6598a9 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -264,6 +264,11 @@ void init_triton_ir(py::module &&m) { return !self.empty() && self.back().hasTrait(); }) + .def("has_return", + [](mlir::Block &self) { + return !self.empty() && + self.back().hasTrait(); + }) .def("erase", [](mlir::Block &self) { self.erase(); }); // using eattr = ir::attribute_kind_t; @@ -430,6 +435,25 @@ void init_triton_ir(py::module &&m) { self.setArgAttr(arg_no, name, mlir::IntegerAttr::get(attrTy, val)); }, ret::reference) + .def("finalize", + [](mlir::triton::FuncOp &self) -> void { + // Remove dead code + // 1. Unreachable code after return + self.walk([&](mlir::Block *block) { + mlir::Operation *retOp = nullptr; + block->walk([&](mlir::Operation *op) { + if (mlir::isa(op)) + if (retOp == nullptr) + retOp = op; + }); + if (retOp && retOp != &block->back()) { + auto pos = retOp->getIterator(); + pos++; + auto *newBlock = block->splitBlock(pos); + newBlock->erase(); + } + }); + }) .def_property_readonly("type", &mlir::triton::FuncOp::getFunctionType) .def("reset_type", &mlir::triton::FuncOp::setType); @@ -454,7 +478,8 @@ void init_triton_ir(py::module &&m) { [](mlir::OpBuilder &self, mlir::triton::FuncOp &func, std::vector &args) -> mlir::OpState { auto loc = self.getUnknownLoc(); - return self.create(loc, func, args); + auto callOp = self.create(loc, func, args); + return callOp; }) // insertion block/point .def("set_insertion_point_to_start", @@ -645,14 +670,16 @@ void init_triton_ir(py::module &&m) { .def("get_or_insert_function", [](mlir::OpBuilder &self, mlir::ModuleOp &module, std::string &funcName, mlir::Type &funcType, - std::string &visibility) -> mlir::triton::FuncOp { + std::string &visibility, bool noinline) -> mlir::triton::FuncOp { if (mlir::Operation *funcOperation = module.lookupSymbol(funcName)) return llvm::dyn_cast(funcOperation); auto loc = self.getUnknownLoc(); if (auto funcTy = funcType.dyn_cast()) { llvm::SmallVector attrs = { mlir::NamedAttribute(self.getStringAttr("sym_visibility"), - self.getStringAttr(visibility))}; + self.getStringAttr(visibility)), + mlir::NamedAttribute(self.getStringAttr("noinline"), + self.getBoolAttr(noinline))}; return self.create(loc, funcName, funcTy, attrs); } @@ -1699,54 +1726,59 @@ void init_triton_translation(py::module &m) { m.def("compile_ptx_to_cubin", [](const std::string &ptxCode, const std::string &ptxasPath, int capability) -> py::object { - py::gil_scoped_release allow_threads; + std::string cubin; + { + py::gil_scoped_release allow_threads; - // compile ptx with ptxas - llvm::SmallString<64> fsrc; - llvm::SmallString<64> flog; - llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc); - llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog); - std::string fbin = std::string(fsrc) + ".o"; - llvm::FileRemover logRemover(flog); - llvm::FileRemover binRemover(fbin); - const char *_fsrc = fsrc.c_str(); - const char *_flog = flog.c_str(); - const char *_fbin = fbin.c_str(); - std::ofstream ofs(_fsrc); - ofs << ptxCode << std::endl; - ofs.close(); - std::string cmd; - int err; - cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) + - (capability == 90 ? "a " : " ") + _fsrc + " -o " + _fsrc + - ".o 2> " + _flog; + // compile ptx with ptxas + llvm::SmallString<64> fsrc; + llvm::SmallString<64> flog; + llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc); + llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog); + std::string fbin = std::string(fsrc) + ".o"; + llvm::FileRemover logRemover(flog); + llvm::FileRemover binRemover(fbin); + const char *_fsrc = fsrc.c_str(); + const char *_flog = flog.c_str(); + const char *_fbin = fbin.c_str(); + std::ofstream ofs(_fsrc); + ofs << ptxCode << std::endl; + ofs.close(); + std::string cmd; + int err; + cmd = ptxasPath + " -v --gpu-name=sm_" + + std::to_string(capability) + (capability == 90 ? "a " : " ") + + _fsrc + " -o " + _fsrc + ".o 2> " + _flog; - err = system(cmd.c_str()); - if (err != 0) { - err >>= 8; - std::ifstream _log(_flog); - std::string log(std::istreambuf_iterator(_log), {}); - if (err == 255) { - throw std::runtime_error("Internal Triton PTX codegen error: \n" + - log); - } else if (err == 128 + SIGSEGV) { - throw std::runtime_error("Please run `ptxas " + fsrc.str().str() + - "` to confirm that this is a " - "bug in `ptxas`\n" + - log); + err = system(cmd.c_str()); + if (err != 0) { + err >>= 8; + std::ifstream _log(_flog); + std::string log(std::istreambuf_iterator(_log), {}); + if (err == 255) { + throw std::runtime_error( + "Internal Triton PTX codegen error: \n" + log); + } else if (err == 128 + SIGSEGV) { + throw std::runtime_error("Please run `ptxas " + + fsrc.str().str() + + "` to confirm that this is a " + "bug in `ptxas`\n" + + log); + } else { + throw std::runtime_error("`ptxas` failed with error code " + + std::to_string(err) + ": \n" + log); + } + return {}; } else { - throw std::runtime_error("`ptxas` failed with error code " + - std::to_string(err) + ": \n" + log); + llvm::FileRemover srcRemover(fsrc); + std::ifstream _cubin(_fbin, std::ios::binary); + cubin = std::string(std::istreambuf_iterator(_cubin), {}); + _cubin.close(); + // Do not return here, exit the gil scope and return below } - return {}; - } else { - llvm::FileRemover srcRemover(fsrc); - std::ifstream _cubin(_fbin, std::ios::binary); - std::string cubin(std::istreambuf_iterator(_cubin), {}); - _cubin.close(); - py::bytes bytes(cubin); - return std::move(bytes); } + py::bytes bytes(cubin); + return std::move(bytes); }); m.def("add_external_libs", diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py new file mode 100644 index 000000000..099399d96 --- /dev/null +++ b/python/test/regression/test_functional_regressions.py @@ -0,0 +1,68 @@ +import torch + +import triton +import triton.language as tl + + +def test_chained_matmul(): + # Regression test for issue #1601 + def chained_matmul_reference(a, b, c): + intermediate = torch.einsum('MK,NK->MN', a, b) + return torch.einsum('MN,NK->MK', intermediate, c) + + @triton.jit + def chained_matmul_kernel( + A, # shape: (m, k) + B, # shape: (n, k) + C, # shape: (n, k) + out, # shape: (m, k) + m, n, k: tl.constexpr, + block_m: tl.constexpr, + block_n: tl.constexpr, + block_k: tl.constexpr): + + tl.static_assert(block_k == k, + f"expected block_k == k but got {block_k} != {k}") + + block_ix = tl.program_id(0) + a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \ + + tl.arange(0, block_k)[None, :] + + a = tl.load(A + a_tile, mask=a_tile < m * k, other=0.0) + + acc = tl.zeros([block_m, block_k], dtype=tl.float32) + + for loop_block_start in range(0, n, block_n): + bc_tile = (loop_block_start + tl.arange(0, block_n))[:, None] * block_k \ + + tl.arange(0, block_k)[None, :] + b = tl.load(B + bc_tile, mask=bc_tile < n * k, other=0.0) + + intermediate = tl.dot(a, tl.trans(b)) + intermediate_mask = ((loop_block_start + tl.arange(0, block_n)) < n)[None, :] \ + * (tl.arange(0, block_m) < m)[:, None] + + intermediate = tl.where(intermediate_mask, intermediate, 0.0) + + c = tl.load(C + bc_tile, mask=bc_tile < n * k) + + acc += tl.dot(intermediate.to(A.dtype.element_ty), c) + + tl.store(out + a_tile, acc.to(A.dtype.element_ty), mask=a_tile < m * k) + + m, n, k = 32, 64, 128 + block_m, block_n, block_k = 16, 32, k + + grid = (triton.cdiv(m, block_m),) + a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, + device='cuda') + b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, + device='cuda') + c = torch.randint_like(b, low=0, high=2) + triton_result = torch.zeros_like(a) + + torch_result = chained_matmul_reference(a, b, c) + chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k, + block_m=block_m, block_n=block_n, + block_k=block_k) + + assert (torch_result == triton_result).all() diff --git a/python/test/unit/debugger/test_debugger.py b/python/test/unit/debugger/test_debugger.py new file mode 100644 index 000000000..741fcab3b --- /dev/null +++ b/python/test/unit/debugger/test_debugger.py @@ -0,0 +1,69 @@ +import random + +import torch + +import triton +import triton.language as tl +from triton.debugger.debugger import program_ids_from_grid + + +def test_addition(): + + @triton.jit(interpret=True) + def add_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(axis=0) + block_start = pid * BLOCK_SIZE + offsets = block_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements + x = tl.load(x_ptr + offsets, mask=mask) + y = tl.load(y_ptr + offsets, mask=mask) + output = x + y + tl.store(output_ptr + offsets, output, mask=mask) + + a = torch.rand((128,), device="cuda") + b = torch.rand((128,), device="cuda") + expected = a + b + output = torch.empty((128,), device="cuda") + + def grid(meta): + return (triton.cdiv(128, meta["BLOCK_SIZE"]),) + + add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32) + + assert torch.allclose(expected, output, atol=1e-2, rtol=0) + + +def test_program_ids_from_grid(): + random.seed(123) + grid = (3, 4) + expected_combinations = 3 * 4 + unique_combinations = set(program_ids_from_grid(grid)) + assert len(unique_combinations) == expected_combinations + + first_run = list(program_ids_from_grid(grid)) + second_run = list(program_ids_from_grid(grid)) + assert first_run != second_run + + +def test_atomic(): + @triton.jit(interpret=True) + def atomic( + x_ptr, + ): + pid = tl.program_id(axis=0) + tl.atomic_add(x_ptr + pid, 1) + t = tl.atomic_xchg(x_ptr + pid, 3) + t += 1 # 2 + tl.atomic_cas(x_ptr + pid, 3, t) # match + tl.atomic_cas(x_ptr + pid, 40, 9) # no match + nb_dim = 16 + a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda") + + atomic[(nb_dim, )](a) + assert torch.allclose(a, torch.full_like(a, 2)) diff --git a/python/test/unit/language/assert_helper.py b/python/test/unit/language/assert_helper.py index ddda846a4..b24a5d500 100644 --- a/python/test/unit/language/assert_helper.py +++ b/python/test/unit/language/assert_helper.py @@ -14,6 +14,14 @@ def kernel_device_assert(X, Y, BLOCK: tl.constexpr): tl.store(Y + tl.arange(0, BLOCK), x) +@triton.jit +def kernel_device_assert_scalar(X, Y, BLOCK: tl.constexpr): + x = tl.load(X + tl.arange(0, BLOCK)) + # Trivial assert + tl.device_assert(0 == 0, "x != 0") + tl.store(Y + tl.arange(0, BLOCK), x) + + @triton.jit def kernel_assert(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) @@ -33,7 +41,12 @@ def test_assert(func: str): x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda') y = torch.zeros(shape, dtype=x.dtype, device="cuda") if func == "device_assert": +<<<<<<< HEAD kernel_device_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0]) +======= + kernel_device_assert[(1,)](x, y, BLOCK=shape[0]) + kernel_device_assert_scalar[(1,)](x, y, BLOCK=shape[0]) +>>>>>>> openai/main elif func == "assert": kernel_assert[(1,)](x, y, num_warps=2, BLOCK=shape[0]) elif func == "static_assert": diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 747f23641..293acd494 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -457,6 +457,86 @@ def test_broadcast(dtype): assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() +# ---------------- +# test expand_dims +# ---------------- +def test_expand_dims(): + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + N = 32 + dummy_tensor = torch.empty((), device="cuda") + expand_dims_kernel[(1,)](dummy_tensor, N) + + +def test_expand_dims_error_cases(): + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device="cuda") + + with pytest.raises(triton.CompilationError, match="invalid axis -3"): + dim_out_of_range1[(1,)](dummy_tensor, N) + + with pytest.raises(triton.CompilationError, match="invalid axis 2"): + dim_out_of_range2[(1,)](dummy_tensor, N) + + with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): + duplicate_dim1[(1,)](dummy_tensor, N) + + with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): + duplicate_dim2[(1,)](dummy_tensor, N) + + # --------------- # test where # --------------- @@ -688,7 +768,7 @@ def test_index1d(expr, dtype_str, device='cuda'): @triton.jit -def fn(a, b): +def tuples_fn(a, b): return a + b, \ a - b, \ a * b @@ -701,7 +781,7 @@ def test_tuples(): def with_fn(X, Y, A, B, C): x = tl.load(X) y = tl.load(Y) - a, b, c = fn(x, y) + a, b, c = tuples_fn(x, y) tl.store(A, a) tl.store(B, b) tl.store(C, c) @@ -728,6 +808,92 @@ def test_tuples(): assert c_tri == c_ref +@triton.jit(noinline=True) +def noinline_simple_fn(x, y, Z): + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_graph_fn1(x): + return x + 1 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn2(y): + return y + 2 + + +@triton.jit(noinline=True) +def noinline_call_graph_fn(x, y, Z): + t0 = noinline_call_graph_fn1(x) + t1 = noinline_call_graph_fn2(y) + z = t0 + t1 + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_shared_fn(x, y, Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + x + y + tl.store(Z + offs, z) + + +@triton.jit(noinline=True) +def noinline_dynamic_fn(x, y, Z): + if x >= 1: + x = noinline_call_graph_fn1(x) + else: + x = noinline_call_graph_fn2(x) + if y >= 2: + y = noinline_call_graph_fn2(y) + else: + y = noinline_call_graph_fn1(y) + z = x + y + tl.store(Z, z) + + +@triton.jit(noinline=True) +def noinline_call_multi_values_fn(x, y): + return x + 1, y + 2 + + +@triton.jit(noinline=True) +def noinline_multi_values_fn(x, y, Z): + x, y = noinline_call_multi_values_fn(x, y) + z = x + y + tl.store(Z, z) + + +@pytest.mark.parametrize("mode", ["simple", "call_graph", "shared", "dynamic", "multi_values"]) +def test_noinline(mode): + device = 'cuda' + + @triton.jit + def kernel(X, Y, Z): + x = tl.load(X) + y = tl.load(Y) + GENERATE_TEST_HERE(x, y, Z) + + func_name = f'noinline_{mode}_fn' + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': func_name}) + x = torch.tensor([1.0], device=device, dtype=torch.float32) + y = torch.tensor([2.0], device=device, dtype=torch.float32) + if mode == "shared": + z = torch.ones((16, 16), device=device, dtype=torch.float32) + else: + z = torch.tensor([0.0], device=device, dtype=torch.float32) + kernel[(1,)](x, y, z, num_warps=1) + if mode == "simple": + assert torch.equal(z, x + y) + elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": + assert torch.equal(z, x + 1 + y + 2) + elif mode == "shared": + ref = torch.full((16, 16), 16, device=device, dtype=torch.float32) + assert torch.equal(z, ref + x + y) + + # --------------- # test atomics # --------------- @@ -1334,6 +1500,99 @@ def test_reduce_layouts(M, N, src_layout, axis, device='cuda'): np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) +layouts = [ + BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), + BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1]) +] + + +@pytest.mark.parametrize("M", [32, 64, 128, 256]) +@pytest.mark.parametrize("src_layout", layouts) +def test_store_op(M, src_layout, device='cuda'): + ir = f""" + #src = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> + tt.store %8, %4 : tensor<{M}x1xf32, #src> + tt.return + }} + }} + """ + + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + store_kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, 1)).astype('float32') + y = np.zeros((M, 1), dtype='float32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + + pgm = store_kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + +layouts = [ + BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), + BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1]) +] + + +@pytest.mark.parametrize("M", [64, 128, 256]) +@pytest.mark.parametrize("src_layout", layouts) +@pytest.mark.parametrize("dst_layout", layouts) +@pytest.mark.parametrize("src_dim", [0, 1]) +@pytest.mark.parametrize("dst_dim", [0, 1]) +def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device='cuda'): + ir = f""" + #dst = {dst_layout} + #src = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> + %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>> + %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> + %7 = triton_gpu.convert_layout %3 : (tensor<{M}xi32, #triton_gpu.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.store %6, %7 : tensor<{M}xi32, #triton_gpu.slice<{{dim = {dst_dim}, parent = #dst}}>> + tt.return + }} + }} + """ + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, )).astype('int32') + y = np.zeros((M, ), dtype='int32') + x_tri = torch.tensor(x, device=device) + y_tri = torch.tensor(y, device=device) + pgm = kernel[(1, 1, 1)](x_tri, y_tri) + y_ref = x + np.testing.assert_allclose(y_ref, y_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + @triton.jit def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): delta = mean_2 - mean_1 @@ -1346,6 +1605,68 @@ def _welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2): ) +layouts = [ + BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0]), + BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0]), + BlockedLayout([1, 4], [1, 32], [1, 4], [1, 0]), + BlockedLayout([1, 4], [8, 4], [2, 2], [0, 1]) +] + + +@pytest.mark.parametrize("M, N", [[128, 128], [256, 128], [256, 256], [128, 256]]) +@pytest.mark.parametrize("src_layout", layouts) +def test_chain_reduce(M, N, src_layout, device='cuda'): + ir = f""" + #src = {src_layout} + module attributes {{"triton_gpu.num-warps" = 4 : i32}} {{ + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #src> + %3 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %4 = tt.expand_dims %3 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> + %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src> + %11 = "tt.reduce"(%10) ({{ + ^bb0(%arg2: i32, %arg3: i32): + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32 + }}) {{axis = 1 : i32}} : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %12 = "tt.reduce"(%11) ({{ + ^bb0(%arg2: i32, %arg3: i32): + %13 = arith.addi %arg2, %arg3 : i32 + tt.reduce.return %13 : i32 + }}) {{axis = 0 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> i32 + tt.store %arg1, %12 {{cache = 1 : i32, evict = 1 : i32}} : i32 + tt.return + }} + }} + """ + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + rs = RandomState(17) + x = rs.randint(0, 4, (M, N)).astype('int32') + + z = np.zeros((1,)).astype('int32') + + x_tri = torch.tensor(x, device=device) + z_tri = torch.tensor(z, device=device) + + pgm = kernel[(1, 1, 1)](x_tri, z_tri) + z_ref = np.sum(x) + + np.testing.assert_allclose(z_ref, z_tri.cpu().numpy(), rtol=0.01, atol=1e-3) + + def test_generic_reduction(device='cuda'): @triton.jit @@ -2013,57 +2334,79 @@ def val_multiplier(val, i): return val * i +@triton.jit(noinline=True) +def val_multiplier_noinline(val, i): + return val * i + + @triton.jit -def vecmul_kernel(ptr, n_elements, rep): +def vecmul_kernel(ptr, n_elements, rep, type: tl.constexpr): pid = tl.program_id(axis=0) offsets = pid * 128 + tl.arange(0, 128) mask = offsets < n_elements vec = tl.load(ptr + offsets, mask=mask) for i in range(1, rep): - vec = val_multiplier(vec, i) + if type == "inline": + vec = val_multiplier(vec, i) + else: + vec = val_multiplier_noinline(vec, i) tl.store(ptr + offsets, vec, mask=mask) -def test_call(): +@pytest.mark.parametrize("type", ["inline", "noinline"]) +def test_call(type): @triton.jit - def kernel(ptr, n_elements, num1, num2): - vecmul_kernel(ptr, n_elements, num1) - vecmul_kernel(ptr, n_elements, num2) + def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): + vecmul_kernel(ptr, n_elements, num1, type) + vecmul_kernel(ptr, n_elements, num2, type) size = 1024 rand_val = numpy_random((size,), dtype_str="float32") rand_val_tri = to_triton(rand_val, device='cuda') - kernel[(size // 128,)](rand_val_tri, size, 3, 5) + err_msg = "" + try: + kernel[(size // 128,)](rand_val_tri, size, 3, 5, type) + except Exception as e: + err_msg = str(e) - ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 - np.testing.assert_equal(to_numpy(rand_val_tri), ans) + if type == "noinline": + assert err_msg is not "" + else: + ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 + np.testing.assert_equal(to_numpy(rand_val_tri), ans) # ------------- # test if # ------------- -@pytest.mark.parametrize("if_type", ["if", "if_exp"]) +@pytest.mark.parametrize("if_type", ["if", "if_exp", "if_and"]) def test_if(if_type): @triton.jit - def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr): + def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr): pid = tl.program_id(0) cond = tl.load(Cond) if IfType == "if": - if pid % 2: + if pid % 2 == 0: tl.store(Ret, tl.load(XTrue)) else: tl.store(Ret, tl.load(XFalse)) - else: + elif IfType == "if_exp": tl.store(Ret, tl.load(XTrue)) if pid % 2 else tl.store(Ret, tl.load(XFalse)) + elif IfType == "if_and": + if BoolVar and pid % 2 == 0: + tl.store(Ret, tl.load(XTrue)) + else: + tl.store(Ret, tl.load(XFalse)) cond = torch.ones(1, dtype=torch.int32, device='cuda') x_true = torch.tensor([3.14], dtype=torch.float32, device='cuda') x_false = torch.tensor([1.51], dtype=torch.float32, device='cuda') ret = torch.empty(1, dtype=torch.float32, device='cuda') - kernel[(1,)](cond, x_true, x_false, ret, if_type) + kernel[(1,)](cond, x_true, x_false, ret, if_type, True) + assert torch.equal(ret, x_true) def test_num_warps_pow2(): @@ -2227,24 +2570,105 @@ def test_if_else(): assert to_numpy(out)[0] == false_val[0] -def test_if_return(): +@pytest.mark.parametrize("mode", ["dynamic", "static"]) +def test_if_return(mode): @triton.jit - def kernel(ExitEarly, Out): - if tl.load(ExitEarly): - tl.store(Out, 0) - return + def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): + if mode == "dynamic": + if tl.load(ExitEarly): + tl.store(Out, 0) + return + else: + if cond: + tl.store(Out, 0) + return tl.store(Out, 1) out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda') exit_early = to_triton(np.zeros((1,), dtype=np.int32), device='cuda') # exit early path taken exit_early[0] = 1 - kernel[(1,)](exit_early, out) + kernel[(1,)](exit_early, out, True, mode) assert to_numpy(out)[0] == 0 # exit early path not taken exit_early[0] = 0 - kernel[(1,)](exit_early, out) + kernel[(1,)](exit_early, out, False, mode) + assert to_numpy(out)[0] == 1 + + +@triton.jit +def add_fn(x): + return x + 1 + + +@triton.jit(noinline=True) +def add_fn_noinline(x): + return x + 1 + + +@triton.jit +def add_fn_return(x, pid): + if pid == 0: + return x + 1 + else: + return x + 2 + + +@triton.jit +def add_fn_expr(Out, x): + tl.store(Out, x) + + +@triton.jit +def add_fn_static_cond(x, cond: tl.constexpr): + if cond == "": + return x + else: + return x + 1 + + +@pytest.mark.parametrize("call_type", ["attribute", "jit_function", "jit_function_return", + "ifexp", "expr", "jit_function_static_cond", "jit_function_noinline"]) +def test_if_call(call_type): + @triton.jit + def kernel(Out, call_type: tl.constexpr): + pid = tl.program_id(0) + o = tl.load(Out) + if pid == 0: + if call_type == "attribute": + # call attribute + a = o + 1 + a = a.to(tl.int32).to(tl.int32) + o = a + else: + a = o + if call_type == "jit_function": + # regular function call + a = add_fn(a) + elif call_type == "jit_function_return": + # function without end_if block + a = add_fn_return(a, pid) + elif call_type == "ifexp": + # ifexp expression + a = add_fn(a) if pid == 0 else add_fn_return(a, pid) + elif call_type == "expr": + if pid == 1: + return + a = add_fn(a) + if pid == 0: + # call without return + add_fn_expr(Out, a) + elif call_type == "jit_function_static_cond": + a = add_fn_static_cond(a, call_type) + elif call_type == "jit_function_noinline": + a = add_fn_noinline(a) + o = a + + tl.store(Out, o) + + out = to_triton(np.zeros((1,), dtype=np.int32), device='cuda') + kernel[(1,)](out, call_type) assert to_numpy(out)[0] == 1 diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 38d396cf0..2bca3f283 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -1,7 +1,5 @@ -import multiprocessing import os import shutil -from collections import namedtuple import pytest import torch @@ -170,34 +168,32 @@ def test_jit_debug() -> None: assert bins[0].asm['ttir'] != bins[1].asm['ttir'] -def test_compile_in_subproc() -> None: +@triton.jit +def add_fn(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) + + +def test_jit_noinline() -> None: @triton.jit - def kernel_sub(a, b, o, N: tl.constexpr): - idx = tl.arange(0, N) - tl.store(o + idx, - tl.load(a + idx) - tl.load(b + idx) * 777) + def kernel_add_device(a, b, o, N: tl.constexpr): + add_fn(a, b, o, N) - major, minor = torch.cuda.get_device_capability(0) - cc = major * 10 + minor - config = namedtuple("instance_descriptor", [ - "divisible_by_16", "equal_to_1"])( - tuple(range(4)), - ()) - - proc = multiprocessing.Process( - target=triton.compile, - kwargs=dict( - fn=kernel_sub, - signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, - device=0, - constants={3: 32}, - configs=[config], - warm_cache_only=True, - cc=cc, - )) - proc.start() - proc.join() - assert proc.exitcode == 0 + device = torch.cuda.current_device() + assert len(kernel_add_device.cache[device]) == 0 + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + inline_ttir = bins[0].asm['ttir'] + add_fn.noinline = True + add_fn.hash = None + kernel_add_device.hash = None + kernel_add_device.cache[device].clear() + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + assert len(kernel_add_device.cache[device]) == 1 + bins = list(kernel_add_device.cache[device].values()) + noinline_ttir = bins[0].asm['ttir'] + assert inline_ttir != noinline_ttir def test_memory_leak() -> None: diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py new file mode 100644 index 000000000..0e0d33c6f --- /dev/null +++ b/python/test/unit/runtime/test_subproc.py @@ -0,0 +1,83 @@ +import multiprocessing +import os +import shutil +from collections import namedtuple + +import torch + +import triton +import triton.language as tl + +tmpdir = ".tmp" + + +def reset_tmp_dir(): + os.environ["TRITON_CACHE_DIR"] = tmpdir + if os.path.exists(tmpdir): + shutil.rmtree(tmpdir) + + +instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"]) + + +def compile_fn(config, cc): + @triton.jit + def kernel_sub(a, b, o, N: tl.constexpr): + idx = tl.arange(0, N) + tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + triton.compile( + fn=kernel_sub, + signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, + device=0, + constants={3: 32}, + configs=[config], + warm_cache_only=True, + cc=cc, + ) + + +def test_compile_in_subproc() -> None: + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = instance_descriptor(tuple(range(4)), ()) + + multiprocessing.set_start_method('fork') + proc = multiprocessing.Process( + target=compile_fn, + args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + +def compile_fn_dot(config, cc): + @triton.jit + def kernel_dot(Z): + offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] + z = tl.load(Z + offs) + z = tl.dot(z, z) + tl.store(Z + offs, z) + + triton.compile( + fn=kernel_dot, + signature={0: "*fp32"}, + device=0, + configs=[config], + warm_cache_only=True, + cc=cc, + ) + + +def test_compile_in_forked_subproc() -> None: + reset_tmp_dir() + major, minor = torch.cuda.get_device_capability(0) + cc = major * 10 + minor + config = instance_descriptor(tuple(range(1)), ()) + + assert multiprocessing.get_start_method() == 'fork' + proc = multiprocessing.Process( + target=compile_fn_dot, + args=(config, cc)) + proc.start() + proc.join() + assert proc.exitcode == 0 diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 15539ecac..14c9d61bd 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -4,10 +4,6 @@ __version__ = '2.1.0' # --------------------------------------- # Note: import order is significant here. -# TODO: torch needs to be imported first -# or pybind11 shows `munmap_chunk(): invalid pointer` -import torch # noqa: F401 - # submodules from .runtime import ( autotune, @@ -22,6 +18,8 @@ from .runtime import ( ) from .runtime.jit import jit from .compiler import compile, CompilationError +from .debugger.debugger import program_ids_from_grid + from . import language from . import testing @@ -45,6 +43,7 @@ __all__ = [ "runtime", "TensorWrapper", "testing", + "program_ids_from_grid", ] diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 66db4aaaa..5702773c9 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -46,6 +46,8 @@ def mangle_fn(name, arg_tys, constants): mangled_constants = '_'.join([f'{i}c{repr(constants[i])}' for i in sorted(constants)]) mangled_constants = mangled_constants.replace('.', '_d_') mangled_constants = mangled_constants.replace("'", '_sq_') + # [ and ] are not allowed in LLVM identifiers + mangled_constants = mangled_constants.replace('[', '_').replace(']', '_') ret = f'{name}__{mangled_arg_names}__{mangled_constants}' return ret @@ -58,10 +60,21 @@ def _is_constexpr(o: Any) -> bool: return isinstance(o, constexpr) +def _is_triton_scalar(o: Any) -> bool: + return _is_triton_tensor(o) and (not o.type.is_block() or o.type.numel == 1) + + def _unwrap_if_constexpr(o: Any): return o.value if isinstance(o, constexpr) else o +def _check_fn_args(node, fn, args): + if fn.noinline: + for idx, arg in enumerate(args): + if not _is_constexpr(arg) and not _is_triton_scalar(arg): + raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}') + + _condition_types = {bool, int, type(None)} # Python types accepted for conditionals inside kernels @@ -86,7 +99,8 @@ class enter_sub_region: class CodeGenerator(ast.NodeVisitor): def __init__(self, context, prototype, gscope, attributes, constants, function_name, - module=None, is_kernel=False, function_types: Optional[Dict] = None, debug=False): + module=None, is_kernel=False, function_types: Optional[Dict] = None, + debug=False, noinline=False): self.builder = ir.builder(context) self.module = self.builder.create_module() if module is None else module self.function_ret_types = {} if function_types is None else function_types @@ -99,7 +113,9 @@ class CodeGenerator(ast.NodeVisitor): self.is_kernel = is_kernel self.last_node = None self.debug = debug + self.noinline = noinline self.scf_stack = [] + self.last_ret_type = None # SSA-construction # name => language.tensor self.local_defs: Dict[str, tensor] = {} @@ -134,7 +150,7 @@ class CodeGenerator(ast.NodeVisitor): def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: ''' This function: - called by visit_Assign() & visit_FuncDef() to store left value (lvalue) + called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) 1. record local defined name (FIXME: should consider control flow) 2. store tensor in self.lvalue ''' @@ -146,10 +162,9 @@ class CodeGenerator(ast.NodeVisitor): # def visit_compound_statement(self, stmts): for stmt in stmts: - self.last_ret_type = self.visit(stmt) - if isinstance(stmt, ast.Return): - break - return stmts and isinstance(stmt, ast.Return) + ret_type = self.visit(stmt) + if ret_type is not None and isinstance(stmt, ast.Return): + self.last_ret_type = ret_type # TODO: should be its own AST visitor def contains_return_op(self, node): @@ -164,8 +179,23 @@ class CodeGenerator(ast.NodeVisitor): pred = lambda s: self.contains_return_op(s) return any(pred(s) for s in node.body) elif isinstance(node, ast.Call): + def check_undefined_name(cur_node): + # Check if name is an undefined local variable, + # which can only be a tensor or a constexpr + if isinstance(cur_node.func, ast.Attribute): + if isinstance(cur_node.func.value, ast.Name): + name = cur_node.func.value.id + if name not in self.lscope and name not in self.gscope: + return True + return False + # chain of calls + # e.g., tl.load(a).to(tl.float32) + return check_undefined_name(cur_node.func.value) + return False + if check_undefined_name(node): + return False fn = self.visit(node.func) - if isinstance(fn, JITFunction): + if isinstance(fn, JITFunction) and fn.noinline is not True: old_gscope = self.gscope self.gscope = sys.modules[fn.fn.__module__].__dict__ ret = self.contains_return_op(fn.parse()) @@ -178,6 +208,18 @@ class CodeGenerator(ast.NodeVisitor): if node.orelse: ret = ret or any(pred(s) for s in node.orelse) return ret + elif isinstance(node, ast.IfExp): + return self.contains_return_op(node.body) or self.contains_return_op(node.orelse) + elif isinstance(node, ast.Expr): + ret = False + for _, value in ast.iter_fields(node): + if isinstance(value, list): + for item in value: + if isinstance(item, ast.AST): + ret = ret or self.contains_return_op(item) + elif isinstance(value, ast.AST): + ret = ret or self.contains_return_op(value) + return ret else: return False @@ -228,7 +270,7 @@ class CodeGenerator(ast.NodeVisitor): self.visit(init_node) # initialize function visibility = "public" if self.is_kernel else "private" - fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility) + fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline) self.module.push_back(fn) entry = fn.add_entry_block() arg_values = [] @@ -251,9 +293,9 @@ class CodeGenerator(ast.NodeVisitor): self.set_value(arg_name, arg_value) self.builder.set_insertion_point_to_start(entry) # visit function body - has_ret = self.visit_compound_statement(node.body) + self.visit_compound_statement(node.body) # finalize function - if not has_ret: + if self.last_ret_type is None: self.builder.ret([]) else: # update return type @@ -265,6 +307,8 @@ class CodeGenerator(ast.NodeVisitor): fn.reset_type(self.prototype.to_ir(self.builder)) if insert_pt: self.builder.set_insertion_point_to_end(insert_pt) + # Remove dead code + fn.finalize() def visit_arguments(self, node): arg_names = [] @@ -415,6 +459,7 @@ class CodeGenerator(ast.NodeVisitor): return then_defs, else_defs, then_block, else_block, names, ret_types, ir_ret_types def visit_if_top_level(self, cond, node): + has_endif_block = True with enter_sub_region(self) as sr: liveins, ip_block = sr then_block = self.builder.create_block() @@ -429,20 +474,25 @@ class CodeGenerator(ast.NodeVisitor): self.visit_then_else_blocks(node, liveins, then_block, else_block) # then terminator self.builder.set_insertion_point_to_end(then_block) - if not then_block.has_terminator(): + if then_block.has_return() and else_block.has_return(): + has_endif_block = False + endif_block.erase() + if not then_block.has_terminator() and has_endif_block: self.builder.create_branch(endif_block, [then_defs[n].handle for n in names]) # else terminator self.builder.set_insertion_point_to_end(else_block) - if not else_block.has_terminator(): + if not else_block.has_terminator() and has_endif_block: self.builder.create_branch(endif_block, [else_defs[n].handle for n in names]) - for ty in ir_ret_types: - endif_block.add_argument(ty) - # change block - self.builder.set_insertion_point_to_start(endif_block) - # update value - for i, name in enumerate(names): - new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) - self.set_value(name, new_tensor) + if has_endif_block: + for ty in ir_ret_types: + endif_block.add_argument(ty) + if has_endif_block: + # change block + self.builder.set_insertion_point_to_start(endif_block) + # update value + for i, name in enumerate(names): + new_tensor = language.core.tensor(endif_block.arg(i), ret_types[i]) + self.set_value(name, new_tensor) # TODO: refactor def visit_if_scf(self, cond, node): @@ -650,6 +700,8 @@ class CodeGenerator(ast.NodeVisitor): ub = language.core._to_tensor(ub, self.builder) step = language.core._to_tensor(step, self.builder) # induction variable type + if not lb.dtype.is_int() or not ub.dtype.is_int() or not step.dtype.is_int(): + raise TypeError(f"For loop bounds and step must all be ints, are ({lb.dtype}, {ub.dtype}, {step.dtype})") iv_type = language.semantic.integer_promote_impl(lb.dtype, ub.dtype) iv_type = language.semantic.integer_promote_impl(iv_type, step.dtype) iv_ir_type = iv_type.to_ir(self.builder) @@ -773,7 +825,7 @@ class CodeGenerator(ast.NodeVisitor): if not self.module.has_function(fn_name): prototype = language.function_type([], arg_types) gscope = sys.modules[fn.fn.__module__].__dict__ - generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug) + generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=fn.debug, noinline=fn.noinline) generator.visit(fn.parse()) callee_ret_type = generator.last_ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -805,6 +857,7 @@ class CodeGenerator(ast.NodeVisitor): if not self.debug: return if isinstance(fn, JITFunction): + _check_fn_args(node, fn, args) return self.call_JitFunction(fn, args, kws) if (hasattr(fn, '__self__') and _is_triton_tensor(fn.__self__)) or language.core.is_builtin(fn): extra_kwargs = dict(_builder=self.builder) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 5e000d45d..3fe3f5385 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -11,8 +11,6 @@ from collections import namedtuple from pathlib import Path from typing import Any, Tuple -import torch - import triton import triton._C.libtriton.triton as _triton from ..runtime import driver @@ -254,7 +252,7 @@ def convert_type_repr(x): return x -def make_hash(fn, **kwargs): +def make_hash(fn, arch, **kwargs): if isinstance(fn, triton.runtime.JITFunction): configs = kwargs["configs"] signature = kwargs["signature"] @@ -265,7 +263,7 @@ def make_hash(fn, **kwargs): # Get unique key for the compiled code get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1)) configs_key = [get_conf_key(conf) for conf in configs] - key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}" + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{debug}-{arch}" return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) return hashlib.md5((Path(fn).read_text() + triton.runtime.jit.version_key()).encode("utf-8")).hexdigest() @@ -330,6 +328,10 @@ def is_hip(): def get_architecture_descriptor(capability): + try: + import torch + except ImportError: + raise ImportError("Triton requires PyTorch to be installed") if capability is None: if torch.version.hip is None: device = triton.runtime.jit.get_current_device() @@ -428,7 +430,7 @@ def compile(fn, **kwargs): # cache manager so_path = make_stub(name, signature, constants) # create cache manager - fn_cache_manager = get_cache_manager(make_hash(fn, **kwargs)) + fn_cache_manager = get_cache_manager(make_hash(fn, arch, **kwargs)) # determine name and extension type of provided function if isinstance(fn, triton.runtime.JITFunction): name, ext = fn.__name__, "ast" diff --git a/lib/Dialect/Triton/IR/Interfaces.cpp b/python/triton/debugger/__init__.py similarity index 100% rename from lib/Dialect/Triton/IR/Interfaces.cpp rename to python/triton/debugger/__init__.py diff --git a/python/triton/debugger/core.py b/python/triton/debugger/core.py new file mode 100644 index 000000000..82f3f43a2 --- /dev/null +++ b/python/triton/debugger/core.py @@ -0,0 +1,9 @@ +from typing import Tuple + +import dataclasses + + +@dataclasses.dataclass +class ExecutionContext: + program_id: Tuple[int] + program_size: Tuple[int] diff --git a/python/triton/debugger/debugger.py b/python/triton/debugger/debugger.py new file mode 100644 index 000000000..5c5b97292 --- /dev/null +++ b/python/triton/debugger/debugger.py @@ -0,0 +1,170 @@ +import itertools +import random +from typing import Tuple + +import triton +import triton.language as tl +from .core import ExecutionContext +from .memory_map import MemoryMap +from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor, + debugger_constexpr) +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch +tl_method_backup = {} + + +def get_proxy_method(proxy, name): + method = getattr(proxy, name) + + def fun(*args, **kwarg): + return method(*args, **kwarg) + + return fun + + +def attach_triton(module, proxy): + method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"] + for name in method_list: + if hasattr(module, name): + attr = getattr(module, name) + tl_method_backup[name] = attr + if callable(attr): + setattr(module, name, get_proxy_method(proxy, name)) + else: + setattr(module, name, getattr(proxy, name)) + + +def detach_triton(module): + for name, method in tl_method_backup.items(): + setattr(module, name, method) + + +def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]: + # reverse the grid dimensions and generate the range for each dimension + reversed_grid = reversed(grid) + ranges_for_each_dimension = [range(dim) for dim in reversed_grid] + + # gen all combinations + index_combinations = list(itertools.product(*ranges_for_each_dimension)) + random.shuffle(index_combinations) + + for index_combination in index_combinations: + yield index_combination + + +class DebuggerFunction: + def __init__(self, func, grid=(1,)): + self.func = func + self.grid = grid + + def _is_constexpr(self, name): + return name in self.func.__annotations__ and self.func.__annotations__[name] is triton.language.core.constexpr + + def _get_constexpr(self): + result = [] + for name, annotation in self.func.__annotations__.items(): + if annotation is triton.language.core.constexpr: + result.append(name) + return result + + def _assert_constexpr(self, **kwargs): + constexp = self._get_constexpr() + missing = [i for i in constexp if i not in kwargs.keys()] + assert len(missing) == 0, f"You must specify constexpr {missing}" + + def _get_grid(self, **kwargs): + if callable(self.grid): + return self.grid(kwargs) + else: + return self.grid + + def __call__(self, *args, **kwargs): + self._assert_constexpr(**kwargs) + + memory = MemoryMap() + + def convert_arg(v): + name, arg = v + if torch.is_tensor(arg): + ptr = memory.add_tensor(arg) + return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda")) + if self._is_constexpr(name): + return debugger_constexpr(arg) + return WrappedTensor(_primitive_to_tensor(arg)) + + new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args))) + new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]} + + grid = self._get_grid(**kwargs) + for program_id in program_ids_from_grid(grid): + proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid)) + attach_triton(tl, proxy) + self.func(*new_args, **new_kwargs) + detach_triton(tl) + + +class GridSelector: + """ + Entry point of the debugger + """ + + def __init__(self, func): + version = torch.__version__ + assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}" + self.func = func + + def __getitem__(self, grid): + return DebuggerFunction(self.func, grid) + + def __call__(self, *args, **kwargs): + return DebuggerFunction(self.func)(*args, **kwargs) + + +class AutotuneGridSelector: + def __init__(self, func, autotune_params): + self.func = func + self.autotune_params = autotune_params + + def __getitem__(self, grid): + return AutotuneRunner(self.func, self.autotune_params, grid) + + def __call__(self, *args, **kwargs): + return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs) + + +class AutotuneRunner: + def __init__(self, func, autotune_params, grid=None): + self.func = func + self.autotune_params = autotune_params + self.grid = grid + + def __call__(self, *args, **kwargs): + assert len(self.autotune_params["configs"]) >= 1 + + for config in self.autotune_params["configs"][1:]: + + def convert_arg(v): + if torch.is_tensor(v): + return torch.clone(v) + return v + + new_args = tuple(map(convert_arg, args)) + new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()} + if self.grid: + self.func[self.grid](*new_args, **new_kwargs, **config.kwargs) + else: + self.func(*new_args, **new_kwargs, **config.kwargs) + + main_config = self.autotune_params["configs"][0] + if self.grid: + self.func[self.grid](*args, **kwargs, **main_config.kwargs) + else: + self.func(*args, **kwargs, **main_config.kwargs) + + +def triton_debug_autotune(**kwars): + def wrapper(func): + return AutotuneGridSelector(func, kwars) + + return wrapper diff --git a/python/triton/debugger/memory_map.py b/python/triton/debugger/memory_map.py new file mode 100644 index 000000000..edf4c3f77 --- /dev/null +++ b/python/triton/debugger/memory_map.py @@ -0,0 +1,100 @@ +import dataclasses + +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch + + +@dataclasses.dataclass +class RegisteredStorage: + storage: torch.Storage + dtype: torch.dtype + size: int + ptr: int + + @property + def end_ptr(self) -> int: + return self.ptr + self.size + + @property + def access_tensor(self) -> torch.Tensor: + return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device) + + def ensure_immutable(self): + assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size + + +class MemoryMap: + storages: [RegisteredStorage] + + def __init__(self): + self.storages = [] + + def _get_registered_storage(self, pointer: torch.Tensor): + max_pointer = torch.max(pointer).item() + min_pointer = torch.min(pointer).item() + + registered_storage = next( + filter( + lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages + ), + None, + ) + if registered_storage is None: + raise Exception("Storage not found or pointers spanning multiple tensors") + registered_storage.ensure_immutable() + return registered_storage + + def add_tensor(self, t: torch.Tensor): + storage = t.untyped_storage() + self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr())) + return t.data_ptr() + + def load( + self, + pointer: torch.Tensor, + mask: torch.Tensor = None, + other=0.0, + ): + assert pointer.is_cuda + assert 0 < pointer.dim() < 3 + assert pointer.dtype == torch.int64 + + if mask is None: + mask = torch.ones_like(pointer).bool() + assert mask.is_cuda + assert 0 < mask.dim() < 3 + assert mask.dtype == torch.bool + mask = mask.expand(pointer.size()) + + if torch.all(~mask): + # Todo: The type is wrong here, we can't determine the correct type + return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda") + + registered_storage = self._get_registered_storage(pointer[mask]) + access_tensor = registered_storage.access_tensor + + index_tensor = pointer - registered_storage.ptr + + block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda") + block[mask] = access_tensor[index_tensor[mask]] + return block + + def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): + assert 0 < pointer.dim() < 3 + assert pointer.dtype == torch.int64 + + if mask is None: + mask = torch.ones_like(pointer).bool() + assert 0 < mask.dim() < 3 + assert mask.dtype == torch.bool + mask = mask.expand(pointer.size()) + + if torch.all(~mask): + return + + registered_storage = self._get_registered_storage(pointer[mask]) + access_tensor = registered_storage.access_tensor + + index_tensor = pointer - registered_storage.ptr + access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype) diff --git a/python/triton/debugger/tl_lang.py b/python/triton/debugger/tl_lang.py new file mode 100644 index 000000000..6364b77a3 --- /dev/null +++ b/python/triton/debugger/tl_lang.py @@ -0,0 +1,621 @@ +import triton +from .core import ExecutionContext +from .memory_map import MemoryMap +from triton.debugger import torch_wrapper + +torch = torch_wrapper.torch + + +def _primitive_to_tensor(x): + """ + Converts various Python primitive data types to PyTorch tensor. + """ + tensor_args = {"device": "cuda"} + if isinstance(x, bool): + return torch.tensor([x], dtype=torch.bool, **tensor_args) + elif isinstance(x, int): + if -(2**31) <= x < 2**31: + return torch.tensor([x], dtype=torch.int32, **tensor_args) + elif -(2**63) <= x < 2**63: + return torch.tensor([x], dtype=torch.int64, **tensor_args) + else: + raise RuntimeError(f"Nonrepresentable integer {x}.") + elif isinstance(x, float): + return torch.tensor([x], dtype=torch.float32, **tensor_args) + elif torch.is_tensor(x): + return x + elif isinstance(x, WrappedTensor): + return x + elif isinstance(x, debugger_constexpr): + if x.value is None: + return None + return _primitive_to_tensor(x.value) + elif x is None: + return None + assert False, f"cannot convert {x} of type {type(x)} to tensor" + + +def _infer_tensor(func): + """ + A decorator function to harmonize function args: + - converts primitives to PyTorch tensors + - wraps PyTorch tensors with WrappedTensors + """ + def wrapper(*args): + new_args = tuple(map(lambda v: _primitive_to_tensor(v), args)) + new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args)) + + return func(*new_args) + + return wrapper + + +def _tensor_operation(func): + """ + A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function. + Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor). + """ + def wrapper(*args, **kwargs): + for arg in args: + assert not torch.is_tensor(arg), "unexpected tensor argument" + + def unwrap_tensor(v): + if isinstance(v, WrappedTensor): + return v.tensor + if isinstance(v, debugger_constexpr): + return v.value + return v + + new_args = tuple(map(unwrap_tensor, args)) + new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()} + + result = func(args[0], *new_args[1:], **new_kwargs) + return WrappedTensor(result) if torch.is_tensor(result) else result + + return wrapper + + +class debugger_constexpr: + def __init__(self, value): + if isinstance(value, debugger_constexpr): + self.value = value.value + else: + self.value = value + + def __str__(self) -> str: + return "debugger_constexpr(" + str(self.value) + ")" + + def __index__(self) -> int: + return self.value + + def __bool__(self): + return bool(self.value) + + def __ge__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value >= other + + def __gt__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value > other + + def __le__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value <= other + + def __lt__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value < other + + def __eq__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value == other + + def __or__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value | other + + def __ror__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value | other + + def __and__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value & other + + def __rand__(self, other): + other = other.value if isinstance(other, debugger_constexpr) else other + return self.value & other + + def to(self, dtype, bitcast=False, _builder=None): + if dtype in [torch.int64]: + ret_ty = int + elif dtype == torch.bool: + ret_ty = bool + elif dtype in [torch.float64]: + ret_ty = float + else: + raise ValueError("dtype not supported in debugger") + return debugger_constexpr(ret_ty(self.value)) + + +class WrappedTensor: + def __init__(self, tensor): + self.tensor = tensor + + def __index__(self) -> int: + return self.tensor.item() + + def __str__(self) -> str: + return "wrapped_" + str(self.tensor) + + def __bool__(self) -> bool: + return torch.all(self.tensor == True).item() # noqa: E712 + + @property + def dtype(self): + return self.tensor.dtype + + @_infer_tensor + @_tensor_operation + def __add__(self, other): + return torch.add(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __radd__(self, other): + return self.__add__(other) + + @_infer_tensor + @_tensor_operation + def __sub__(self, other): + return torch.sub(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rsub__(self, other): + return torch.sub(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __mul__(self, other): + return torch.mul(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rmul__(self, other): + return self.__mul__(other) + + @_infer_tensor + @_tensor_operation + def __truediv__(self, other): + return torch.div(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rtruediv__(self, other): + return torch.div(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __floordiv__(self, other): + return torch.floor_divide(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rfloordiv__(self, other): + return torch.floor_divide(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __mod__(self, other): + return torch.remainder(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rmod__(self, other): + return torch.remainder(other, self.tensor) + + @_infer_tensor + @_tensor_operation + def __neg__(self): + return -self.tensor + + @_infer_tensor + @_tensor_operation + def __invert__(self): + return ~self.tensor + + @_infer_tensor + @_tensor_operation + def __and__(self, other): + return torch.bitwise_and(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __or__(self, other): + return torch.bitwise_or(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __xor__(self, other): + return torch.bitwise_xor(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __lshift__(self, other): + return torch.bitwise_left_shift(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __rshift__(self, other): + return torch.bitwise_right_shift(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __gt__(self, other): + return self.tensor > other + + @_infer_tensor + @_tensor_operation + def __rgt__(self, other): + return other > self.tensor + + @_infer_tensor + @_tensor_operation + def __ge__(self, other): + return self.tensor >= other + + @_infer_tensor + @_tensor_operation + def __rge__(self, other): + return other >= self.tensor + + @_infer_tensor + @_tensor_operation + def __lt__(self, other): + return self.tensor < other + + @_infer_tensor + @_tensor_operation + def __rlt__(self, other): + return other < self.tensor + + @_infer_tensor + @_tensor_operation + def __le__(self, other): + return self.tensor <= other + + @_infer_tensor + @_tensor_operation + def __rle__(self, other): + return other <= self.tensor + + @_infer_tensor + @_tensor_operation + def __eq__(self, other): + return torch.equal(self.tensor, other) + + @_infer_tensor + @_tensor_operation + def __ne__(self, other): + return not torch.equal(self.tensor, other) + + @_tensor_operation + def __getitem__(self, slices): + return self.tensor.__getitem__(slices) + # if isinstance(slices, slice): + # slices = [slices] + # src_shape = self.shape + # dst_shape = [] + # curr = 0 + # for sl in slices: + # if isinstance(sl, constexpr) and sl.value is None: + # dst_shape.append(1) + # elif sl == slice(None, None, None): + # dst_shape.append(src_shape[curr].value) + # curr += 1 + # ret = torch.reshape(self.tensor, dst_shape, ) + # return ret + + @_tensor_operation + def to(self, dtype, bitcast=False): + return self.tensor.to(dtype) + # if isinstance(bitcast, constexpr): + # bitcast = bitcast.value + # if bitcast: + # return semantic.bitcast(self, dtype, ) + # return semantic.cast(self, dtype, ) + + +def _constexpr_to_value(v): + if isinstance(v, debugger_constexpr): + return v.value + return v + + +class TritonLangProxy: + _memory_map: MemoryMap + _context: ExecutionContext + + def __init__(self, memory_map: MemoryMap, context: ExecutionContext): + self._memory_map = memory_map + self._context = context + + # Types + # Removed void, int1, float8, uint16, uint32, uint64, pi32_t + + # constexpr = debugger_constexpr + + # Program functions + + @_tensor_operation + def load( + self, + pointer: torch.Tensor, + mask: torch.Tensor = None, + other=0.0, + cache_modifier="", + eviction_policy="", + volatile=False, + ): + return self._memory_map.load(pointer, mask, other) + + @_tensor_operation + def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): + return self._memory_map.store(pointer, value, mask) + + @_tensor_operation + def program_id(self, axis): + assert axis < len(self._context.program_id) + return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda") + + @_tensor_operation + def num_programs(self, axis): + assert axis < len(self._context.program_size) + return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda") + + @_tensor_operation + def arange(self, start, end): + return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda") + + @_tensor_operation + def zeros(self, shape, dtype): + for i, d in enumerate(shape): + if not isinstance(d, debugger_constexpr): + raise TypeError(f"Shape element {i} must have type `constexpr`") + if not isinstance(d.value, int): + raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") + shape = [x.value for x in shape] + if isinstance(dtype, triton.language.core.dtype): + if dtype.is_fp32(): + dtype = torch.float32 + elif dtype.is_fp16(): + dtype = torch.float16 + elif dtype.is_bf16(): + dtype = torch.bfloat16 + elif dtype.is_int32(): + dtype = torch.int32 + elif dtype.is_int16(): + dtype = torch.int16 + elif dtype.is_int8(): + dtype = torch.int8 + else: + raise TypeError(f"Unsupported dtype {dtype}") + return torch.zeros(size=shape, dtype=dtype, device="cuda") + + @_tensor_operation + def dequantize(self, input, scale, shift, nbit, dst_ty=torch.float16): + raise NotImplementedError() + + @_tensor_operation + def broadcast(self, input, other): + raise NotImplementedError() + + @_tensor_operation + def broadcast_to(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def cat(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def reshape(self, input, shape): + raise NotImplementedError() + + @_tensor_operation + def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True): + assert input.dtype == other.dtype + if trans_a: + input = input.T + if trans_b: + other = other.T + return torch.matmul(input=input, other=other) + + @_tensor_operation + def atomic_cas(self, pointer, cmp, val): + stored = self._memory_map.load(pointer, None, 0.0) + if not isinstance(cmp, torch.Tensor): + cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda") + if not isinstance(val, torch.Tensor): + val = torch.tensor([val], dtype=stored.dtype, device="cuda") + if stored == cmp: + self._memory_map.store(pointer, val, None) + return stored + + @_tensor_operation + def atomic_xchg(self, pointer, val, mask=None): + if isinstance(val, int): + val = torch.tensor([val], dtype=torch.int32, device="cuda") + stored = self._memory_map.load(pointer, mask, 0.0) + self._memory_map.store(pointer, val, mask) + return stored + + @_tensor_operation + def atomic_add(self, pointer, val, mask=None): + # arbitrary other value as it will masked during storing + stored = self._memory_map.load(pointer, mask, 0.0) + result = stored + val + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_max(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0.0) + result = torch.maximum(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_min(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0.0) + result = torch.minimum(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_and(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_and(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_or(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_or(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def atomic_xor(self, pointer, val, mask=None): + stored = self._memory_map.load(pointer, mask, 0) + result = torch.bitwise_xor(stored, val) + self._memory_map.store(pointer, result, mask) + return stored + + @_tensor_operation + def where(self, condition, x, y): + condition = _primitive_to_tensor(condition) + x = _primitive_to_tensor(x) + y = _primitive_to_tensor(y) + return torch.where(condition, x, y) + + @_tensor_operation + def umulhi(self, x, y): + raise NotImplementedError() + + @_tensor_operation + def fdiv(self, x, y, ieee_rounding=False): + raise NotImplementedError() + + @_tensor_operation + def exp(self, x): + return torch.exp(x) + + @_tensor_operation + def log(self, x): + return torch.log(x) + + @_tensor_operation + def cos(self, x): + return torch.cos(x) + + @_tensor_operation + def sin(self, x): + return torch.sin(x) + + @_tensor_operation + def sqrt(self, x): + return torch.sqrt(x) + + @_tensor_operation + def globaltimer(self): + raise NotImplementedError() + + @_tensor_operation + def clock(self): + raise NotImplementedError() + + @_tensor_operation + def debug_barrier(self): + raise NotImplementedError() + + @_tensor_operation + def multiple_of(self, input, values): + return input + + @_tensor_operation + def max_contiguous(self, input, values): + return input + + @_tensor_operation + def abs(self, x): + return torch.abs(x) + + @_tensor_operation + def cdiv(self, x, div): + return (x + div - 1) // div + + @_tensor_operation + def minimum(self, x, y): + if isinstance(x, int): + x = torch.tensor(x, device="cuda") + if isinstance(y, int): + y = torch.tensor(y, device="cuda") + return torch.minimum(x, y) + + @_tensor_operation + def maximum(self, x, y): + return torch.maximum(x, y) + + @_tensor_operation + def sigmoid(self, x): + raise NotImplementedError() + + @_tensor_operation + def softmax(self, x, ieee_rounding=False): + raise NotImplementedError() + + @_tensor_operation + def ravel(self, x): + raise NotImplementedError() + + @_tensor_operation + def swizzle2d(self, i, j, size_i, size_j, size_g): + raise NotImplementedError() + + @_tensor_operation + def zeros_like(self, input): + raise NotImplementedError() + + @_tensor_operation + def max(self, input, axis=None): + if axis is None: + return torch.max(input) + return torch.max(input, dim=axis).values + + @_tensor_operation + def argmax(self, input, axis): + raise NotImplementedError() + + @_tensor_operation + def min(self, input, axis=None): + if axis is None: + return torch.min(input) + return torch.min(input, dim=axis).values + + @_tensor_operation + def argmin(self, input, axis): + raise NotImplementedError() + + @_tensor_operation + def sum(self, input, axis=None): + if axis is None: + return torch.sum(input) + return torch.sum(input, dim=axis) + + @_tensor_operation + def xor_sum(self, input, axis): + raise NotImplementedError() diff --git a/python/triton/debugger/torch_wrapper.py b/python/triton/debugger/torch_wrapper.py new file mode 100644 index 000000000..44aa17eb1 --- /dev/null +++ b/python/triton/debugger/torch_wrapper.py @@ -0,0 +1,18 @@ +try: + import torch as _torch +except ImportError: + _torch = None + + +class TorchWrapper: + """ + Helps in making torch an optional dependency + """ + + def __getattr__(self, name): + if _torch is None: + raise ImportError("Triton requires PyTorch to be installed") + return getattr(_torch, name) + + +torch = TorchWrapper() diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 92619bf27..7485f374b 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -39,6 +39,7 @@ from .core import ( dot, dtype, exp, + expand_dims, full, fdiv, float16, @@ -130,6 +131,7 @@ __all__ = [ "dot", "dtype", "exp", + "expand_dims", "extra", "fdiv", "float16", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 81c6417b4..749f98089 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -3,7 +3,7 @@ from __future__ import annotations from contextlib import contextmanager from enum import Enum from functools import wraps -from typing import Callable, List, TypeVar +from typing import Callable, List, Sequence, TypeVar import triton from . import semantic @@ -903,6 +903,41 @@ def reshape(input, shape, _builder=None): shape = _shape_check_impl(shape) return semantic.reshape(input, shape, _builder) + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims recieved duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + # ----------------------- # Linear Algebra # ----------------------- @@ -1301,9 +1336,9 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None): if len(input.shape) > 1: # Broadcast index across the non-reduced axes - expand_dims_index = [constexpr(None)] * len(input.shape) - expand_dims_index[axis] = slice(None) - index = index.__getitem__(expand_dims_index, _builder=_builder) + axes_to_expand = [constexpr(d) for d in range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) index = broadcast_to(index, input.shape, _builder=_builder) rvalue, rindices = reduce((input, index), axis, combine_fn, diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 7e84ccd05..a73bfa18e 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -3,9 +3,6 @@ from __future__ import annotations # remove after python 3.11 from functools import wraps from typing import List, Optional, Sequence, Tuple, TypeVar -import torch - -import triton from . import core as tl from triton._C.libtriton.triton import ir @@ -665,7 +662,7 @@ def bitcast(input: tl.tensor, src_bits = src_sca_ty.primitive_bitwidth dst_bits = dst_sca_ty.primitive_bitwidth if src_bits != dst_bits: - raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + "to " + raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " "data-type of size " + str(dst_bits)) return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) @@ -1190,14 +1187,6 @@ def dot(lhs: tl.tensor, allow_tf32: bool, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: - if torch.version.hip is None: - device = triton.runtime.jit.get_current_device() - capability = triton.runtime.jit.get_device_capability(device) - capability = capability[0] * 10 + capability[1] - if capability < 70: - assert ( - not rhs.dtype.is_fp16() and not rhs.dtype.is_fp8() - ), "Float8 and Float16 types are not supported for compute capability < 70 (use Float32 or above)" assert lhs.type.is_block() and rhs.type.is_block() assert lhs.dtype == rhs.dtype, "lhs and rhs must have the same dtype!" assert len(lhs.shape) == 2 and len(rhs.shape) == 2 @@ -1398,6 +1387,10 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl. def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: + cond_ty = cond.type + if not cond_ty.is_block(): + cond_ty = tl.block_type(cond_ty.scalar, (1,)) + cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty) return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 63ce81074..f66cddf37 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -4,17 +4,6 @@ import triton import triton.language as tl -def next_power_of_2(n): - n -= 1 - n |= n >> 1 - n |= n >> 2 - n |= n >> 4 - n |= n >> 8 - n |= n >> 16 - n += 1 - return n - - def num_warps(N): if N < 2048: return 4 @@ -24,7 +13,7 @@ def num_warps(N): @triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) -@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])}) @triton.jit def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): row = tl.program_id(0) @@ -49,7 +38,7 @@ def _forward(LOGITS, PROBS, IDX, LOSS, N, BLOCK: tl.constexpr): @triton.heuristics({'num_warps': lambda nargs: num_warps(nargs['N'])}) -@triton.heuristics({'BLOCK': lambda nargs: next_power_of_2(nargs['N'])}) +@triton.heuristics({'BLOCK': lambda nargs: triton.next_power_of_2(nargs['N'])}) @triton.jit def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): row = tl.program_id(0) diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 0a6394ac1..3cb9f9dbe 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -142,6 +142,7 @@ class Autotuner(KernelInterface): class Config: """ An object that represents a possible kernel configuration for the auto-tuner to try. + :ivar meta: a dictionary of meta-parameters to pass to the kernel as keyword arguments. :type meta: dict[Str, Any] :ivar num_warps: the number of warps to use for the kernel when compiled for GPUs. For example, if @@ -173,8 +174,10 @@ class Config: def autotune(configs, key, prune_configs_by=None, reset_to_zero=None): """ Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python .. code-block:: python + @triton.autotune(configs=[ triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), @@ -223,8 +226,10 @@ def heuristics(values): """ Decorator for specifying how the values of certain meta-parameters may be computed. This is useful for cases where auto-tuning is prohibitevely expensive, or just not applicable. + .. highlight:: python .. code-block:: python + @triton.heuristics(values={'BLOCK_SIZE': lambda args: 2 ** int(math.ceil(math.log2(args[1])))}) @triton.jit def kernel(x_ptr, x_size, **META): diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index c9700de75..102a6d1ba 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -83,7 +83,8 @@ class DependenciesFinder(ast.NodeVisitor): finder = DependenciesFinder(func.__globals__, func.src) finder.visit(tree) func.hash = finder.ret - self.ret = (self.ret + func.hash).encode("utf-8") + noinline = str(getattr(func, 'noinline', False)) + self.ret = (self.ret + func.hash + noinline).encode("utf-8") self.ret = hashlib.md5(self.ret).hexdigest() # ----------------------------------------------------------------------------- @@ -175,7 +176,7 @@ class JITFunction(KernelInterface[T]): return True return False divisible_by_16 = {i for i, arg in enumerate(args) if is_divisible_by_16(arg) and i not in self.do_not_specialize} - equal_to_1 = {i for i, arg in enumerate(args) if isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} + equal_to_1 = {i for i, arg in enumerate(args) if not isinstance(arg, bool) and isinstance(arg, int) and arg == 1 and i not in self.do_not_specialize} return namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1"])(tuple(divisible_by_16), tuple(equal_to_1)) # return _triton.code_gen.instance_descriptor(divisible_by_16, equal_to_1) @@ -298,13 +299,13 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage set_current_device(device) if stream is None and not warmup: stream = get_cuda_stream(device) - try: - bin = cache[device][key] + bin = cache[device].get(key, None) + if bin is not None: if not warmup: bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, {args}) return bin # kernel not cached -- compile - except KeyError: + else: # build dict of constant values args = [{args}] all_args = {', '.join([f'{arg}' for arg in self.arg_names])}, @@ -334,7 +335,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage exec(src, scope) return scope[self.fn.__name__] - def __init__(self, fn, version=None, do_not_specialize=None, debug=None): + def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinline=None): self.fn = fn self.module = fn.__module__ self.version = version @@ -356,6 +357,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage self.kernel_decorators = [] self.kernel = None self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug + self.noinline = noinline # annotations normalize_ty = lambda ty: ty.__name__ if isinstance(ty, type) else ty self.__annotations__ = {name: normalize_ty(ty) for name, ty in fn.__annotations__.items()} @@ -425,6 +427,7 @@ def jit( version=None, do_not_specialize: Optional[Iterable[int]] = None, debug: Optional[bool] = None, + noinline: Optional[bool] = None, ) -> Callable[[T], JITFunction[T]]: ... @@ -435,6 +438,8 @@ def jit( version=None, do_not_specialize: Optional[Iterable[int]] = None, debug: Optional[bool] = None, + noinline: Optional[bool] = None, + interpret: Optional[bool] = None, ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: """ Decorator for JIT-compiling a function using the Triton compiler. @@ -456,13 +461,17 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) - return JITFunction( - fn, - version=version, - do_not_specialize=do_not_specialize, - debug=debug, - ) - + if interpret: + from ..debugger.debugger import GridSelector + return GridSelector(fn) + else: + return JITFunction( + fn, + version=version, + do_not_specialize=do_not_specialize, + debug=debug, + noinline=noinline, + ) if fn is not None: return decorator(fn) diff --git a/python/triton/third_party/cuda/bin/ptxas b/python/triton/third_party/cuda/bin/ptxas new file mode 100755 index 000000000..8b47936ea Binary files /dev/null and b/python/triton/third_party/cuda/bin/ptxas differ diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 67c39a7df..6a5b2e5f5 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -6,8 +6,8 @@ #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> #B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> // CHECK-LABEL: matmul_loop // There shouldn't be any aliasing with the dot op encoding. diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index a4fd69264..5017bfdc4 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -402,8 +402,6 @@ tt.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 // ----- -module { - // This is a tiny test for verifying StoreOp-related alignment, It simply store a constant to a buffer. // CHECK-LABEL: @store_constant_align tt.func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, %n: i32 {tt.divisibility = 16 : i32}) { @@ -433,8 +431,6 @@ tt.func @store_constant_align(%addr: !tt.ptr {tt.divisibility = 16 : i32}, tt.return } -} - // ----- // This IR is dumped from vecadd test. @@ -491,3 +487,88 @@ tt.func @vecadd_mask_align_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, % tt.store %15, %13, %10 : tensor<64xf32> tt.return } + +// ----- + +module { + +// We don't use function cloning here, so the alignment info is the gcd of all call sites. +// CHECK-LABEL: @addptr_hints +tt.func @addptr_hints(%arg0: !tt.ptr) { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %cst1 = arith.constant 1 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %1 = tt.addptr %arg0, %cst1 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 4 + %cst4 = arith.constant 4 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %2 = tt.addptr %arg0, %cst4 : !tt.ptr, i32 + // CHECK-NEXT: contiguity = [1], divisibility = [16], constancy = [1], constant_value = 16 + %cst16 = arith.constant 16 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %3 = tt.addptr %arg0, %cst4 : !tt.ptr, i32 + tt.return +} + +// CHECK-LABEL: @kernel_div16 +tt.func @kernel_div16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +// CHECK-LABEL: @kernel_div8 +tt.func @kernel_div8(%arg0: !tt.ptr {tt.divisibility = 8 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +// CHECK-LABEL: @kernel_div4 +tt.func @kernel_div4(%arg0: !tt.ptr {tt.divisibility = 4 : i32}) { + tt.call @addptr_hints(%arg0) : (!tt.ptr) -> () + tt.return +} + +} + +// ----- + +module { + +// We don't use function cloning here, so the alignment info is the gcd of all call sites. +// CHECK-LABEL: @mul +tt.func @mul(%arg0: i32) { + // CHECK: contiguity = [1], divisibility = [1], constancy = [1], constant_value = 1 + %cst1 = arith.constant 1 : i32 + // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %1 = arith.muli %arg0, %cst1 : i32 + tt.return +} + +// CHECK-LABEL: @bar +tt.func @bar(%arg0: i32) { + tt.call @mul(%arg0) : (i32) -> () + tt.return +} + +// CHECK-LABEL: @foo +tt.func @foo(%arg0: i32) { + tt.call @mul(%arg0) : (i32) -> () + tt.return +} + +// CHECK-LABEL: @call_graph +tt.func @call_graph(%arg0: i32) { + // CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = 12 + %cst12 = arith.constant 12 : i32 + // CHECK: contiguity = [1], divisibility = [4], constancy = [1], constant_value = + %0 = arith.muli %arg0, %cst12 : i32 + tt.call @foo(%0) : (i32) -> () + // CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = 8 + %cst8 = arith.constant 8 : i32 + // CHECK: contiguity = [1], divisibility = [8], constancy = [1], constant_value = + %1 = arith.muli %arg0, %cst8 : i32 + tt.call @bar(%1) : (i32) -> () + tt.return +} + +} diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 174be59b6..820ab2729 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -7,8 +7,8 @@ #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> #B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> module attributes {"triton_gpu.num-warps" = 4 : i32} { @@ -28,10 +28,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK: offset = 0, size = 4608 + // CHECK: scratch offset = 0, size = 4608 %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> - // CHECK-NEXT: offset = 0, size = 4224 + // CHECK-NEXT: scratch offset = 0, size = 4224 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> @@ -56,17 +56,17 @@ tt.func @reusable(%A : !tt.ptr) { %a_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr = tt.broadcast %A : (!tt.ptr) -> tensor<32x128x!tt.ptr, #AL> %a1_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK-NEXT: offset = 0, size = 4608 + // CHECK-NEXT: scratch offset = 0, size = 4608 %a1 = triton_gpu.convert_layout %a1_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %a2_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK-NEXT: offset = 0, size = 1152 + // CHECK-NEXT: scratch offset = 0, size = 1152 %a2 = triton_gpu.convert_layout %a2_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT> %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - // CHECK-NEXT: offset = 0, size = 4608 + // CHECK-NEXT: scratch offset = 0, size = 4608 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> - // CHECK-NEXT: offset = 0, size = 1152 + // CHECK-NEXT: scratch offset = 0, size = 1152 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT> %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return @@ -396,3 +396,127 @@ tt.func @for_for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr, } } + +module attributes {"triton_gpu.num-warps" = 4 : i32} { + +// CHECK-LABEL: alloc1 +tt.func @alloc1(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + tt.return + // CHECK-NEXT: size = 512 +} + +// CHECK-LABEL: alloc2 +tt.func @alloc2(%A : !tt.ptr) { + // CHECK: offset = 0, size = 1024 + %cst0 = triton_gpu.alloc_tensor : tensor<32x16xf16, #A_SHARED> + tt.return + // CHECK-NEXT: size = 1024 +} + +// CHECK-LABEL: alloc3 +tt.func @alloc3(%cond : i1) { + scf.if %cond { + // CHECK: offset = 0, size = 512 + %cst0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + } else { + // CHECK-NEXT: offset = 0, size = 1024 + %cst0 = triton_gpu.alloc_tensor : tensor<16x32xf16, #A_SHARED> + } + tt.return + // CHECK-NEXT: size = 1024 +} + +// CHECK-LABEL: alloc4 +tt.func @alloc4(%A : !tt.ptr, %cond : i1) { + scf.if %cond { + // CHECK: virtual offset = 0, size = 1024 + tt.call @alloc3(%cond) : (i1) -> () + } else { + // CHECK-NEXT: virtual offset = 0, size = 512 + tt.call @alloc1(%A) : (!tt.ptr) -> () + } + tt.return + // CHECK-NEXT: size = 1024 +} + +// CHECK-LABEL: single_call +tt.func @single_call(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK-NEXT: virtual offset = 0, size = 512 + tt.call @alloc1(%A) : (!tt.ptr) -> () + tt.return + // CHECK-NEXT: size = 512 +} + +// CHECK-LABEL: multiple_calls +tt.func @multiple_calls(%A : !tt.ptr) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: virtual offset = 0, size = 512 + tt.call @alloc1(%A) : (!tt.ptr) -> () + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK-NEXT: virtual offset = 0, size = 1024 + tt.call @alloc2(%A) : (!tt.ptr) -> () + tt.return + // CHECK-NEXT: size = 1024 +} + +// CHECK-LABEL: if_else_calls +tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { + scf.if %cond { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: offset = 0, size = 1024 + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> + // CHECK-NEXT: virtual offset = 0, size = 512 + tt.call @alloc1(%A) : (!tt.ptr) -> () + } else { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK-NEXT: virtual offset = 0, size = 1024 + tt.call @alloc2(%A) : (!tt.ptr) -> () + } + tt.return + // CHECK-NEXT: size = 1024 +} + +// CHECK-LABEL: for_calls +tt.func @for_calls(%A : !tt.ptr, %cond : i1) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + %lb = arith.constant 0 : index + %ub = arith.constant 10 : index + %step = arith.constant 1 : index + scf.for %iv = %lb to %ub step %step { + // CHECK-NEXT: virtual offset = 0, size = 512 + tt.call @alloc1(%A) : (!tt.ptr) -> () + } + tt.return + // CHECK-NEXT: size = 512 +} + +// CHECK-LABEL: call_graph_1 +tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: virtual offset = 0, size = 1024 + tt.call @alloc3(%cond) : (i1) -> () + tt.return + // CHECK-NEXT: size = 1024 +} + +// CHECK-LABEL: call_graph_2 +tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { + // CHECK: offset = 0, size = 512 + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK-NEXT: virtual offset = 0, size = 1024 + tt.call @alloc4(%A, %cond) : (!tt.ptr, i1) -> () + tt.return + // CHECK-NEXT: size = 1024 +} + +} diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 4946eeef5..621dd10e2 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -7,8 +7,8 @@ #A_SHARED_T = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [0, 1]}> #B_SHARED = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> -#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> module attributes {"triton_gpu.num-warps" = 4 : i32} { @@ -503,3 +503,136 @@ tt.func @cf_if_else_return(%i1 : i1) { } } + +module attributes {"triton_gpu.num-warps" = 4 : i32} { + +// CHECK-LABEL: convert_layout1 +tt.func @convert_layout1(%A : !tt.ptr) { + // CHECK-NOT: gpu.barrier + %0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + %1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: convert_layout2 +tt.func @convert_layout2(%A : !tt.ptr) { + %0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + %1 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + %2 = tt.cat %1, %1 {axis = 0} : (tensor<16x16xf16, #A_SHARED>, tensor<16x16xf16, #A_SHARED>) -> tensor<32x16xf16, #A_SHARED> + // CHECK: triton_gpu.convert_layout + // CHECK-NEXT: gpu.barrier + %3 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> + %4 = tt.cat %2, %2 {axis = 0} : (tensor<32x16xf16, #A_SHARED>, tensor<32x16xf16, #A_SHARED>) -> tensor<64x16xf16, #AL> + tt.return +} + +// CHECK-LABEL: convert_layout3 +tt.func @convert_layout3(%cond : i1) { + scf.if %cond { + %0 = triton_gpu.alloc_tensor : tensor<16x64xf16, #A_SHARED> + // CHECK: triton_gpu.convert_layout + // CHECK-NOT: gpu.barrier + %1 = triton_gpu.convert_layout %0 : (tensor<16x64xf16, #A_SHARED>) -> tensor<16x64xf16, #AL> + } else { + %0 = triton_gpu.alloc_tensor : tensor<16x16xf16, #A_SHARED> + // CHECK: triton_gpu.convert_layout + // CHECK-NEXT: gpu.barrier + // CHECK-NEXT: triton_gpu.convert_layout + %1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL> + %2 = triton_gpu.convert_layout %1 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A_SHARED> + } + tt.return +} + +// CHEKC-LABEL: convert_layout4 +tt.func @convert_layout4(%A : !tt.ptr, %cond : i1) { + // CHECK-NOT: gpu.barrier + scf.if %cond { + tt.call @convert_layout3(%cond) : (i1) -> () + } else { + tt.call @convert_layout2(%A) : (!tt.ptr) -> () + } + tt.return +} + +// CHECK-LABEL: single_call_sync +tt.func @single_call_sync(%A : !tt.ptr) { + %0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK: tt.call + // CHECK-NEXT: gpu.barrier + tt.call @convert_layout1(%A) : (!tt.ptr) -> () + %1 = triton_gpu.convert_layout %0 : (tensor<16x32xf16, #AL>) -> tensor<16x32xf16, #BL> + tt.return +} + +// CHECK-LABEL: single_call_no_sync +// %1 can reuse %0 in convert_layout2, which has been synced +tt.func @single_call_no_sync(%A : !tt.ptr) { + // CHECK-NOT: gpu.barrier + %0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> + tt.call @convert_layout2(%A) : (!tt.ptr) -> () + %1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #BL> + tt.return +} + +// CHECK-LABEL: multiple_calls +tt.func @multiple_calls(%A : !tt.ptr) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + tt.call @convert_layout1(%A) : (!tt.ptr) -> () + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + tt.call @convert_layout2(%A) : (!tt.ptr) -> () + tt.return +} + +// CHECK-LABEL: if_else_calls +tt.func @if_else_calls(%A : !tt.ptr, %cond : i1) { + scf.if %cond { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.call + // CHECK-NEXT: gpu.barrier + tt.call @convert_layout1(%A) : (!tt.ptr) -> () + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #A_SHARED> + } else { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + // CHECK: tt.call + // CHECK-NOT: gpu.barrier + tt.call @convert_layout2(%A) : (!tt.ptr) -> () + } + tt.return +} + +// CHECK-LABEL: for_calls +tt.func @for_calls(%A : !tt.ptr, %cond : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + %cst1 = arith.constant dense<0.000000e+00> : tensor<16x32xf16, #AL> + %lb = arith.constant 0 : index + %ub = arith.constant 10 : index + %step = arith.constant 1 : index + scf.for %iv = %lb to %ub step %step { + // CHECK: gpu.barrier + // CHECK-NEXT: tt.call + tt.call @convert_layout1(%A) : (!tt.ptr) -> () + } + tt.return +} + +// CHECK-LABEL: call_graph_1 +tt.func @call_graph_1(%A : !tt.ptr, %cond : i1) { + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + // CHECK: gpu.barrier + // CHECK-NEXT: tt.call + tt.call @convert_layout3(%cond) : (i1) -> () + tt.return +} + +// CHECK-LABEL: call_graph_2 +tt.func @call_graph_2(%A : !tt.ptr, %cond : i1) { + tt.call @convert_layout4(%A, %cond) : (!tt.ptr, i1) -> () + // CHECK: tt.call + // CHECK-NEXT: gpu.barrier + %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #A_SHARED> + tt.return +} + +} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index cd52d45bc..3d9667483 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1132,8 +1132,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [1, 1], order = [1, 0]}> #shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}> #mma0 = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[1,1]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, kWidth=2}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, kWidth=2}> module attributes {"triton_gpu.num-warps" = 1 : i32} { // PTX-LABEL: convert_dot // This test is disabled for GCN target, because it is PTX specific @@ -1250,7 +1250,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { // CHECK-LABEL: convert_blocked1d_to_slice1 tt.func @convert_blocked1d_to_slice1(%src:tensor<32xi32, #blocked0>) { - // CHECK-COUNT-32: llvm.load {{.*}} : !llvm.ptr, 3> + // CHECK-COUNT-8: llvm.load {{.*}} : !llvm.ptr, 3> %cvt = triton_gpu.convert_layout %src : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> tt.return } @@ -1279,8 +1279,8 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #mma = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [2, 2]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=2}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=2}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // PTX-LABEL: matmul_kernel_dot_operand_layout // This test is disabled for GCN target, because it is PTX specific @@ -1357,8 +1357,8 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { #mma = #triton_gpu.mma<{versionMajor=2, warpsPerCTA=[2, 2]}> #shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0]}> -#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma}> -#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma}> +#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma, kWidth=1}> +#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma, kWidth=1}> module attributes {"triton_gpu.num-warps" = 4 : i32} { // PTX-LABEL: matmul_tf32dot // This test is disabled for GCN target, because it is PTX specific @@ -1366,12 +1366,21 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { tt.func @matmul_tf32dot(%ptr:!tt.ptr {tt.divisibility = 16 : i32}, %a:tensor<32x16xf32, #shared>, %b:tensor<16x32xf32, #shared>) { %cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma> +<<<<<<< HEAD // PTX: llvm.inline_asm // PTX-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 // PTX-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) // PTX: llvm.inline_asm // PTX-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 // PTX-SAME: (vector<1xf32>, vector<1xf32>, vector<1xf32>, vector<1xf32>) +======= + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + // CHECK-SAME: (i32, i32, i32, i32) + // CHECK: llvm.inline_asm + // CHECK-SAME: ldmatrix.sync.aligned.m8n8.x4.shared.b16 + // CHECK-SAME: (i32, i32, i32, i32) +>>>>>>> openai/main %a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b> @@ -1399,6 +1408,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32 tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) { +<<<<<<< HEAD // GCN-NOT: llvm.inline_asm // GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr, f32 // PTX: llvm.icmp "slt" @@ -1406,6 +1416,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // PTX-SAME: @$3 atom.global.gpu.add.f32 // PTX: llvm.inline_asm // PTX-SAME: @$3 atom.global.gpu.add.f32 +======= + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.add.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.add.f32 +>>>>>>> openai/main %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0> tt.return } @@ -1416,11 +1432,18 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: atomic_add_f32_scalar tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr, %arg1 : i1, %arg2 : f32) { +<<<<<<< HEAD // GCN-NOT: llvm.inline_asm // GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr, f32 // PTX: llvm.icmp "eq" // PTX: llvm.inline_asm // PTX-SAME: @$3 atom.global.gpu.add.f32 +======= + // CHECK: llvm.icmp "eq" + // CHECK: llvm.inline_asm + // CHECK: llvm.inline_asm + // CHECK-SAME: @$3 atom.global.gpu.add.f32 +>>>>>>> openai/main %0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr, f32, i1) -> f32 tt.return } @@ -1432,6 +1455,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32 tt.func @store_f32(%arg0 : tensor<256x!tt.ptr, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) { +<<<<<<< HEAD // GCN-NOT: llvm.inline_asm // GCN: llvm.store {{.*}} : !llvm.ptr // GCN: llvm.store {{.*}} : !llvm.ptr @@ -1440,6 +1464,12 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // PTX-SAME: @$2 st.global.b32 // PTX: llvm.inline_asm // PTX-SAME: @$2 st.global.b32 +======= + // CHECK: llvm.inline_asm + // CHECK-SAME: @$2 st.global.b32 + // CHECK: llvm.inline_asm + // CHECK-SAME: @$2 st.global.b32 +>>>>>>> openai/main tt.store %arg0, %arg1 : tensor<256xf32, #blocked0> tt.return } @@ -1450,11 +1480,17 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK-LABEL: store_f32_scalar tt.func @store_f32_scalar(%arg0 : !tt.ptr, %arg1 : f32) { +<<<<<<< HEAD // GCN-NOT: llvm.inline_asm // GCN: llvm.store {{.*}} : !llvm.ptr // PTX: llvm.icmp "slt" // PTX: llvm.inline_asm // PTX-SAME: @$2 st.global.b32 +======= + // CHECK: llvm.icmp "eq" + // CHECK: llvm.inline_asm + // CHECK-SAME: @$2 st.global.b32 +>>>>>>> openai/main tt.store %arg0, %arg1 : f32 tt.return } diff --git a/test/Target/tritongpu_to_llvmir_noinline.mlir b/test/Target/tritongpu_to_llvmir_noinline.mlir new file mode 100644 index 000000000..d4784af22 --- /dev/null +++ b/test/Target/tritongpu_to_llvmir_noinline.mlir @@ -0,0 +1,22 @@ +// RUN: %PYTHON -m triton.tools.aot %s --target=llvm-ir --sm=80 | FileCheck %s + +// == LLVM IR check begin == +// CHECK-LABEL: ; ModuleID = 'LLVMDialectModule' +// CHECK: define void @test_func +// CHECK: define void @test_kernel +// CHECK: tail call void @test_func + +module attributes {"triton_gpu.num-warps" = 4 : i32} { + +tt.func @test_func(%lb : index, %A : !tt.ptr) attributes { noinline = true } { + %0 = arith.constant 1.0 : f16 + tt.store %A, %0 : f16 + tt.return +} + +tt.func @test_kernel(%lb : index, %A : !tt.ptr) { + tt.call @test_func(%lb, %A) : (index, !tt.ptr) -> () + tt.return +} + +} diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir new file mode 100644 index 000000000..41a65cce4 --- /dev/null +++ b/test/TritonGPU/dot-operands.mlir @@ -0,0 +1,52 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-optimize-dot-operands -tritongpu-remove-layout-conversions -canonicalize | FileCheck %s + +#Cv2 = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#Av2 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv2, kWidth=2}> +#Bv2 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv2, kWidth=2}> +#Cv1 = #triton_gpu.mma<{versionMajor = 1, warpsPerCTA = [4, 1]}> +#Av1 = #triton_gpu.dot_op<{opIdx = 0, parent = #Cv1}> +#Bv1 = #triton_gpu.dot_op<{opIdx = 1, parent = #Cv1}> +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> + +// CHECK: tt.func @push_elementwise1 +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[ALOAD]] +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ACVT]] +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] +// CHECK: %[[C:.*]] = tt.dot %[[AF16]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma> +tt.func @push_elementwise1( + %pa: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv2>) -> tensor<16x16xf32, #Cv2>{ + %ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL> + %b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> + %af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL> + %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL> + %dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av2> + %dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv2> + %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2> * tensor<16x16xf16, #Bv2> -> tensor<16x16xf32, #Cv2> + tt.return %newc : tensor<16x16xf32, #Cv2> +} + +// CHECK: tt.func @push_elementwise2 +// CHECK: %[[ALOAD:.*]] = tt.load %arg0 +// CHECK: %[[AF8E5:.*]] = tt.bitcast %[[ALOAD]] +// CHECK: %[[AF16:.*]] = tt.fp_to_fp %[[AF8E5]] +// CHECK: %[[ACVT:.*]] = triton_gpu.convert_layout %[[AF16]] +// CHECK: %[[C:.*]] = tt.dot %[[ACVT]] +// CHECK: tt.return %[[C]] : tensor<16x16xf32, #mma1> +tt.func @push_elementwise2( + %pa: tensor<16x16x!tt.ptr, #AL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %pb: tensor<16x16x!tt.ptr, #BL> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, + %c: tensor<16x16xf32, #Cv1>) -> tensor<16x16xf32, #Cv1>{ + %ai8 = tt.load %pa {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xi8, #AL> + %b = tt.load %pb {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> + %af8 = tt.bitcast %ai8: tensor<16x16xi8, #AL> -> tensor<16x16xf8E5M2, #AL> + %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #AL> -> tensor<16x16xf16, #AL> + %dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #Av1> + %dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #Bv1> + %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av1> * tensor<16x16xf16, #Bv1> -> tensor<16x16xf32, #Cv1> + tt.return %newc : tensor<16x16xf32, #Cv1> +} diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 5a3a2e437..7ce3f7b59 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -8,8 +8,8 @@ #BLs0 = #triton_gpu.slice<{parent=#BL, dim=0}> #BLs1 = #triton_gpu.slice<{parent=#BL, dim=1}> #C = #triton_gpu.mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> -#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> -#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> +#A = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth=2}> +#B = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth=2}> // CHECK: tt.func @matmul_loop // CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32 diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index e9b0487fd..b820f4034 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -1,4 +1,4 @@ -// RUN: triton-opt %s -split-input-file -tritongpu-prefetch | FileCheck %s +// RUN: triton-opt %s -split-input-file -tritongpu-prefetch -canonicalize | FileCheck %s // 4 warps // matmul: 128x32 @ 32x128 -> 128x128 @@ -7,33 +7,36 @@ #A = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #B = #triton_gpu.shared<{vec = 2, perPhase = 2, maxPhase = 4, order = [1, 0]}> #C = #triton_gpu.mma<{version = 2, warpsPerCTA = [4, 1]}> -#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C}> -#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C}> +#A_OP = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_OP = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> -// CHECK: tt.func @matmul_loop +// CHECK: tt.func @matmul_loop_mixed // CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[A0:.*]][0, 0] [128, 16] // CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]] +// CHECK-DAG: %[[A0_CVT:.*]] = tt.fp_to_fp %[[A0_PREFETCH]] // CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[B0:.*]][0, 0] [16, 128] // CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.convert_layout %[[B0_PREFETCH_SMEM]] -// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_PREFETCH]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] +// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_CVT]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]] // CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_a0]][0, 16] [128, 16] // CHECK-DAG: %[[A_REM:.*]] = triton_gpu.convert_layout %[[A_REM_SMEM]] +// CHECK-DAG: %[[A_REM_CVT:.*]] = tt.fp_to_fp %[[A_REM]] // CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_b0]][16, 0] [16, 128] // CHECK-DAG: %[[B_REM:.*]] = triton_gpu.convert_layout %[[B_REM_SMEM]] // CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}} -// CHECK: tt.dot %[[A_REM]], %[[B_REM]], %[[D_FIRST:.*]] +// CHECK: tt.dot %[[A_REM_CVT]], %[[B_REM]], %[[D_FIRST:.*]] // CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [128, 16] // CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_A_PREFETCH_SMEM]] +// CHECK-DAG: %[[NEXT_A_PREFETCH_CVT:.*]] = tt.fp_to_fp %[[NEXT_A_PREFETCH]] // CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [16, 128] // CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]] -// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]] -tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) { - %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> +// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH_CVT]], %[[NEXT_B_PREFETCH]] +tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr) -> tensor<128x128xf32, #C>{ + %a_ptr_init = tt.broadcast %A : (!tt.ptr) -> tensor<128x32x!tt.ptr, #AL> %b_ptr_init = tt.broadcast %B : (!tt.ptr) -> tensor<32x128x!tt.ptr, #BL> %a_mask = arith.constant dense : tensor<128x32xi1, #AL> - %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL> + %a_other = arith.constant dense<0.00e+00> : tensor<128x32xf8E5M2, #AL> %b_mask = arith.constant dense : tensor<32x128xi1, #BL> %b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL> %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> @@ -41,24 +44,25 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> - %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %a_ = tt.load %a_ptr_init, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf8E5M2, #AL> + %a_init = triton_gpu.convert_layout %a_ : (tensor<128x32xf8E5M2, #AL>) -> tensor<128x32xf8E5M2, #A> %b_ = tt.load %b_ptr_init, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b_init = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) { - %a_op = triton_gpu.convert_layout %a : (tensor<128x32xf16, #A>) -> tensor<128x32xf16, #A_OP> + %loop:5 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %a = %a_init, %b = %b_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf8E5M2, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C>) { + %a_op_ = triton_gpu.convert_layout %a : (tensor<128x32xf8E5M2, #A>) -> tensor<128x32xf8E5M2, #A_OP> + %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> %b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP> %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> - %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> - %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> - %next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A> + %next_a_ = tt.load %next_a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf8E5M2, #AL> + %next_a = triton_gpu.convert_layout %next_a_ : (tensor<128x32xf8E5M2, #AL>) -> tensor<128x32xf8E5M2, #A> %next_b_ = tt.load %next_b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %next_b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf16, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C> + scf.yield %next_a_ptr, %next_b_ptr, %next_a, %next_b, %c : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x32xf8E5M2, #A>, tensor<32x128xf16, #B>, tensor<128x128xf32, #C> } - tt.return + tt.return %loop#4 : tensor<128x128xf32, #C> } diff --git a/test/lib/Analysis/CMakeLists.txt b/test/lib/Analysis/CMakeLists.txt index deafc7b8a..e3c774379 100644 --- a/test/lib/Analysis/CMakeLists.txt +++ b/test/lib/Analysis/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_library(TritonTestAnalysis TestMembar.cpp LINK_LIBS PUBLIC + MLIRPass TritonAnalysis ${dialect_libs} ) diff --git a/test/lib/Analysis/TestAllocation.cpp b/test/lib/Analysis/TestAllocation.cpp index 71cc9f2e2..772e0258b 100644 --- a/test/lib/Analysis/TestAllocation.cpp +++ b/test/lib/Analysis/TestAllocation.cpp @@ -6,7 +6,7 @@ using namespace mlir; namespace { struct TestAllocationPass - : public PassWrapper> { + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAllocationPass); @@ -16,31 +16,37 @@ struct TestAllocationPass } void runOnOperation() override { - Operation *operation = getOperation(); auto &os = llvm::errs(); + ModuleOp moduleOp = getOperation(); // Convert to std::string can remove quotes from opName - auto opName = SymbolTable::getSymbolName(operation).getValue().str(); - os << opName << "\n"; - Allocation allocation(operation); - operation->walk([&](Operation *op) { - auto scratchBufferId = allocation.getBufferId(op); - if (scratchBufferId != Allocation::InvalidBufferId) { - size_t offset = allocation.getOffset(scratchBufferId); - size_t size = allocation.getAllocatedSize(scratchBufferId); - os << "scratch offset = " << offset << ", size = " << size << "\n"; - } - if (op->getNumResults() < 1) - return; - for (Value result : op->getResults()) { - auto bufferId = allocation.getBufferId(result); - if (bufferId != Allocation::InvalidBufferId) { - size_t offset = allocation.getOffset(bufferId); - size_t size = allocation.getAllocatedSize(bufferId); - os << "offset = " << offset << ", size = " << size << "\n"; + ModuleAllocation moduleAllocation(moduleOp); + moduleOp.walk([&](triton::FuncOp funcOp) { + auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); + os << opName << "\n"; + auto *allocation = moduleAllocation.getFuncData(funcOp); + funcOp.walk([&](Operation *op) { + auto scratchBufferId = allocation->getBufferId(op); + if (scratchBufferId != Allocation::InvalidBufferId) { + size_t offset = allocation->getOffset(scratchBufferId); + size_t size = allocation->getAllocatedSize(scratchBufferId); + if (allocation->isVirtualBuffer(scratchBufferId)) + os << "virtual offset = " << offset << ", size = " << size << "\n"; + else + os << "scratch offset = " << offset << ", size = " << size << "\n"; } - } + if (op->getNumResults() < 1) + return; + for (Value result : op->getResults()) { + auto bufferId = allocation->getBufferId(result); + if (bufferId != Allocation::InvalidBufferId) { + size_t offset = allocation->getOffset(bufferId); + size_t size = allocation->getAllocatedSize(bufferId); + os << "offset = " << offset << ", size = " << size << "\n"; + } + } + }); + os << "size = " << allocation->getSharedMemorySize() << "\n"; }); - os << "size = " << allocation.getSharedMemorySize() << "\n"; } }; diff --git a/test/lib/Analysis/TestAxisInfo.cpp b/test/lib/Analysis/TestAxisInfo.cpp index 13d869eeb..57f8895d1 100644 --- a/test/lib/Analysis/TestAxisInfo.cpp +++ b/test/lib/Analysis/TestAxisInfo.cpp @@ -7,7 +7,7 @@ using namespace mlir; namespace { struct TestAxisInfoPass - : public PassWrapper> { + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestAxisInfoPass); @@ -18,23 +18,24 @@ struct TestAxisInfoPass void runOnOperation() override { Operation *operation = getOperation(); - auto &os = llvm::errs(); - auto opName = SymbolTable::getSymbolName(operation).getValue().str(); - os << "@" << opName << "\n"; - - std::unique_ptr solver = createDataFlowSolver(); - AxisInfoAnalysis *analysis = solver->load(); - if (failed(solver->initializeAndRun(operation))) - return signalPassFailure(); - operation->walk([&](Operation *op) { - if (op->getNumResults() < 1) - return; - for (Value result : op->getResults()) { - result.print(os); - os << " => "; - analysis->getLatticeElement(result)->getValue().print(os); - os << "\n"; - } + ModuleOp moduleOp = cast(operation); + ModuleAxisInfoAnalysis moduleAxisInfoAnalysis(moduleOp); + moduleOp.walk([&](triton::FuncOp funcOp) { + auto &os = llvm::errs(); + auto opName = SymbolTable::getSymbolName(funcOp).getValue().str(); + os << "@" << opName << "\n"; + funcOp.walk([&](Operation *op) { + if (op->getNumResults() < 1) + return; + for (Value result : op->getResults()) { + result.print(os); + os << " => "; + auto *axisInfo = moduleAxisInfoAnalysis.getAxisInfo(result); + if (axisInfo) + axisInfo->print(os); + os << "\n"; + } + }); }); } }; diff --git a/test/lib/Analysis/TestMembar.cpp b/test/lib/Analysis/TestMembar.cpp index 4764a69aa..5e7bbb0c8 100644 --- a/test/lib/Analysis/TestMembar.cpp +++ b/test/lib/Analysis/TestMembar.cpp @@ -11,7 +11,7 @@ using namespace mlir; namespace { struct TestMembarPass - : public PassWrapper> { + : public PassWrapper> { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestMembarPass); @@ -22,17 +22,11 @@ struct TestMembarPass void runOnOperation() override { Operation *operation = getOperation(); - auto &os = llvm::errs(); - // Convert to std::string can remove quotes from op_name - auto opName = SymbolTable::getSymbolName(operation).getValue().str(); - os << opName << "\n"; - + ModuleOp moduleOp = cast(operation); // Print all ops after membar pass - Allocation allocation(operation); - MembarAnalysis membarPass(&allocation); + ModuleAllocation allocation(moduleOp); + ModuleMembarAnalysis membarPass(&allocation); membarPass.run(); - - os << *operation << "\n"; } }; diff --git a/unittest/Analysis/CMakeLists.txt b/unittest/Analysis/CMakeLists.txt index 880c8117b..829d0bff7 100644 --- a/unittest/Analysis/CMakeLists.txt +++ b/unittest/Analysis/CMakeLists.txt @@ -1,5 +1,8 @@ add_triton_ut( NAME TestTritonAnalysis SRCS UtilityTest.cpp - LIBS TritonAnalysis + LIBS + TritonAnalysis + TritonIR + TritonGPUIR ) diff --git a/unittest/Dialect/TritonGPU/SwizzleTest.cpp b/unittest/Dialect/TritonGPU/SwizzleTest.cpp index c7dc33d0a..0482327e4 100644 --- a/unittest/Dialect/TritonGPU/SwizzleTest.cpp +++ b/unittest/Dialect/TritonGPU/SwizzleTest.cpp @@ -30,7 +30,7 @@ TEST_P(SwizzleDotOperandTestFixture, DotOperands) { // create encoding auto parent = triton::gpu::MmaEncodingAttr::get(&ctx, 2, 0, {1, 1}); auto encoding = - triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent); + triton::gpu::DotOperandEncodingAttr::get(&ctx, params.opIdx, parent, 0); // create element type Type eltType = IntegerType::get(&ctx, params.typeWidth);