MMA V3 support taking operand A from register. This helps for chained
matmul operations like in attention.
Add an optimization to use this mode when it helps and add the lowering
for it.
Hi,
I'm adding some features to
`triton.runtime.jit.JITFunction_make_launcher` and found it is hard to
debug it:
1. The inlined Python code is hard to inspect in my editor.
2. My debugger fails to step into these inlined codes.
In response, I've introduced some code to solve these issues. My
modifications include:
~~1. Refactoring the launcher's inline Python code, ensuring it only
relies on the "self" object.~~
~~2. Add a utility method that generates a temporary file to create a
launcher when debugging kernel in main module~~
Using a closure to hold the launcher's body
Because this features might be good to others, I have initiated this
Pull Request.
~~Tests are yet to be added; if this submission might be accepted, I
will add it later.~~
Since this change is a refactor, no new test was added.
This PR leverages the plug-in system to add a shared middle-layer to
Triton.
Currently the middle layer is not complete but has enough functionality
to demonstrate how it can work. The general idea is that Triton IR is
lowered into an MLIR core dialect to allow it to be both shared across
Triton targets as well as allow back-ends to be shared with other
languages.
The basic intended architecture looks like this:
[Triton IR] -> [Middle Layer] -> [HW specific IR]
The middle-layer uses MLIR's Linalg and Tenor Dialects for operations on
Triton block values. Operations on Triton pointers use the Memref
Dialect.
## Usage
To include the shared middle-layer in your Triton build do `export
TRITON_CODEGEN_TRITON_SHARED=1` before invoking your build. Once it is
part of the build it can be leveraged in two ways:
### Stand-Alone
The middle layer can be used as a stand-alone component to convert
Triton dialect to the middle layer dialects.
Stand-alone example:
```
triton-shared-opt --triton-to-linalg %file
```
### Backend Component
The middle layer can also be used as a component in a Triton back-end by
adding the cmake targets it produces and its headers files to that
back-end. An example back-end will be published at a later date.
## Implementation details
Even though a valid triton program can perform load and store in
arbitrary memory locations, the prototype only supports lowering
programs that have structured memory access patterns.
### Analyses
As part of the conversion process, there are three important analyses:
1. Pointer analysis:
+ This analysis is responsible for extracting structured memory access
patterns from a `triton` program during load and store; it walks the IR
and visits relevant instructions to build strided memory accesses in the
`memref` dialect. The analysis is still in its early stage and does not
support all scenarios.
2. Use analysis:
+ After "Pointer analysis", instructions that are part of memory address
calculation will no longer be necessary in a triton program because
their semantics have now been captured by `memref` operations
representing strided memory accesses. To aid with removing these
instructions safely, we perform `Use analysis` to mark which
instructions are used *only* in address calculation (called `MetaUse`)
or used in *both* address calculation and data manipulation (called
`MixedUse`) operations. Those that are `MixedUse` are cloned and have
their users adjusted accordingly with the goal of separating out the
`MetaUse` ops so that they can be safely deleted.
3. Mask analysis:
+ This analysis is responsible for handling masked loads and stores.
### Conversion strategy
We introduce the `TritonToLinalg` pass that converts the `triton`
dialect to the `linalg` dialect on *tensors*. This means the resulting
IR is fully compatible with `linalg` tiling and fusion transformation
passes. As mentioned in the `Pointer analysis`'s description, we do
however have to deal with memref instructions at the load and store
boundaries and have to convert them to tensors using
`bufferization.to_tensor`. Here's a simple example of what the IR looks
like:
```mlir
tt.func @kernel(%afloat : !tt.ptr<bf16>, %res : !tt.ptr<bf16>) {
%0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32>
%1 = tt.splat %afloat : (!tt.ptr<bf16>) -> tensor<128x!tt.ptr<bf16>>
%2 = tt.addptr %1, %0 : tensor<128x!tt.ptr<bf16>>, tensor<128xi32>
%afm = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128xbf16>
%3 = "tt.reduce"(%afm) ({
^bb0(%arg5: bf16, %arg6: bf16):
%21 = arith.addf %arg5, %arg6 : bf16
tt.reduce.return %21 : bf16
}) {axis = 0 : i32} : (tensor<128xbf16>) -> bf16
tt.store %res, %3 : bf16
tt.return
}
```
after conversion:
```mlir
func.func @kernel(%arg0: memref<*xbf16>, %arg1: memref<*xbf16>, %arg2: i32, %arg3: i32, %arg4: i32) {
%cst = arith.constant 0.000000e+00 : f32
%reinterpret_cast = memref.reinterpret_cast %arg0 to offset: [0], sizes: [128], strides: [1] :
memref<*xbf16> to memref<128xbf16, strided<[1]>>
%alloc = memref.alloc() : memref<128xbf16>
memref.copy %reinterpret_cast, %alloc : memref<128xbf16, strided<[1]>> to memref<128xbf16>
%0 = bufferization.to_tensor %alloc restrict writable : memref<128xbf16>
%1 = bufferization.alloc_tensor() : tensor<f32>
%inserted = tensor.insert %cst into %1[] : tensor<f32>
%reduced = linalg.reduce ins(%0 : tensor<128xbf16>) outs(%inserted : tensor<f32>) dimensions = [0]
(%in: bf16, %init: f32) {
%3 = arith.extf %in : bf16 to f32
%4 = arith.addf %3, %init : f32
linalg.yield %4 : f32
}
%extracted = tensor.extract %reduced[] : tensor<f32>
%2 = arith.truncf %extracted : f32 to bf16
%reinterpret_cast_0 = memref.reinterpret_cast %arg1 to offset: [0], sizes: [1], strides: [1] :
memref<*xbf16> to memref<1xbf16, strided<[1]>>
affine.store %2, %reinterpret_cast_0[0] : memref<1xbf16, strided<[1]>>
return
}
```
Important details to note:
+ `tt.load` (together with all of its related address calculation
instructions such as `tt.addptr` and `tt.splat`) are lowered to a
combination of `memref.reinterpret_cast`, `memref.alloc`, and
`memref.copy`. After the initialization of the local buffer, we convert
the memref back to a tensor using `bufferization.to_tensor`; this op is
automatically removed during bufferization.
+ `tt.store` lowers to a combination of `memref.reinterpret_cast` and
either `affine.store` or `memref.tensor_store`:
```
%reinterpret_cast = memref.reinterpret_cast %arg2 to offset: [...] memref<*xf32> to memref<1024xf32>
%extracted_slice = tensor.extract_slice %15[0] [%21] [1] : tensor<1024xf32> to tensor<?xf32>
%subview = memref.subview %reinterpret_cast[0] [%21] [1] : memref<1024xf32> to memref<?xf32>
memref.tensor_store %extracted_slice, %subview : memref<?xf32>
```
+ element-wise `arith` and `math` operators are converted to their
corresponding `linalg.generic` version.
+ `tt.dot` becomes `linalg.matmul`.
+ `tt.reduce` becomes `linalg.reduce`; known limitation: only support
`addf` and `maxf` reduction in the reduction body for now.
### Testing
The prototype was tested on the following triton kernel examples:
1. vector addition
2. fused softmax
3. matrix multiplication
4. layer normalization
5. fused attention
In addition to testing on the tutorial kernels, I have also added many
lit tests covering various scenarios.
## Recognition
The work here represents contributions from myself as well as many of my
colleagues at Microsoft. I especially want to call out @nhat-nguyen and
@haishanzzz who were major contributors to this work.
When there is a chain of mma ops we want to pick the same shape to avoid
conversions. This improves the detection going through for loops.
This fixes a crash in tutorial bw attention.
We might want to change this logic and convert the format to allow more
efficient MMA at some point.
Replace a single
mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32 instruction
that is used on Ampere with 4 x
mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 instructions
for Turing
Extracted the Turing-int8, Turing-fp16 and Ampere to separate functions.
Somehow I messed up with my previous PR, so just open a new one.
---------
Co-authored-by: Philippe Tillet <phil@openai.com>
This fixes a few bugs I've encountered
- `atomic_add` with int64/uint64 `Operation .add requires .u32 or .s32
or .u64 [...] for instruction 'atom'`
- `atomic_min/max` with float64 -> `ValueError('Cannot bitcast data-type
of size 64 to data-type of size 32')`
- `atomic_min/max` with float32 returns the old value as int32
Google uses clang-format at LLVM HEAD. clang-format's formatting is not
stable, so we want to minimize the difference between the pre-commit
clang-format and HEAD to minimize differences with Google's formatter.
In practice, it appears that there are no relevant changes to the
formatting, so this is a nop. 🤷
Tested by running `pre-commit run --all-files`.
llvm/llvm-project#66754 extends the `LoopLikeOpInterface`: the signature
of `getLoopBody` has changed. `ForOp::getRegion` can be used instead.
This change works with and without llvm/llvm-project#66754.
This was regressed by #2185 because we didn't realise CUDA_CHECK macro
could do Python calls (similar to what led to #2225). I think the
PyErr_Occurred got removed in that PR because there was missing error
handling before the call to _launch, so it looked like it was just in
the wrong place.
It looks like there are also potentially a couple places in cuda.c that
can return with error set, e.g. getDeviceProperties, memAlloc,
memcpyHtoD, memFree, tensorMapEncodeTiled etc, but those are all
pre-existing and not affected by recent changes.
Improve patterns that sync broadcast to reduce the arithmetic density
and also hoist convert on top of expand_dims to do less work.
This address comments in https://github.com/openai/triton/pull/2274
1. On the axis, using `getAxisNumWarpsWithUniqueData` instead of getting
the raw number of warps to avoid communication among warps that handle
the same piece of data.
2. When there's a single warp on the axis, using warp Intrinsics for
communication and skip shared memory.
Need a follow up PR for code clean up.
On my machine, when I try to `pip install cmake` outside a virtualenv,
it gets mad at me and tells me to use apt. Which doesn't quite work for
some reason. Anyway maybe this is simple to Python people, but perhaps
worth mentioning. Especially because we have `.venv` in gitignore
already.
Change the dot to allow taking an initial accumulator and add a flag
that will allow the compiler to accumulate in a lower precision than the
output type.
On Hopper this flag is on by default which allows accumualting with
lower precision.
This only affect Hopper fp8 dot.