diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml index d413a3dca..7993b5733 100644 --- a/.github/workflows/documentation.yml +++ b/.github/workflows/documentation.yml @@ -25,6 +25,7 @@ jobs: pip3 install tabulate pip3 install cmake pip3 install sphinx + pip3 install myst_parser #- name: Fetch dependent branches # run: | diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 1a7881e0d..b4f4495cf 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -33,6 +33,7 @@ jobs: echo '::set-output name=matrix-optional::["ubuntu-latest"]' fi + Integration-Tests-Nvidia: needs: Runner-Preparation @@ -44,14 +45,14 @@ jobs: steps: - name: Checkout - uses: actions/checkout@v2 - + uses: actions/checkout@v3 + with: + submodules: 'true' - name: Set CUDA ENV if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'V100' || matrix.runner[1] == 'A100' || matrix.runner[1] == 'H100')}} run: | echo "BACKEND=CUDA" >> "${GITHUB_ENV}" echo "ENABLE_TMA=0" >> "${GITHUB_ENV}" - echo "ENABLE_MMA_V3=0" >> "${GITHUB_ENV}" echo "TRITON_DISABLE_LINE_INFO=1" >> "${GITHUB_ENV}" - name: Clear cache @@ -88,24 +89,26 @@ jobs: fi lit -v "${LIT_TEST_DIR}" - - name: Enable MMAV3 and TMA + - name: Enable TMA if: ${{(matrix.runner[0] == 'self-hosted') && (matrix.runner[1] == 'H100')}} run: | echo "ENABLE_TMA=1" >> "${GITHUB_ENV}" - echo "ENABLE_MMA_V3=1" >> "${GITHUB_ENV}" - - name: Run python tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} + - name: Run python tests on CUDA with ENABLE_TMA=1 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} run: | cd python/test/unit - python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py + python3 -m pytest -n 8 --ignore=runtime --ignore=operators --ignore=language/test_line_info.py --ignore=language/test_subprocess.py + python3 -m pytest -n 8 language/test_subprocess.py # run runtime tests serially to avoid race condition with cache handling. python3 -m pytest runtime/ # run test_line_info.py separately with TRITON_DISABLE_LINE_INFO=0 TRITON_DISABLE_LINE_INFO=0 python3 -m pytest language/test_line_info.py + #run hopper/test_flashattention.py to avoid out of gpu memory + python3 -m pytest hopper/test_flashattention.py - - name: Run python tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}} + - name: Run python tests on CUDA with ENABLE_TMA=0 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} run: | cd python/test/unit python3 -m pytest -n 8 --ignore=runtime --ignore=hopper --ignore=operators --ignore=language/test_line_info.py @@ -118,14 +121,22 @@ jobs: run: | rm -rf ~/.triton - - name: Run partial tests on CUDA with ENABLE_TMA=1 and ENABLE_MMA_V3=1 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1' && env.ENABLE_MMA_V3 == '1'}} + - name: Run interpreter tests + env: + # TRITON_INTERPRET: "1" + CUA_VISIBLE_DEVICES: "" + run: | + cd python/test/unit + python3 -m pytest -vs operators/test_flash_attention.py + + - name: Run partial tests on CUDA with ENABLE_TMA=1 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '1'}} run: | cd python/test/unit python3 -m pytest -n 8 operators - - name: Run partial tests on CUDA with ENABLE_TMA=0 and ENABLE_MMA_V3=0 - if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0' && env.ENABLE_MMA_V3 == '0'}} + - name: Run partial tests on CUDA with ENABLE_TMA=0 + if: ${{ env.BACKEND == 'CUDA' && env.ENABLE_TMA == '0'}} run: | cd python/test/unit python3 -m pytest -n 8 operators @@ -160,6 +171,50 @@ jobs: python3 -m pytest -vs . --reruns 10 sudo nvidia-smi -i 0 -rgc + Integration-Tests-Shared-Middle-Layer: + + runs-on: ubuntu-latest + + steps: + - name: Checkout + uses: actions/checkout@v2 + + - name: Clear cache + run: | + rm -rf ~/.triton + + - name: Update PATH + run: | + echo "PATH=${HOME}/.local/bin:${PATH}" >> "${GITHUB_ENV}" + + - name: Check pre-commit + run: | + python3 -m pip install --upgrade pre-commit + python3 -m pre_commit run --all-files --verbose + + - name: Install Triton + run: | + export TRITON_CODEGEN_TRITON_SHARED=1 + git submodule update --init --recursive + cd python + python3 -m pip install --upgrade pip + python3 -m pip install cmake==3.24 + python3 -m pip install ninja + python3 -m pip uninstall -y triton + python3 setup.py build + python3 -m pip install --no-build-isolation -vvv '.[tests]' + + - name: Run shared middle-layer lit tests + run: | + python3 -m pip install lit + cd python + LIT_TEST_DIR="build/$(ls build | grep -i cmake)/third_party/triton_shared/test" + if [ ! -d "${LIT_TEST_DIR}" ]; then + echo "Coult not find '${LIT_TEST_DIR}'" ; exit -1 + fi + lit -v "${LIT_TEST_DIR}" + + Integration-Tests-Third-Party: needs: Runner-Preparation if: false diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index b841e8942..98f05e720 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -46,8 +46,8 @@ jobs: export CIBW_MANYLINUX_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" #export CIBW_MANYLINUX_PYPY_X86_64_IMAGE="quay.io/pypa/manylinux2014_x86_64:latest" export CIBW_BEFORE_BUILD="pip install cmake;" - export CIBW_SKIP="{cp,pp}{35,36}-*" - export CIBW_BUILD="{cp,pp}3*-manylinux_x86_64" + export CIBW_SKIP="cp{35,36}-*" + export CIBW_BUILD="cp3*-manylinux_x86_64" python3 -m cibuildwheel python --output-dir wheelhouse - name: Install Azure CLI diff --git a/.gitignore b/.gitignore index d03c11525..47fe5db52 100644 --- a/.gitignore +++ b/.gitignore @@ -26,8 +26,27 @@ venv.bak/ cmake-build-* # Third-party binaries +cuobjdump +nvdisasm ptxas +<<<<<<< HEAD # HIP log* python/triton/third_party/cuda/bin/ptxas +======= +# Docs +docs/_build/ +docs/python-api/generated/ +docs/dialects/ +docs/getting-started/tutorials +!python/tutorials/*.py +!python/tutorials/*.rst + +# clangd index. (".clangd" is a config file now, thus trailing slash) +.clangd/ +.cache +/compile_commands.json +.vscode +.vs +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 diff --git a/.gitmodules b/.gitmodules index 30ba43425..3a989c6cc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -5,3 +5,6 @@ path = third_party/amd_hip_backend url = https://github.com/ROCmSoftwarePlatform/triton branch = third_party_backend_2 +[submodule "third_party/triton_shared"] + path = third_party/triton_shared + url = https://github.com/microsoft/triton-shared diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dfe756c27..1729fc92e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,7 @@ repos: ^docs/conf.py$ ) - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v14.0.6 + rev: v16.0.6 hooks: - id: clang-format stages: [commit, push, manual] diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ac51329b..d4c2e4a06 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -248,7 +248,6 @@ if(TRITON_BUILD_PYTHON_MODULE) TritonNvidiaGPUTransforms TritonLLVMIR TritonPTX - TritonHSACO ${dialect_libs} ${conversion_libs} diff --git a/README.md b/README.md index 5f88299c9..7a0df5bc8 100644 --- a/README.md +++ b/README.md @@ -10,32 +10,6 @@ We're hiring! If you are interested in working on Triton at OpenAI, we have role ------------------- | [![Documentation](https://github.com/openai/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) -# Triton Developer Conference Registration Open -The Triton Developer Conference will be held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference will be held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. Please use the link below to register to attend either in-person or virtually online. - -Registration Link for Triton Developer Conference is [here](https://forms.office.com/r/m4jQXShDts) - -Tentative Agenda for the conference (subject to change): - -|Time |Title |Speaker -|--------|-------|-------| -|10:00 AM|Welcome|Kevin Scott (Microsoft)| -|10:20 AM|The Triton Compiler: Past, Present and Future|Phil Tillet (OpenAI)| -|11:00 AM|**Break**|| -|11:20 AM|Hopper support in Triton|Gustav Zhu (Nvidia)| -|11:40 AM|Bringing Triton to AMD GPUs|Jason Furmanek, Lixun Zhang (AMD)| -|12:00 PM|Intel XPU Backend for Triton|Eikan Wang (Intel)| -|12:20 PM|Vectorization of Triton Kernels for Qualcomm Hexagon Backend|Javed Absar (Qualcomm)| -|12:30 PM|**Lunch**|| -|1:40 PM |Triton for MTIA|Roman Levenstein et al, (Meta)| -|2:00 PM |Using Triton IR for high-performance fusions in XLA|George Karpenkov (Google)| -|2:20 PM |Triton for All: Triton as a device-independent language|Ian Bearman (Microsoft)| -|2:40 PM|**Break**|| -|3:00 PM|PyTorch 2.0 and TorchInductor|Jason Ansel, Horace He (Meta)| -|3:20 PM|Pallas: A JAX Kernel Language|Sharad Vikram (Google)| -|3:40 PM|Writing Grouped GEMMs in Triton|Vinod Grover (Nvidia)| -|4:00 PM|**Reception**|| - # Triton @@ -86,9 +60,29 @@ lit -v test ``` git clone https://github.com/openai/triton.git; +<<<<<<< HEAD cd triton/python; pip install ninja cmake; # build-time dependencies pip install -e . +======= +cd triton; + +pip install ninja cmake wheel; # build-time dependencies +pip install -e python +``` + +Or with a virtualenv: + +``` +git clone https://github.com/openai/triton.git; +cd triton; + +python -m venv .venv --prompt triton; +source .venv/bin/activate; + +pip install ninja cmake wheel; # build-time dependencies +pip install -e python +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ``` # Building with a custom LLVM @@ -111,6 +105,7 @@ arbitrary LLVM version. modifications to LLVM. 3. [Build LLVM](https://llvm.org/docs/CMake.html). For example, you might run +<<<<<<< HEAD $ cd $HOME/llvm-project # your clone of LLVM. $ mkdir build @@ -130,6 +125,62 @@ arbitrary LLVM version. LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ LLVM_SYSPATH=$LLVM_BUILD_DIR \ pip install -e . +======= + + $ cd $HOME/llvm-project # your clone of LLVM. + $ mkdir build + $ cd build + $ cmake -G Ninja -DCMAKE_BUILD_TYPE=Release -DLLVM_ENABLE_ASSERTIONS=ON ../llvm -DLLVM_ENABLE_PROJECTS="mlir;llvm" + $ ninja + +4. Grab a snack, this will take a while. + +5. Build Triton as above, but set the following environment variables. + + # Modify as appropriate to point to your LLVM build. + $ export LLVM_BUILD_DIR=$HOME/llvm-project/build + + $ cd + $ LLVM_INCLUDE_DIRS=$LLVM_BUILD_DIR/include \ + LLVM_LIBRARY_DIR=$LLVM_BUILD_DIR/lib \ + LLVM_SYSPATH=$LLVM_BUILD_DIR \ + pip install -e python + +# Tips for building + +- Set `TRITON_BUILD_WITH_CLANG_LLD=true` as an environment variable to use clang + and lld. lld in particular results in faster builds. + +- Set `TRITON_BUILD_WITH_CCACHE=true` to build with ccache. + +- Pass `--no-build-isolation` to `pip install` to make nop builds faster. + Without this, every invocation of `pip install` uses a different symlink to + cmake, and this forces ninja to rebuild most of the `.a` files. + +# Running tests + +There currently isn't a turnkey way to run all the Triton tests, but you can +follow the following recipe. + +```shell +# One-time setup. Note we have to reinstall local Triton because torch +# overwrites it with the public version. +$ pip install scipy numpy torch pytest lit && pip install -e python + +# Run Python tests using your local GPU. +$ python3 -m pytest python/test/unit + +# Move to builddir. Fill in <...> with the full path, e.g. +# `cmake.linux-x86_64-cpython-3.11`. +$ cd python/build/cmake<...> + +# Run C++ unit tests. +$ ninja test + +# Run lit tests. +$ lit test +``` +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # Changelog diff --git a/bin/CMakeLists.txt b/bin/CMakeLists.txt index a8966c5e7..9da8e5628 100644 --- a/bin/CMakeLists.txt +++ b/bin/CMakeLists.txt @@ -53,7 +53,6 @@ llvm_update_compile_flags(triton-translate) TritonNvidiaGPUTransforms TritonLLVMIR TritonPTX - TritonHSACO ${dialect_libs} ${conversion_libs} # tests @@ -80,3 +79,20 @@ llvm_update_compile_flags(triton-translate) MLIRROCDLToLLVMIRTranslation ) mlir_check_all_link_libraries(triton-translate) + +add_llvm_executable(triton-llvm-opt + triton-llvm-opt.cpp + + DEPENDS + intrinsics_gen + SUPPORT_PLUGINS + ) +target_link_libraries(triton-llvm-opt PRIVATE + TritonLLVMIR + + LLVMCore + LLVMSupport + LLVMOption + LLVMCodeGen + ) +export_executable_symbols_for_plugins(triton-llvm-opt) diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 5cf1c3a25..e88a9a539 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -1,4 +1,5 @@ #pragma once +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" @@ -11,6 +12,7 @@ #include "triton/Conversion/TritonGPUToLLVM/Passes.h" #include "triton/Conversion/TritonToTritonGPU/Passes.h" +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/InitAllPasses.h" namespace mlir { @@ -40,5 +42,6 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect, mlir::triton::gpu::TritonGPUDialect, mlir::math::MathDialect, mlir::arith::ArithDialect, mlir::scf::SCFDialect, - mlir::gpu::GPUDialect>(); + mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, + mlir::NVVM::NVVMDialect, mlir::triton::nvgpu::NVGPUDialect>(); } diff --git a/bin/triton-llvm-opt.cpp b/bin/triton-llvm-opt.cpp new file mode 100644 index 000000000..fe82a1dce --- /dev/null +++ b/bin/triton-llvm-opt.cpp @@ -0,0 +1,114 @@ +/// Trimmed down clone of llvm opt to be able to test triton custom llvm ir +/// passes. +#include "lib/Target/LLVMIR/LLVMPasses.h" +#include "llvm/CodeGen/CommandFlags.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" +#include "llvm/IRReader/IRReader.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/SystemUtils.h" +#include "llvm/Support/ToolOutputFile.h" +#include "llvm/TargetParser/Triple.h" +#include + +using namespace llvm; + +static cl::opt InputFilename(cl::Positional, + cl::desc(""), + cl::init("-"), + cl::value_desc("filename")); + +static cl::opt ClDataLayout("data-layout", + cl::desc("data layout string to use"), + cl::value_desc("layout-string"), + cl::init("")); +static cl::opt + TargetTriple("mtriple", cl::desc("Override target triple for module")); + +static cl::opt + BreakStructPhiNodes("break-struct-phi-nodes", + llvm::cl::desc("run pass to break phi struct"), + cl::init(false)); + +namespace { +static std::function makeOptimizingPipeline() { + return [](Module *m) -> Error { + PipelineTuningOptions tuningOptions; + PassBuilder pb(nullptr, tuningOptions); + + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + llvm::FunctionPassManager fpm; + if (BreakStructPhiNodes) + fpm.addPass(BreakStructPhiNodesPass()); + mpm.addPass(createModuleToFunctionPassAdaptor(std::move(fpm))); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + +int main(int argc, char **argv) { + InitLLVM X(argc, argv); + cl::ParseCommandLineOptions( + argc, argv, "llvm .bc -> .bc modular optimizer and analysis printer\n"); + + LLVMContext Context; + SMDiagnostic Err; + + // Load the input module... + auto SetDataLayout = [](StringRef, StringRef) -> std::optional { + if (ClDataLayout.empty()) + return std::nullopt; + return ClDataLayout; + }; + std::unique_ptr M; + M = parseIRFile(InputFilename, Err, Context, ParserCallbacks(SetDataLayout)); + if (!M) { + Err.print(argv[0], errs()); + return 1; + } + // If we are supposed to override the target triple or data layout, do so now. + if (!TargetTriple.empty()) + M->setTargetTriple(Triple::normalize(TargetTriple)); + auto optPipeline = makeOptimizingPipeline(); + if (auto err = optPipeline(M.get())) { + llvm::errs() << "Failed to optimize LLVM IR " << err << "\n"; + } + + if (verifyModule(*M, &errs())) { + errs() << argv[0] << ": " << InputFilename + << ": error: input module is broken!\n"; + return 1; + } + + // Write to standard output. + std::unique_ptr Out; + std::string OutputFilename = "-"; + std::error_code EC; + sys::fs::OpenFlags Flags = sys::fs::OF_TextWithCRLF; + Out.reset(new ToolOutputFile(OutputFilename, EC, Flags)); + if (EC) { + errs() << EC.message() << '\n'; + return 1; + } + Out->os() << *M << "\n"; + return 0; +} diff --git a/bin/triton-translate.cpp b/bin/triton-translate.cpp index ee1da6a20..1b81425ce 100644 --- a/bin/triton-translate.cpp +++ b/bin/triton-translate.cpp @@ -15,7 +15,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" -#include "triton/Target/HSACO/HSACOTranslation.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "triton/Target/PTX/PTXTranslation.h" #include "llvm/IR/LLVMContext.h" @@ -143,11 +142,14 @@ LogicalResult tritonTranslateMain(int argc, char **argv, } else if (targetKind == "ptx") { llvm::outs() << ::triton::translateLLVMIRToPTX(*llvmir, SMArch.getValue(), ptxVersion.getValue()); +<<<<<<< HEAD } else if (targetKind == "hsaco") { auto [module, hsaco] = mlir::triton::translateLLVMIRToHSACO( *llvmir, GCNArch.getValue(), GCNTriple.getValue(), GCNFeatures.getValue()); llvm::outs() << hsaco; +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } else { llvm::errs() << "Error: Unknown target specified: " << targetKind << "\n"; return failure(); diff --git a/docs/conf.py b/docs/conf.py index 23ff8ecc9..54ca52468 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -24,7 +24,10 @@ import os +import shutil import sys +import sysconfig +from pathlib import Path import sphinx_rtd_theme from sphinx_gallery.sorting import FileNameSortKey @@ -36,6 +39,58 @@ def process_sig(app, what, name, obj, options, signature, return_annotation): return (signature, return_annotation) +def get_cmake_dir(): + plat_name = sysconfig.get_platform() + python_version = sysconfig.get_python_version() + dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}" + cmake_dir = Path("../python") / "build" / dir_name + return cmake_dir + + +def setup_generated_mlir_docs(): + dst_path = Path("dialects") + os.makedirs(dst_path, exist_ok=True) + + cmake_dir = get_cmake_dir() + src_dir = cmake_dir / "docs" / "dialects" + assert os.path.isdir(src_dir) + + shutil.copytree(src_dir, dst_path, dirs_exist_ok=True) + + files = os.listdir(dst_path) + + dialects = "\n ".join(["./" + f for f in files if "Dialect" in f]) + ops = [f for f in files if "Ops" in f] + + # Add titles + for op in ops: + with open(dst_path / op, 'r+') as f: + lines = f.readlines() + lines.insert(0, "# " + op.split(".md")[0]) + f.seek(0) + f.writelines(lines) + ops = "\n ".join(["./" + op for op in ops]) + + rst_string = f""" +Triton MLIR Dialects and Ops +===================== + +.. toctree:: + :maxdepth: 1 + :caption: Dialects + + {dialects} + +.. toctree:: + :maxdepth: 1 + :caption: Dialect Ops + + {ops} +""" + with open(dst_path / "dialects.rst", "w+") as f: + f.write(rst_string) + + def setup(app): """Customize function args retrieving to get args under decorator.""" import os @@ -44,6 +99,7 @@ def setup(app): app.connect("autodoc-process-signature", process_sig) os.system("pip install -e ../python") + setup_generated_mlir_docs() def forward_jit_fn(func): old = func @@ -82,7 +138,8 @@ extensions = [ 'sphinx.ext.autosummary', 'sphinx.ext.coverage', 'sphinx.ext.napoleon', - 'sphinx_multiversion'] + 'sphinx_multiversion', + 'myst_parser'] autosummary_generate = True # versioning config diff --git a/docs/index.rst b/docs/index.rst index 080f94239..b72c9352e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -37,6 +37,18 @@ Python API python-api/triton.testing +Triton MLIR Dialects and Ops +-------------------- + +- :doc:`Triton MLIR Dialects and Ops ` + +.. toctree:: + :maxdepth: 1 + :caption: Triton MLIR Dialects + :hidden: + + dialects/dialects + Going Further ------------- diff --git a/docs/meetups/tritonconf2023.md b/docs/meetups/tritonconf2023.md new file mode 100644 index 000000000..27719b107 --- /dev/null +++ b/docs/meetups/tritonconf2023.md @@ -0,0 +1,27 @@ +The conference slides are available [here](https://drive.google.com/drive/folders/1yDFc4ElNN_GGhWDdMlM4wcm5uFEFFVQk?usp=sharing) + +The conference videos will be available [here](https://youtube.com/playlist?list=PLc_vA1r0qoiRZfUC3o4_yjj0FtWvodKAz&feature=shared) when ready. + +# Triton Developer Conference +The Triton Developer Conference was held in a hybrid mode at the Microsoft Silicon Valley Campus in Mountain View, California. The conference was held on September 20th from 10am to 4pm, followed by a reception till 5:30 pm. + +Agenda for the conference: + +|Time |Title |Speaker +|--------|-------|-------| +|10:00 AM|Welcome|Kevin Scott (Microsoft)| +|10:20 AM|The Triton Compiler: Past, Present and Future|Phil Tillet (OpenAI)| +|11:00 AM|**Break**|| +|11:20 AM|Hopper support in Triton|Gustav Zhu (Nvidia)| +|11:40 AM|Bringing Triton to AMD GPUs|Jason Furmanek, Lixun Zhang (AMD)| +|12:00 PM|Intel XPU Backend for Triton|Eikan Wang (Intel)| +|12:20 PM|Vectorization of Triton Kernels for Qualcomm Hexagon Backend|Javed Absar (Qualcomm)| +|12:30 PM|**Lunch**|| +|1:40 PM |Triton for MTIA|Roman Levenstein et al, (Meta)| +|2:00 PM |Using Triton IR for high-performance fusions in XLA|George Karpenkov (Google)| +|2:20 PM |Triton for All: Triton as a device-independent language|Ian Bearman (Microsoft)| +|2:40 PM|**Break**|| +|3:00 PM|PyTorch 2.0 and TorchInductor|Jason Ansel, Horace He (Meta)| +|3:20 PM|Pallas: A JAX Kernel Language|Sharad Vikram (Google)| +|3:40 PM|Writing Grouped GEMMs in Triton|Vinod Grover (Nvidia)| +|4:00 PM|**Reception**|| diff --git a/include/triton/Analysis/Allocation.h b/include/triton/Analysis/Allocation.h index 22106f3e1..4d861bb02 100644 --- a/include/triton/Analysis/Allocation.h +++ b/include/triton/Analysis/Allocation.h @@ -21,6 +21,7 @@ class AllocationAnalysis; SmallVector getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, unsigned &outVec); +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op); } // namespace triton diff --git a/include/triton/Analysis/Utility.h b/include/triton/Analysis/Utility.h index ef167f561..658edd183 100644 --- a/include/triton/Analysis/Utility.h +++ b/include/triton/Analysis/Utility.h @@ -36,7 +36,9 @@ public: triton::ReduceOp getOperation() { return op; } - bool isFastReduction(); + bool isReductionOnLayoutFastAxis(); + + unsigned getThreadOffsetOnReductionAxis(); bool isWarpSynchronous(); @@ -50,14 +52,16 @@ public: unsigned getThreadsReductionAxis(); - SmallVector getScratchConfigBasic(); - - SmallVector> getScratchConfigsFast(); + SmallVector getScratchConfig(); unsigned getScratchSizeInBytes(); bool isSupportedLayout(); + bool isReduceWithinCTA(); + + unsigned getAxis() { return axis; } + private: triton::ReduceOp op; ArrayRef srcShape; @@ -84,8 +88,12 @@ public: unsigned getNonAxisNumThreadsPerCTA(); // Return the number of warps per CTA along axis dim. unsigned getAxisNumWarps(); + // Return the number of warps per CTA along axis dim with unique data. + unsigned getAxisNumWarpsWithUniqueData(); // Return the number of threads per warp along axis dim. unsigned getAxisNumThreadsPerWarp(); + // Return the number of threads per warp along axis dim with unique data. + unsigned getAxisNumThreadsPerWarpWithUniqueData(); // Return the number of blocks along axis dim. unsigned getAxisNumBlocks(); // Return the number of blocks along non axis dim. @@ -103,6 +111,7 @@ public: Location getLoc() { return scanOp.getLoc(); } unsigned getAxis() { return scanOp.getAxis(); } triton::gpu::BlockedEncodingAttr getEncoding(); + llvm::ArrayRef getShape(); Region &getCombineOp(); private: @@ -128,6 +137,10 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy); +// Return true if the src and dst layout match. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy); + // TODO: Move utility functions that belong to ConvertLayoutOp to class // ConvertLayoutOpHelper in the future bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout); diff --git a/include/triton/Conversion/TritonGPUToLLVM/Passes.td b/include/triton/Conversion/TritonGPUToLLVM/Passes.td index f94b8d30a..3d4e1d4e3 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Passes.td +++ b/include/triton/Conversion/TritonGPUToLLVM/Passes.td @@ -27,9 +27,6 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp" Option<"computeCapability", "compute-capability", "int32_t", /*default*/"80", "device compute capability">, - Option<"tmaMetadata", "tma-metadata", - "mlir::triton::gpu::TMAMetadataTy*", /*default*/"nullptr", - "tma metadata to the runtime">, Option<"target", "target", "enum Target", "mlir::triton::Target::Default", "compile for target compatible LLVM", "llvm::cl::values(" diff --git a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h index 3be5c9009..54f26145b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h @@ -21,7 +21,8 @@ enum Target { NVVM, ROCDL, Default = NVVM }; std::unique_ptr> createConvertTritonGPUToLLVMPass(); std::unique_ptr> -createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options); +createConvertTritonGPUToLLVMPass(int32_t computeCapability, Target target, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata); } // namespace triton diff --git a/include/triton/Dialect/NVGPU/IR/CMakeLists.txt b/include/triton/Dialect/NVGPU/IR/CMakeLists.txt index aa965dac6..f8932cdc4 100644 --- a/include/triton/Dialect/NVGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/NVGPU/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(LLVM_TARGET_DEFINITIONS NVGPUOps.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=nvgpu) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=nvgpu) @@ -6,6 +8,8 @@ mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(NVGPUDialect NVGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(NVGPUOps NVGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(NVGPUTableGen) set(LLVM_TARGET_DEFINITIONS NVGPUAttrDefs.td) diff --git a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td index a9451984c..896a27c17 100644 --- a/include/triton/Dialect/NVGPU/IR/NVGPUOps.td +++ b/include/triton/Dialect/NVGPU/IR/NVGPUOps.td @@ -97,7 +97,15 @@ def WGMMADesc_ModeAttr : I32EnumAttr<"WGMMADescMode", } def NVGPU_WGMMADescCreateOp : NVGPU_Op<"wgmma_desc_create", []> { - let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode); + let arguments = (ins LLVM_AnyPointer:$buffer, I32:$height, WGMMADesc_ModeAttr:$mode, I64Attr:$swizzling); + let builders = [ + OpBuilder<(ins "Value":$buffer, + "Value":$height, + "WGMMADescMode":$mode), [{ + uint32_t mode_ = static_cast(mode); + uint64_t swizzling = (mode_ == 1 ? 128 : mode_ == 2 ? 64 : 32); + build($_builder, $_state, $_builder.getIntegerType(64), buffer, height, WGMMADescModeAttr::get($_builder.getContext(), mode), $_builder.getI64IntegerAttr(swizzling)); + }]>]; let results = (outs I64:$res); let assemblyFormat = "$buffer `,` $height attr-dict `:` functional-type(operands, results)"; } @@ -140,12 +148,12 @@ def WGMMA_EltTypeAttr : I32EnumAttr<"WGMMAEltType", def WGMMA_OperandType : AnyTypeOf<[LLVM_AnyStruct, I64], "wgmma operand A/B type">; def NVGPU_WGMMAOp : NVGPU_Op<"wgmma", []> { - let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, LLVM_AnyStruct:$opC, + let arguments = (ins WGMMA_OperandType:$opA, WGMMA_OperandType:$opB, Optional:$opC, I32Attr:$m, I32Attr:$n, I32Attr:$k, WGMMA_EltTypeAttr:$eltTypeC, WGMMA_EltTypeAttr:$eltTypeA, WGMMA_EltTypeAttr:$eltTypeB, WGMMA_LayoutAttr:$layoutA, WGMMA_LayoutAttr:$layoutB); let results = (outs LLVM_AnyStruct:$res); - let assemblyFormat = "$opA `,` $opB `,` $opC attr-dict `:` functional-type(operands, $res)"; + let assemblyFormat = "$opA `,` $opB (`,` $opC^)? attr-dict `:` functional-type(operands, $res)"; } def NVGPU_CGABarrierSyncOp : NVGPU_Op<"cga_barrier_sync", []> { diff --git a/include/triton/Dialect/Triton/IR/CMakeLists.txt b/include/triton/Dialect/Triton/IR/CMakeLists.txt index 84bd723f6..42e6c039d 100644 --- a/include/triton/Dialect/Triton/IR/CMakeLists.txt +++ b/include/triton/Dialect/Triton/IR/CMakeLists.txt @@ -1,12 +1,16 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(LLVM_TARGET_DEFINITIONS TritonOps.td) mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(OpsEnums.h.inc -gen-enum-decls) mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs) +add_mlir_doc(TritonOps TritonOps dialects/ -gen-op-doc) set(LLVM_TARGET_DEFINITIONS TritonDialect.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs) +add_mlir_doc(TritonDialect TritonDialect dialects/ -gen-dialect-doc) set(LLVM_TARGET_DEFINITIONS TritonTypes.td) mlir_tablegen(Types.h.inc -gen-typedef-decls) diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 69cad2bcf..575db87be 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -394,7 +394,12 @@ def TT_DotOp : TT_Op<"dot", [Pure, $d = matrix_multiply($a, $b) + $c }]; - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); + let arguments = (ins + TT_FpIntTensor:$a, + TT_FpIntTensor:$b, + TT_FpIntTensor:$c, + BoolAttr:$allowTF32, + I32Attr:$maxNumImpreciseAcc); let results = (outs TT_FpIntTensor:$d); diff --git a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt index d32192749..7b7ca5593 100644 --- a/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonGPU/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(LLVM_TARGET_DEFINITIONS TritonGPUOps.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_gpu) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_gpu) @@ -5,6 +7,8 @@ mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_gpu) mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_gpu) +add_mlir_doc(TritonGPUDialect TritonGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonGPUOps TritonGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonGPUAttrDefs.td) diff --git a/include/triton/Dialect/TritonGPU/IR/Dialect.h b/include/triton/Dialect/TritonGPU/IR/Dialect.h index 10192b0f4..fad35fe00 100644 --- a/include/triton/Dialect/TritonGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonGPU/IR/Dialect.h @@ -7,7 +7,6 @@ #include "mlir/IR/Dialect.h" // TritonGPU depends on Triton -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h.inc" diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td index 947aa3f4d..8e24c915a 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td @@ -113,6 +113,7 @@ compared to 1*64 when the hasLeadingOffset is false. "ArrayRef":$order, "CTALayoutAttr":$CTALayout, "unsigned":$typeWidthInBit), [{ +<<<<<<< HEAD #ifdef USE_ROCM // ---- begin GFX908/GFX90A ---- @@ -155,6 +156,18 @@ compared to 1*64 when the hasLeadingOffset is false. } } #endif +======= + bool needTrans = false; // default value + return get(context, dotOpEnc, shape, order, CTALayout, typeWidthInBit, needTrans); + }]>, + + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "unsigned":$typeWidthInBit, + "bool":$needTrans), [{ +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 auto mmaEnc = dotOpEnc.getParent().dyn_cast(); if(!mmaEnc) @@ -194,16 +207,23 @@ compared to 1*64 when the hasLeadingOffset is false. // --- handle A operand --- if (opIdx == 0) { // compute swizzling for A operand - int vec = (order[0] == 1) ? matShape[2] : matShape[0]; // k : m - int mmaStride = (order[0] == 1) ? matShape[0] : matShape[2]; + int m = (needTrans) ? matShape[2] : matShape[0]; + int k = (needTrans) ? matShape[0] : matShape[2]; + int vec = (order[0] == 1) ? k : m; + int mmaStride = (order[0] == 1) ? m : k; int maxPhase = mmaStride / perPhase; return get(context, vec, perPhase, maxPhase, order, CTALayout); } // --- handle B operand --- if (opIdx == 1) { - int vec = (order[0] == 1) ? matShape[1] : matShape[2]; // n : k - int mmaStride = (order[0] == 1) ? matShape[2] : matShape[1]; + // we compute vec and maxPhase m, n and k size of the mma + // instruction. when matmul operands is transposed, we should + // consider that to get m, n and k. + int n = needTrans ? matShape[2] : matShape[1]; + int k = needTrans ? matShape[1] : matShape[2]; + int vec = (order[0] == 1) ? n : k; + int mmaStride = (order[0] == 1) ? k : n; int maxPhase = mmaStride / perPhase; return get(context, vec, perPhase, maxPhase, order, CTALayout); } @@ -231,6 +251,16 @@ compared to 1*64 when the hasLeadingOffset is false. return get(context, dotOpEnc, shape, order, CTALayout, bitwidth); }]>, + AttrBuilder<(ins "DotOperandEncodingAttr":$dotOpEnc, + "ArrayRef":$shape, + "ArrayRef":$order, + "CTALayoutAttr":$CTALayout, + "Type":$eltTy, + "bool":$needTrans), [{ + unsigned bitwidth = eltTy.getIntOrFloatBitWidth(); + return get(context, dotOpEnc, shape, order, CTALayout, bitwidth, needTrans); + }]>, + AttrBuilder<(ins "ArrayRef":$shape, "ArrayRef":$order, "CTALayoutAttr":$CTALayout, diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td index a4cc9eca0..d52647856 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td @@ -16,7 +16,6 @@ def TritonGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", - "mlir::triton::nvgpu::NVGPUDialect", "mlir::gpu::GPUDialect", "tensor::TensorDialect", ]; diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index d91fa0764..c69b62eb5 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -28,7 +28,7 @@ def TTG_ConvertLayoutOp : TTG_Op<"convert_layout", let results = (outs TT_Tensor:$result); - let hasCanonicalizeMethod = 1; + let hasCanonicalizer = 1; let assemblyFormat = "$src attr-dict `:` functional-type(operands, results)"; } diff --git a/include/triton/Dialect/TritonGPU/Transforms/Utility.h b/include/triton/Dialect/TritonGPU/Transforms/Utility.h index 6c0193182..fe9f9f8c5 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Utility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Utility.h @@ -141,13 +141,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, ArrayRef shape); -// Returns null if the op is not inside a agent region (warp specialization -// mode). Note that there should be at most one agent id attached to the -// operation. -std::optional getWSAgentId(Operation *op); -std::optional getWSRoleId(Operation *op); -void setRoleId(Operation *op, int roleId); - } // namespace mlir #endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt index aba08ab13..b7ce83fe7 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/CMakeLists.txt @@ -1,3 +1,5 @@ +set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR}) + set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUOps.td) mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=triton_nvidia_gpu) mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=triton_nvidia_gpu) @@ -5,6 +7,8 @@ mlir_tablegen(Ops.h.inc -gen-op-decls) mlir_tablegen(Ops.cpp.inc -gen-op-defs) mlir_tablegen(Types.h.inc -gen-typedef-decls -typedefs-dialect=triton_nvidia_gpu) mlir_tablegen(Types.cpp.inc -gen-typedef-defs -typedefs-dialect=triton_nvidia_gpu) +add_mlir_doc(TritonNvidiaGPUDialect TritonNvidiaGPUDialect dialects/ -gen-dialect-doc) +add_mlir_doc(TritonNvidiaGPUOps TritonNvidiaGPUOps dialects/ -gen-op-doc) add_public_tablegen_target(TritonNvidiaGPUTableGen) set(LLVM_TARGET_DEFINITIONS TritonNvidiaGPUAttrDefs.td) diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h index 680af81ac..fc8a99457 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h @@ -30,7 +30,6 @@ #include "mlir/IR/Dialect.h" // TritonNvidiaGPU depends on Triton -#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Traits.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h.inc" @@ -43,4 +42,15 @@ #define GET_OP_CLASSES #include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc" +namespace mlir { + +// Returns null if the op is not inside a agent region (warp specialization +// mode). Note that there should be at most one agent id attached to the +// operation. +std::optional getWSAgentId(Operation *op); +std::optional getWSRoleId(Operation *op); +void setRoleId(Operation *op, int roleId); + +} // namespace mlir + #endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_ diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td index 08ff21f52..f2ab288c1 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td @@ -38,7 +38,6 @@ def TritonNvidiaGPU_Dialect : Dialect { let dependentDialects = [ "triton::TritonDialect", "triton::gpu::TritonGPUDialect", - "mlir::triton::nvgpu::NVGPUDialect", "mlir::gpu::GPUDialect", "tensor::TensorDialect", ]; diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index cdf114690..7d8cc7b41 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -258,7 +258,11 @@ def TTNG_DotAsyncOp : TTNG_Op<"dot_async", [Pure, $d = matrix_multiply($a, $b) + $c }]; - let arguments = (ins TT_FpIntTensor:$a, TT_FpIntTensor:$b, TT_FpIntTensor:$c, BoolAttr:$allowTF32); + let arguments = (ins TT_FpIntTensor:$a, + TT_FpIntTensor:$b, + TT_FpIntTensor:$c, + BoolAttr:$allowTF32, + I32Attr:$maxNumImpreciseAcc); let results = (outs TT_FpIntTensor:$d); diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index bf682af94..d1ff8ab83 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -30,7 +30,7 @@ namespace triton { const std::set ENV_VARS = { - "ENABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", + "DISABLE_MMA_V3", "TRITON_DISABLE_LINE_INFO", "DISABLE_FAST_REDUCTION", "ENABLE_TMA", "MLIR_ENABLE_DUMP", "LLVM_IR_ENABLE_DUMP", "AMDGCN_ENABLE_DUMP"}; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index 94d0b6337..bb4efa69d 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -18,7 +18,11 @@ using ::mlir::triton::gpu::getOrder; using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; using ::mlir::triton::gpu::getSizePerThread; +<<<<<<< HEAD using ::mlir::triton::gpu::MfmaEncodingAttr; +======= +using ::mlir::triton::gpu::getUniqueContigPerThread; +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 using ::mlir::triton::gpu::MmaEncodingAttr; using ::mlir::triton::gpu::SharedEncodingAttr; using ::mlir::triton::gpu::SliceEncodingAttr; @@ -51,9 +55,7 @@ getCvtOrder(Attribute srcLayout, Attribute dstLayout) { return {inOrd, outOrd}; } -SmallVector -getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, - unsigned &outVec) { +SmallVector getRepShapeForCvtLayout(triton::gpu::ConvertLayoutOp op) { auto srcTy = op.getSrc().getType().cast(); auto dstTy = op.getResult().getType().cast(); Attribute srcLayout = srcTy.getEncoding(); @@ -77,6 +79,7 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, } } +<<<<<<< HEAD #ifdef USE_ROCM if (srcLayout.isa() && srcLayout.dyn_cast().getIsTransposed() && @@ -94,6 +97,9 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, // that we cannot do vectorization. inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; +======= + assert(srcLayout && dstLayout && "Unexpected layout in getRepShape()"); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 auto srcShapePerCTA = getShapePerCTA(srcTy); auto dstShapePerCTA = getShapePerCTA(dstTy); @@ -101,21 +107,46 @@ getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, auto dstShapePerCTATile = getShapePerCTATile(dstLayout, dstTy.getShape()); unsigned rank = dstTy.getRank(); - SmallVector paddedRepShape(rank); - unsigned pad = std::max(inVec, outVec); + SmallVector repShape(rank); for (unsigned d = 0; d < rank; ++d) { - paddedRepShape[d] = + repShape[d] = std::max(std::min(srcShapePerCTA[d], srcShapePerCTATile[d]), std::min(dstShapePerCTA[d], dstShapePerCTATile[d])); } - if (rank == 1) - return paddedRepShape; + return repShape; +} + +SmallVector +getScratchConfigForCvtLayout(triton::gpu::ConvertLayoutOp op, unsigned &inVec, + unsigned &outVec) { + auto repShape = getRepShapeForCvtLayout(op); + if (repShape.empty()) + return repShape; + + auto srcTy = op.getSrc().getType().cast(); + auto dstTy = op.getResult().getType().cast(); + Attribute srcLayout = srcTy.getEncoding(); + Attribute dstLayout = dstTy.getEncoding(); + + auto [inOrd, outOrd] = getCvtOrder(srcLayout, dstLayout); + unsigned srcContigPerThread = + getUniqueContigPerThread(srcLayout, srcTy.getShape())[inOrd[0]]; + unsigned dstContigPerThread = + getUniqueContigPerThread(dstLayout, dstTy.getShape())[outOrd[0]]; + // TODO: Fix the legacy issue that ourOrd[0] == 0 always means + // that we cannot do vectorization. + inVec = outOrd[0] == 0 ? 1 : inOrd[0] == 0 ? 1 : srcContigPerThread; + outVec = outOrd[0] == 0 ? 1 : dstContigPerThread; + + if (repShape.size() <= 1) + return repShape; unsigned paddedDim = 1; if (auto dstBlockedLayout = dstLayout.dyn_cast()) { paddedDim = dstBlockedLayout.getOrder()[0]; } - paddedRepShape[paddedDim] += pad; - return paddedRepShape; + unsigned pad = std::max(inVec, outVec); + repShape[paddedDim] += pad; + return repShape; } SmallVector diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index dd782b187..14e766a1a 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -301,10 +301,19 @@ private: int64_t getDivisibility(arith::MulIOp op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - // lhs = k * d_lhs - // rhs = p * d_rhs - // lhs * rhs = k * d_lhs * p * d_rhs = k * p * d_lhs * d_rhs - return lhs.getDivisibility(dim) * rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && + !(rhs.getConstantValue().has_value() && rhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto rhsDivisibility = rhs.getDivisibility(dim); + if (rhs.getContiguity(dim) > 1 && + !(lhs.getConstantValue().has_value() && lhs.getConstantValue() == 1)) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + rhsDivisibility = 1; + } + return lhsDivisibility * rhsDivisibility; } std::optional getConstantValue(arith::MulIOp op, const AxisInfo &lhs, @@ -511,8 +520,23 @@ public: AxisInfo::DimVectorT contiguity = opInfo.getContiguity(); AxisInfo::DimVectorT divisibility = opInfo.getDivisibility(); AxisInfo::DimVectorT constancy = opInfo.getConstancy(); + int64_t newDivisibility = 1; + if (opInfo.getConstantValue().has_value()) { + // The tensor is constant, same as ConstantOpAxisInfoVisitor + newDivisibility = highestPowOf2Divisor(opInfo.getConstantValue().value()); + } else if (opInfo.getRank()) { + // Otherwise, calculate the GCD as the new divisibility + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + newDivisibility = + opInfo.getContiguity(0) > 1 ? 1 : opInfo.getDivisibility(0); + for (int d = 1; d < opInfo.getRank(); ++d) { + newDivisibility = + gcd(newDivisibility, + opInfo.getContiguity(d) > 1 ? 1 : opInfo.getDivisibility(d)); + } + } contiguity.insert(contiguity.begin() + op.getAxis(), 1); - divisibility.insert(divisibility.begin() + op.getAxis(), 1); + divisibility.insert(divisibility.begin() + op.getAxis(), newDivisibility); constancy.insert(constancy.begin() + op.getAxis(), 1); return AxisInfo(contiguity, divisibility, constancy, operands[0]->getValue().getConstantValue()); @@ -756,12 +780,17 @@ private: auto shift = rhs.getConstantValue().has_value() ? rhs.getConstantValue().value() : rhs.getDivisibility(dim); - auto numBits = log2Int(lhs.getDivisibility(dim)); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + auto numBits = log2Int(lhsDivisibility); auto maxBits = log2Int(highestPowOf2Divisor(0)); // Make sure the return value doesn't exceed highestPowOf2Divisor(0) if (shift + numBits > maxBits) return highestPowOf2Divisor(0); - return lhs.getDivisibility(dim) << shift; + return lhsDivisibility << shift; } int64_t getConstancy(arith::ShLIOp op, const AxisInfo &lhs, @@ -795,12 +824,15 @@ private: int64_t getDivisibility(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, int dim) override { - if (rhs.getConstantValue().has_value()) - return std::max(1, lhs.getDivisibility(dim) / - (1 << rhs.getConstantValue().value())); - else - return std::max(1, lhs.getDivisibility(dim) / - (1 << rhs.getDivisibility(dim))); + auto shift = rhs.getConstantValue().has_value() + ? rhs.getConstantValue().value() + : rhs.getDivisibility(dim); + auto lhsDivisibility = lhs.getDivisibility(dim); + if (lhs.getContiguity(dim) > 1 && shift) { + // Treat [2^n,2^n+1,...]'s divisibility as 1 instead of 2^n + lhsDivisibility = 1; + } + return std::max(1, lhsDivisibility / (1 << shift)); } int64_t getConstancy(OpTy op, const AxisInfo &lhs, const AxisInfo &rhs, diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index df1fe4066..aecc2345a 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -10,6 +10,7 @@ add_mlir_library(TritonAnalysis TritonGPUAttrDefsIncGen LINK_LIBS PUBLIC + ASMBuilder MLIRAnalysis MLIRLLVMDialect TritonIR diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 39a916224..da2323660 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -33,14 +33,39 @@ SmallVector getParentOrder(Attribute layout) { } // namespace -bool ReduceOpHelper::isFastReduction() { - // Disable fast reduction only for debugging purpose - if (::triton::tools::getBoolEnv("DISABLE_FAST_REDUCTION")) - return false; +bool ReduceOpHelper::isReductionOnLayoutFastAxis() { return getParentAxis(getSrcLayout(), axis) == getParentOrder(getSrcLayout())[0]; } +// Thread offset is the thread index offset of two adjacent threads on the +// reduction axis within the warp. +unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() { + auto srcLayout = getSrcLayout(); + + // If the reduction axis is the fast axis of the parent layout + if (isReductionOnLayoutFastAxis()) { + return 1; + } + + unsigned threadOffset = 1; + if (auto sliceLayout = + srcLayout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(parentLayout); + threadOffset = threadsPerWarp[sliceLayout.getDim()]; + } else { + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcLayout); + if (threadsPerWarp.size() == 1) { + threadOffset = 1; + } else { + assert(threadsPerWarp.size() == 2 && "Only supports 2D layouts"); + threadOffset = axis == 0 ? threadsPerWarp[1] : threadsPerWarp[0]; + } + } + return threadOffset; +} + // Cases where distributed shared memory is not required in ConvertLayout: // (1) numCTAs == 1 // (2) numCTAs > 1 but srcCTALayout == dstCTALayout @@ -124,53 +149,26 @@ unsigned ReduceOpHelper::getThreadsReductionAxis() { triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape)[axis]; } -SmallVector ReduceOpHelper::getScratchConfigBasic() { - auto smemShape = convertType(getSrcShape()); - smemShape[axis] = std::min(smemShape[axis], getThreadsReductionAxis()); +bool ReduceOpHelper::isWarpSynchronous() { + auto argsLayout = getSrcLayout(); + return triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1; +} + +SmallVector ReduceOpHelper::getScratchConfig() { + SmallVector smemShape; + // that case doesn't need inter-warp communication + if (isWarpSynchronous()) + return {0, 0}; + + smemShape = convertType(getSrcShape()); + smemShape[axis] = getInterWarpSizeWithUniqueData(); + return smemShape; } -bool ReduceOpHelper::isWarpSynchronous() { - auto argsLayout = getSrcLayout(); - return isFastReduction() && - (triton::gpu::getWarpsPerCTA(argsLayout)[axis] == 1); -} - -SmallVector> ReduceOpHelper::getScratchConfigsFast() { - SmallVector> smemShapes(3); - - auto argLayout = getSrcLayout(); - auto argLayoutMma = argLayout.dyn_cast(); - - // that case doesn't need inter-warp communication - if (isWarpSynchronous()) - return {{0, 0}, {0, 0}}; - - /// shared memory block0 - smemShapes[0] = convertType(getSrcShape()); - smemShapes[0][axis] = getInterWarpSize(); - - /// FIXME(Qingyi): This size is actually larger than required. - /// shared memory block1: - auto mod = op->getParentOfType(); - unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); - unsigned threadsPerWarp = - triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - smemShapes[1].push_back(numWarps * threadsPerWarp); - - return smemShapes; -} - unsigned ReduceOpHelper::getScratchSizeInBytes() { - unsigned elems = 0; - if (isFastReduction()) { - auto smemShapes = getScratchConfigsFast(); - for (const auto &smemShape : smemShapes) - elems = std::max(elems, product(smemShape)); - } else { - auto smemShape = getScratchConfigBasic(); - elems = product(smemShape); - } + auto smemShape = getScratchConfig(); + auto elems = product(smemShape); unsigned bytesPerElem = 0; for (const auto &ty : srcElementTypes) { @@ -179,7 +177,21 @@ unsigned ReduceOpHelper::getScratchSizeInBytes() { return bytesPerElem * elems; } +bool ReduceOpHelper::isReduceWithinCTA() { + auto axis = getAxis(); + auto srcLayout = getSrcLayout(); + auto CTASplitNum = mlir::triton::gpu::getCTASplitNum(srcLayout); + assert(axis < CTASplitNum.size()); + return CTASplitNum[axis] == 1; +} + bool ReduceOpHelper::isSupportedLayout() { + // Layout optimization passes such as PlanCTAPass and + // RemoveLayoutConversionPass should avoid cross-CTA reduction + if (!isReduceWithinCTA()) { + return false; + } + auto srcLayout = getSrcLayout(); if (srcLayout.isa()) { return true; @@ -203,7 +215,8 @@ unsigned ScanLoweringHelper::getAxisNumElementsPerThread() { } unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() { - SmallVector sizePerThreads = getContigPerThread(getEncoding()); + SmallVector sizePerThreads = + triton::gpu::getContigPerThread(getEncoding()); sizePerThreads[getAxis()] = 1; return product(sizePerThreads); } @@ -214,6 +227,11 @@ unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() { return triton::gpu::getThreadsPerWarp(getEncoding())[getAxis()]; } +unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() { + return triton::gpu::getThreadsPerWarpWithUniqueData(getEncoding(), + getShape())[getAxis()]; +} + unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() { auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); threadsPerWarp[getAxis()] = 1; @@ -228,35 +246,38 @@ unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() { unsigned numParallelWarpsPerCTA = product(warpsPerCTA); return numParallelThreadsPerWarp * numParallelWarpsPerCTA; } + unsigned ScanLoweringHelper::getAxisNumWarps() { - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); - return warpsPerCTA[getAxis()]; + return triton::gpu::getWarpsPerCTA(getEncoding())[getAxis()]; +} + +unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() { + return triton::gpu::getWarpsPerCTAWithUniqueData(getEncoding(), + getShape())[getAxis()]; } unsigned ScanLoweringHelper::getAxisNumBlocks() { - auto type = scanOp.getOperand(0).getType().cast(); - auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); unsigned axis = getAxis(); return ceil( - type.getShape()[axis], + getShape()[axis], (sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis])); } unsigned ScanLoweringHelper::getNonAxisNumBlocks() { - auto type = scanOp.getOperand(0).getType().cast(); - auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); unsigned axis = getAxis(); unsigned numBlocks = 1; for (unsigned i = 0; i < sizePerThreads.size(); i++) { if (i == axis) continue; - numBlocks *= ceil( - type.getShape()[i], - (sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i])); + numBlocks *= + ceil(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] * + warpsPerCTA[i])); } return numBlocks; } @@ -265,7 +286,7 @@ bool ScanLoweringHelper::isSupported() { // TODO: Support the following cases: // 1. Scan on non-blocking encodings // 2. Scan with multiple operands - if (!isa(srcEncoding)) + if (!isa(getEncoding())) return false; if (scanOp.getNumOperands() != 1) return false; @@ -273,6 +294,9 @@ bool ScanLoweringHelper::isSupported() { } unsigned ScanLoweringHelper::getScratchSizeInBytes() { + unsigned axisNumWarps = getAxisNumWarpsWithUniqueData(); + if (axisNumWarps == 1) + return 0; auto type = scanOp.getOperand(0).getType().cast(); unsigned elementSizeInBytes = type.getElementTypeBitWidth() / 8; auto mod = scanOp->getParentOfType(); @@ -288,8 +312,12 @@ triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() { return srcEncoding.cast(); } +llvm::ArrayRef ScanLoweringHelper::getShape() { + return scanOp.getOperand(0).getType().cast().getShape(); +} + unsigned ScanLoweringHelper::getAxisElementStride() { - auto order = triton::gpu::getOrder(srcEncoding); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; for (unsigned dim : order) { if (dim == getAxis()) @@ -300,7 +328,7 @@ unsigned ScanLoweringHelper::getAxisElementStride() { } unsigned ScanLoweringHelper::getAxisThreadStride() { - auto order = triton::gpu::getOrder(srcEncoding); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; for (unsigned dim : order) { if (dim == getAxis()) @@ -311,18 +339,23 @@ unsigned ScanLoweringHelper::getAxisThreadStride() { } unsigned ScanLoweringHelper::getAxisBlockStride() { - auto order = triton::gpu::getOrder(srcEncoding); + auto order = triton::gpu::getOrder(getEncoding()); unsigned stride = 1; - auto type = scanOp.getOperand(0).getType().cast(); - auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding); - auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); - auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding()); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding()); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding()); for (unsigned dim : order) { if (dim == getAxis()) return stride; +<<<<<<< HEAD stride *= ceil(type.getShape()[dim], sizePerThreads[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]); +======= + stride *= ceil(getShape()[dim], sizePerThreads[dim] * + threadsPerWarp[dim] * + warpsPerCTA[dim]); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } llvm_unreachable("Axis not found in order"); } @@ -357,7 +390,7 @@ bool supportMMA(triton::DotOp op, int version) { auto aElemTy = op.getA().getType().cast().getElementType(); auto bElemTy = op.getB().getType().cast().getElementType(); if (version == 3) { - if (!::triton::tools::getBoolEnv("ENABLE_MMA_V3")) + if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) return false; auto retType = op.getResult().getType().cast(); auto retShapePerCTA = triton::gpu::getShapePerCTA(retType); @@ -370,6 +403,12 @@ bool supportMMA(triton::DotOp op, int version) { aElemTy.isF32()))) { return false; } + // We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op. + if (op.getMaxNumImpreciseAcc() < 32 && + (aElemTy.isFloat8E5M2() || aElemTy.isFloat8E4M3FNUZ()) && + op.getType().cast().getElementType().isF32()) { + return false; + } } if (aElemTy.isF32() && bElemTy.isF32()) { return (op.getAllowTF32() && version == 2) || version == 3; @@ -452,13 +491,45 @@ bool supportMMA(Value value, int version) { // FP8 is not natively supported on all mma versions but it can always be // promoted to fp16 therefore we can always support it. bool isFP8 = elemTy.isFloat8E5M2() || elemTy.isFloat8E4M3FN() || - elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ(); + elemTy.isFloat8E5M2FNUZ() || elemTy.isFloat8E4M3FNUZ() || + elemTy.isFloat8E4M3B11FNUZ(); return isFP8 || elemTy.isF16() || elemTy.isBF16() || (elemTy.isF32() && version >= 2) || (elemTy.isInteger(8) && version >= 2); } +static bool isMmaToMmaShortcut(Attribute srcEncoding, Attribute dstEncoding) { + auto src = srcEncoding.dyn_cast(); + auto dst = dstEncoding.dyn_cast(); + if (!src || !dst) + return false; + auto srcInstrShape = src.getInstrShape(); + auto dstInstrShape = dst.getInstrShape(); + // when #mma = MmaEncoding + return src && dst && src.getVersionMajor() == 3 && + src.getWarpsPerCTA()[1] == 1 && dst.getVersionMajor() == 3 && + dst.getWarpsPerCTA()[1] == 1 && srcInstrShape[2] == dstInstrShape[2]; +} + +bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + return isMmaToMmaShortcut(srcTy.getEncoding(), dstTy.getEncoding()); +} + +// For MMAV3 dotOperand layout matches mma operand for f16 case. +bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto srcLayout = srcTy.getEncoding(); + auto dstLayout = dstTy.getEncoding(); + auto mmaLayout = srcLayout.cast(); + auto dotOperandLayout = dstLayout.cast(); + return mmaLayout.getVersionMajor() == 3 && dotOperandLayout.getOpIdx() == 0 && + isMmaToMmaShortcut(dotOperandLayout.getParent(), srcLayout) && + srcTy.getElementType().isF16(); +} + bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) + return true; // dot_op = #mma // when #mma = MmaEncoding auto srcLayout = srcTy.getEncoding(); @@ -472,6 +543,7 @@ bool isMmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { !srcTy.getElementType().isF32(); } +<<<<<<< HEAD #ifdef USE_ROCM bool isMfmaToDotShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { auto srcLayout = srcTy.getEncoding(); @@ -501,6 +573,8 @@ bool isMmaToMmaShortcut(RankedTensorType &srcTy, RankedTensorType &dstTy) { srcElemsPerThread == dstElemsPerThread; } +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 bool isSingleValue(Value value) { // Don't consider load as expensive if it is loading a scalar. if (auto tensorTy = value.getType().dyn_cast()) diff --git a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp index f52256b3a..02b4c024f 100644 --- a/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/lib/Conversion/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -6,6 +6,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Conversion/TritonGPUToLLVM/PTXAsmFormat.h" @@ -21,12 +22,246 @@ using ::mlir::LLVM::getSRegValue; namespace { +using OperandsAndConstraints = std::vector>; +typedef std::vector Constraints; + +const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;"; +const std::string Wgmma_Fence_Op = "wgmma.fence.sync.aligned;"; +const std::string Cga_Barrier_Sync_op = "barrier.cluster.sync.aligned;"; +const std::string Wgmma_Commit_Group_Op = "wgmma.commit_group.sync.aligned;"; +const std::string Wgmma_Wait_Group_Op = + "wgmma.wait_group.sync.aligned #pendings;"; +const std::string Cluster_Wait_Op = "barrier.cluster.wait.aligned;"; +const std::string Fence_Mbarrier_Init_Op = + "fence.mbarrier_init.release.cluster;"; +const std::string Cga_Barrier_Arrive_Op = "barrier.cluster.arrive;"; +const std::string Cga_Barrier_Wait_Op = "barrier.cluster.wait;"; +const std::string Reg_Dealloc_Op = "setmaxnreg.dec.sync.aligned.u32 #regCount;"; +const std::string Wgmma_Desc_Create_op = + "{\n" + ".reg .u64 a<5>; \n" + "mov.u64 a0, #swizzling;\n" + "shl.b64 a1, a0, 3;\n" // stride dimension + "shr.b64 a1, a1, 4;\n" // stride dimension + "mul.lo.u64 a2, $2, #swizzling;\n" // leadingDimension + "shr.b64 a2, a2, 4;\n" // leadingDimension + "shl.b64 a3, $1, 46; \n" // startAddr + "shr.b64 a3, a3, 50; \n" // startAddr + "mov.u64 a4, #mode; \n" // mode + "shl.b64 a4, a4, 62; \n" + "shl.b64 a1, a1, 32; \n" + "or.b64 a1, a4, a1; \n" + "shl.b64 a2, a2, 16; \n" + "or.b64 a1, a1, a2; \n" + "or.b64 $0, a1, a3; \n" + "}"; + +const std::string Mbarrier_Init_Op = + "@$1 mbarrier.init.shared.b64 [$0], #count;"; +const std::string Mbarrier_Wait_Op = + "{ \n" + ".reg .pred P1; \n" + "LAB_WAIT: \n" + "mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, 0x989680; \n" + "@P1 bra.uni DONE; \n" + "bra.uni LAB_WAIT; \n" + "DONE: \n" + "} \n"; +const std::string Named_Barrier_Arrive_Op = "bar.arrive $0, $1;"; +const std::string Named_Barrier_Wait_Op = "bar.sync $0, $1;"; +const std::string Sts64_Op = "st.shared.v2.b32 [$0], {$1, $2};"; +const std::string Cluster_Cta_Id_Op = "{\n" + ".reg .u32 a<5>; \n" + "mov.u32 a0, %cluster_ctaid.x;\n" // x + "mov.u32 a1, %cluster_ctaid.y;\n" // y + "mov.u32 a2, %cluster_ctaid.z;\n" // z + "mov.u32 a3, %cluster_nctaid.x;\n" // nx + "mov.u32 a4, %cluster_nctaid.y;\n" // ny + "mad.lo.u32 a1, a2, a4, a1; \n" + "mad.lo.u32 $0, a1, a3, a0; \n" + "}"; + +bool isNumber(const std::string &s) { + return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { + return !std::isdigit(c); + }) == s.end(); +} + +Type getTypeFromConstraint(char constraint, mlir::PatternRewriter &rewriter) { + Type ty; + if (constraint == 'b') + ty = IntegerType::get(rewriter.getContext(), 1); + else if (constraint == 'h') + ty = IntegerType::get(rewriter.getContext(), 16); + else if (constraint == 'r') + ty = IntegerType::get(rewriter.getContext(), 32); + else if (constraint == 'l') + ty = IntegerType::get(rewriter.getContext(), 64); + else if (constraint == 'f') + ty = FloatType::getF32(rewriter.getContext()); + else if (constraint == 'd') + ty = FloatType::getF64(rewriter.getContext()); + else { + assert(false && "Unsupported constraint"); + } + return ty; +} + template class NVGPUOpPatternBase : public mlir::RewritePattern { public: explicit NVGPUOpPatternBase(mlir::MLIRContext *context) : mlir::RewritePattern(SourceOp::getOperationName(), 1, context) {} + // Converts the given value to the type represented by the constraint + // E.g. if val is of type llvmptr and constraint is 'r', then we convert + // val to i32 using ptrtoint(i32_ty, val) + mlir::Value convertToType(mlir::Value val, std::string constraint, + Location &loc, + mlir::PatternRewriter &rewriter) const { + auto isConstraintNumber = isNumber(constraint); + if (!isConstraintNumber) { + auto ty = getTypeFromConstraint(constraint[0], rewriter); + if (val.getType().isa()) { + return ptrtoint(ty, val); + } else { + assert(val.getType().getIntOrFloatBitWidth() <= + ty.getIntOrFloatBitWidth() && + "Cannot convert to a smaller type"); + if (val.getType().getIntOrFloatBitWidth() < ty.getIntOrFloatBitWidth()) + return zext(ty, val); + } + } + return val; + } + + SmallVector + getPtxOutputs(std::vector &outputConstraints, + PTXBuilder &ptxBuilder) const { + SmallVector ptxOutputs; + for (unsigned i = 0; i < outputConstraints.size(); i++) { + auto *ptxOutput = ptxBuilder.newOperand(outputConstraints[i]); + ptxOutputs.push_back(ptxOutput); + } + return ptxOutputs; + } + + OperandsAndConstraints + unpackOperands(OperandsAndConstraints &operandsAndConstraints, + PTXBuilder &ptxBuilder, Location &loc, + mlir::PatternRewriter &rewriter) const { + OperandsAndConstraints unpackedOperands; + for (auto &[operand, constraint] : operandsAndConstraints) { + auto llvmStruct = llvm::dyn_cast(operand.getType()); + // if a constraint is a number, then we are doing input/output tying + // if the operand is a struct, then we need to unpack it, and + // add the constraint to each of the unpacked operands uses the constraint + // as an offset + auto isConstraintNumber = isNumber(constraint); + if (llvmStruct) { + for (unsigned i = 0; i < llvmStruct.getBody().size(); i++) { + if (isConstraintNumber) { + auto constraintInt = std::stoi(constraint) + i; + unpackedOperands.push_back( + {extract_val(llvmStruct.getBody()[i], operand, i), + std::to_string(constraintInt)}); + } else { + unpackedOperands.push_back( + {extract_val(llvmStruct.getBody()[i], operand, i), constraint}); + } + } + } else { + unpackedOperands.push_back({operand, constraint}); + } + } + return unpackedOperands; + } + + SmallVector + getPtxOperands(OperandsAndConstraints &operandsAndConstraints, + PTXBuilder &ptxBuilder, Location &loc, + mlir::PatternRewriter &rewriter) const { + SmallVector ptxOperands; + auto unpackedOperandsAndConstraints = + unpackOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); + for (auto &[operand, constraint] : unpackedOperandsAndConstraints) { + auto convertedOperand = convertToType(operand, constraint, loc, rewriter); + auto *ptxOperand = ptxBuilder.newOperand(convertedOperand, constraint); + ptxOperands.push_back(ptxOperand); + } + return ptxOperands; + } + + virtual std::vector getOutputConstraints(SourceOp op) const { + return {}; + } + + virtual OperandsAndConstraints getOperandsAndConstraints(SourceOp op) const { + return {}; + } + + Type getReturnType(std::vector outputConstraints, + mlir::PatternRewriter &rewriter) const { + auto ctx = rewriter.getContext(); + Type resTy; + if (outputConstraints.empty()) { + resTy = void_ty(ctx); + } else { + SmallVector retTys; + for (auto &outputConstraint : outputConstraints) { + assert(outputConstraint[0] == '=' && + "Constraint must be for an output"); + Type retTy = getTypeFromConstraint(outputConstraint[1], rewriter); + retTys.push_back(retTy); + } + if (retTys.size() == 1) { + resTy = retTys[0]; + } else { + resTy = struct_ty(retTys); + } + } + return resTy; + } + + std::string patchPtxAsm(mlir::Operation *op, std::string ptxAsm) const { + std::vector> patchLocations; + std::vector patchValues; + auto start = ptxAsm.find("#", 0); + while (start != std::string::npos) { + auto endIterator = + std::find_if(ptxAsm.begin() + start + 1, ptxAsm.end(), + [](unsigned char c) { return !std::isalnum(c); }); + + assert(endIterator != ptxAsm.end() && "unexpected asm format"); + + auto end = std::distance(ptxAsm.begin(), endIterator); + auto patchLocation = std::make_pair(start, end); + patchLocations.push_back(patchLocation); + auto patchValue = ptxAsm.substr(start + 1, end - start - 1); + patchValues.push_back(patchValue); + start = ptxAsm.find("#", end); + } + assert(patchLocations.size() == patchValues.size() && + "patchLocations and patchValues should have the same size"); + if (patchLocations.size() == 0) { + return ptxAsm; + } + std::string res = ""; + size_t prevStart = 0; + unsigned i = 0; + for (auto &[start, end] : patchLocations) { + res += ptxAsm.substr(prevStart, start - prevStart); + auto integerAttr = op->getAttrOfType(patchValues[i]); + auto attr = integerAttr.getInt(); + res += std::to_string(attr); + prevStart = end; + i++; + } + if (prevStart < ptxAsm.size()) + res += ptxAsm.substr(prevStart, ptxAsm.size() - prevStart); + return res; + } + LogicalResult matchAndRewrite(mlir::Operation *op, mlir::PatternRewriter &rewriter) const override { @@ -35,30 +270,62 @@ public: auto sourceOp = llvm::dyn_cast(op); if (!sourceOp) return mlir::failure(); - auto ptxAsm = static_cast(this)->getPtxAsm(sourceOp); + auto concrete = static_cast(this); + auto ptxAsm = concrete->getPtxAsm(sourceOp); + auto ptxAsmPatched = patchPtxAsm(sourceOp, ptxAsm); auto hasSideEffects = !isMemoryEffectFree(sourceOp); + auto operandsAndConstraints = concrete->getOperandsAndConstraints(sourceOp); + auto outputConstraints = concrete->getOutputConstraints(sourceOp); + PTXBuilder ptxBuilder; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - ptxInstr({}, /*onlyAttachMLIRArgs=*/true); - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy, - /*hasSideEffects*/ hasSideEffects); - rewriter.eraseOp(op); + auto ptxOutputs = getPtxOutputs(outputConstraints, ptxBuilder); + auto ptxOperands = + getPtxOperands(operandsAndConstraints, ptxBuilder, loc, rewriter); + SmallVector outputsAndOperands = ptxOutputs; + outputsAndOperands.append(ptxOperands.begin(), ptxOperands.end()); + auto &ptxInstr = *ptxBuilder.create(ptxAsmPatched); + ptxInstr(outputsAndOperands, /*onlyAttachMLIRArgs=*/true); + auto retTy = getReturnType(outputConstraints, rewriter); + auto res = ptxBuilder.launch(rewriter, loc, retTy, + /*hasSideEffects*/ hasSideEffects); + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, res); + } + return mlir::success(); } }; -class CGABarrierSyncOpPattern - : public NVGPUOpPatternBase { +template +class NVGPUOpGenericPattern + : public NVGPUOpPatternBase> { public: - using Base = - NVGPUOpPatternBase; - using Base::Base; + explicit NVGPUOpGenericPattern(mlir::MLIRContext *context, std::string ptxAsm, + std::vector outputConstraints, + std::vector inputConstraints) + : NVGPUOpPatternBase>(context), + ptxAsm(ptxAsm), outputConstraints(outputConstraints), + inputConstraints(inputConstraints) {} - std::string getPtxAsm(ttn::CGABarrierSyncOp op) const { - return "barrier.cluster.sync.aligned;"; + std::vector getOutputConstraints(SourceOp op) const { + return outputConstraints; } + OperandsAndConstraints getOperandsAndConstraints(SourceOp op) const { + OperandsAndConstraints operandsAndConstraints; + for (unsigned i = 0; i < inputConstraints.size(); i++) { + operandsAndConstraints.push_back( + {op->getOperand(i), inputConstraints[i]}); + } + return operandsAndConstraints; + } + std::string getPtxAsm(SourceOp op) const { return ptxAsm; } + +private: + std::string ptxAsm; + std::vector outputConstraints; + std::vector inputConstraints; }; class FenceAsyncSharedOpPattern @@ -78,226 +345,6 @@ public: } }; -class WGMMAFenceOpPattern - : public NVGPUOpPatternBase { -public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::WGMMAFenceOp op) const { - return "wgmma.fence.sync.aligned;"; - } -}; - -class WGMMACommitGroupOpPattern - : public NVGPUOpPatternBase { -public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::WGMMACommitGroupOp op) const { - return "wgmma.commit_group.sync.aligned;"; - } -}; - -class WGMMAWaitGroupOpPattern - : public NVGPUOpPatternBase { -public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::WGMMAWaitGroupOp op) const { - auto pendings = op.getPendings(); - return "wgmma.wait_group.sync.aligned " + std::to_string(pendings) + ";"; - } -}; - -class StoreMatrixOpPattern : public mlir::RewritePattern { -public: - StoreMatrixOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::StoreMatrixOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto storeMatrixOp = llvm::dyn_cast(op); - if (!storeMatrixOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto addr = storeMatrixOp.getAddr(); - auto datas = storeMatrixOp.getDatas(); - - assert(datas.size() == 1 || datas.size() == 2 || - datas.size() == 4 && "Invalid size for StoreMatrixOp"); - PTXBuilder ptxBuilder; - auto &ptxInstr = *ptxBuilder.create( - "stmatrix.sync.aligned.m8n8.x" + std::to_string(datas.size()) + - ".shared.b16"); - auto *addrOpr = ptxBuilder.newAddrOperand(ptrtoint(i32_ty, addr), "r"); - - SmallVector> args; - for (unsigned i = 0; i < datas.size(); ++i) { - args.push_back({datas[i], "r"}); - } - auto *operands = ptxBuilder.newListOperand(args); - - ptxInstr(addrOpr, operands); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -class MBarrierInitOpPattern : public mlir::RewritePattern { -public: - MBarrierInitOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::MBarrierInitOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto mBarrierInitOp = llvm::dyn_cast(op); - if (!mBarrierInitOp) - return mlir::failure(); - auto loc = op->getLoc(); - Value mbarrier = mBarrierInitOp.getMbarrier(); - Value pred = mBarrierInitOp.getPred(); - uint32_t count = mBarrierInitOp.getCount(); - PTXBuilder ptxBuilder; - - auto &ptxInstr = *ptxBuilder.create("mbarrier.init.shared.b64"); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - auto *expectedOpr = ptxBuilder.newConstantOperand(count); - - ptxInstr(barOpr, expectedOpr).predicate(pred, "b"); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -class MBarrierArriveOpPattern : public mlir::RewritePattern { -public: - MBarrierArriveOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::MBarrierArriveOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto mbarrierArriveOp = llvm::dyn_cast(op); - if (!mbarrierArriveOp) - return mlir::failure(); - auto loc = op->getLoc(); - Value mbarrier = mbarrierArriveOp.getMbarrier(); - Value pred = mbarrierArriveOp.getPred(); - Value ctaId = mbarrierArriveOp.getCtaId(); - auto arriveType = mbarrierArriveOp.getArriveType(); - uint32_t txCount = mbarrierArriveOp.getTxCount(); - - PTXBuilder ptxBuilder; - if (arriveType == ttn::MBarriveType::normal) { - auto &ptxInstr = - *ptxBuilder.create("mbarrier.arrive.shared.b64 _,"); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - - ptxInstr(barOpr).predicate(pred, "b"); - } else if (arriveType == ttn::MBarriveType::cp_async) { - auto &ptxInstr = *ptxBuilder.create( - "cp.async.mbarrier.arrive.noinc.shared.b64"); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - - ptxInstr(barOpr).predicate(pred, "b"); - } else if (arriveType == ttn::MBarriveType::expect_tx) { - assert(txCount > 0 && "txCount should be valid"); - auto &ptxInstr = *ptxBuilder.create( - "mbarrier.arrive.expect_tx.shared.b64 _,"); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - auto *expectedOpr = ptxBuilder.newConstantOperand(txCount); - - ptxInstr(barOpr, expectedOpr).predicate(pred, "b"); - } else if (arriveType == ttn::MBarriveType::remote) { - assert(ctaId && "ctaId should have a valid value"); - auto ptxAsm = - " { .reg .b32 remAddr32; \n" - " @$2 mapa.shared::cluster.u32 remAddr32, $0, $1; \n" - " @$2 mbarrier.arrive.shared::cluster.b64 _, [remAddr32]; } \n"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *barOpr = - ptxBuilder.newAddrOperand(ptrtoint(i32_ty, mbarrier), "r", 0); - auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - - ptxInstr({barOpr, ctaIdOpr, predOpr}, /*onlyAttachMLIRArgs=*/true); - } else { - assert(false && - "Unsupported mbarrier arrive type"); // TODO: is this the right way - // to assert in LLVM pass ? - } - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; -class MBarrierWaitOpPattern : public mlir::RewritePattern { -public: - MBarrierWaitOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::MBarrierWaitOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto mBarrierWaitOp = llvm::dyn_cast(op); - if (!mBarrierWaitOp) - return mlir::failure(); - auto loc = op->getLoc(); - Value mbarrier = mBarrierWaitOp.getMbarrier(); - Value phase = mBarrierWaitOp.getPhase(); - PTXBuilder ptxBuilder; - - auto ptxAsm = - "{\n" - ".reg .pred P1; \n" - "LAB_WAIT: \n" - "mbarrier.try_wait.parity.shared.b64 P1, [$0], $1, 0x989680; \n" - "@P1 bra.uni DONE; \n" - "bra.uni LAB_WAIT; \n" - "DONE: \n" - "}"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *barOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, mbarrier), "r"); - auto *phaseOpr = ptxBuilder.newOperand(zext(i32_ty, phase), "r"); - - ptxInstr({barOpr, phaseOpr}, - /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - class ClusterArriveOpPattern : public NVGPUOpPatternBase { public: @@ -313,202 +360,327 @@ public: } }; -class ClusterWaitOpPattern - : public NVGPUOpPatternBase { +class StoreMatrixOpPattern + : public NVGPUOpPatternBase { public: - using Base = NVGPUOpPatternBase; + using Base = NVGPUOpPatternBase; using Base::Base; - std::string getPtxAsm(ttn::ClusterWaitOp op) const { - return "barrier.cluster.wait.aligned;"; + + OperandsAndConstraints + getOperandsAndConstraints(ttn::StoreMatrixOp op) const { + OperandsAndConstraints operandsAndTypes; + auto addr = op.getAddr(); + auto datas = op.getDatas(); + operandsAndTypes.push_back({addr, "r"}); + for (unsigned i = 0; i < datas.size(); i++) { + operandsAndTypes.push_back({datas[i], "r"}); + } + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::StoreMatrixOp op) const { + auto datas = op.getDatas(); + std::string ptxAsm; + switch (datas.size()) { + case 1: + ptxAsm = "stmatrix.sync.aligned.m8n8.x1.shared.b16 [$0], {$1};"; + break; + case 2: + ptxAsm = "stmatrix.sync.aligned.m8n8.x2.shared.b16 [$0], {$1, $2};"; + break; + case 4: + ptxAsm = + "stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};"; + break; + default: + assert(false && "Invalid size"); + } + return ptxAsm; } }; -class TMALoadTiledOpPattern : public mlir::RewritePattern { +class MBarrierArriveOpPattern + : public NVGPUOpPatternBase { public: - TMALoadTiledOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::TMALoadTiledOp::getOperationName(), 1, - context) {} + using Base = + NVGPUOpPatternBase; + using Base::Base; - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto tmaLoadTiledOp = llvm::dyn_cast(op); - if (!tmaLoadTiledOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto dst = tmaLoadTiledOp.getDst(); - auto mbarrier = tmaLoadTiledOp.getMbarrier(); - auto tmaDesc = tmaLoadTiledOp.getTmaDesc(); - auto l2Desc = tmaLoadTiledOp.getL2Desc(); - auto pred = tmaLoadTiledOp.getPred(); - auto coords = tmaLoadTiledOp.getCoords(); - auto mcastMask = tmaLoadTiledOp.getMcastMask(); + OperandsAndConstraints + getOperandsAndConstraints(ttn::MBarrierArriveOp op) const { + OperandsAndConstraints operandsAndTypes; + Value mbarrier = op.getMbarrier(); + Value pred = op.getPred(); + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + + switch (arriveType) { + case ttn::MBarriveType::normal: + case ttn::MBarriveType::cp_async: + case ttn::MBarriveType::expect_tx: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + case ttn::MBarriveType::remote: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::MBarrierArriveOp op) const { + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + uint32_t txCount = op.getTxCount(); + std::string ptxAsm; + switch (arriveType) { + case ttn::MBarriveType::normal: + ptxAsm = "@$1 mbarrier.arrive.shared.b64 _, [$0];"; + break; + case ttn::MBarriveType::cp_async: + ptxAsm = "@$1 cp.async.mbarrier.arrive.noinc.shared.b64 [$0];"; + break; + case ttn::MBarriveType::expect_tx: + assert(txCount > 0 && "txCount should be valid"); + ptxAsm = "@$1 mbarrier.arrive.expect_tx.shared.b64 _, [$0], " + + std::to_string(txCount) + ";"; + break; + case ttn::MBarriveType::remote: + assert(ctaId && "ctaId should have a valid value"); + ptxAsm = + " { .reg .b32 remAddr32; \n" + " @$2 mapa.shared::cluster.u32 remAddr32, $0, $1; \n" + " @$2 mbarrier.arrive.shared::cluster.b64 _, [remAddr32]; } \n"; + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return ptxAsm; + } +}; + +class TMALoadTiledOpPattern + : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; + + OperandsAndConstraints + getOperandsAndConstraints(ttn::TMALoadTiledOp op) const { + OperandsAndConstraints operandsAndTypes; + auto dst = op.getDst(); + auto mbarrier = op.getMbarrier(); + auto tmaDesc = op.getTmaDesc(); + auto l2Desc = op.getL2Desc(); + auto pred = op.getPred(); + auto coords = op.getCoords(); + auto mcastMask = op.getMcastMask(); auto dimSize = coords.size(); + assert(dimSize == 2 || (dimSize == 4 && mcastMask == nullptr) && + "Does not support TMA configuration"); - PTXBuilder ptxBuilder; + operandsAndTypes.push_back({dst, "r"}); + operandsAndTypes.push_back({tmaDesc, "l"}); + for (unsigned i = 0; i < coords.size(); i++) { + operandsAndTypes.push_back({coords[i], "r"}); + } + operandsAndTypes.push_back({mbarrier, "l"}); + if (mcastMask) { + operandsAndTypes.push_back({mcastMask, "h"}); + } + operandsAndTypes.push_back({l2Desc, "l"}); + operandsAndTypes.push_back({pred, "b"}); + + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::TMALoadTiledOp op) const { + auto coords = op.getCoords(); + auto mcastMask = op.getMcastMask(); + auto dimSize = coords.size(); + std::string ptxAsm; if (dimSize == 2) { if (mcastMask == nullptr) { - auto ptxAsm = - "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier:" - ":complete_tx" - "::bytes.L2::cache_hint [$0], [$1, {$2, $3}], [$4], $5;"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); - auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - - ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, barOpr, l2DescOpr, predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "@$6 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier:" + ":complete_tx" + "::bytes.L2::cache_hint [$0], [$1, {$2, $3}], [$4], $5;"; } else { - auto ptxAsm = - "@$7 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" - "complete_tx::bytes.multicast::cluster.L2::cache_hint" - " [$0], [$1, {$2, $3}], [$4], $5, $6;"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); - auto *maskOpr = ptxBuilder.newOperand(mcastMask, "h"); - auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, barOpr, maskOpr, l2DescOpr, - predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "@$7 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::" + "complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [$0], [$1, {$2, $3}], [$4], $5, $6;"; } } else if (dimSize == 4) { assert(mcastMask == nullptr && "Does not support multicast"); - auto ptxAsm = "@$8 " - "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier:" - ":complete_tx" - "::bytes.L2::cache_hint [$0], [$1, {$2, $3, $4, $5}], " - "[$6], $7;"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *dstOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, dst), "r"); - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); - auto *c3Opr = ptxBuilder.newOperand(coords[3], "r"); - auto *barOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, mbarrier), "r"); - auto *l2DescOpr = ptxBuilder.newOperand(l2Desc, "l"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({dstOpr, descOpr, c0Opr, c1Opr, c2Opr, c3Opr, barOpr, l2DescOpr, - predOpr}, - /*onlyAttachMLIRArgs=*/true); + ptxAsm = "@$8 " + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier:" + ":complete_tx" + "::bytes.L2::cache_hint [$0], [$1, {$2, $3, $4, $5}], " + "[$6], $7;"; } else { - assert(false && "invalid dim size"); + llvm::errs() << "Unsupported dimSize " << dimSize << "\n"; + llvm_unreachable(""); } - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); - rewriter.eraseOp(op); - return mlir::success(); + return ptxAsm; } }; -class TMAStoreTiledOpPattern : public mlir::RewritePattern { +class TMAStoreTiledOpPattern + : public NVGPUOpPatternBase { public: - TMAStoreTiledOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::TMAStoreTiledOp::getOperationName(), 1, - context) {} + using Base = NVGPUOpPatternBase; + using Base::Base; - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto tmaStoreTiledOp = llvm::dyn_cast(op); - if (!tmaStoreTiledOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto src = tmaStoreTiledOp.getSrc(); - auto tmaDesc = tmaStoreTiledOp.getTmaDesc(); - auto pred = tmaStoreTiledOp.getPred(); - auto coords = tmaStoreTiledOp.getCoords(); + OperandsAndConstraints + getOperandsAndConstraints(ttn::TMAStoreTiledOp op) const { + OperandsAndConstraints operandsAndTypes; + auto src = op.getSrc(); + auto tmaDesc = op.getTmaDesc(); + auto pred = op.getPred(); + auto coords = op.getCoords(); auto dimSize = coords.size(); - - PTXBuilder ptxBuilder; - if (dimSize == 2) { - auto ptxAsm = "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" - "[$0, {$2, $3}], [$1];"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, predOpr}, - /*onlyAttachMLIRArgs=*/true); - } else if (dimSize == 3) { - auto ptxAsm = "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group" - "[$0, {$2, $3, $4}], [$1];"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, c2Opr, predOpr}, - /*onlyAttachMLIRArgs=*/true); - } else if (dimSize == 4) { - auto ptxAsm = "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" - "[$0, {$2, $3, $4, $5}], [$1];"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - auto *descOpr = ptxBuilder.newOperand(ptrtoint(i64_ty, tmaDesc), "l"); - auto *srcOpr = ptxBuilder.newOperand(ptrtoint(i32_ty, src), "r"); - auto *c0Opr = ptxBuilder.newOperand(coords[0], "r"); - auto *c1Opr = ptxBuilder.newOperand(coords[1], "r"); - auto *c2Opr = ptxBuilder.newOperand(coords[2], "r"); - auto *c3Opr = ptxBuilder.newOperand(coords[3], "r"); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - ptxInstr({descOpr, srcOpr, c0Opr, c1Opr, c2Opr, c3Opr, predOpr}, - /*onlyAttachMLIRArgs=*/true); - } else { - assert(false && "invalid dim size"); + if (dimSize != 2 && dimSize != 3 && dimSize != 4) { + llvm::errs() << "Unsupported dimSize " << dimSize << "\n"; + llvm_unreachable(""); } + operandsAndTypes.push_back({tmaDesc, "l"}); + operandsAndTypes.push_back({src, "r"}); + for (unsigned i = 0; i < dimSize; i++) { + operandsAndTypes.push_back({coords[i], "r"}); + } + operandsAndTypes.push_back({pred, "b"}); - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); - rewriter.eraseOp(op); - return mlir::success(); + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::TMAStoreTiledOp op) const { + auto coords = op.getCoords(); + auto dimSize = coords.size(); + std::string ptxAsm; + if (dimSize == 2) { + ptxAsm = "@$4 cp.async.bulk.tensor.2d.global.shared::cta.bulk_group" + "[$0, {$2, $3}], [$1];"; + } else if (dimSize == 3) { + ptxAsm = "@$5 cp.async.bulk.tensor.3d.global.shared::cta.bulk_group" + "[$0, {$2, $3, $4}], [$1];"; + } else if (dimSize == 4) { + ptxAsm = "@$6 cp.async.bulk.tensor.4d.global.shared::cta.bulk_group" + "[$0, {$2, $3, $4, $5}], [$1];"; + } else { + llvm::errs() << "Unsupported dimSize " << dimSize << "\n"; + llvm_unreachable(""); + } + return ptxAsm; } }; -class LoadDSmemOpPattern : public mlir::RewritePattern { +class StoreDSmemOpPattern + : public NVGPUOpPatternBase { public: - LoadDSmemOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::LoadDSmemOp::getOperationName(), 1, context) { + using Base = NVGPUOpPatternBase; + using Base::Base; + + OperandsAndConstraints getOperandsAndConstraints(ttn::StoreDSmemOp op) const { + OperandsAndConstraints operandsAndTypes; + auto addr = op.getAddr(); + auto ctaId = op.getCtaId(); + auto values = op.getValues(); + auto pred = op.getPred(); + auto bitwidth = op.getBitwidth(); + operandsAndTypes.push_back({addr, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + operandsAndTypes.push_back({pred, "b"}); + std::string c = bitwidth == 16 ? "h" : (bitwidth == 32 ? "r" : "l"); + for (unsigned i = 0; i < values.size(); i++) { + operandsAndTypes.push_back({values[i], c}); + } + return operandsAndTypes; } - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto loadDSmemOp = llvm::dyn_cast(op); - if (!loadDSmemOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto addr = loadDSmemOp.getAddr(); - auto ctaId = loadDSmemOp.getCtaId(); - auto bitwidth = loadDSmemOp.getBitwidth(); - auto vec = loadDSmemOp.getVec(); + std::string getPtxAsm(ttn::StoreDSmemOp op) const { + auto bitwidth = op.getBitwidth(); + auto vec = op.getVec(); + auto values = op.getValues(); + assert( + (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && + "invalid bitwidth"); + assert((vec == 1 || vec == 2 || vec == 4) && vec == values.size() && + "invalid vec size"); + std::string ptxAsm; + if (vec == 1) { + ptxAsm = "{ \n" + ".reg .u32 remoteAddr; \n" + "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" + ".reg .pred p; \n" + "mov.pred p, $2; \n" + "@p st.shared::cluster.u#bitwidth [remoteAddr], $3; \n" + "}\n"; + } + if (vec == 2) { + ptxAsm = "{ \n" + ".reg .u32 remoteAddr; \n" + "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" + ".reg .pred p; \n" + "mov.pred p, $2; \n" + "@p st.shared::cluster.v.u#bitwidth [remoteAddr], {$3, $4}; \n" + "}\n"; + } + if (vec == 4) { + ptxAsm = "{ \n" + ".reg .u32 remoteAddr; \n" + "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n" + ".reg .pred p; \n" + "mov.pred p, $2; \n" + "@p st.shared::cluster.v.u#bitwidth [remoteAddr], {$3, $4, $5, " + "$6}; \n" + "}\n"; + } + return ptxAsm; + } +}; + +class LoadDSmemOpPattern + : public NVGPUOpPatternBase { +public: + using Base = NVGPUOpPatternBase; + using Base::Base; + + std::vector getOutputConstraints(ttn::LoadDSmemOp op) const { + auto bitwidth = op.getBitwidth(); + std::string c = bitwidth == 16 ? "=h" : (bitwidth == 32 ? "=r" : "=l"); + auto vec = op.getVec(); + return std::vector(vec, c); + } + OperandsAndConstraints getOperandsAndConstraints(ttn::LoadDSmemOp op) const { + OperandsAndConstraints operandsAndTypes; + auto addr = op.getAddr(); + auto ctaId = op.getCtaId(); + + operandsAndTypes.push_back({addr, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::LoadDSmemOp op) const { + auto addr = op.getAddr(); + auto ctaId = op.getCtaId(); + auto bitwidth = op.getBitwidth(); + auto vec = op.getVec(); assert( (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && "invalid bitwidth"); assert((vec == 1 || vec == 2 || vec == 4) && "invalid vec size"); - PTXBuilder ptxBuilder; std::string o1 = vec > 1 ? ".v.u" : ".u"; std::string vecStr = vec == 1 ? "$0" @@ -524,68 +696,74 @@ public: o1 + std::to_string(bitwidth) + " " + vecStr + ", [remoteAddr];\n" "}\n"; - - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - std::string c = bitwidth == 16 ? "=h" : (bitwidth == 32 ? "=r" : "=l"); - SmallVector oprs; - for (unsigned i = 0; i < vec; ++i) { - auto *ret = ptxBuilder.newOperand(c); - oprs.push_back(ret); - } - auto *addrOpr = ptxBuilder.newOperand(addr, "r"); - auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); - oprs.push_back(addrOpr); - oprs.push_back(ctaIdOpr); - - Type retTy = IntegerType::get(rewriter.getContext(), bitwidth); - SmallVector retTys(vec, retTy); - if (vec > 1) - retTy = struct_ty(retTys); - - ptxInstr(oprs, - /*onlyAttachMLIRArgs=*/true); - - auto res = ptxBuilder.launch(rewriter, loc, retTy); - rewriter.replaceOp(op, {res}); - return mlir::success(); + return ptxAsm; } }; -class WGMMAOpPattern : public mlir::RewritePattern { +class WGMMAOpPattern : public NVGPUOpPatternBase { public: - WGMMAOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::WGMMAOp::getOperationName(), 1, context) {} + using Base = NVGPUOpPatternBase; + using Base::Base; - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { + std::vector getOutputConstraints(ttn::WGMMAOp op) const { + // TODO (zahi): Return type must always be a struct for wgmma, currently + // we rely on the size of output constraints vector to determine whether + // the output is a struct or not. We should find a way to pass this info + auto resultType = op.getType(); + + auto outputStructType = resultType.dyn_cast(); + uint32_t numOutputRegs = outputStructType.getBody().size(); + std::string output = + outputStructType.getBody().front().isF32() ? "=f" : "=r"; + return std::vector(numOutputRegs, output); + } + + OperandsAndConstraints getOperandsAndConstraints(ttn::WGMMAOp op) const { + OperandsAndConstraints operandsAndConstraints; + auto opA = op.getOpA(); + auto opB = op.getOpB(); + auto opC = op.getOpC(); + auto typeA = opA.getType(); + + auto structTypeA = typeA.dyn_cast(); + + // TODO (zahi): is this the best way to tie inputs/outputs ? + if (opC) + operandsAndConstraints.push_back({opC, "0"}); + + if (structTypeA) { + operandsAndConstraints.push_back({opA, "r"}); + } else { + operandsAndConstraints.push_back({opA, "l"}); + } + + // Operand B (must be `desc`) + operandsAndConstraints.push_back({opB, "l"}); + return operandsAndConstraints; + } + + std::string getPtxAsm(ttn::WGMMAOp op) const { using namespace ttn; - auto ctx = rewriter.getContext(); - auto wgmmaOp = llvm::dyn_cast(op); - if (!wgmmaOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto opA = wgmmaOp.getOpA(); - auto opB = wgmmaOp.getOpB(); - auto opC = wgmmaOp.getOpC(); - auto m = wgmmaOp.getM(); - auto n = wgmmaOp.getN(); - auto k = wgmmaOp.getK(); - auto eltTypeC = wgmmaOp.getEltTypeC(); - auto eltTypeA = wgmmaOp.getEltTypeA(); - auto eltTypeB = wgmmaOp.getEltTypeB(); - auto layoutA = wgmmaOp.getLayoutA(); - auto layoutB = wgmmaOp.getLayoutB(); + auto opA = op.getOpA(); + auto opB = op.getOpB(); + auto m = op.getM(); + auto n = op.getN(); + auto k = op.getK(); + auto eltTypeC = op.getEltTypeC(); + auto eltTypeA = op.getEltTypeA(); + auto eltTypeB = op.getEltTypeB(); + auto layoutA = op.getLayoutA(); + auto layoutB = op.getLayoutB(); // Register checks auto typeA = opA.getType(); auto typeB = opB.getType(); - auto typeC = opC.getType(); + auto typeOutput = op.getType(); auto structTypeA = typeA.dyn_cast(); auto structTypeB = typeB.dyn_cast(); - auto structTypeC = typeC.dyn_cast(); + auto structTypeOutput = typeOutput.dyn_cast(); assert(!structTypeB && "Operand B can not be registers"); - assert(structTypeC && "Operand C must be registers"); + assert(structTypeOutput && "Output and C operand must be registers"); // Element type, MNK shape and transposing support check // Reference: @@ -624,65 +802,41 @@ public: (m == 64 && 8 <= n && n <= 224 && k == 32); } assert(supported && "WGMMA type or shape is not supported"); - PTXBuilder ptxBuilder; - SmallVector oprs; // Operands uint32_t asmOpIdx = 0; - - // Operand C - uint32_t numCRegs = structTypeC.getBody().size(); - std::string args = ""; + + // Output and operand C + uint32_t numCRegs = structTypeOutput.getBody().size(); + args += "{"; for (uint32_t i = 0; i < numCRegs; ++i) { args += "$" + std::to_string(asmOpIdx++) + (i == numCRegs - 1 ? "" : ","); - // LLVM does not support `+` semantic, we must repeat the arguments for - // both input and outputs - PTXBuilder::Operand *opr; - if (structTypeC.getBody().front().isF32()) - opr = ptxBuilder.newOperand( - extract_val(structTypeC.getBody()[i], opC, i), "=f"); - else - opr = ptxBuilder.newOperand( - extract_val(structTypeC.getBody()[i], opC, i), "=r"); - oprs.push_back(opr); } args += "}, "; - for (uint32_t i = asmOpIdx - numCRegs; i < asmOpIdx; ++i) { - auto *opr = ptxBuilder.newOperand(i); - oprs.push_back(opr); - } + if (op.getOpC()) + asmOpIdx += numCRegs; - // Note that LLVM will not skip the indexed repeating placeholders - asmOpIdx += numCRegs; // Operand A if (structTypeA) { - uint32_t numARegs = m * k / 128; - assert(numARegs == structTypeA.getBody().size()); + uint32_t numARegs = structTypeA.getBody().size(); args += "{"; for (uint32_t i = 0; i < numARegs; ++i) { args += "$" + std::to_string(asmOpIdx++) + (i == numARegs - 1 ? "" : ","); - auto *opr = ptxBuilder.newOperand( - extract_val(structTypeA.getBody()[i], opA, i), "f"); - oprs.push_back(opr); } args += "}, "; } else { args += "$" + std::to_string(asmOpIdx++) + ", "; - auto *opr = ptxBuilder.newOperand(opA, "l"); - oprs.push_back(opr); } // Operand B (must be `desc`) args += "$" + std::to_string(asmOpIdx++) + ", "; - auto *opr = ptxBuilder.newOperand(opB, "l"); - oprs.push_back(opr); - // `scale-d` is 1 by default - args += "1"; + // `scale-d` is 1 if we have a C operand. + args += op.getOpC() ? "1" : "0"; // `imm-scale-a`, and `imm-scale-b` are 1 by default only for float-based // WGMMA @@ -690,8 +844,11 @@ public: args += ", 1, 1"; // Push `trans-a` and `trans-b` args if needed (determined as constant) - if (needTransArgs) - args += ", " + std::to_string(transA) + ", " + std::to_string(transB); + if (needTransArgs) { + if (!structTypeA) + args += ", " + std::to_string(transA); + args += ", " + std::to_string(transB); + } auto ptxAsm = "wgmma.mma_async.sync.aligned" ".m" + @@ -699,338 +856,37 @@ public: std::to_string(k) + "." + stringifyEnum(eltTypeC).str() + "." + stringifyEnum(eltTypeA).str() + "." + stringifyEnum(eltTypeB).str() + " " + args + ";"; - - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - ptxInstr(oprs, - /*onlyAttachMLIRArgs=*/true); - - auto res = - ptxBuilder.launch(rewriter, loc, structTypeC, /*hasSideEffect*/ true); - rewriter.replaceOp(op, {res}); - return mlir::success(); + return ptxAsm; } }; -class FenceMBarrierInitOpPattern - : public NVGPUOpPatternBase { +class OffsetOfSts64OpPattern + : public NVGPUOpPatternBase { public: - using Base = - NVGPUOpPatternBase; + using Base = NVGPUOpPatternBase; using Base::Base; - std::string getPtxAsm(ttn::FenceMBarrierInitOp op) const { - return "fence.mbarrier_init.release.cluster;"; + std::vector getOutputConstraints(ttn::OffsetOfSts64Op op) const { + return {"=r"}; } -}; -class NamedBarrierArriveOpPattern : public mlir::RewritePattern { -public: - NamedBarrierArriveOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::NamedBarrierArriveOp::getOperationName(), 1, - context) {} + OperandsAndConstraints + getOperandsAndConstraints(ttn::OffsetOfSts64Op op) const { + OperandsAndConstraints operandsAndConstraints; + auto threadId = op.getThreadId(); + auto rowOfWarp = op.getRowOfWarp(); + auto elemIdx = op.getElemIdx(); - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto namedBarrierArriveOp = llvm::dyn_cast(op); - if (!namedBarrierArriveOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto bar = namedBarrierArriveOp.getBar(); - auto numThreads = namedBarrierArriveOp.getNumThreads(); - PTXBuilder ptxBuilder; + operandsAndConstraints.push_back({threadId, "r"}); + operandsAndConstraints.push_back({elemIdx, "r"}); + operandsAndConstraints.push_back({rowOfWarp, "r"}); - auto &ptxInstr = *ptxBuilder.create("bar.arrive $0, $1;"); - auto *barOpr = ptxBuilder.newOperand(bar, "r"); - auto *numThreadsOpr = ptxBuilder.newOperand(numThreads, "r"); - ptxInstr({barOpr, numThreadsOpr}, /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); + return operandsAndConstraints; } -}; -class NamedBarrierWaitOpPattern : public mlir::RewritePattern { -public: - NamedBarrierWaitOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::NamedBarrierWaitOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto namedBarrierWaitOp = llvm::dyn_cast(op); - if (!namedBarrierWaitOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto bar = namedBarrierWaitOp.getBar(); - auto numThreads = namedBarrierWaitOp.getNumThreads(); - PTXBuilder ptxBuilder; - - auto &ptxInstr = *ptxBuilder.create("bar.sync $0, $1;"); - auto *barOpr = ptxBuilder.newOperand(bar, "r"); - auto *numThreadsOpr = ptxBuilder.newOperand(numThreads, "r"); - ptxInstr({barOpr, numThreadsOpr}, /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -class CGABarrierArriveOpPattern - : public NVGPUOpPatternBase { -public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - std::string getPtxAsm(ttn::CGABarrierArriveOp op) const { - return "barrier.cluster.arrive;"; - } -}; - -class CGABarrierWaitOpPattern - : public NVGPUOpPatternBase { -public: - using Base = - NVGPUOpPatternBase; - using Base::Base; - std::string getPtxAsm(ttn::CGABarrierWaitOp op) const { - return "barrier.cluster.wait;"; - } -}; - -class StoreDSmemOpPattern : public mlir::RewritePattern { -public: - StoreDSmemOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::StoreDSmemOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto storeDSmemOp = llvm::dyn_cast(op); - if (!storeDSmemOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto addr = storeDSmemOp.getAddr(); - auto ctaId = storeDSmemOp.getCtaId(); - auto values = storeDSmemOp.getValues(); - auto pred = storeDSmemOp.getPred(); - - auto bitwidth = storeDSmemOp.getBitwidth(); - auto vec = storeDSmemOp.getVec(); - assert( - (bitwidth == 8 || bitwidth == 16 || bitwidth == 32 || bitwidth == 64) && - "invalid bitwidth"); - assert((vec == 1 || vec == 2 || vec == 4) && vec == values.size() && - "invalid vec size"); - - PTXBuilder ptxBuilder; - - std::string ptxAsm = "{\n\t" - ".reg .u32 remoteAddr;\n\t" - "mapa.shared::cluster.u32 remoteAddr, $0, $1;\n\t" - ".reg .pred p;\n\t" - "mov.pred p, $2;\n\t" - "@p st.shared::cluster"; - if (vec > 1) - ptxAsm += ".v" + std::to_string(vec); - ptxAsm += ".u" + std::to_string(bitwidth) + " [remoteAddr], "; - if (vec == 1) - ptxAsm += "$3"; - else if (vec == 2) - ptxAsm += "{$3, $4}"; - else if (vec == 4) - ptxAsm += "{$3, $4, $5, $6}"; - ptxAsm += ";\n\t"; - ptxAsm += "}\n"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - - std::string c = bitwidth == 16 ? "h" : (bitwidth == 32 ? "r" : "l"); - SmallVector oprs; - auto *addrOpr = ptxBuilder.newOperand(addr, "r"); - oprs.push_back(addrOpr); - auto *ctaIdOpr = ptxBuilder.newOperand(ctaId, "r"); - oprs.push_back(ctaIdOpr); - auto *predOpr = ptxBuilder.newOperand(pred, "b"); - oprs.push_back(predOpr); - for (unsigned i = 0; i < values.size(); i++) { - auto *valueOpr = ptxBuilder.newOperand(values[i], c); - oprs.push_back(valueOpr); - } - ptxInstr(oprs, - /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy, /*hasSideEffect*/ true); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -class Sts64OpPattern : public mlir::RewritePattern { -public: - Sts64OpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::Sts64Op::getOperationName(), 1, context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto sts64Op = llvm::dyn_cast(op); - if (!sts64Op) - return mlir::failure(); - auto loc = op->getLoc(); - auto offset = sts64Op.getOffset(); - auto d0 = sts64Op.getD0(); - auto d1 = sts64Op.getD1(); - - PTXBuilder ptxBuilder; - - std::string ptxAsm = "st.shared.v2.b32 [$0], {$1, $2}"; - auto &ptxInstr = *ptxBuilder.create(ptxAsm); - - SmallVector oprs; - auto *addrOpr = ptxBuilder.newOperand(offset, "r"); - auto *d0Opr = ptxBuilder.newOperand(d0, "r"); - auto *d1Opr = ptxBuilder.newOperand(d1, "r"); - - ptxInstr({addrOpr, d0Opr, d1Opr}, - /*onlyAttachMLIRArgs=*/true); - - auto asmReturnTy = void_ty(ctx); - ptxBuilder.launch(rewriter, loc, asmReturnTy); - rewriter.eraseOp(op); - return mlir::success(); - } -}; - -class RegAllocOpPattern - : public NVGPUOpPatternBase { -public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::RegAllocOp op) const { - auto regCount = op.getRegCount(); - return "setmaxnreg.inc.sync.aligned.u32 " + std::to_string(regCount) + ";"; - } -}; - -class RegDeallocOpPattern - : public NVGPUOpPatternBase { -public: - using Base = NVGPUOpPatternBase; - using Base::Base; - - std::string getPtxAsm(ttn::RegDeallocOp op) const { - auto regCount = op.getRegCount(); - return "setmaxnreg.dec.sync.aligned.u32 " + std::to_string(regCount) + ";"; - } -}; - -class ClusterCTAIdOpPattern : public mlir::RewritePattern { -public: - ClusterCTAIdOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::ClusterCTAIdOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto clusterCTAIdOp = llvm::dyn_cast(op); - if (!clusterCTAIdOp) - return mlir::failure(); - auto loc = op->getLoc(); - - auto x = getSRegValue(rewriter, loc, "%cluster_ctaid.x"); - auto y = getSRegValue(rewriter, loc, "%cluster_ctaid.y"); - auto z = getSRegValue(rewriter, loc, "%cluster_ctaid.z"); - auto nx = getSRegValue(rewriter, loc, "%cluster_nctaid.x"); - auto ny = getSRegValue(rewriter, loc, "%cluster_nctaid.y"); - auto res = add(x, mul(add(y, mul(z, ny)), nx)); - rewriter.replaceOp(op, {res}); - return mlir::success(); - } -}; - -class WGMMADescCreateOpPattern : public mlir::RewritePattern { -public: - WGMMADescCreateOpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::WGMMADescCreateOp::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto wgmmaDescCreateOp = llvm::dyn_cast(op); - if (!wgmmaDescCreateOp) - return mlir::failure(); - auto loc = op->getLoc(); - auto buffer = wgmmaDescCreateOp.getBuffer(); - auto height = wgmmaDescCreateOp.getHeight(); - uint32_t mode = static_cast(wgmmaDescCreateOp.getMode()); - - auto smem_nvvm_pointer = ptrtoint(i64_ty, buffer); - - Value desc = int_val(64, 0); - uint64_t swizzling = (mode == 1 ? 128 : mode == 2 ? 64 : 32); - Value swizzling_ = int_val(64, swizzling); - Value smem_address_bit = smem_nvvm_pointer; - - Value strideDimension = - lshr(shl(swizzling_, int_val(64, 3)), int_val(64, 4)); - Value height64 = zext(i64_ty, height); - Value leadingDimension = lshr(mul(height64, swizzling_), int_val(64, 4)); - - // Value baseOffset = int_val(64, 0); - Value startAddr = - lshr(shl(smem_address_bit, int_val(64, 46)), int_val(64, 50)); - - Value mode_ = int_val(64, mode); - desc = or_(desc, shl(mode_, int_val(64, 62))); - desc = or_(desc, shl(strideDimension, int_val(64, 32))); - desc = or_(desc, shl(leadingDimension, int_val(64, 16))); - // desc = or_(desc, shl(baseOffset, int_val(64, 49))); - desc = or_(desc, startAddr); - - rewriter.replaceOp(op, {desc}); - return mlir::success(); - } -}; - -class OffsetOfSts64OpPattern : public mlir::RewritePattern { -public: - OffsetOfSts64OpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::OffsetOfSts64Op::getOperationName(), 1, - context) {} - - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto offsetOfSts64Op = llvm::dyn_cast(op); - if (!offsetOfSts64Op) - return mlir::failure(); - auto loc = op->getLoc(); - auto threadId = offsetOfSts64Op.getThreadId(); - auto rowOfWarp = offsetOfSts64Op.getRowOfWarp(); - auto elemIdx = offsetOfSts64Op.getElemIdx(); - auto leadingDimOffset = offsetOfSts64Op.getLeadingDimOffset(); - auto rowStride = offsetOfSts64Op.getRowStride(); - auto swizzleEnabled = offsetOfSts64Op.getSwizzleEnabled(); + std::string getPtxAsm(ttn::OffsetOfSts64Op op) const { + auto rowStride = op.getRowStride(); + auto swizzleEnabled = op.getSwizzleEnabled(); if (swizzleEnabled) { assert((rowStride == 32 || rowStride == 64 || rowStride == 128) && @@ -1048,51 +904,77 @@ public: } else if (rowStride == 32) { perPhase = 4; maxPhase = 2; + } else { + assert(false && "Unsupported rowStride"); } - auto laneId = and_(threadId, i32_val(0x1f)); - auto myRow = - add(mul(and_(lshr(elemIdx, i32_val(1)), i32_val(0x1)), i32_val(8)), - udiv(laneId, i32_val(4))); - auto myCol = add(mul(udiv(elemIdx, i32_val(4)), i32_val(8)), - mul(urem(laneId, i32_val(4)), i32_val(2))); - myRow = add(myRow, rowOfWarp); - auto phase = urem(udiv(myRow, i32_val(perPhase)), i32_val(maxPhase)); - auto lineOffset = - add(mul(urem(myRow, i32_val(perPhase)), i32_val(rowStride)), - mul(myCol, i32_val(4))); - auto colOffset = - add(mul(xor_(udiv(lineOffset, i32_val(16)), phase), i32_val(16)), - urem(lineOffset, i32_val(16))); - auto offset = - add(mul(udiv(myRow, i32_val(perPhase)), i32_val(128)), colOffset); - - rewriter.replaceOp(op, {offset}); - return mlir::success(); + auto ptxAsm = "{\n" + ".reg .u32 a<9>; \n" + "and.b32 a0, $1, 0x1f;\n" // laneid + "shr.b32 a1, $2, 4; \n" + "and.b32 a1, a1, 0x1; \n" + "div.u32 a2, a0, 4; \n" + "mad.lo.u32 a2, a1, 8, a2; \n" // myRow + "div.u32 a3, $2, 4; \n" + "rem.u32 a4, a0, 4; \n" + "mul.lo.u32 a4, a4, 2; \n" + "mad.lo.u32 a4, a3, 8, a4; \n" // myCol + "add.u32 a2, a2, $3; \n" // myRow = myRow + rowOfWarp + "div.u32 a3, a2, " + + std::to_string(perPhase) + + "; \n" + "rem.u32 a3, a3, " + + std::to_string(maxPhase) + + "; \n" // phase + "rem.u32 a5, a2, " + + std::to_string(perPhase) + + "; \n" // lineOffset + "mul.lo.u32 a5, a5, #rowStride; \n" + "mad.lo.u32 a5, a4, 4, a5; \n" // lineOffset + "div.u32 a6, a5, 16; \n" + "xor.b32 a6, a6, a3; \n" // colOffset + "rem.u32 a7, a5, 16; \n" + "mad.lo.u32 a7, a6, 16, a7; \n" // colOffset + "div.u32 a8, a2, #perPhase; \n" + "mad.lo.u32 $0, a8, 128, a7; \n" // offset + "}"; + return ptxAsm; } }; -class OffsetOfStmatrixV4OpPattern : public mlir::RewritePattern { +class OffsetOfStmatrixV4OpPattern + : public NVGPUOpPatternBase { public: - OffsetOfStmatrixV4OpPattern(mlir::MLIRContext *context) - : mlir::RewritePattern(ttn::OffsetOfStmatrixV4Op::getOperationName(), 1, - context) {} + using Base = NVGPUOpPatternBase; + using Base::Base; - mlir::LogicalResult - matchAndRewrite(mlir::Operation *op, - mlir::PatternRewriter &rewriter) const override { - auto ctx = rewriter.getContext(); - auto offsetOfStmatrixV4Op = llvm::dyn_cast(op); - if (!offsetOfStmatrixV4Op) - return mlir::failure(); - auto loc = op->getLoc(); - auto threadId = offsetOfStmatrixV4Op.getThreadId(); - auto rowOfWarp = offsetOfStmatrixV4Op.getRowOfWarp(); - auto elemIdx = offsetOfStmatrixV4Op.getElemIdx(); - auto leadingDimOffset = offsetOfStmatrixV4Op.getLeadingDimOffset(); - auto rowStride = offsetOfStmatrixV4Op.getRowStride(); - auto swizzleEnabled = offsetOfStmatrixV4Op.getSwizzleEnabled(); + std::vector + getOutputConstraints(ttn::OffsetOfStmatrixV4Op op) const { + return {"=r"}; + } + OperandsAndConstraints + getOperandsAndConstraints(ttn::OffsetOfStmatrixV4Op op) const { + OperandsAndConstraints operandsAndConstraints; + auto threadId = op.getThreadId(); + auto rowOfWarp = op.getRowOfWarp(); + auto elemIdx = op.getElemIdx(); + + operandsAndConstraints.push_back({threadId, "r"}); + operandsAndConstraints.push_back({elemIdx, "r"}); + operandsAndConstraints.push_back({rowOfWarp, "r"}); + + return operandsAndConstraints; + } + + std::string getPtxAsm(ttn::OffsetOfStmatrixV4Op op) const { + auto leadingDimOffset = op.getLeadingDimOffset(); + auto rowStride = op.getRowStride(); + auto swizzleEnabled = op.getSwizzleEnabled(); + + std::string ptxAsm; if (swizzleEnabled) { uint32_t perPhase = 0; uint32_t maxPhase = 0; @@ -1105,43 +987,71 @@ public: } else if (rowStride == 16) { perPhase = 4; maxPhase = 2; + } else { + assert(false && "Unsupported rowStride"); } - Value iterOfCol = udiv(elemIdx, i32_val(8)); - Value myRow = add(rowOfWarp, and_(threadId, i32_val(0xf))); - Value myCol = - mul(and_(lshr(threadId, i32_val(4)), i32_val(0x1)), i32_val(8)); - myCol = add(myCol, mul(iterOfCol, i32_val(16))); - - Value offset0 = - mul(udiv(myCol, i32_val(rowStride)), i32_val(leadingDimOffset)); - myCol = urem(myCol, i32_val(rowStride)); - - Value phase = urem(udiv(myRow, i32_val(perPhase)), i32_val(maxPhase)); - - Value lineOffset = - add(mul(urem(myRow, i32_val(perPhase)), i32_val(rowStride)), myCol); - Value colOffset = - add(mul(xor_(udiv(lineOffset, i32_val(8)), phase), i32_val(8)), - urem(lineOffset, i32_val(8))); - Value offset1 = - add(mul(udiv(myRow, i32_val(perPhase)), i32_val(64)), colOffset); - - Value res = add(offset1, offset0); - - rewriter.replaceOp(op, {res}); + ptxAsm = + "{\n" + ".reg .u32 a<10>; \n" + "div.u32 a0, $2, 8; \n" // iterOfCol = udiv(elemIdx, i32_val(8)) + "and.b32 a1, $1, 0xf; \n" // myRow = and_(threadId, i32_val(0xf)) + "add.u32 a1, a1, $3; \n" // myRow = myRow + rowOfWarp + "shr.b32 a2, $1, 4; \n" // myCol = lshr(threadId, i32_val(4)) + "and.b32 a2, a2, 0x1; \n" // myCol = and_(myCol, i32_val(0x1)) + "mul.lo.u32 a2, a2, 8; \n" // myCol = mul(myCol, i32_val(8)) + "mad.lo.u32 a2, a0, 16, a2; \n" // myCol = add(myCol, + // mul(iterOfCol, i32_val(16))) + "div.u32 a3, a2, #rowStride; \n" // offset0 = udiv(myCol, + // i32_val(rowStride)) + "mul.lo.u32 a3, a3, #leadingDimOffset; \n" // offset0 = mul(offset0, + // i32_val(leadingDimOffset)) + "rem.u32 a2, a2, #rowStride; \n" // myCol = myCol % rowStride + "div.u32 a4, a1, " + + std::to_string(perPhase) + + "; \n" // phase = myrow // perPhase + "rem.u32 a4, a4, " + + std::to_string(maxPhase) + + "; \n" // phase = phase % maxPhase + "rem.u32 a5, a1, " + + std::to_string(perPhase) + + "; \n" // lineOffset = urem(myRow, i32_val(perPhase)) + "mad.lo.u32 a5, a5, #rowStride, a2; \n" // lineOffset = + // add(mul(lineOffset, + // rowStride), myCol) + "div.u32 a6, a5, 8; \n" // colOffset = udiv(lineOffset, i32_val(8) + "xor.b32 a6, a6, a4; \n" // colOffset = xor_(colOffset, phase) + "rem.u32 a7, a5, 8; \n" // temp = urem(lineOffset, i32_val(8) + "mad.lo.u32 a7, a6, 8, a7; \n" // colOffset = add(mul(colOffset, + // i32_val(8)), temp) + "div.u32 a8, a1, " + + std::to_string(perPhase) + + "; \n" // offset1 = udiv(myRow, i32_val(perPhase)) + "mad.lo.u32 a9, a8, 64, a7; \n" // offset1 = add(mul(offset1, + // i32_val(64)), colOffset) + "add.u32 $0, a9, a3; \n" // result = add(offset1, offset0) + "}"; } else { - Value iterOfCol = udiv(elemIdx, i32_val(4)); - Value myRow = add(rowOfWarp, and_(threadId, i32_val(0xf))); - Value myCol = - mul(and_(lshr(threadId, i32_val(4)), i32_val(0x1)), i32_val(8)); - myCol = add(myCol, mul(iterOfCol, i32_val(16))); - - Value offset = - add(mul(myRow, i32_val(rowStride)), mul(myCol, i32_val(2))); - rewriter.replaceOp(op, {offset}); + ptxAsm = "{\n" + ".reg .u64 a<5>; \n" + "div.u32 a0, $2, 4; \n" // iterOfCol = udiv(elemIdx, + // i32_val(4)) + "and.b32 a1, $1, 0xf; \n" // myRow = and_(threadId, + // i32_val(0xf)) + "add.u32 a1, a1, $3; \n" // myRow = myRow + rowOfWarp + "shr.b32 a2, $1, 4; \n" // myCol = lshr(threadId, + // i32_val(4)) + "and.b32 a2, a2, 0x1; \n" // myCol = and_(myCol, + // i32_val(0x1)) + "mul.lo.u32 a2, a2, 8; \n" // myCol = mul(myCol, + // i32_val(8)) + "mul.u32 a3, a1, #rowStride; \n" // offset = myRow * rowStride + "mad.lo.u32 $0, a2, 2, a3; \n" // result = add(mul(myCol, + // i32_val(2)), offset) + "}\n"; } - return mlir::success(); + + return ptxAsm; } }; @@ -1155,35 +1065,43 @@ public: ModuleOp mod = getOperation(); RewritePatternSet patterns(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); +#define POPULATE_NVGPU_OP(SRC_OP, ASM) \ + patterns.add>(context, ASM, Constraints(), \ + Constraints()); + POPULATE_NVGPU_OP(ttn::RegAllocOp, Reg_Alloc_Op) + POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op) + POPULATE_NVGPU_OP(ttn::CGABarrierSyncOp, Cga_Barrier_Sync_op) + POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op) + POPULATE_NVGPU_OP(ttn::WGMMAWaitGroupOp, Wgmma_Wait_Group_Op) + POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op) + POPULATE_NVGPU_OP(ttn::FenceMBarrierInitOp, Fence_Mbarrier_Init_Op) + POPULATE_NVGPU_OP(ttn::CGABarrierArriveOp, Cga_Barrier_Arrive_Op) + POPULATE_NVGPU_OP(ttn::CGABarrierWaitOp, Cga_Barrier_Wait_Op) + POPULATE_NVGPU_OP(ttn::RegDeallocOp, Reg_Dealloc_Op) +#undef POPULATE_NVGPU_OP + patterns.add>( + context, Mbarrier_Init_Op, Constraints(), Constraints({"r", "b"})); + patterns.add>( + context, Mbarrier_Wait_Op, Constraints(), Constraints({"r", "r"})); + patterns.add>( + context, Named_Barrier_Arrive_Op, Constraints(), + Constraints({"r", "r"})); + patterns.add>( + context, Named_Barrier_Wait_Op, Constraints(), Constraints({"r", "r"})); + patterns.add>( + context, Sts64_Op, Constraints(), Constraints({"r", "r", "r"})); + patterns.add>( + context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints()); + patterns.add>( + context, Wgmma_Desc_Create_op, Constraints({"=l"}), + Constraints({"l", "l"})); + + patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); - patterns.add(context); if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); } diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index a1d069fbb..68c8dfa71 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -1,5 +1,14 @@ +<<<<<<< HEAD add_library(rocm_libraries SHARED IMPORTED ) set_target_properties(rocm_libraries PROPERTIES IMPORTED_LOCATION ${ROCM_LIBRARIES}) +======= +# Separate out PTX/GCN builders to avoid cyclic dependencies as TritonAnalysis +# depends on it. +set(LLVM_OPTIONAL_SOURCES + GCNAsmFormat.cpp + PTXAsmFormat.cpp + ) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 add_mlir_conversion_library(TritonGPUToLLVM ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -15,6 +24,7 @@ add_mlir_conversion_library(TritonGPUToLLVM LoadStoreOpToLLVM.cpp BarrierOpToLLVM.cpp TritonGPUToLLVM.cpp +<<<<<<< HEAD GCNAsmFormat.cpp PTXAsmFormat.cpp ConvertLayoutOpToLLVM/SharedToDotOperandFMA.cpp @@ -30,9 +40,9 @@ add_mlir_conversion_library(TritonGPUToLLVM ElementwiseOpToLLVM.cpp LoadStoreOpToLLVM.cpp TritonGPUToLLVM.cpp +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 TritonGPUToLLVMPass.cpp - GCNAsmFormat.cpp - PTXAsmFormat.cpp ReduceOpToLLVM.cpp ScanOpToLLVM.cpp TypeConverter.cpp @@ -53,6 +63,7 @@ add_mlir_conversion_library(TritonGPUToLLVM Core LINK_LIBS PUBLIC + ASMBuilder MLIRIR MLIRPass MLIRGPUOps @@ -67,3 +78,15 @@ add_mlir_conversion_library(TritonGPUToLLVM NVGPUIR rocm_libraries ) + +add_mlir_library(ASMBuilder + GCNAsmFormat.cpp + PTXAsmFormat.cpp + + DEPENDS + TritonTableGen + + LINK_LIBS PUBLIC + MLIRAnalysis + MLIRLLVMDialect +) diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 04a6455a4..de08b7753 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -91,8 +91,7 @@ public: // forwarding on mma->mma shortcut, lower distributed->distributed otherwise if (srcLayout.isa() && dstLayout.isa()) { if (isMmaToMmaShortcut(srcTy, dstTy)) { - rewriter.replaceOp(op, op.getSrc()); - return success(); + return lowerMmaToMma(op, adaptor, rewriter); } } if (isaDistributedLayout(srcLayout) && isaDistributedLayout(dstLayout)) { @@ -195,9 +194,11 @@ private: Value _16 = i32_val(16); if (mmaLayout.isAmpere() || mmaLayout.isHopper()) { multiDimWarpId[0] = - urem(multiDimWarpId[0], i32_val(shapePerCTA[0] / instrShape[0])); + urem(multiDimWarpId[0], + i32_val(ceil(shapePerCTA[0], instrShape[0]))); multiDimWarpId[1] = - urem(multiDimWarpId[1], i32_val(shapePerCTA[1] / instrShape[1])); + urem(multiDimWarpId[1], + i32_val(ceil(shapePerCTA[1], instrShape[1]))); Value mmaGrpId = udiv(laneId, _4); Value mmaGrpIdP8 = add(mmaGrpId, _8); @@ -267,12 +268,30 @@ private: llvm_unreachable("unexpected layout in getMultiDimOffset"); } + SmallVector + getWrappedMultiDimOffset(ConversionPatternRewriter &rewriter, Location loc, + ArrayRef multiDimOffset, + ArrayRef shape, + SmallVector shapePerCTATile, + SmallVector shapePerCTA) const { + unsigned rank = shape.size(); + SmallVector multiDimOffsetWrapped(rank); + for (unsigned d = 0; d < rank; ++d) { + if (shapePerCTATile[d] > shapePerCTA[d]) + multiDimOffsetWrapped[d] = urem(multiDimOffset[d], i32_val(shape[d])); + else + multiDimOffsetWrapped[d] = multiDimOffset[d]; + } + return multiDimOffsetWrapped; + } + // shared memory rd/st for blocked or mma layout with data padding void processReplica(Location loc, ConversionPatternRewriter &rewriter, bool stNotRd, RankedTensorType type, ArrayRef numCTAsEachRep, ArrayRef multiDimRepId, unsigned vec, ArrayRef paddedRepShape, + ArrayRef origRepShape, ArrayRef outOrd, SmallVector &vals, Value smemBase) const { auto accumNumCTAsEachRep = product(numCTAsEachRep); @@ -316,8 +335,11 @@ private: SmallVector multiDimOffset = getMultiDimOffset(layout, loc, rewriter, elemId, type, multiDimCTAInRepId, shapePerCTATile); - Value offset = - linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd); + SmallVector multiDimOffsetWrapped = getWrappedMultiDimOffset( + rewriter, loc, multiDimOffset, origRepShape, shapePerCTATile, + shapePerCTA); + Value offset = linearize(rewriter, loc, multiDimOffsetWrapped, + paddedRepShape, outOrd); auto elemPtrTy = ptr_ty(llvmElemTy, 3); Value ptr = gep(elemPtrTy, smemBase, offset); auto vecTy = vec_ty(llvmElemTy, vec); @@ -376,7 +398,7 @@ private: SmallVector numCTAsEachRep(rank, 1); SmallVector shapePerCTATile = getShapePerCTATile(layout, shape); SmallVector shapePerCTA = getShapePerCTA(layout, shape); - auto elemTy = type.getElementType(); + auto elemTy = getTypeConverter()->convertType(type.getElementType()); int ctaId = 0; @@ -605,6 +627,7 @@ private: rewriter, srcTy); unsigned inVec = 0; unsigned outVec = 0; + auto origRepShape = getRepShapeForCvtLayout(op); auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec); if (getElementTypeOrSelf(op.getType()) .isa()) { @@ -651,7 +674,7 @@ private: else processReplica(loc, rewriter, /*stNotRd*/ true, srcTy, inNumCTAsEachRep, multiDimRepId, inVec, paddedRepShape, - outOrd, vals, smemBase); + origRepShape, outOrd, vals, smemBase); } else { assert(0 && "ConvertLayout with input layout not implemented"); return failure(); @@ -687,7 +710,8 @@ private: else processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, - paddedRepShape, outOrd, outVals, smemBase); + paddedRepShape, origRepShape, outOrd, outVals, + smemBase); } else { assert(0 && "ConvertLayout with output layout not implemented"); return failure(); @@ -962,6 +986,11 @@ private: auto loc = op.getLoc(); auto srcTy = op.getSrc().getType().cast(); auto dstTy = op.getResult().getType().cast(); + if (matchMmaV3AndDotOperandLayout(srcTy, dstTy)) { + rewriter.replaceOp(op, op.getSrc()); + return success(); + } + if (isMmaToDotShortcut(srcTy, dstTy)) { // get source values auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), @@ -1030,6 +1059,7 @@ private: return failure(); } +<<<<<<< HEAD #ifdef USE_ROCM // shared -> dot_operand if the result layout is mma Value lowerSharedToDotOperandMFMA( @@ -1054,6 +1084,44 @@ private: return res; } #endif +======= + // mma -> mma + LogicalResult lowerMmaToMma(triton::gpu::ConvertLayoutOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); + auto srcTy = op.getSrc().getType().cast(); + auto dstTy = op.getResult().getType().cast(); + if (triton::gpu::getTotalElemsPerThread(srcTy) == + triton::gpu::getTotalElemsPerThread(dstTy)) { + rewriter.replaceOp(op, op.getSrc()); + return success(); + } + // get source values + auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(), + rewriter, srcTy); + SmallVector retVals; + SmallVector dstElementPerThread = + triton::gpu::getElemsPerThread(dstTy); + SmallVector srcElementPerThread = + triton::gpu::getElemsPerThread(srcTy); + for (unsigned j = 0; j < dstElementPerThread[0]; j++) { + for (unsigned i = 0; i < dstElementPerThread[1]; i++) { + if (i >= srcElementPerThread[1] || j >= srcElementPerThread[0]) { + retVals.push_back(undef(vals[0].getType())); + continue; + } + unsigned index = i + j * srcElementPerThread[1]; + retVals.push_back(vals[index]); + } + } + assert(retVals.size() == triton::gpu::getTotalElemsPerThread(dstTy)); + Value view = + getTypeConverter()->packLLElements(loc, retVals, rewriter, dstTy); + rewriter.replaceOp(op, view); + return success(); + } +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 // shared -> dot_operand if the result layout is mma Value lowerSharedToDotOperandMMA( diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp index bd222f9e6..870c8980b 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/MMAv2.cpp @@ -145,6 +145,9 @@ inline static const std::map mmaInstrPtxTuring = { {TensorCoreType::FP32_FP16_FP16_FP32, "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"}, + {TensorCoreType::INT32_INT8_INT8_INT32, + "mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32"}, + {TensorCoreType::FP16_FP16_FP16_FP16, "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16"}, }; @@ -168,6 +171,107 @@ inline static const std::map mmaInstrPtxAmpere = { "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16"}, }; +static void callMmaTuringInt8(PTXBuilder &builder, unsigned m, unsigned n, + unsigned k, mlir::triton::PTXInstr &mma, + unsigned numMmaRets, unsigned colsPerThread, + int numCPackedElem, ValueTableV2 &ha, + ValueTableV2 &hb, const SmallVector &fc) { + auto retArgs1 = builder.newListOperand(numMmaRets / 2, "=r"); + auto retArgs2 = builder.newListOperand(numMmaRets / 2, "=r"); + auto cArgs1 = builder.newListOperand(); + for (int i = 0; i < numMmaRets / 2; ++i) { + cArgs1->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto cArgs2 = builder.newListOperand(); + for (int i = numMmaRets / 2; i < numMmaRets; ++i) { + cArgs2->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto aArgs1 = builder.newListOperand({ + {ha[{m, k}], "r"}, + }); + auto bArgs1 = builder.newListOperand({ + {hb[{n, k}], "r"}, + }); + auto aArgs2 = builder.newListOperand({ + {ha[{m, k + 1}], "r"}, + }); + auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); + auto aArgs3 = builder.newListOperand({ + {ha[{m + 1, k}], "r"}, + }); + auto bArgs3 = builder.newListOperand({ + {hb[{n, k}], "r"}, + }); + auto aArgs4 = builder.newListOperand({ + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs4 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); + mma(retArgs1, aArgs1, bArgs1, cArgs1); + mma(retArgs1, aArgs2, bArgs2, cArgs1); + mma(retArgs2, aArgs3, bArgs3, cArgs2); + mma(retArgs2, aArgs4, bArgs4, cArgs2); +} + +static void callMmaTuringFp16(PTXBuilder &builder, unsigned m, unsigned n, + unsigned k, mlir::triton::PTXInstr &mma, + unsigned numMmaRets, unsigned colsPerThread, + int numCPackedElem, ValueTableV2 &ha, + ValueTableV2 &hb, const SmallVector &fc, + bool isAccF16) { + auto retArgs = builder.newListOperand(numMmaRets, isAccF16 ? "=r" : "=f"); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < numMmaRets; ++i) { + cArgs->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto aArgs1 = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + }); + auto bArgs1 = builder.newListOperand({{hb[{n, k}], "r"}}); + auto aArgs2 = builder.newListOperand({ + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); + mma(retArgs, aArgs1, bArgs1, cArgs); + mma(retArgs, aArgs2, bArgs2, cArgs); +} + +static void callMmaAmpere(PTXBuilder &builder, unsigned m, unsigned n, + unsigned k, mlir::triton::PTXInstr &mma, + unsigned numMmaRets, unsigned colsPerThread, + int numCPackedElem, ValueTableV2 &ha, + ValueTableV2 &hb, const SmallVector &fc, + bool isAccF16, bool isIntMMA) { + auto retArgs = + builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f"); + auto cArgs = builder.newListOperand(); + for (int i = 0; i < numMmaRets; ++i) { + cArgs->listAppend( + builder.newOperand(fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], + std::to_string(i))); + // reuse the output registers + } + auto aArgs = builder.newListOperand({ + {ha[{m, k}], "r"}, + {ha[{m + 1, k}], "r"}, + {ha[{m, k + 1}], "r"}, + {ha[{m + 1, k + 1}], "r"}, + }); + auto bArgs = + builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); + mma(retArgs, aArgs, bArgs, cArgs); +} + LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Value a, Value b, Value c, Value d, Value loadedA, @@ -215,42 +319,19 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, // using =r for float32 works but leads to less readable ptx. bool isIntMMA = dTensorTy.getElementType().isInteger(32); bool isAccF16 = dTensorTy.getElementType().isF16(); - auto retArgs = - builder.newListOperand(numMmaRets, isIntMMA || isAccF16 ? "=r" : "=f"); - auto cArgs = builder.newListOperand(); - for (int i = 0; i < numMmaRets; ++i) { - cArgs->listAppend(builder.newOperand( - fc[(m * colsPerThread + 4 * n) / numCPackedElem + i], - std::to_string(i))); - // reuse the output registers - } if (isTuring) { - auto aArgs1 = builder.newListOperand({ - {ha[{m, k}], "r"}, - {ha[{m + 1, k}], "r"}, - }); - auto bArgs1 = builder.newListOperand({ - {hb[{n, k}], "r"}, - }); - auto aArgs2 = builder.newListOperand({ - {ha[{m, k + 1}], "r"}, - {ha[{m + 1, k + 1}], "r"}, - }); - auto bArgs2 = builder.newListOperand({{hb[{n, k + 1}], "r"}}); - mma(retArgs, aArgs1, bArgs1, cArgs); - mma(retArgs, aArgs2, bArgs2, cArgs); - } else { - auto aArgs = builder.newListOperand({ - {ha[{m, k}], "r"}, - {ha[{m + 1, k}], "r"}, - {ha[{m, k + 1}], "r"}, - {ha[{m + 1, k + 1}], "r"}, - }); - auto bArgs = - builder.newListOperand({{hb[{n, k}], "r"}, {hb[{n, k + 1}], "r"}}); - mma(retArgs, aArgs, bArgs, cArgs); + if (isIntMMA) // Turing int8 + callMmaTuringInt8(builder, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, ha, hb, fc); + else // Turing fp16 + callMmaTuringFp16(builder, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, ha, hb, fc, isAccF16); + } else { // Ampere + callMmaAmpere(builder, m, n, k, mma, numMmaRets, colsPerThread, + numCPackedElem, ha, hb, fc, isAccF16, isIntMMA); } + Value mmaOut = builder.launch(rewriter, loc, getMmaRetType(mmaType, op.getContext())); diff --git a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp index fbccdeefc..b16aee5d8 100644 --- a/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DotOpToLLVM/WGMMA.cpp @@ -21,8 +21,8 @@ * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ -#include "DotOpToLLVM.h" -#include "Utility.h" +#include "../DotOpToLLVM.h" +#include "../Utility.h" using namespace mlir; using namespace mlir::triton; @@ -95,13 +95,13 @@ getModeFromLayout(const SharedEncodingAttr &layout, uint32_t widthInByte) { class DotOpMmaV3SmemLoader { public: - DotOpMmaV3SmemLoader(Value tensor, const SharedMemoryObject &smemObj, - SmallVector shape, Value warpId, - unsigned int dimWpt, bool trans, + DotOpMmaV3SmemLoader() {} + DotOpMmaV3SmemLoader(Value tensor, Value base, SmallVector shape, + Value warpId, unsigned int dimWpt, bool trans, SmallVector instrShape, ConversionPatternRewriter &rewriter, Location loc) - : base(smemObj.base), shape(shape), warpId(warpId), dimWpt(dimWpt), - trans(trans), instrShape(instrShape), rewriter(rewriter), loc(loc) { + : base(base), shape(shape), warpId(warpId), dimWpt(dimWpt), trans(trans), + instrShape(instrShape) { auto tensorTy = tensor.getType().cast(); auto sharedLayout = tensorTy.getEncoding().cast(); ord = sharedLayout.getOrder(); @@ -115,10 +115,11 @@ public: mode = getModeFromLayout(sharedLayout, widthInByte); baseDesc = rewriter.create( - loc, i64_ty, base, i32_val(shape[ord[1]]), mode); + loc, base, i32_val(shape[ord[1]]), mode); } - Value smemLoad(int a, int b) { + Value smemLoad(int a, int b, ConversionPatternRewriter &rewriter, + Location loc) { Value k = i32_val(b * instrShape[1]); Value m = add(i32_val(a * dimWpt * instrShape[0]), mul(warpId, i32_val(instrShape[0]))); @@ -146,8 +147,6 @@ private: mlir::triton::nvgpu::WGMMADescMode mode; SmallVector instrShape; ArrayRef ord; - ConversionPatternRewriter &rewriter; - Location loc; int elemsPerSwizzlingRow; int elemBytes; Value baseDesc; @@ -156,7 +155,7 @@ private: DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, const MmaEncodingAttr &mmaEncoding, Value tensor, - const SharedMemoryObject &smemObj, Value thread) { + Value smemObjBase, Value thread) { auto aTensorTy = tensor.getType().cast(); auto aSharedLayout = aTensorTy.getEncoding().dyn_cast(); assert(aSharedLayout && "only support load dot operand from shared."); @@ -174,7 +173,7 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter, Value warpId = urem(warpM, i32_val(shapePerCTA[0] / instrShape[0])); return {tensor, - smemObj, + smemObjBase, shapePerCTA, warpId, wpt[0], @@ -187,7 +186,7 @@ DotOpMmaV3SmemLoader loadA(TritonGPUToLLVMTypeConverter *typeConverter, DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, MmaEncodingAttr &mmaEncoding, Value tensor, - const SharedMemoryObject &smemObj, Value thread) { + Value base, Value thread) { auto bTensorTy = tensor.getType().cast(); auto bSharedLayout = bTensorTy.getEncoding().cast(); assert(bSharedLayout && "only support load B from shared."); @@ -206,7 +205,7 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, Value warpId = urem(warpN, i32_val(shapePerCTA[1] / instrShape[1])); return {tensor, - smemObj, + base, shapePerCTA, warpId, wpt[1], @@ -218,9 +217,13 @@ DotOpMmaV3SmemLoader loadB(TritonGPUToLLVMTypeConverter *typeConverter, // Return a vector of Value of the accumulator start at startIndex and pack the // values into 32bits in case the accumulator is fp16. -llvm::SmallVector loadC(ConversionPatternRewriter &rewriter, - Location loc, const SmallVector &elements, - int startIndex, int numElements) { +llvm::SmallVector loadReg(ConversionPatternRewriter &rewriter, + Location loc, + const SmallVector &elements, + int startIndex, int numElements, + Operation *insertBefore) { + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPoint(insertBefore); if (!elements[0].getType().isF16()) { llvm::SmallVector mmaOut(numElements); for (int i = 0; i < numElements; ++i) @@ -260,22 +263,59 @@ SmallVector unpackAccumulator(ConversionPatternRewriter &rewriter, return results; } +static bool isFP8(triton::nvgpu::WGMMAEltType eltType) { + return eltType == triton::nvgpu::WGMMAEltType::e5m2 || + eltType == triton::nvgpu::WGMMAEltType::e4m3; +} + +static Value faddAccumulate(ConversionPatternRewriter &rewriter, Location loc, + Value a, Value b) { + int numEl = a.getType().cast().getBody().size(); + Value newStruct = rewriter.create(loc, a.getType()); + for (int i = 0; i < numEl; ++i) { + Value lhs = rewriter.create(loc, a, i); + Value rhs = rewriter.create(loc, b, i); + Value add = rewriter.create(loc, lhs, rhs); + newStruct = rewriter.create(loc, newStruct, add, i); + } + return newStruct; +} + +static bool isZero(Value v) { + auto constantOp = v.getDefiningOp(); + if (!constantOp) + return false; + if (auto denseAttr = dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + if (auto denseAttr = + dyn_cast(constantOp.getValueAttr())) + return denseAttr.isSplat() && denseAttr.getSplatValue().isZero(); + return false; +} + LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Location loc, Operation *op, Value a, Value b, Value c, Value d, Value loadedA, Value loadedB, Value loadedC, - bool allowTF32, const SharedMemoryObject &smemObjA, - const SharedMemoryObject &smemObjB, bool sync, + bool allowTF32, uint32_t maxNumImpreciseAcc, bool sync, Value thread) { auto aTensorTy = a.getType().cast(); auto bTensorTy = b.getType().cast(); auto dTensorTy = d.getType().cast(); - auto aSharedLayout = aTensorTy.getEncoding().cast(); + auto aSharedLayout = aTensorTy.getEncoding().dyn_cast(); auto bSharedLayout = bTensorTy.getEncoding().cast(); auto mmaEncoding = dTensorTy.getEncoding().cast(); - auto aOrd = aSharedLayout.getOrder(); auto bOrd = bSharedLayout.getOrder(); - bool transA = aOrd[0] == 0; + bool transA = false; + Value baseA; + Value baseB; + if (aSharedLayout) + baseA = getSharedMemoryObjectFromStruct(loc, loadedA, rewriter).base; + baseB = getSharedMemoryObjectFromStruct(loc, loadedB, rewriter).base; + if (aSharedLayout) { + auto aOrd = aSharedLayout.getOrder(); + transA = aOrd[0] == 0; + } bool transB = bOrd[0] == 1; auto dShapePerCTA = getShapePerCTA(dTensorTy); auto instrShape = mmaEncoding.getInstrShape(); @@ -283,16 +323,22 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, int M = 4 * instrShape[0]; int N = instrShape[1]; int K = instrShape[2]; - + bool zeroAcc = isZero(c); auto shapePerCTATile = getShapePerCTATile(mmaEncoding); int numRepM = ceil(dShapePerCTA[0], shapePerCTATile[0]); int numRepN = ceil(dShapePerCTA[1], shapePerCTATile[1]); int numRepK = ceil(aTensorTy.getShape()[1], instrShape[2]); - - DotOpMmaV3SmemLoader aLoader = - loadA(typeConverter, rewriter, loc, mmaEncoding, a, smemObjA, thread); + DotOpMmaV3SmemLoader aLoader; + SmallVector structA; + if (aSharedLayout) { + aLoader = + loadA(typeConverter, rewriter, loc, mmaEncoding, a, baseA, thread); + } else { + structA = + typeConverter->unpackLLElements(loc, loadedA, rewriter, aTensorTy); + } DotOpMmaV3SmemLoader bLoader = - loadB(typeConverter, rewriter, loc, mmaEncoding, b, smemObjB, thread); + loadB(typeConverter, rewriter, loc, mmaEncoding, b, baseB, thread); auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy); @@ -308,28 +354,69 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, auto func = op->getParentOfType(); int numTMADescs = func->getAttr(kAttrNumTMALoadDescsName).cast().getInt(); + Operation *startSequence = nullptr; if (numTMADescs == 0) - rewriter.create(loc, 0); - rewriter.create(loc); - + startSequence = rewriter.create(loc, 0); + Operation *fenceOp = rewriter.create(loc); + if (startSequence == nullptr) + startSequence = fenceOp; + // WGMMA fp8 -> fp32 accumulates in lower precision than fp32. + bool needsPartialAccumulator = isFP8(eltTypeA) && + eltTypeC == triton::nvgpu::WGMMAEltType::f32 && + maxNumImpreciseAcc <= aTensorTy.getShape()[1]; SmallVector mmaResults; for (int m = 0; m < numRepM; ++m) { for (int n = 0; n < numRepN; ++n) { llvm::SmallVector mmaOut = - loadC(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize); + loadReg(rewriter, loc, fc, (m * numRepN + n) * accSize, accSize, + startSequence); llvm::SmallVector elemTypes; for (Value accEl : mmaOut) elemTypes.push_back(accEl.getType()); auto accTy = LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes); - Value d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy); + Value d; + if (!zeroAcc) + d = typeConverter->packLLElements(loc, mmaOut, rewriter, accTy); + uint32_t numLowPrecisionAcc = 0; + Value partialAcc; for (int k = 0; k < numRepK; ++k) { - auto a = aLoader.smemLoad(m, k); - auto b = bLoader.smemLoad(n, k); + Value a; + if (aSharedLayout) { + a = aLoader.smemLoad(m, k, rewriter, loc); + } else { + unsigned regASize = (instrShape[0] * instrShape[2]) / 32; + llvm::SmallVector regA = + loadReg(rewriter, loc, structA, (m * numRepK + k) * regASize, + regASize, startSequence); + auto regATy = LLVM::LLVMStructType::getLiteral( + rewriter.getContext(), + SmallVector(regA.size(), regA[0].getType())); + a = typeConverter->packLLElements(loc, regA, rewriter, regATy); + } + auto b = bLoader.smemLoad(n, k, rewriter, loc); ValueRange operands{a, b, d}; - d = rewriter.create(loc, accTy, a, b, d, M, N, - K, eltTypeC, eltTypeA, - eltTypeB, layoutA, layoutB); + numLowPrecisionAcc += K; + // If using native accumulation would cause use to do more low precion + // accumulation than allowed do a separate allocation. + bool requireAddAccumulator = + needsPartialAccumulator && + (numLowPrecisionAcc >= maxNumImpreciseAcc || k == numRepK - 1); + Value mmaAcc = needsPartialAccumulator ? partialAcc : d; + mmaAcc = rewriter.create( + loc, accTy, a, b, mmaAcc, M, N, K, eltTypeC, eltTypeA, eltTypeB, + layoutA, layoutB); + if (needsPartialAccumulator) + partialAcc = mmaAcc; + else + d = mmaAcc; + // If we need accumulate separately to have higer precision, insert + // adds. + if (requireAddAccumulator) { + d = faddAccumulate(rewriter, loc, d, partialAcc); + numLowPrecisionAcc = 0; + partialAcc = Value(); + } } auto acc = typeConverter->unpackLLElements(loc, d, rewriter, accTy); for (int i = 0; i < acc.size(); ++i) { @@ -354,28 +441,6 @@ LogicalResult convertDot(TritonGPUToLLVMTypeConverter *typeConverter, return success(); } -// Loading $c to registers, returns a Value. -Value loadC(Value tensor, Value llTensor) { - auto tensorTy = tensor.getType().cast(); - auto mmaEncoding = tensorTy.getEncoding().dyn_cast(); - assert(mmaEncoding && "Currently, we only support $c with a mma layout."); - auto instrShape = mmaEncoding.getInstrShape(); - auto wpt = mmaEncoding.getWarpsPerCTA(); - auto shapePerCTA = getShapePerCTA(tensorTy); - auto shapePerCTATile = getShapePerCTATile(mmaEncoding); - - int numRepM = ceil(shapePerCTA[0], shapePerCTATile[0]); - int numRepN = ceil(shapePerCTA[1], shapePerCTATile[1]); - - size_t fcSize = 2 * (instrShape[1] / 4) * numRepM * numRepN; - - auto structTy = llTensor.getType().cast(); - assert(structTy.getBody().size() == fcSize && - "DotOp's $c operand should pass the same number of values as $d in " - "mma layout."); - return llTensor; -} - LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, TritonGPUToLLVMTypeConverter *typeConverter, ConversionPatternRewriter &rewriter, Value thread) { @@ -386,20 +451,19 @@ LogicalResult convertWGMMA(triton::DotOp op, triton::DotOp::Adaptor adaptor, auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); - assert(ATensorTy.getEncoding().isa() && - BTensorTy.getEncoding().isa() && - "Both $a and %b should be Shared layout."); + assert(ATensorTy.getEncoding().isa() || + ATensorTy.getEncoding().isa()); + assert(BTensorTy.getEncoding().isa() && + "Operand B should use Shared layout."); Value llA, llB, llC; llA = adaptor.getA(); llB = adaptor.getB(); - llC = loadC(C, adaptor.getC()); + llC = adaptor.getC(); - auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter); - auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter); return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C, - op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA, - smemObjB, true, thread); + op.getD(), llA, llB, llC, op.getAllowTF32(), + op.getMaxNumImpreciseAcc(), true, thread); } LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, @@ -414,18 +478,17 @@ LogicalResult convertAsyncWGMMA(triton::nvidia_gpu::DotAsyncOp op, auto ATensorTy = A.getType().cast(); auto BTensorTy = B.getType().cast(); - assert(ATensorTy.getEncoding().isa() && - BTensorTy.getEncoding().isa() && - "Both $a and %b should be Shared layout."); + assert(ATensorTy.getEncoding().isa() || + ATensorTy.getEncoding().isa()); + assert(BTensorTy.getEncoding().isa() && + "Operand B should use Shared layout."); Value llA, llB, llC; llA = adaptor.getA(); llB = adaptor.getB(); - llC = loadC(C, adaptor.getC()); + llC = adaptor.getC(); - auto smemObjA = getSharedMemoryObjectFromStruct(loc, llA, rewriter); - auto smemObjB = getSharedMemoryObjectFromStruct(loc, llB, rewriter); return convertDot(typeConverter, rewriter, loc, op.getOperation(), A, B, C, - op.getD(), llA, llB, llC, op.getAllowTF32(), smemObjA, - smemObjB, false, thread); + op.getD(), llA, llB, llC, op.getAllowTF32(), + op.getMaxNumImpreciseAcc(), false, thread); } diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp index 46aa1d82e..b84e0a5ca 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -471,8 +471,8 @@ const std::string Fp16_to_Fp8E4M3B15(bool has_minx2) { ".reg .b16 c<4>; \n" ".reg .b16 max_val_f16; \n" ".reg .b32 max_val_f16x2; \n" - "mov.b16 max_val_f16, 0x3F80; \n" - "mov.b32 max_val_f16x2, 0x3F803F80; \n" + "mov.b16 max_val_f16, 0x3F00; \n" + "mov.b32 max_val_f16x2, 0x3F003F00; \n" "and.b32 a0, $1, 0x7fff7fff; \n" "and.b32 a1, $2, 0x7fff7fff; \n"; if (has_minx2) @@ -912,6 +912,30 @@ const std::string Fp16_to_Fp8E4M3Nv = "{ \n" "cvt.rn.satfinite.e4m3x2.f16x2 $0, $1; \n" "}"; +// Fp8E4M3 (x2) -> Fp16 (x2) (packed) +const std::string Fp8E4M3Nv_to_Bf16 = + "{ \n" + ".reg .b32 a; \n" + ".reg .f16 a<2>; \n" + ".reg .b16 b<2>; \n" + "cvt.rn.f16x2.e4m3x2 a, $1; \n" + "mov.b32 {a0, a1}, a; \n" + "cvt.bf16.f16 b0, a0; \n" + "cvt.bf16.f16 b1, a1; \n" + "mov.b32 $0, {b0, b1}; \n" + "}"; + +// Bf16 (x2) -> Fp8E4M3 (x2) (packed) +const std::string Bf16_to_Fp8E4M3Nv = + "{ \n" + ".reg .b16 a<2>; \n" + ".reg .f32 b<2>; \n" + "mov.b32 {a0, a1}, $1; \n" + "cvt.f32.bf16 b0, a0; \n" + "cvt.f32.bf16 b1, a1; \n" + "cvt.rn.satfinite.e4m3x2.f32 $0, b0, b1; \n" + "}"; + /* ----- Packed integer to BF16 ------ */ #ifndef USE_ROCM const std::string S8_to_Bf16 = @@ -1374,8 +1398,10 @@ struct FpToFpOpConversion {{F16TyID, F8E5M2FNUZTyID}, Fp16_to_Fp8E5M2FNUZ}, // F8 -> BF16 {{F8E5M2TyID, BF16TyID}, Fp8E5M2_to_Bf16}, + {{F8E4M3TyID, BF16TyID}, Fp8E4M3Nv_to_Bf16}, // BF16 -> F8 {{BF16TyID, F8E5M2TyID}, Bf16_to_Fp8E5M2}, + {{BF16TyID, F8E4M3TyID}, Bf16_to_Fp8E4M3Nv}, }; int inVecWidthBits = 32; @@ -1418,26 +1444,28 @@ struct FpToFpOpConversion Location loc) const { auto srcElementType = getElementType(op.getFrom()); auto dstElementType = getElementType(op.getResult()); - int numElements = 4; + + size_t numElements = 4; if (srcElementType.isFloat8E4M3FNUZ() || dstElementType.isFloat8E4M3FNUZ()) { numElements = 2; } - assert(operands.size() % numElements == 0 && - "FP8 casting only support tensors with aligned sizes"); bool isSrcFP32 = srcElementType.isF32(); bool isDstFP32 = dstElementType.isF32(); auto cvtFunc = getConversionFunc(isSrcFP32 ? f16_ty : srcElementType, isDstFP32 ? f16_ty : dstElementType); SmallVector inVals; - for (unsigned i = 0; i < numElements; i++) { + for (unsigned i = 0; i < std::min(numElements, operands.size()); i++) { inVals.push_back(operands[i][0]); } if (isSrcFP32) for (Value &v : inVals) v = convertFp32ToFp16(loc, rewriter, v); + inVals.resize(numElements, + undef(typeConverter->convertType(srcElementType))); SmallVector outVals = cvtFunc(loc, rewriter, inVals); assert(outVals.size() == inVals.size()); + outVals.resize(std::min(numElements, operands.size())); if (isDstFP32) for (Value &v : outVals) v = convertFp16ToFp32(loc, rewriter, v); @@ -2141,161 +2169,3 @@ void populateElementwiseOpToLLVMPatterns( // __nv_expf for higher-precision calculation patterns.add(typeConverter, benefit); } - -struct FPExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::FPExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isF32() && srcTy.isF16()) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::FPExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - return { - FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0][0])}; - } -}; - -struct FPTruncOpConversion - : ElementwiseOpConversionBase { - using Base = - ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::FPTruncOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isF16() && srcTy.isF32()) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::FPTruncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - return { - FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0][0])}; - } -}; - -struct TruncOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::TruncOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(16) && srcTy.isInteger(32)) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::TruncOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.u16.u32"); - auto res = builder.newOperand("=h"); - auto operand = builder.newOperand(operands[0][0], "r"); - cvt(res, operand); - return {builder.launch(rewriter, loc, i16_ty, false)}; - } -}; - -struct SExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::SExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(32) && srcTy.isInteger(16)) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::SExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.s32.s16"); - auto res = builder.newOperand("=r"); - auto operand = builder.newOperand(operands[0][0], "h"); - cvt(res, operand); - return {builder.launch(rewriter, loc, i32_ty, false)}; - } -}; - -struct ZExtOpConversion - : ElementwiseOpConversionBase { - using Base = ElementwiseOpConversionBase; - using Base::Base; - using Adaptor = typename Base::OpAdaptor; - - static bool isLegalOp(LLVM::ZExtOp op) { - auto retTy = op.getResult().getType(); - auto srcTy = op.getOperand().getType(); - if (retTy.isInteger(32) && srcTy.isInteger(16)) { - return false; - } - return true; - } - - SmallVector createDestOps(LLVM::ZExtOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter, - Type elemTy, MultipleOperandsRange operands, - Location loc) const { - PTXBuilder builder; - auto &cvt = *builder.create("cvt.u32.u16"); - auto res = builder.newOperand("=r"); - auto operand = builder.newOperand(operands[0][0], "h"); - cvt(res, operand); - return {builder.launch(rewriter, loc, i32_ty, false)}; - } -}; - -bool isLegalElementwiseOp(Operation *op) { - if (isa(op)) { - return FPExtOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return FPTruncOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return TruncOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return SExtOpConversion::isLegalOp(cast(op)); - } else if (isa(op)) { - return ZExtOpConversion::isLegalOp(cast(op)); - } - return true; -} - -void populateElementwiseOpToPTXPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit) { - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); - patterns.add(typeConverter, benefit); -} diff --git a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h index fbcbe95bd..22b2e2101 100644 --- a/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h +++ b/lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.h @@ -13,10 +13,4 @@ void populateElementwiseOpToLLVMPatterns( ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo, int computeCapability, PatternBenefit benefit); -bool isLegalElementwiseOp(Operation *op); - -void populateElementwiseOpToPTXPatterns( - TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns, - PatternBenefit benefit); - #endif diff --git a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp index 5e4fa3c11..1fd621079 100644 --- a/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -21,6 +21,19 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getTotalElemsPerThread; using ::mlir::triton::gpu::SharedEncodingAttr; +static CUtensorMapDataType getCUtensorMapDataType(Type ty) { + if (ty.isF16()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } else if (ty.isBF16()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + } else if (ty.isF32()) { + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + } else { + llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op"); + return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + } +} + // Contains some helper functions for both Load and Store conversions. struct LoadStoreConversionBase { explicit LoadStoreConversionBase(ModuleAxisInfoAnalysis &axisAnalysisPass) @@ -804,7 +817,7 @@ struct StoreAsyncOpConversion typeConverter->convertType(rewriter.getI8Type()), 3); auto threadId = getThreadId(rewriter, loc); - Value pred = icmp_eq(urem(threadId, i32_val(32)), i32_val(0)); + Value pred = int_val(1, 1); auto llCoord = getTypeConverter()->unpackLLElements(loc, llDst, rewriter, dst.getType()); @@ -912,17 +925,6 @@ struct StoreAsyncOpConversion } private: - CUtensorMapDataType getCUtensorMapDataType(Type ty) const { - if (ty.isF16()) { - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } else if (ty.isF32()) { - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - } else { - llvm::report_fatal_error("Unsupported elemTy for StoreAsyncOp"); - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } - } - unsigned getArgIdx(Value v) const { if (auto op = v.getDefiningOp()) { return -1 - @@ -969,6 +971,18 @@ private: const TensorPtrMapT *tensorPtrMap; }; +namespace { +void createBarrier(ConversionPatternRewriter &rewriter, Location loc, + int numCTAs) { + if (numCTAs == 1) { + barrier(); + } else { + rewriter.create(loc, false); + rewriter.create(loc); + } +} +} // namespace + struct AtomicCASOpConversion : public ConvertTritonGPUOpToLLVMPattern, public LoadStoreConversionBase { @@ -1060,6 +1074,10 @@ struct AtomicCASOpConversion auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for AtomicCASOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + Value llPtr = adaptor.getPtr(); Value llCmp = adaptor.getCmp(); Value llVal = adaptor.getVal(); @@ -1097,7 +1115,7 @@ struct AtomicCASOpConversion atom.global().o(semStr).o("cas").o("b32"); atom(dstOpr, ptrOpr, cmpOpr, valOpr).predicate(mask); auto old = ptxBuilderAtomicCAS.launch(rewriter, loc, valueElemTy); - barrier(); + createBarrier(rewriter, loc, numCTAs); PTXBuilder ptxBuilderStore; auto *dstOprStore = ptxBuilderStore.newAddrOperand(atomPtr, "r"); @@ -1107,9 +1125,9 @@ struct AtomicCASOpConversion st(dstOprStore, valOprStore).predicate(mask); auto ASMReturnTy = void_ty(ctx); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); - barrier(); + createBarrier(rewriter, loc, numCTAs); Value ret = load(atomPtr); - barrier(); + createBarrier(rewriter, loc, numCTAs); rewriter.replaceOp(op, {ret}); return success(); } @@ -1279,7 +1297,11 @@ struct AtomicRMWOpConversion ConversionPatternRewriter &rewriter) const override { auto loc = op.getLoc(); MLIRContext *ctx = rewriter.getContext(); - // + + auto moduleOp = op->getParentOfType(); + assert(moduleOp && "Parent ModuleOp not found for AtomicRMWOp"); + int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(moduleOp); + auto atomicRmwAttr = op.getAtomicRmwOp(); Value val = op.getVal(); @@ -1352,7 +1374,7 @@ struct AtomicRMWOpConversion sTy = "b" + sBits; break; case RMWOp::ADD: - sTy = "s" + sBits; + sTy = "u" + sBits; break; case RMWOp::FADD: rmwOp = "add"; @@ -1410,9 +1432,9 @@ struct AtomicRMWOpConversion auto *valOpr = ptxBuilderStore.newOperand(old, tyId); storeShared(ptrOpr, valOpr).predicate(rmwMask); ptxBuilderStore.launch(rewriter, loc, void_ty(ctx)); - barrier(); + createBarrier(rewriter, loc, numCTAs); Value ret = load(atomPtr); - barrier(); + createBarrier(rewriter, loc, numCTAs); rewriter.replaceOp(op, {ret}); } } @@ -1980,17 +2002,6 @@ private: return bcastMask; } - CUtensorMapDataType getCUtensorMapDataType(Type ty) const { - if (ty.isF16()) { - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } else if (ty.isF32()) { - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32; - } else { - llvm::report_fatal_error("Unsupported elemTy for InsertSliceAsyncV2Op"); - return CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16; - } - } - unsigned getArgIdx(Value v) const { if (auto op = v.getDefiningOp()) { return -1 - diff --git a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp index 01681c48b..1fdfed388 100644 --- a/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp @@ -9,6 +9,7 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; +using ::mlir::LLVM::loadShared; using ::mlir::LLVM::shflSync; using ::mlir::LLVM::storeShared; using ::mlir::triton::gpu::getOrder; @@ -28,9 +29,54 @@ public: LogicalResult matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - if (ReduceOpHelper(op).isFastReduction()) - return matchAndRewriteFast(op, adaptor, rewriter); - return matchAndRewriteBasic(op, adaptor, rewriter); + ReduceOpHelper helper(op); + assert(helper.isSupportedLayout() && + "Unexpected srcLayout in ReduceOpConversion"); + Location loc = op->getLoc(); + + auto srcValues = unpackInputs(loc, op, adaptor, rewriter); + std::map, SmallVector> accs; + std::map, SmallVector> indices; + // First reduce all the values along axis within each thread. + reduceWithinThreads(helper, srcValues, accs, indices, rewriter); + + // Then reduce across threads within a warp. + reduceWithinWarps(helper, accs, rewriter); + + if (helper.isWarpSynchronous()) { + // If all the values to be reduced are within the same warp there is + // nothing left to do. + packResults(helper, accs, rewriter); + return success(); + } + + // Compute a shared memory base per operand. + auto smemShape = helper.getScratchConfig(); + + SmallVector smemBases = + getSmemBases(helper, op, smemShape, rewriter); + + storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); + + sync(rewriter, loc, op); + + // The second round of shuffle reduction + // now the problem size: sizeInterWarps, s1, s2, .. , sn + // where sizeInterWarps is 2^m + // + // Each thread needs to process: + // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads + accumulatePartialReductions(helper, smemBases, rewriter); + + // We could avoid this barrier in some of the layouts, however this is not + // the general case. + // TODO: optimize the barrier in case the layouts are accepted. + sync(rewriter, loc, op); + + // set output values + loadReductionAndPackResult(helper, smemShape, smemBases, rewriter); + + return success(); } private: @@ -87,6 +133,7 @@ private: return srcValues; } +<<<<<<< HEAD // Calculates the write index in the shared memory where we would be writing // the within-thread accumulations before we start doing across-threads // accumulations. `index` is the index of the within-thread accumulations in @@ -183,12 +230,28 @@ private: auto llvmIndexTy = getTypeConverter()->getIndexType(); auto smemShape = helper.getScratchConfigBasic(); +======= + SmallVector getSmemBases(ReduceOpHelper &helper, triton::ReduceOp op, + SmallVector smemShape, + ConversionPatternRewriter &rewriter) const { + auto loc = op.getLoc(); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 unsigned elems = product(smemShape); - - SmallVector smemBases(op.getNumOperands()); - smemBases[0] = bitcast( - getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys[0]); + // indices will store the index of the op operands in descending order + // of their bitwidths + std::vector indices(op.getNumOperands()); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&](unsigned i, unsigned j) { + return op.getElementTypes()[i].getIntOrFloatBitWidth() > + op.getElementTypes()[j].getIntOrFloatBitWidth(); + }); + // Assign base index to each operand in their order in indices + std::map indexToBase; + indexToBase[indices[0]] = + bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()), + getElementPtrType(op, indices[0])); for (unsigned i = 1; i < op.getNumOperands(); ++i) { +<<<<<<< HEAD smemBases[i] = bitcast(gep(elemPtrTys[i - 1], smemBases[i - 1], i32_val(elems)), elemPtrTys[i]); @@ -269,43 +332,19 @@ private: store(acc[i], writePtrs[i]); } } +======= + indexToBase[indices[i]] = + bitcast(gep(getElementPtrType(op, indices[i - 1]), + indexToBase[indices[i - 1]], i32_val(elems)), + getElementPtrType(op, indices[i])); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } - - sync(rewriter, loc, op); - - // set output values - SmallVector results(op.getNumOperands()); + // smemBases[k] is the base pointer for the k-th operand + SmallVector smemBases(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { - if (auto resultTy = - op.getResult()[i].getType().dyn_cast()) { - // nd-tensor where n >= 1 - - auto resultLayout = resultTy.getEncoding(); - - unsigned resultElems = getTotalElemsPerThread(resultTy); - auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy); - assert(resultIndices.size() == resultElems); - - SmallVector resultVals(resultElems); - for (unsigned j = 0; j < resultElems; ++j) { - SmallVector readIdx = resultIndices[j]; - readIdx.insert(readIdx.begin() + axis, ints[0]); - Value readOffset = - linearize(rewriter, loc, readIdx, smemShape, srcOrd); - Value readPtr = gep(elemPtrTys[i], smemBases[i], readOffset); - resultVals[j] = load(readPtr); - } - results[i] = getTypeConverter()->packLLElements(loc, resultVals, - rewriter, resultTy); - } else { - // 0d-tensor -> scalar - results[i] = load(smemBases[i]); - } + smemBases[i] = indexToBase[i]; } - - auto parentBlock = op.getOperation()->getBlock(); - rewriter.replaceOp(op, results); - return success(); + return smemBases; } void sync(ConversionPatternRewriter &rewriter, Location loc, @@ -390,7 +429,7 @@ private: // region and the accumulator values as source. void warpReduce(ConversionPatternRewriter &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce) const { + unsigned numLaneToReduce, unsigned interleave) const { if (auto kind = matchReduxKind(op)) { // Based on benchmarking on A100 redux op gives a speed up only when doing // a single reduction (not partioned) and when the mask is static. @@ -446,7 +485,11 @@ private: } #endif for (unsigned i = 0; i < acc.size(); ++i) { +<<<<<<< HEAD shfl[i] = shflSync(loc, rewriter, acc[i], shuffleIdx); +======= + shfl[i] = shflSync(loc, rewriter, acc[i], N * interleave); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } accumulate(rewriter, op.getCombineOp(), acc, shfl, false); } @@ -459,10 +502,13 @@ private: ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); unsigned sizeIntraWarps = helper.getIntraWarpSizeWithUniqueData(); + unsigned threadOffsetOnReductionAxis = + helper.getThreadOffsetOnReductionAxis(); for (auto it : accs) { const SmallVector &key = it.first; SmallVector &acc = accs[key]; - warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps); + warpReduce(rewriter, op.getLoc(), acc, op, sizeIntraWarps, + threadOffsetOnReductionAxis); } } @@ -502,6 +548,32 @@ private: return LLVM::LLVMPointerType::get(llvmElemTy, 3); } + SmallVector + getMultiDimWarpId(ReduceOpHelper &helper, Value &warpId, Location &loc, + ConversionPatternRewriter &rewriter) const { + auto srcLayout = helper.getSrcLayout(); + auto srcShape = helper.getSrcShape(); + auto order = getOrder(srcLayout); + SmallVector multiDimWarpId; + + // 2x2 warps with slice dim = 0, warpId = 2 ends up writing at the same + // address as warpId = 0 since the warpsPerCTA is [1, 2], need to figure out + // a way to properly delinearize warpId in the slice case + if (auto sliceLayout = srcLayout.dyn_cast()) { + auto parentLayout = sliceLayout.getParent(); + auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA(parentLayout); + auto parentOrder = triton::gpu::getOrder(parentLayout); + multiDimWarpId = + delinearize(rewriter, loc, warpId, parentWarpsPerCTA, parentOrder); + multiDimWarpId.erase(multiDimWarpId.begin() + sliceLayout.getDim()); + } else { + auto warpsPerCTA = + triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape); + multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); + } + return multiDimWarpId; + } + void storeWarpReduceToSharedMemory( ReduceOpHelper &helper, std::map, SmallVector> &accs, @@ -518,15 +590,14 @@ private: Value laneId = urem(threadId, warpSize); auto srcShape = helper.getSrcShape(); unsigned axis = op.getAxis(); - auto smemShapes = helper.getScratchConfigsFast(); + auto smemShape = helper.getScratchConfig(); auto threadsPerWarp = triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape); - auto warpsPerCTA = - triton::gpu::getWarpsPerCTAWithUniqueData(srcLayout, srcShape); auto order = getOrder(srcLayout); SmallVector multiDimLaneId = delinearize(rewriter, loc, laneId, threadsPerWarp, order); +<<<<<<< HEAD SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, order); @@ -545,20 +616,26 @@ private: } #endif +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value laneIdAxis = multiDimLaneId[axis]; - Value warpIdAxis = multiDimWarpId[axis]; - Value zero = i32_val(0); Value laneZero = icmp_eq(laneIdAxis, zero); + SmallVector multiDimWarpId = + getMultiDimWarpId(helper, warpId, loc, rewriter); + Value warpIdAxis = multiDimWarpId[axis]; + + if (!helper.isReductionOnLayoutFastAxis()) { + std::reverse(order.begin(), order.end()); + } for (auto it : accs) { const SmallVector &key = it.first; SmallVector acc = it.second; SmallVector writeIdx = indices[key]; writeIdx[axis] = warpIdAxis; - Value writeOffset = - linearize(rewriter, loc, writeIdx, smemShapes[0], order); + Value writeOffset = linearize(rewriter, loc, writeIdx, smemShape, order); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemPtrTy = getElementPtrType(op, i); Value writePtr = gep(elemPtrTy, smemBases[i], writeOffset); @@ -574,8 +651,8 @@ private: ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); auto srcLayout = helper.getSrcLayout(); - auto smemShapes = helper.getScratchConfigsFast(); - unsigned elems = product(smemShapes[0]); + auto smemShape = helper.getScratchConfig(); + unsigned elems = product(smemShape); unsigned sizeInterWarps = helper.getInterWarpSizeWithUniqueData(); Location loc = op.getLoc(); @@ -589,18 +666,16 @@ private: product(triton::gpu::getWarpsPerCTA(srcLayout)) * triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); unsigned elemsPerThread = std::max(elems / numThreads, 1); + Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); Value readOffset = threadId; for (unsigned round = 0; round < elemsPerThread; ++round) { - // FIXME(Qingyi): need predicate icmp_slt(threadId, - // i32_val(sizeInerWarps)) SmallVector acc(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { auto elemPtrTy = getElementPtrType(op, i); Value readPtr = gep(elemPtrTy, smemBases[i], readOffset); - acc[i] = load(readPtr); + acc[i] = loadShared(rewriter, loc, readPtr, threadIsNeeded); } - warpReduce(rewriter, loc, acc, op, sizeInterWarps); - + warpReduce(rewriter, loc, acc, op, sizeInterWarps, 1 /* interleave */); // only the first thread in each sizeInterWarps is writing Value writeOffset = readOffset; SmallVector writePtrs(op.getNumOperands()); @@ -608,7 +683,7 @@ private: auto elemPtrTy = getElementPtrType(op, i); writePtrs[i] = gep(elemPtrTy, smemBases[i], writeOffset); } - Value threadIsNeeded = icmp_slt(threadId, i32_val(elems)); + Value laneIdModSizeInterWarps = urem(laneId, i32_val(sizeInterWarps)); Value laneIdModSizeInterWarpsIsZero = icmp_eq(laneIdModSizeInterWarps, zero); @@ -638,12 +713,17 @@ private: // Load the final reduction from shared memory and replace the reduce result // with it. void loadReductionAndPackResult(ReduceOpHelper &helper, + SmallVector smemShape, SmallVector &smemBases, ConversionPatternRewriter &rewriter) const { triton::ReduceOp op = helper.getOperation(); Location loc = op.getLoc(); - auto smemShapes = helper.getScratchConfigsFast(); - auto order = getOrder(helper.getSrcLayout()); + auto srcLayout = helper.getSrcLayout(); + auto axis = op.getAxis(); + auto order = getOrder(srcLayout); + if (!helper.isReductionOnLayoutFastAxis()) { + std::reverse(order.begin(), order.end()); + } SmallVector results(op.getNumOperands()); for (unsigned i = 0; i < op.getNumOperands(); ++i) { if (auto resultTy = @@ -659,7 +739,7 @@ private: SmallVector readIdx = resultIndices[j]; readIdx.insert(readIdx.begin() + op.getAxis(), i32_val(0)); Value readOffset = - linearize(rewriter, loc, readIdx, smemShapes[0], order); + linearize(rewriter, loc, readIdx, smemShape, order); Value readPtr = gep(getElementPtrType(op, i), smemBases[i], readOffset); resultVals[j] = load(readPtr); @@ -674,67 +754,6 @@ private: } rewriter.replaceOp(op, results); } - - // Use warp shuffle for reduction within warps and shared memory for data - // exchange across warps - LogicalResult matchAndRewriteFast(triton::ReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const { - ReduceOpHelper helper(op); - assert(helper.isSupportedLayout() && - "Unexpected srcLayout in ReduceOpConversion"); - Location loc = op->getLoc(); - - auto srcValues = unpackInputs(loc, op, adaptor, rewriter); - std::map, SmallVector> accs; - std::map, SmallVector> indices; - // First reduce all the values along axis within each thread. - reduceWithinThreads(helper, srcValues, accs, indices, rewriter); - - // Then reduce across threads within a warp. - reduceWithinWarps(helper, accs, rewriter); - - if (helper.isWarpSynchronous()) { - // If all the values to be reduced are within the same warp there is - // nothing left to do. - packResults(helper, accs, rewriter); - return success(); - } - - // Compute a shared memory base per operand. - auto smemShapes = helper.getScratchConfigsFast(); - unsigned elems = product(smemShapes[0]); - unsigned maxElems = std::max(elems, product(smemShapes[1])); - SmallVector smemBases(op.getNumOperands()); - smemBases[0] = - bitcast(getSharedMemoryBase(loc, rewriter, op.getOperation()), - getElementPtrType(op, 0)); - for (unsigned i = 1; i < op.getNumOperands(); ++i) { - smemBases[i] = bitcast(gep(getElementPtrType(op, i - 1), smemBases[i - 1], - i32_val(maxElems)), - getElementPtrType(op, i)); - } - storeWarpReduceToSharedMemory(helper, accs, indices, smemBases, rewriter); - - sync(rewriter, loc, op); - - // The second round of shuffle reduction - // now the problem size: sizeInterWarps, s1, s2, .. , sn - // where sizeInterWarps is 2^m - // - // Each thread needs to process: - // elemsPerThread = sizeInterWarps * s1 * s2 .. Sn / numThreads - accumulatePartialReductions(helper, smemBases, rewriter); - - // We could avoid this barrier in some of the layouts, however this is not - // the general case. - // TODO: optimize the barrier in case the layouts are accepted. - sync(rewriter, loc, op); - - // set output values - loadReductionAndPackResult(helper, smemBases, rewriter); - - return success(); - } }; void populateReduceOpToLLVMPatterns( diff --git a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp index 397abc1ea..c1276a4de 100644 --- a/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp @@ -7,16 +7,19 @@ using namespace mlir::triton; using ::mlir::LLVM::delinearize; using ::mlir::LLVM::linearize; +using ::mlir::LLVM::shflIdxSync; using ::mlir::LLVM::shflUpSync; using ::mlir::LLVM::storeShared; -// Apply the region of the scan op to the acc and cur values and update acc -// inplace with the result. -static void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, - Value &acc, Value cur) { - if (!acc) { - acc = cur; - return; +// apply combine region to a and b and return the result. If a or b is null, +// return the other operand. +static Value accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, + Value a, Value b) { + if (!a) { + return b; + } + if (!b) { + return a; } // Create a new copy of the reduce block, and inline it Block *currentBlock = rewriter.getBlock(); @@ -24,13 +27,14 @@ static void accumulate(ConversionPatternRewriter &rewriter, Region &combineOp, rewriter.cloneRegionBefore(combineOp, &parent.front()); auto &newScan = parent.front(); auto returnOp = dyn_cast(newScan.getTerminator()); - llvm::SmallVector combineArgs = {acc, cur}; + llvm::SmallVector combineArgs = {a, b}; rewriter.inlineBlockBefore(&newScan, &*rewriter.getInsertionPoint(), combineArgs); auto results = returnOp.getResult(); - acc = results[0]; + Value acc = results[0]; // Delete the terminator, which is no longer used rewriter.eraseOp(returnOp); + return acc; } // Scan a contiguous elements within a thread and update `srcValues` in place. @@ -41,7 +45,6 @@ static void scanThreadContiguousElements(SmallVector &srcValues, // contiguous in srcValues. Keep track of what elements belong to the same // chunk of contiguous elements. unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); - unsigned parallelElementsPerThread = helper.getAxisNumElementsPerThread(); unsigned numChunks = srcValues.size() / scanElementsPerThreads; unsigned stride = helper.getAxisElementStride(); SmallVector accs(numChunks); @@ -49,8 +52,8 @@ static void scanThreadContiguousElements(SmallVector &srcValues, unsigned accIndex = (srcIndex % stride) + ((srcIndex / stride) / scanElementsPerThreads) * stride; - accumulate(rewriter, helper.getCombineOp(), accs[accIndex], - srcValues[srcIndex]); + accs[accIndex] = accumulate(rewriter, helper.getCombineOp(), accs[accIndex], + srcValues[srcIndex]); srcValues[srcIndex] = accs[accIndex]; } } @@ -59,13 +62,17 @@ static void scanThreadContiguousElements(SmallVector &srcValues, // contiguous group of elements. static void warpScan(SmallVector &srcValues, ConversionPatternRewriter &rewriter, +<<<<<<< HEAD ScanLoweringHelper &helper, Value laneIdAxis, Value laneId) { +======= + ScanLoweringHelper &helper, Value laneIdAxis) { +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Location loc = helper.getLoc(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); - unsigned scanDim = helper.getAxisNumThreadsPerWarp(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; // Only consider the last element of each contiguous chunk of elements. @@ -74,9 +81,14 @@ static void warpScan(SmallVector &srcValues, // Reduce within warps. Value acc = srcValues[srcIndex]; for (unsigned i = 1; i <= (scanDim) / 2; i = i << 1) { +<<<<<<< HEAD Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride, laneId); Value tempAcc = acc; accumulate(rewriter, helper.getCombineOp(), tempAcc, shfl); +======= + Value shfl = shflUpSync(loc, rewriter, acc, i * threadStride); + Value tempAcc = accumulate(rewriter, helper.getCombineOp(), shfl, acc); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value mask = icmp_slt(laneIdAxis, i32_val(i)); acc = select(mask, acc, tempAcc); } @@ -97,9 +109,9 @@ static void storeWarpAccumulator(SmallVector &srcValues, Value parallelLaneId) { Location loc = helper.getLoc(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); - unsigned scanDim = helper.getAxisNumThreadsPerWarp(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); - unsigned numWarps = helper.getAxisNumWarps(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); unsigned chunkId = 0; unsigned elementStride = helper.getAxisElementStride(); for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { @@ -110,7 +122,7 @@ static void storeWarpAccumulator(SmallVector &srcValues, Value lastElement = srcValues[srcIndex]; Value mask = icmp_eq(laneId, i32_val(scanDim - 1)); Value index = add(parallelLaneId, mul(warpId, i32_val(numParallelLane))); - index = add(index, i32_val(chunkId * numParallelLane * numWarps)); + index = add(index, i32_val(chunkId * numParallelLane * axisNumWarps)); Value writePtr = gep(baseSharedMemPtr.getType(), baseSharedMemPtr, index); storeShared(rewriter, loc, writePtr, lastElement, mask); chunkId++; @@ -126,14 +138,18 @@ static void AddPartialReduce(SmallVector &srcValues, ConversionPatternRewriter &rewriter, ScanLoweringHelper &helper, Value sharedMemoryPtr, Value warpId, Value laneIdAxis, +<<<<<<< HEAD Value parallelLaneId, Value laneId) { +======= + Value parallelLaneId) { +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Location loc = helper.getLoc(); unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); - unsigned numWarps = helper.getAxisNumWarps(); unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); unsigned elementStride = helper.getAxisElementStride(); unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); @@ -165,9 +181,10 @@ static void AddPartialReduce(SmallVector &srcValues, unsigned accumulatorIndex = chunkId % parallelElementsPerThread + parallelBlockId * parallelElementsPerThread; Accumulator &accumulator = accumulators[accumulatorIndex]; - for (unsigned i = 0; i < numWarps; ++i) { - Value index = add(parallelLaneId, - i32_val(numParallelLane * (i + chunkId * numWarps))); + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + for (unsigned i = 0; i < axisNumWarps; ++i) { + Value index = add(parallelLaneId, i32_val(numParallelLane * + (i + chunkId * axisNumWarps))); Value ptr = gep(sharedMemoryPtr.getType(), sharedMemoryPtr, index); Value partialReduce = load(ptr); if (!accumulator.acc) { @@ -175,15 +192,14 @@ static void AddPartialReduce(SmallVector &srcValues, accumulator.maskedAcc = partialReduce; continue; } - accumulate(rewriter, helper.getCombineOp(), accumulator.acc, - partialReduce); + accumulator.acc = accumulate(rewriter, helper.getCombineOp(), + accumulator.acc, partialReduce); Value mask = icmp_slt(warpId, i32_val(i + 1)); accumulator.maskedAcc = select(mask, accumulator.maskedAcc, accumulator.acc); } - Value temp = srcValues[srcIndex]; - accumulate(rewriter, helper.getCombineOp(), temp, accumulator.maskedAcc); - unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + Value temp = accumulate(rewriter, helper.getCombineOp(), + accumulator.maskedAcc, srcValues[srcIndex]); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. @@ -196,7 +212,8 @@ static void AddPartialReduce(SmallVector &srcValues, lastElement = select(maskFirstLane, accumulator.maskedAcc, lastElement); for (unsigned i = 1; i < scanElementsPerThreads; ++i) { Value laneValue = srcValues[srcIndex - i * elementStride]; - accumulate(rewriter, helper.getCombineOp(), laneValue, lastElement); + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); if (axisBlockId == 0) { // For the first warp and first chunk we don't have anything to // accumulate. @@ -212,6 +229,76 @@ static void AddPartialReduce(SmallVector &srcValues, } } +static void AddPartialReduceOneWarp(SmallVector &srcValues, + ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, Value warpId, + Value laneIdAxis, Value laneIdLast) { + Location loc = helper.getLoc(); + unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread(); + unsigned parallelElementsPerThread = helper.getNonAxisNumElementsPerThread(); + unsigned elementStride = helper.getAxisElementStride(); + unsigned threadStride = helper.getAxisThreadStride(); + unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA(); + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + Value maskFirstWarp = icmp_eq(warpId, i32_val(0)); + Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0)); + Value maskFirstThread = and_(maskFirstWarp, maskFirstLane); + unsigned numScanBlocks = helper.getAxisNumBlocks(); + unsigned numParallelBlocks = helper.getNonAxisNumBlocks(); + assert(numScanBlocks * numParallelBlocks * parallelElementsPerThread * + scanElementsPerThreads == + srcValues.size()); + SmallVector accumulators(numParallelBlocks * + parallelElementsPerThread); + unsigned chunkId = 0; + unsigned blockStride = helper.getAxisBlockStride(); + for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) { + unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads; + // Only consider the last element of each contiguous chunk of elements. + if (elementIdx != scanElementsPerThreads - 1) + continue; + unsigned blockId = chunkId / parallelElementsPerThread; + unsigned parallelBlockId = + blockId % blockStride + + ((blockId / blockStride) / numScanBlocks) * blockStride; + unsigned accumulatorIndex = chunkId % parallelElementsPerThread + + parallelBlockId * parallelElementsPerThread; + Value &accumulator = accumulators[accumulatorIndex]; + unsigned axisBlockId = (blockId / blockStride) % numScanBlocks; + if (axisBlockId == 0) // First chunk and first block + accumulator = srcValues[srcIndex]; + else + srcValues[srcIndex] = accumulate(rewriter, helper.getCombineOp(), + accumulator, srcValues[srcIndex]); + // Update the rest of the contiguous elements. + Value lastElement = srcValues[srcIndex]; + if (scanDim > 1) { + lastElement = + shflUpSync(loc, rewriter, srcValues[srcIndex], threadStride); + lastElement = select(maskFirstLane, accumulator, lastElement); + if (numScanBlocks > 1) + // Update accumulator with the value from the last lane. + accumulator = + shflIdxSync(loc, rewriter, srcValues[srcIndex], laneIdLast); + } + for (unsigned i = 1; i < scanElementsPerThreads; ++i) { + Value laneValue = srcValues[srcIndex - i * elementStride]; + laneValue = + accumulate(rewriter, helper.getCombineOp(), lastElement, laneValue); + if (axisBlockId == 0) + // For the first warp and first chunk we don't have anything to + // accumulate. + laneValue = select(maskFirstThread, + srcValues[srcIndex - i * elementStride], laneValue); + srcValues[srcIndex - i * elementStride] = laneValue; + } + // For the next chunk start back from the value containing the + // accumulated value of all the warps. + chunkId++; + } +} + namespace { struct ScanOpConversion : public ConvertTritonGPUOpToLLVMPattern { @@ -228,6 +315,12 @@ public: } private: + SmallVector getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const; + SmallVector getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const; std::tuple getDelinearizedIds(ConversionPatternRewriter &rewriter, ScanLoweringHelper &helper, Value laneId, @@ -236,6 +329,34 @@ private: ConversionPatternRewriter &rewriter) const; }; +SmallVector +ScanOpConversion::getMultiDimLaneId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value laneId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + return delinearize(rewriter, loc, laneId, threadsPerWarp, order); +} + +SmallVector +ScanOpConversion::getMultiDimWarpId(ConversionPatternRewriter &rewriter, + ScanLoweringHelper &helper, + Value warpId) const { + auto loc = helper.getLoc(); + unsigned axis = helper.getAxis(); + auto srcEncoding = helper.getEncoding(); + + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding); + auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding); + auto order = triton::gpu::getOrder(srcEncoding); + return delinearize(rewriter, loc, warpId, warpsPerCTA, order); +} + // Break up the threadId into lane and warp id along the scan dimension and // compute a flat id for the parallel dimensions. std::tuple @@ -280,7 +401,10 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, if (!helper.isSupported()) return failure(); +<<<<<<< HEAD // Obtain global laneId and pass it around +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value threadId = getThreadId(rewriter, loc); auto mod = op->getParentOfType(); unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); @@ -292,6 +416,8 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, getDelinearizedIds(rewriter, helper, laneId, warpId); auto input = adaptor.getOperands()[0]; auto type = op.getOperand(0).getType().cast(); + auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData(); + warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps)); SmallVector srcValues = getTypeConverter()->unpackLLElements(loc, input, rewriter, type); @@ -301,6 +427,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // elements. warpScan(srcValues, rewriter, helper, laneIdAxis, laneId); +<<<<<<< HEAD // Store the partial reducing for each warp into shared memory. Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3); Value baseSharedMemPtr = bitcast( @@ -313,6 +440,35 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor, // accumulated value from the previous lane. AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis, laneIdAxis, flatIdParallel, laneId); +======= + if (axisNumWarps > 1) { + // Slow path for the case where there are multiple warps with unique data on + // the axis. + Type elemPtrTys = LLVM::LLVMPointerType::get(srcValues[0].getType(), 3); + Value baseSharedMemPtr = bitcast( + getSharedMemoryBase(loc, rewriter, op.getOperation()), elemPtrTys); + // Store the partial reducing for each warp into shared memory. + storeWarpAccumulator(srcValues, rewriter, helper, laneIdAxis, warpIdAxis, + baseSharedMemPtr, flatIdParallel); + barrier(); + // Read back the partial reduction of each warp and accumulate them based on + // warpId. Then update each chunk of contiguous elements by adding the + // accumulated value from the previous lane. + AddPartialReduce(srcValues, rewriter, helper, baseSharedMemPtr, warpIdAxis, + laneIdAxis, flatIdParallel); + } else if (srcValues.size() > 1) { + // Fast path for the case where there is only one warp with unique data on + // the axis. + unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData(); + auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId); + multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1); + auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding()); + auto laneIdLast = linearize(rewriter, loc, multiDimLaneId, threadsPerWarp, + triton::gpu::getOrder(helper.getEncoding())); + AddPartialReduceOneWarp(srcValues, rewriter, helper, warpIdAxis, laneIdAxis, + laneIdLast); + } // else axisNumWarps == 1 and srcValues.size() == 1, nothing to do. +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value results = getTypeConverter()->packLLElements(loc, srcValues, rewriter, input.getType()); diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h index df3e877d1..cf031a339 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMBase.h @@ -11,6 +11,7 @@ #include "Utility.h" #include "mlir/IR/TypeUtilities.h" #include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Target/PTX/TmaMetadata.h" #include @@ -340,6 +341,9 @@ public: // Order auto inOrder = triton::gpu::getOrder(srcEncoding); auto outOrder = triton::gpu::getOrder(resSharedLayout); + assert(maxPhase == 1 || + outVec * maxPhase <= srcShape[outOrder[0]] && + "Swizzling would generate out of bounds memory accesses"); // Tensor indices held by the current thread, as LLVM values auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy, false); // Swizzling with leading offsets (e.g. Hopper GMMA) @@ -453,18 +457,19 @@ public: auto dstElemTy = dstTy.getElementType(); auto inOrd = triton::gpu::getOrder(srcSharedLayout); auto outOrd = triton::gpu::getOrder(dstDistributedLayout); - unsigned outVec = - inOrd == outOrd - ? triton::gpu::getContigPerThread(dstDistributedLayout)[outOrd[0]] - : 1; + unsigned outVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + dstDistributedLayout, dstShape)[outOrd[0]] + : 1; unsigned inVec = srcSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned outElems = triton::gpu::getTotalElemsPerThread(dstTy); + SmallVector offsetVals = {i32_val(0), i32_val(0)}; assert(outElems == dstIndices.size()); - DenseMap sharedPtrs = getSwizzledSharedPtrs( - loc, outVec, dstTy, srcSharedLayout, srcElemTy, smemObj, rewriter, - smemObj.offsets, smemObj.strides); + DenseMap sharedPtrs = + getSwizzledSharedPtrs(loc, outVec, dstTy, srcSharedLayout, srcElemTy, + smemObj, rewriter, offsetVals, smemObj.strides); assert(outElems % minVec == 0 && "Unexpected number of elements"); unsigned numVecs = outElems / minVec; auto wordTy = vec_ty(elemTy, minVec); @@ -502,10 +507,10 @@ public: auto dstElemTy = dstTy.getElementType(); auto inOrd = triton::gpu::getOrder(srcDistributedLayout); auto outOrd = dstSharedLayout.getOrder(); - unsigned inVec = - inOrd == outOrd - ? triton::gpu::getContigPerThread(srcDistributedLayout)[inOrd[0]] - : 1; + unsigned inVec = inOrd == outOrd + ? triton::gpu::getUniqueContigPerThread( + srcDistributedLayout, srcShape)[inOrd[0]] + : 1; unsigned outVec = dstSharedLayout.getVec(); unsigned minVec = std::min(outVec, inVec); unsigned numElems = triton::gpu::getTotalElemsPerThread(srcTy); @@ -543,6 +548,7 @@ public: auto tensorTy = valueTy.dyn_cast(); Value mask = int_val(1, 1); auto tid = tid_val(); + auto clusterCTAId = getClusterCTAId(rewriter, loc); if (tensorTy) { auto layout = tensorTy.getEncoding(); auto shape = tensorTy.getShape(); @@ -578,7 +584,6 @@ public: auto CTASplitNum = triton::gpu::getCTASplitNum(layout); auto CTAOrder = triton::gpu::getCTAOrder(layout); - auto clusterCTAId = getClusterCTAId(rewriter, loc); auto multiDimClusterCTAId = delinearize(rewriter, loc, clusterCTAId, CTAsPerCGA, CTAOrder); @@ -588,14 +593,23 @@ public: continue; // This wrapping rule must be consistent with emitCTAOffsetForLayout unsigned splitNum = std::min(shape[dim], CTASplitNum[dim]); - multiDimClusterCTAId[dim] = - urem(multiDimClusterCTAId[dim], i32_val(splitNum)); - mask = and_(mask, icmp_eq(multiDimClusterCTAId[dim], _0)); + Value repId = udiv(multiDimClusterCTAId[dim], i32_val(splitNum)); + // Consider the example where CTAsPerCGA = [4] and CTASplitNum = [2]: + // CTA0 and CTA2 holds data of block0, + // CTA1 and CTA3 holds data of block1. + // Only CTA0 and CTA1 are expected to write while CTA2 and CTA3 should + // be masked. We add the following mask: + // multiDimClusterCTAId[dim] / splitNum == 0 + // Actually in all existing cases of multicast, splitNum is always 1. + // The mask is equivalent to: + // multiDimClusterCTAId[dim] == 0 + mask = and_(mask, icmp_eq(repId, _0)); } } } else { - // If the tensor is not ranked, then it is a scalar and only thread 0 can - // write + // If the tensor is not ranked, then it is a scalar and only thread 0 of + // CTA0 can write + mask = and_(mask, icmp_eq(clusterCTAId, i32_val(0))); mask = and_(mask, icmp_eq(tid, i32_val(0))); } return mask; diff --git a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp index 97df7231c..94ed93fa1 100644 --- a/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp @@ -18,6 +18,7 @@ #include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #ifndef USE_ROCM @@ -398,6 +399,16 @@ struct ConvertTritonGPUToLLVM using ConvertTritonGPUToLLVMBase< ConvertTritonGPUToLLVM>::ConvertTritonGPUToLLVMBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + ConvertTritonGPUToLLVM(int32_t computeCapability, Target target, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata) + : ConvertTritonGPUToLLVMBase({computeCapability, target}), + tmaMetadata(tmaMetadata) {} + void runOnOperation() override { MLIRContext *context = &getContext(); ModuleOp mod = getOperation(); @@ -591,6 +602,7 @@ private: DenseMap>, CacheKeyDenseMapInfo> indexCache; + mlir::triton::gpu::TMAMetadataTy *tmaMetadata = nullptr; void initSharedMemory(ModuleAllocation &allocation, TritonGPUToLLVMTypeConverter &typeConverter) { @@ -1036,7 +1048,7 @@ private: bool isNativeHopperFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FNUZ(); bool isFP8 = isNativeHopperFP8 || AElType.isFloat8E5M2FNUZ() || - AElType.isFloat8E4M3FN(); + AElType.isFloat8E4M3FN() || AElType.isFloat8E4M3B11FNUZ(); if (!isFP8 || (isNativeHopperFP8 && mmaLayout.isHopper())) return; promoteType = builder.getF16Type(); @@ -1090,9 +1102,11 @@ namespace triton { std::unique_ptr> createConvertTritonGPUToLLVMPass() { return std::make_unique(); } -std::unique_ptr> -createConvertTritonGPUToLLVMPass(const ConvertTritonGPUToLLVMOptions &options) { - return std::make_unique(options); +std::unique_ptr> createConvertTritonGPUToLLVMPass( + int32_t computeCapability, Target target, + mlir::triton::gpu::TMAMetadataTy *tmaMetadata) { + return std::make_unique(computeCapability, target, + tmaMetadata); } } // namespace triton diff --git a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp index f04cee67e..ca53a82dd 100644 --- a/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp +++ b/lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp @@ -179,7 +179,7 @@ Type TritonGPUToLLVMTypeConverter::getElementTypeForStruct( #endif auto mmaParent = dotOpLayout.getParent().dyn_cast(); - if (!mmaParent) + if (!mmaParent || mmaParent.isHopper()) return elemTy; int bitwidth = elemTy.getIntOrFloatBitWidth(); assert(bitwidth <= 32); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 3cc3425b4..5385dd135 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -1,6 +1,7 @@ #include "Utility.h" #include "TypeConverter.h" - +#include "mlir/Dialect/LLVMIR/NVVMDialect.h" +#include "triton/Dialect/NVGPU/IR/Dialect.h" namespace mlir { namespace LLVM { @@ -243,7 +244,7 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, return val; #else MLIRContext *ctx = rewriter.getContext(); - unsigned bits = val.getType().getIntOrFloatBitWidth(); + unsigned bits = std::max(8u, val.getType().getIntOrFloatBitWidth()); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); PTXBuilder builder; @@ -255,9 +256,32 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, #endif } +Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value pred) { + MLIRContext *ctx = rewriter.getContext(); + auto ptrTy = ptr.getType().cast(); + assert(ptrTy.getAddressSpace() == 3 && "Invalid addr space for loadShared"); + auto elemTy = ptrTy.getElementType(); + unsigned bitwidth = std::max(8u, elemTy.getIntOrFloatBitWidth()); + + const char *c = bitwidth == 64 ? "=l" : (bitwidth == 16 ? "=h" : "=r"); + + PTXBuilder builder; + auto *dOpr = builder.newOperand(c); + auto *ptrOpr = builder.newAddrOperand(ptr, "r"); + auto &ld = builder.create<>("ld")->shared().b(bitwidth); + ld(dOpr, ptrOpr).predicate(pred, "b"); + return builder.launch(rewriter, loc, elemTy); +} + static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, +<<<<<<< HEAD Value val, int i, const std::string &shuffleType, const std::string &clamp, Value laneId = Value()) { +======= + Value val, Value i, NVVM::ShflKind mode, + Value clamp) { +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 unsigned bits = val.getType().getIntOrFloatBitWidth(); #ifdef USE_ROCM @@ -275,13 +299,19 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, Value vec = bitcast(val, vecTy); Value val0 = extract_element(f32_ty, vec, i32_val(0)); Value val1 = extract_element(f32_ty, vec, i32_val(1)); +<<<<<<< HEAD val0 = commonShflSync(loc, rewriter, val0, i, shuffleType, clamp, laneId); val1 = commonShflSync(loc, rewriter, val1, i, shuffleType, clamp, laneId); +======= + val0 = commonShflSync(loc, rewriter, val0, i, mode, clamp); + val1 = commonShflSync(loc, rewriter, val1, i, mode, clamp); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 vec = undef(vecTy); vec = insert_element(vecTy, vec, val0, i32_val(0)); vec = insert_element(vecTy, vec, val1, i32_val(1)); return bitcast(vec, val.getType()); } +<<<<<<< HEAD #ifdef USE_ROCM GCNBuilder builder; @@ -342,17 +372,51 @@ static Value commonShflSync(Location loc, ConversionPatternRewriter &rewriter, shfl(dOpr, aOpr, bOpr, cOpr, maskOpr); return builder.launch(rewriter, loc, val.getType(), false); #endif +======= + Type type = val.getType(); + if (type != i32_ty) { + val = bitcast(val, int_ty(bits)); + val = zext(i32_ty, val); + } + Value mask = i32_val(0xFFFFFFFF); + Value result = rewriter.create(loc, i32_ty, mask, val, i, clamp, + mode, UnitAttr()); + if (type != i32_ty) { + result = trunc(int_ty(bits), result); + result = bitcast(result, type); + } + return result; +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i) { - return commonShflSync(loc, rewriter, val, i, "bfly", "0x1f"); + return commonShflSync(loc, rewriter, val, i32_val(i), NVVM::ShflKind::bfly, + i32_val(0x1f)); } Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, +<<<<<<< HEAD int i, Value laneId) { return commonShflSync(loc, rewriter, val, i, "up", "0x0", laneId); +======= + int i) { + return commonShflSync(loc, rewriter, val, i32_val(i), NVVM::ShflKind::up, + i32_val(0x0)); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } + +Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + int i) { + return shflIdxSync(loc, rewriter, val, i32_val(i)); +} + +Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + Value i) { + return commonShflSync(loc, rewriter, val, i, NVVM::ShflKind::idx, + i32_val(0x1f)); +} + Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) { PTXBuilder builder; auto &mov = builder.create("mov")->o("u32"); diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.h b/lib/Conversion/TritonGPUToLLVM/Utility.h index 3cfdcd444..31d7950c4 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.h +++ b/lib/Conversion/TritonGPUToLLVM/Utility.h @@ -240,21 +240,24 @@ SmallVector getStridesFromShapeAndOrder(ArrayRef shape, ArrayRef order, Location loc, ConversionPatternRewriter &rewriter); struct SharedMemoryObject { - Value base; // i32 ptr. The start address of the shared memory object. - // We need to store strides as Values but not integers because the + Value base; // i32 ptr. The start address of the shared memory object after + // the initial allocation or the last slicing operation. + // We need to store strides as Values, not integers, because the // extract_slice instruction can take a slice at arbitrary offsets. - // Take $a[16:32, 16:32] as an example, though we know the stride of $a[0] is - // 32, we need to let the instruction that uses $a to be aware of that. + // Take $a[16:32, 16:32] as an example; though we know the stride of $a[0] is + // 32, we need to let the instruction that uses $a be aware of that. // Otherwise, when we use $a, we only know that the shape of $a is 16x16. If // we store strides into an attribute array of integers, the information // cannot pass through block argument assignment because attributes are - // associated with operations but not Values. + // associated with operations, not Values. // TODO(Keren): We may need to figure out a way to store strides as integers // if we want to support more optimizations. SmallVector strides; // i32 int. The strides of the shared memory object. - SmallVector offsets; // i32 int. The offsets of the shared memory - // objects from the originally allocated object. + SmallVector offsets; // i32 int. + // Offsets are applied at the last slicing operation. + // We can use offsets to recover the previous base. + // The offsets are zero at the initial allocation. SharedMemoryObject(Value base, ArrayRef strides, ArrayRef offsets) @@ -328,11 +331,22 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc, Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, Value val, Value pred); +Value loadShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, + Value pred); + Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val, int i); Value shflUpSync(Location loc, ConversionPatternRewriter &rewriter, Value val, +<<<<<<< HEAD int i, Value laneId); +======= + int i); +Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + int i); +Value shflIdxSync(Location loc, ConversionPatternRewriter &rewriter, Value val, + Value i); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr); Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, StringRef key, StringRef content); diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index b03d76ac4..b4f091d7b 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -29,17 +29,19 @@ static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) { op->setAttr(attr.getName(), attr.getValue()); } -template class GenericOpPattern : public OpConversionPattern { -public: +template struct GenericOpPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(Op op, typename Op::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - Type retType = this->getTypeConverter()->convertType(op.getType()); - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()), - adaptor.getAttributes()); + SmallVector retTypes; + if (failed(this->getTypeConverter()->convertTypes(op->getResultTypes(), + retTypes))) + return failure(); + rewriter.replaceOpWithNewOp(op, retTypes, adaptor.getOperands(), + op->getAttrs()); + return success(); } }; @@ -88,22 +90,6 @@ public: } }; -class ConvertArithOp : public ConversionPattern { -public: - ConvertArithOp(TritonGPUTypeConverter &typeConverter, MLIRContext *context) - : ConversionPattern(typeConverter, MatchAnyOpTypeTag(), /*benefit=*/1, - context) {} - - LogicalResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - Dialect *dialect = op->getDialect(); - if (dialect->getTypeID() != mlir::TypeID::get()) - return failure(); - return success(); - } -}; - void populateArithPatternsAndLegality(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns, TritonGPUConversionTarget &target) { @@ -199,22 +185,6 @@ void populateMathPatternsAndLegality(TritonGPUTypeConverter &typeConverter, // // Triton patterns // -// TODO: Do we need to put them in anonymous namespace? -struct TritonMakeRangePattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type retType = getTypeConverter()->convertType(op.getType()); - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, retType, adaptor.getStart(), adaptor.getEnd()), - adaptor.getAttributes()); - return success(); - } -}; - struct TritonExpandDimsPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -342,7 +312,8 @@ struct TritonDotPattern : public OpConversionPattern { c = rewriter.create(c.getLoc(), retType, c); addNamedAttrs(rewriter.replaceOpWithNewOp( - op, retType, a, b, c, adaptor.getAllowTF32()), + op, retType, a, b, c, adaptor.getAllowTF32(), + adaptor.getMaxNumImpreciseAcc()), adaptor.getAttributes()); return success(); } @@ -434,85 +405,6 @@ struct TritonTransPattern : public OpConversionPattern { } }; -struct TritonLoadPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(), - adaptor.getBoundaryCheckAttr(), adaptor.getPaddingAttr(), - adaptor.getCache(), adaptor.getEvict(), - adaptor.getIsVolatile()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonStorePattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, adaptor.getPtr(), adaptor.getValue(), - adaptor.getMask(), adaptor.getCache(), - adaptor.getEvict()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonAtomicCASPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal(), - op.getSem()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonAtomicRMWPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), - adaptor.getAtomicRmwOp(), adaptor.getPtr(), - adaptor.getVal(), adaptor.getMask(), op.getSem()), - adaptor.getAttributes()); - return success(); - } -}; - -template -struct TritonGenericPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(Op op, typename Op::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Type retType = this->getTypeConverter()->convertType(op.getType()); - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, retType, adaptor.getOperands()), - adaptor.getAttributes()); - return success(); - } -}; - struct TritonBroadcastPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -554,20 +446,6 @@ struct TritonReducePattern : public OpConversionPattern { } }; -struct TritonReduceReturnPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ReduceReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, adaptor.getResult()), - adaptor.getAttributes()); - return success(); - } -}; - struct TritonScanPattern : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -586,48 +464,6 @@ struct TritonScanPattern : public OpConversionPattern { } }; -struct TritonScanReturnPattern - : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::ScanReturnOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, adaptor.getResult()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonPrintPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::PrintOp op, typename triton::PrintOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, op.getPrefixAttr(), adaptor.getOperands()), - adaptor.getAttributes()); - return success(); - } -}; - -struct TritonAssertPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(triton::AssertOp op, - typename triton::AssertOp::Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - addNamedAttrs(rewriter.replaceOpWithNewOp( - op, adaptor.getCondition(), op.getMessageAttr(), - op.getFileAttr(), op.getFuncAttr(), op.getLineAttr()), - adaptor.getAttributes()); - return success(); - } -}; - class TritonFuncOpPattern : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -679,22 +515,23 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, MLIRContext *context = patterns.getContext(); patterns.insert< // TODO: view should have custom pattern that views the // layout - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, - TritonGenericPattern, TritonBroadcastPattern, - TritonGenericPattern, TritonCatPattern, - TritonGenericPattern, TritonReducePattern, - TritonReduceReturnPattern, TritonScanPattern, TritonScanReturnPattern, - TritonTransPattern, TritonExpandDimsPattern, TritonMakeRangePattern, - TritonDotPattern, TritonLoadPattern, TritonStorePattern, - TritonGenericPattern, TritonPrintPattern, - TritonAssertPattern, TritonAtomicRMWPattern, TritonFuncOpPattern, - TritonReturnOpPattern, TritonCallOpPattern>(typeConverter, context); + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + TritonBroadcastPattern, GenericOpPattern, + TritonCatPattern, GenericOpPattern, + TritonReducePattern, GenericOpPattern, + TritonScanPattern, GenericOpPattern, + GenericOpPattern, TritonExpandDimsPattern, + TritonTransPattern, TritonDotPattern, GenericOpPattern, + GenericOpPattern, + GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, + GenericOpPattern, TritonFuncOpPattern>(typeConverter, + context); } // @@ -710,8 +547,8 @@ struct SCFForPattern : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { auto newOp = cast(rewriter.cloneWithoutRegions(*op.getOperation())); - rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(), - newOp.getLoopBody().end()); + rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(), + newOp.getRegion().end()); // Now, update all the types. @@ -720,7 +557,7 @@ struct SCFForPattern : public OpConversionPattern { // The entry block may have a special conversion if `entryConversion` is // provided. On success, the new entry block to the region is returned for // convenience. Otherwise, failure is returned. - if (failed(rewriter.convertRegionTypes(&newOp.getLoopBody(), + if (failed(rewriter.convertRegionTypes(&newOp.getRegion(), *getTypeConverter()))) { return rewriter.notifyMatchFailure(op, "could not convert body types"); } @@ -744,22 +581,6 @@ struct SCFForPattern : public OpConversionPattern { } }; -struct SCFYieldPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(scf::YieldOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - // rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock()); - // rewriter.create(op.getLoc(), adaptor.getOperands()); - // op.erase(); - addNamedAttrs( - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()), - adaptor.getAttributes()); - return success(); - } -}; - // This is borrowed from ConvertFIfOpTypes in // SCF/Transforms/StructuralTypeConversions.cpp class SCFIfPattern : public OpConversionPattern { @@ -847,8 +668,8 @@ public: void populateSCFPatterns(TritonGPUTypeConverter &typeConverter, RewritePatternSet &patterns) { MLIRContext *context = patterns.getContext(); - patterns.add(typeConverter, context); + patterns.add, SCFForPattern, SCFIfPattern, + SCFWhilePattern, SCFConditionPattern>(typeConverter, context); } // CF diff --git a/lib/Dialect/NVGPU/IR/CMakeLists.txt b/lib/Dialect/NVGPU/IR/CMakeLists.txt index 4e9e1ada1..24a93ce58 100644 --- a/lib/Dialect/NVGPU/IR/CMakeLists.txt +++ b/lib/Dialect/NVGPU/IR/CMakeLists.txt @@ -6,4 +6,5 @@ add_mlir_dialect_library(NVGPUIR NVGPUAttrDefsIncGen LINK_LIBS PUBLIC + MLIRLLVMDialect ) diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 99a891b9d..0ee9f96eb 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -181,8 +181,9 @@ public: op->getLoc(), newAccType, rewriter.create(op->getLoc(), rewriter.getF32FloatAttr(0))); - rewriter.replaceOpWithNewOp( - op, expandLhsOp.getOperand(), expandRhsOp.getOperand(), newAcc, true); + rewriter.replaceOpWithNewOp(op, expandLhsOp.getOperand(), + expandRhsOp.getOperand(), newAcc, + true, 0); return mlir::success(); } }; diff --git a/lib/Dialect/Triton/Transforms/Combine.td b/lib/Dialect/Triton/Transforms/Combine.td index 974e09924..39c4ad234 100644 --- a/lib/Dialect/Triton/Transforms/Combine.td +++ b/lib/Dialect/Triton/Transforms/Combine.td @@ -12,22 +12,24 @@ include "mlir/IR/PatternBase.td" // AddIOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) // AddFOp(d, DotOp(a, b, c)) and c==0 => DotOp(a, b, d) def CombineDotAddIPattern : Pat< - (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32)), - (TT_DotOp $a, $b, $d, $allowTF32), + (Arith_AddIOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc)), + (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc), [(Constraint> $c)]>; def CombineDotAddFPattern : Pat< - (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32), $fastmath), - (TT_DotOp $a, $b, $d, $allowTF32), - [(Constraint> $c)]>; + (Arith_AddFOp $d, (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $fastmath), + (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc)]>; def CombineDotAddIRevPattern : Pat< - (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d), - (TT_DotOp $a, $b, $d, $allowTF32), + (Arith_AddIOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d), + (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc), [(Constraint> $c)]>; def CombineDotAddFRevPattern : Pat< - (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32), $d, $fastmath), - (TT_DotOp $a, $b, $d, $allowTF32), - [(Constraint> $c)]>; + (Arith_AddFOp (TT_DotOp:$res $a, $b, $c, $allowTF32, $maxNumImpreciseAcc), $d, $fastmath), + (TT_DotOp $a, $b, $d, $allowTF32, $maxNumImpreciseAcc), + [(Constraint> $c), + (Constraint($0).getInt() == 0">> $maxNumImpreciseAcc)]>; // TODO: this fails for addptr(addptr(ptr, i32), i64) // Commented out until fixed diff --git a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp index 1930ab9f6..931777bfa 100644 --- a/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp +++ b/lib/Dialect/Triton/Transforms/ReorderBroadcast.cpp @@ -116,18 +116,23 @@ struct MoveBroadcastAfterElementwisePattern auto operands = op->getOperands(); bool seenBroadcast = false; + ArrayRef srcShape; for (auto operand : operands) { auto definingOp = operand.getDefiningOp(); if (!definingOp) { return mlir::failure(); } - + auto getSrcShape = [](triton::BroadcastOp b) { + return b.getSrc().getType().cast().getShape(); + }; if (auto broadcastOp = llvm::dyn_cast(definingOp)) { - if (seenBroadcast) { - // Only support one broadcasted argument for now + if (!seenBroadcast) { + seenBroadcast = true; + srcShape = getSrcShape(broadcastOp); + } else if (srcShape != getSrcShape(broadcastOp)) { + // If the broadcast have different types we cannot re-order. return mlir::failure(); } - seenBroadcast = true; } else if (!isSplat(definingOp)) { // Not splat or broadcast return mlir::failure(); @@ -149,8 +154,7 @@ struct MoveBroadcastAfterElementwisePattern } } - auto src = broadcastOp.getSrc(); - auto srcTy = src.getType().dyn_cast(); + auto srcTy = broadcastOp.getSrc().getType().dyn_cast(); auto srcShape = srcTy.getShape(); auto srcEncoding = srcTy.getEncoding(); @@ -158,8 +162,9 @@ struct MoveBroadcastAfterElementwisePattern llvm::SmallVector newOperands; for (auto operand : operands) { auto definingOp = operand.getDefiningOp(); - if (llvm::isa(definingOp)) { - newOperands.push_back(src); + if (auto broadcastSrcOp = + llvm::dyn_cast(definingOp)) { + newOperands.push_back(broadcastSrcOp.getSrc()); continue; } auto elemTy = diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 44eb9e0b1..a3f51a6dc 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -326,9 +326,9 @@ public: Operation *rewriteForOp(OpBuilder &builder, scf::ForOp op, std::stack &eraser) { // Generate new iteration operands and set rewrited information - SmallVector oldIterOperands = op.getIterOperands(); - SmallVector newIterOperands = op.getIterOperands(); - for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size; + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; ++i, ++oldI) { if (!triton::isTensorPointerType(newIterOperands[i].getType())) continue; @@ -351,7 +351,7 @@ public: // mapping. It may refer to a value in the old loop, but we will rewrite it // later IRMapping mapping; - for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands(); + for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; ++i, ++oldI) { auto oldRegionIterArg = op.getRegionIterArg(oldI); if (triton::isTensorPointerType(oldRegionIterArg.getType())) { @@ -378,7 +378,7 @@ public: } // Replace later usages - assert(op.getNumResults() == op.getNumIterOperands()); + assert(op.getNumResults() == op.getInitArgs().size()); for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { auto oldResult = op.getResult(oldI); if (triton::isTensorPointerType(oldResult.getType())) { diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 5ff25a0a5..113e709c1 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1031,6 +1031,11 @@ unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef shape, if (auto mmaParent = getParent().dyn_cast()) { int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0]; int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1]; + // H100 + if (mmaParent.isHopper()) { + if (eltTy.isF16()) + return mmaParent.getTotalElemsPerThread(shape, eltTy); + } // A100 if (mmaParent.isAmpere()) { auto rep = getMMAv2Rep(shapePerCTA, eltTy.getIntOrFloatBitWidth()); @@ -1718,10 +1723,13 @@ struct TritonGPUInferLayoutInterface std::optional location) const override { auto mmaRetEncoding = retEncoding.dyn_cast(); if (mmaRetEncoding && mmaRetEncoding.isHopper()) { - // TODO: support gmma when A/B does not reside in shared memory - if (!operandEncoding.isa()) + auto dotOpEnc = operandEncoding.dyn_cast(); + if (!operandEncoding.isa() && + !(opIdx == 0 && dotOpEnc && dotOpEnc.getOpIdx() == 0 && + dotOpEnc.getParent().isa())) { return emitOptionalError( location, "unexpected operand layout for MmaEncodingAttr v3"); + } } else if (auto dotOpEnc = operandEncoding.dyn_cast()) { if (opIdx != dotOpEnc.getOpIdx()) @@ -1743,6 +1751,10 @@ struct TritonGPUInferLayoutInterface operandEncodingB.dyn_cast(); if (!aEncoding && !bEncoding) return mlir::success(); + auto mmaAEncoding = + aEncoding.getParent().dyn_cast_or_null(); + if (mmaAEncoding && mmaAEncoding.isHopper()) + return success(); // Verify that the encodings are valid. if (!aEncoding || !bEncoding) return op->emitError("mismatching encoding between A and B operands"); @@ -1756,6 +1768,7 @@ struct TritonGPUInferLayoutInterface // Canonicalizer //===----------------------------------------------------------------------===// +<<<<<<< HEAD LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, PatternRewriter &rewriter) { // we don't handle conversions to DotOperandEncodingAttr @@ -1779,136 +1792,196 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op, return mlir::failure(); } } +======= +struct CanonicalizeConvertFromView + : public mlir::OpRewritePattern { +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 - // convert to the same layout -- we can delete - if (op->getResultTypes() == op->getOperandTypes()) { - rewriter.replaceOp(op, op->getOperands()); - return mlir::success(); - } - Operation *arg = op->getOperand(0).getDefiningOp(); - // block argument - if (!arg) - return mlir::failure(); - // cvt(view) -> view - if (auto view = dyn_cast(arg)) { - rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), - view.getResult()); - return mlir::success(); - } - // cvt(cat) -> cat - if (auto cat = dyn_cast(arg)) { - auto encoding = - op->getResult(0).getType().cast().getEncoding(); - if (isExpensiveCat(cat, encoding)) - return mlir::failure(); - rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), - cat.getOperands()); - return mlir::success(); - } - // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) - auto alloc_tensor = dyn_cast(arg); - if (alloc_tensor) { - if (!triton::gpu::isSharedEncoding(op->getResult(0))) { - return mlir::failure(); - } - rewriter.replaceOpWithNewOp( - op, op->getResult(0).getType()); - return mlir::success(); - } - // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) - auto insert_slice = dyn_cast(arg); - if (insert_slice) { - if (!triton::gpu::isSharedEncoding(op->getResult(0))) { - return mlir::failure(); - } - auto newType = op->getResult(0).getType().cast(); - // Ensure that the new insert_slice op is placed in the same place as - // the old insert_slice op. Otherwise, the new insert_slice op may be - // placed after the async_wait op, which is not allowed. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(insert_slice); - auto newArg = rewriter.create( - op->getLoc(), newType, insert_slice.getDst()); - rewriter.replaceOpWithNewOp( - op, newType, insert_slice.getSrc(), newArg.getResult(), - insert_slice.getIndex(), insert_slice.getMask(), - insert_slice.getOther(), insert_slice.getCache(), - insert_slice.getEvict(), insert_slice.getIsVolatile(), - insert_slice.getAxis()); - return mlir::success(); - } - // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) - auto extract_slice = dyn_cast(arg); - if (extract_slice) { - if (!triton::gpu::isSharedEncoding(op->getResult(0))) { - return mlir::failure(); - } - auto origType = - extract_slice.getSource().getType().cast(); - auto newType = RankedTensorType::get( - origType.getShape(), origType.getElementType(), - op->getResult(0).getType().cast().getEncoding()); - auto origResType = op->getResult(0).getType().cast(); - auto resType = RankedTensorType::get( - origResType.getShape(), origResType.getElementType(), - extract_slice.getType().cast().getEncoding()); - // Ensure that the new extract_slice op is placed in the same place as - // the old extract_slice op. Otherwise, the new extract_slice op may be - // placed after the async_wait op, which is not allowed. - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(extract_slice); - auto newArg = rewriter.create( - op->getLoc(), newType, extract_slice.getSource()); - rewriter.replaceOpWithNewOp( - op, resType, newArg.getResult(), extract_slice.getOffsets(), - extract_slice.getSizes(), extract_slice.getStrides(), - extract_slice.getStaticOffsets(), extract_slice.getStaticSizes(), - extract_slice.getStaticStrides()); - return mlir::success(); - } + CanonicalizeConvertFromView(MLIRContext *context) + : OpRewritePattern(context, 1) {} - // cvt(cvt(x, type1), type2) -> cvt(x, type2) - if (llvm::isa(arg)) { - if (arg->getOperand(0).getDefiningOp() && - !triton::gpu::isSharedEncoding(arg->getOperand(0)) && - triton::gpu::isSharedEncoding(op.getOperand()) && - !triton::gpu::isSharedEncoding(op.getResult())) { + mlir::LogicalResult + matchAndRewrite(triton::ViewOp op, PatternRewriter &rewriter) const override { + Operation *arg = op->getOperand(0).getDefiningOp(); + if (!arg) return mlir::failure(); - } - if (triton::gpu::isSharedEncoding(op.getOperand()) && - triton::gpu::isSharedEncoding(op.getResult())) { - return mlir::failure(); - } - auto srcType = op.getOperand().getType().cast(); - auto srcShared = - srcType.getEncoding().dyn_cast(); - if (srcShared && srcShared.getVec() > 1) - return mlir::failure(); - rewriter.replaceOpWithNewOp( - op, op->getResultTypes().front(), arg->getOperand(0)); - return mlir::success(); - } - // cvt(type1, splat(type2, x)) -> splat(type1, x) - if (auto splat = llvm::dyn_cast(arg)) { - rewriter.replaceOpWithNewOp(op, op->getResultTypes(), - splat.getSrc()); - return mlir::success(); - } - // cvt(type1, make_range(type2, x)) -> make_range(type1, x) - if (auto range = llvm::dyn_cast(arg)) { - rewriter.replaceOpWithNewOp( - op, op->getResultTypes(), range.getStart(), range.getEnd()); - return mlir::success(); - } - // cvt(type, constant) -> constant - if (auto cst = llvm::dyn_cast(arg)) - if (auto ret = cst.getValue().dyn_cast()) { - auto ty = op->getResultTypes().front().cast(); - auto newRet = SplatElementsAttr::get(ty, ret.getSplatValue()); - rewriter.replaceOpWithNewOp(op, newRet); + // view(convert) -> view + if (auto convert = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), convert.getOperand()); return mlir::success(); } - return mlir::failure(); + return mlir::failure(); + } +}; + +struct CanonicalizeConvertFromConvert + : public mlir::OpRewritePattern { + + CanonicalizeConvertFromConvert(mlir::MLIRContext *context) + : OpRewritePattern(context, 1) {} + + mlir::LogicalResult + matchAndRewrite(ConvertLayoutOp op, + mlir::PatternRewriter &rewriter) const override { + // we don't handle conversions to DotOperandEncodingAttr + // this is a heuristics to accommodate fused attention + auto srcType = op.getOperand().getType().cast(); + auto dstType = op.getType().cast(); + if (dstType.getEncoding().isa() && + srcType.getEncoding().isa()) + return mlir::failure(); + // for hopper MMAv3 + if (!op.use_empty()) { + bool hasDotUser = false; + for (Operation *dot : op.getResult().getUsers()) + if (isa(dot)) + hasDotUser = true; + + if (hasDotUser) { + if (dstType.getEncoding().isa() && + srcType.getEncoding().isa()) + return mlir::failure(); + } + } + + // convert to the same layout -- we can delete + if (op->getResultTypes() == op->getOperandTypes()) { + rewriter.replaceOp(op, op->getOperands()); + return mlir::success(); + } + Operation *arg = op->getOperand(0).getDefiningOp(); + // block argument + if (!arg) + return mlir::failure(); + // cvt(view) -> view + if (auto view = dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType(), view.getResult()); + return mlir::success(); + } + // cvt(cat) -> cat + if (auto cat = dyn_cast(arg)) { + auto encoding = + op->getResult(0).getType().cast().getEncoding(); + if (isExpensiveCat(cat, encoding)) + return mlir::failure(); + rewriter.replaceOpWithNewOp(op, op->getResult(0).getType(), + cat.getOperands()); + return mlir::success(); + } + // cvt(alloc_tensor(x), type2) -> alloc_tensor(x, type2) + auto alloc_tensor = dyn_cast(arg); + if (alloc_tensor) { + if (!triton::gpu::isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + rewriter.replaceOpWithNewOp( + op, op->getResult(0).getType()); + return mlir::success(); + } + // cvt(insert_slice(x), type2) -> insert_slice(cvt(x, type2)) + auto insert_slice = dyn_cast(arg); + if (insert_slice) { + if (!triton::gpu::isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + auto newType = op->getResult(0).getType().cast(); + // Ensure that the new insert_slice op is placed in the same place as + // the old insert_slice op. Otherwise, the new insert_slice op may be + // placed after the async_wait op, which is not allowed. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(insert_slice); + auto newArg = rewriter.create( + op->getLoc(), newType, insert_slice.getDst()); + rewriter.replaceOpWithNewOp( + op, newType, insert_slice.getSrc(), newArg.getResult(), + insert_slice.getIndex(), insert_slice.getMask(), + insert_slice.getOther(), insert_slice.getCache(), + insert_slice.getEvict(), insert_slice.getIsVolatile(), + insert_slice.getAxis()); + return mlir::success(); + } + // cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2)) + auto extract_slice = dyn_cast(arg); + if (extract_slice) { + if (!triton::gpu::isSharedEncoding(op->getResult(0))) { + return mlir::failure(); + } + auto origType = + extract_slice.getSource().getType().cast(); + auto newType = RankedTensorType::get( + origType.getShape(), origType.getElementType(), + op->getResult(0).getType().cast().getEncoding()); + auto origResType = op->getResult(0).getType().cast(); + auto resType = RankedTensorType::get( + origResType.getShape(), origResType.getElementType(), + extract_slice.getType().cast().getEncoding()); + // Ensure that the new extract_slice op is placed in the same place as + // the old extract_slice op. Otherwise, the new extract_slice op may be + // placed after the async_wait op, which is not allowed. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(extract_slice); + auto newArg = rewriter.create( + op->getLoc(), newType, extract_slice.getSource()); + rewriter.replaceOpWithNewOp( + op, resType, newArg.getResult(), extract_slice.getOffsets(), + extract_slice.getSizes(), extract_slice.getStrides(), + extract_slice.getStaticOffsets(), extract_slice.getStaticSizes(), + extract_slice.getStaticStrides()); + return mlir::success(); + } + + // cvt(cvt(x, type1), type2) -> cvt(x, type2) + if (llvm::isa(arg)) { + if (arg->getOperand(0).getDefiningOp() && + !triton::gpu::isSharedEncoding(arg->getOperand(0)) && + triton::gpu::isSharedEncoding(op.getOperand()) && + !triton::gpu::isSharedEncoding(op.getResult())) { + return mlir::failure(); + } + if (triton::gpu::isSharedEncoding(op.getOperand()) && + triton::gpu::isSharedEncoding(op.getResult())) { + return mlir::failure(); + } + auto srcType = op.getOperand().getType().cast(); + auto srcShared = + srcType.getEncoding().dyn_cast(); + if (srcShared && srcShared.getVec() > 1) + return mlir::failure(); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), arg->getOperand(0)); + return mlir::success(); + } + // cvt(type1, splat(type2, x)) -> splat(type1, x) + if (auto splat = llvm::dyn_cast(arg)) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + splat.getSrc()); + return mlir::success(); + } + // cvt(type1, make_range(type2, x)) -> make_range(type1, x) + if (auto range = llvm::dyn_cast(arg)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), range.getStart(), range.getEnd()); + return mlir::success(); + } + // cvt(type, constant) -> constant + if (auto cst = llvm::dyn_cast(arg)) + if (auto ret = cst.getValue().dyn_cast()) { + auto ty = op->getResultTypes().front().cast(); + auto newRet = + SplatElementsAttr::get(ty, ret.getSplatValue()); + rewriter.replaceOpWithNewOp(op, newRet); + return mlir::success(); + } + return mlir::failure(); + } +}; + +void ConvertLayoutOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index a61faa356..1fe5a1c63 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -124,8 +124,7 @@ warpsPerTileV3(tt::DotOp dotOp, const ArrayRef shape, int numWarps, class BlockedToMMA : public mlir::RewritePattern { int computeCapability; mutable int mmaV1Counter{}; // used to generate ID for MMAv1 encoding - mutable llvm::SmallVector> dotOpSetVector; - mutable llvm::SmallVector mmaV3InstrNs; + mutable llvm::DenseMap dotOpInstNs; static bool bwdFilter(Operation *op) { return op->getNumOperands() == 1 && @@ -168,40 +167,11 @@ public: } } - unsigned getMmaV3InstrN(tt::DotOp dotOp, unsigned currN) const { - auto type = dotOp.getResult().getType().cast(); - if (type.getEncoding().isa()) - return currN; - for (size_t i = 0; i < dotOpSetVector.size(); ++i) { - if (dotOpSetVector[i].count(dotOp.getOperation()) > 0) - return mmaV3InstrNs[i]; - } - - SetVector slices; - mlir::getForwardSlice(dotOp.getResult(), &slices); - mlir::getBackwardSlice(dotOp.getOperation(), &slices); - unsigned N = currN; - llvm::SetVector dotOpSet; - for (Operation *iter : slices) { - if (auto nextDotOp = dyn_cast(iter)) { - auto type = nextDotOp.getResult().getType().cast(); - auto AType = nextDotOp.getOperand(0).getType().cast(); - auto shapePerCTA = ttg::getShapePerCTA(type); - auto instrShape = mmaVersionToInstrShape(3, shapePerCTA, AType); - dotOpSet.insert(iter); - if (instrShape[1] < N) - N = instrShape[1]; - } - } - mmaV3InstrNs.push_back(N); - dotOpSetVector.push_back(dotOpSet); - return N; - } - static Value getMMAv3Operand(Value v, mlir::PatternRewriter &rewriter, int opIdx) { - auto cvtOp = dyn_cast_or_null(v.getDefiningOp()); - auto arg = cvtOp.getSrc(); + Value arg = v; + if (auto cvtOp = v.getDefiningOp()) + arg = cvtOp.getSrc(); auto argType = arg.getType().cast(); auto eltType = argType.getElementType(); assert(argType.getEncoding() && "unexpected tensor type"); @@ -255,9 +225,6 @@ public: auto instrShape = mmaVersionToInstrShape(versionMajor, retShapePerCTA, AType); - if (versionMajor == 3) - instrShape[1] = getMmaV3InstrN(dotOp, instrShape[1]); - // operands Value a = dotOp.getA(); Value b = dotOp.getB(); @@ -338,7 +305,8 @@ public: } // convert dot instruction auto newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, - newAcc, dotOp.getAllowTF32()); + newAcc, dotOp.getAllowTF32(), + dotOp.getMaxNumImpreciseAcc()); rewriter.replaceOpWithNewOp(op, oldRetType, newDot.getResult()); diff --git a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp index 5ebc88083..4e32f7ab2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Coalesce.cpp @@ -21,11 +21,39 @@ template SmallVector argSort(const T &arr) { return ret; } -typedef DenseMap> LayoutMap; +unsigned getElementBitWidth(const Value &val) { + auto valType = val.getType(); + if (valType.isa()) + valType = valType.cast().getPointeeType(); + auto tensorType = valType.cast(); + + auto typeForMem = + tensorType.getElementType().isa() + ? tensorType.getElementType().cast().getPointeeType() + : tensorType.getElementType(); + return typeForMem.getIntOrFloatBitWidth(); +} + +static Value getMemAccessPtr(Operation *op) { + if (auto ld = dyn_cast(op)) + return ld.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto atomic = dyn_cast(op)) + return atomic.getPtr(); + if (auto insert = dyn_cast(op)) + return insert.getSrc(); + if (auto store = dyn_cast(op)) + return store.getPtr(); + return nullptr; +} struct CoalescePass : public TritonGPUCoalesceBase { - Attribute getCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, - Value ptr, int numWarps, int threadsPerWarp) { + void + setCoalescedEncoding(ModuleAxisInfoAnalysis &axisInfoAnalysis, Operation *op, + int numWarps, int threadsPerWarp, + llvm::MapVector &layoutMap) { + Value ptr = getMemAccessPtr(op); auto refType = ptr.getType(); if (refType.isa()) refType = refType.cast().getPointeeType(); @@ -74,21 +102,35 @@ struct CoalescePass : public TritonGPUCoalesceBase { order = argSort(queryAxisInfo(ptr).getContiguity()); } + auto matchesShape = [&refTensorType](const Value &val) { + if (val.getType() == refTensorType) { + return true; + } + + auto rttType = val.getType().dyn_cast(); + if (!rttType) { + return false; + } + return rttType.getShape() == refTensorType.getShape(); + }; + // The desired divisibility is the maximum divisibility // among all dependent pointers who have the same order as // `ptr`. // We only do it for normal tensors of pointers, not tensor pointers. - SetVector withSameOrder; - withSameOrder.insert(ptr); + llvm::SmallSetVector memAccessesSameOrder; + memAccessesSameOrder.insert(op); if (refType.isa() && ptr.getDefiningOp()) { - for (Operation *op : mlir::multiRootGetSlice(ptr.getDefiningOp())) { - for (Value val : op->getResults()) { - if (val.getType() != refTensorType) - continue; - auto currOrder = - argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); - if (order == currOrder) - withSameOrder.insert(val); + for (Operation *use : mlir::multiRootGetSlice(op)) { + Value val = getMemAccessPtr(use); + if (!val) + continue; + if (!matchesShape(val)) + continue; + auto currOrder = + argSort(axisInfoAnalysis.getAxisInfo(val)->getContiguity()); + if (order == currOrder) { + memAccessesSameOrder.insert(use); } } } @@ -107,73 +149,82 @@ struct CoalescePass : public TritonGPUCoalesceBase { .getPointeeType() : refTensorType.getElementType(); - // Thread tile size depends on memory alignment - SmallVector sizePerThread(refTensorType.getRank(), 1); - unsigned elemNumBits = typeForMem.getIntOrFloatBitWidth(); - unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); - unsigned perThread = 1; - for (Value val : withSameOrder) { + auto getNumElementPerThread = [&](Operation *op) { + Value val = getMemAccessPtr(op); auto valInfo = queryAxisInfo(val); + unsigned elemNumBits = getElementBitWidth(val); + unsigned elemNumBytes = std::max(elemNumBits / 8, 1u); unsigned maxMultipleBytes = valInfo.getDivisibility(order[0]); unsigned maxMultiple = std::max(maxMultipleBytes / elemNumBytes, 1u); unsigned maxContig = std::min(valInfo.getContiguity(order[0]), shapePerCTA[order[0]]); unsigned alignment = std::min(maxMultiple, maxContig); unsigned currPerThread = std::min(alignment, 128 / elemNumBits); + return currPerThread; + }; + unsigned perThread = getNumElementPerThread(op); + for (Operation *op : memAccessesSameOrder) { + unsigned currPerThread = getNumElementPerThread(op); perThread = std::max(perThread, currPerThread); } - sizePerThread[order[0]] = std::min(perThread, numElemsPerThread); + + perThread = std::min(perThread, numElemsPerThread); + + if (!dyn_cast(op)) { + // For ops that can result in a global memory write, we should enforce + // that each thread handles at most 128 bits, which is the widest + // available vectorized store op; otherwise, the store will have "gaps" + // in the memory write at the warp level, resulting in worse performance. + // For loads, we can expect that the gaps won't matter due to the L1 + // cache. + unsigned elemNumBits = getElementBitWidth(ptr); + perThread = std::min(perThread, getNumElementPerThread(op)); + } + SmallVector sizePerThread(refTensorType.getRank(), 1); + sizePerThread[order[0]] = perThread; auto CTALayout = triton::gpu::getCTALayout(refTensorType.getEncoding()); - return triton::gpu::BlockedEncodingAttr::get( + layoutMap[op] = triton::gpu::BlockedEncodingAttr::get( &getContext(), refTensorType.getShape(), sizePerThread, order, numWarps, threadsPerWarp, CTALayout); } - std::function - getTypeConverter(ModuleAxisInfoAnalysis &axisInfoAnalysis, Value ptr, - int numWarps, int threadsPerWarp) { - Attribute encoding = - getCoalescedEncoding(axisInfoAnalysis, ptr, numWarps, threadsPerWarp); - return [encoding](Type type) { - RankedTensorType tensorType = type.cast(); - return RankedTensorType::get(tensorType.getShape(), - tensorType.getElementType(), encoding); - }; + static Type getNewType(Type type, Attribute encoding) { + RankedTensorType tensorType = type.cast(); + return RankedTensorType::get(tensorType.getShape(), + tensorType.getElementType(), encoding); } - template - void coalesceOp(LayoutMap &layoutMap, Operation *op, Value ptr, - OpBuilder builder) { - if (!layoutMap.count(ptr)) - return; - + void coalesceOp(Attribute encoding, Operation *op) { + OpBuilder builder(op); // Convert operands // For load/store with tensor pointers, we don't have to change the // operands' type, we do this by changing the outputs' type of // `make_tensor_ptr` - auto convertType = layoutMap.lookup(ptr); SmallVector newArgs; for (auto operand : op->getOperands()) { auto tensorType = operand.getType().dyn_cast(); if (tensorType && - !tensorType.getEncoding().isa()) + !tensorType.getEncoding().isa()) { + Type newType = getNewType(tensorType, encoding); newArgs.push_back(builder.create( - op->getLoc(), convertType(tensorType), operand)); - else + op->getLoc(), newType, operand)); + } else { newArgs.push_back(operand); + } } // Convert output types SmallVector newTypes; for (auto t : op->getResultTypes()) { - bool isAsync = std::is_same::value; - newTypes.push_back(isAsync ? t : convertType(t)); + bool isAsync = isa(op); + newTypes.push_back(isAsync ? t : getNewType(t, encoding)); } // Construct new op with the new encoding Operation *newOp = - builder.create(op->getLoc(), newTypes, newArgs, op->getAttrs()); + builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs, + newTypes, op->getAttrs()); // Cast the results back to the original layout for (size_t i = 0; i < op->getNumResults(); i++) { @@ -187,25 +238,6 @@ struct CoalescePass : public TritonGPUCoalesceBase { op->erase(); } - void coalesceMakeTensorPtrOpResult(LayoutMap &layoutMap, Operation *op, - Value ptr, OpBuilder builder) { - if (!layoutMap.count(ptr)) - return; - - // Convert result type - auto convertType = layoutMap.lookup(ptr); - auto ptrType = ptr.getType().cast(); - auto resultTensorType = convertType(ptrType.getPointeeType()); - auto newResultType = - PointerType::get(resultTensorType, ptrType.getAddressSpace()); - - // Build new operation and replace - Operation *newOp = builder.create( - op->getLoc(), newResultType, op->getOperands(), op->getAttrs()); - op->getResult(0).replaceAllUsesWith(newOp->getResult(0)); - op->erase(); - } - void runOnOperation() override { // Run axis info analysis ModuleOp moduleOp = getOperation(); @@ -213,19 +245,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { // For each i/o operation, we determine what layout // the pointers should have for best memory coalescing - LayoutMap layoutMap; + llvm::MapVector layoutMap; moduleOp.walk([&](Operation *curr) { - Value ptr; - if (auto op = dyn_cast(curr)) - ptr = op.getPtr(); - if (auto op = dyn_cast(curr)) - ptr = op.getPtr(); - if (auto op = dyn_cast(curr)) - ptr = op.getPtr(); - if (auto op = dyn_cast(curr)) - ptr = op.getSrc(); - if (auto op = dyn_cast(curr)) - ptr = op.getPtr(); + Value ptr = getMemAccessPtr(curr); if (!ptr) return; // We only convert `tensor>` or `tt.ptr>` load/store @@ -240,9 +262,8 @@ struct CoalescePass : public TritonGPUCoalesceBase { int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - auto convertType = - getTypeConverter(axisInfoAnalysis, ptr, numWarps, threadsPerWarp); - layoutMap[ptr] = convertType; + setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp, + layoutMap); }); // For each memory op that has a layout L1: @@ -252,30 +273,9 @@ struct CoalescePass : public TritonGPUCoalesceBase { // produces a tensor with layout L2 // 4. Convert the output of this new memory op back to L1 // 5. Replace all the uses of the original memory op by the new one - moduleOp.walk([&](Operation *curr) { - OpBuilder builder(curr); - if (auto load = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, load.getPtr(), builder); - return; - } - if (auto op = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, op.getPtr(), builder); - return; - } - if (auto op = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, op.getPtr(), builder); - return; - } - if (auto load = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, - load.getSrc(), builder); - return; - } - if (auto store = dyn_cast(curr)) { - coalesceOp(layoutMap, curr, store.getPtr(), builder); - return; - } - }); + for (auto &kv : layoutMap) { + coalesceOp(kv.second, kv.first); + } } }; diff --git a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp index 43f411bca..8e13719b7 100644 --- a/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp +++ b/lib/Dialect/TritonGPU/Transforms/OptimizeDotOperands.cpp @@ -60,9 +60,14 @@ public: // used here. For tests where numCTAs = 1, this is not a problem since all // CTALayouts are the same. auto newXOrder = triton::gpu::getOrder(argEncoding); + // set needTrans to true here. newXEncoding is computed based on argEncoding + // which is before the transpose. without needTrans we will compute vec and + // maxPhase based on incorrect m, n and k size of mma. the type inference of + // TransOp simply swap the order but doesn't fix the vec and maxPhase for + // the YType, hence it would causing incorrect swizzling code. auto newXEncoding = triton::gpu::SharedEncodingAttr::get( getContext(), ZEncoding, XType.getShape(), newXOrder, - XEncoding.getCTALayout(), XType.getElementType()); + XEncoding.getCTALayout(), XType.getElementType(), true); auto newXType = RankedTensorType::get(XType.getShape(), XType.getElementType(), newXEncoding); if (XEncoding == newXEncoding) @@ -226,6 +231,44 @@ public: } }; +struct MMAV3UseRegOperand : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(triton::DotOp dotOp, + PatternRewriter &rewriter) const override { + auto convertLhs = + dotOp.getOperand(0).getDefiningOp(); + if (!convertLhs) + return failure(); + auto getEncoding = [](Value v) { + return v.getType().cast().getEncoding(); + }; + if (!getEncoding(dotOp.getOperand(0)).isa()) + return failure(); + auto srcEncoding = + getEncoding(convertLhs.getSrc()).dyn_cast(); + auto dstEncoding = + getEncoding(dotOp.getResult()).dyn_cast(); + if (!srcEncoding || srcEncoding.getVersionMajor() != 3 || !dstEncoding || + dstEncoding.getVersionMajor() != 3) + return failure(); + // We currently only support convert from f16 mma to f16 dot operand as the + // other types require shuffling data across threads. + // TODO: extend it to more types. + auto srcType = convertLhs.getSrc().getType().cast(); + if (!srcType.getElementType().isF16()) + return failure(); + auto dotOperandEncoding = + DotOperandEncodingAttr::get(dotOp.getContext(), 0, srcEncoding, 0); + auto newType = RankedTensorType::get( + srcType.getShape(), srcType.getElementType(), dotOperandEncoding); + Value newOperand = rewriter.create(dotOp.getLoc(), newType, + convertLhs.getSrc()); + rewriter.updateRootInPlace(dotOp, + [&]() { dotOp.setOperand(0, newOperand); }); + return success(); + } +}; } // namespace #define GEN_PASS_CLASSES @@ -247,8 +290,10 @@ public: mlir::RewritePatternSet patterns(context); patterns.add(context); - patterns.add(context); + if (triton::gpu::TritonGPUDialect::getComputeCapability(m) >= 80) + patterns.add(context); patterns.add(context); + patterns.add(context); if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed()) signalPassFailure(); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp index db5513d92..67b86e238 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeline.cpp @@ -347,7 +347,7 @@ LogicalResult LoopPipeliner::collectOps(SetVector &ops) { void LoopPipeliner::collectValueDep(Value v, int stage, SetVector &deps) { // Loop-invariant value, skip - if (v.getParentRegion() != &forOp.getLoopBody()) + if (v.getParentRegion() != &forOp.getRegion()) return; // Since we only need to peel the loop numStages-1 times, don't worry @@ -652,10 +652,12 @@ void LoopPipeliner::createBufferTypes() { .getEncoding() .dyn_cast()) { // MMAv1 and MMAv2 + bool needTrans = dyn_cast_or_null( + cvt.getDefiningOp()->getOperand(0).getDefiningOp()); unsigned bitWidth = ty.getElementType().getIntOrFloatBitWidth(); sharedEnc = ttg::SharedEncodingAttr::get( ty.getContext(), dotOpEnc, ty.getShape(), - ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth); + ttg::getOrder(ty.getEncoding()), CTALayout, bitWidth, needTrans); } else { // MMAv3 sharedEnc = ttg::SharedEncodingAttr::get(ty.getContext(), ty.getShape(), @@ -669,7 +671,7 @@ void LoopPipeliner::createBufferTypes() { } void LoopPipeliner::createOrderedDeps() { - for (Operation &op : forOp.getLoopBody().front()) { + for (Operation &op : *forOp.getBody()) { if (depOps.contains(&op)) orderedDeps.push_back(&op); else if (op.getNumResults() > 0 && validLoads.contains(op.getResult(0))) @@ -1005,7 +1007,7 @@ SmallVector LoopPipeliner::collectNewLoopArgs() { // We need this to update operands for yield // original block arg => new arg's idx SmallVector newLoopArgs; - for (auto v : forOp.getIterOperands()) + for (auto v : forOp.getInitArgs()) newLoopArgs.push_back(v); bufferIdx = newLoopArgs.size(); @@ -1638,7 +1640,8 @@ void PipelinePass::asyncLaunchDots(scf::ForOp forOp) { auto dotOp = cast(dot.getDefiningOp()); builder.setInsertionPoint(dot.getDefiningOp()); auto dotAsync = builder.create( - loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32()); + loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), dotOp.getAllowTF32(), + dotOp.getMaxNumImpreciseAcc()); dot.replaceAllUsesWith(dotAsync.getResult()); updateConsumerReleaseInfo(dot.getDefiningOp(), dotWait, /*stage=*/1); dot.getDefiningOp()->erase(); diff --git a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp index 07e982dbf..a597ada6c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Prefetch.cpp @@ -269,7 +269,7 @@ scf::ForOp Prefetcher::createNewForOp() { OpBuilder builder(forOp); SmallVector loopArgs; - for (auto v : forOp.getIterOperands()) + for (auto v : forOp.getInitArgs()) loopArgs.push_back(v); for (Value dot : dots) { loopArgs.push_back( diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index d3ad7350f..71981b97a 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -136,7 +136,8 @@ public: op->getLoc(), dotOp.getResult().getType(), _0f); auto newDot = rewriter.create( op->getLoc(), dotOp.getResult().getType(), dotOp.getOperand(0), - dotOp.getOperand(1), _0, dotOp.getAllowTF32()); + dotOp.getOperand(1), _0, dotOp.getAllowTF32(), + dotOp.getMaxNumImpreciseAcc()); auto newCvt = rewriter.create( op->getLoc(), dstTy, newDot.getResult()); rewriter.replaceOpWithNewOp(op, newCvt, cvtOp.getOperand()); @@ -235,7 +236,8 @@ static bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { if (convertOp.getResult() .getType() .cast() - .getEncoding() == encoding) + .getEncoding() + .isa()) return true; } auto yield = dyn_cast(op); @@ -298,7 +300,8 @@ static bool hasConvertToMFMATransisitiveUse(Operation *op, Attribute encoding) { static bool isLayoutAnchor(Operation *op) { if (isa(op)) return isExpensiveLoadOrStore(op); - if (isa(op)) + if (isa(op)) return true; return false; } @@ -315,6 +318,7 @@ void LayoutPropagation::initAnchorLayout() { if (tensorType.getEncoding().isa() && !hasConvertToMMATransisitiveUse(op, tensorType.getEncoding())) continue; +<<<<<<< HEAD #ifdef USE_ROCM // Workaround to not propagate MFMA layout in case there are // no chained dots MFMA layout is expensive to convert, so we want @@ -328,6 +332,9 @@ void LayoutPropagation::initAnchorLayout() { continue; #endif layouts.insert({result, tensorType.getEncoding()}); +======= + layouts.insert({result, LayoutInfo(tensorType.getEncoding())}); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } } } @@ -431,14 +438,20 @@ void LayoutPropagation::propagateLayout() { void LayoutPropagation::resolveConflicts() { for (auto &it : layouts) { + Operation *op = it.first.getDefiningOp(); LayoutInfo &info = it.second; if (info.encodings.size() <= 1) continue; // Hacky resolve, prefer block encoding. // TODO: add a proper heuristic. + int maxSizePerThread = 1; Attribute encoding = *info.encodings.begin(); + bool isLoadOrStore = + op && isa(op); for (Attribute e : info.encodings) { - if (e.isa()) { + if ((isLoadOrStore && e.isa()) || + (!isLoadOrStore && e.isa())) { encoding = e; break; } @@ -829,16 +842,15 @@ static scf::ForOp replaceForOpWithNewSignature(OpBuilder &rewriter, // Create a new loop before the existing one, with the extra operands. rewriter.setInsertionPoint(loop); - auto operands = llvm::to_vector<4>(loop.getIterOperands()); + auto operands = llvm::to_vector<4>(loop.getInitArgs()); operands.append(newIterOperands.begin(), newIterOperands.end()); scf::ForOp newLoop = rewriter.create( loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(), operands); newLoop.getBody()->erase(); - newLoop.getLoopBody().getBlocks().splice( - newLoop.getLoopBody().getBlocks().begin(), - loop.getLoopBody().getBlocks()); + newLoop.getRegion().getBlocks().splice( + newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks()); for (Value operand : newIterOperands) newLoop.getBody()->addArgument(operand.getType(), operand.getLoc()); @@ -873,9 +885,9 @@ static void rewriteSlice(SetVector &slice, for (auto arg : forOp.getRegionIterArgs()) { if (slice.count(arg)) { OpOperand &initVal = forOp.getOpOperandForRegionIterArg(arg); - argMapping.push_back( - std::make_pair(*forOp.getIterArgNumberForOpOperand(initVal), - forOp.getNumIterOperands() + newOperands.size())); + argMapping.push_back(std::make_pair( + forOp.getResultForOpOperand(initVal).getResultNumber(), + forOp.getInitArgs().size() + newOperands.size())); newOperands.push_back(mapping.lookup(initVal.get())); } } @@ -984,7 +996,7 @@ static void backwardRematerialization(ConvertLayoutOp convertOp) { // For convert left we try to hoist them above type extension to reduce the cost // of the convert. -static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { +static void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp) { // we don't want to rematerialize any conversion to/from shared if (triton::gpu::isSharedEncoding(convertOp.getResult()) || triton::gpu::isSharedEncoding(convertOp.getOperand())) @@ -995,9 +1007,15 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { if (targetType.getEncoding().isa()) return; +<<<<<<< HEAD #ifndef USE_ROCM auto isExtOp = [](Operation *op) { return isa(op); +======= + auto isExtOrBroadcastOp = [](Operation *op) { + return isa(op); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 }; #else auto isExtOp = [](Operation *op) { @@ -1008,19 +1026,20 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { // 1. Take a backward slice of all the tensor dependencies. SetVector slice; DenseMap layout; - LogicalResult result = getRematerializableSlice( - convertOp.getOperand(), targetType.getEncoding(), slice, layout, isExtOp); + LogicalResult result = + getRematerializableSlice(convertOp.getOperand(), targetType.getEncoding(), + slice, layout, isExtOrBroadcastOp); if (result.failed()) return; - Operation *extOp = nullptr; + Operation *extOrBroadcatOp = nullptr; unsigned sliceSize = slice.size(); for (unsigned i = 0; i < sliceSize; i++) { Value v = slice[i]; Operation *op = v.getDefiningOp(); if (!op) continue; - if (isExtOp(op)) { + if (isExtOrBroadcastOp(op)) { SetVector tempSlice; DenseMap tempLayout; std::optional srcEncoding = inferSrcEncoding(op, layout[v]); @@ -1037,26 +1056,38 @@ static void hoistConvertOnTopOfExt(ConvertLayoutOp convertOp) { } // Only apply it if there is a single ext op otherwise we would have to // duplicate the convert. - if (extOp != nullptr) + if (extOrBroadcatOp != nullptr) return; - extOp = op; + extOrBroadcatOp = op; } } - if (extOp == nullptr) + if (extOrBroadcatOp == nullptr) + return; + std::optional srcEncoding = + inferSrcEncoding(extOrBroadcatOp, layout[extOrBroadcatOp->getResult(0)]); + if (!srcEncoding) return; std::optional srcEncoding = inferSrcEncoding(extOp, layout[extOp->getResult(0)]); // Move the convert before the ext op and rewrite the slice. +<<<<<<< HEAD OpBuilder builder(extOp); auto tensorType = extOp->getOperand(0).getType().cast(); auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), *srcEncoding); +======= + OpBuilder builder(extOrBroadcatOp); + auto tensorType = + extOrBroadcatOp->getOperand(0).getType().cast(); + auto newType = RankedTensorType::get( + tensorType.getShape(), tensorType.getElementType(), *srcEncoding); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 auto newConvertOp = builder.create( - convertOp.getLoc(), newType, extOp->getOperand(0)); + convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); IRMapping mapping; - mapping.map(extOp->getOperand(0), newConvertOp.getResult()); + mapping.map(extOrBroadcatOp->getOperand(0), newConvertOp.getResult()); // 3. Rewrite the slice. rewriteSlice(slice, layout, convertOp, mapping); } @@ -1075,7 +1106,7 @@ static void hoistConvert(ModuleOp module) { module.walk( [&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); }); for (ConvertLayoutOp convertOp : convertOps) { - hoistConvertOnTopOfExt(convertOp); + hoistConvertOnTopOfExtOrBroadcast(convertOp); } } diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 6e92fb290..f315fe5ad 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -492,30 +492,6 @@ Value linearize(OpBuilder &b, Location loc, ArrayRef multiDim, return linear; } -std::optional getWSAgentId(Operation *op) { - int prevAgentId = -1; - if (auto attr = op->getAttrOfType("async_agent")) { - for (auto agentId : attr.getValues()) { - assert(prevAgentId == -1 && "support at most one agent id"); - prevAgentId = agentId; - } - } - if (prevAgentId == -1) - return std::nullopt; - return prevAgentId; -} - -std::optional getWSRoleId(Operation *op) { - if (!op->hasAttr("agent.mutex_role")) - return std::nullopt; - return op->getAttrOfType("agent.mutex_role").getInt(); -} - -void setRoleId(Operation *op, int roleId) { - auto attr = IntegerAttr::get(IntegerType::get(op->getContext(), 32), roleId); - op->setAttr("agent.mutex_role", attr); -} - namespace { /// Detect dead arguments in scf.for op by assuming all the values are dead and @@ -597,7 +573,7 @@ struct ForOpDeadArgElimination : public OpRewritePattern { Value yieldOperand = forOwner.getBody()->getTerminator()->getOperand(iterIdx); markLive(yieldOperand); - markLive(forOwner.getIterOperands()[iterIdx]); + markLive(forOwner.getInitArgs()[iterIdx]); } } SmallVector deadArg; diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp index 0a982ce05..c7985a927 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp @@ -67,3 +67,27 @@ TritonNvidiaGPUDialect::verifyOperationAttribute(Operation *op, // TODO: fill this. return success(); } + +std::optional mlir::getWSAgentId(Operation *op) { + int prevAgentId = -1; + if (auto attr = op->getAttrOfType("async_agent")) { + for (auto agentId : attr.getValues()) { + assert(prevAgentId == -1 && "support at most one agent id"); + prevAgentId = agentId; + } + } + if (prevAgentId == -1) + return std::nullopt; + return prevAgentId; +} + +std::optional mlir::getWSRoleId(Operation *op) { + if (!op->hasAttr("agent.mutex_role")) + return std::nullopt; + return op->getAttrOfType("agent.mutex_role").getInt(); +} + +void mlir::setRoleId(Operation *op, int roleId) { + auto attr = IntegerAttr::get(IntegerType::get(op->getContext(), 32), roleId); + op->setAttr("agent.mutex_role", attr); +} diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp index 3175fbbfb..25a43529d 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp @@ -3,6 +3,7 @@ #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/Debug.h" //===----------------------------------------------------------------------===// @@ -32,16 +33,19 @@ public: FenceInsertionPass(int computeCapability) { this->computeCapability = computeCapability; } - // TODO: support more patterns to insert fences - // only support insertion between convert layout ops and dot ops to protect - // flashattention + // TODO: support more general patterns to insert fences. eg. any op(generic) + // to shared in use-def chain which refers by async proxy. We have generic( + // convertlayout with sts/stmatix) + fence + async(wgmma) up to now void runOnOperation() override { // Only insert fences for compute capability 9.0 if (computeCapability < 90) return; + if (::triton::tools::getBoolEnv("DISABLE_MMA_V3")) + return; ModuleOp mod = getOperation(); mod.walk([&](Operation *op) { - if (isa(op)) { + if (isa(op)) { + OpBuilder builder(op); auto a = op->getOperand(0); auto b = op->getOperand(1); auto mmaEncoding = op->getResult(0) @@ -50,21 +54,70 @@ public: .getEncoding() .dyn_cast(); auto isHopperEncoding = mmaEncoding && mmaEncoding.isHopper(); - if (isHopperEncoding && (isa(a.getDefiningOp()) && - ttg::isSharedEncoding(a)) || - (isa(b.getDefiningOp()) && - ttg::isSharedEncoding(b))) { - - // TODO: check whether cluster fence is needed - OpBuilder builder(op); + if (isHopperEncoding && + (dependOnSharedEncOperand(a) || dependOnSharedEncOperand(b))) { builder.create(op->getLoc(), false /*bCluster*/); } } }); } -}; +private: + bool dependOnSharedEncOperand(Value operand) { + static DenseSet> trace; + auto op = operand.getDefiningOp(); + // avoid redundant insertion + if (op && isa(op)) + return false; + // reach convertlayout + if (op && isa(op) && ttg::isSharedEncoding(operand)) + return true; + // root and not BlockArgument + if (!op && !isa(operand)) + return false; + // op and not BlockArgument + if (op && !isa(operand)) { + for (auto v : op->getOperands()) { + if (dependOnSharedEncOperand(v)) + return true; + } + } + // reach BlockArgument + // TODO: support other scf ops, IfOp, WhileOp, etc. + if (BlockArgument arg = dyn_cast(operand)) { + unsigned argNum = arg.getArgNumber(); + Operation *argOwner = arg.getOwner()->getParentOp(); + // suport ForOp only + if (auto forOp = dyn_cast(argOwner)) { + // prologue + auto iterOperands = forOp.getIterOperands(); + if (argNum == 0) + return false; + if (dependOnSharedEncOperand(iterOperands[argNum - 1])) + return true; + // yield + auto yieldOp = forOp.getBody()->getTerminator(); + Value v = yieldOp->getOperand(argNum - 1); + auto entry = std::make_pair(std::move(yieldOp), + std::move(argNum)); + // avoid cyclic + if (trace.contains(entry)) + return false; + else + trace.insert(entry); + + if (dependOnSharedEncOperand(v)) + return true; + } else if (auto whileOp = dyn_cast(argOwner)) { + assert(false && "FenceInsertionPass does not supported WhileOp"); + } else if (auto ifOp = dyn_cast(argOwner)) { + assert(false && "FenceInsertionPass does not supported IfOp"); + } + } + return false; + } +}; } // namespace std::unique_ptr diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp index e13cf8bd9..5a2d3beaa 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/RewriteTensorPointer.cpp @@ -523,9 +523,9 @@ public: std::stack &eraser, DenseSet &valueToRemove) { // Generate new iteration operands and set rewrited information - SmallVector oldIterOperands = op.getIterOperands(); - SmallVector newIterOperands = op.getIterOperands(); - for (unsigned i = 0, oldI = 0, size = op.getNumIterOperands(); i < size; + SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); + SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); + for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; ++i, ++oldI) { if (!tt::isTensorPointerType(newIterOperands[i].getType())) continue; @@ -550,7 +550,7 @@ public: // mapping. It may refer to a value in the old loop, but we will rewrite it // later IRMapping mapping; - for (unsigned i = 0, oldI = 0; oldI < op.getNumIterOperands(); + for (unsigned i = 0, oldI = 0; oldI < op.getInitArgs().size(); ++i, ++oldI) { auto oldRegionIterArg = op.getRegionIterArg(oldI); if (tt::isTensorPointerType(oldRegionIterArg.getType()) && @@ -586,7 +586,7 @@ public: valueToRemove.insert(v); // Replace later usages - assert(op.getNumResults() == op.getNumIterOperands()); + assert(op.getNumResults() == op.getInitArgs().size()); for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { auto oldResult = op.getResult(oldI); if (tt::isTensorPointerType(oldResult.getType()) && @@ -787,8 +787,8 @@ public: } } if (auto forOp = dyn_cast(op)) { - SmallVector iterOperands = forOp.getIterOperands(); - for (unsigned i = 0, size = forOp.getNumIterOperands(); i < size; ++i) { + SmallVector iterOperands = llvm::to_vector(forOp.getInitArgs()); + for (unsigned i = 0, size = forOp.getInitArgs().size(); i < size; ++i) { if (tt::isTensorPointerType(iterOperands[i].getType())) { auto makeTensorPtrOp = getMakeTensorPtrOp(iterOperands[i]); if (shouldRemove(makeTensorPtrOp, computeCapability)) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp index a7bab9ff1..dd39e94b8 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSMutex.cpp @@ -18,7 +18,9 @@ namespace ttng = triton::nvidia_gpu; namespace { // Target operations: dot, load, store. Add more when necessary. -#define KEY_TYPES triton::DotOp, ttg::InsertSliceOp, triton::StoreOp +#define KEY_TYPES \ + triton::DotOp, triton::nvidia_gpu::DotAsyncOp, ttg::InsertSliceOp, \ + triton::StoreOp template void getKeyTypeId(Operation *op, int &id, bool &found) { @@ -151,7 +153,9 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, Value newIdx = builder.createWithAgentIds(loc, pipelineIdx, curRoleId); - persistentForOp.setIterArg(persistentForOp.getNumIterOperands() - 1, newIdx); + persistentForOp.getInitArgsMutable() + .slice(persistentForOp.getInitArgs().size() - 1, 1) + .assign(newIdx); auto yield = llvm::cast(persistentForOp.getBody()->getTerminator()); auto idxPlusOneOp = @@ -209,6 +213,45 @@ void mutexSync(ModuleOp &mod, scf::IfOp &ifOp, scf::ForOp &persistentForOp, unlockLocs[i] = op; } + // Update unlockLocs + // ====================== IR after async launch dots ====================== + // * %0:2 = scf.for %arg0 = %c0 to %1 step %c1 iter_args(%arg1 = %2, arg2 = + // %3) { + // * triton_nvidia_gpu.producer_wait arg2 + // * %5 = triton_nvidia_gpu.dot_async %4, %5 + // * triton_nvidia_gpu.dot_wait {pendings = 1} + // * %6 = arith.cmpi sgt, arg0, %c0 + // * scf.if %6 { + // * %7 = arith.subi arg2, c1 + // * triton_nvidia_gpu.consumer_release %7 + // * } + // * %8 = arith.addi arg2, c1 + // * scf.yield %5, %8 + // * } + // * triton_nvidia_gpu.dot_wait {pendings = 0} + // * %9 = arith.subi %0#1, c1 + // * triton_nvidia_gpu.consumer_release %9 + // * ======================================================================= + // after async launch dots, there will be outstanding consumerReleaseOp after + // ForOp. we should expend the unlockLocs from ForOp to the outstanding + // consumerReleaseOp. + for (int i = 0; i < numRoles; ++i) { + Operation *unlockOp = unlockLocs[i]; + auto filter = [&](Operation *op) { + return op->getBlock() == unlockOp->getBlock(); + }; + if (isa(unlockOp)) { + SetVector slices; + mlir::getForwardSlice(unlockOp->getResults().back(), &slices, {filter}); + auto iter = llvm::find_if(slices, [](Operation *op) { + return isa(op); + }); + if (iter != slices.end()) { + unlockLocs[i] = *iter; + } + } + } + // Only cases where all lock/unlock locations are in same level make sense. for (int i = 1; i < numRoles; ++i) { if (lockLocs[i]->getParentOp() != lockLocs[i - 1]->getParentOp() || diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp index 459eff719..5d6417fab 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/WSPipeline.cpp @@ -162,7 +162,7 @@ scf::ForOp appendPipelineIdxToLoopArgs(scf::ForOp forOp, int numStages, // Copy iter operands of forOp SmallVector newLoopArgs; - for (auto operand : forOp.getIterOperands()) + for (auto operand : llvm::to_vector(forOp.getInitArgs())) newLoopArgs.push_back(operand); // Append initial value of pipelineIdx to newLoopArgs @@ -302,7 +302,7 @@ DenseMap createForOpsForEachAgentId(scf::ForOp forOp) { // Prepare newLoopArgs SmallVector newLoopArgs; for (unsigned argNumber : usedArgs) - newLoopArgs.push_back(forOp.getIterOperands()[argNumber]); + newLoopArgs.push_back(forOp.getInitArgs()[argNumber]); // Create newForOp builder.setAgentIdsFromArray({agentId}); @@ -341,7 +341,9 @@ DenseMap createForOpsForEachAgentId(scf::ForOp forOp) { for (unsigned i = 0; i < usedArgs.size(); ++i) { auto oldResult = forOp.getResult(usedArgs[i]); auto newResult = newForOp.getResult(i); - oldResult.replaceAllUsesWith(newResult); + oldResult.replaceUsesWithIf(newResult, [&](OpOperand &operand) -> bool { + return hasAgentId(operand.getOwner(), agentId); + }); } agentsToForOp[agentId] = newForOp; @@ -642,6 +644,30 @@ void buildAsyncComm(const DenseMap> &map, agentsPC.insert(agentsPC.end(), agentP.begin(), agentP.end()); agentsPC.insert(agentsPC.end(), agentC.begin(), agentC.end()); }; + + // Don't pipeline dots that depend on ops other than scf.yield and scf.for. + // Because the DotOp will be replaced by a DotAsyncOp, which will be issued in + // iter_i but waited in iter_i+1. The use of DotAsyncOp should not be ops + // other than scf.for and scf.yield because the result of DotAsyncOp is not + // ready in iter_i. + auto getValidDot = [&](const SmallVector &block) -> Operation * { + Operation *headConsumer = block.front()->dstOp; + if (block.size() == 2 && + isa(*headConsumer->getUsers().begin()) && + headConsumer->getParentOfType()) { + auto dotOp = cast(*headConsumer->getUsers().begin()); + auto dot = dotOp.getResult(); + auto resTy = dot.getType().dyn_cast(); + auto cArg = dotOp.getOperand(2).dyn_cast(); + if (auto resEnc = resTy.getEncoding().dyn_cast()) + if (resEnc.isHopper() && dot.hasOneUse() && + isa(*dot.getUsers().begin()) && cArg && + cArg.hasOneUse()) + return dotOp.getOperation(); + } + return nullptr; + }; + // TODO: try to optimize locations of arriving and waiting token // for fused-attention for (auto kv : map) { @@ -694,12 +720,69 @@ void buildAsyncComm(const DenseMap> &map, builder.createWithAgentIds(headConsumer->getLoc(), token, pipelineIdx); - // insert ConsumerReleaseOp - auto consumerReleasePoint = - consumerReleaseHeutistic(tailProducer, tailConsumer); - builder.setInsertionPointAfter(consumerReleasePoint); - builder.createWithAgentIds( - consumerReleasePoint->getLoc(), token, pipelineIdx); + /// async launch dots + if (auto cvg = getValidDot(kv.second)) { + auto dotOp = cast(cvg); + auto dot = dotOp.getResult(); + auto loc = dot.getLoc(); + auto forOp = cvg->getParentOfType(); + + auto agentIds = collectAgentIds(dotOp); + OpBuilderWithAgentIds builder(dotOp.getContext()); + builder.setAgentIdsFromArray(agentIds); + builder.setInsertionPoint(dotOp); + + // 0. replace Dot with DotAsync + auto dotAsync = + builder.createWithAgentIds( + loc, dotOp.getA(), dotOp.getB(), dotOp.getC(), + dotOp.getAllowTF32(), dotOp.getMaxNumImpreciseAcc()); + dot.replaceAllUsesWith(dotAsync.getResult()); + builder.createWithAgentIds(loc, 1); + + // 1. insert ConsumerReleaseOp for DotAsyncOps + Value cond = builder.createWithAgentIds( + loc, arith::CmpIPredicate::sgt, forOp.getInductionVar(), + forOp.getLowerBound()); + auto ifOp = + builder.createWithAgentIds(loc, ArrayRef{}, cond, + /*hasElse*/ false); + builder.setInsertionPointToStart(ifOp.thenBlock()); + Value one = builder.createWithAgentIds( + headConsumer->getLoc(), 1, 32); + auto oriIdx = forOp.getBody()->getArguments().back(); + Value consumerReleaseIdx = + builder.createWithAgentIds(loc, oriIdx, one); + consumerReleaseIdx = builder.createWithAgentIds( + loc, consumerReleaseIdx, numStagesVal); + builder.createWithAgentIds(loc, token, + consumerReleaseIdx); + setAgentIds(ifOp.thenYield().getOperation(), agentIds); + + // 2. If there's any outstanding DotAsyncOps, we need to wait for them. + builder.setInsertionPointAfter(forOp); + builder.createWithAgentIds(forOp.getLoc(), + 0); + + // 3. insert ConsumerReleaseOp for outstanding DotAsyncOps + Value one_ = builder.createWithAgentIds( + headConsumer->getLoc(), 1, 32); + consumerReleaseIdx = forOp.getResults().back(); + consumerReleaseIdx = builder.createWithAgentIds( + loc, consumerReleaseIdx, one_); + consumerReleaseIdx = builder.createWithAgentIds( + loc, consumerReleaseIdx, numStagesVal); + builder.createWithAgentIds(loc, token, + consumerReleaseIdx); + dotOp->erase(); + } else { + // insert ConsumerReleaseOp + auto consumerReleasePoint = + consumerReleaseHeutistic(tailProducer, tailConsumer); + builder.setInsertionPointAfter(consumerReleasePoint); + builder.createWithAgentIds( + consumerReleasePoint->getLoc(), token, pipelineIdx); + } /*****************Buffer related*****************/ /// splitLoadsInForLoop diff --git a/lib/Target/CMakeLists.txt b/lib/Target/CMakeLists.txt index 99cf364fa..9b24f0ff2 100644 --- a/lib/Target/CMakeLists.txt +++ b/lib/Target/CMakeLists.txt @@ -1,3 +1,2 @@ add_subdirectory(LLVMIR) add_subdirectory(PTX) -add_subdirectory(HSACO) diff --git a/lib/Target/LLVMIR/CMakeLists.txt b/lib/Target/LLVMIR/CMakeLists.txt index fbaefe683..9c0a6c26e 100644 --- a/lib/Target/LLVMIR/CMakeLists.txt +++ b/lib/Target/LLVMIR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_translation_library(TritonLLVMIR LLVMIRTranslation.cpp LLVMDIScope.cpp + LLVMIRBreakPhiStruct.cpp LINK_COMPONENTS Core diff --git a/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp b/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp new file mode 100644 index 000000000..44afcfd21 --- /dev/null +++ b/lib/Target/LLVMIR/LLVMIRBreakPhiStruct.cpp @@ -0,0 +1,60 @@ +//===----------------------------------------------------------------------===// +/// Implements a trivial pass breaking up 1 level deep structure in phi nodes. +/// This handles the common case generated by Triton and allow better +/// optimizations down the compiler pipeline. +//===----------------------------------------------------------------------===// +#include "LLVMPasses.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" + +using namespace llvm; + +static bool processPhiStruct(PHINode *phiNode) { + StructType *STy = dyn_cast(phiNode->getType()); + if (!STy) + return false; + IRBuilder<> builder(phiNode); + unsigned numOperands = phiNode->getNumIncomingValues(); + unsigned numScalarEl = STy->getNumElements(); + Value *newStruct = UndefValue::get(STy); + builder.SetInsertPoint(phiNode->getParent()->getFirstNonPHI()); + llvm::IRBuilderBase::InsertPoint insertInsertPt = builder.saveIP(); + for (unsigned i = 0; i < numScalarEl; i++) { + builder.SetInsertPoint(phiNode); + PHINode *newPhiNode = + builder.CreatePHI(STy->getElementType(i), numOperands); + for (unsigned j = 0; j < numOperands; ++j) { + Value *operand = phiNode->getIncomingValue(j); + builder.SetInsertPoint(phiNode->getIncomingBlock(j)->getTerminator()); + newPhiNode->addIncoming(builder.CreateExtractValue(operand, i), + phiNode->getIncomingBlock(j)); + } + builder.restoreIP(insertInsertPt); + newStruct = builder.CreateInsertValue(newStruct, newPhiNode, i); + insertInsertPt = builder.saveIP(); + } + phiNode->replaceAllUsesWith(newStruct); + return true; +} + +static bool runOnFunction(Function &F) { + bool Changed = false; + SmallVector PhiNodes; + for (BasicBlock &BB : F) { + for (Instruction &inst : BB) { + if (PHINode *phiNode = dyn_cast(&inst)) { + Changed |= processPhiStruct(phiNode); + continue; + } + break; + } + } + return Changed; +} + +PreservedAnalyses BreakStructPhiNodesPass::run(Function &F, + FunctionAnalysisManager &AM) { + + bool b = runOnFunction(F); + return b ? PreservedAnalyses::none() : PreservedAnalyses::all(); +} diff --git a/lib/Target/LLVMIR/LLVMIRTranslation.cpp b/lib/Target/LLVMIR/LLVMIRTranslation.cpp index 45f0fedaf..077e1a3af 100644 --- a/lib/Target/LLVMIR/LLVMIRTranslation.cpp +++ b/lib/Target/LLVMIR/LLVMIRTranslation.cpp @@ -1,8 +1,13 @@ #include "triton/Target/LLVMIR/LLVMIRTranslation.h" +<<<<<<< HEAD #include "mlir/Conversion/Passes.h" +======= +#include "LLVMPasses.h" +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" #include "mlir/Conversion/IndexToLLVM/IndexToLLVM.h" +#include "mlir/Conversion/Passes.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/Passes.h" @@ -28,13 +33,25 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Module.h" #include "llvm/IRReader/IRReader.h" #include "llvm/Linker/Linker.h" +#include "llvm/Passes/OptimizationLevel.h" +#include "llvm/Passes/PassBuilder.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/Support/SourceMgr.h" +<<<<<<< HEAD #include +======= +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/InstCombine/InstCombine.h" +#include +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 #ifdef _WIN32 #define WIN32_LEAN_AND_MEAN #include @@ -46,6 +63,91 @@ namespace fs = std::filesystem; +namespace { +using namespace llvm; + +static std::optional mapToLevel(unsigned optLevel, + unsigned sizeLevel) { + switch (optLevel) { + case 0: + return OptimizationLevel::O0; + + case 1: + return OptimizationLevel::O1; + + case 2: + switch (sizeLevel) { + case 0: + return OptimizationLevel::O2; + + case 1: + return OptimizationLevel::Os; + + case 2: + return OptimizationLevel::Oz; + } + break; + case 3: + return OptimizationLevel::O3; + } + return std::nullopt; +} + +// Create and return a lambda that uses LLVM pass manager builder to set up +// optimizations based on the given level. +static std::function +makeOptimizingPipeline(unsigned optLevel, unsigned sizeLevel, + TargetMachine *targetMachine) { + return [optLevel, sizeLevel, targetMachine](Module *m) -> Error { + std::optional ol = mapToLevel(optLevel, sizeLevel); + if (!ol) { + return make_error( + formatv("invalid optimization/size level {0}/{1}", optLevel, + sizeLevel) + .str(), + inconvertibleErrorCode()); + } + LoopAnalysisManager lam; + FunctionAnalysisManager fam; + CGSCCAnalysisManager cgam; + ModuleAnalysisManager mam; + + PipelineTuningOptions tuningOptions; + tuningOptions.LoopUnrolling = true; + tuningOptions.LoopInterleaving = true; + tuningOptions.LoopVectorization = true; + // TODO: currently we run SLP vectorizer with an empty target machine. This + // cause the vectorizer to create larger vector which could be bad. + // Disabling it would currently cause regressions as this pass also applies + // some scheduling that helps performance in some cases. We should work on + // using NVPTX target instead and address the performance regressions with + // some scheduling solution. + tuningOptions.SLPVectorization = true; + + PassBuilder pb(targetMachine, tuningOptions); + + pb.registerModuleAnalyses(mam); + pb.registerCGSCCAnalyses(cgam); + pb.registerFunctionAnalyses(fam); + pb.registerLoopAnalyses(lam); + pb.crossRegisterProxies(lam, fam, cgam, mam); + + ModulePassManager mpm; + pb.registerVectorizerStartEPCallback( + [&](llvm::FunctionPassManager &fpm, llvm::OptimizationLevel level) { + // Triton generates large structure of scalars which may pessimise + // optimizations, we run a pass to break up phi of struct to make sure + // all the struct are removed for the following passes. + fpm.addPass(BreakStructPhiNodesPass()); + fpm.addPass(InstCombinePass()); + }); + mpm.addPass(pb.buildPerModuleDefaultPipeline(*ol)); + mpm.run(*m, mam); + return Error::success(); + }; +} +} // namespace + namespace mlir { namespace triton { @@ -318,7 +420,7 @@ translateLLVMToLLVMIR(llvm::LLVMContext *llvmContext, mlir::ModuleOp module, return nullptr; } - auto optPipeline = mlir::makeOptimizingTransformer( + auto optPipeline = makeOptimizingPipeline( /*optLevel=*/3, /*sizeLevel=*/0, /*targetMachine=*/nullptr); @@ -367,8 +469,12 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext, pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass()); pm.addPass( +<<<<<<< HEAD createConvertTritonGPUToLLVMPass({computeCapability, &tmaInfos, target})); #ifndef USE_ROCM +======= + createConvertTritonGPUToLLVMPass(computeCapability, target, &tmaInfos)); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.addPass(createConvertNVGPUToLLVMPass()); #endif pm.addPass(mlir::createArithToLLVMConversionPass()); diff --git a/lib/Target/LLVMIR/LLVMPasses.h b/lib/Target/LLVMIR/LLVMPasses.h new file mode 100644 index 000000000..1dcdb2992 --- /dev/null +++ b/lib/Target/LLVMIR/LLVMPasses.h @@ -0,0 +1,16 @@ +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/CodeGen.h" + +namespace llvm { + +// Pass to pre-process LLVM IR before optimization and break up phi of struct. +// Breaking up those phis into elementary types allows better optimizations +// downstream. +struct BreakStructPhiNodesPass : PassInfoMixin { + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); + + static StringRef name() { return "BreakStructPhiNodesPass"; } +}; + +} // namespace llvm diff --git a/lib/Target/PTX/PTXTranslation.cpp b/lib/Target/PTX/PTXTranslation.cpp index fe8841997..3ae1bac1a 100644 --- a/lib/Target/PTX/PTXTranslation.cpp +++ b/lib/Target/PTX/PTXTranslation.cpp @@ -16,6 +16,7 @@ #include #include +#include #include #include @@ -87,9 +88,9 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) { opt.UnsafeFPMath = false; opt.NoInfsFPMath = false; opt.NoNaNsFPMath = true; - llvm::TargetMachine *machine = target->createTargetMachine( + std::unique_ptr machine{target->createTargetMachine( module.getTargetTriple(), proc, features, opt, llvm::Reloc::PIC_, - std::nullopt, llvm::CodeGenOpt::Aggressive); + std::nullopt, llvm::CodeGenOpt::Aggressive)}; // set data layout if (layout.empty()) module.setDataLayout(machine->createDataLayout()); diff --git a/python/setup.py b/python/setup.py index ff0d19930..f420e7f18 100644 --- a/python/setup.py +++ b/python/setup.py @@ -8,6 +8,7 @@ import sysconfig import tarfile import tempfile import urllib.request +from distutils.command.clean import clean from pathlib import Path from typing import NamedTuple @@ -58,8 +59,8 @@ class Package(NamedTuple): def get_pybind11_package_info(): - name = "pybind11-2.10.0" - url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz" + name = "pybind11-2.11.1" + url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz" return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH") # llvm @@ -124,15 +125,12 @@ def get_thirdparty_packages(triton_cache_path): # ---- package data --- -def download_and_copy_ptxas(): - +def download_and_copy(src_path, version, url_func): base_dir = os.path.dirname(__file__) - src_path = "bin/ptxas" - version = "12.1.105" arch = platform.machine() if arch == "x86_64": arch = "64" - url = f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2" + url = url_func(arch, version) dst_prefix = os.path.join(base_dir, "triton") dst_suffix = os.path.join("third_party", "cuda", src_path) dst_path = os.path.join(dst_prefix, dst_suffix) @@ -155,9 +153,28 @@ def download_and_copy_ptxas(): shutil.copy(src_path, dst_path) return dst_suffix - # ---- cmake extension ---- + +def get_base_dir(): + return os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) + + +def get_cmake_dir(): + plat_name = sysconfig.get_platform() + python_version = sysconfig.get_python_version() + dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}" + cmake_dir = Path(get_base_dir()) / "python" / "build" / dir_name + cmake_dir.mkdir(parents=True, exist_ok=True) + return cmake_dir + + +class CMakeClean(clean): + def initialize_options(self): + clean.initialize_options(self) + self.build_temp = get_cmake_dir() + + class CMakeBuildPy(build_py): def run(self) -> None: self.run_command('build_ext') @@ -178,10 +195,7 @@ class CMakeBuild(build_ext): def initialize_options(self): build_ext.initialize_options(self) - self.base_dir = os.path.abspath( - os.path.join( - os.path.dirname(__file__), - os.pardir)) + self.base_dir = get_base_dir() def finalize_options(self): build_ext.finalize_options(self) @@ -200,14 +214,6 @@ class CMakeBuild(build_ext): for ext in self.extensions: self.build_extension(ext) - def get_cmake_dir(self): - plat_name = sysconfig.get_platform() - python_version = sysconfig.get_python_version() - dir_name = f"cmake.{plat_name}-{sys.implementation.name}-{python_version}" - cmake_dir = Path(self.base_dir) / "python" / "build" / dir_name - cmake_dir.mkdir(parents=True, exist_ok=True) - return cmake_dir - def build_extension(self, ext): lit_dir = shutil.which('lit') ninja_dir = shutil.which('ninja') @@ -267,14 +273,21 @@ class CMakeBuild(build_ext): "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"] + if check_env_flag("TRITON_BUILD_WITH_CCACHE"): + cmake_args += [ + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + ] + env = os.environ.copy() - cmake_dir = self.get_cmake_dir() + cmake_dir = get_cmake_dir() subprocess.check_call(["cmake", self.base_dir] + cmake_args, cwd=cmake_dir, env=env) subprocess.check_call(["cmake", "--build", "."] + build_args, cwd=cmake_dir) + subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -download_and_copy_ptxas() - +download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2") +download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2") +download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2") setup( name="triton", @@ -288,7 +301,6 @@ setup( "triton/_C", "triton/common", "triton/compiler", - "triton/interpreter", "triton/language", "triton/language/extra", "triton/ops", @@ -304,7 +316,7 @@ setup( ], include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], - cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy}, + cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean}, zip_safe=False, # for PyPI keywords=["Compiler", "Deep Learning"], diff --git a/python/src/triton.cc b/python/src/triton.cc index 3e43e52dd..b05202a68 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -33,7 +33,6 @@ #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" -#include "triton/Target/HSACO/HSACOTranslation.h" #include "triton/Target/LLVMIR/LLVMIRTranslation.h" #include "triton/Target/PTX/PTXTranslation.h" #include "triton/Target/HSACO/HSACOTranslation.h" @@ -66,6 +65,7 @@ #include #include +#include namespace py = pybind11; PYBIND11_MAKE_OPAQUE(mlir::triton::gpu::TMAMetadataTy); @@ -1592,9 +1592,10 @@ void init_triton_ir(py::module &&m) { }) .def("create_dot", [](TritonOpBuilder &self, mlir::Value &a, mlir::Value &b, - mlir::Value &c, bool allowTF32) -> mlir::Value { - return self.create(c.getType(), a, b, c, - allowTF32); + mlir::Value &c, bool allowTF32, + int maxNumImpreciseAcc) -> mlir::Value { + return self.create( + c.getType(), a, b, c, allowTF32, maxNumImpreciseAcc); }) .def("create_exp", [](TritonOpBuilder &self, mlir::Value &val) -> mlir::Value { @@ -2095,7 +2096,9 @@ void init_triton_translation(py::module &m) { const std::vector &paths) { ::mlir::triton::addExternalLibs(op, names, paths); }); +} +<<<<<<< HEAD m.def( "translate_llvmir_to_hsaco", [](const std::string llvmIR, std::string gfx_arch, std::string gfx_triple, @@ -2113,6 +2116,46 @@ void init_triton_translation(py::module &m) { return hsacoCode; }, ret::take_ownership); +======= +void init_triton_interpreter(py::module &&m) { + using ret = py::return_value_policy; + + m.def("load", + [](py::array_t ptrs, py::array_t masks, py::array other, + py::dtype ret_dtype) -> py::array { + int numel = ptrs.size(); + auto shape = + std::vector(ptrs.shape(), ptrs.shape() + ptrs.ndim()); + py::array ret(ret_dtype, py::array::ShapeContainer{numel}); + py::array_t reshaped_ptrs = ptrs.reshape({numel}); + py::array_t reshaped_masks = masks.reshape({numel}); + py::array reshaped_others = other.reshape({numel}); + for (size_t i = 0; i < ptrs.size(); ++i) { + if (reshaped_masks.at(i)) + memcpy(ret.mutable_data(i), + reinterpret_cast(reshaped_ptrs.at(i)), + ret_dtype.itemsize()); + else + memcpy(ret.mutable_data(i), reshaped_others.data(i), + ret_dtype.itemsize()); + } + return ret.reshape(shape); + }); + + m.def("store", [](py::array_t ptrs, py::array values, + py::array_t mask) { + int numel = ptrs.size(); + py::array_t reshaped_ptrs = ptrs.reshape({numel}); + py::array_t reshaped_masks = mask.reshape({numel}); + py::array reshaped_values = values.reshape({numel}); + for (size_t i = 0; i < ptrs.size(); ++i) { + if (reshaped_masks.at(i)) { + memcpy(reinterpret_cast(reshaped_ptrs.mutable_at(i)), + reshaped_values.data(i), values.dtype().itemsize()); + } + } + }); +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 } void init_triton(py::module &m) { @@ -2121,5 +2164,6 @@ void init_triton(py::module &m) { // init_triton_codegen(subm.def_submodule("code_gen")); init_triton_runtime(subm.def_submodule("runtime")); init_triton_ir(subm.def_submodule("ir")); + init_triton_interpreter(subm.def_submodule("interpreter")); init_triton_translation(subm); } diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py index 684cbfb4d..b873db7a3 100644 --- a/python/test/regression/test_functional_regressions.py +++ b/python/test/regression/test_functional_regressions.py @@ -227,4 +227,4 @@ def test_iv_dependent_matmul(type): b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1), BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, num_stages=num_stages) - torch.testing.assert_allclose(torch_output, triton_output, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index b22fea3e5..ec9966d29 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -1,13 +1,10 @@ -import subprocess -import sys - import pytest import torch import triton import triton.language as tl import triton.ops -from triton.testing import get_dram_gbps, get_max_tensorcore_tflops +from triton.testing import get_dram_gbps, get_max_tensorcore_tflops, nvsmi DEVICE_NAME = {7: 'v100', 8: 'a100'}[torch.cuda.get_device_capability()[0]] @@ -21,15 +18,6 @@ def print_perf(cur_ms, cur_util, ref_util): print(f'{cur_ms:.3f} ms \t cur: {cur_util:.3f} \t ref: {ref_util:.3f} \t dif={cur_util - ref_util:.3f}', end='\t') -def nvsmi(attrs): - attrs = ','.join(attrs) - cmd = ['nvidia-smi', '-i', '0', '--query-gpu=' + attrs, '--format=csv,noheader,nounits'] - out = subprocess.check_output(cmd) - ret = out.decode(sys.stdout.encoding).split(',') - ret = [int(x) for x in ret] - return ret - - ####################### # Matrix Multiplication ####################### @@ -41,22 +29,22 @@ matmul_data = { # NOTE: 'a100': { # square - (512, 512, 512): {'float16': 0.061, 'float32': 0.097, 'int8': 0.05}, - (1024, 1024, 1024): {'float16': 0.283, 'float32': 0.313, 'int8': 0.169}, - (2048, 2048, 2048): {'float16': 0.618, 'float32': 0.532, 'int8': 0.34}, - (8192, 8192, 8192): {'float16': 0.786, 'float32': 0.754, 'int8': 0.51}, + (512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05}, + (1024, 1024, 1024): {'float16': 0.355, 'float32': 0.313, 'int8': 0.169}, + (2048, 2048, 2048): {'float16': 0.653, 'float32': 0.532, 'int8': 0.34}, + (8192, 8192, 8192): {'float16': 0.839, 'float32': 0.754, 'int8': 0.51}, # tall-skinny - (16, 1024, 1024): {'float16': 0.006, 'float32': 0.009, 'int8': 0.005}, - (16, 4096, 4096): {'float16': 0.057, 'float32': 0.051, 'int8': 0.026}, - (16, 8192, 8192): {'float16': 0.077, 'float32': 0.077, 'int8': 0.043}, - (64, 1024, 1024): {'float16': 0.018, 'float32': 0.023, 'int8': 0.017}, - (64, 4096, 4096): {'float16': 0.150, 'float32': 0.000, 'int8': 0.097}, - (64, 8192, 8192): {'float16': 0.338, 'float32': 0.000, 'int8': 0.174}, - (1024, 64, 1024): {'float16': 0.029, 'float32': 0.046, 'int8': 0.017}, - (4096, 64, 4096): {'float16': 0.179, 'float32': 0.214, 'int8': 0.102}, - (8192, 64, 8192): {'float16': 0.278, 'float32': 0.000, 'int8': 0.177}, + (16, 1024, 1024): {'float16': 0.015, 'float32': 0.009, 'int8': 0.005}, + (16, 4096, 4096): {'float16': 0.080, 'float32': 0.051, 'int8': 0.026}, + (16, 8192, 8192): {'float16': 0.083, 'float32': 0.077, 'int8': 0.043}, + (64, 1024, 1024): {'float16': 0.045, 'float32': 0.023, 'int8': 0.017}, + (64, 4096, 4096): {'float16': 0.170, 'float32': 0.000, 'int8': 0.097}, + (64, 8192, 8192): {'float16': 0.227, 'float32': 0.000, 'int8': 0.174}, + (1024, 64, 1024): {'float16': 0.040, 'float32': 0.046, 'int8': 0.017}, + (4096, 64, 4096): {'float16': 0.160, 'float32': 0.214, 'int8': 0.102}, + (8192, 64, 8192): {'float16': 0.272, 'float32': 0.000, 'int8': 0.177}, # test EVEN_K==False - (8192, 8192, 8176): {'float16': 0.786, 'float32': 0.743, 'int8': 0.51}, + (8192, 8192, 8176): {'float16': 0.828, 'float32': 0.743, 'int8': 0.51}, } } @@ -112,15 +100,15 @@ def _add(x_ptr, y_ptr, output_ptr, n_elements, elementwise_data = { 'a100': { - 1024 * 16: {'float16': 0.003, 'float32': 0.007}, - 1024 * 64: {'float16': 0.013, 'float32': 0.026}, - 1024 * 256: {'float16': 0.053, 'float32': 0.105}, - 1024 * 1024: {'float16': 0.212, 'float32': 0.420}, - 1024 * 16384: {'float16': 0.762, 'float32': 0.812}, - 1024 * 65536: {'float16': 0.846, 'float32': 0.869}, + 1024 * 16: {'float16': 0.031, 'float32': 0.060}, + 1024 * 64: {'float16': 0.120, 'float32': 0.224}, + 1024 * 256: {'float16': 0.394, 'float32': 0.691}, + 1024 * 1024: {'float16': 1.06, 'float32': 1.453}, + 1024 * 16384: {'float16': 0.832, 'float32': 0.862}, + 1024 * 65536: {'float16': 0.873, 'float32': 0.882}, # Non pow 2 - 1020 * 100: {'float16': 0.020, 'float32': 0.041}, - 10003 * 7007: {'float16': 0.513, 'float32': 0.861}, + 1020 * 100: {'float16': 0.173, 'float32': 0.327}, + 10003 * 7007: {'float16': 0.522, 'float32': 0.873}, } } @@ -155,30 +143,30 @@ def test_elementwise(N, dtype_str): flash_attention_data = { "a100": { - (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.532, + (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542, (4, 48, 4096, 64, True, True, 'forward', 'bfloat16'): 0.471, - (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.150, - (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.204, + (4, 48, 1024, 16, True, True, 'forward', 'float32'): 0.155, + (4, 48, 4096, 64, True, True, 'backward', 'float16'): 0.203, (4, 48, 4096, 64, True, True, 'backward', 'bfloat16'): 0.202, - (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.089, - (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.298, - (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.263, - (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.095, - (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.136, + (4, 48, 1024, 16, True, True, 'backward', 'float32'): 0.108, + (4, 48, 4096, 64, True, False, 'forward', 'float16'): 0.306, + (4, 48, 4096, 64, True, False, 'forward', 'bfloat16'): 0.266, + (4, 48, 1024, 16, True, False, 'forward', 'float32'): 0.098, + (4, 48, 4096, 64, True, False, 'backward', 'float16'): 0.134, (4, 48, 4096, 64, True, False, 'backward', 'bfloat16'): 0.135, - (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.052, - (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.525, + (4, 48, 1024, 16, True, False, 'backward', 'float32'): 0.066, + (4, 48, 4096, 64, False, True, 'forward', 'float16'): 0.541, (4, 48, 4096, 64, False, True, 'forward', 'bfloat16'): 0.471, (4, 48, 1024, 16, False, True, 'forward', 'float32'): 0.150, - (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.265, - (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.257, - (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.128, - (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.297, - (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.263, - (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.095, + (4, 48, 4096, 64, False, True, 'backward', 'float16'): 0.263, + (4, 48, 4096, 64, False, True, 'backward', 'bfloat16'): 0.255, + (4, 48, 1024, 16, False, True, 'backward', 'float32'): 0.144, + (4, 48, 4096, 64, False, False, 'forward', 'float16'): 0.306, + (4, 48, 4096, 64, False, False, 'forward', 'bfloat16'): 0.266, + (4, 48, 1024, 16, False, False, 'forward', 'float32'): 0.098, (4, 48, 4096, 64, False, False, 'backward', 'float16'): 0.159, - (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.138, - (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.076, + (4, 48, 4096, 64, False, False, 'backward', 'bfloat16'): 0.136, + (4, 48, 1024, 16, False, False, 'backward', 'float32'): 0.088, } } @@ -250,8 +238,8 @@ def _sum(x_ptr, y_ptr, output_ptr, n_elements, reduction_data = { 'a100': { - 1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.015, 'int32': 0.031}, - 1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.015, 'int32': 0.032}, + 1024 * 16384: {'float16': 0.016, 'float32': 0.031, 'int16': 0.022, 'int32': 0.048}, + 1024 * 65536: {'float16': 0.016, 'float32': 0.032, 'int16': 0.022, 'int32': 0.049}, } } diff --git a/python/test/unit/hopper/test_flashattention.py b/python/test/unit/hopper/test_flashattention.py index e46e1c1f2..60006613b 100644 --- a/python/test/unit/hopper/test_flashattention.py +++ b/python/test/unit/hopper/test_flashattention.py @@ -368,14 +368,15 @@ class _attention(torch.autograd.Function): attention = _attention.apply -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 128, 64), - # (4, 48, 256, 64), - # (4, 48, 512, 64), - # (4, 48, 1024, 64), - # (4, 48, 2048, 64), - # (4, 48, 4096, 64), - # (4, 48, 8192, 64), out of memory - ]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 128, 64), + (4, 48, 256, 64), + (4, 48, 512, 64), + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (4, 48, 4096, 64), + # (4, 48, 8192, 64), out of memory +]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) diff --git a/python/test/unit/hopper/test_gemm.py b/python/test/unit/hopper/test_gemm.py index af236d0de..a3e3f80b9 100644 --- a/python/test/unit/hopper/test_gemm.py +++ b/python/test/unit/hopper/test_gemm.py @@ -331,7 +331,6 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, ]: pytest.skip('shapePerCTA[1] < 16 not supported') - # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ '16-32-64-4-1-256-256-256-False', '16-32-64-4-2-256-256-256-False', @@ -444,7 +443,7 @@ def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, atol=1e-3, check_dtype=False) - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() - if enable_mmav3 in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: + disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() + if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: ptx = pgm.asm['ptx'] assert re.search(r'wgmma.mma_async.sync.aligned.m\d+n{}k16(?:.row.col)?.f32.f16.f16'.format(BLOCK_N), ptx) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index fd7c14e6c..32c04c33b 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -141,7 +141,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO c = torch.empty((M, N), device=a.device, dtype=torch.float32) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (num_SMs,) + grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) if USE_TMA: static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) @@ -149,7 +149,7 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit @@ -300,7 +300,7 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K enable_warp_specialization=True) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit @@ -432,7 +432,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N c = torch.empty((M, N), device=a.device, dtype=torch.float32) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (num_SMs,) + grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) if USE_TMA: static_persistent_tma_warp_specialized_matmul_kernel[grid]( @@ -456,7 +456,7 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N enable_warp_specialization=True) th_c = torch.matmul(a, b) - torch.testing.assert_allclose(th_c, c, atol=1e-2, rtol=0) + torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit @@ -818,7 +818,6 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR ]: pytest.skip('shapePerCTA[1] < 16 not supported') - # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ '16-32-64-4-1-256-256-256-False', '16-32-64-4-2-256-256-256-False', @@ -899,7 +898,7 @@ def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WAR num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count def grid(META): - return (num_SMs,) + return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) full_static_persistent_matmul_kernel[grid]( a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, M=M, N=N, K=K, diff --git a/python/test/unit/interpreter/test_interpreter.py b/python/test/unit/interpreter/test_interpreter.py deleted file mode 100644 index b6bb6b79c..000000000 --- a/python/test/unit/interpreter/test_interpreter.py +++ /dev/null @@ -1,69 +0,0 @@ -import random - -import torch - -import triton -import triton.language as tl -from triton.interpreter.interpreter import program_ids_from_grid - - -def test_addition(): - - @triton.jit(interpret=True) - def add_kernel( - x_ptr, - y_ptr, - output_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, - ): - pid = tl.program_id(axis=0) - block_start = pid * BLOCK_SIZE - offsets = block_start + tl.arange(0, BLOCK_SIZE) - mask = offsets < n_elements - x = tl.load(x_ptr + offsets, mask=mask) - y = tl.load(y_ptr + offsets, mask=mask) - output = x + y - tl.store(output_ptr + offsets, output, mask=mask) - - a = torch.rand((128,), device="cuda") - b = torch.rand((128,), device="cuda") - expected = a + b - output = torch.empty((128,), device="cuda") - - def grid(meta): - return (triton.cdiv(128, meta["BLOCK_SIZE"]),) - - add_kernel[grid](a, b, output, 128, BLOCK_SIZE=32) - - assert torch.allclose(expected, output, atol=1e-2, rtol=0) - - -def test_program_ids_from_grid(): - random.seed(123) - grid = (3, 4) - expected_combinations = 3 * 4 - unique_combinations = set(program_ids_from_grid(grid)) - assert len(unique_combinations) == expected_combinations - - first_run = list(program_ids_from_grid(grid)) - second_run = list(program_ids_from_grid(grid)) - assert first_run != second_run - - -def test_atomic(): - @triton.jit(interpret=True) - def atomic( - x_ptr, - ): - pid = tl.program_id(axis=0) - tl.atomic_add(x_ptr + pid, 1) - t = tl.atomic_xchg(x_ptr + pid, 3) - t += 1 # 2 - tl.atomic_cas(x_ptr + pid, 3, t) # match - tl.atomic_cas(x_ptr + pid, 40, 9) # no match - nb_dim = 16 - a = torch.zeros((nb_dim, ), dtype=torch.int32, device="cuda") - - atomic[(nb_dim, )](a) - assert torch.allclose(a, torch.full_like(a, 2)) diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index accfdbdbb..b1802b908 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -99,4 +99,4 @@ def test_block_ptr_matmul_no_scf(shape, num_warps): BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, num_warps=num_warps) golden = torch.matmul(a, b) - torch.testing.assert_allclose(c, golden) + torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 9961dc51c..7f3530862 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -20,6 +20,7 @@ uint_dtypes = ['uint8', 'uint16', 'uint32', 'uint64'] float_dtypes = ['float16', 'float32', 'float64'] dtypes = int_dtypes + uint_dtypes + float_dtypes dtypes_with_bfloat16 = dtypes + ['bfloat16'] +torch_float8_dtypes = ['float8_e4m3fn', 'float8_e5m2'] torch_dtypes = ['bool'] + int_dtypes + ['uint8'] + float_dtypes + ['bfloat16'] # TODO: enable multiple cta cluster testing. @@ -131,8 +132,8 @@ def check_type_supported(dtype, device): cc = torch.cuda.get_device_capability() if cc[0] < 8 and (dtype is tl.bfloat16 or dtype == "bfloat16" or dtype is torch.bfloat16): pytest.skip("bfloat16 is only supported on NVGPU with cc >= 80") - if cc[0] < 9 and (dtype is tl.float8e4nv or dtype == "float8e4"): - pytest.skip("float8e4 is only supported on NVGPU with cc >= 90") + if cc[0] < 9 and dtype in {tl.float8e4nv, "float8e4nv", "float8_e4m3fn"}: + pytest.skip("float8e4nv is only supported on NVGPU with cc >= 90") class MmaLayout: @@ -564,35 +565,6 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): _test_binary(dtype_x, dtype_y, expr, numpy_expr, mode_x=mode_x, mode_y=mode_y, device=device, num_ctas=num_ctas) -# --------------- -# test broadcast -# --------------- -@pytest.mark.parametrize("dtype", dtypes_with_bfloat16) -def test_broadcast(dtype): - @triton.jit - def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): - offset1 = tl.arange(0, M) - offset2 = tl.arange(0, N) - x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :]) - y = tl.load(y_ptr + offset2) - _, y_broadcasted = tl.broadcast(x, y) - tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted) - - M = 32 - N = 64 - rs = RandomState(17) - x = numpy_random((M, N), dtype_str=dtype, rs=rs) - y = numpy_random(N, dtype_str=dtype, rs=rs) - _, y_broadcasted_np = np.broadcast_arrays(x, y) - - x_tri = to_triton(x, device='cuda', dst_type=dtype) - y_tri = to_triton(y, device='cuda', dst_type=dtype) - y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype) - - broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) - assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() - - # --------------- # test broadcast # --------------- @@ -621,6 +593,36 @@ def test_broadcast(dtype, device): broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() +# ---------- +# test slice +# ---------- + + +def test_slice(device): + + @triton.jit + def slice_kernel(XBLOCK: tl.constexpr): + data = tl.arange(0, XBLOCK) + tl.static_assert(data.shape == [XBLOCK]) + + t = data[None, :] + tl.static_assert(t.shape == [1, XBLOCK]) + + t = data[None, :, None] + tl.static_assert(t.shape == [1, XBLOCK, 1]) + + scalar = tl.full([], 1, tl.int32) + tl.static_assert(scalar.shape == []) + + t = scalar[None] + tl.static_assert(t.shape == [1]) + + t = scalar[None, None] + tl.static_assert(t.shape == [1, 1]) + + slice_kernel[(1,)](XBLOCK=32) + + # ------------------ # test invalid slice # ------------------ @@ -669,6 +671,14 @@ def test_expand_dims(device): t = tl.expand_dims(offset1, (3, 1, 2)) tl.static_assert(t.shape == [N, 1, 1, 1]) + scalar = tl.sum(offset1) + tl.static_assert(scalar.shape == []) + t = tl.expand_dims(scalar, 0) + tl.static_assert(t.shape == [1]) + + t = tl.expand_dims(scalar, -1) + tl.static_assert(t.shape == [1]) + N = 32 dummy_tensor = torch.empty((), device=device) expand_dims_kernel[(1,)](dummy_tensor, N) @@ -689,6 +699,13 @@ def test_expand_dims_error_cases(device): t = tl.expand_dims(offset1, 1) t = tl.expand_dims(offset1, 2) + @triton.jit + def dim_out_of_range3(dummy, N: tl.constexpr): + offset1 = tl.arange(0, 1) + scalar = tl.sum(offset1) + + t = tl.expand_dims(scalar, 1) + @triton.jit def duplicate_dim1(dummy, N: tl.constexpr): offset1 = tl.arange(0, N) @@ -710,6 +727,9 @@ def test_expand_dims_error_cases(device): with pytest.raises(triton.CompilationError, match="invalid axis 2"): dim_out_of_range2[(1,)](dummy_tensor, N) + with pytest.raises(triton.CompilationError, match="invalid axis 1"): + dim_out_of_range3[(1,)](dummy_tensor, N) + with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): duplicate_dim1[(1,)](dummy_tensor, N) @@ -727,7 +747,7 @@ def test_invalid_pid_axis(device): def _kernel(dst): pid = tl.program_id(20) - with pytest.raises(triton.CompilationError, match=r"program_id must be in \[0,3\]"): + with pytest.raises(triton.CompilationError, match=r"program_id axis must be 0, 1, or 2 but got 20"): _kernel[(1,)](dst) @@ -881,7 +901,7 @@ def test_abs_fp8(in_dtype, device): f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) expect = f32_tensor.abs() actual_f8 = convert_float_to_float32(out_f8, in_dtype) - torch.testing.assert_allclose(actual_f8, expect) + torch.testing.assert_close(actual_f8, expect, equal_nan=True) # ---------------- @@ -1100,8 +1120,11 @@ def test_noinline(mode, device): [ ('add', 'float16', mode, sem), ('add', 'uint32', mode, sem), ('add', 'int32', mode, sem), ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), ('add', 'int64', mode, sem), ('add', 'float64', mode, sem), ('max', 'uint32', mode, sem), ('max', 'int32', mode, sem), ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), ('max', 'int64', mode, sem), ('max', 'float64', mode, sem), ('min', 'uint32', mode, sem), ('min', 'int32', mode, sem), ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), ('min', 'int64', mode, sem), ('min', 'float64', mode, sem), ] for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) @@ -1120,6 +1143,7 @@ def test_atomic_rmw(op, dtype_x_str, mode, sem, device): pid = tl.program_id(0) x = tl.load(X + pid) old = GENERATE_TEST_HERE + tl.static_assert(old.dtype == x.dtype) sem_arg = sem if sem is None else f'"{sem}"' kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.atomic_{op}(Z, x, sem={sem_arg})'}) @@ -1258,23 +1282,33 @@ def test_atomic_cas(sem, num_ctas, device): # --------------- -@pytest.mark.parametrize("dtype_x, dtype_z, bitcast", [ - (dtype_x, dtype_z, False) +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [ + (dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes ] + [ - ('float32', 'bfloat16', False), - ('bfloat16', 'float32', False), - ('float32', 'int32', True), - ('float32', 'int1', False), - ('int8', 'bfloat16', False), + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'int1', False, 1024), + ('int8', 'bfloat16', False, 1024), ] + [ - (f'uint{x}', f'int{x}', True) for x in [8, 16, 32, 64] + (f'uint{x}', f'int{x}', True, 1024) for x in [8, 16, 32, 64] ] + [ - (f'int{x}', f'uint{x}', True) for x in [8, 16, 32, 64] -]) + (f'int{x}', f'uint{x}', True, 1024) for x in [8, 16, 32, 64] +] + (([ + (dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32] +] + [ + (dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32] +]) if torch.__version__ >= "2.1" else [])) @pytest.mark.parametrize("num_ctas", num_ctas_list) -def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device): +def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): # bfloat16 on cc < 80 will not be tested check_type_supported(dtype_x, device) check_type_supported(dtype_z, device) @@ -1282,10 +1316,15 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device): if is_hip() and (dtype_z == "bfloat16"): pytest.skip(f'test_cast{(dtype_x, dtype_z)} cast to bfloat16 not supported on HIP.') +<<<<<<< HEAD size = 1024 +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # This is tricky because numpy doesn't have bfloat, and torch doesn't have uints. if dtype_x.startswith('bfloat'): x_tri = torch.randn(size, dtype=getattr(torch, dtype_x), device=device) + elif dtype_x.startswith('float8'): + x_tri = torch.randn(size, dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_x)) else: x = numpy_random(size, dtype_str=dtype_x, low=-10, high=10) * 10 # Triton clamps negative values to zero, while numpy wraps around @@ -1308,11 +1347,13 @@ def test_cast(dtype_x, dtype_z, bitcast, num_ctas, device): # triton result if dtype_z.startswith('bfloat'): z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device) + elif dtype_z.startswith('float8'): + z_tri = torch.empty((size,), dtype=torch.float, device=device) else: z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas) # torch result - if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat'): + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith('float8') or dtype_x.startswith('float8'): assert bitcast is False z_ref = x_tri.to(z_tri.dtype) torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) @@ -1489,22 +1530,26 @@ def test_fp8_fpN_roundtrip(in_dtype, out_dtype, device): # initialize array containing all possible f8 values except NaN ref_fp8 = np.array(range(-128, 128), dtype=np.int8) - is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width exp_mask = 0b01111111 ^ ((1 << in_dtype.fp_mantissa_width) - 1) + is_nan = (ref_fp8 & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width is_subnormal = np.logical_or((ref_fp8 & exp_mask) == 0, (ref_fp8 & exp_mask) == exp_mask) - ref_fp8[is_nan] = 0 - ref_fp8[is_subnormal] = 0 tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda() + # check that non-subnormal fp8 are correctly converted to fp16 tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda") copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024) - ref_fp8 = torch.from_numpy(ref_fp8).cuda() ref_fp16 = convert_float_to_float32(ref_fp8, in_dtype) assert torch.all(tri_fp16[~is_subnormal] == ref_fp16[~is_subnormal]) - + # check that values are properly converted back to float8 ref_fp8 = torch.empty_like(tri_fp16, dtype=torch.int8) copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) - assert torch.all(tri_fp8 == ref_fp8) + if in_dtype == tl.float8e4b15: + assert torch.all(tri_fp8[:127] == ref_fp8[:127]) + assert torch.all(tri_fp8[128:255] == ref_fp8[128:255]) + assert ref_fp8[126] == ref_fp8[127] # -1.875 saturates to -1.75 + assert ref_fp8[254] == ref_fp8[255] # 1.875 saturates to 1.75 + else: + assert torch.all(tri_fp8[~is_subnormal] == ref_fp8[~is_subnormal]) # --------------- # test reduce @@ -1693,10 +1738,16 @@ scan_configs = [ for type in ['int32', 'float32'] for axis in [1, 0] for shape in scan2d_shapes - for op in ['cumsum', 'cumprod'] + for op in ['cumsum', 'cumprod', 'get_first_element'] ] +@triton.jit +# trivial associative but not commutative function +def get_first_element(a, b): + return a + + @pytest.mark.parametrize("op, dtype_str, shape, axis, num_warps", scan_configs) def test_scan2d(op, dtype_str, shape, axis, num_warps, device): if is_hip(): @@ -1712,15 +1763,26 @@ def test_scan2d(op, dtype_str, shape, axis, num_warps, device): z = GENERATE_TEST_HERE tl.store(Z + range_m[:, None] * BLOCK_N + range_n[None, :], z) - kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'}) + if op == 'cumsum' or op == 'cumprod': + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.{op}(x, axis={axis})'}) + else: + kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f'tl.associative_scan(x, axis={axis}, combine_fn={op})'}) # input rs = RandomState(17) x = numpy_random(shape, dtype_str=dtype_str, rs=rs) z = np.empty_like(x) x_tri = to_triton(x, device=device) - numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] - z_dtype_str = dtype_str - z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + if op == 'cumsum' or op == 'cumprod': + numpy_op = {'cumsum': np.cumsum, 'cumprod': np.cumprod}[op] + z_dtype_str = dtype_str + z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) + else: + assert op == 'get_first_element' + z_ref = x + if axis == 0: + z_ref[1:] = x[0] + else: + z_ref[:, 1:] = x[:, 0:1] # triton result z_tri = to_triton(z, device=device) kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) @@ -1750,7 +1812,7 @@ scan_layouts = [ ] -@pytest.mark.parametrize("M, N", [[32, 32], [32, 64], [64, 32]]) +@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]]) @pytest.mark.parametrize("src_layout", scan_layouts) @pytest.mark.parametrize("axis", [0, 1]) def test_scan_layouts(M, N, src_layout, axis, device): @@ -1760,28 +1822,28 @@ def test_scan_layouts(M, N, src_layout, axis, device): ir = f""" #blocked = {src_layout} module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ - tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> %2 = arith.muli %1, %cst : tensor<{M}x1xi32, #blocked> - %3 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> - %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %3 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> + %4 = tt.addptr %3, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> %5 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>> %6 = tt.expand_dims %5 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked> - %7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> + %7 = tt.broadcast %4 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> %8 = tt.broadcast %6 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> - %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %9 = tt.addptr %7, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #blocked> %11 = "tt.scan"(%10) <{{axis = {axis} : i32}}> ({{ ^bb0(%arg2: i32, %arg3: i32): %16 = arith.addi %arg2, %arg3 : i32 tt.scan.return %16 : i32 }}) : (tensor<{M}x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> - %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> - %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> - %14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> - %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %12 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> + %13 = tt.addptr %12, %2 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> + %14 = tt.broadcast %13 : (tensor<{M}x1x!tt.ptr, #blocked>) -> tensor<{M}x{N}x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> tt.store %15, %11 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{M}x{N}xi32, #blocked> tt.return }} @@ -1812,30 +1874,79 @@ layouts = [ BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), - MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 16, 16]), ] @pytest.mark.parametrize("M, N", [[128, 16], [128, 128], [32, 128], [32, 32]]) @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("axis", [0, 1]) +<<<<<<< HEAD def test_reduce_layouts(M, N, src_layout, axis, device): if is_hip(): pytest.skip("test_reduce_layouts is not supported in HIP") rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" +======= +@pytest.mark.parametrize("reduce2d", [False, True]) +@pytest.mark.parametrize("dtype_str", ["int32", "float32", "float16"]) +@pytest.mark.parametrize("reduce_op", ["sum", "max"]) +def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, device): + if is_hip(): + pytest.skip("test_reduce_layouts is not supported in HIP") + if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024: + pytest.skip("Skipping sum reduction on float16 due to accuracy issues") + + ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] + arith_op = { + "max": {"int32": "arith.maxsi", "float32": "arith.maxf", "float16": "arith.maxf"}, + "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} + }[reduce_op][dtype_str] + numpy_op = { + "max": np.max, + "sum": np.sum + }[reduce_op] +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 rdims_1d = f"{N}" if axis == 0 else f"{M}" + rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" store_range = "%7" if axis == 0 else "%1" blocked = BlockedLayout([1, 1], [32, 1], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]) + epilogue = f""" + %14 = "tt.reduce"(%13) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = 0 : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> {ty} + tt.store %arg2, %14 {{cache = 1 : i32, evict = 1 : i32}} : {ty} + tt.return + }} + }} + """ if reduce2d else f""" + %14 = tt.splat %arg2 : (!tt.ptr<{ty}, 1>) -> tensor<{rdims_2d}x!tt.ptr<{ty}, 1>, #blocked> + %15 = tt.addptr %14, {store_range} : tensor<{rdims_2d}x!tt.ptr<{ty}>, #blocked>, tensor<{rdims_2d}xi32, #blocked> + %16 = {GPU_DIALECT}.convert_layout %13 : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>> + %17 = tt.expand_dims %16 {{axis = {axis} : i32}} : (tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #blocked}}>>) -> tensor<{rdims_2d}x{ty}, #blocked> + tt.store %15, %17 {{cache = 1 : i32, evict = 1 : i32}} : tensor<{rdims_2d}x{ty}, #blocked> + tt.return + }} + }} + """ + ir = f""" #blocked = {blocked} #src = {src_layout} module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ +<<<<<<< HEAD tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ +======= + tt.func public @kernel_0d1d2c3d4c(%arg0: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}, %arg1: i32 {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr<{ty}, 1> {{tt.divisibility = 16 : i32}}) {{ +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #blocked}}>>) -> tensor<{M}x1xi32, #blocked> %2 = tt.splat %arg1 : (i32) -> tensor<{M}x1xi32, #blocked> %3 = arith.muli %1, %2 : tensor<{M}x1xi32, #blocked> +<<<<<<< HEAD %4 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #blocked> %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr, #blocked>, tensor<{M}x1xi32, #blocked> %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> @@ -1859,6 +1970,23 @@ def test_reduce_layouts(M, N, src_layout, axis, device): }} }} """ +======= + %4 = tt.splat %arg0 : (!tt.ptr<{ty}, 1>) -> tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked> + %5 = tt.addptr %4, %3 : tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x1xi32, #blocked> + %6 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>> + %7 = tt.expand_dims %6 {{axis = 0 : i32}} : (tensor<{N}xi32, #{GPU_DIALECT}.slice<{{dim = 0, parent = #blocked}}>>) -> tensor<1x{N}xi32, #blocked> + %8 = tt.broadcast %5 : (tensor<{M}x1x!tt.ptr<{ty}, 1>, #blocked>) -> tensor<{M}x{N}x!tt.ptr<{ty}, 1>, #blocked> + %9 = tt.broadcast %7 : (tensor<1x{N}xi32, #blocked>) -> tensor<{M}x{N}xi32, #blocked> + %10 = tt.addptr %8, %9 : tensor<{M}x{N}x!tt.ptr<{ty}, 1>, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x{ty}, #blocked> + %12 = {GPU_DIALECT}.convert_layout %11 : (tensor<{M}x{N}x{ty}, #blocked>) -> tensor<{M}x{N}x{ty}, #src> + %13 = "tt.reduce"(%12) ({{ + ^bb0(%arg3: {ty}, %arg4: {ty}): + %17 = {arith_op} %arg3, %arg4 : {ty} + tt.reduce.return %17 : {ty} + }}) {{axis = {axis} : i32}} : (tensor<{M}x{N}x{ty}, #src>) -> tensor<{rdims_1d}x{ty}, #{GPU_DIALECT}.slice<{{dim = {axis}, parent = #src}}>> + """ + epilogue +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 import tempfile with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: @@ -1867,21 +1995,20 @@ def test_reduce_layouts(M, N, src_layout, axis, device): kernel = triton.compile(f.name) rs = RandomState(17) - x = rs.randint(0, 20, (M, N)).astype('int32') - - if axis == 0: - z = np.zeros((1, N)).astype('int32') - else: - z = np.zeros((M, 1)).astype('int32') + x = numpy_random((M, N), dtype_str=dtype_str, rs=rs, low=0, high=10) + z_shape = (1, 1) if reduce2d else (1, N) if axis == 0 else (M, 1) + z = np.zeros(z_shape).astype(dtype_str) x_tri = torch.tensor(x, device=device) z_tri = torch.tensor(z, device=device) - pgm = kernel[(1, 1, 4)](x_tri, x_tri.stride(0), z_tri) + pgm = kernel[(1, 1, 1)](x_tri, x_tri.stride(0), z_tri) + z_ref = numpy_op(x) if reduce2d else numpy_op(x, axis=axis, keepdims=True) - z_ref = np.sum(x, axis=axis, keepdims=True) - - np.testing.assert_equal(z_ref, z_tri.cpu().numpy()) + if dtype_str == 'float16': + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-2) + else: + np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) layouts = [ @@ -1900,16 +2027,28 @@ def test_store_op(M, src_layout, device): ir = f""" #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "{GPU_DIALECT}.num-ctas" = 1 : i32, "{GPU_DIALECT}.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ +<<<<<<< HEAD tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> +======= + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %1 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> + %2 = tt.addptr %1, %0 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %4 = tt.expand_dims %3 {{axis = 1 : i32}} : (tensor<{M}xf32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xf32, #src> %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %6 = tt.expand_dims %5 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> +<<<<<<< HEAD %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> +======= + %7 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x1x!tt.ptr, #src> + %8 = tt.addptr %7, %6 : tensor<{M}x1x!tt.ptr, #src>, tensor<{M}x1xi32, #src> +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 tt.store %8, %4 : tensor<{M}x1xf32, #src> tt.return }} @@ -1953,6 +2092,7 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): #dst = {dst_layout} #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ +<<<<<<< HEAD tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> @@ -1961,6 +2101,16 @@ def test_convert1d(M, src_layout, dst_layout, src_dim, dst_dim, device): %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> +======= + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %0 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %2 = tt.addptr %0, %1 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %3 = tt.load %2 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>> + %4 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %5 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> + %6 = tt.addptr %4, %5 : tensor<{M}x!tt.ptr, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>>, tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %7 = {GPU_DIALECT}.convert_layout %3 : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {src_dim}, parent = #src}}>>) -> tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> tt.store %6, %7 : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = {dst_dim}, parent = #dst}}>> tt.return @@ -2024,7 +2174,11 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): ir = f""" #src = {src_layout} module attributes {{"{GPU_DIALECT}.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = {THREADS_PER_WARP} : i32}} {{ +<<<<<<< HEAD tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ +======= + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>> %1 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #{GPU_DIALECT}.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> @@ -2034,8 +2188,8 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): %5 = tt.broadcast %2 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> %6 = tt.broadcast %4 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> %7 = arith.addi %5, %6 : tensor<{M}x{N}xi32, #src> - %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> - %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %8 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %9 = tt.addptr %8, %7 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> %10 = tt.load %9 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xi32, #src> %11 = "tt.reduce"(%10) ({{ ^bb0(%arg2: i32, %arg3: i32): @@ -2195,7 +2349,10 @@ def test_permute(dtype_str, shape, perm, num_ctas, device): for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), - ('float32', 'float32')]]) + ('float32', 'float32')]] + + + [(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32') + for col_a in [True, False] for col_b in [True, False]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, num_ctas, device): check_cuda_only(device) @@ -2215,12 +2372,12 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") if capability[0] < 8: - if in_dtype == 'int8': - pytest.skip("Only test int8 on devices with sm >= 80") - elif allow_tf32: + if capability[1] == 0 and in_dtype == 'int8': + pytest.skip("Only test int8 on devices with sm >= 75") + if allow_tf32: pytest.skip("Only test tf32 on devices with sm >= 80") if capability[0] == 7: - if (M, N, K, num_warps) == (128, 256, 32, 8): + if (M, N, K, num_warps) in [(128, 256, 32, 8), (64, 128, 128, 4), (64, 128, 128, 2)]: pytest.skip("shared memory out of resource") if out_dtype == 'float16': # TODO: support out_dtype=float16 for tl.dot on V100 @@ -2349,11 +2506,18 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o red_code = ptx[start:end] assert len(red_code) > 0 import os +<<<<<<< HEAD enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() # skip this check on hopper because there are some functions whose name contain "shared" in ptx. # TODO: we should eliminate these unused functions in ptx code. if not (enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]): +======= + + # skip this check on hopper because there are some functions whose name contain "shared" in ptx. + # TODO: we should eliminate these unused functions in ptx code. + if not (capability[0] >= 9): +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 assert "shared" not in red_code assert "bar.sync" not in red_code # torch result @@ -2376,7 +2540,6 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o if epilogue == 'chain-dot': z_ref = np.matmul(z_ref, w) # compare - # print(z_ref[:,0], z_tri[:,0]) if in_dtype == 'float32': # XXX: Somehow there's a larger difference when we use float32 np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0.01, atol=1e-3) @@ -2406,8 +2569,11 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o else: assert re.search(r'[mma|wgmma.mma_async].sync.aligned.m\d+n\d+k16(?:.row.col)?.f16.f16.f16', ptx) elif in_dtype == 'int8': - assert 'wgmma.mma_async.sync.aligned' in ptx or\ - 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx + if capability[0] == 7 and capability[1] == 5: # Turing + assert 'mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32' in ptx + else: + assert 'wgmma.mma_async.sync.aligned' in ptx or\ + 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx @pytest.mark.parametrize('in_dtype', ['float32']) @@ -2454,20 +2620,20 @@ def test_dot_mulbroadcastred(in_dtype, device): if is_hip(): return assert "tt.dot" in h.asm['ttir'] - # with option ENABLE_MMA_V3 on, we will not pipeline the load op for Y + # when using MMAv3, we will not pipeline the load op for Y # as the loaded value is in rowmajor. But MMAv3 requires it's second # operand is in colmajor because transpose is not supported for MMAv3 # with float32 input. import os - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() - if enable_mmav3 in ["on", "true", "1"]: + if capability[0] >= 9: assert "triton_gpu.async_wait {num = 1 : i32}" in h.asm['ttgir'] else: assert "triton_gpu.async_wait {num = 2 : i32}" in h.asm['ttgir'] @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) -def test_full(dtype_str, device): +@pytest.mark.parametrize("shape", [(), (1,), (128,)]) +def test_full(dtype_str, shape, device): if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): # PyTorch only has unsigned 8, but not 16, 32, or 64 dtype = getattr(torch, dtype_str[1:]) # uintx -> intx @@ -2478,21 +2644,28 @@ def test_full(dtype_str, device): @triton.jit def kernel_static(out): a = GENERATE_TEST_HERE + tl.static_assert(a.shape == SHAPE) out_ptr = out + tl.arange(0, 128)[:] tl.store(out_ptr, a) @triton.jit def kernel_dynamic(out, val, dtype: tl.constexpr): - a = tl.full((128,), val, dtype) + a = tl.full(SHAPE, val, dtype) + tl.static_assert(a.shape == SHAPE) out_ptr = out + tl.arange(0, 128)[:] tl.store(out_ptr, a) - kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"}) + kernel_static_patched = patch_kernel(kernel_static, { + 'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})", + 'SHAPE': str(list(shape)), + }) out_static = torch.zeros((128), dtype=dtype, device=device) kernel_static_patched[(1,)](out_static) - out_dynamic = torch.zeros((128), dtype=dtype, device=device) - kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str)) assert torch.all(out_static == 2) + + kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) + out_dynamic = torch.zeros((128), dtype=dtype, device=device) + kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str)) assert torch.all(out_dynamic == 2) @@ -2594,7 +2767,7 @@ def test_masked_load(dtype_str, size, size_diff, num_ctas, device): reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device))) # print((output - reference_out).nonzero()) - torch.testing.assert_allclose(output, reference_out) + torch.testing.assert_close(output, reference_out) # Testing masked loads with an intermate copy to shared memory run. @@ -2649,7 +2822,7 @@ def test_masked_load_shared_memory(dtype, device): M=M, N=N, K=K) reference_out = torch.matmul(in1, in2) - torch.testing.assert_allclose(out, reference_out, atol=1e-2, rtol=0) + torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) @pytest.mark.parametrize("cache", ["", ".ca", ".cg"]) @@ -3000,7 +3173,7 @@ def test_call(type, num_ctas, device): err_msg = str(e) if type == "noinline": - assert err_msg is not "" + assert err_msg != "" else: ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 np.testing.assert_equal(to_numpy(rand_val_tri), ans) @@ -3507,6 +3680,22 @@ def test_while(device): # test extra # ----------------------- +def test_num_threads(device): + if is_hip(): + pytest.skip("test_num_threads is not supported in HIP") + check_cuda_only(device) + + @triton.jit + def kernel(Out): + num_threads: tl.constexpr = tl.extra.cuda.num_threads() + offs = tl.arange(0, num_threads) + tl.store(Out + offs, 1) + + num_threads = 256 + out = to_triton(np.zeros((num_threads,), dtype=np.int32), device=device) + kernel[(1,)](out, num_warps=num_threads // 32) + assert torch.sum(out) == 256 + def test_globaltimer(device): if is_hip(): @@ -3554,6 +3743,7 @@ layouts = [ # MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), # MmaLayout(1, [4, 1], [1, 1], [0, 1]), # MmaLayout((2, 0), [4, 1], [1, 1], [1, 1], [0, 1], [16, 8]), + BlockedLayout([1, 16], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), @@ -3571,15 +3761,23 @@ intermediate_layouts = [ ] -@pytest.mark.parametrize("shape", [(128, 128)]) +@pytest.mark.parametrize("M, N", [[64, 1], [64, 64], [128, 128], [1, 64]]) @pytest.mark.parametrize("dtype", ['float16']) @pytest.mark.parametrize("src_layout", layouts) @pytest.mark.parametrize("interm_layout", intermediate_layouts) @pytest.mark.parametrize("dst_layout", layouts) +<<<<<<< HEAD def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): if is_hip(): pytest.skip("test_convert2d is not supported in HIP") +======= +def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): + if is_hip(): + pytest.skip("test_convert2d is not supported in HIP") + if (M == 1 or N == 1) and interm_layout: + pytest.skip("Out of bound access when maxPhase > 1") +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 if str(src_layout) == str(dst_layout): pytest.skip() if 'mma' in str(src_layout) and 'mma' in str(dst_layout): @@ -3595,43 +3793,43 @@ def test_convert2d(dtype, shape, src_layout, interm_layout, dst_layout, device): """ conversion = f""" - %12 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst> - %13 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst> + %12 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst> """ if interm_layout is None else f""" - %15 = triton_gpu.convert_layout %9 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #interm> - %16 = triton_gpu.convert_layout %15 : (tensor<128x128xi32, #interm>) -> tensor<128x128xi32, #src> - %17 = triton_gpu.convert_layout %11 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #interm> - %18 = triton_gpu.convert_layout %17 : (tensor<128x128xf16, #interm>) -> tensor<128x128xf16, #src> + %15 = triton_gpu.convert_layout %9 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #interm> + %16 = triton_gpu.convert_layout %15 : (tensor<{M}x{N}xi32, #interm>) -> tensor<{M}x{N}xi32, #src> + %17 = triton_gpu.convert_layout %11 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #interm> + %18 = triton_gpu.convert_layout %17 : (tensor<{M}x{N}xf16, #interm>) -> tensor<{M}x{N}xf16, #src> - %12 = triton_gpu.convert_layout %16 : (tensor<128x128xi32, #src>) -> tensor<128x128xi32, #dst> - %13 = triton_gpu.convert_layout %18 : (tensor<128x128xf16, #src>) -> tensor<128x128xf16, #dst> + %12 = triton_gpu.convert_layout %16 : (tensor<{M}x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #dst> + %13 = triton_gpu.convert_layout %18 : (tensor<{M}x{N}xf16, #src>) -> tensor<{M}x{N}xf16, #dst> """ - ir = layouts + """ - module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { - tt.func public @kernel_0d1d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}) { - %cst = arith.constant dense<128> : tensor<128x1xi32, #src> - %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>> - %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>> - %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #src> - %4 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #src}>>) -> tensor<128x1xi32, #src> - %5 = arith.muli %4, %cst : tensor<128x1xi32, #src> - %6 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #src}>>) -> tensor<1x128xi32, #src> - %7 = tt.broadcast %6 : (tensor<1x128xi32, #src>) -> tensor<128x128xi32, #src> - %8 = tt.broadcast %5 : (tensor<128x1xi32, #src>) -> tensor<128x128xi32, #src> - %9 = arith.addi %8, %7 : tensor<128x128xi32, #src> - %10 = tt.addptr %2, %9 : tensor<128x128x!tt.ptr, #src>, tensor<128x128xi32, #src> - %11 = tt.load %10 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x128xf16, #src> - %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<128x128x!tt.ptr, #dst> - """ + conversion + """ - %14 = tt.addptr %3, %12 : tensor<128x128x!tt.ptr, #dst>, tensor<128x128xi32, #dst> - tt.store %14, %13 : tensor<128x128xf16, #dst> + ir = layouts + f""" + module attributes {{"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.threads-per-warp" = 32 : i32}} {{ + tt.func public @kernel_0d1d(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #src> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> + %1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> + %2 = tt.splat %arg0 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #src> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : (tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>) -> tensor<{M}x1xi32, #src> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : (tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>) -> tensor<1x{N}xi32, #src> + %7 = tt.broadcast %6 : (tensor<1x{N}xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %8 = tt.broadcast %5 : (tensor<{M}x1xi32, #src>) -> tensor<{M}x{N}xi32, #src> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #src>, tensor<{M}x{N}xi32, #src> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}xf16, #src> + %3 = tt.splat %arg1 : (!tt.ptr) -> tensor<{M}x{N}x!tt.ptr, #dst> + """ + conversion + f""" + %14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr, #dst>, tensor<{M}x{N}xi32, #dst> + tt.store %14, %13 : tensor<{M}x{N}xf16, #dst> tt.return - } -} + }} +}} """ - x = to_triton(numpy_random(shape, dtype_str=dtype), device=device) + x = to_triton(numpy_random((M, N), dtype_str=dtype), device=device) z = torch.empty_like(x) # write the IR to a temporary file using mkstemp @@ -3695,3 +3893,86 @@ def test_ptx_cast(dtype_str, device): buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) assert buf14.to(torch.float32).mean() == -2.0 + +# ----------------------- +# test fp8 -> fp32 dot +# ----------------------- + + +def f8_to_f16(x, dtype): + @triton.jit + def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < N + x = tl.load(X + offs, mask=mask) + tl.store(Y + offs, x, mask=mask) + + ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + dtype = getattr(tl, dtype) + kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) + return ret + + +@triton.jit +def matmul_kernel( + a_ptr, b_ptr, c_ptr, + M, N, K, + stride_am, stride_ak, + stride_bk, stride_bn, + stride_cm, stride_cn, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + low_precision_acc: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + accumulator = tl.dot(a, b, acc=accumulator, max_num_imprecise_acc=low_precision_acc) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + tl.store(c_ptrs, accumulator) + + +@pytest.mark.parametrize("in_type_str", ['float8e5', 'float8e4nv']) +@pytest.mark.parametrize("low_precision_acc", [0, 32, 64, 128]) +def test_fp8_dot_acc(in_type_str, low_precision_acc, device): + check_type_supported(in_type_str, device) + M, N, K = 128, 256, 256 + BLOCK_M, BLOCK_N, BLOCK_K = 128, 256, 128 + A = numpy_random((M, K), dtype_str=in_type_str) + B = numpy_random((K, N), dtype_str=in_type_str) + Bt = B.T + C = torch.empty((M, N), dtype=torch.float32, device='cuda') + num_warps = 8 + a = to_triton(A, device='cuda', dst_type=in_type_str) + b = to_triton(B, device='cuda', dst_type=in_type_str) + grid = (triton.cdiv(M, BLOCK_M), 1) + matmul_kernel[grid](a, b, C, M, N, K, + a.stride(0), a.stride(1), b.stride(0), b.stride( + 1), C.stride(0), C.stride(1), + BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) + torch_a = torch.from_numpy(A) + th_a = f8_to_f16(torch_a.cuda(), in_type_str) + torch_b = torch.from_numpy(B) + th_b = f8_to_f16(torch_b.cuda(), in_type_str) + ref_out = torch.matmul(th_a, th_b).to(torch.float32) + if in_type_str == 'float8e4nv': + torch.testing.assert_close(ref_out, C, rtol=0.01, atol=0.01) + elif low_precision_acc > 32: + torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3) + else: + torch.testing.assert_close(ref_out, C) diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index 2823cf929..fc73f2bf3 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -6,6 +6,7 @@ import torch import triton import triton.language as tl +from triton.common.backend import path_to_nvdisasm @triton.jit @@ -50,10 +51,11 @@ def kernel_multi_files(X, Y, BLOCK: tl.constexpr): def extract_file_lines(asm): + nvdisasm, _ = path_to_nvdisasm() fd, path = tempfile.mkstemp() with open(fd, 'wb') as cubin: cubin.write(asm) - asm = subprocess.check_output(["nvdisasm", "-g", path]).decode("utf-8") + asm = subprocess.check_output([nvdisasm, "-g", path]).decode("utf-8") file_lines = [] lines = asm.splitlines() for line in lines: @@ -80,7 +82,7 @@ func_types = ["single", "call", "call_noinline", "multi_files"] @pytest.mark.parametrize("func", func_types) def test_line_info(func: str): try: - subprocess.check_output(["nvdisasm", "-h"]) + _, _ = path_to_nvdisasm() except BaseException: pytest.skip("nvdisasm is not available") @@ -99,20 +101,20 @@ def test_line_info(func: str): file_lines = extract_file_lines(kernel_info.asm["cubin"]) if func == "single": - assert (check_file_lines(file_lines, "test_line_info.py", 15)) assert (check_file_lines(file_lines, "test_line_info.py", 16)) + assert (check_file_lines(file_lines, "test_line_info.py", 17)) elif func == "call": - assert (check_file_lines(file_lines, "test_line_info.py", 28)) - assert (check_file_lines(file_lines, "test_line_info.py", 21)) - assert (check_file_lines(file_lines, "test_line_info.py", 30)) + assert (check_file_lines(file_lines, "test_line_info.py", 29)) + assert (check_file_lines(file_lines, "test_line_info.py", 22)) + assert (check_file_lines(file_lines, "test_line_info.py", 31)) elif func == "call_noinline": - assert (check_file_lines(file_lines, "test_line_info.py", 42)) - assert (check_file_lines(file_lines, "test_line_info.py", 35)) + assert (check_file_lines(file_lines, "test_line_info.py", 43)) assert (check_file_lines(file_lines, "test_line_info.py", 36)) assert (check_file_lines(file_lines, "test_line_info.py", 37)) + assert (check_file_lines(file_lines, "test_line_info.py", 38)) elif func == "multi_files": - assert (check_file_lines(file_lines, "test_line_info.py", 47)) - assert (check_file_lines(file_lines, "test_line_info.py", 49)) + assert (check_file_lines(file_lines, "test_line_info.py", 48)) + assert (check_file_lines(file_lines, "test_line_info.py", 50)) assert (check_file_lines(file_lines, "standard.py", 33)) assert (check_file_lines(file_lines, "standard.py", 34)) assert (check_file_lines(file_lines, "standard.py", 36)) diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 5f94cd8b3..7e6f820a3 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -86,9 +86,9 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= da_tri = a_tri.grad db_tri = b_tri.grad # compare - torch.testing.assert_allclose(c_ref, c_tri) - torch.testing.assert_allclose(da_ref, da_tri) - torch.testing.assert_allclose(db_ref, db_tri) + torch.testing.assert_close(c_ref, c_tri) + torch.testing.assert_close(da_ref, da_tri) + torch.testing.assert_close(db_ref, db_tri) configs = [ @@ -138,8 +138,8 @@ def test_softmax(BLOCK, WIDTH, is_dense, Z=2, H=2, is_causal=True, scale=0.4): out_tri.backward(dout_tri) da_tri = a_tri.grad # compare - torch.testing.assert_allclose(out_tri, out_ref) - torch.testing.assert_allclose(da_tri, da_ref) + torch.testing.assert_close(out_tri, out_ref, equal_nan=True) + torch.testing.assert_close(da_tri, da_ref, equal_nan=True) @pytest.mark.parametrize("block", [16, 32, 64]) @@ -195,9 +195,9 @@ def test_attention_fwd_bwd( # comparison # print(f"Triton loss {loss} and torch loss {torch_loss}. Also checking grads...") - torch.testing.assert_allclose(loss, torch_loss, atol=1e-3, rtol=0) + torch.testing.assert_close(loss, torch_loss, atol=1e-3, rtol=0) for g1, g2 in zip(grads, torch_grads): - torch.testing.assert_allclose(g1, g2) + torch.testing.assert_close(g1, g2) @pytest.mark.parametrize("block", [16, 32, 64]) diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index f4e40d3a6..f6ae42ac3 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -25,7 +25,7 @@ def test_op(M, N, dtype, mode): tt_y = triton.ops.cross_entropy(x, idx) th_y = torch.nn.CrossEntropyLoss(reduction="none")(x, idx) if mode == 'forward': - torch.testing.assert_allclose(th_y, tt_y) + torch.testing.assert_close(th_y, tt_y) # backward pass elif mode == 'backward': dy = torch.randn_like(tt_y) @@ -36,5 +36,7 @@ def test_op(M, N, dtype, mode): x.grad = None th_y.backward(dy) th_dx = x.grad.clone() - - torch.testing.assert_allclose(th_dx, tt_dx) + if dtype == torch.float16: + torch.testing.assert_close(th_dx, tt_dx, rtol=0.001, atol=0.001) + else: + torch.testing.assert_close(th_dx, tt_dx) diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 8bf7f2e21..b43184265 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -5,22 +5,22 @@ import triton import triton.ops -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 16), - (4, 48, 1024, 32), - (4, 48, 1024, 64), - (4, 48, 1024, 128)]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128)]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('seq_par', [True, False]) def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): - # with ENABLE_TMA=0 and ENABLE_MMA_V3=0 import os - enable_mmav3 = os.environ.get('ENABLE_MMA_V3', 'not found').lower() enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() - if enable_mmav3 in ["on", "true", "1"] and enable_tma in ["on", "true", "1"]: - pytest.skip('Segmentation fault') + if enable_tma in ["on", "true", "1"]: + if dtype == torch.bfloat16: + pytest.skip('bfloat16 tma not support currently') capability = torch.cuda.get_device_capability() +<<<<<<< HEAD if torch.version.hip is not None: if dtype != torch.float16: pytest.skip("Currently flash attention on AMD gpu is only supported in fp16.") @@ -31,6 +31,11 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): if capability[0] < 8: pytest.skip("Flash attention only supported for compute capability < 80") +======= + interpreter = os.environ.get("TRITON_INTERPRET", 'not found') in ["on", "true", "1"] + if not interpreter and capability[0] < 8: + pytest.skip("Flash attention only supported for compute capability >= 80") +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 torch.manual_seed(20) q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() @@ -63,8 +68,15 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): tri_dq, q.grad = q.grad.clone(), None # compare atol = 1e-1 if dtype == torch.bfloat16 else 1e-2 +<<<<<<< HEAD torch.testing.assert_allclose(ref_out, tri_out, atol=atol, rtol=0) if torch.version.hip is None: torch.testing.assert_allclose(ref_dv, tri_dv, atol=atol, rtol=0) torch.testing.assert_allclose(ref_dk, tri_dk, atol=atol, rtol=0) torch.testing.assert_allclose(ref_dq, tri_dq, atol=atol, rtol=0) +======= + torch.testing.assert_close(ref_out, tri_out, atol=atol, rtol=0) + torch.testing.assert_close(ref_dv, tri_dv, atol=atol, rtol=0) + torch.testing.assert_close(ref_dk, tri_dk, atol=atol, rtol=0) + torch.testing.assert_close(ref_dq, tri_dq, atol=atol, rtol=0) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 diff --git a/python/test/unit/operators/test_inductor.py b/python/test/unit/operators/test_inductor.py index f7e2ce2aa..579d0ad93 100644 --- a/python/test/unit/operators/test_inductor.py +++ b/python/test/unit/operators/test_inductor.py @@ -1,3 +1,4 @@ +import pytest import torch import triton @@ -52,7 +53,7 @@ def test_normalization_with_remat(): arg8_1 = torch.rand(64, device="cuda") arg9_1 = torch.rand(64, device="cuda") triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) - torch.testing.assert_allclose(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) + torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) def test_avg_pool_bw(): @@ -152,4 +153,43 @@ def test_avg_pool_bw(): out_ref[:, :, 1:7, 0::7] = 2 / 3 out_ref[:, :, 0::7, 1:7] = 2 / 3 out_ref[:, :, 0::7, 0::7] = 4 / 9 - torch.testing.assert_allclose(out, out_ref) + torch.testing.assert_close(out, out_ref) + + +@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) +@pytest.mark.parametrize("num_warps", [1, 4]) +def test_scan2d_broadcast(RBLOCK, num_warps): + @triton.jit(debug=True) + def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + rindex = tl.arange(0, RBLOCK)[None, :] + xindex = tl.arange(0, XBLOCK)[:, None] + data = tl.load(in_ptr + rindex) + scan = tl.cumsum(data, 1) + expected_max = tl.sum(data, 1) + tl.device_assert(scan <= expected_max) + tl.store(out_ptr + xindex * RBLOCK + rindex, scan) + + XBLOCK = 4 + input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda') + output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda') + fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps) + ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)) + torch.testing.assert_close(output, ref) + + +def test_scan2d_for(): + @triton.jit + def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): + rbase = tl.arange(0, RBLOCK)[None, :] + for roffset in range(0, rnumel, RBLOCK): + rindex = roffset + rbase + rmask = rindex < rnumel + tmp3 = tl.where(rmask, 1, 0) + tmp6 = tl.cumsum(tmp3, 1) + tl.store(out_ptr0 + rindex, tmp6, rmask) + + RBLOCK = 8 + out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64) + fn[(1,)](out0, RBLOCK, RBLOCK) + ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1 + torch.testing.assert_close(out0, ref) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index a7afa02f1..642b0982b 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -26,61 +26,61 @@ def f8_to_f16(x, dtype): @pytest.mark.parametrize( - "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32", + "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM", itertools.chain( *[ [ # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True), + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), # variable input - (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True), - (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True), - (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True), - (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True), + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] ], # n-stage *[ [ - (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True), - (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True), - (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True), - (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True), - (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True), + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] ], # mixed-precision *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), ] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"), ("float8e4nv", "float8e4nv"), ("float8e5", "float8e4nv"), @@ -91,14 +91,14 @@ def f8_to_f16(x, dtype): ("float16", "float32"), ("float32", "float16"), ("bfloat16", "float32"), - ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] + ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False] ], # mixed-precision block layout *[ [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False), + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), ] for ADTYPE, BDTYPE in [("float8e4nv", "float16"), ("float16", "float8e5"), ("float16", "float32"), @@ -108,7 +108,7 @@ def f8_to_f16(x, dtype): ], ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -176,7 +176,7 @@ def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, a = triton.reinterpret(a, getattr(tl, ADTYPE)) if b_fp8: b = triton.reinterpret(b, getattr(tl, BDTYPE)) - tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32) - torch.testing.assert_allclose(th_c, tt_c, atol=0, rtol=0) + tt_c = triton.ops.matmul(a, b, None, ALLOW_TF32, F8_FASTACCUM) + torch.testing.assert_close(th_c, tt_c) except triton.OutOfResources as e: pytest.skip(str(e)) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index 6f9b94d90..f75fa7c32 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -105,6 +105,21 @@ def test_specialize(mode): assert counter == target +def test_annotation(): + @triton.jit + def kernel(X, i: tl.int32): + tl.store(X, i) + + x = torch.empty(1, dtype=torch.int32, device='cuda') + + device = torch.cuda.current_device() + kernel[(1,)](x, 1) + kernel[(1,)](x, 8) + kernel[(1,)](x, 16) + kernel[(1,)](x, 17) + assert len(kernel.cache[device]) == 4 + + def test_constexpr_not_callable() -> None: @triton.jit def kernel(X, c: tl.constexpr): @@ -138,13 +153,14 @@ def test_jit_warmup_cache() -> None: torch.randn(32, dtype=torch.float32, device="cuda"), 32, ] - assert len(kernel_add.cache) == 0 + device = torch.cuda.current_device() + assert len(kernel_add.cache[device]) == 0 kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) - assert len(kernel_add.cache) == 1 + assert len(kernel_add.cache[device]) == 1 kernel_add.warmup(*args, grid=(1,)) - assert len(kernel_add.cache) == 1 + assert len(kernel_add.cache[device]) == 1 kernel_add.warmup(*args, grid=(1,)) - assert len(kernel_add.cache) == 1 + assert len(kernel_add.cache[device]) == 1 def test_jit_debug() -> None: diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index 2c6fe88a2..06a7ed2b1 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -330,7 +330,7 @@ def test_compile_link_autotune_matmul(): def test_ttgir_to_ptx(): src = """ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32, "triton_gpu.num-ctas" = 1 : i32} { - tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { + tt.func public @sum_kernel_0d1d(%arg0: !tt.ptr, %arg1: !tt.ptr) { tt.return } } diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py index 5ec015d18..b76953f5d 100644 --- a/python/triton/common/backend.py +++ b/python/triton/common/backend.py @@ -101,20 +101,34 @@ def get_backend(device_type: str): return _backends[device_type] if device_type in _backends else None -@functools.lru_cache() -def path_to_ptxas(): +def _path_to_binary(binary: str): base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ os.environ.get("TRITON_PTXAS_PATH", ""), - os.path.join(base_dir, "third_party", "cuda", "bin", "ptxas") + os.path.join(base_dir, "third_party", "cuda", "bin", binary) ] - for ptxas in paths: - ptxas_bin = ptxas.split(" ")[0] - if os.path.exists(ptxas_bin) and os.path.isfile(ptxas_bin): - result = subprocess.check_output([ptxas_bin, "--version"], stderr=subprocess.STDOUT) + for p in paths: + bin = p.split(" ")[0] + if os.path.exists(bin) and os.path.isfile(bin): + result = subprocess.check_output([bin, "--version"], stderr=subprocess.STDOUT) if result is not None: version = re.search(r".*release (\d+\.\d+).*", result.decode("utf-8"), flags=re.MULTILINE) if version is not None: - return ptxas, version.group(1) - raise RuntimeError("Cannot find ptxas") + return p, version.group(1) + raise RuntimeError(f"Cannot find {binary}") + + +@functools.lru_cache() +def path_to_ptxas(): + return _path_to_binary("ptxas") + + +@functools.lru_cache() +def path_to_cuobjdump(): + return _path_to_binary("cuobjdump") + + +@functools.lru_cache() +def path_to_nvdisasm(): + return _path_to_binary("nvdisasm") diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index aa7dfcbfb..21bd0cbf4 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -199,7 +199,7 @@ class ContainsReturnChecker(ast.NodeVisitor): class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, arch, + def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None, is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0): self.context = context @@ -208,7 +208,7 @@ class CodeGenerator(ast.NodeVisitor): # node.lineno starts from 1, so we need to subtract 1 self.begin_line = begin_line - 1 self.builder.set_loc(file_name, begin_line, 0) - self.builder.arch = arch + self.builder.target = target self.module = self.builder.create_module() if module is None else module self.function_ret_types = {} if function_types is None else function_types self.prototype = prototype @@ -912,7 +912,7 @@ class CodeGenerator(ast.NodeVisitor): file_name, begin_line = _get_fn_file_line(fn) generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline, - file_name=file_name, begin_line=begin_line, arch=self.builder.arch) + file_name=file_name, begin_line=begin_line, target=self.builder.target) generator.visit(fn.parse()) callee_ret_type = generator.last_ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -1108,7 +1108,7 @@ def kernel_suffix(signature, specialization): return suffix -def ast_to_ttir(fn, signature, specialization, constants, debug, arch): +def ast_to_ttir(fn, signature, specialization, constants, debug, target): # canonicalize signature if isinstance(signature, str): signature = {k: v.strip() for k, v in enumerate(signature.split(","))} @@ -1137,7 +1137,7 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, arch): generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line, - arch=arch) + target=target) try: generator.visit(fn.parse()) except CompilationError as e: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index dd6dbb760..8cce0d7a3 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -5,10 +5,18 @@ import hashlib import json import os import re +<<<<<<< HEAD import tempfile from collections import namedtuple from pathlib import Path from typing import Any +======= +from collections import namedtuple +from pathlib import Path +from typing import Any + +from dataclasses import dataclass +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars, get_num_warps, @@ -20,11 +28,11 @@ from ..common.build import is_hip # from ..runtime import driver, jit, JITFunction # TODO: runtime.errors from ..runtime.autotuner import OutOfResources -from ..runtime.cache import get_cache_manager +from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager from ..runtime.driver import driver from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability, version_key) -from ..tools.disasm import extract +from ..tools.disasm import get_sass from .code_generator import ast_to_ttir from .make_launcher import make_stub from .utils import (InfoFromBackendForTensorMap, TensorMapManager, @@ -32,6 +40,24 @@ from .utils import (InfoFromBackendForTensorMap, TensorMapManager, CUDA_DEFAULT_WARP_SIZE = 32 +@dataclass +class CudaTargetDescriptor: + capability: int + num_warps: int + + +def _is_cuda(target): + return isinstance(target, CudaTargetDescriptor) + + +class LazyDict(dict): + def __getitem__(self, key): + val = dict.__getitem__(self, key) + if callable(val): + return val() + return val + + def inline_triton_ir(mod): pm = ir.pass_manager(mod.context) pm.enable_debug() @@ -40,11 +66,12 @@ def inline_triton_ir(mod): return mod -def ttir_compute_capability_rewrite(mod, arch): +def ttir_compute_capability_rewrite(mod, target): # For hardware without support, we must rewrite all load/store # with block (tensor) pointers into tensors of pointers pm = ir.pass_manager(mod.context) pm.enable_debug() +<<<<<<< HEAD if _is_cuda(arch): pm.add_rewrite_tensor_pointer_pass(arch, False) elif is_hip(): @@ -52,13 +79,17 @@ def ttir_compute_capability_rewrite(mod, arch): pm.add_rewrite_tensor_pointer_pass(capability, True) else: assert(False, "unsupported target") +======= + if _is_cuda(target): + pm.add_rewrite_tensor_pointer_pass(target.capability) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.run(mod) return mod -def optimize_ttir(mod, arch): +def optimize_ttir(mod, target): mod = inline_triton_ir(mod) - mod = ttir_compute_capability_rewrite(mod, arch) + mod = ttir_compute_capability_rewrite(mod, target) pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_inliner_pass() @@ -72,6 +103,7 @@ def optimize_ttir(mod, arch): return mod +<<<<<<< HEAD def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch): pm = ir.pass_manager(mod.context) pm.enable_debug() @@ -79,21 +111,36 @@ def ttir_to_ttgir(mod, num_warps, warpsize, num_ctas, arch): pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, 0) else: pm.add_convert_triton_to_tritongpu_pass(num_warps, warpsize, num_ctas, arch) +======= +def ttir_to_ttgir(mod, num_warps, num_ctas, target): + pm = ir.pass_manager(mod.context) + pm.enable_debug() + pm.add_convert_triton_to_tritongpu_pass(num_warps, 32, num_ctas, target.capability) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.run(mod) return mod +<<<<<<< HEAD def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue, matrix_inst_type): +======= +def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, + cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue): + is_cuda = _is_cuda(target) + if is_cuda: + capability = target.capability +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm = ir.pass_manager(mod.context) pm.enable_debug() pm.add_tritongpu_coalesce_pass() # TODO(Qingyi): Move PlanCTAPass to the front of CoalescePass pm.add_plan_cta_pass(cluster_info) - if _is_cuda(arch): - pm.add_tritongpu_rewrite_tensor_pointer_pass(arch) + if is_cuda: + pm.add_tritongpu_rewrite_tensor_pointer_pass(capability) pm.add_plan_cta_pass(cluster_info) pm.add_tritongpu_remove_layout_conversions_pass() +<<<<<<< HEAD if _is_cuda(arch): pm.add_tritongpu_accelerate_matmul_pass(arch) # TODO change interface of accelerate_matmul_pass @@ -101,6 +148,10 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, matrix_core_version = gpu_matrix_core_version() matrix_inst_size = matrix_inst_type pm.add_tritonamdgpu_accelerate_matmul_pass(matrix_core_version, matrix_inst_size) +======= + if is_cuda: + pm.add_tritongpu_accelerate_matmul_pass(capability) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.add_tritongpu_remove_layout_conversions_pass() if optimize_epilogue: pm.add_tritongpu_optimize_epilogue_pass() @@ -114,20 +165,25 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, # it's the responsibility of the compiler to figure out the exact # `num_warps` to use. # TODO: support the case where `num_warps` from user is not 4. +<<<<<<< HEAD if _is_cuda(arch) and arch // 10 >= 9 and enable_warp_specialization and num_warps == 4: pm.add_tritongpu_ws_feasibility_checking_pass(arch) +======= + if capability // 10 >= 9 and enable_warp_specialization and num_warps == 4: + pm.add_tritongpu_ws_feasibility_checking_pass(capability) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.run(mod) ws_enabled = ir.is_ws_supported(mod) pm = ir.pass_manager(mod.context) pm.enable_debug() if ws_enabled: - pm.add_tritongpu_wsdecomposing_pass(arch) - pm.add_tritongpu_wspipeline_pass( - num_stages, num_warps, arch) - pm.add_tritongpu_wsmutex_pass(arch) - pm.add_tritongpu_wsmaterialization_pass(arch) + pm.add_tritongpu_wsdecomposing_pass(capability) + pm.add_tritongpu_wspipeline_pass(num_stages, num_warps, capability) + pm.add_tritongpu_wsmutex_pass(capability) + pm.add_tritongpu_wsmaterialization_pass(capability) pm.add_cse_pass() else: +<<<<<<< HEAD if is_hip(): pm.add_tritongpu_pipeline_pass( num_stages, num_warps, num_ctas, 0) @@ -139,6 +195,11 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, else: pm.add_tritongpu_materialize_load_store_pass(num_warps, arch) if _is_cuda(arch) and arch // 10 <= 8: +======= + pm.add_tritongpu_pipeline_pass(num_stages, num_warps, num_ctas, capability) + pm.add_tritongpu_materialize_load_store_pass(num_warps, capability) + if capability // 10 <= 8: +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.add_tritongpu_prefetch_pass() pm.add_tritongpu_optimize_dot_operands_pass() pm.add_tritongpu_remove_layout_conversions_pass() @@ -148,7 +209,11 @@ def optimize_ttgir(mod, num_stages, num_warps, num_ctas, arch, pm.add_tritongpu_reorder_instructions_pass() pm.add_cse_pass() pm.add_symbol_dce_pass() +<<<<<<< HEAD if _is_cuda(arch) and arch // 10 >= 9: +======= + if capability // 10 >= 9: +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 pm.add_tritongpu_fence_insertion_pass() pm.add_tritongpu_ws_fixup_missing_attrs_pass() pm.run(mod) @@ -162,12 +227,21 @@ def _add_external_libs(mod, libs): add_external_libs(mod, list(libs.keys()), list(libs.values())) +<<<<<<< HEAD def ttgir_to_llir(mod, extern_libs, arch, tma_infos, waves_per_eu=0): if extern_libs: _add_external_libs(mod, extern_libs) # TODO: separate tritongpu_to_llvmir for different backends if _is_cuda(arch): return translate_triton_gpu_to_llvmir(mod, arch, tma_infos, runtime.TARGET.NVVM, waves_per_eu) +======= +def ttgir_to_llir(mod, extern_libs, target, tma_infos): + if extern_libs: + _add_external_libs(mod, extern_libs) + # TODO: separate tritongpu_to_llvmir for different backends + if _is_cuda(target): + return translate_triton_gpu_to_llvmir(mod, target.capability, tma_infos, runtime.TARGET.NVVM) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 else: return translate_triton_gpu_to_llvmir(mod, 0, TMAInfos(), runtime.TARGET.ROCDL, waves_per_eu) @@ -190,7 +264,7 @@ def ptx_get_version(cuda_version) -> int: raise RuntimeError("Triton only support CUDA 10.0 or higher") -def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str: +def llir_to_ptx(mod: Any, target: CudaTargetDescriptor, ptx_version: int = None) -> str: ''' Translate TritonGPU module to PTX code. :param mod: a TritonGPU dialect module @@ -199,10 +273,10 @@ def llir_to_ptx(mod: Any, arch: int, ptx_version: int = None) -> str: if ptx_version is None: _, cuda_version = path_to_ptxas() ptx_version = ptx_get_version(cuda_version) - return translate_llvmir_to_ptx(mod, arch, ptx_version) + return translate_llvmir_to_ptx(mod, target.capability, ptx_version) -def ptx_to_cubin(ptx: str, arch: int): +def ptx_to_cubin(ptx: str, target: CudaTargetDescriptor): ''' Compile TritonGPU module to cubin. :param ptx: ptx code @@ -210,7 +284,11 @@ def ptx_to_cubin(ptx: str, arch: int): :return: str ''' ptxas, _ = path_to_ptxas() +<<<<<<< HEAD return compile_ptx_to_cubin(ptx, ptxas, arch) +======= + return compile_ptx_to_cubin(ptx, ptxas, target.capability) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # ------------------------------------------------------------------------------ @@ -230,13 +308,15 @@ def get_kernel_name(src: str, pattern: str) -> str: def convert_type_repr(x): - match = re.search(r'!tt\.ptr<(.*)>', x) + # Currently we only capture the pointer type and assume the pointer is on global memory. + # TODO: Capture and support shared memory space + match = re.search(r'!tt\.ptr<([^,]+)', x) if match is not None: return '*' + convert_type_repr(match.group(1)) return x -def make_hash(fn, arch, env_vars, **kwargs): +def make_hash(fn, target, env_vars, **kwargs): if isinstance(fn, JITFunction): configs = kwargs["configs"] signature = kwargs["signature"] @@ -253,9 +333,16 @@ def make_hash(fn, arch, env_vars, **kwargs): get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] +<<<<<<< HEAD key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{waves_per_eu}-{matrix_instr_nonkdim}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{arch}-{env_vars_list}" +======= + key = f"{fn.cache_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 return hashlib.md5(key.encode("utf-8")).hexdigest() assert isinstance(fn, str) + ignore_version = kwargs.get('ignore_version', False) + if (ignore_version): + return hashlib.md5((Path(fn).read_text()).encode("utf-8")).hexdigest() return hashlib.md5((Path(fn).read_text() + version_key()).encode("utf-8")).hexdigest() @@ -266,7 +353,8 @@ def make_hash(fn, arch, env_vars, **kwargs): # (letters, digits, or underscores), and capture it as group 1 (the function name) # - (\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\)) : match a pair of parentheses enclosing # zero or more arguments separated by commas, and capture it as group 2 (the argument list) -mlir_prototype_pattern = r'^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: \S+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*\{\s*$' +# - (attributes \{[\S\s]+\})? : optionally match attributes enclosed in braces and capture it as group 3 +mlir_prototype_pattern = r"^\s*tt\.func\s+(?:public\s+)?(@\w+)(\((?:%\w+: [\S\s]+(?: \{\S+ = \S+ : \S+\})?(?:, )?)*\))\s*(attributes \{[\S\s]+\})?\s+\{\s*$" ptx_prototype_pattern = r"\.(?:visible|extern)\s+\.(?:entry|func)\s+(\w+)\s*\(([^)]*)\)" prototype_pattern = { "ttir": mlir_prototype_pattern, @@ -274,7 +362,11 @@ prototype_pattern = { "ptx": ptx_prototype_pattern, } -mlir_arg_type_pattern = r'%\w+: ([^,^\)\s]+)(?: \{\S+ = \S+ : \S+\})?,?' +# - ((?:[^,\s<]+|<[^>]+>)+): Capturing group that matches one or more of either: +# [^,\s<]+: One or more characters that are not a comma, whitespace, or the < symbol. +# |: OR +# <[^>]+>: A string that starts with < and ends with >, containing any characters except > in between. +mlir_arg_type_pattern = r'%\w+: ((?:[^,\s<]+|<[^>]+>)+),?' ptx_arg_type_pattern = r"\.param\s+\.(\w+)" arg_type_pattern = { "ttir": mlir_arg_type_pattern, @@ -282,7 +374,11 @@ arg_type_pattern = { "ptx": ptx_arg_type_pattern, } if is_hip(): +<<<<<<< HEAD ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' +======= + ttgir_num_warps_pattern = r'"triton_gpu_rocm.num-warps"\s?=\s?(\d+)\s?:' +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 else: ttgir_num_warps_pattern = r'"triton_gpu.num-warps"\s?=\s?(\d+)\s?:' @@ -311,6 +407,7 @@ def parse_mlir_module(path, context): instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()]) +<<<<<<< HEAD # TODO: architecture descriptor class def _is_cuda(arch): return isinstance(arch, int) @@ -337,6 +434,15 @@ def get_architecture_descriptor(capability): capability = get_device_capability(device) capability = capability[0] * 10 + capability[1] return capability +======= +def get_cuda_capability(capability): + if capability is None: + device = get_current_device() + capability = get_device_capability(device) + capability = capability[0] * 10 + capability[1] + return capability + +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 @functools.lru_cache def get_arch_default_num_warps(device_type): @@ -347,15 +453,19 @@ def get_arch_default_num_warps(device_type): assert _device_backend arch = _device_backend.get_architecture_descriptor() num_warps = arch["num_warps"] - return num_warps @functools.lru_cache def get_arch_default_num_stages(device_type, capability=None): +<<<<<<< HEAD if device_type in ["cuda"]: arch = get_architecture_descriptor(capability) is_cuda = device_type == "cuda" and _is_cuda(arch) num_stages = 3 if is_cuda and arch >= 75 else 2 +======= + if device_type == "cuda": + num_stages = 3 if get_cuda_capability(capability) >= 75 else 2 +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 else: _device_backend = get_backend(device_type) assert _device_backend @@ -365,11 +475,16 @@ def get_arch_default_num_stages(device_type, capability=None): return num_stages +<<<<<<< HEAD def add_cuda_stages(arch, extern_libs, stages): +======= +def add_cuda_stages(target, extern_libs, stages): + +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 stages["ptx"] = (lambda path: Path(path).read_text(), - lambda src: llir_to_ptx(src, arch)) + lambda src: llir_to_ptx(src, target)) stages["cubin"] = (lambda path: Path(path).read_bytes(), - lambda src: ptx_to_cubin(src, arch)) + lambda src: ptx_to_cubin(src, target)) def compile(fn, **kwargs): @@ -379,6 +494,7 @@ def compile(fn, **kwargs): if is_hip(): device_type = "hip" +<<<<<<< HEAD capability = None if device_type == "cuda": @@ -393,6 +509,12 @@ def compile(fn, **kwargs): if is_hip(): is_cuda = False warp_size = CUDA_DEFAULT_WARP_SIZE if _is_cuda(arch) else arch["warp_size"] +======= + is_cuda = device_type == "cuda" + if is_hip(): + is_cuda = False + +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 context = ir.context() constants = kwargs.get("constants", dict()) num_warps = kwargs.get("num_warps", get_arch_default_num_warps(device_type)) @@ -420,11 +542,23 @@ def compile(fn, **kwargs): cluster_info.clusterDimY = kwargs["clusterDims"][1] cluster_info.clusterDimZ = kwargs["clusterDims"][2] tma_infos = TMAInfos() +<<<<<<< HEAD +======= + # build architecture descriptor + if device_type == "cuda": + _device_backend = get_backend(device_type) + target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps) + else: + _device_backend = get_backend(device_type) + assert _device_backend + target = _device_backend.get_architecture_descriptor(**kwargs) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # build compilation stages stages = dict() stages["ast"] = (lambda path: fn, None) stages["ttir"] = (lambda path: parse_mlir_module(path, context), +<<<<<<< HEAD lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch)) if is_cuda: stages["ttgir"] = (lambda path: parse_mlir_module(path, context), @@ -456,12 +590,23 @@ def compile(fn, **kwargs): stages["llir"] = (lambda path: Path(path).read_text(), lambda src: ttgir_to_llir(src, extern_libs, arch, tma_infos)) _device_backend.add_stages(arch, extern_libs, stages) +======= + lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) + if is_cuda: + stages["ttgir"] = (lambda path: parse_mlir_module(path, context), + lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) + stages["llir"] = (lambda path: Path(path).read_text(), + lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos)) + add_cuda_stages(target, extern_libs, stages) + elif device_type == "hip": + _device_backend.add_stages(target, extern_libs, stages, num_warps=num_warps, num_stages=num_stages) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 else: # pass the user's configuration to the backend device. - arch["num_warps"] = num_warps - arch["num_stages"] = num_stages - arch["num_ctas"] = num_ctas - _device_backend.add_stages(arch, extern_libs, stages) + target["num_warps"] = num_warps + target["num_stages"] = num_stages + target["num_ctas"] = num_ctas + _device_backend.add_stages(target, extern_libs, stages) # find out the signature of the function if isinstance(fn, JITFunction): @@ -482,6 +627,7 @@ def compile(fn, **kwargs): src = Path(fn).read_text() import re match = re.search(prototype_pattern[ir_name], src, re.MULTILINE) + # TODO: support function attributes at group 3 (e.g., device function) name, signature = match.group(1), match.group(2) types = re.findall(arg_type_pattern[ir_name], signature) if ir_name == 'ttgir': @@ -494,7 +640,12 @@ def compile(fn, **kwargs): first_stage = list(stages.keys()).index(ir_name) # create cache manager - fn_cache_manager = get_cache_manager(make_hash(fn, arch, get_env_vars(), **kwargs)) + fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), **kwargs)) + # managers used to dump and override IR for debugging + enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" + fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True)) + fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), **kwargs, ignore_version=True)) + # determine name and extension type of provided function if isinstance(fn, JITFunction): name, ext = fn.__name__, "ast" @@ -529,7 +680,7 @@ def compile(fn, **kwargs): "enable_persistent": enable_persistent, "constants": _get_jsonable_constants(constants), "debug": debug, - "arch": arch, } + "target": target, } metadata.update(get_env_vars()) if ext == "ptx": assert "shared" in kwargs, "ptx compilation must provide shared memory size" @@ -539,7 +690,7 @@ def compile(fn, **kwargs): metadata["device_type"] = device_type first_stage = list(stages.keys()).index(ext) - asm = dict() + asm = LazyDict() module = fn # run compilation pipeline and populate metadata for ir_name, (parse, compile_kernel) in list(stages.items())[first_stage:]: @@ -557,7 +708,11 @@ def compile(fn, **kwargs): metadata_group[extra_file_name] = fn_cache_manager.put(next_module[1], extra_file_name) else: metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename) - fn_cache_manager.put(next_module, ir_filename) + fn_dump_manager.put(next_module, ir_filename) + if (enable_override and fn_override_manager.has_file(ir_filename)): + print(f"\nOverriding kernel with file {ir_filename}") + full_name = fn_override_manager.get_file(ir_filename) + next_module = parse(full_name) else: if ir_name == "amdgcn": extra_file_name = f"{name}.hsaco_path" @@ -569,6 +724,7 @@ def compile(fn, **kwargs): if ir_name == "cubin": asm[ir_name] = next_module + asm["sass"] = lambda: get_sass(next_module) elif ir_name == "amdgcn": asm[ir_name] = str(next_module[0]) else: @@ -579,11 +735,19 @@ def compile(fn, **kwargs): else: metadata["shared"] = get_shared_memory_size(module) if ir_name == "ttgir": +<<<<<<< HEAD metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) if metadata["enable_warp_specialization"]: if is_hip(): metadata["num_warps"] = _device_backend.get_num_warps(next_module) else: +======= + if is_hip(): + metadata["num_warps"] = _device_backend.get_num_warps(next_module) + else: + metadata["enable_warp_specialization"] = ir.is_ws_supported(next_module) + if metadata["enable_warp_specialization"]: +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 metadata["num_warps"] = get_num_warps(next_module) if ir_name == "ptx": metadata["name"] = get_kernel_name(next_module, pattern='// .globl') @@ -723,16 +887,3 @@ class CompiledKernel: self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) return runner - - def get_sass(self, fun=None): - if 'sass' in self.asm: - return self.asm['sass'] - fd, path = tempfile.mkstemp() - try: - with open(fd, 'wb') as cubin: - cubin.write(self.asm['cubin']) - self.sass = extract(path, fun) - finally: - os.remove(path) - self.asm['sass'] = self.sass - return self.sass diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index a97d7d11b..ea58455f3 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -63,8 +63,9 @@ def ty_to_cpp(ty): def generate_launcher(constants, signature, ids): - start_desc = len(signature) - signature = generate_cu_signature(constants, signature, ids) + # Record the end of regular arguments; + # subsequent arguments are architecture-specific descriptors, such as tensor descriptors for CUDA. + signature, desc_start_idx = generate_cu_signature(constants, signature, ids) arg_decls = ', '.join(f"{ty_to_cpp(ty)} arg{i}" for i, ty in signature.items()) def _extracted_type(ty): @@ -99,7 +100,11 @@ def generate_launcher(constants, signature, ids): # generate glue code folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] +<<<<<<< HEAD params = [i for i in signature.keys() if i >= start_desc or (i not in constants and i not in folded_without_constexprs)] +======= + params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)] +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 src = f""" #include \"cuda.h\" #include @@ -116,7 +121,10 @@ static inline void gpuAssert(CUresult code, const char *file, int line) char err[1024] = {{0}}; strcat(err, prefix); strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); }} }} @@ -251,6 +259,9 @@ static PyObject* launch(PyObject* self, PyObject* args) {{ Py_BEGIN_ALLOW_THREADS; _launch(gridX, gridY, gridZ, num_warps, num_ctas, clusterDimX, clusterDimY, clusterDimZ, shared_memory, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(f"ptr_info{i}.dev_ptr" if ty[0]=="*" else f"_arg{i}"for i, ty in signature.items()) if len(signature) > 0 else ''}); Py_END_ALLOW_THREADS; + if (PyErr_Occurred()) {{ + return NULL; + }} if (launch_exit_hook != Py_None && !PyObject_CallObject(launch_exit_hook, args)) {{ return NULL; diff --git a/python/triton/interpreter/__init__.py b/python/triton/compiler/target.py similarity index 100% rename from python/triton/interpreter/__init__.py rename to python/triton/compiler/target.py diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index cb4f1f3ab..d4b24a93e 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -26,12 +26,11 @@ from ..runtime import driver def generate_cu_signature(constants, signature, ids): # CUtensorMap*s are always the last arguments + num_regular_signatures = max(signature.keys()) + 1 if len(signature) > 0 else 0 if ids["ids_of_tensormaps"] is not None: - signature = signature.copy() - num_signature = len(signature) for i, _ in enumerate(ids["ids_of_tensormaps"]): - signature[num_signature + i] = '*CUtensorMap' - return signature + signature[num_regular_signatures + i] = '*CUtensorMap' + return signature, num_regular_signatures def dummy_tensormaps_info(n=2): diff --git a/python/triton/interpreter/core.py b/python/triton/interpreter/core.py deleted file mode 100644 index 82f3f43a2..000000000 --- a/python/triton/interpreter/core.py +++ /dev/null @@ -1,9 +0,0 @@ -from typing import Tuple - -import dataclasses - - -@dataclasses.dataclass -class ExecutionContext: - program_id: Tuple[int] - program_size: Tuple[int] diff --git a/python/triton/interpreter/interpreter.py b/python/triton/interpreter/interpreter.py deleted file mode 100644 index 001b80ec9..000000000 --- a/python/triton/interpreter/interpreter.py +++ /dev/null @@ -1,171 +0,0 @@ -import itertools -import random -from typing import Tuple - -from .. import language as tl -# import .language.core as lcore -from ..language import core as lcore -from . import torch_wrapper -from .core import ExecutionContext -from .memory_map import MemoryMap -from .tl_lang import (TritonLangProxy, WrappedTensor, _primitive_to_tensor, - debugger_constexpr) - -torch = torch_wrapper.torch -tl_method_backup = {} - - -def get_proxy_method(proxy, name): - method = getattr(proxy, name) - - def fun(*args, **kwarg): - return method(*args, **kwarg) - - return fun - - -def attach_triton(module, proxy): - method_list = [func for func in dir(TritonLangProxy) if func[0] != "_"] - for name in method_list: - if hasattr(module, name): - attr = getattr(module, name) - tl_method_backup[name] = attr - if callable(attr): - setattr(module, name, get_proxy_method(proxy, name)) - else: - setattr(module, name, getattr(proxy, name)) - - -def detach_triton(module): - for name, method in tl_method_backup.items(): - setattr(module, name, method) - - -def program_ids_from_grid(grid: Tuple[int, ...]) -> Tuple[int, ...]: - # reverse the grid dimensions and generate the range for each dimension - reversed_grid = reversed(grid) - ranges_for_each_dimension = [range(dim) for dim in reversed_grid] - - # gen all combinations - index_combinations = list(itertools.product(*ranges_for_each_dimension)) - random.shuffle(index_combinations) - - for index_combination in index_combinations: - yield index_combination - - -class DebuggerFunction: - def __init__(self, func, grid=(1,)): - self.func = func - self.grid = grid - - def _is_constexpr(self, name): - return name in self.func.__annotations__ and self.func.__annotations__[name] is lcore.constexpr - - def _get_constexpr(self): - result = [] - for name, annotation in self.func.__annotations__.items(): - if annotation is lcore.constexpr: - result.append(name) - return result - - def _assert_constexpr(self, **kwargs): - constexp = self._get_constexpr() - missing = [i for i in constexp if i not in kwargs.keys()] - assert len(missing) == 0, f"You must specify constexpr {missing}" - - def _get_grid(self, **kwargs): - if callable(self.grid): - return self.grid(kwargs) - else: - return self.grid - - def __call__(self, *args, **kwargs): - self._assert_constexpr(**kwargs) - - memory = MemoryMap() - - def convert_arg(v): - name, arg = v - if torch.is_tensor(arg): - ptr = memory.add_tensor(arg) - return WrappedTensor(torch.tensor([ptr], dtype=torch.int64, device="cuda")) - if self._is_constexpr(name): - return debugger_constexpr(arg) - return WrappedTensor(_primitive_to_tensor(arg)) - - new_args = tuple(map(convert_arg, zip(self.func.__code__.co_varnames, args))) - new_kwargs = {k: convert_arg((k, v)) for (k, v) in kwargs.items() if k not in ["num_warps", "num_stages"]} - - grid = self._get_grid(**kwargs) - for program_id in program_ids_from_grid(grid): - proxy = TritonLangProxy(memory, ExecutionContext(program_id, grid)) - attach_triton(tl, proxy) - self.func(*new_args, **new_kwargs) - detach_triton(tl) - - -class GridSelector: - """ - Entry point of the debugger - """ - - def __init__(self, func): - version = torch.__version__ - assert version[0] == "2", f"Triton Debugger only supports torch >= 2.0, using {version}" - self.func = func - - def __getitem__(self, grid): - return DebuggerFunction(self.func, grid) - - def __call__(self, *args, **kwargs): - return DebuggerFunction(self.func)(*args, **kwargs) - - -class AutotuneGridSelector: - def __init__(self, func, autotune_params): - self.func = func - self.autotune_params = autotune_params - - def __getitem__(self, grid): - return AutotuneRunner(self.func, self.autotune_params, grid) - - def __call__(self, *args, **kwargs): - return AutotuneRunner(self.func, self.autotune_params)(*args, **kwargs) - - -class AutotuneRunner: - def __init__(self, func, autotune_params, grid=None): - self.func = func - self.autotune_params = autotune_params - self.grid = grid - - def __call__(self, *args, **kwargs): - assert len(self.autotune_params["configs"]) >= 1 - - for config in self.autotune_params["configs"][1:]: - - def convert_arg(v): - if torch.is_tensor(v): - return torch.clone(v) - return v - - new_args = tuple(map(convert_arg, args)) - new_kwargs = {k: convert_arg(v) for k, v in kwargs.items()} - if self.grid: - self.func[self.grid](*new_args, **new_kwargs, **config.kwargs) - else: - self.func(*new_args, **new_kwargs, **config.kwargs) - - main_config = self.autotune_params["configs"][0] - if self.grid: - self.func[self.grid](*args, **kwargs, **main_config.kwargs) - else: - self.func(*args, **kwargs, **main_config.kwargs) - - -def triton_debug_autotune(**kwars): - def wrapper(func): - return AutotuneGridSelector(func, kwars) - - return wrapper diff --git a/python/triton/interpreter/memory_map.py b/python/triton/interpreter/memory_map.py deleted file mode 100644 index d0ff732a7..000000000 --- a/python/triton/interpreter/memory_map.py +++ /dev/null @@ -1,102 +0,0 @@ -from __future__ import annotations - -import dataclasses - -from . import torch_wrapper - -torch = torch_wrapper.torch - - -@dataclasses.dataclass -class RegisteredStorage: - storage: torch.Storage - dtype: torch.dtype - size: int - ptr: int - - @property - def end_ptr(self) -> int: - return self.ptr + self.size - - @property - def access_tensor(self) -> torch.Tensor: - return torch.tensor(self.storage, dtype=self.dtype, device=self.storage.device) - - def ensure_immutable(self): - assert self.storage.data_ptr() == self.ptr and self.storage.size() == self.size - - -class MemoryMap: - storages: [RegisteredStorage] - - def __init__(self): - self.storages = [] - - def _get_registered_storage(self, pointer: torch.Tensor): - max_pointer = torch.max(pointer).item() - min_pointer = torch.min(pointer).item() - - registered_storage = next( - filter( - lambda registered: min_pointer >= registered.ptr and max_pointer < registered.end_ptr, self.storages - ), - None, - ) - if registered_storage is None: - raise Exception("Storage not found or pointers spanning multiple tensors") - registered_storage.ensure_immutable() - return registered_storage - - def add_tensor(self, t: torch.Tensor): - storage = t.untyped_storage() - self.storages.append(RegisteredStorage(storage, t.dtype, storage.size(), storage.data_ptr())) - return t.data_ptr() - - def load( - self, - pointer: torch.Tensor, - mask: torch.Tensor = None, - other=0.0, - ): - assert pointer.is_cuda - assert 0 < pointer.dim() < 3 - assert pointer.dtype == torch.int64 - - if mask is None: - mask = torch.ones_like(pointer).bool() - assert mask.is_cuda - assert 0 < mask.dim() < 3 - assert mask.dtype == torch.bool - mask = mask.expand(pointer.size()) - - if torch.all(~mask): - # Todo: The type is wrong here, we can't determine the correct type - return torch.full_like(pointer, fill_value=other, dtype=torch.float16, device="cuda") - - registered_storage = self._get_registered_storage(pointer[mask]) - access_tensor = registered_storage.access_tensor - - index_tensor = pointer - registered_storage.ptr - - block = torch.full_like(pointer, fill_value=other, dtype=access_tensor.dtype, device="cuda") - block[mask] = access_tensor[index_tensor[mask]] - return block - - def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): - assert 0 < pointer.dim() < 3 - assert pointer.dtype == torch.int64 - - if mask is None: - mask = torch.ones_like(pointer).bool() - assert 0 < mask.dim() < 3 - assert mask.dtype == torch.bool - mask = mask.expand(pointer.size()) - - if torch.all(~mask): - return - - registered_storage = self._get_registered_storage(pointer[mask]) - access_tensor = registered_storage.access_tensor - - index_tensor = pointer - registered_storage.ptr - access_tensor[index_tensor[mask]] = value[mask].to(access_tensor.dtype) diff --git a/python/triton/interpreter/tl_lang.py b/python/triton/interpreter/tl_lang.py deleted file mode 100644 index e2a578fa5..000000000 --- a/python/triton/interpreter/tl_lang.py +++ /dev/null @@ -1,641 +0,0 @@ -from __future__ import annotations - -from ..language import core as lcore -from . import torch_wrapper -from .core import ExecutionContext -from .memory_map import MemoryMap - -torch = torch_wrapper.torch - - -def _primitive_to_tensor(x): - """ - Converts various Python primitive data types to PyTorch tensor. - """ - tensor_args = {"device": "cuda"} - if isinstance(x, bool): - return torch.tensor([x], dtype=torch.bool, **tensor_args) - elif isinstance(x, int): - if -(2**31) <= x < 2**31: - return torch.tensor([x], dtype=torch.int32, **tensor_args) - elif -(2**63) <= x < 2**63: - return torch.tensor([x], dtype=torch.int64, **tensor_args) - else: - raise RuntimeError(f"Nonrepresentable integer {x}.") - elif isinstance(x, float): - return torch.tensor([x], dtype=torch.float32, **tensor_args) - elif torch.is_tensor(x): - return x - elif isinstance(x, WrappedTensor): - return x - elif isinstance(x, debugger_constexpr): - if x.value is None: - return None - return _primitive_to_tensor(x.value) - elif x is None: - return None - assert False, f"cannot convert {x} of type {type(x)} to tensor" - - -def _infer_tensor(func): - """ - A decorator function to harmonize function args: - - converts primitives to PyTorch tensors - - wraps PyTorch tensors with WrappedTensors - """ - def wrapper(*args): - new_args = tuple(map(lambda v: _primitive_to_tensor(v), args)) - new_args = tuple(map(lambda v: WrappedTensor(v) if torch.is_tensor(v) else v, new_args)) - - return func(*new_args) - - return wrapper - - -def _tensor_operation(func): - """ - A decorator function to unwrap WrappedTensors and debugger_constexpr before calling the function. - Can be combined with _infer_tensor decorator to harmonize args (everything to torch tensor). - """ - def wrapper(*args, **kwargs): - for arg in args: - assert not torch.is_tensor(arg), "unexpected tensor argument" - - def unwrap_tensor(v): - if isinstance(v, WrappedTensor): - return v.tensor - if isinstance(v, debugger_constexpr): - return v.value - return v - - new_args = tuple(map(unwrap_tensor, args)) - new_kwargs = {k: unwrap_tensor(v) for k, v in kwargs.items()} - - result = func(args[0], *new_args[1:], **new_kwargs) - return WrappedTensor(result) if torch.is_tensor(result) else result - - return wrapper - - -class debugger_constexpr: - def __init__(self, value): - if isinstance(value, debugger_constexpr): - self.value = value.value - else: - self.value = value - - def __str__(self) -> str: - return "debugger_constexpr(" + str(self.value) + ")" - - def __index__(self) -> int: - return self.value - - def __bool__(self): - return bool(self.value) - - def __ge__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value >= other - - def __gt__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value > other - - def __le__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value <= other - - def __lt__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value < other - - def __eq__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value == other - - def __or__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value | other - - def __ror__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value | other - - def __and__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value & other - - def __rand__(self, other): - other = other.value if isinstance(other, debugger_constexpr) else other - return self.value & other - - def to(self, dtype, bitcast=False, _builder=None): - if dtype in [torch.int64]: - ret_ty = int - elif dtype == torch.bool: - ret_ty = bool - elif dtype in [torch.float64]: - ret_ty = float - else: - raise ValueError("dtype not supported in debugger") - return debugger_constexpr(ret_ty(self.value)) - - -class WrappedTensor: - def __init__(self, tensor): - self.tensor = tensor - - def __index__(self) -> int: - return self.tensor.item() - - def __str__(self) -> str: - return "wrapped_" + str(self.tensor) - - def __bool__(self) -> bool: - return torch.all(self.tensor == True).item() # noqa: E712 - - @property - def dtype(self): - return self.tensor.dtype - - @_infer_tensor - @_tensor_operation - def __add__(self, other): - return torch.add(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __radd__(self, other): - return self.__add__(other) - - @_infer_tensor - @_tensor_operation - def __sub__(self, other): - return torch.sub(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rsub__(self, other): - return torch.sub(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __mul__(self, other): - return torch.mul(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rmul__(self, other): - return self.__mul__(other) - - @_infer_tensor - @_tensor_operation - def __truediv__(self, other): - return torch.div(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rtruediv__(self, other): - return torch.div(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __floordiv__(self, other): - return torch.floor_divide(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rfloordiv__(self, other): - return torch.floor_divide(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __mod__(self, other): - return torch.remainder(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rmod__(self, other): - return torch.remainder(other, self.tensor) - - @_infer_tensor - @_tensor_operation - def __neg__(self): - return -self.tensor - - @_infer_tensor - @_tensor_operation - def __invert__(self): - return ~self.tensor - - @_infer_tensor - @_tensor_operation - def __and__(self, other): - return torch.bitwise_and(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __or__(self, other): - return torch.bitwise_or(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __xor__(self, other): - return torch.bitwise_xor(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __lshift__(self, other): - return torch.bitwise_left_shift(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __rshift__(self, other): - return torch.bitwise_right_shift(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __gt__(self, other): - return self.tensor > other - - @_infer_tensor - @_tensor_operation - def __rgt__(self, other): - return other > self.tensor - - @_infer_tensor - @_tensor_operation - def __ge__(self, other): - return self.tensor >= other - - @_infer_tensor - @_tensor_operation - def __rge__(self, other): - return other >= self.tensor - - @_infer_tensor - @_tensor_operation - def __lt__(self, other): - return self.tensor < other - - @_infer_tensor - @_tensor_operation - def __rlt__(self, other): - return other < self.tensor - - @_infer_tensor - @_tensor_operation - def __le__(self, other): - return self.tensor <= other - - @_infer_tensor - @_tensor_operation - def __rle__(self, other): - return other <= self.tensor - - @_infer_tensor - @_tensor_operation - def __eq__(self, other): - return torch.equal(self.tensor, other) - - @_infer_tensor - @_tensor_operation - def __ne__(self, other): - return not torch.equal(self.tensor, other) - - @_tensor_operation - def __getitem__(self, slices): - return self.tensor.__getitem__(slices) - # if isinstance(slices, slice): - # slices = [slices] - # src_shape = self.shape - # dst_shape = [] - # curr = 0 - # for sl in slices: - # if isinstance(sl, constexpr) and sl.value is None: - # dst_shape.append(1) - # elif sl == slice(None, None, None): - # dst_shape.append(src_shape[curr].value) - # curr += 1 - # ret = torch.reshape(self.tensor, dst_shape, ) - # return ret - - @_tensor_operation - def to(self, dtype, bitcast=False): - return self.tensor.to(dtype) - # if isinstance(bitcast, constexpr): - # bitcast = bitcast.value - # if bitcast: - # return semantic.bitcast(self, dtype, ) - # return semantic.cast(self, dtype, ) - - -def _constexpr_to_value(v): - if isinstance(v, debugger_constexpr): - return v.value - return v - - -class TritonLangProxy: - _memory_map: MemoryMap - _context: ExecutionContext - - def __init__(self, memory_map: MemoryMap, context: ExecutionContext): - self._memory_map = memory_map - self._context = context - - # Types - # Removed void, int1, float8, uint16, uint32, uint64, pi32_t - - # constexpr = debugger_constexpr - - # Program functions - - @_tensor_operation - def load( - self, - pointer: torch.Tensor, - mask: torch.Tensor = None, - other=0.0, - cache_modifier="", - eviction_policy="", - volatile=False, - ): - return self._memory_map.load(pointer, mask, other) - - @_tensor_operation - def store(self, pointer: torch.Tensor, value: torch.Tensor, mask=None): - return self._memory_map.store(pointer, value, mask) - - @_tensor_operation - def program_id(self, axis): - assert axis < len(self._context.program_id) - return torch.tensor([self._context.program_id[axis]], dtype=torch.int32, device="cuda") - - @_tensor_operation - def num_programs(self, axis): - assert axis < len(self._context.program_size) - return torch.tensor([self._context.program_size[axis]], dtype=torch.int32, device="cuda") - - @_tensor_operation - def arange(self, start, end): - return torch.arange(start=start, end=end, dtype=torch.int32, device="cuda") - - @_tensor_operation - def zeros(self, shape, dtype): - for i, d in enumerate(shape): - if not isinstance(d, debugger_constexpr): - raise TypeError(f"Shape element {i} must have type `constexpr`") - if not isinstance(d.value, int): - raise TypeError(f"Shape element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") - shape = [x.value for x in shape] - if isinstance(dtype, lcore.dtype): - if dtype.is_fp32(): - dtype = torch.float32 - elif dtype.is_fp16(): - dtype = torch.float16 - elif dtype.is_bf16(): - dtype = torch.bfloat16 - elif dtype.is_int32(): - dtype = torch.int32 - elif dtype.is_int16(): - dtype = torch.int16 - elif dtype.is_int8(): - dtype = torch.int8 - else: - raise TypeError(f"Unsupported dtype {dtype}") - return torch.zeros(size=shape, dtype=dtype, device="cuda") - - @_tensor_operation - def dequantize(self, input, scale, shift, nbit, dst_ty=None): - if dst_ty is None: - dst_ty = torch.float16 - raise NotImplementedError() - - @_tensor_operation - def broadcast(self, input, other): - raise NotImplementedError() - - @_tensor_operation - def broadcast_to(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def cat(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def reshape(self, input, shape): - raise NotImplementedError() - - @_tensor_operation - def dot(self, input, other, trans_a=False, trans_b=False, allow_tf32=True): - assert input.dtype == other.dtype - if trans_a: - input = input.T - if trans_b: - other = other.T - return torch.matmul(input=input, other=other) - - @_tensor_operation - def atomic_cas(self, pointer, cmp, val): - stored = self._memory_map.load(pointer, None, 0.0) - if not isinstance(cmp, torch.Tensor): - cmp = torch.tensor([cmp], dtype=stored.dtype, device="cuda") - if not isinstance(val, torch.Tensor): - val = torch.tensor([val], dtype=stored.dtype, device="cuda") - if stored == cmp: - self._memory_map.store(pointer, val, None) - return stored - - @_tensor_operation - def atomic_xchg(self, pointer, val, mask=None): - if isinstance(val, int): - val = torch.tensor([val], dtype=torch.int32, device="cuda") - stored = self._memory_map.load(pointer, mask, 0.0) - self._memory_map.store(pointer, val, mask) - return stored - - @_tensor_operation - def atomic_add(self, pointer, val, mask=None): - # arbitrary other value as it will masked during storing - stored = self._memory_map.load(pointer, mask, 0.0) - result = stored + val - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_max(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0.0) - result = torch.maximum(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_min(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0.0) - result = torch.minimum(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_and(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_and(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_or(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_or(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def atomic_xor(self, pointer, val, mask=None): - stored = self._memory_map.load(pointer, mask, 0) - result = torch.bitwise_xor(stored, val) - self._memory_map.store(pointer, result, mask) - return stored - - @_tensor_operation - def where(self, condition, x, y): - condition = _primitive_to_tensor(condition) - x = _primitive_to_tensor(x) - y = _primitive_to_tensor(y) - return torch.where(condition, x, y) - - @_tensor_operation - def umulhi(self, x, y): - raise NotImplementedError() - - @_tensor_operation - def fdiv(self, x, y, ieee_rounding=False): - raise NotImplementedError() - - @_tensor_operation - def exp(self, x): - return torch.exp(x) - - @_tensor_operation - def log(self, x): - return torch.log(x) - - @_tensor_operation - def cos(self, x): - return torch.cos(x) - - @_tensor_operation - def sin(self, x): - return torch.sin(x) - - @_tensor_operation - def sqrt(self, x): - return torch.sqrt(x) - - @_tensor_operation - def globaltimer(self): - raise NotImplementedError() - - @_tensor_operation - def clock(self): - raise NotImplementedError() - - @_tensor_operation - def debug_barrier(self): - raise NotImplementedError() - - @_tensor_operation - def multiple_of(self, input, values): - return input - - @_tensor_operation - def max_contiguous(self, input, values): - return input - - @_tensor_operation - def max_constancy(self, input, values): - return input - - @_tensor_operation - def abs(self, x): - return torch.abs(x) - - @_tensor_operation - def cdiv(self, x, div): - return (x + div - 1) // div - - @_tensor_operation - def minimum(self, x, y): - if isinstance(x, int): - x = torch.tensor(x, device="cuda") - if isinstance(y, int): - y = torch.tensor(y, device="cuda") - return torch.minimum(x, y) - - @_tensor_operation - def maximum(self, x, y): - return torch.maximum(x, y) - - @_tensor_operation - def sigmoid(self, x): - raise NotImplementedError() - - @_tensor_operation - def softmax(self, x, ieee_rounding=False): - raise NotImplementedError() - - @_tensor_operation - def ravel(self, x): - raise NotImplementedError() - - @_tensor_operation - def swizzle2d(self, i, j, size_i, size_j, size_g): - raise NotImplementedError() - - @_tensor_operation - def zeros_like(self, input): - raise NotImplementedError() - - @_tensor_operation - def max(self, input, axis=None): - if axis is None: - return torch.max(input) - return torch.max(input, dim=axis).values - - @_tensor_operation - def argmax(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def min(self, input, axis=None): - if axis is None: - return torch.min(input) - return torch.min(input, dim=axis).values - - @_tensor_operation - def argmin(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def sum(self, input, axis=None): - if axis is None: - return torch.sum(input) - return torch.sum(input, dim=axis) - - @_tensor_operation - def xor_sum(self, input, axis): - raise NotImplementedError() - - @_tensor_operation - def cumsum(self, input, axis=None): - if axis is None: - return torch.cumsum(input) - return torch.cumsum(input, dim=axis) - - @_tensor_operation - def cumprod(self, input, axis=None): - if axis is None: - return torch.cumprod(input) - return torch.cumprod(input, dim=axis) diff --git a/python/triton/interpreter/torch_wrapper.py b/python/triton/interpreter/torch_wrapper.py deleted file mode 100644 index 44aa17eb1..000000000 --- a/python/triton/interpreter/torch_wrapper.py +++ /dev/null @@ -1,18 +0,0 @@ -try: - import torch as _torch -except ImportError: - _torch = None - - -class TorchWrapper: - """ - Helps in making torch an optional dependency - """ - - def __getattr__(self, name): - if _torch is None: - raise ImportError("Triton requires PyTorch to be installed") - return getattr(_torch, name) - - -torch = TorchWrapper() diff --git a/python/triton/language/core.py b/python/triton/language/core.py index c452aedfd..5e2abf192 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -293,7 +293,7 @@ class dtype: return self.name def __repr__(self): - return f'triton.language.{self.name}' + return f'triton.language.{str(self)}' class pointer_type(dtype): @@ -551,9 +551,7 @@ class tensor: # IR handle self.handle = handle # Block shape - self.shape = (1, ) - if type.is_block(): - self.shape = type.shape + self.shape = type.shape if type.is_block() else () self.numel = 1 for s in self.shape: self.numel *= s @@ -564,14 +562,15 @@ class tensor: self.shape = [constexpr(s) for s in self.shape] def __str__(self) -> str: - # ex. "float32[3,4]" - return str(self.dtype) + '[' + ','.join(str(s) for s in self.shape) + ']' + # ex. "float32[16, 32]" + return str(self.dtype) + '[' + ', '.join(str(s) for s in self.shape) + ']' @builtin def __add__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.add(self, other, _builder) + @builtin def __radd__(self, other, _builder=None): return self.__add__(other, _builder=_builder) @@ -580,6 +579,7 @@ class tensor: other = _to_tensor(other, _builder) return semantic.sub(self, other, _builder) + @builtin def __rsub__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.sub(other, self, _builder) @@ -589,6 +589,7 @@ class tensor: other = _to_tensor(other, _builder) return semantic.mul(self, other, _builder) + @builtin def __rmul__(self, other, _builder=None): return self.__mul__(other, _builder=_builder) @@ -597,6 +598,7 @@ class tensor: other = _to_tensor(other, _builder) return semantic.truediv(self, other, _builder) + @builtin def __rtruediv__(self, other, _builder=None): other = _to_tensor(other, _builder) return semantic.truediv(other, self, _builder) @@ -688,8 +690,6 @@ class tensor: else: return semantic.lshr(other, self, _builder) - # comparison operators - # > @builtin def __gt__(self, other, _builder=None): @@ -763,11 +763,11 @@ class tensor: @builtin def __getitem__(self, slices, _builder=None): - if isinstance(slices, slice): + if isinstance(slices, (slice, constexpr)): slices = [slices] ret = self for dim, sl in enumerate(slices): - if isinstance(sl, constexpr) and sl.value is None: + if sl is None or isinstance(sl, constexpr) and sl.value is None: ret = semantic.expand_dims(ret, dim, _builder) elif isinstance(sl, slice) and sl.start is None and sl.stop is None and sl.step is None: pass @@ -852,6 +852,8 @@ def arange(start, end, _builder=None): def _shape_check_impl(shape): shape = _constexpr_to_value(shape) for i, d in enumerate(shape): + if isinstance(d, int): + d = constexpr(d) if not isinstance(d, constexpr): raise TypeError(f"Shape element {i} must have type `constexpr`") if not isinstance(d.value, int): @@ -930,6 +932,12 @@ def broadcast_to(input, shape, _builder=None): @builtin def trans(input, _builder=None): + """ + Returns a transposed tensor. + + :param input: The input tensor. + :type input: + """ return semantic.trans(input, _builder) @@ -968,6 +976,15 @@ def view(input, shape, _builder=None): @builtin def reshape(input, shape, _builder=None): + """ + Returns a tensor with the same number of elements as input but with the + provided shape. + + :param input: The input tensor. + :type input: + :param shape: The new shape. + :type shape: Tuple[int] + """ shape = _shape_check_impl(shape) return semantic.reshape(input, shape, _builder) @@ -1012,7 +1029,7 @@ def expand_dims(input, axis, _builder=None): @builtin -def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None): +def dot(input, other, acc=None, allow_tf32=True, max_num_imprecise_acc=None, out_dtype=float32, _builder=None): """ Returns the matrix product of two blocks. @@ -1025,7 +1042,8 @@ def dot(input, other, allow_tf32=True, out_dtype=float32, _builder=None): """ allow_tf32 = _constexpr_to_value(allow_tf32) out_dtype = _constexpr_to_value(out_dtype) - return semantic.dot(input, other, allow_tf32, out_dtype, _builder) + max_num_imprecise_acc = _constexpr_to_value(max_num_imprecise_acc) + return semantic.dot(input, other, acc, allow_tf32, max_num_imprecise_acc, out_dtype, _builder) # ----------------------- @@ -1266,6 +1284,14 @@ def where(condition, x, y, _builder=None): @builtin def umulhi(x, y, _builder=None): + """ + Returns the most significant 32 bits of the product of x and y. + + :param x: the input tensor + :type x: int32 + :param y: the input tensor + :type y: int32 + """ x = _to_tensor(x, _builder) y = _to_tensor(y, _builder) return semantic.umulhi(x, y, _builder) @@ -1273,6 +1299,15 @@ def umulhi(x, y, _builder=None): @builtin def fdiv(x, y, ieee_rounding=False, _builder=None): + """ + Returns a floating-point resultant tensor of dividing x by y. + + :param x: the input numerator value. + :param y: the input denominator value. + :param ieee_rounding: To follow IEEE-754 floating point number + rounding mechanism + :type ieee_rounding: bool + """ ieee_rounding = _constexpr_to_value(ieee_rounding) return semantic.fdiv(x, y, ieee_rounding, _builder) diff --git a/python/triton/language/extra/cuda.py b/python/triton/language/extra/cuda.py index d69120938..8c4114739 100644 --- a/python/triton/language/extra/cuda.py +++ b/python/triton/language/extra/cuda.py @@ -13,3 +13,8 @@ def smid(_builder=None): return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, _builder=_builder) + + +@core.builtin +def num_threads(_builder=None): + return core.constexpr(_builder.target.num_warps * 32) diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 17b1b2920..b6194ad3b 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -12,6 +12,13 @@ import re T = TypeVar('T') +# TODO: redundant code -- remove after 3P backend refactor + + +def _is_cuda(target): + from ..compiler.compiler import CudaTargetDescriptor + return isinstance(target, CudaTargetDescriptor) + # Create custom exception that prints message "hello" @@ -28,10 +35,14 @@ class IncompatibleTypeErrorImpl(Exception): # ===----------------------------------------------------------------------===## def program_id(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") return tl.tensor(builder.create_get_program_id(axis), tl.int32) def num_programs(axis: int, builder: ir.builder) -> tl.tensor: + if axis not in (0, 1, 2): + raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") return tl.tensor(builder.create_get_num_programs(axis), tl.int32) # ===----------------------------------------------------------------------===// @@ -131,6 +142,8 @@ def add(input: tl.tensor, input, other = binary_op_type_checking_impl(input, other, builder, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar + if input_scalar_ty.is_ptr() and other_scalar_ty.is_ptr(): + raise ValueError("cannot add pointers together") # offset + ptr # ptr + offset @@ -504,19 +517,18 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te if isinstance(value, tl.tensor): assert value.numel.value == 1, "only accepts size-1 tensor" value = cast(value, dtype, builder) - ret_ty = tl.block_type(value.dtype, shape) - return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) else: # scalar + if dtype is None: + raise ValueError("dtype must be specified when value is not a tensor") if value == 0: value = builder.get_null_value(dtype.to_ir(builder)) else: get_value_fn = getattr(builder, f"get_{dtype.name}") value = get_value_fn(value) - if dtype is None: - raise ValueError("dtype must be specified when value is not a tensor") - ret_ty = tl.block_type(dtype, shape) - return tl.tensor(builder.create_splat(value, shape), ret_ty) + value = tl.tensor(value, dtype) + + return splat(value, shape, builder) @@ -529,6 +541,13 @@ def ones(shape: List[int], dtype: tl.dtype, builder: ir.builder) -> tl.tensor: # Shape Manipulation # ===----------------------------------------------------------------------===// +def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: + assert not value.type.is_block(), "Cannot splat a block tensor" + if len(shape) == 0: + return value + ret_ty = tl.block_type(value.dtype, shape) + return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) + def view(input: tl.tensor, dst_shape: List[int], @@ -553,8 +572,12 @@ def reshape(input: tl.tensor, def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor: - dst_shape = list(input.type.shape) + dst_shape = [tl._constexpr_to_value(x) for x in input.shape] dst_shape.insert(axis, 1) + + if not input.type.is_block(): + return splat(input, shape=dst_shape, builder=builder) + ret_ty = tl.block_type(input.type.scalar, dst_shape) return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty) @@ -674,11 +697,6 @@ def bitcast(input: tl.tensor, dst_ty) -# TODO: architecture descriptor class -def _is_cuda(arch): - return isinstance(arch, int) - - def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: @@ -693,7 +711,7 @@ def cast(input: tl.tensor, src_sca_ty = src_ty.scalar dst_sca_ty = dst_ty.scalar - if _is_cuda(builder.arch) and builder.arch < 89 and \ + if _is_cuda(builder.target) and builder.target.capability < 89 and \ (src_sca_ty.is_fp8e4nv() or dst_sca_ty.is_fp8e4nv()): assert False, "fp8e4nv data type is not supported on CUDA arch < 89" @@ -1139,13 +1157,20 @@ def atomic_max(ptr: tl.tensor, # for float # return atomic_smax(i_ptr, i_val) if val >= 0 # return atomic_umin(i_ptr, i_val) if val < 0 - i_val = bitcast(val, tl.int32, builder) - i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) - neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_max not supported for dtype {sca_ty}") + + itype = tl.int32 if sca_ty == tl.float32 else tl.float64 + zero = full([], 0.0, sca_ty, builder) + + i_val = bitcast(val, itype, builder) + i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem), i_val.type) neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem), i_val.type) - return where(pos, pos_ret, neg_ret, builder) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) def atomic_min(ptr: tl.tensor, @@ -1175,10 +1200,16 @@ def atomic_min(ptr: tl.tensor, # for float # return atomic_smin(i_ptr, i_val) if val >= 0 # return atomic_umax(i_ptr, i_val) if val < 0 - i_val = bitcast(val, tl.int32, builder) - i_ptr = bitcast(ptr, tl.pointer_type(tl.int32, 1), builder) - pos = greater_equal(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) - neg = less_than(val, tl.tensor(builder.get_fp32(0), sca_ty), builder) + if sca_ty not in {tl.float32, tl.float64}: + raise TypeError(f"atomic_min not supported for dtype {sca_ty}") + + itype = tl.int32 if sca_ty == tl.float32 else tl.float64 + zero = full([], 0.0, sca_ty, builder) + + i_val = bitcast(val, itype, builder) + i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder) + pos = greater_equal(val, zero, builder) + neg = less_than(val, zero, builder) pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, @@ -1191,7 +1222,8 @@ def atomic_min(ptr: tl.tensor, and_(mask, neg, builder).handle, sem), i_val.type) - return where(pos, pos_ret, neg_ret, builder) + ret = where(pos, pos_ret, neg_ret, builder) + return bitcast(ret, sca_ty, builder) def atomic_add(ptr: tl.tensor, @@ -1302,11 +1334,27 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: return False return True +def gpu_has_mfma() -> bool: + if not is_hip(): + return False + return True # mfma supported in ['gfx908', 'gfx90a'] + + +def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: + if not gpu_has_mfma(): + return False + # TODO: Add check for configurations and types. + return True + + def dot(lhs: tl.tensor, rhs: tl.tensor, + acc: tl.tensor, allow_tf32: bool, + max_num_imprecise_acc: int, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: +<<<<<<< HEAD def assert_dtypes_valid(lhs_dtype, rhs_dtype, arch): if is_hip(): assert lhs.dtype == rhs.dtype or (lhs.type.scalar.is_fp8() and rhs.type.scalar.is_fp16()) or \ @@ -1320,6 +1368,31 @@ def dot(lhs: tl.tensor, # Checks for cuda arch if arch < 90: assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" +======= + def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): + # Checks for non-cuda archs + if not _is_cuda(target): + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + return + # Checks for cuda arch + if target.capability < 90: + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" + if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): + return + assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" + else: + assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" + assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" + if lhs_dtype.is_int() or rhs_dtype.is_int(): + assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): + assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" + assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" + else: + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}" +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" else: assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" @@ -1339,8 +1412,12 @@ def dot(lhs: tl.tensor, assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" return +<<<<<<< HEAD assert lhs.type.is_block() and rhs.type.is_block() assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.arch) +======= + assert_dtypes_valid(lhs.dtype, rhs.dtype, builder.target) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" @@ -1367,6 +1444,8 @@ def dot(lhs: tl.tensor, assert is_hip() or lhs.shape[1].value >= 32, "small blocks not supported!" _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 + elif out_dtype.is_bf16(): + raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): _0 = builder.get_fp32(0) ret_scalar_ty = tl.float32 @@ -1401,10 +1480,25 @@ def dot(lhs: tl.tensor, ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) return cast(ret, ret_scalar_ty, builder) +<<<<<<< HEAD _0 = builder.create_splat(_0, [M, N]) +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ret_ty = tl.block_type(ret_scalar_ty, [M, N]) - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), + if acc is None: + acc_handle = builder.create_splat(_0, [M, N]) + else: + acc_handle = acc.handle + assert acc.type == ret_ty + + # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 + if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()): + max_num_imprecise_acc = 0 + if max_num_imprecise_acc is None: + max_num_imprecise_acc = 2**30 + + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty) @@ -1574,7 +1668,7 @@ def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor: def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor: - if len(x.shape) != len(values): + if max(1, len(x.shape)) != len(values): raise ValueError("Shape of input to multiple_of does not match the length of values") x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context())) return x @@ -1614,6 +1708,8 @@ def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: def _convert_elem_to_ir_value(builder, elem, require_i64): + if isinstance(elem, int): + elem = tl.constexpr(elem) if isinstance(elem, tl.constexpr): return builder.get_int64(elem.value) if require_i64 else builder.get_int32(elem.value) elif isinstance(elem, tl.tensor): diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index 8acc42615..8ef52cb9c 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -160,7 +160,7 @@ def max(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr else: return core._reduce_with_indices(input, axis, _argmax_combine_tie_break_fast) else: - if core.constexpr(input.dtype.primitive_bitwidth) < 32: + if core.constexpr(input.dtype.primitive_bitwidth) < core.constexpr(32): if core.constexpr(input.dtype.is_floating()): input = input.to(core.float32) else: diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index a946fcf64..74bc240b1 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -21,9 +21,10 @@ def _fwd_kernel( Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, + Z_H_N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, @@ -31,27 +32,21 @@ def _fwd_kernel( start_m = tl.program_id(0) off_hz = tl.program_id(1) qvk_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) + vk_offset = qvk_offset // stride_qm + K_block_ptr = tl.make_block_ptr( - base=K + qvk_offset, - shape=(BLOCK_DMODEL, N_CTX), + base=K, + shape=(BLOCK_DMODEL, Z_H_N_CTX), strides=(stride_kk, stride_kn), - offsets=(0, 0), + offsets=(0, vk_offset), block_shape=(BLOCK_DMODEL, BLOCK_N), order=(0, 1) ) V_block_ptr = tl.make_block_ptr( - base=V + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), + base=V, + shape=(Z_H_N_CTX, BLOCK_DMODEL), + strides=(stride_vn, stride_vk), + offsets=(vk_offset, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) ) @@ -68,7 +63,11 @@ def _fwd_kernel( # don't work as expected with `exp` in the loop qk_scale = sm_scale * 1.44269504 # load q: it will stay in SRAM throughout - q = tl.load(Q_block_ptr) + + offs_k = tl.arange(0, BLOCK_DMODEL) + Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk + q = tl.load(Q_ptrs) + q = (q * qk_scale).to(K.dtype.element_ty) lo = 0 hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX @@ -86,8 +85,7 @@ def _fwd_kernel( alpha = tl.math.exp2(m_i - m_i_new) p = tl.math.exp2(qk - m_i_new[:, None]) # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] + acc *= alpha[:, None] acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) # -- update m_i and l_i -- l_i = l_i * alpha + tl.sum(p, 1) @@ -101,13 +99,14 @@ def _fwd_kernel( tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O O_block_ptr = tl.make_block_ptr( - base=Out + qvk_offset, - shape=(N_CTX, BLOCK_DMODEL), + base=Out, + shape=(Z_H_N_CTX, BLOCK_DMODEL), strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), + offsets=(vk_offset + start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) + # O_ptrs = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) @@ -137,13 +136,14 @@ def _bwd_kernel_one_col_block( D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, off_hz, start_n, num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, CAUSAL: tl.constexpr, + MMA_V3: tl.constexpr ): if SEQUENCE_PARALLEL: DQ += stride_dqa.to(tl.int64) * start_n @@ -159,7 +159,7 @@ def _bwd_kernel_one_col_block( # initialize pointers to value-like data q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) - v_ptrs = V + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) # pointer to row-wise quantities in value-like data @@ -203,8 +203,11 @@ def _bwd_kernel_one_col_block( dq += tl.dot(ds, k, allow_tf32=True) tl.store(dq_ptrs, dq) elif SEQUENCE_PARALLEL: - # dq = tl.dot(ds, k, allow_tf32=True) - dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True)) + if MMA_V3: + dq = tl.dot(ds, k, allow_tf32=True) + else: + # not work with mma v3, becuase M % 64 != 0 + dq = tl.trans(tl.dot(tl.trans(k), tl.trans(ds), allow_tf32=True)) tl.store(dq_ptrs, dq) # increment pointers @@ -212,7 +215,7 @@ def _bwd_kernel_one_col_block( q_ptrs += BLOCK_M * stride_qm do_ptrs += BLOCK_M * stride_qm # write-back - dv_ptrs = DV + (offs_n[:, None] * stride_vk + offs_k[None, :] * stride_vn) + dv_ptrs = DV + (offs_n[:, None] * stride_vn + offs_k[None, :] * stride_vk) dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) tl.store(dv_ptrs, dv) tl.store(dk_ptrs, dk) @@ -228,12 +231,13 @@ def _bwd_kernel( D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SEQUENCE_PARALLEL: tl.constexpr, CAUSAL: tl.constexpr, + MMA_V3: tl.constexpr # fmt: on ): qk_scale = sm_scale * 1.44269504 @@ -259,13 +263,14 @@ def _bwd_kernel( D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, off_hz, start_n, num_block_n, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, CAUSAL=CAUSAL, + MMA_V3=MMA_V3 ) else: start_n = tl.program_id(1) @@ -276,13 +281,14 @@ def _bwd_kernel( D, stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, + stride_vz, stride_vh, stride_vn, stride_vk, Z, H, N_CTX, off_hz, start_n, num_block_n, BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK_N, SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, CAUSAL=CAUSAL, + MMA_V3=MMA_V3 ) @@ -317,6 +323,7 @@ class _attention(torch.autograd.Function): v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], q.shape[2], + q.shape[0] * q.shape[1] * q.shape[2], BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, IS_CAUSAL=causal, num_warps=num_warps, @@ -332,6 +339,8 @@ class _attention(torch.autograd.Function): @staticmethod def backward(ctx, do): + capability = torch.cuda.get_device_capability() + MMA_V3 = capability[0] >= 9 BLOCK = 128 q, k, v, o, L = ctx.saved_tensors sequence_parallel = ctx.sequence_parallel @@ -365,6 +374,7 @@ class _attention(torch.autograd.Function): BLOCK_DMODEL=ctx.BLOCK_DMODEL, SEQUENCE_PARALLEL=sequence_parallel, CAUSAL=ctx.causal, + MMA_V3=MMA_V3, num_warps=8, num_stages=1, ) diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index e7c27ec40..63490504d 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -82,6 +82,7 @@ def _kernel(A, B, C, M, N, K, stride_cm, stride_cn, dot_out_dtype: tl.constexpr, allow_tf32: tl.constexpr, + fp8_fast_accum: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr ): @@ -118,7 +119,10 @@ def _kernel(A, B, C, M, N, K, if AB_DTYPE: a = a.to(C.dtype.element_ty) b = b.to(C.dtype.element_ty) - acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + if fp8_fast_accum: + acc = tl.dot(a, b, acc, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + else: + acc += tl.dot(a, b, out_dtype=dot_out_dtype, allow_tf32=allow_tf32) A += BLOCK_K * SPLIT_K * stride_ak B += BLOCK_K * SPLIT_K * stride_bk acc = acc.to(C.dtype.element_ty) @@ -140,7 +144,7 @@ class _matmul(torch.autograd.Function): _locks = {} @staticmethod - def _call(a, b, dot_out_dtype, allow_tf32): + def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): device = a.device # handle non-contiguous inputs if necessary if a.stride(0) > 1 and a.stride(1) > 1: @@ -155,6 +159,8 @@ class _matmul(torch.autograd.Function): if a.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5] or\ b.dtype in [tl.float8e4nv, tl.float8e4b15, tl.float8e5]: c_dtype = torch.float16 + elif a.dtype in [torch.int8] or b.dtype in [torch.int8]: + c_dtype = torch.int32 else: c_dtype = get_higher_dtype(a.dtype, b.dtype) c = torch.empty((M, N), device=device, dtype=c_dtype) @@ -174,6 +180,8 @@ class _matmul(torch.autograd.Function): ab_dtype = True if a.dtype in [tl.float8e4nv, tl.float8e5] and b.dtype in [tl.float8e4nv, tl.float8e5]: ab_dtype = False + if a.dtype in [torch.int8] and b.dtype in [torch.int8]: + ab_dtype = False # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) _kernel[grid](a, b, c, M, N, K, @@ -182,12 +190,13 @@ class _matmul(torch.autograd.Function): c.stride(0), c.stride(1), dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, + fp8_fast_accum=fp8_fast_accum, GROUP_M=8, AB_DTYPE=ab_dtype) return c @staticmethod - def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True): - return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32) + def forward(ctx, a, b, dot_out_dtype=None, allow_tf32=True, fp8_fast_accum=True): + return _matmul._call(a, b, dot_out_dtype=dot_out_dtype, allow_tf32=allow_tf32, fp8_fast_accum=fp8_fast_accum) matmul = _matmul.apply diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index cdb66bc9c..abe5325ee 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -5,14 +5,16 @@ import torch from .. import cdiv from .._C.libtriton.triton import runtime from ..runtime import driver -from ..testing import get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops +from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, + nvsmi) def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs - tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, backend, device) + cur_sm_clock = nvsmi(['clocks.current.sm'])[0] + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device) return tflops @@ -20,7 +22,8 @@ def get_simd_tflops(backend, device, num_ctas, num_warps, dtype): ''' return compute throughput in TOPS ''' total_warps = num_ctas * min(num_warps, 4) num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs - tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, backend, device) + cur_sm_clock = nvsmi(['clocks.current.sm'])[0] + tflops = min(num_subcores, total_warps) / num_subcores * get_max_simd_tflops(dtype, cur_sm_clock, backend, device) return tflops diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index e90d2fcb2..1d78cd691 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -132,9 +132,11 @@ class Autotuner(KernelInterface): full_nargs = {**self.nargs, **kwargs, **self.best_config.kwargs} if config.pre_hook is not None: config.pre_hook(full_nargs) - return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, - num_ctas=config.num_ctas, - enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs) + ret = self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, + num_ctas=config.num_ctas, + enable_warp_specialization=config.enable_warp_specialization, **kwargs, **config.kwargs) + self.nargs = None + return ret def prune_configs(self, kwargs): pruned_configs = self.configs diff --git a/python/triton/runtime/backends/cuda.c b/python/triton/runtime/backends/cuda.c index 622588e8d..278310473 100644 --- a/python/triton/runtime/backends/cuda.c +++ b/python/triton/runtime/backends/cuda.c @@ -11,7 +11,10 @@ static inline void gpuAssert(CUresult code, const char *file, int line) { char err[1024] = {0}; strcat(err, prefix); strcat(err, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); } } @@ -327,7 +330,7 @@ static PyObject *memFree(PyObject *self, PyObject *args) { // Helper function to convert a Python list to a cuuint64_t array static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) { Py_ssize_t len = PyList_Size(listObj); - cuuint64_t *array = malloc(len * sizeof(cuuint64_t)); + cuuint64_t *array = (cuuint64_t *)malloc(len * sizeof(cuuint64_t)); for (Py_ssize_t i = 0; i < len; i++) { PyObject *item = PyList_GetItem(listObj, i); array[i] = (cuuint64_t)PyLong_AsUnsignedLongLong(item); @@ -338,7 +341,7 @@ static cuuint64_t *list_to_cuuint64_array(PyObject *listObj) { // Helper function to convert a Python list to a cuuint32_t array static cuuint32_t *list_to_cuuint32_array(PyObject *listObj) { Py_ssize_t len = PyList_Size(listObj); - cuuint32_t *array = malloc(len * sizeof(cuuint32_t)); + cuuint32_t *array = (cuuint32_t *)malloc(len * sizeof(cuuint32_t)); for (Py_ssize_t i = 0; i < len; i++) { PyObject *item = PyList_GetItem(listObj, i); array[i] = (cuuint32_t)PyLong_AsUnsignedLong(item); diff --git a/python/triton/runtime/backends/hip.c b/python/triton/runtime/backends/hip.c index 5ed5f19ce..c419132fe 100644 --- a/python/triton/runtime/backends/hip.c +++ b/python/triton/runtime/backends/hip.c @@ -13,7 +13,10 @@ static inline void gpuAssert(hipError_t code, const char *file, int line) { const char *str = hipGetErrorString(code); char err[1024] = {0}; snprintf(err, 1024, "%s Code: %d, Messsage: %s", prefix, code, str); + PyGILState_STATE gil_state; + gil_state = PyGILState_Ensure(); PyErr_SetString(PyExc_RuntimeError, err); + PyGILState_Release(gil_state); } } } diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index db8f6193e..e4721cbe3 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -10,6 +10,14 @@ def default_cache_dir(): return os.path.join(Path.home(), ".triton", "cache") +def default_override_dir(): + return os.path.join(Path.home(), ".triton", "override") + + +def default_dump_dir(): + return os.path.join(Path.home(), ".triton", "dump") + + class CacheManager(ABC): def __init__(self, key): pass @@ -36,17 +44,26 @@ class CacheManager(ABC): class FileCacheManager(CacheManager): - def __init__(self, key): + def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None - # create cache directory if it doesn't exist - self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir() - if self.cache_dir: + if (dump): + self.cache_dir = default_dump_dir() self.cache_dir = os.path.join(self.cache_dir, self.key) self.lock_path = os.path.join(self.cache_dir, "lock") os.makedirs(self.cache_dir, exist_ok=True) + elif (override): + self.cache_dir = default_override_dir() + self.cache_dir = os.path.join(self.cache_dir, self.key) else: - raise RuntimeError("Could not create or locate cache dir") + # create cache directory if it doesn't exist + self.cache_dir = os.getenv('TRITON_CACHE_DIR', "").strip() or default_cache_dir() + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") def _make_path(self, filename) -> str: return os.path.join(self.cache_dir, filename) @@ -131,3 +148,11 @@ def get_cache_manager(key) -> CacheManager: __cache_cls_nme = user_cache_manager return __cache_cls(key) + + +def get_override_manager(key) -> CacheManager: + return __cache_cls(key, override=True) + + +def get_dump_manager(key) -> CacheManager: + return __cache_cls(key, dump=True) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py new file mode 100644 index 000000000..e066ea082 --- /dev/null +++ b/python/triton/runtime/interpreter.py @@ -0,0 +1,527 @@ +import inspect + +import numpy as np + +import triton +import triton.language as tl +from .._C.libtriton.triton import interpreter as _interpreter + + +# TODO: duplicate +def str_to_ty(name): + language = tl + if name[0] == "*": + ty = str_to_ty(name[1:]) + return language.pointer_type(ty) + tys = { + "fp8e4nv": language.float8e4nv, + "fp8e5": language.float8e5, + "fp8e4b15": language.float8e4b15, + "fp8e4b15x4": language.float8e4b15x4, + "fp16": language.float16, + "bf16": language.bfloat16, + "fp32": language.float32, + "fp64": language.float64, + "i1": language.int1, + "i8": language.int8, + "i16": language.int16, + "i32": language.int32, + "i64": language.int64, + "u8": language.uint8, + "u16": language.uint16, + "u32": language.uint32, + "u64": language.uint64, + "B": language.int1, + } + return tys[name] + + +class TensorHandle: + + def __init__(self, data, dtype): + self.data = data + self.dtype = dtype + + def __bool__(self): + return bool(self.data.all()) + + +class BlockPointerHandle: + + def __init__(self, base, shape, strides, offsets, tensor_shape, order): + self.base = base + self.shape = shape + self.strides = strides + self.offsets = offsets + self.tensor_shape = tensor_shape + self.order = order + + def materialize_pointers(self, boundary_check): + dtype_tt = self.base.dtype.element_ty + n_bytes = dtype_tt.primitive_bitwidth // 8 + tensor_shape = self.tensor_shape + ptrs = np.broadcast_to(self.base.data, self.tensor_shape) + masks = np.ones(self.tensor_shape, dtype=bool) + for dim in range(len(tensor_shape)): + bcast_dims = [1] * len(tensor_shape) + bcast_dims[dim] = tensor_shape[dim] + off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims) + ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64) + if dim in boundary_check: + masks = np.logical_and(masks, off < self.shape[dim].data) + ptrs = TensorHandle(ptrs, self.base.dtype) + return ptrs, masks + + +def wrap_ret(compute_ret_ty): + def wrapper(fn): + def wrapped(*args, **kwargs): + ret = fn(*args, **kwargs) + return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs)) + return wrapped + return wrapper + + +class Builder: + + def __init__(self) -> None: + self.arch = None + # pass + + def set_grid_idx(self, x, y, z): + assert x < self.grid_dim[0] + assert y < self.grid_dim[1] + assert z < self.grid_dim[2] + self.grid_idx = (x, y, z) + + def set_grid_dim(self, nx, ny, nz): + self.grid_dim = (nx, ny, nz) + + def np_dtype(self, tt_dtype): + if isinstance(tt_dtype, tl.pointer_type): + return np.dtype(np.uint64) + np_types = { + tl.float16: np.dtype(np.float16), + tl.float32: np.dtype(np.float32), + tl.float64: np.dtype(np.float64), + tl.int8: np.dtype(np.int8), + tl.uint8: np.dtype(np.uint8), + tl.int16: np.dtype(np.int16), + tl.uint16: np.dtype(np.uint16), + tl.int32: np.dtype(np.int32), + tl.uint32: np.dtype(np.uint32), + tl.int64: np.dtype(np.int64), + tl.uint64: np.dtype(np.uint64), + } + return np_types[tt_dtype] + + # constants + def get_half_ty(self): + return tl.float16 + + def get_float_ty(self): + return tl.float32 + + def get_int64_ty(self): + return tl.int64 + + def get_ptr_ty(self, elt_ty, addr_space): + return tl.pointer_type(elt_ty, addr_space) + + def get_block_ty(self, dtype, shape): + return tl.tensor(shape, dtype) + + def get_int32(self, value): + return TensorHandle(np.array([value], dtype=np.int32), tl.int32) + + def get_int64(self, value): + return TensorHandle(np.array([value], dtype=np.int64), tl.int64) + + def get_fp16(self, value): + return TensorHandle(np.array([value], dtype=np.float16), tl.float16) + + def get_fp32(self, value): + return TensorHandle(np.array([value], dtype=np.float32), tl.float32) + + def get_null_value(self, type): + return TensorHandle(np.array([0], dtype=self.np_dtype(type)), type) + + # programming model + def create_get_program_id(self, axis): + assert self.grid_idx is not None + return TensorHandle(np.array([self.grid_idx[axis]], dtype=np.int32), tl.int32) + + def create_get_num_programs(self, axis): + return TensorHandle(np.array([self.grid_dim[axis]], dtype=np.int32), tl.int32) + + # memory ops + def create_load(self, ptr, _0, _1, is_volatile): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + other = None + return self.create_masked_load(ptr, mask, other, _0, _1, is_volatile) + + def create_store(self, ptr, val, _0, _1): + mask = TensorHandle(np.ones_like(ptr.data, dtype=bool), tl.int1) + return self.create_masked_store(ptr, val, mask, None, None) + + def create_masked_load(self, ptrs, mask, other, cache_modifier, eviction_policy, is_volatile): + dtype_tt = ptrs.dtype.element_ty + dtype_np = self.np_dtype(dtype_tt) + if other is None: + other = TensorHandle(np.ones_like(ptrs.data, dtype=dtype_np), dtype_tt) + ret = _interpreter.load(ptrs.data, mask.data, other.data, dtype_np) + return TensorHandle(ret, dtype_tt) + + def create_masked_store(self, ptrs, value, mask, cache_modifier, eviction_policy): + return _interpreter.store(ptrs.data, value.data, mask.data) + + # casting ops + def cast_impl(self, src, dst_type): + if isinstance(dst_type, tl.tensor): + dst_type = dst_type.dtype + return TensorHandle(src.data.astype(self.np_dtype(dst_type)), dst_type) + + create_si_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_ui_to_fp = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_si = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_to_ui = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_ext = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_fp_trunc = lambda self, src, dst_type: self.cast_impl(src, dst_type) + create_int_cast = lambda self, src, dst_type, is_signed: self.cast_impl(src, dst_type) + + def create_fp_to_fp(self, src, dst_type): + assert "float8 not NotImplemented yet" + + def create_bitcast(self, src, dst_type): + return TensorHandle(src.data.view(self.np_dtype(dst_type)), dst_type) + + # binary operators + def binary_op(self, lhs, rhs, op): + return TensorHandle(op(lhs.data, rhs.data), lhs.dtype) + + create_fadd = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_fmul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_fdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.divide) + create_frem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_fsub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_mul = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.multiply) + create_sdiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide) + create_udiv = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.floor_divide) + create_srem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_urem = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.remainder) + create_add = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.add) + create_sub = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.subtract) + create_shl = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.left_shift) + create_lshr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_ashr = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.right_shift) + create_minsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_minf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.minimum) + create_maxsi = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxui = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_maxf = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.maximum) + create_icmpSLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpSLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpSGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpSGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_icmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_icmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_icmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_icmpEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_icmpNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpOLT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpOGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpOLE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpOGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpOEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpONE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_fcmpULT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less) + create_fcmpUGT = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater) + create_fcmpULE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.less_equal) + create_fcmpUGE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.greater_equal) + create_fcmpUEQ = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.equal) + create_fcmpUNE = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.not_equal) + create_and = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_and) + create_xor = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_xor) + create_or = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.bitwise_or) + + # ternary functions + def ternary_op(self, lhs, rhs, other, op): + return TensorHandle(op(lhs.data, rhs.data, other.data), other.dtype) + create_select = lambda self, cond, lhs, rhs: self.ternary_op(cond, lhs, rhs, np.where) + + # unary functions + def unary_op(self, arg, op): + return TensorHandle(op(arg.data), arg.dtype) + create_exp = lambda self, arg: self.unary_op(arg, np.exp) + create_cos = lambda self, arg: self.unary_op(arg, np.cos) + create_sin = lambda self, arg: self.unary_op(arg, np.sin) + create_log = lambda self, arg: self.unary_op(arg, np.log) + create_sqrt = lambda self, arg: self.unary_op(arg, np.sqrt) + create_fabs = lambda self, arg: self.unary_op(arg, np.abs) + create_iabs = lambda self, arg: self.unary_op(arg, np.abs) + + # tensor operators + create_dot = lambda self, lhs, rhs: self.binary_op(lhs, rhs, np.dot) + create_view = lambda self, arg, shape: TensorHandle(arg.data.reshape(shape), arg.dtype) + create_trans = lambda self, arg: self.unary_op(arg, np.transpose) + + def create_dot(self, a, b, d, allow_tf32, maxNumImpreciseAcc): + return TensorHandle(np.dot(a.data, b.data) + d.data, a.dtype) + + def create_make_range(self, start, stop): + return TensorHandle(np.arange(start, stop, dtype=np.int32), tl.int32) + + # pointer arithmetic + + def create_addptr(self, ptr, offset): + dtype_tt = ptr.dtype.element_ty + return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) + + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile): + ptrs, masks = ptr.materialize_pointers(boundary_check) + assert padding_option is None + other = None + return self.create_masked_load(ptrs, masks, other, cache_modifier, eviction_policy, is_volatile) + + def create_tensor_pointer_store(self, ptr, value, boundary_check, cache_modifier, eviction_policy): + ptrs, masks = ptr.materialize_pointers(boundary_check) + return self.create_masked_store(ptrs, value, masks, cache_modifier, eviction_policy) + + def create_expand_dims(self, arg, axis): + return TensorHandle(np.expand_dims(arg.data, axis), arg.dtype) + + def create_broadcast(self, arg, shape): + return TensorHandle(np.broadcast_to(arg.data, shape), arg.dtype) + + def create_int_to_ptr(self, val, dst_ty): + return TensorHandle(val.data.astype(np.uint64), dst_ty) + # def create_cat(self, lhs, rhs): + # pass + + # def create_broadcast(self, arg, shape): + # pass + + def create_splat(self, arg, shape): + return TensorHandle(np.full(shape, arg.data[0], dtype=self.np_dtype(arg.dtype)), arg.dtype) + + # def create_atomic_cas(self, ptr, cmp, val, sem): + # pass + + # def create_atomic_rmw(self, rmwOp, ptr, val, mask, sem): + # pass + + # def create_extern_elementwise(self, libName, libPath, symbol, argList, retType, isPure): + # pass + + # def create_reduce(self, operands, axis): + # pass + + # def create_reduce_ret(self, args): + # pass + + # def create_scan(self, operands, axis): + # pass + + # def create_scan_ret(self, args): + # pass + + # def create_ptr_to_int(self, val, type): + # pass + + # def create_int_to_ptr(self, val, type): + # pass + + # def create_inline_asm(self, inlineAsm, constraints, values, type, isPure, pack): + # pass + + # def create_print(self, prefix, values): + # pass + + # def create_assert(self, condition, message, fileName, funcName, lineNo): + # pass + + # def create_undef(self, type): + # pass + + # def create_barrier(self): + # pass + + def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order): + return BlockPointerHandle(base, shape, strides, np.array(offsets), tensor_shape, order) + + def create_advance(self, ptr, offsets): + assert len(ptr.offsets) == len(offsets) + ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, ptr.offsets, ptr.tensor_shape, ptr.order) + for i in range(len(offsets)): + ret.offsets[i].data += offsets[i].data + return ret + + +def patch_attr(obj, name, member, builder): + new_member = lambda *args, member=member, **kwargs: (member(*args, **{k: v for k, v in kwargs.items() if k != '_builder'}, _builder=builder)) + setattr(obj, name, new_member) + + +def _patch_lang_tensor(tensor, builder): + for name, member in inspect.getmembers(tensor): + if tl.core.is_builtin(member): + patch_attr(tensor, name, member, builder) + tensor.__index__ = lambda self: int(self.handle.data) + tensor.__bool__ = lambda self: True + tensor.__str__ = lambda self: str(self.handle.data) + tensor.__getitem__ = lambda self, slices: self.handle.data.__getitem__(slices) + + +def _patch_lang_core(lang, builder): + for name, member in inspect.getmembers(lang): + if tl.core.is_builtin(member): + patch_attr(lang, name, member, builder) + # reduce is better off with a separate patch due to how + # the builder currently interfaces with custom functions + + def _new_reduce(input, axis, combine_fn): + fn = combine_fn.fn.__name__ + mapping = { + 'maximum': np.max, + '_sum_combine': np.sum, + } + ret = mapping[fn](input.handle.data, axis=axis) + ret_type = tl.block_type(input.dtype, ret.shape) + return tl.core.tensor(TensorHandle(ret, input.dtype), ret_type) + + lang.reduce = _new_reduce + + +def _patch_lang_math(lang, builder): + math = lang.math + mapping = { + 'abs': 'abs', + 'acos': 'arccos', + 'asin': 'arcsin', + 'exp2': 'exp2', + 'log2': 'log2', + 'max': 'maximum', + } + + def make_numpy(name): + def impl(*args, **kwargs): + ret_type = args[0].type # TODO: incorrect + ret_dtype = args[0].dtype # TODO: incorrect + args = [arg.handle.data for arg in args] + kwargs = {k: v.handle.data for k, v in kwargs.items()} + ret = getattr(np, mapping[name])(*args, **kwargs) + ret = tl.core.tensor(TensorHandle(ret, ret_dtype), ret_type) + return ret + return impl + + def make_fallback(name): + def fallback(*args, **kwargs): + raise NotImplementedError(f""" +{name} not supported in interpreter mode: no known numpy implementation. +If you think that {name} in fact does have a numpy implementation, please add it +to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math. +""") + return fallback + + for name, member in inspect.getmembers(math): + if name in mapping: + setattr(math, name, make_numpy(name)) + else: + setattr(math, name, make_fallback(name)) + + +# TODO: wrap everything in triton tensors +def _implicit_cvt(arg): + if isinstance(arg, int): + ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg], dtype=np.int32), ty) + return tl.tensor(handle, ty) + if hasattr(arg, 'data_ptr'): + ty = str_to_ty(triton.runtime.jit.JITFunction._type_of(triton.runtime.jit.JITFunction._key_of(arg))) + handle = TensorHandle(np.array([arg.data_ptr()], dtype=np.uint64), ty) + return tl.tensor(handle, ty) + return arg + + +def _unwrap(tensor): + if isinstance(tensor, triton.TensorWrapper): + return tensor.base + return tensor + + +builder = Builder() + +RESERVED_KWS = ['num_warps', 'num_stages', 'num_ctas', 'enable_warp_specialization'] + + +class GridExecutor: + + def __init__(self, fn, arg_names, grid): + from .jit import _normalize_ty # TODO: modularize + self.fn = fn + self.arg_names = arg_names + self.grid = grid + __annotations__ = {name: _normalize_ty(ty) for name, ty in fn.__annotations__.items()} + self.constexprs = [name for name in arg_names if __annotations__.get(name) == 'constexpr'] + + def _patch_lang(self, builder): + lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + _patch_lang_core(lang[0], builder) + _patch_lang_math(lang[0], builder) + + def __call__(self, *args_dev, **kwargs): + args_hst = [_unwrap(arg).cpu() if hasattr(arg, 'data_ptr') else arg for arg in args_dev] + # removes reserved keywords from kwargs + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS} + # remaps core language functions to interpreted ones + self._patch_lang(builder) + # we need to copy arguments to the host for the interpreter + # implicitly convert tensor arguments to their base pointers + args = inspect.getcallargs(self.fn, *args_hst, **kwargs) + args = {name: arg if name in self.constexprs else _implicit_cvt(arg) for name, arg in args.items()} + # iterate through grid + grid = self.grid(args) if callable(self.grid) else self.grid + assert len(grid) <= 3 + grid = grid + (1,) * (3 - len(grid)) + builder.set_grid_dim(*grid) + for x in range(grid[0]): + for y in range(grid[1]): + for z in range(grid[2]): + builder.set_grid_idx(x, y, z) + self.fn(**args) + # copy arguments back to propagate side-effects + for arg_dev, arg_hst in zip(args_dev, args_hst): + if hasattr(arg_dev, 'data_ptr'): + _unwrap(arg_dev).copy_(arg_hst.to(arg_dev.device)) + + +class InterpretedFunction: + + def _patch_lang(self, builder): + lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] + assert len(lang) == 1, "triton.language must be visible from within jit'd function" + _patch_lang_tensor(getattr(lang[0], 'tensor'), builder) + _patch_lang_core(lang[0], builder) + + def __init__(self, fn) -> None: + self.fn = fn + + def run(*args, **kwargs): + grid = kwargs['grid'] + kwargs = {k: v for k, v in kwargs.items() if k not in RESERVED_KWS + ['grid']} + + return GridExecutor(self.fn, self.arg_names, grid)(*args, **kwargs) + self.run = run + signature = inspect.signature(fn) + self.arg_names = [v.name for v in signature.parameters.values()] + + def __getitem__(self, grid): + return GridExecutor(self.fn, self.arg_names, grid) + + def __call__(self, *args, **kwargs): + self._patch_lang(builder) + return self.fn(*args, **kwargs) diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index d346d51f1..72605e621 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -14,6 +14,7 @@ from typing import (Callable, Generic, Iterable, List, Optional, TypeVar, Union, from .._C.libtriton.triton import TMAInfos from ..common.backend import get_backend, path_to_ptxas from ..language.core import dtype +from .interpreter import InterpretedFunction TRITON_PATH = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TRITON_VERSION = "2.1.0" @@ -256,6 +257,8 @@ class JITFunction(KernelInterface[T]): "float8_e5m2fnuz": "fp8e5b16", "float8e4b15": "fp8e4b15", "float8e4b15x4": "fp8e4b15x4", + "float8_e4m3fn": "fp8e4nv", + "float8_e5m2": "fp8e5", "float16": "fp16", "bfloat16": "bf16", "float32": "fp32", @@ -274,10 +277,6 @@ class JITFunction(KernelInterface[T]): tys[v] = v return key if isinstance(key, str) else f"*{tys[dtype_str]}" - def _make_signature(self, sig_key): - signature = ",".join([self._type_of(k) for i, k in enumerate(sig_key)]) - return signature - def _make_constants(self, constexpr_key): constants = dict(zip(self.constexprs, constexpr_key)) return constants @@ -304,29 +303,29 @@ class JITFunction(KernelInterface[T]): return JITFunction.cache_hook(key=key, repr=repr, fn=LegacyCompiler(module, name), compile={ "key": key, **kwargs}, is_manual_warmup=False, already_compiled=False) - def _get_arg_specialization_key(self, arg) -> str: - arg_annotation = self.__annotations__.get(arg, '') + def _get_arg_specialization_key(self, arg_name, arg): + arg_annotation = self.__annotations__.get(arg_name, '') if arg_annotation == '': - return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0) if hasattr({arg}, "data_ptr") \ - else ({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1) if isinstance({arg}, int) \ - else (False,)' + return (arg.data_ptr() % JITFunction.divisibility == 0) if hasattr(arg, "data_ptr") \ + else (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) if isinstance(arg, int) \ + else (False,) elif 'Tensor' in arg_annotation: - return f'({arg}.data_ptr() % {JITFunction.divisibility} == 0)' - elif arg_annotation == 'int': - return f'({arg} % {JITFunction.divisibility} == 0, {arg} % {JITFunction.divisibility_8} == 0, {arg} == 1)' + return (arg.data_ptr() % JITFunction.divisibility == 0) + elif 'int' in arg_annotation or 'bool' in arg_annotation: + return (arg % JITFunction.divisibility == 0, arg % JITFunction.divisibility_8 == 0, arg == 1) else: - return '(False,)' + return (False,) - def _get_arg_sig_key(self, arg) -> str: - arg_annotation = self.__annotations__.get(arg, '') + def _get_arg_sig_key(self, arg_name, arg) -> str: + arg_annotation = self.__annotations__.get(arg_name, '') if 'Tensor' in arg_annotation: - return f'{arg}.dtype' + return arg.dtype elif arg_annotation == 'bool': return "i1" elif arg_annotation == 'float': return 'fp32' else: - return f'_key_of({arg})' + return self._key_of(arg) def _conclude_device_type(self, device_types: List[str], pinned_memory_flags: List[bool]) -> str: device_types = [device_type for device_type in device_types if device_type != ''] @@ -344,32 +343,110 @@ class JITFunction(KernelInterface[T]): return device_types[0] if len(device_types) > 0 else 'cuda' def _make_launcher(self): - regular_args = [f'{arg}' for i, arg in enumerate( + regular_args = [arg for i, arg in enumerate( self.arg_names) if i not in self.constexprs] - constexpr_args = [ - f'{arg}' for i, arg in enumerate( - self.arg_names) if i in self.constexprs] - args = ', '.join(regular_args) - # cache key for regular argument type - sig_keys = ', '.join([self._get_arg_sig_key(arg) for arg in regular_args]) - device_types = '[' + ', '.join([f'_device_of({arg})' for arg in regular_args]) + ']' - pinned_memory_flags = '[' + ', '.join([f'_pinned_memory_of({arg})' for arg in regular_args]) + ']' - # cache key for constexpr argument values - constexpr_keys = ', '.join(constexpr_args) - # cache key for argument specialization - specializations = [] - for i, arg in enumerate(regular_args): - if i in self.do_not_specialize: - continue - specializations += [self._get_arg_specialization_key(arg)] + constexpr_args = [arg for i, arg in enumerate( + self.arg_names) if i in self.constexprs] - spec_keys = ', '.join(specializations) - grid_args = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) + def regular_args_v(args_proxy): + return [args_proxy[arg_name] for arg_name in regular_args] + + def launcher_body(args_proxy, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type): + from ..compiler import (CompiledKernel, compile, + get_arch_default_num_stages, + get_arch_default_num_warps) + sig_key = tuple([self._get_arg_sig_key(arg_name, args_proxy[arg_name]) for arg_name in regular_args]) + constexpr_key = tuple([args_proxy[arg_name] for arg_name in constexpr_args]) + specializations = [] + for i, arg_name in enumerate(regular_args): + if i in self.do_not_specialize: + continue + specializations += [self._get_arg_specialization_key(arg_name, args_proxy[arg_name])] + + spec_key = tuple(specializations) + assert num_ctas > 0 + assert grid is not None + if callable(grid): + grid = grid(args_proxy) + grid_size = len(grid) + grid_0 = grid[0] + grid_1 = grid[1] if grid_size > 1 else 1 + grid_2 = grid[2] if grid_size > 2 else 1 + if device_type is None: + device_types = [self._device_of(arg) for arg in regular_args_v(args_proxy)] + device_types = [_device_type for _device_type in device_types if _device_type != ''] + device_type = self._conclude_device_type(device_types, [self._pinned_memory_of(arg) for arg in + regular_args_v(args_proxy)]) + + device_backend = None + if device_type not in ['cuda']: + device_backend = get_backend(device_type) + if device_backend is None: + raise ValueError('Cannot find backend for ' + device_type) + + if device is None: + if device_type in ['cuda']: + device = get_current_device() + set_current_device(device) + else: + device = device_backend.get_current_device() + device_backend.set_current_device(device) + if stream is None and not warmup: + if device_type in ['cuda']: + stream = get_cuda_stream(device) + else: + stream = device_backend.get_stream() + + if num_warps is None: + num_warps = get_arch_default_num_warps(device_type) + if num_stages is None: + num_stages = get_arch_default_num_stages(device_type) + + key = (version_key(), sig_key, constexpr_key, spec_key, num_warps, num_ctas, num_stages, enable_warp_specialization, self.debug) + if extern_libs is not None: + key = (key, tuple(extern_libs.items())) + + bin = self.cache[device].get(key, None) + if bin is not None: + # build dict of constant values + args = regular_args_v(args_proxy) + # Create tensormaps and append to args + args = bin.assemble_tensormap_to_arg(args) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) + return bin + # kernel not cached -- compile + else: + # build dict of constant values + args = regular_args_v(args_proxy) + all_args = tuple([args_proxy[arg_name] for arg_name in self.arg_names]) + configs = self._get_config(*all_args), + constants = self._make_constants(constexpr_key) + constants.update({i: None for i, arg in enumerate(all_args) if arg is None}) + constants.update({i: 1 for i in configs[0].equal_to_1}) + # build kernel signature -- doesn't include specialized arguments + signature = {i: self._type_of(self._key_of(arg)) for i, arg in enumerate(all_args) if i not in self.constexprs} + # build stub signature -- includes arguments that are specialized + for i, arg in constants.items(): + if callable(arg): + raise TypeError(f"Callable constexpr at index {i} is not supported") + if not self._call_hook(key, signature, device, constants, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, configs): + bin = compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_ctas=num_ctas, num_stages=num_stages, enable_warp_specialization=enable_warp_specialization, extern_libs=extern_libs, configs=configs, debug=self.debug, device_type=device_type) + # Create tensormaps and append to args + args = bin.assemble_tensormap_to_arg(args) + if not warmup: + bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.num_ctas, bin.clusterDims[0], bin.clusterDims[1], bin.clusterDims[2], bin.shared, stream, bin.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, bin, *args) + self.cache[device][key] = bin + return bin + return None + + # create a wrapper to call launcher_body + args_map = ','.join([f'"{arg}": {arg}' for arg in self.arg_names]) args_signature = ', '.join(name if dflt == inspect._empty else f'{name} = triton.language.dtype(\'{dflt}\')' if dtype.is_dtype(f'{dflt}') else f'{name} = {dflt}' for name, dflt in zip(self.arg_names, self.arg_defaults)) args_signature = args_signature + ', ' if len(args_signature) > 0 else '' - src = f""" import triton +<<<<<<< HEAD def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, waves_per_eu=0, matrix_instr_nonkdim=0, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): from ..compiler import compile, CompiledKernel, get_arch_default_num_warps, get_arch_default_num_stages sig_key = {f'{sig_keys},' if len(sig_keys) > 0 else ()} @@ -449,19 +526,12 @@ def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, nu self.cache[device][key] = bin return bin return None +======= +def {self.fn.__name__}({args_signature}grid=None, num_warps=None, num_ctas=1, num_stages=None, enable_warp_specialization=False, extern_libs=None, stream=None, warmup=False, device=None, device_type=None): + return launcher_body({{{args_map}}}, grid, num_warps, num_ctas, num_stages, enable_warp_specialization, extern_libs, stream, warmup, device, device_type) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 """ - scope = {"version_key": version_key(), - "get_cuda_stream": get_cuda_stream, - "self": self, - "_spec_of": self._spec_of, - "_key_of": self._key_of, - "_device_of": self._device_of, - "_pinned_memory_of": self._pinned_memory_of, - "cache": self.cache, - "__spec__": __spec__, - "get_backend": get_backend, - "get_current_device": get_current_device, - "set_current_device": set_current_device} + scope = {"launcher_body": launcher_body} exec(src, scope) return scope[self.fn.__name__] @@ -572,7 +642,6 @@ def jit( do_not_specialize: Optional[Iterable[int]] = None, debug: Optional[bool] = None, noinline: Optional[bool] = None, - interpret: Optional[bool] = None, ) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]: """ Decorator for JIT-compiling a function using the Triton compiler. @@ -594,9 +663,8 @@ def jit( def decorator(fn: T) -> JITFunction[T]: assert callable(fn) - if interpret: - from ..interpreter.interpreter import GridSelector - return GridSelector(fn) + if os.getenv("TRITON_INTERPRET", "0") == "1": + return InterpretedFunction(fn) else: return JITFunction( fn, diff --git a/python/triton/testing.py b/python/triton/testing.py index c4357bd24..f01d4f8e3 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -32,8 +32,11 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None): """ if torch.cuda.current_stream() == torch.cuda.default_stream(): raise RuntimeError("Cannot capture graph in default stream. Please use side stream in benchmark code.") - # record CUDAGraph + # warmup fn() + # step 1 - we estimate the amount of time the kernel call takes + # NOTE: this estimate isn't super accurate because the GPU isn't warmed up at this point + # but it is probably good enough if grad_to_none is not None: for x in grad_to_none: x.detach_() @@ -43,39 +46,35 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None): with torch.cuda.graph(g): fn() torch.cuda.synchronize() - fn = lambda: g.replay() - # Estimate the runtime of the function start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) start_event.record() - fn() + g.replay() end_event.record() torch.cuda.synchronize() estimate_ms = start_event.elapsed_time(end_event) - # compute number of repetition to last `rep` ms n_repeat = max(1, int(rep / estimate_ms)) - # compute number of repetition to last `rep` ms - start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - ret = [] - n_retries = 50 - for _ in range(n_retries): - # Benchmark - torch.cuda.synchronize() + # step 2 - construct a cuda graph with `n_repeat` unrolled function calls to minimize + # host overhead + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): for i in range(n_repeat): - # we don't want `fn` to accumulate gradient values - # if it contains a backward pass. So we clear the - # provided gradients if grad_to_none is not None: for x in grad_to_none: x.grad = None - # record time of `fn` - start_event[i].record() fn() - end_event[i].record() + torch.cuda.synchronize() + # measure time and return + ret = [] + n_retries = 10 + for i in range(n_retries): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + g.replay() + end_event.record() torch.cuda.synchronize() - times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]) - ret.append(torch.min(times)) + ret += [start_event.elapsed_time(end_event) / n_repeat] return torch.mean(torch.tensor(ret)).item() @@ -266,7 +265,7 @@ class Mark: self.fn = fn self.benchmarks = benchmarks - def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool): + def _run(self, bench: Benchmark, save_path: str, show_plots: bool, print_data: bool, **kwrags): import os import matplotlib.pyplot as plt @@ -287,7 +286,7 @@ class Mark: row_mean, row_min, row_max = [], [], [] for y in bench.line_vals: - ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args) + ret = self.fn(**x_args, **{bench.line_arg: y}, **bench.args, **kwrags) try: y_mean, y_min, y_max = ret except TypeError: @@ -328,14 +327,14 @@ class Mark: if save_path: df.to_csv(os.path.join(save_path, f"{bench.plot_name}.csv"), float_format='%.1f', index=False) - def run(self, show_plots=False, print_data=False, save_path=''): + def run(self, show_plots=False, print_data=False, save_path='', **kwargs): has_single_bench = isinstance(self.benchmarks, Benchmark) benchmarks = [self.benchmarks] if has_single_bench else self.benchmarks if save_path: html = open(os.path.join(save_path, "results.html"), "w") html.write("\n") for bench in benchmarks: - self._run(bench, save_path, show_plots, print_data) + self._run(bench, save_path, show_plots, print_data, **kwargs) if save_path: html.write(f"\n") if save_path: @@ -368,7 +367,7 @@ def get_dram_gbps(backend=None, device=None): return bw_gbps -def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None): +def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None): import torch from .runtime import driver @@ -378,8 +377,6 @@ def get_max_tensorcore_tflops(dtype, backend=None, device=None, clock_rate=None) device = torch.cuda.current_device() num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 - if not clock_rate: - clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz capability = torch.cuda.get_device_capability(device) if capability[0] < 8: assert dtype == torch.float16 @@ -423,21 +420,6 @@ def cuda_memcheck(**target_kwargs): return decorator -def nvsmi_attr(attrs): - attrs = ",".join(attrs) - cmd = [ - "nvidia-smi", - "-i", - "0", - "--query-gpu=" + attrs, - "--format=csv,noheader,nounits", - ] - out = subprocess.check_output(cmd) - ret = out.decode(sys.stdout.encoding).split(",") - ret = [int(x) for x in ret] - return ret - - @contextmanager def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): try: @@ -458,8 +440,8 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", ] ) - cur_sm_clock = nvsmi_attr(["clocks.current.sm"])[0] - cur_mem_clock = nvsmi_attr(["clocks.current.memory"])[0] + cur_sm_clock = nvsmi(["clocks.current.sm"])[0] + cur_mem_clock = nvsmi(["clocks.current.memory"])[0] assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" assert abs(cur_mem_clock - ref_mem_clock) < 10, f"GPU SMs must run at {ref_mem_clock} MHz" tflops = 1e-6 * 2 * 108 * 4 * 256 * ref_sm_clock @@ -471,7 +453,7 @@ def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): subprocess.check_output(["nvidia-smi", "-i", "0", "-rmc"]) -def get_max_simd_tflops(dtype, backend=None, device=None): +def get_max_simd_tflops(dtype, clock_rate, backend=None, device=None): import torch from .runtime import driver @@ -481,7 +463,6 @@ def get_max_simd_tflops(dtype, backend=None, device=None): device = torch.cuda.current_device() num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 - clock_rate = driver.utils.get_device_properties(device)["sm_clock_rate"] # in kHz capability = torch.cuda.get_device_capability() if capability[0] < 8: if dtype == torch.float32: diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index 32138e874..d80f15e8a 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -20,7 +20,7 @@ data along with utilities to load, unload and launch the kernel. signature is provided as a list of (optionally divisibility-hinted) types or constexpr values, e.g. -`compile.py --kernel-name kernel --signature "*f32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` +`compile.py --kernel-name kernel --signature "*fp32:16, i32:16, 1024, i32" --out-name kernel /path/to/kernel.py` will compile triton.JITFunction of name `kernel` inside the file `/path/to/kernel.py`. Said kernel will be specialized such that argument 0, 1 are assumed to be multiple of 16, @@ -51,7 +51,7 @@ if __name__ == "__main__": args = parser.parse_args() out_name = args.out_name if args.out_name else args.kernel_name - out_path = args.out_path if args.out_path else out_name + out_path = args.out_path if args.out_path else Path(out_name) # execute python sources and extract functions wrapped in JITFunction arg_path = Path(args.path) diff --git a/python/triton/tools/disasm.py b/python/triton/tools/disasm.py index 24a0787c5..1e309a2e4 100644 --- a/python/triton/tools/disasm.py +++ b/python/triton/tools/disasm.py @@ -20,8 +20,13 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import functools +import os import re import subprocess +import tempfile + +from ..common.backend import path_to_cuobjdump, path_to_nvdisasm FLINE_RE = re.compile(r'\s*/\*\w{4}\*/\s*([^;]*;)\s*/\* 0x(\w{16}) \*/\s*') SLINE_RE = re.compile(r'\s*/\* 0x(\w{16}) \*/\s*') @@ -60,11 +65,26 @@ def processSassLines(fline, sline, labels): return (f'{ctrl}', f'{asm}') +@functools.lru_cache() +def get_sass(cubin_asm, fun=None): + fd, path = tempfile.mkstemp() + try: + with open(fd, 'wb') as cubin: + cubin.write(cubin_asm) + sass = extract(path, fun) + finally: + os.remove(path) + return sass + + def extract(file_path, fun): + cuobjdump, _ = path_to_cuobjdump() + nvdisasm, _ = path_to_nvdisasm() + os.environ["NVDISASM_PATH"] = nvdisasm if fun is None: - sass_str = subprocess.check_output(["cuobjdump", "-sass", file_path]) + sass_str = subprocess.check_output([cuobjdump, "-sass", file_path]) else: - sass_str = subprocess.check_output(["cuobjdump", "-fun", fun, "-sass", file_path]) + sass_str = subprocess.check_output([cuobjdump, "-fun", fun, "-sass", file_path]) sass_lines = sass_str.splitlines() line_idx = 0 while line_idx < len(sass_lines): diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index c94e77052..719ba3bf6 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -3,11 +3,11 @@ Fused Attention =============== This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team Extra Credits: - Original flash attention paper (https://arxiv.org/abs/2205.14135) - Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) -- Adam P. Goucher for simplified vector math """ @@ -22,8 +22,55 @@ torch_dtype:tl.constexpr = torch.float16 TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') @triton.jit -def max_fn(x, y): - return tl.math.max(x, y) +def _attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, qk_scale, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, + offs_m: tl.constexpr, + offs_n: tl.constexpr, +): + # range of values handled by this stage + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + else: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_block_ptr = tl.advance(K_block_ptr, (0, lo)) + V_block_ptr = tl.advance(V_block_ptr, (lo, 0)) + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk * qk_scale + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1) * qk_scale) + qk = qk * qk_scale - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + # -- update output accumulator -- + acc = acc * alpha[:, None] + # update acc + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(tl.float16), v) + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i @triton.jit def _attn_fwd_inner( @@ -101,6 +148,7 @@ def _attn_fwd( stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, +<<<<<<< HEAD N_CTX, BLOCK_DMODEL: tl.constexpr, STAGE: tl.constexpr, @@ -113,20 +161,50 @@ def _attn_fwd( qkv_offset = off_hz * stride_qh Q_block_ptr = tl.make_block_ptr( base=Q + qkv_offset, +======= + N_CTX: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + off_z = off_hz // H + off_h = off_hz % H + qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh + + # block pointers + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 shape=(N_CTX, BLOCK_DMODEL), strides=(stride_qm, stride_qk), offsets=(start_m * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), ) K_block_ptr = tl.make_block_ptr( +<<<<<<< HEAD base=K + qkv_offset, +======= + base=K + qvk_offset, +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 shape=(BLOCK_DMODEL, N_CTX), strides=(stride_kk, stride_kn), offsets=(0, 0), block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) + order=(0, 1), ) +<<<<<<< HEAD V_block_ptr = tl.make_block_ptr( base=V + qkv_offset, shape=(N_CTX, BLOCK_DMODEL), @@ -134,6 +212,15 @@ def _attn_fwd( offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_DMODEL), order=(1, 0) +======= + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ) # initialize offsets offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) @@ -142,6 +229,7 @@ def _attn_fwd( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) +<<<<<<< HEAD # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM # don't work as expected with `exp` in the loop @@ -186,47 +274,200 @@ def _attn_fwd( block_shape=(BLOCK_M, BLOCK_DMODEL), order=(1, 0) ) +======= + # load scales + qk_scale = sm_scale + qk_scale *= 1.44269504 # 1/log(2) + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + # stage 1: off-band + if STAGE & 1: + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, qk_scale, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 1, offs_m, offs_n, + ) + # barrier makes it easier for compielr to schedule the + # two loops independently + tl.debug_barrier() + # stage 2: on-band + if STAGE & 2: + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, qk_scale, + BLOCK_M, BLOCK_DMODEL, BLOCK_N, + 2, offs_m, offs_n, + ) + # epilogue + m_i += tl.math.log2(l_i) + acc = acc / l_i[:, None] + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(m_ptrs, m_i) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 tl.store(O_block_ptr, acc.to(Out.type.element_ty)) @triton.jit +<<<<<<< HEAD def _bwd_preprocess( Out, DO, NewDO, Delta, +======= +def _attn_bwd_preprocess( + O, DO, + Delta, + Z, H, N_CTX, +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_hz = tl.program_id(1) off_n = tl.arange(0, D_HEAD) # load - o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) - # compute + o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]) + do = tl.load(DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) delta = tl.sum(o * do, axis=1) # write-back +<<<<<<< HEAD tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) tl.store(Delta + off_m, delta) +======= + tl.store(Delta + off_hz * N_CTX + off_m, delta) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 + + +# The main inner-loop logic for computing dK and dV. +@triton.jit +def _attn_bwd_dkdv( + dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + offs_k = tl.arange(0, BLOCK_DMODEL) + qT_ptrs = Q + offs_m[None, :] * stride_tok + offs_k[:, None] * stride_d + do_ptrs = DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(qT_ptrs) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + qkT = tl.dot(k, qT) + pT = tl.math.exp2(qkT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(do_ptrs) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)).to(tl.float32) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + qT_ptrs += step_m * stride_tok + do_ptrs += step_m * stride_tok + return dk, dv + + +# the main inner-loop logic for computing dQ +@triton.jit +def _attn_bwd_dq( + dq, q, K, V, + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, + MASK: tl.constexpr, +): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + offs_k = tl.arange(0, BLOCK_DMODEL) + kT_ptrs = K + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + vT_ptrs = V + offs_n[None, :] * stride_tok + offs_k[:, None] * stride_d + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(kT_ptrs) + vT = tl.load(vT_ptrs) + qk = tl.dot(q, kT) + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + kT_ptrs += step_n * stride_tok + vT_ptrs += step_n * stride_tok + return dq @triton.jit -def _bwd_kernel( - Q, K, V, sm_scale, Out, DO, +def _attn_bwd( + Q, K, V, sm_scale, + DO, DQ, DK, DV, - L, - D, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - Z, H, N_CTX, P_SEQ, - num_block_q, num_block_kv, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - CAUSAL: tl.constexpr, + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + H, N_CTX, + BLOCK_M1: tl.constexpr, + BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, + BLOCK_N2: tl.constexpr, + BLK_SLICE_FACTOR: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, ): - off_hz = tl.program_id(0) - off_z = off_hz // H - off_h = off_hz % H - qk_scale = sm_scale * 1.44269504 + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + # offset pointers for batch/head +<<<<<<< HEAD Q += off_z * stride_qz + off_h * stride_qh K += off_z * stride_kz + off_h * stride_kh V += off_z * stride_vz + off_h * stride_vh @@ -298,6 +539,117 @@ def _bwd_kernel( dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) tl.store(dk_ptrs, dk) tl.store(dv_ptrs, dv) +======= + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # load scales + offs_k = tl.arange(0, BLOCK_DMODEL) + + # THIS BLOCK DOES DK/DV/DR: + + start_n = pid * BLOCK_N1 + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + # load K and V: they stay in SRAM throughout the inner loop. + k = tl.load(K + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + v = tl.load(V + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d) + + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + + dk, dv = _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=True, + ) + + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + # Compute dK and dV for non-masked blocks. + dk, dv = _attn_bwd_dkdv(dk, dv, + Q, k, v, sm_scale, + DO, + M, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, + start_n, start_m, num_steps, + MASK=False, + ) + + dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dv_ptrs, dv) + + # Write back dK. + dk *= sm_scale + dk_ptrs = DK + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d + tl.store(dk_ptrs, dk) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + q = tl.load(Q + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + do = tl.load(DO + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _attn_bwd_dq( + dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, + MASK=True, + ) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _attn_bwd_dq( + dq, q, K, V, + do, m, D, + stride_tok, stride_d, + H, N_CTX, + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, + start_m, end_n - num_steps * BLOCK_N2, num_steps, + MASK=False, + ) + # Write back dQ. + dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d + dq *= LN2 + tl.store(dq_ptrs, dq) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 @triton.jit def _bwd_kernel_dk_dv( @@ -515,13 +867,13 @@ empty = torch.empty(128, device="cuda") class _attention(torch.autograd.Function): - @staticmethod def forward(ctx, q, k, v, causal, sm_scale, split_kernel=False): # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv assert Lk in {16, 32, 64, 128} +<<<<<<< HEAD o = torch.empty_like(q, dtype=v.dtype) if torch.version.hip is None: BLOCK_M = 128 @@ -537,6 +889,15 @@ class _attention(torch.autograd.Function): ) M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) +======= + o = torch.empty_like(q) + BLOCK_M = 128 + BLOCK_N = 64 if Lk <= 64 else 32 + num_stages = 4 if Lk <= 64 else 3 + num_warps = 4 + grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 _attn_fwd[grid]( q, k, v, sm_scale, M, o, q.stride(0), q.stride(1), q.stride(2), q.stride(3), @@ -545,6 +906,7 @@ class _attention(torch.autograd.Function): o.stride(0), o.stride(1), o.stride(2), o.stride(3), q.shape[0], q.shape[1], N_CTX=q.shape[2], +<<<<<<< HEAD BLOCK_DMODEL=Lk, STAGE=stage, ) @@ -553,17 +915,30 @@ class _attention(torch.autograd.Function): best_config = _attn_fwd.get_best_config(Z = q.shape[0], H = q.shape[1], N_CTX = q.shape[2], STAGE = stage, BLOCK_DMODEL=Lk) block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1]) grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1) +======= + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + BLOCK_DMODEL=Lk, + STAGE=3, + num_warps=num_warps, + num_stages=num_stages, + ) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = Lk ctx.causal = causal +<<<<<<< HEAD ctx.split_kernel = split_kernel +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 return o @staticmethod def backward(ctx, do): +<<<<<<< HEAD # configuration is not supported assert(not (ctx.split_kernel and not ctx.causal)) if torch.version.hip is not None: @@ -636,10 +1011,53 @@ class _attention(torch.autograd.Function): ) # print(h.asm["ttgir"]) return dq, dk, dv, None, None, None +======= + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 128, 128, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + PRE_BLOCK = 128 + assert N_CTX % PRE_BLOCK == 0 + pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) + delta = torch.empty_like(M) + _attn_bwd_preprocess[pre_grid]( + o, do, + delta, + BATCH, N_HEAD, N_CTX, + BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + ) + grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, + M, delta, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + N_HEAD, N_CTX, + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=NUM_WARPS, + num_stages=NUM_STAGES, + ) + + return dq, dk, dv, None, None + +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 attention = _attention.apply +<<<<<<< HEAD @pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64), (4, 48, 2048, 64), @@ -687,6 +1105,27 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() +======= +@pytest.mark.parametrize("Z, H, N_CTX, D_HEAD", [(1, 2, 1024, 64)]) +@pytest.mark.parametrize("causal", [True]) +def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + torch.manual_seed(20) + q = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + k = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) + v = ( + torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") + .normal_(mean=0.0, std=0.5) + .requires_grad_() + ) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 sm_scale = 0.5 split_kernel = True dout = torch.randn_like(q) @@ -722,16 +1161,12 @@ def test_op_bwd(Z, H, N_CTX, D_HEAD, dtype=torch.float16): try: from flash_attn.flash_attn_interface import \ flash_attn_qkvpacked_func as flash_attn_func - FLASH_VER = 2 + HAS_FLASH = True except BaseException: - try: - from flash_attn.flash_attn_interface import flash_attn_func - FLASH_VER = 1 - except BaseException: - FLASH_VER = None -HAS_FLASH = FLASH_VER is not None + HAS_FLASH = False # vary seq length for fixed head and batch=4 +<<<<<<< HEAD configs = [] for mode in ['fwd', 'bwd']: for D_HEAD in [128, 64]: @@ -765,11 +1200,37 @@ for mode in ['fwd', 'bwd']: 'mode': mode, 'causal': causal}) ) +======= +configs = [ + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 15)], + line_arg="provider", + line_vals=["triton"] + (["flash"] if HAS_FLASH else []), + line_names=["Triton"] + (["Flash-2"] if HAS_FLASH else []), + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "D_HEAD": D_HEAD, + "dtype": torch.float16, + "mode": mode, + "causal": causal, + }, + ) + for mode in ["fwd", "bwd"] + for causal in [True] +] +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 @triton.testing.perf_report(configs) -def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): - assert mode in ['fwd', 'bwd'] +def bench_flash_attention( + BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda" +): + assert mode in ["fwd", "bwd"] warmup = 25 rep = 100 split_kernel = False @@ -785,37 +1246,35 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype q = q.to(torch_dtype) k = k.to(torch_dtype) sm_scale = 1.3 +<<<<<<< HEAD fn = lambda: attention(q, k, v, causal, sm_scale, split_kernel) if mode == 'bwd': +======= + fn = lambda: attention(q, k, v, causal, sm_scale) + if mode == "bwd": +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) if provider == "flash": - qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) - if FLASH_VER == 1: - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros((BATCH + 1,), device=device, dtype=torch.int32) - cu_seqlens[1:] = lengths.cumsum(0) - qkv = qkv.reshape(BATCH * N_CTX, 3, H, D_HEAD) - fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=causal) - elif FLASH_VER == 2: - fn = lambda: flash_attn_func(qkv, causal=causal) - else: - raise ValueError(f'unknown {FLASH_VER = }') - if mode == 'bwd': + qkv = torch.randn( + (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True + ) + fn = lambda: flash_attn_func(qkv, causal=causal) + if mode == "bwd": o = fn() do = torch.randn_like(o) fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) - flops_per_matmul = 2. * BATCH * H * N_CTX * N_CTX * D_HEAD + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD total_flops = 2 * flops_per_matmul if causal: total_flops *= 0.5 - if mode == 'bwd': + if mode == "bwd": total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) return total_flops / ms * 1e-9 # only works on post-Ampere GPUs right now -bench_flash_attention.run(save_path='.', print_data=True) +bench_flash_attention.run(save_path=".", print_data=True) diff --git a/python/tutorials/11-grouped-gemm.py b/python/tutorials/11-grouped-gemm.py index 2b20e0a01..5b4305b79 100644 --- a/python/tutorials/11-grouped-gemm.py +++ b/python/tutorials/11-grouped-gemm.py @@ -1,3 +1,14 @@ +<<<<<<< HEAD +======= + +""" +Group GEMM +============================ +This group gemm kernel launches a fixed number of CTA to compute a group +of gemms. The scheduling is static and we do it on device. +""" + +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # Copyright (c) 2023 NVIDIA Corporation & Affiliates. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining @@ -24,10 +35,16 @@ import torch import triton import triton.language as tl +<<<<<<< HEAD # This group gemm kernel launches a fixed number of CTA to compute a group # of gemms. The scheduling is static and we do it on device @triton.autotune( configs= [ +======= + +@triton.autotune( + configs=[ +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 triton.Config( { 'BLOCK_SIZE_M': 128, @@ -60,6 +77,7 @@ import triton.language as tl 'NUM_SM': 128, } ), +<<<<<<< HEAD ] if torch.version.hip is None else [ triton.Config( { @@ -113,6 +131,10 @@ import triton.language as tl ), ], key=['SUM_M', 'SUM_N', 'SUM_K'], +======= + ], + key=['group_size'], +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ) @triton.jit def grouped_matmul_kernel( @@ -128,9 +150,12 @@ def grouped_matmul_kernel( g_lds, # number of gemms group_size, +<<<<<<< HEAD SUM_M: tl.constexpr, SUM_N: tl.constexpr, SUM_K: tl.constexpr, +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 # number of virtual SM NUM_SM: tl.constexpr, # tile sizes @@ -211,9 +236,12 @@ def group_gemm_fn(group_A, group_B): g_sizes = [] g_lds = [] group_C = [] +<<<<<<< HEAD SUM_M = 0 SUM_N = 0 SUM_K = 0 +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 for i in range(group_size): A = group_A[i] B = group_B[i] @@ -226,9 +254,12 @@ def group_gemm_fn(group_A, group_B): B_addrs.append(B.data_ptr()) C_addrs .append(C.data_ptr()) g_sizes += [M, N, K] +<<<<<<< HEAD SUM_M += M SUM_N += N SUM_K += K +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 g_lds += [A.stride(0), B.stride(0), C.stride(0)] # note these are device tensors @@ -250,9 +281,12 @@ def group_gemm_fn(group_A, group_B): d_g_sizes, d_g_lds, group_size, +<<<<<<< HEAD SUM_M=SUM_M, SUM_N=SUM_N, SUM_K=SUM_K, +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ) return group_C @@ -277,6 +311,7 @@ for i in range(group_size): tri_out = group_gemm_fn(group_A, group_B) ref_out = [torch.matmul(a, b) for a, b in zip(group_A, group_B)] +<<<<<<< HEAD rtol = 0 if torch.version.hip is None else 1e-2 for i in range(group_size): assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=rtol) @@ -284,6 +319,14 @@ for i in range(group_size): # only launch the kernel, no tensor preparation here to remove all overhead def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, sum_m, sum_n, sum_k): +======= +for i in range(group_size): + assert torch.allclose(ref_out[i], tri_out[i], atol=1e-2, rtol=0) + + +# only launch the kernel, no tensor preparation here to remove all overhead +def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 grid = lambda META: (META['NUM_SM'],) grouped_matmul_kernel[grid]( a_ptrs, @@ -292,9 +335,12 @@ def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size, sum_m, sum_n, sizes, lds, group_size, +<<<<<<< HEAD sum_m, sum_n, sum_k, +======= +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 ) @@ -355,7 +401,11 @@ def benchmark(N, provider): if provider == 'cublas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles) if provider == 'triton': +<<<<<<< HEAD ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size, group_size*N, group_size*N, group_size*N), quantiles=quantiles) +======= + ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 return ms, max_ms, min_ms diff --git a/test/Analysis/test-alias.mlir b/test/Analysis/test-alias.mlir index 75740d929..1f3b1df65 100644 --- a/test/Analysis/test-alias.mlir +++ b/test/Analysis/test-alias.mlir @@ -26,7 +26,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/Analysis/test-alignment.mlir b/test/Analysis/test-alignment.mlir index 11ee22fe4..77eae5f22 100644 --- a/test/Analysis/test-alignment.mlir +++ b/test/Analysis/test-alignment.mlir @@ -184,13 +184,28 @@ tt.func @rem() { // ----- +// CHECK-LABEL: @expanddims +tt.func @expanddims() { + // CHECK: contiguity = [128], divisibility = [1073741824], constancy = [1], constant_value = + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [128], constant_value = 2 + %1 = arith.constant dense<2> : tensor<128xi32> + // CHECK-NEXT: contiguity = [1], divisibility = [2], constancy = [1], constant_value = + %2 = arith.muli %0, %1 : tensor<128xi32> + // CHECK-NEXT: contiguity = [1, 1], divisibility = [2, 2], constancy = [1, 1], constant_value = + %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> + tt.return +} + +// ----- + // CHECK-LABEL: @broadcast tt.func @broadcast() { // CHECK: contiguity = [1], divisibility = [64], constancy = [128], constant_value = 64 %0 = arith.constant dense<64> : tensor<128xi32> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 1], constant_value = 64 + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 1], constant_value = 64 %1 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 1], constancy = [128, 128], constant_value = 64 + // CHECK-NEXT: contiguity = [1, 1], divisibility = [64, 64], constancy = [128, 128], constant_value = 64 %2 = tt.broadcast %1 : (tensor<128x1xi32>) -> tensor<128x128xi32> tt.return } @@ -290,9 +305,9 @@ tt.func @shift() { %1 = arith.constant dense<8> : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [4], constancy = [128], constant_value = 4 %2 = arith.constant dense<4> : tensor<128xi32> - // CHECK-NEXT: contiguity = [1], divisibility = [274877906944], constancy = [1], constant_value = + // CHECK-NEXT: contiguity = [1], divisibility = [256], constancy = [1], constant_value = %3 = arith.shli %0, %1 : tensor<128xi32> - // CHECK-NEXT: contiguity = [1], divisibility = [67108864], constancy = [1], constant_value = + // CHECK-NEXT: contiguity = [1], divisibility = [1], constancy = [1], constant_value = %4 = arith.shrsi %0, %2 : tensor<128xi32> // CHECK-NEXT: contiguity = [1], divisibility = [128], constancy = [128], constant_value = 128 %5 = arith.shli %1, %2 : tensor<128xi32> @@ -362,7 +377,7 @@ tt.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 %2 = tt.expand_dims %0 {axis = 1 : i32} : (tensor<128xi32>) -> tensor<128x1xi32> // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %3 = tt.splat %arg1 : (i32) -> tensor<128x1xi32> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [17179869184, 16], constancy = [1, 1], constant_value = + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = %4 = arith.muli %2, %3 : tensor<128x1xi32> // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %5 = tt.splat %arg0 : (!tt.ptr) -> tensor<128x1x!tt.ptr> @@ -386,11 +401,11 @@ tt.func @permute_2d(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: i32 %14 = tt.expand_dims %1 {axis = 0 : i32} : (tensor<128xi32>) -> tensor<1x128xi32> // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 128], constant_value = %15 = tt.splat %arg3 : (i32) -> tensor<1x128xi32> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [1, 1], constant_value = + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [1, 1], constant_value = %16 = arith.muli %14, %15 : tensor<1x128xi32> // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 128], constant_value = %17 = tt.broadcast %13 : (tensor<128x1x!tt.ptr>) -> tensor<128x128x!tt.ptr> - // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 17179869184], constancy = [128, 1], constant_value = + // CHECK-NEXT: contiguity = [1, 1], divisibility = [16, 16], constancy = [128, 1], constant_value = %18 = tt.broadcast %16 : (tensor<1x128xi32>) -> tensor<128x128xi32> // CHECK-NEXT: contiguity = [128, 1], divisibility = [16, 4], constancy = [1, 1], constant_value = %19 = tt.addptr %17, %18 : tensor<128x128x!tt.ptr>, tensor<128x128xi32> diff --git a/test/Analysis/test-allocation.mlir b/test/Analysis/test-allocation.mlir index 93d80448c..c0175811f 100644 --- a/test/Analysis/test-allocation.mlir +++ b/test/Analysis/test-allocation.mlir @@ -34,7 +34,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, // CHECK-NEXT: offset = 0, size = 4224 %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -64,11 +64,11 @@ tt.func @reusable(%A : !tt.ptr) { %a3_ = tt.load %a_ptr, %cst1, %cst2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL> // CHECK-NEXT: offset = 0, size = 4608 %a3 = triton_gpu.convert_layout %a3_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> - %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a1, %a2, %c_init {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %a4_ = tt.load %b_ptr, %cst3, %cst4 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #AL> // CHECK-NEXT: offset = 0, size = 1152 %a4 = triton_gpu.convert_layout %a4_ : (tensor<32x128xf16, #AL>) -> tensor<32x128xf16, #B_DOT> - %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c1 = tt.dot %a3, %a4, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> tt.return // CHECK-NEXT: size = 4608 } @@ -265,14 +265,14 @@ tt.func @alloc_m_barrier_scalar() { // CHECK-LABEL: scratch tt.func @scratch() { %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf16, #AL> - // CHECK: scratch offset = 0, size = 512 + // CHECK: scratch offset = 0, size = 128 %b = "tt.reduce" (%cst0) ({ ^bb0(%arg0: f16, %arg1: f16): %add = arith.addf %arg0, %arg1 : f16 tt.reduce.return %add : f16 }) {axis = 0 : i32} : (tensor<16x16xf16, #AL>) -> tensor<16xf16, #sliceAd0> tt.return - // CHECK-NEXT: size = 512 + // CHECK-NEXT: size = 128 } // CHECK-LABEL: trans diff --git a/test/Analysis/test-membar.mlir b/test/Analysis/test-membar.mlir index 63b4ef5d2..961176cc6 100644 --- a/test/Analysis/test-membar.mlir +++ b/test/Analysis/test-membar.mlir @@ -32,7 +32,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A_DOT> %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B_DOT> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/Conversion/divide-by-0.mlir b/test/Conversion/divide-by-0.mlir new file mode 100644 index 000000000..8eca5a348 --- /dev/null +++ b/test/Conversion/divide-by-0.mlir @@ -0,0 +1,14 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-gpu-to-llvm="target=nvvm" --cse | FileCheck %s + +// CHECK-LABEL: dont_divide_0 +// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i32) : i32 +// CHECK-NOT: llvm.urem %{{.*}}, %[[C0]] +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 8]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @dont_divide_0() attributes {noinline = false} { + %zero = arith.constant dense<0.000000e+00> : tensor<16x1xf32, #mma> + %cvt = triton_gpu.convert_layout %zero : (tensor<16x1xf32, #mma>) -> tensor<16x1xf32, #blocked> + tt.return + } +} diff --git a/test/Conversion/invalid.mlir b/test/Conversion/invalid.mlir index 81b866502..178d5109f 100644 --- a/test/Conversion/invalid.mlir +++ b/test/Conversion/invalid.mlir @@ -6,7 +6,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf32, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{element types of operands A and B must have same bit width}} - %D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf32, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } @@ -20,7 +20,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching encoding between A and B operands}} - %D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } @@ -34,7 +34,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32} { module attributes {"triton_gpu.num-warps" = 1 : i32} { tt.func @convert_dot(%A: tensor<16x16xf16, #dot_operand_a>, %B: tensor<16x16xf16, #dot_operand_b>, %C: tensor<16x16xf32, #mma0>) { // expected-error@+1 {{mismatching kWidth between A and B operands}} - %D = tt.dot %A, %B, %C {allowTF32 = true, transA = false, transB = false} : + %D = tt.dot %A, %B, %C {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> tt.return } diff --git a/test/Conversion/triton_ops.mlir b/test/Conversion/triton_ops.mlir index bb4cba096..ce6505d72 100644 --- a/test/Conversion/triton_ops.mlir +++ b/test/Conversion/triton_ops.mlir @@ -161,13 +161,13 @@ tt.func @dot_ops_infer(%ptr: !tt.ptr, %v : f32) { %zero1x1 = arith.constant dense<0.00e+00> : tensor<1x1xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> - %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> + %r1 = tt.dot %v128x32, %v32x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf32> * tensor<32x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<32x32xf32> - %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> + %r2 = tt.dot %v32x128, %v128x32, %zero32x32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x128xf32> * tensor<128x32xf32> -> tensor<32x32xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<128x128xf32> - %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> + %r3 = tt.dot %v128x1, %v1x128, %zero128x128 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x1xf32> * tensor<1x128xf32> -> tensor<128x128xf32> // CHECK: %{{.*}} = tt.dot %{{.*}} -> tensor<1x1xf32> - %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> + %r4 = tt.dot %v1x128, %v128x1, %zero1x1 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<1x128xf32> * tensor<128x1xf32> -> tensor<1x1xf32> %ptr128x128 = tt.splat %ptr : (!tt.ptr) -> tensor<128x128x!tt.ptr> %ptr32x32 = tt.splat %ptr : (!tt.ptr) -> tensor<32x32x!tt.ptr> diff --git a/test/Conversion/triton_to_tritongpu.mlir b/test/Conversion/triton_to_tritongpu.mlir index cc3b4dd5b..9d53e092c 100644 --- a/test/Conversion/triton_to_tritongpu.mlir +++ b/test/Conversion/triton_to_tritongpu.mlir @@ -5,7 +5,7 @@ tt.func @ops() { %a = arith.constant dense<1.00e+00> : tensor<128x32xf16> %b = arith.constant dense<2.00e+00> : tensor<32x128xf16> %c = arith.constant dense<3.00e+00> : tensor<128x128xf32> - %0 = tt.dot %a, %b, %c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> + %0 = tt.dot %a, %b, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16> * tensor<32x128xf16> -> tensor<128x128xf32> tt.return } diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index f937d5b78..9b6110fed 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -1283,11 +1283,19 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : %BB_DOT = triton_gpu.convert_layout %BB : (tensor<16x16xf16, #shared0>) -> tensor<16x16xf16, #dot_operand_b> %cst0 = arith.constant dense<0.000000e+00> : tensor<16x16xf32, #mma0> +<<<<<<< HEAD // PTX: llvm.inline_asm // PTX-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 // PTX: llvm.inline_asm // PTX-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> +======= + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 + %D = tt.dot %AA_DOT, %BB_DOT, %cst0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #dot_operand_a> * tensor<16x16xf16, #dot_operand_b> -> tensor<16x16xf32, #mma0> +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 tt.return } @@ -1466,7 +1474,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<128x32xf16, #shared>) -> tensor<128x32xf16, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<32x256xf16, #shared>) -> tensor<32x256xf16, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #dot_operand_a> * tensor<32x256xf16, #dot_operand_b> -> tensor<128x256xf32, #mma> %38 = triton_gpu.convert_layout %28 : (tensor<128x256xf32, #mma>) -> tensor<128x256xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<128x1x!tt.ptr, #blocked> @@ -1522,7 +1530,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<32x64xf16, #shared0>) -> tensor<32x64xf16, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<64x64xf16, #shared1>) -> tensor<64x64xf16, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<32x64xf32, #mma> %38 = triton_gpu.convert_layout %28 : (tensor<32x64xf32, #mma>) -> tensor<32x64xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x64x!tt.ptr, #blocked> @@ -1545,7 +1553,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b> - %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = false, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> tt.store %36, %28 : tensor<32x32xf32, #blocked> @@ -1575,6 +1583,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %a_mat = triton_gpu.convert_layout %a : (tensor<32x16xf32, #shared>) -> tensor<32x16xf32, #dot_operand_a> %b_mat = triton_gpu.convert_layout %b : (tensor<16x32xf32, #shared>) -> tensor<16x32xf32, #dot_operand_b> +<<<<<<< HEAD // PTX: llvm.inline_asm // PTX-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 // PTX: llvm.inline_asm @@ -1584,6 +1593,17 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // PTX: llvm.inline_asm // PTX-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> +======= + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + // CHECK: llvm.inline_asm + // CHECK-SAME: mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 + %28 = tt.dot %a_mat, %b_mat, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 %38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> @@ -1816,7 +1836,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK: %[[SI:.+]] = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> // CHECK: llvm.insertvalue %[[BC]], %[[SI]][0] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> %b_mat = arith.constant dense<1.000000e+00> : tensor<16x32xf32, #dot_operand_b> - %28 = tt.dot %a, %b_mat, %c {allowTF32 = true, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> + %28 = tt.dot %a, %b_mat, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<32x16xf32, #dot_operand_a> * tensor<16x32xf32, #dot_operand_b> -> tensor<32x32xf32, #mma> %38 = triton_gpu.convert_layout %28 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> %30 = tt.splat %ptr : (!tt.ptr) -> tensor<32x1x!tt.ptr, #blocked> %36 = tt.broadcast %30 : (tensor<32x1x!tt.ptr, #blocked>) -> tensor<32x32x!tt.ptr, #blocked> @@ -1846,7 +1866,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %cst_0 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_1 = arith.constant dense<1.000000e+00> : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %cst_2 = arith.constant dense<32> : tensor<32x1xi32, #blocked> - %0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> + %0 = tt.dot %cst_0, %cst_1, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma> %1 = triton_gpu.convert_layout %0 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> %2 = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> %3 = tt.expand_dims %2 {axis = 1 : i32} : (tensor<32xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>>) -> tensor<32x1xi32, #blocked> @@ -1905,6 +1925,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // ----- // CHECK-LABEL: sum_reduction +<<<<<<< HEAD // PTX: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32 // PTX: nvvm.redux.sync add %{{.*}}, %[[M]] // PTX: nvvm.barrier0 @@ -1925,6 +1946,14 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : // GCN: rocdl.barrier // GCN: llvm.load // GCN: llvm.store +======= +// CHECK: %[[M:.+]] = llvm.mlir.constant(-1 : i32) : i32 +// CHECK: nvvm.redux.sync add %{{.*}}, %[[M]] +// CHECK: nvvm.barrier0 +// CHECK: nvvm.shfl.sync bfly +// CHECK: nvvm.shfl.sync bfly +// CHECK: nvvm.barrier0 +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 #blocked = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { @@ -2012,6 +2041,7 @@ module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-c // ----- +<<<<<<< HEAD #blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [16, 4], warpsPerCTA = [2, 1], order = [1, 0], CTAsPerCGA = [1,1], CTASplitNum = [1,1], CTAOrder = [1, 0]}> module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32} { // CHECK-LABEL: atomic_add_f16 @@ -2029,6 +2059,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : // PTX-SAME: @$3 atom.global.gpu.add.noftz.f16x2 // GCN-COUNT-8: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr, f16 %8 = "tt.atomic_rmw"(%5, %6, %7) {atomic_rmw_op = 5 : i32, sem = 1: i32} : (tensor<32x32x!tt.ptr, #blocked>, tensor<32x32xf16, #blocked>, tensor<32x32xi1, #blocked>) -> tensor<32x32xf16, #blocked> +======= +// CHECK-LABEL: copyitem +// CHECK: st.shared.b8 +// CHECK: ld.shared.b8 +// CHECK-NOT: st.shared.b1 +// CHECK-NOT: ld.shared.b1 +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"triton_gpu.compute-capability" = 80 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @copyitem() attributes {noinline = false} { + %cst = arith.constant dense : tensor<4x1xi1, #blocked> + %0 = "tt.reduce"(%cst) <{axis = 1 : i32}> ({ + ^bb0(%arg0: i1, %arg1: i1): + %1 = arith.ori %arg0, %arg1 : i1 + tt.reduce.return %1 : i1 + }) : (tensor<4x1xi1, #blocked>) -> tensor<4xi1, #triton_gpu.slice<{dim = 1, parent = #blocked}>> +>>>>>>> ac9fa68d18c777e421bd3f6fb1ddcfd60b6fda33 tt.return } } diff --git a/test/Conversion/tritongpu_to_llvm_hopper.mlir b/test/Conversion/tritongpu_to_llvm_hopper.mlir index 053330d47..21a7e29b0 100644 --- a/test/Conversion/tritongpu_to_llvm_hopper.mlir +++ b/test/Conversion/tritongpu_to_llvm_hopper.mlir @@ -78,3 +78,108 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : tt.return } } + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_high_precision_acc + tt.func @dot_high_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + %m = triton_nvidia_gpu.dot_async %a, %b, %c + {maxNumImpreciseAcc = 32 : i32, allowTF32 = true} : + tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_low_precision_acc + tt.func @dot_low_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: llvm.return + %m = triton_nvidia_gpu.dot_async %a, %b, %c + {maxNumImpreciseAcc = 129 : i32, allowTF32 = true} : + tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}> +#shared = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 16, perPhase = 4, maxPhase = 2, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32} { + // CHECK-LABEL: @dot_mix_precision_acc + tt.func @dot_mix_precision_acc(%a: tensor<128x128xf8E5M2, #shared>, %b: tensor<128x256xf8E5M2, #shared1>, %c: tensor<128x256xf32, #mma>) { + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-NOT: llvm.fadd + // CHECK: nvgpu.wgmma + // CHECK-COUNT-128: llvm.fadd + // CHECK: llvm.return + %m = triton_nvidia_gpu.dot_async %a, %b, %c + {maxNumImpreciseAcc = 64 : i32, allowTF32 = true} : + tensor<128x128xf8E5M2, #shared> * tensor<128x256xf8E5M2, #shared1> -> tensor<128x256xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @dot_zero_acc + // Generate a wgmma with 2 sources. + // CHECK: nvgpu.wgmma %{{.*}}, %{{.*}} { + tt.func @dot_zero_acc(%a: tensor<128x64xf16, #shared>, %b: tensor<64x64xf16, #shared1>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %m = triton_nvidia_gpu.dot_async %a, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + tensor<128x64xf16, #shared> * tensor<64x64xf16, #shared1> -> tensor<128x64xf32, #mma> + tt.return + } +} + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: @dot_reg_operand_A + // Generate a wgmma where the first operand is a struct. + // CHECK: nvgpu.wgmma {{.*}} : (!llvm.struct<(i32, i32, i32, i32)>, i64, !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)>) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + tt.func @dot_reg_operand_A(%a: tensor<128x64xf16, #mma>, %b: tensor<64x64xf16, #shared>) { + %cst = arith.constant dense<0.000000e+00> : tensor<128x64xf32, #mma> + %opA = triton_gpu.convert_layout %a : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %m = tt.dot %opA, %b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : + tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return + } +} diff --git a/test/LLVMIR/break-phi-struct.ll b/test/LLVMIR/break-phi-struct.ll new file mode 100644 index 000000000..b27c87588 --- /dev/null +++ b/test/LLVMIR/break-phi-struct.ll @@ -0,0 +1,33 @@ +; RUN: triton-llvm-opt -break-struct-phi-nodes %s | FileCheck %s + +; CHECK-LABEL: struct +define {i32, i32} @struct(i1 %c) { +; CHECK: br i1 %{{.*}}, label [[TRUE:%.*]], label [[FALSE:%.*]] + br i1 %c, label %true, label %false + +true: + %s.1 = insertvalue {i32, i32} undef, i32 20, 0 + %s.2 = insertvalue {i32, i32} %s.1, i32 200, 1 + +; CHECK-DAG: [[E0:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0 +; CHECK-DAG: [[E1:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1 +; CHECK: br + br label %exit + +false: + %s.3 = insertvalue {i32, i32} undef, i32 30, 0 + %s.4 = insertvalue {i32, i32} %s.3, i32 300, 1 +; CHECK-DAG: [[E2:%.*]] = extractvalue { i32, i32 } %{{.*}}, 0 +; CHECK-DAG: [[E3:%.*]] = extractvalue { i32, i32 } %{{.*}}, 1 +; CHECK: br + br label %exit + +exit: +; CHECK-DAG: [[PHI0:%.*]] = phi i32 [ [[E0]], [[TRUE]] ], [ [[E2]], [[FALSE]] ] +; CHECK-DAG: [[PHI1:%.*]] = phi i32 [ [[E1]], [[TRUE]] ], [ [[E3]], [[FALSE]] ] +; CHECK: [[S0:%.*]] = insertvalue { i32, i32 } undef, i32 [[PHI0]], 0 +; CHECK: [[S1:%.*]] = insertvalue { i32, i32 } [[S0]], i32 [[PHI1]], 1 +; CHECK: ret { i32, i32 } [[S1]] + %r = phi {i32, i32} [ %s.2, %true], [ %s.4, %false ] + ret {i32, i32} %r +} diff --git a/test/NVGPU/test_cga.mlir b/test/NVGPU/test_cga.mlir index 0b72d92e1..8b9705db5 100644 --- a/test/NVGPU/test_cga.mlir +++ b/test/NVGPU/test_cga.mlir @@ -17,14 +17,6 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : %ptr = llvm.mlir.null : !llvm.ptr // CHECK: llvm.inline_asm - // CHECK: llvm.inline_asm - // CHECK: llvm.inline_asm - // CHECK: llvm.inline_asm - // CHECK: llvm.inline_asm - // CHECK: llvm.mul - // CHECK: llvm.add - // CHECK: llvm.mul - // CHECK: llvm.add %v = nvgpu.cluster_id llvm.store %v, %ptr : !llvm.ptr diff --git a/test/NVGPU/test_wgmma.mlir b/test/NVGPU/test_wgmma.mlir index bb4844ab5..9ad7e606c 100644 --- a/test/NVGPU/test_wgmma.mlir +++ b/test/NVGPU/test_wgmma.mlir @@ -5,37 +5,11 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 %buffer = llvm.mlir.null : !llvm.ptr %height = arith.constant 16 : i32 // CHECK: llvm.ptrtoint - // CHECK: llvm.shl - // CHECK: llvm.lshr - // CHECK: llvm.zext - // CHECK: llvm.mul - // CHECK: llvm.lshr - // CHECK: llvm.shl - // CHECK: llvm.lshr - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.or - %descA = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32}: (!llvm.ptr, i32) -> (i64) + // CHECK: llvm.inline_asm + %descA = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32, swizzling = 64 : i64}: (!llvm.ptr, i32) -> (i64) // CHECK: llvm.ptrtoint - // CHECK: llvm.shl - // CHECK: llvm.lshr - // CHECK: llvm.zext - // CHECK: llvm.mul - // CHECK: llvm.lshr - // CHECK: llvm.shl - // CHECK: llvm.lshr - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.shl - // CHECK: llvm.or - // CHECK: llvm.or - %descB = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32}: (!llvm.ptr, i32) -> (i64) + // CHECK: llvm.inline_asm + %descB = nvgpu.wgmma_desc_create %buffer, %height {mode = 2 : i32, swizzling = 64 : i64}: (!llvm.ptr, i32) -> (i64) // CHECK-COUNT-32: llvm.extractvalue // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31}, $64, $65, 1, 1, 1, 0, 1;" @@ -43,3 +17,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 tt.return } } // end module + +// ----- + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 2 : i32} { + tt.func @wgmma_no_acc(%descA: i64, %descB: i64) { + // CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 {$0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59,$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80,$81,$82,$83,$84,$85,$86,$87,$88,$89,$90,$91,$92,$93,$94,$95,$96,$97,$98,$99,$100,$101,$102,$103,$104,$105,$106,$107,$108,$109,$110,$111,$112,$113,$114,$115,$116,$117,$118,$119,$120,$121,$122,$123,$124,$125,$126,$127}, $128, $129, 0, 1, 1;", "=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,=f,l,l" %{{.*}}, %{{.*}} : (i64, i64) -> !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + %acc0 = nvgpu.wgmma %descA, %descB + {eltTypeA = 3 : i32, eltTypeB = 3 : i32, eltTypeC = 7 : i32, k = 32 : i32, layoutA = 0 : i32, layoutB = 1 : i32, m = 64 : i32, n = 256 : i32} : + (i64, i64) -> + !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32)> + tt.return + } +} diff --git a/test/Triton/combine.mlir b/test/Triton/combine.mlir index deb85dc62..170bad012 100644 --- a/test/Triton/combine.mlir +++ b/test/Triton/combine.mlir @@ -10,12 +10,12 @@ tt.func @test_combine_dot_add_pattern() -> (tensor<128x128xf32>, tensor<128x128x %zero = arith.constant dense<0.0> : tensor<128x128xf32> %d = arith.constant dense<3.0> : tensor<128x128xf32> - %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + %dot_out = tt.dot %a, %b, %zero {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> - // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res0:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res0 = arith.addf %dot_out, %d : tensor<128x128xf32> - // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> + // CHECK-NEXT: %[[res1:.*]] = tt.dot %[[a]], %[[b]], %[[d]] {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf32> * tensor<128x128xf32> -> tensor<128x128xf32> %res1 = arith.addf %d, %dot_out : tensor<128x128xf32> tt.return %res0, %res1 : tensor<128x128xf32>, tensor<128x128xf32> diff --git a/test/Triton/reorder-broadcast.mlir b/test/Triton/reorder-broadcast.mlir index fbe44dace..201b81b1e 100644 --- a/test/Triton/reorder-broadcast.mlir +++ b/test/Triton/reorder-broadcast.mlir @@ -38,3 +38,30 @@ tt.func @test_broadcast_elementwise_pattern(%arg0: tensor<128x1xf32>) -> (tensor tt.return %abs, %add : tensor<128x128xf32>, tensor<128x32xf32> } + +// CHECK-LABEL: @test_broadcast_binary_op_pattern +tt.func @test_broadcast_binary_op_pattern(%arg0: tensor<128x1xf32>, %arg1: tensor<128x1xf32>, %arg2: tensor<1x128xf32>) -> (tensor<128x128xf32>, tensor<128x128xf32>) { + // CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[mul]] : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast1 = tt.broadcast %arg1 : (tensor<128x1xf32>) -> tensor<128x128xf32> + %mul = arith.mulf %broadcast0, %broadcast1 : tensor<128x128xf32> + + // CHECK: %[[mul:.*]] = arith.mulf %{{.*}}, %{{.*}} : tensor<128x128xf32> + %broadcast2 = tt.broadcast %arg2 : (tensor<1x128xf32>) -> tensor<128x128xf32> + %mul1 = arith.mulf %broadcast0, %broadcast2 : tensor<128x128xf32> + + tt.return %mul, %mul1 : tensor<128x128xf32>, tensor<128x128xf32> +} + +// CHECK-LABEL: @test_broadcast_mix_type_op_pattern +tt.func @test_broadcast_mix_type_op_pattern(%arg0: tensor<128x1xf32>, %arg1: f32, %arg2: tensor<1x128xf32>, %arg3: tensor<128x1xi1>) -> (tensor<128x128xf32>) { + // CHECK: %[[sel:.*]] = arith.select %{{.*}}, %{{.*}}, %{{.*}} : tensor<128x1xi1>, tensor<128x1xf32> + // CHECK-NEXT: %{{.*}} = tt.broadcast %[[sel]] : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast0 = tt.broadcast %arg0 : (tensor<128x1xf32>) -> tensor<128x128xf32> + %broadcast1 = tt.splat %arg1 : (f32) -> tensor<128x128xf32> + %cond = tt.broadcast %arg3 : (tensor<128x1xi1>) -> tensor<128x128xi1> + %sel = arith.select %cond, %broadcast0, %broadcast1 : tensor<128x128xi1>, tensor<128x128xf32> + + tt.return %sel : tensor<128x128xf32> +} diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir new file mode 100644 index 000000000..09fa73cd9 --- /dev/null +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -0,0 +1,46 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-accelerate-matmul=compute-capability=90 | FileCheck %s + +// CHECK: #[[MMA:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 16, 16]}> +// CHECK: #[[MMA1:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +// CHECK: #[[MMA2:.+]] = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 32, 16]}> +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + // CHECK: mma_chain_loop + tt.func public @mma_chain_loop( + %170: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>, + %171: tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>, + %179: tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>, + %164: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>>, + %165: tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>>, + %173: tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>, + %153: tensor<128x64x!tt.ptr, #blocked1>) { + %c0_i32 = arith.constant 0 : i32 + %c8_i32 = arith.constant 8 : i32 + %c1_i32 = arith.constant 1 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<128x16xf16, #blocked> + %cst_0 = arith.constant dense<0.000000e+00> : tensor<128x64xf16, #blocked1> + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x32xf16, #blocked2> + // CHECK: scf.for + // CHECK: tt.dot {{.*}} -> tensor<128x16xf16, #[[MMA]]> + // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + %115 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %cst_0) -> (tensor<128x64xf16, #blocked1>) : i32 { + %172 = tt.dot %170, %171, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x16xf16, #blocked> + %178 = triton_gpu.convert_layout %172 : (tensor<128x16xf16, #blocked>) -> tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + %180 = tt.dot %178, %179, %arg16 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<16x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + scf.yield %180 : tensor<128x64xf16, #blocked1> + } + // CHECK: scf.for + // CHECK: tt.dot {{.*}} -> tensor<128x32xf16, #[[MMA2]]> + // CHECK: tt.dot {{.*}} -> tensor<128x64xf16, #[[MMA1]]> + %149 = scf.for %arg15 = %c0_i32 to %c8_i32 step %c1_i32 iter_args(%arg16 = %115) -> (tensor<128x64xf16, #blocked1>) : i32 { + %166 = tt.dot %164, %165, %cst_2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x32xf16, #blocked2> + %172 = triton_gpu.convert_layout %166 : (tensor<128x32xf16, #blocked2>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> + %174 = tt.dot %172, %173, %arg16 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<32x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<128x64xf16, #blocked1> + scf.yield %174 : tensor<128x64xf16, #blocked1> + } + tt.store %153, %149 {cache = 1 : i32, evict = 1 : i32} : tensor<128x64xf16, #blocked1> + tt.return + } +} diff --git a/test/TritonGPU/coalesce.mlir b/test/TritonGPU/coalesce.mlir index 496fb848e..4979c77bd 100644 --- a/test/TritonGPU/coalesce.mlir +++ b/test/TritonGPU/coalesce.mlir @@ -69,3 +69,71 @@ tt.func @load_tensor(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: } } + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + + +// CHECK: [[NARROW_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK: [[WIDE_LAYOUT:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked> + %6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked> + %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> + %13 = arith.extf %12 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked> + %14 = arith.addf %9, %13 : tensor<1024xf32, #blocked> + %15 = tt.splat %arg2 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + // CHECK: tt.store {{.*}} : tensor<1024xf32, [[WIDE_LAYOUT]]> + tt.store %16, %14, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf32, #blocked> + tt.return +} + +} + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + +// CHECK-NOT: sizePerThread = [4] +// CHECK: #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +// CHECK-NOT: sizePerThread = [4] +tt.func public @load_tensors_two_types(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked> + %3 = tt.splat %1 : (i32) -> tensor<1024xi32, #blocked> + %4 = arith.addi %3, %2 : tensor<1024xi32, #blocked> + %5 = tt.splat %arg3 : (i32) -> tensor<1024xi32, #blocked> + %6 = "triton_gpu.cmpi"(%4, %5) <{predicate = 2 : i64}> : (tensor<1024xi32, #blocked>, tensor<1024xi32, #blocked>) -> tensor<1024xi1, #blocked> + %7 = tt.splat %arg0 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %8 = tt.addptr %7, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %9 = tt.load %8, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf32, #blocked> + %10 = tt.splat %arg1 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %11 = tt.addptr %10, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %12 = tt.load %11, %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<1024xf16, #blocked> + %13 = arith.extf %12 : tensor<1024xf16, #blocked> to tensor<1024xf32, #blocked> + %14 = arith.addf %9, %13 : tensor<1024xf32, #blocked> + %15 = tt.splat %arg2 : (!tt.ptr) -> tensor<1024x!tt.ptr, #blocked> + %16 = tt.addptr %15, %4 : tensor<1024x!tt.ptr, #blocked>, tensor<1024xi32, #blocked> + %17 = arith.truncf %14 : tensor<1024xf32, #blocked> to tensor<1024xf16, #blocked> + tt.store %16, %17, %6 {cache = 1 : i32, evict = 1 : i32} : tensor<1024xf16, #blocked> + tt.return +} + +} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 4a1a92798..8f5685ae8 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -3,6 +3,10 @@ #layout0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> #layout1 = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +#layout2 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#layout3 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [2, 16], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}> + + module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: [[$target_layout:#.*]] = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> @@ -105,6 +109,19 @@ tt.func @hoist_above_ext2(%arg0: tensor<1024xf16, #layout0>, %arg1: f16) -> tens tt.return %4 : tensor<1024xf32, #layout1> } +// Hoist the convert on top of broadcast to make it cheaper. +// CHECK-LABEL: hoist_above_broadcast +tt.func @hoist_above_broadcast(%arg0: tensor<1024x1xf32, #layout2>, %arg1: f32) -> tensor<1024x128xf32, #layout3> { +// CHECK: %[[CVT:.+]] = triton_gpu.convert_layout +// CHECK: tt.broadcast %[[CVT]] +// CHECK-NOT: triton_gpu.convert_layout +// CHECK: tt.return + %0 = tt.broadcast %arg0 : (tensor<1024x1xf32, #layout2>) -> tensor<1024x128xf32, #layout2> + %1 = tt.splat %arg1 : (f32) -> tensor<1024x128xf32, #layout2> + %2 = arith.addf %0, %1 : tensor<1024x128xf32, #layout2> + %3 = triton_gpu.convert_layout %2 : (tensor<1024x128xf32, #layout2>) -> tensor<1024x128xf32, #layout3> + tt.return %3 : tensor<1024x128xf32, #layout3> +} // CHECK-LABEL: if @@ -1528,7 +1545,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %26 = triton_gpu.convert_layout %19 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> %27 = triton_gpu.convert_layout %25 : (tensor<32x32xf16, #blocked>) -> tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> %28 = triton_gpu.convert_layout %cst : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #blocked5> - %29 = tt.dot %26, %27, %28 {allowTF32 = true} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> + %29 = tt.dot %26, %27, %28 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked5}>> * tensor<32x32xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked5}>> -> tensor<32x32xf32, #blocked5> %30 = triton_gpu.convert_layout %29 : (tensor<32x32xf32, #blocked5>) -> tensor<32x32xf32, #blocked> %31:2 = "tt.reduce"(%30, %11) <{axis = 1 : i32}> ({ ^bb0(%arg3: f32, %arg4: i32, %arg5: f32, %arg6: i32): @@ -1675,7 +1692,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %117 = tt.load %116 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x64xf16, #blocked3> %118 = triton_gpu.convert_layout %41 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %119 = triton_gpu.convert_layout %97 : (tensor<64x64xf16, #blocked6>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> - %120 = tt.dot %118, %119, %cst {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> + %120 = tt.dot %118, %119, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf16, #blocked> %121 = triton_gpu.convert_layout %120 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #blocked2> %122 = arith.extf %121 : tensor<128x64xf16, #blocked2> to tensor<128x64xf32, #blocked2> %123 = "tt.reduce"(%122) <{axis = 1 : i32}> ({ @@ -1704,7 +1721,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %142 = triton_gpu.convert_layout %141 : (tensor<128x64xf16, #blocked2>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> %143 = triton_gpu.convert_layout %117 : (tensor<64x64xf16, #blocked3>) -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> %144 = triton_gpu.convert_layout %140 : (tensor<128x64xf32, #blocked2>) -> tensor<128x64xf32, #blocked> - %145 = tt.dot %142, %143, %144 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> + %145 = tt.dot %142, %143, %144 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x64xf32, #blocked> %146 = triton_gpu.convert_layout %145 : (tensor<128x64xf32, #blocked>) -> tensor<128x64xf32, #blocked2> %147 = arith.mulf %arg24, %127 : tensor<128xf32, #blocked1> %148 = "tt.reduce"(%133) <{axis = 1 : i32}> ({ diff --git a/test/TritonGPU/dot-operands.mlir b/test/TritonGPU/dot-operands.mlir index 1fbfaa9d4..1cd3b772d 100644 --- a/test/TritonGPU/dot-operands.mlir +++ b/test/TritonGPU/dot-operands.mlir @@ -15,7 +15,7 @@ #BLR = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> #BLC = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [0, 1]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capability" = 80} { // CHECK: tt.func @push_elementwise // CHECK: %[[ALOAD:.*]] = tt.load %arg0 @@ -36,7 +36,7 @@ tt.func @push_elementwise( %a = tt.fp_to_fp %af8: tensor<16x16xf8E5M2, #ALR> -> tensor<16x16xf16, #ALR> %dota = triton_gpu.convert_layout %a : (tensor<16x16xf16, #ALR>) -> tensor<16x16xf16, #Av2k4> %dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -58,7 +58,7 @@ tt.func @succeeds_if_arg_is_not_convert_layout( %dotaf8 = tt.bitcast %dotai8 : tensor<16x16xi8, #Av2k4> -> tensor<16x16xf8E5M2, #Av2k4> %dota = tt.fp_to_fp %dotaf8 : tensor<16x16xf8E5M2, #Av2k4> -> tensor<16x16xf16, #Av2k4> %dotb = triton_gpu.convert_layout %b : (tensor<16x16xf16, #BLC>) -> tensor<16x16xf16, #Bv2k4> - %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> + %newc = tt.dot %dota, %dotb, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<16x16xf16, #Av2k4> * tensor<16x16xf16, #Bv2k4> -> tensor<16x16xf32, #Cv2> tt.return %newc : tensor<16x16xf32, #Cv2> } @@ -69,7 +69,7 @@ tt.func @succeeds_if_arg_is_not_convert_layout( #blockedA = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blockedB = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capability" = 80} { // CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> // CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> @@ -82,7 +82,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: %[[AEXT:.*]] = arith.extf %[[ACVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.dot %[[AEXT]], %[[BEXT]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @push_convert_both_operands( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -93,7 +93,7 @@ tt.func @push_convert_both_operands( %be = arith.extf %b : tensor<16x16xf16, #blockedB> to tensor<16x16xf32, #blockedB> %al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %bl = triton_gpu.convert_layout %be : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } @@ -104,7 +104,7 @@ tt.func @push_convert_both_operands( #blockedA = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> #blockedB = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}> #mma = #triton_gpu.mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4]}> -module attributes {"triton_gpu.num-warps" = 4 : i32} { +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.compute-capability" = 80} { // CHECK: #[[BA:.*]] = #triton_gpu.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> // CHECK: #[[BB:.*]] = #triton_gpu.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1], CTAsPerCGA = [], CTASplitNum = [], CTAOrder = []}> @@ -119,7 +119,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} { // CHECK: %[[BCVT:.*]] = triton_gpu.convert_layout %[[BLOAD]] : (tensor<16x16xf16, #[[BB]]>) -> tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[BEXT:.*]] = arith.extf %[[BCVT]] : tensor<16x16xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> to tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> // CHECK: %[[ADD:.+]] = arith.addf %[[BEXT]], %[[CST]] : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> +// CHECK: tt.dot %[[AEXT]], %[[ADD]], %{{.*}} {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #[[MMA]], kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #[[MMA]], kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.func @update_kwidth_slice( %pa: tensor<16x16x!tt.ptr, #blockedA> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, %pb: tensor<16x16x!tt.ptr, #blockedB> {tt.divisibility=16: i32, tt.contiguity=2 : i32}, @@ -132,8 +132,24 @@ tt.func @update_kwidth_slice( %add = arith.addf %be, %cst : tensor<16x16xf32, #blockedB> %al = triton_gpu.convert_layout %ae : (tensor<16x16xf32, #blockedA>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %bl = triton_gpu.convert_layout %add : (tensor<16x16xf32, #blockedB>) -> tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> - %r = tt.dot %al, %bl, %c {allowTF32 = true} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> + %r = tt.dot %al, %bl, %c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<16x16xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<16x16xf32, #mma> tt.return %r : tensor<16x16xf32, #mma> } } + +// ----- + +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 64, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { +// CHECK: tt.func @mma_v3_reg_operand_A +// CHECK: %[[A:.+]] = triton_gpu.convert_layout %{{.*}} : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> +// CHECK: tt.dot %[[A]], {{.*}} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * tensor<64x64xf16, #shared> -> tensor<128x64xf32, #mma> +tt.func @mma_v3_reg_operand_A(%arg0: tensor<128x64xf16, #mma>, %arg1: tensor<64x64xf16, #shared>, %arg2: tensor<128x64xf32, #mma>) -> tensor<128x64xf32, #mma>{ + %A = triton_gpu.convert_layout %arg0 : (tensor<128x64xf16, #mma>) -> tensor<128x64xf16, #shared1> + %r = tt.dot %A, %arg1, %arg2 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #shared1> * tensor<64x64xf16, #shared> -> tensor<128x64xf32, #mma> + tt.return %r : tensor<128x64xf32, #mma> +} +} diff --git a/test/TritonGPU/fence-inserstion.mlir b/test/TritonGPU/fence-inserstion.mlir new file mode 100644 index 000000000..ff4a1fe1e --- /dev/null +++ b/test/TritonGPU/fence-inserstion.mlir @@ -0,0 +1,205 @@ +// RUN: triton-opt %s -split-input-file --triton-nvidia-gpu-fence-insertion | FileCheck %s +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_like_fence_1(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c128_i32 = arith.constant 128 : i32 + %c1_i32 = arith.constant 1 : i32 + %c3_i32 = arith.constant 3 : i32 + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %0 = arith.extsi %arg3 : i32 to i64 + %1 = arith.extsi %arg5 : i32 to i64 + %2 = arith.extsi %arg6 : i32 to i64 + %3 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %4 = arith.extsi %arg4 : i32 to i64 + %5 = arith.extsi %arg7 : i32 to i64 + %6 = tt.make_tensor_ptr %arg1, [%1, %4], [%5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %7 = arith.extsi %arg8 : i32 to i64 + %8 = tt.make_tensor_ptr %arg2, [%0, %4], [%7, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %9 = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : tensor<3xi64, #shared> + %10 = arith.cmpi sgt, %arg5, %c0_i32 : i32 + %11 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> + %12 = tt.splat %10 : (i1) -> tensor<128x128xi1, #blocked1> + %13 = triton_nvidia_gpu.extract_mbarrier %9[%c0_i32] : tensor<3xi64, #shared>, i32 -> + %14 = triton_nvidia_gpu.get_thread_id : i32 + %15 = arith.cmpi eq, %14, %c0_i32 : i32 + %16 = arith.andi %15, %10 : i1 + triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> + %19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %20 = tt.advance %3, [%c0_i32, %c128_i32] : , 1> + %21 = tt.advance %6, [%c128_i32, %c0_i32] : , 1> + %22 = arith.cmpi sgt, %arg5, %c128_i32 : i32 + %23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1> + %24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> + %25 = arith.andi %15, %22 : i1 + triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %30:15 = scf.for %arg9 = %c0_i32 to %arg5 step %c128_i32 iter_args(%arg10 = %cst, %arg11 = %3, %arg12 = %6, %arg13 = %26, %arg14 = %27, %arg15 = %28, %arg16 = %29, %arg17 = %20, %arg18 = %21, %arg19 = %c128_i32, %arg20 = %c2_i32, %arg21 = %c0_i32, %arg22 = %c0_i32, %arg23 = %false, %arg24 = %true) -> (tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1) : i32 { + %33 = triton_nvidia_gpu.extract_mbarrier %9[%arg21] : tensor<3xi64, #shared>, i32 -> + triton_nvidia_gpu.mbarrier_wait %33, %arg23 : + // CHECK: triton_nvidia_gpu.fence_async_shared + %34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma> + triton_nvidia_gpu.dot_wait {pendings = 1 : i32} + %35 = tt.advance %arg11, [%c0_i32, %c128_i32] : , 1> + %36 = tt.advance %arg12, [%c128_i32, %c0_i32] : , 1> + %37 = arith.addi %arg19, %c128_i32 : i32 + %38 = arith.cmpi slt, %37, %arg5 : i32 + %39 = arith.addi %arg21, %c1_i32 : i32 + %40 = arith.cmpi uge, %39, %c3_i32 : i32 + %41 = arith.select %40, %c0_i32, %39 : i32 + %42 = tt.advance %arg17, [%c0_i32, %c128_i32] : , 1> + %43 = tt.advance %arg18, [%c128_i32, %c0_i32] : , 1> + %44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1> + %45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> + %46 = arith.andi %15, %38 : i1 + triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %b_48 = triton_gpu.convert_layout %48 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1> + %s_48 = triton_gpu.convert_layout %b_48 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #shared1> + %51 = arith.addi %arg20, %c1_i32 : i32 + %52 = arith.cmpi uge, %51, %c3_i32 : i32 + %53 = arith.select %52, %c0_i32, %51 : i32 + %54 = arith.addi %arg22, %c1_i32 : i32 + %55 = arith.xori %arg23, %true : i1 + %56 = arith.cmpi ult, %39, %c3_i32 : i32 + %57 = arith.andi %40, %55 : i1 + %58 = arith.andi %56, %arg23 : i1 + %59 = arith.ori %57, %58 : i1 + %60 = arith.xori %arg24, %true : i1 + %61 = arith.cmpi ult, %51, %c3_i32 : i32 + %62 = arith.andi %52, %60 : i1 + %63 = arith.andi %61, %arg24 : i1 + %64 = arith.ori %62, %63 : i1 + scf.yield %34, %35, %36, %47, %49, %s_48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1 + } + scf.if %10 { + triton_nvidia_gpu.dot_wait {pendings = 0 : i32} + } + %31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1> + triton_nvidia_gpu.store_async %8, %32 : !tt.ptr, 1>, tensor<128x128xf16, #shared1> + triton_gpu.async_bulk_commit_group + triton_gpu.async_bulk_wait {num = 0 : i32} + tt.return + } +} + + +// ----- + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [0, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +#mma = #triton_gpu.mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0], hasLeadingOffset = false}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_like_fence_2(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg4: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg5: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg6: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg7: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}, %arg8: i32 {tt.divisibility = 16 : i32, tt.max_divisibility = 8 : i32}) attributes {noinline = false} { + %c2_i32 = arith.constant 2 : i32 + %c128_i32 = arith.constant 128 : i32 + %c1_i32 = arith.constant 1 : i32 + %c3_i32 = arith.constant 3 : i32 + %true = arith.constant true + %false = arith.constant false + %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %0 = arith.extsi %arg3 : i32 to i64 + %1 = arith.extsi %arg5 : i32 to i64 + %2 = arith.extsi %arg6 : i32 to i64 + %3 = tt.make_tensor_ptr %arg0, [%0, %1], [%2, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %4 = arith.extsi %arg4 : i32 to i64 + %5 = arith.extsi %arg7 : i32 to i64 + %6 = tt.make_tensor_ptr %arg1, [%1, %4], [%5, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %7 = arith.extsi %arg8 : i32 to i64 + %8 = tt.make_tensor_ptr %arg2, [%0, %4], [%7, %c1_i64], [%c0_i32, %c0_i32] {order = array} : , 1> + %9 = triton_nvidia_gpu.alloc_mbarrier {count = 1 : i32} : tensor<3xi64, #shared> + %10 = arith.cmpi sgt, %arg5, %c0_i32 : i32 + %11 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> + %12 = tt.splat %10 : (i1) -> tensor<128x128xi1, #blocked1> + %13 = triton_nvidia_gpu.extract_mbarrier %9[%c0_i32] : tensor<3xi64, #shared>, i32 -> + %14 = triton_nvidia_gpu.get_thread_id : i32 + %15 = arith.cmpi eq, %14, %c0_i32 : i32 + %16 = arith.andi %15, %10 : i1 + triton_nvidia_gpu.mbarrier_arrive %13, %16 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %17 = triton_nvidia_gpu.insert_slice_async_v2 %3, %11, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %18 = triton_gpu.alloc_tensor : tensor<3x128x128xf16, #shared1> + %19 = triton_nvidia_gpu.insert_slice_async_v2 %6, %18, %c0_i32, %13, %12 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %20 = tt.advance %3, [%c0_i32, %c128_i32] : , 1> + %21 = tt.advance %6, [%c128_i32, %c0_i32] : , 1> + %22 = arith.cmpi sgt, %arg5, %c128_i32 : i32 + %23 = tt.splat %22 : (i1) -> tensor<128x128xi1, #blocked1> + %24 = triton_nvidia_gpu.extract_mbarrier %9[%c1_i32] : tensor<3xi64, #shared>, i32 -> + %25 = arith.andi %15, %22 : i1 + triton_nvidia_gpu.mbarrier_arrive %24, %25 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %26 = triton_nvidia_gpu.insert_slice_async_v2 %20, %17, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %27 = triton_nvidia_gpu.insert_slice_async_v2 %21, %19, %c1_i32, %24, %23 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %28 = triton_gpu.extract_slice %26[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %29 = triton_gpu.extract_slice %27[0, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %b_29 = triton_gpu.convert_layout %29 : (tensor<128x128xf16, #shared1>) -> tensor<128x128xf16, #blocked1> + %s_29 = triton_gpu.convert_layout %b_29 : (tensor<128x128xf16, #blocked1>) -> tensor<128x128xf16, #shared1> + %30:15 = scf.for %arg9 = %c0_i32 to %arg5 step %c128_i32 iter_args(%arg10 = %cst, %arg11 = %3, %arg12 = %6, %arg13 = %26, %arg14 = %27, %arg15 = %28, %arg16 = %s_29, %arg17 = %20, %arg18 = %21, %arg19 = %c128_i32, %arg20 = %c2_i32, %arg21 = %c0_i32, %arg22 = %c0_i32, %arg23 = %false, %arg24 = %true) -> (tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1) : i32 { + %33 = triton_nvidia_gpu.extract_mbarrier %9[%arg21] : tensor<3xi64, #shared>, i32 -> + triton_nvidia_gpu.mbarrier_wait %33, %arg23 : + // CHECK: triton_nvidia_gpu.fence_async_shared + %34 = triton_nvidia_gpu.dot_async %arg15, %arg16, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x128xf16, #shared1> * tensor<128x128xf16, #shared1> -> tensor<128x128xf32, #mma> + triton_nvidia_gpu.dot_wait {pendings = 1 : i32} + %35 = tt.advance %arg11, [%c0_i32, %c128_i32] : , 1> + %36 = tt.advance %arg12, [%c128_i32, %c0_i32] : , 1> + %37 = arith.addi %arg19, %c128_i32 : i32 + %38 = arith.cmpi slt, %37, %arg5 : i32 + %39 = arith.addi %arg21, %c1_i32 : i32 + %40 = arith.cmpi uge, %39, %c3_i32 : i32 + %41 = arith.select %40, %c0_i32, %39 : i32 + %42 = tt.advance %arg17, [%c0_i32, %c128_i32] : , 1> + %43 = tt.advance %arg18, [%c128_i32, %c0_i32] : , 1> + %44 = tt.splat %38 : (i1) -> tensor<128x128xi1, #blocked1> + %45 = triton_nvidia_gpu.extract_mbarrier %9[%arg20] : tensor<3xi64, #shared>, i32 -> + %46 = arith.andi %15, %38 : i1 + triton_nvidia_gpu.mbarrier_arrive %45, %46 {operand_segment_sizes = array, trackAsyncOp = false, txCount = 65536 : i32} : !tt.ptr, i1 + %47 = triton_nvidia_gpu.insert_slice_async_v2 %42, %arg13, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %48 = triton_gpu.extract_slice %47[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %49 = triton_nvidia_gpu.insert_slice_async_v2 %43, %arg14, %arg20, %45, %44 {axis = 0 : i32, cache = 1 : i32, evict = 1 : i32, isVolatile = false, operand_segment_sizes = array} : !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, i32, !tt.ptr, tensor<128x128xi1, #blocked1> -> tensor<3x128x128xf16, #shared1> + %50 = triton_gpu.extract_slice %49[%41, 0, 0] [1, 128, 128] [1, 1, 1] : tensor<3x128x128xf16, #shared1> to tensor<128x128xf16, #shared1> + %51 = arith.addi %arg20, %c1_i32 : i32 + %52 = arith.cmpi uge, %51, %c3_i32 : i32 + %53 = arith.select %52, %c0_i32, %51 : i32 + %54 = arith.addi %arg22, %c1_i32 : i32 + %55 = arith.xori %arg23, %true : i1 + %56 = arith.cmpi ult, %39, %c3_i32 : i32 + %57 = arith.andi %40, %55 : i1 + %58 = arith.andi %56, %arg23 : i1 + %59 = arith.ori %57, %58 : i1 + %60 = arith.xori %arg24, %true : i1 + %61 = arith.cmpi ult, %51, %c3_i32 : i32 + %62 = arith.andi %52, %60 : i1 + %63 = arith.andi %61, %arg24 : i1 + %64 = arith.ori %62, %63 : i1 + scf.yield %34, %35, %36, %47, %49, %48, %50, %42, %43, %37, %53, %41, %54, %59, %64 : tensor<128x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1>, tensor<3x128x128xf16, #shared1>, tensor<3x128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, tensor<128x128xf16, #shared1>, !tt.ptr, 1>, !tt.ptr, 1>, i32, i32, i32, i32, i1, i1 + } + scf.if %10 { + triton_nvidia_gpu.dot_wait {pendings = 0 : i32} + } + %31 = arith.truncf %30#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %32 = triton_gpu.convert_layout %31 : (tensor<128x128xf16, #mma>) -> tensor<128x128xf16, #shared1> + triton_nvidia_gpu.store_async %8, %32 : !tt.ptr, 1>, tensor<128x128xf16, #shared1> + triton_gpu.async_bulk_commit_group + triton_gpu.async_bulk_wait {num = 0 : i32} + tt.return + } +} diff --git a/test/TritonGPU/loop-pipeline-hopper.mlir b/test/TritonGPU/loop-pipeline-hopper.mlir index 7c54ce39b..8bfb7b576 100644 --- a/test/TritonGPU/loop-pipeline-hopper.mlir +++ b/test/TritonGPU/loop-pipeline-hopper.mlir @@ -74,7 +74,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -151,7 +151,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -220,7 +220,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } @@ -293,7 +293,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // // %sa = triton_gpu.convert_layout %a : (tensor<128x32xf16, #BA>) -> tensor<128x32xf16, #SA> // %sb = triton_gpu.convert_layout %b : (tensor<32x128xf16, #BB>) -> tensor<32x128xf16, #SB> -// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> +// %c = tt.dot %sa, %sb, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #SA> * tensor<32x128xf16, #SB> -> tensor<128x128xf32, #C> // // %a_tileptr_next = tt.advance %a_tileptr, [%c0, %c32_i32] : !tt.ptr, 1> // %b_tileptr_next = tt.advance %b_tileptr, [%c32_i32, %c0] : !tt.ptr, 1> diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 626e7bdb1..a5bb2f239 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -84,7 +84,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %b_ = triton_gpu.convert_layout %b__ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> %b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -157,7 +157,7 @@ tt.func @matmul_loop_nested(%lb : index, %ub : index, %step : index, %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> @@ -224,7 +224,7 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, %loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C>) { %b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B> - %c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> + %c = tt.dot %a, %b, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #C> } @@ -266,7 +266,7 @@ tt.func @lut_bmm_scalar(%77: i64 {tt.divisibility=16: i32}, %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32 : !tt.ptr, i32 scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, !tt.ptr @@ -312,7 +312,7 @@ tt.func @lut_bmm_vector(%77: tensor<16x16xi64, #BL> {tt.divisibility=16: i32, tt %87 = tt.load %86 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<16x16xf16, #BL> %88 = triton_gpu.convert_layout %82 : (tensor<16x16xf16, #AL>) -> tensor<16x16xf16, #A> %89 = triton_gpu.convert_layout %87 : (tensor<16x16xf16, #BL>) -> tensor<16x16xf16, #B> - %90 = tt.dot %88, %89, %arg19 {allowTF32 = true} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> + %90 = tt.dot %88, %89, %arg19 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<16x16xf16, #A> * tensor<16x16xf16, #B> -> tensor<16x16xf32, #C> %91 = tt.addptr %arg20, %78 : tensor<16x16x!tt.ptr, #AL>, tensor<16x16xi32, #AL> %92 = tt.addptr %arg21, %c1_i32_splat : tensor<16x!tt.ptr, #BLs1>, tensor<16xi32, #BLs1> scf.yield %90, %91, %92 : tensor<16x16xf32, #C>, tensor<16x16x!tt.ptr, #AL>, tensor<16x!tt.ptr, #BLs1> @@ -362,7 +362,7 @@ tt.func @post_load_inv(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %116 = tt.load %arg12, %115, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %117 = triton_gpu.convert_layout %112 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> %118 = triton_gpu.convert_layout %116 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %119 = tt.dot %117, %118, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %119 = tt.dot %117, %118, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %131 = arith.index_cast %arg9 : index to i32 %120 = arith.addi %131, %c1_i32 : i32 %121 = arith.muli %120, %c32_i32 : i32 @@ -425,7 +425,7 @@ tt.func @cross_iter_dep(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %150 = tt.load %arg12, %149, %cst_1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #AL> %151 = triton_gpu.convert_layout %146 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> %152 = triton_gpu.convert_layout %150 : (tensor<32x32xf32, #AL>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> - %153 = tt.dot %151, %152, %arg10 {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> + %153 = tt.dot %151, %152, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 1}>> -> tensor<32x32xf32, #C> %162 = arith.index_cast %arg9 : index to i32 %154 = arith.addi %162, %c2_i32 : i32 %155 = arith.muli %154, %c32_i32 : i32 @@ -497,7 +497,7 @@ tt.func @dep_arg_two_uses(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %199 = tt.load %arg24, %198, %88 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> %200 = triton_gpu.convert_layout %193 : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> %201 = triton_gpu.convert_layout %199 : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> - %202 = tt.dot %200, %201, %arg23 {allowTF32 = true} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> + %202 = tt.dot %200, %201, %arg23 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}>> * tensor<32x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}>> -> tensor<128x128xf32, #C> %203 = tt.addptr %arg24, %90 : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi64, #BL> scf.yield %190, %196, %197, %202, %203 : tensor<128x32x!tt.ptr, #AL>, tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #AL}>>, !tt.ptr, tensor<128x128xf32, #C>, tensor<32x128x!tt.ptr, #BL> } diff --git a/test/TritonGPU/materialize-load-store.mlir b/test/TritonGPU/materialize-load-store.mlir index 65ca0e6c6..58bc51514 100644 --- a/test/TritonGPU/materialize-load-store.mlir +++ b/test/TritonGPU/materialize-load-store.mlir @@ -52,7 +52,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %8 = tt.load %6 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x16xf16, #blockedB1> %9 = triton_gpu.convert_layout %7 : (tensor<64x16xf16, #blockedA1>) -> tensor<64x16xf16, #sharedA> %10 = triton_gpu.convert_layout %8 : (tensor<16x16xf16, #blockedB1>) -> tensor<16x16xf16, #sharedB> - %11 = tt.dot %9, %10, %cst {allowTF32 = true} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma> + %11 = tt.dot %9, %10, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #sharedA> * tensor<16x16xf16, #sharedB> -> tensor<64x16xf32, #mma> %12 = triton_gpu.convert_layout %11 : (tensor<64x16xf32, #mma>) -> tensor<64x16xf32, #blockedA1> %13 = arith.truncf %12 : tensor<64x16xf32, #blockedA1> to tensor<64x16xf16, #blockedA1> %14 = arith.extsi %arg8 : i32 to i64 diff --git a/test/TritonGPU/matmul.mlir b/test/TritonGPU/matmul.mlir index 81560db65..6c9264400 100644 --- a/test/TritonGPU/matmul.mlir +++ b/test/TritonGPU/matmul.mlir @@ -62,7 +62,7 @@ tt.func @matmul_kernel__Pfp32_Pfp32_Pfp32_i32_i32_i32_i32_i32_i32_i32_i32_i32__1 %47:3 = scf.for %arg12 = %c0 to %46 step %c64 iter_args(%arg13 = %cst_0, %arg14 = %34, %arg15 = %45) -> (tensor<64x64xf32>, tensor<64x64x!tt.ptr>, tensor<64x64x!tt.ptr>) { %76 = tt.load %arg14, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> %77 = tt.load %arg15, %cst, %cst_0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false, transA=false, transB=false} : tensor<64x64xf32> - %78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> + %78 = tt.dot %76, %77, %cst_0 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<64x64xf32> * tensor<64x64xf32> -> tensor<64x64xf32> %79 = arith.addf %arg13, %78 : tensor<64x64xf32> %80 = arith.muli %arg7, %c64_i32 : i32 %81 = tt.splat %80 : (i32) -> tensor<64x64xi32> diff --git a/test/TritonGPU/prefetch.mlir b/test/TritonGPU/prefetch.mlir index b820f4034..7104d8dc8 100644 --- a/test/TritonGPU/prefetch.mlir +++ b/test/TritonGPU/prefetch.mlir @@ -53,7 +53,7 @@ tt.func @matmul_loop_mixed(%lb : index, %ub : index, %step : index, %A : !tt.ptr %a_op_ = triton_gpu.convert_layout %a : (tensor<128x32xf8E5M2, #A>) -> tensor<128x32xf8E5M2, #A_OP> %a_op = tt.fp_to_fp %a_op_ : tensor<128x32xf8E5M2, #A_OP> -> tensor<128x32xf16, #A_OP> %b_op = triton_gpu.convert_layout %b : (tensor<32x128xf16, #B>) -> tensor<32x128xf16, #B_OP> - %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> + %c = tt.dot %a_op, %b_op, %prev_c {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, transA = false, transB = false} : tensor<128x32xf16, #A_OP> * tensor<32x128xf16, #B_OP> -> tensor<128x128xf32, #C> %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> diff --git a/test/TritonGPU/reorder-instructions.mlir b/test/TritonGPU/reorder-instructions.mlir index a410771ec..18782376b 100644 --- a/test/TritonGPU/reorder-instructions.mlir +++ b/test/TritonGPU/reorder-instructions.mlir @@ -15,7 +15,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %9 = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %10 = triton_gpu.convert_layout %9 : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #shared> %11 = triton_gpu.convert_layout %10 : (tensor<32x32xf32, #shared>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %11, %cst_0, %cst {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %12 = tt.dot %11, %cst_0, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %13 = triton_gpu.convert_layout %12 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked> tt.return @@ -41,7 +41,7 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war %A = tt.load %arg0 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x32xf32, #blocked> %AS = triton_gpu.convert_layout %A : (tensor<32x32xf32, #blocked>) -> tensor<32x32xf32, #shared> %AD = triton_gpu.convert_layout %AS : (tensor<32x32xf32, #shared>) -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> - %12 = tt.dot %AD, %BD, %cst {allowTF32 = true} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> + %12 = tt.dot %AD, %BD, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 1}>> -> tensor<32x32xf32, #mma> %13 = triton_gpu.convert_layout %12 : (tensor<32x32xf32, #mma>) -> tensor<32x32xf32, #blocked> tt.store %arg0, %13 {cache = 1 : i32, evict = 1 : i32} : tensor<32x32xf32, #blocked> tt.return diff --git a/test/TritonGPU/rewrite-tensor-pointer.mlir b/test/TritonGPU/rewrite-tensor-pointer.mlir index 23eddb24b..cfe46c787 100644 --- a/test/TritonGPU/rewrite-tensor-pointer.mlir +++ b/test/TritonGPU/rewrite-tensor-pointer.mlir @@ -46,7 +46,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %30 = triton_gpu.convert_layout %28 : (tensor<128x64xf16, #blocked>) -> tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> %31 = triton_gpu.convert_layout %29 : (tensor<64x128xf16, #blocked1>) -> tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> %32 = triton_gpu.convert_layout %arg12 : (tensor<128x128xf32, #blocked>) -> tensor<128x128xf32, #blocked2> - %33 = tt.dot %30, %31, %32 {allowTF32 = true} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> + %33 = tt.dot %30, %31, %32 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> %34 = triton_gpu.convert_layout %33 : (tensor<128x128xf32, #blocked2>) -> tensor<128x128xf32, #blocked> // CHECK-NOT: tt.advance %35 = tt.advance %arg13, [%c0_i32, %c64_i32] : , 1> diff --git a/test/TritonGPU/wsdecomposing.mlir b/test/TritonGPU/wsdecomposing.mlir index 059554a59..7d89baa8d 100644 --- a/test/TritonGPU/wsdecomposing.mlir +++ b/test/TritonGPU/wsdecomposing.mlir @@ -97,7 +97,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -208,7 +208,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %90 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> // CHECK-NEXT: %91 = triton_gpu.convert_layout %89 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> // CHECK-NEXT: %92 = triton_gpu.convert_layout %90 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - // CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> // CHECK-NEXT: %94 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> // CHECK-NEXT: %95 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -336,7 +336,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> } - %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -452,7 +452,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %96 = triton_gpu.convert_layout %94 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %95, %96 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> // CHECK-NEXT: } {async_agent = dense<[0, 1]> : vector<2xi32>} - // CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %90 = tt.dot %89#0, %89#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> // CHECK-NEXT: %91 = tt.addptr %arg11, %cst_1 {async_agent = dense<0> : vector<1xi32>} : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> // CHECK-NEXT: %92 = tt.addptr %arg12, %cst_0 {async_agent = dense<0> : vector<1xi32>} : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> // CHECK-NEXT: scf.yield {async_agent = dense<[0, 1]> : vector<2xi32>} %90, %91, %92 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -587,7 +587,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr, #blocked1>) { %r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32 %r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1> @@ -717,7 +717,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // CHECK-NEXT: %92 = tt.load %arg12 {async_agent = dense<0> : vector<1xi32>, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> // CHECK-NEXT: %93 = triton_gpu.convert_layout %91 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> // CHECK-NEXT: %94 = triton_gpu.convert_layout %92 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - // CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + // CHECK-NEXT: %95 = tt.dot %93, %94, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> // CHECK-NEXT: %96 = scf.if %90 -> (tensor<128x32x!tt.ptr, #blocked1>) { // CHECK-NEXT: %99 = arith.select %90, %c31_i32, %c127_i32 {async_agent = dense<1> : vector<1xi32>} : i32 // CHECK-NEXT: %100 = tt.splat %99 {async_agent = dense<1> : vector<1xi32>} : (i32) -> tensor<128x32xi32, #blocked1> diff --git a/test/TritonGPU/wsmaterialization.mlir b/test/TritonGPU/wsmaterialization.mlir index 4ab8be6c5..07ee80f9b 100644 --- a/test/TritonGPU/wsmaterialization.mlir +++ b/test/TritonGPU/wsmaterialization.mlir @@ -177,7 +177,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %63 = triton_gpu.extract_slice %0[%arg11, 0, 0] [1, 32, 128] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x32x128xf16, #shared> to tensor<32x128xf16, #shared> %64 = triton_gpu.convert_layout %62 {async_agent = dense<1> : vector<1xi32>} : (tensor<128x32xf16, #shared1>) -> tensor<128x32xf16, #shared1> %65 = triton_gpu.convert_layout %63 {async_agent = dense<1> : vector<1xi32>} : (tensor<32x128xf16, #shared>) -> tensor<32x128xf16, #shared> - %66 = tt.dot %64, %65, %arg10 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma> + %66 = tt.dot %64, %65, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<128x32xf16, #shared1> * tensor<32x128xf16, #shared> -> tensor<128x128xf32, #mma> %c1_i32_2 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 %c3_i32 = arith.constant {async_agent = dense<1> : vector<1xi32>} 3 : i32 %67 = arith.addi %arg11, %c1_i32_2 {async_agent = dense<1> : vector<1xi32>} : i32 @@ -384,7 +384,7 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" %50 = triton_gpu.convert_layout %49 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared> %51 = triton_gpu.extract_slice %1[%48, 0, 0] [1, 16, 64] [1, 1, 1] {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> %52 = triton_gpu.convert_layout %51 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1> - %53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %53 = tt.dot %50, %52, %arg12 {agent.mutex_role = 0 : i32, allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> triton_nvidia_gpu.consumer_release %2, %48 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 %c1_i32_6 = arith.constant {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} 1 : i32 %54 = arith.addi %arg13, %c1_i32_6 {agent.mutex_role = 0 : i32, async_agent = dense<1> : vector<1xi32>} : i32 diff --git a/test/TritonGPU/wsmutex.mlir b/test/TritonGPU/wsmutex.mlir index 78b9037c5..1c0ad7712 100644 --- a/test/TritonGPU/wsmutex.mlir +++ b/test/TritonGPU/wsmutex.mlir @@ -141,7 +141,7 @@ module attributes {"async.num-agents" = 2 : i32, "triton_gpu.compute-capability" %40 = triton_gpu.convert_layout %39 {async_agent = dense<1> : vector<1xi32>} : (tensor<64x16xf16, #shared>) -> tensor<64x16xf16, #shared> %41 = triton_gpu.extract_slice %1[%38, 0, 0] [1, 16, 64] [1, 1, 1] {async_agent = dense<1> : vector<1xi32>} : tensor<3x16x64xf16, #shared1> to tensor<16x64xf16, #shared1> %42 = triton_gpu.convert_layout %41 {async_agent = dense<1> : vector<1xi32>} : (tensor<16x64xf16, #shared1>) -> tensor<16x64xf16, #shared1> - %43 = tt.dot %40, %42, %arg12 {allowTF32 = true, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %43 = tt.dot %40, %42, %arg12 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32, async_agent = dense<1> : vector<1xi32>} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> triton_nvidia_gpu.consumer_release %2, %38 {async_agent = dense<1> : vector<1xi32>} : tensor<3x!triton_nvidia_gpu.token>, i32 %c1_i32_5 = arith.constant {async_agent = dense<1> : vector<1xi32>} 1 : i32 %44 = arith.addi %arg13, %c1_i32_5 {async_agent = dense<1> : vector<1xi32>} : i32 diff --git a/test/TritonGPU/wspipeline.mlir b/test/TritonGPU/wspipeline.mlir index c2b0a1b70..a42ca46d0 100644 --- a/test/TritonGPU/wspipeline.mlir +++ b/test/TritonGPU/wspipeline.mlir @@ -21,7 +21,8 @@ // CHECK: triton_nvidia_gpu.consumer_wait // CHECK: triton_gpu.extract_slice // CHECK: triton_gpu.extract_slice -// CHECK: tt.dot +// CHECK: triton_nvidia_gpu.dot_async +// CHECK: triton_nvidia_gpu.dot_wait // CHECK: triton_nvidia_gpu.consumer_release // CHECK: scf.yield // CHECK: async_agent = dense<1> : vector<1xi32> @@ -119,7 +120,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> diff --git a/test/TritonNvidiaGPU/ws-feasibility-checking.mlir b/test/TritonNvidiaGPU/ws-feasibility-checking.mlir index 0eec6889f..981d4748d 100644 --- a/test/TritonNvidiaGPU/ws-feasibility-checking.mlir +++ b/test/TritonNvidiaGPU/ws-feasibility-checking.mlir @@ -96,7 +96,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -226,7 +226,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> scf.yield %91, %92 : tensor<128x32xf16, #shared>, tensor<32x128xf16, #shared1> } - %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %96#0, %96#1, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %94 = tt.addptr %arg11, %cst_1 : tensor<128x32x!tt.ptr, #blocked1>, tensor<128x32xi32, #blocked1> %95 = tt.addptr %arg12, %cst_0 : tensor<32x128x!tt.ptr, #blocked>, tensor<32x128xi32, #blocked> scf.yield %93, %94, %95 : tensor<128x128xf32, #mma>, tensor<128x32x!tt.ptr, #blocked1>, tensor<32x128x!tt.ptr, #blocked> @@ -362,7 +362,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %90 = tt.load %arg12 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #blocked> %91 = triton_gpu.convert_layout %89 : (tensor<128x32xf16, #blocked1>) -> tensor<128x32xf16, #shared> %92 = triton_gpu.convert_layout %90 : (tensor<32x128xf16, #blocked>) -> tensor<32x128xf16, #shared1> - %93 = tt.dot %91, %92, %arg10 {allowTF32 = true} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> + %93 = tt.dot %91, %92, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #shared> * tensor<32x128xf16, #shared1> -> tensor<128x128xf32, #mma> %base_94 = scf.if %96 -> (tensor<128x32x!tt.ptr, #blocked1>) { %r1_0 = arith.select %96, %c31_i32, %c127_i32 : i32 %r1_1 = tt.splat %r1_0 : (i32) -> tensor<128x32xi32, #blocked1> @@ -438,7 +438,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> - %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -518,7 +518,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> - %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -600,7 +600,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> - %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -686,7 +686,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %47 = tt.load %arg12 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked3> %48 = triton_gpu.convert_layout %46 : (tensor<64x16xf16, #blocked2>) -> tensor<64x16xf16, #shared> %49 = triton_gpu.convert_layout %47 : (tensor<16x64xf16, #blocked3>) -> tensor<16x64xf16, #shared1> - %50 = tt.dot %48, %49, %arg10 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %50 = tt.dot %48, %49, %arg10 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %51 = tt.advance %arg11, [%c0_i32, %c16_i32] : , 1> %52 = tt.advance %arg12, [%c16_i32, %c0_i32] : , 1> scf.yield %50, %51, %52 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -799,7 +799,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %44 = tt.load %arg17 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked4> %45 = triton_gpu.convert_layout %43 : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> %46 = triton_gpu.convert_layout %44 : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> - %47 = tt.dot %45, %46, %arg15 {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %47 = tt.dot %45, %46, %arg15 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> %48 = tt.advance %arg16, [%c0_i32, %c64_i32] : , 1> %49 = tt.advance %arg17, [%c64_i32, %c0_i32] : , 1> scf.yield %47, %48, %49 : tensor<256x128xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> @@ -852,7 +852,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %b = tt.load %arg1 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<64x128xf16, #blocked4> %shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> %shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> - %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> %out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2> tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2> } @@ -887,7 +887,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %b = tt.load %arg1 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64x128x!tt.ptr, #blocked4> -> tensor<64x128xf16, #blocked4> %shm_a = triton_gpu.convert_layout %a : (tensor<256x64xf16, #blocked3>) -> tensor<256x64xf16, #shared> %shm_b = triton_gpu.convert_layout %b : (tensor<64x128xf16, #blocked4>) -> tensor<64x128xf16, #shared1> - %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> + %d = tt.dot %shm_a, %shm_b, %cst {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<256x64xf16, #shared> * tensor<64x128xf16, #shared1> -> tensor<256x128xf32, #mma> %out = triton_gpu.convert_layout %d : (tensor<256x128xf32, #mma>) -> tensor<256x128xf32, #blocked2> tt.store %arg2, %out {cache = 1 : i32, evict = 1 : i32} : tensor<256x128xf32, #blocked2> } @@ -1000,7 +1000,7 @@ module attributes {"triton_gpu.compute-capability" = 90 : i32, "triton_gpu.num-c %92 = tt.load %arg19 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false} : !tt.ptr, 1> -> tensor<16x64xf16, #blocked4> %93 = triton_gpu.convert_layout %91 : (tensor<64x16xf16, #blocked3>) -> tensor<64x16xf16, #shared> %94 = triton_gpu.convert_layout %92 : (tensor<16x64xf16, #blocked4>) -> tensor<16x64xf16, #shared1> - %95 = tt.dot %93, %94, %arg17 {allowTF32 = true} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> + %95 = tt.dot %93, %94, %arg17 {allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<64x16xf16, #shared> * tensor<16x64xf16, #shared1> -> tensor<64x64xf32, #mma> %96 = tt.advance %arg18, [%c0_i32, %c16_i32] : , 1> %97 = tt.advance %arg19, [%c16_i32, %c0_i32] : , 1> scf.yield %95, %96, %97 : tensor<64x64xf32, #mma>, !tt.ptr, 1>, !tt.ptr, 1> diff --git a/test/lit.cfg.py b/test/lit.cfg.py index db65d3e4f..5ea9c458d 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -19,7 +19,7 @@ config.name = 'TRITON' config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir'] +config.suffixes = ['.mlir', '.ll'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -62,6 +62,7 @@ for d in tool_dirs: llvm_config.with_environment('PATH', d, append_path=True) tools = [ 'triton-opt', + 'triton-llvm-opt', ToolSubst('%PYTHON', config.python_executable, unresolved='ignore'), ] diff --git a/third_party/intel_xpu_backend b/third_party/intel_xpu_backend index 0bcc485f8..d05dc79da 160000 --- a/third_party/intel_xpu_backend +++ b/third_party/intel_xpu_backend @@ -1 +1 @@ -Subproject commit 0bcc485f82b34d49494bd0264bacc24a20aafb7a +Subproject commit d05dc79dad638b8ebbacfef44886f568b5885fc3 diff --git a/third_party/triton_shared b/third_party/triton_shared new file mode 160000 index 000000000..d0ac5898f --- /dev/null +++ b/third_party/triton_shared @@ -0,0 +1 @@ +Subproject commit d0ac5898ff97ab33c2839306ec10bfa4fab816f5 diff --git a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp index 90e6ef8c3..20603cd2e 100644 --- a/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp +++ b/unittest/Conversion/TritonGPUToLLVM/EmitIndicesTest.cpp @@ -22,6 +22,7 @@ */ #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "DumpLayout.h"