This is a very quick change that allows the configs' pre-hooks to see
the values in the config itself. This is useful if we'd like to allocate
intermediate tensor and the shape depends on tile size.
This is a revival of @gaxler initial ahead-of-time compiler proposal.
Code was simplified and some constraints were relaxed (i.e., we now
execute the entire file provided vs just the kernel AST) to promote
maintainability. A basic unit test was added, though it does not test
specialization right now.
co-authored by: Gregory Axler, thexler <g.axler@gmail.com>
Example:
```
if static_a == 0 and static_b == 1:
...
```
The return value of `static_a == 0` should be `constexpr(True)` but not
`True`, otherwise the bool object (True/False) doesn't have the
`logical_and` method.
* [MFMA] Activated Fused Attention Forward Pass
Patch contains following changes:
1) make_range operator now works with MFMA layout.
2) Reduce operation is forced to run in block layout:
inputs converted to block layouts, outputs returned to MFMA layout
* Use simple module walk instead of pattern rewritter.
* Remove pattern rewritter header.
* Enable basic reduce algorithm for MFMA layout
* Add TODO comment for fused attention backward pass
* Fix bug in fast codegen algorithm for reduce op
* Fix input type bug
* Increase block size to 128 since out of memory issue is not seen on MI210
* Fix block_size error
* Add mfma support in DecomposeDotOperand pattern.
`tl.cat(tensor<64>, tensor<64>) -> tensor(128)`, because it concatenates
elements into a single thread, if number of threads is 128, each thread
should own at least 2 elements.
With this PR, we also disable remat of the cat op in some cases.
Add a configurable parameter for the number of threads per warp for
other GPU. Like: Intel GPU.
Make it default to be 32 not change code logic on the CUDA/AMD GPU.
Note: The Intel GPU GenX ISA is explicit SIMD and can support variant
number of threads lane per HW execution unit.
- Applied to `loadA()` the same fix as 2c88ed6aab9ace22ccde1f0e443a1579727ee501.
- Minor cleanup of `mfmaLayout.getWarpsPerCTA()` usage.
Partially fixesROCmSoftwarePlatform/frameworks-internal#4545
Support the following three cases:
1. Operands of `load` depend on induction variables before `load`s.
2. Mixed use of induction variables and offset to update the `ptr`.
3. Cross iteration (>1) dependency values.
Also catch out-of-bounds indices at constructio and throw a proper error
in the frontend.
Finally, let's make the IR a bit prettier:
%0 = tt.get_program_id {axis = 0 : i32} : i32
becomes:
%0 = tt.get_program_id x : i32
Fixes#1718
When comparing to the expected slides, using the `==` operator will
dispatch to the component of the slice. If the user writes `a[10:20]`
these are `triton.constexpr` instances, and the `__eq__` operator which
is implemented as: `return constexpr(self.value == other.value)`. At
this point the access to `.value` on the provided `None` yields an
exception that isn't very friendly to the user.
I am not sure if the implementation of `constexpr` should be hardened
instead?
Co-authored-by: Philippe Tillet <phil@openai.com>
* [MFMA] Implementation of MFMA DotOp pipeline
* Added MFMA test_dot unit tests
* Added missing ifdefs
* Update offline tests
* Removing duplicate parts
* fix build after rebase
* remove redundant stuff
* simplify MMAv3.cpp
* move reps function into operand attr description,
remove coreMatrixType type from layout conversion,
refactored type conversion
* remove duplication of mfma intruction shape computation
* move all MFMA instruction shape details into layout attribute
* fix formatting
* reenable matmul acceleration
* fix dot operator type conversion
* add offline test for dotop
* add missing ifdef wrappers
* run clang format on changes
* review and rebase fix
* add switch for MFMA instructions
* change check precision for float16 test
* disable redundant check for allowTF32
* - skip unsupported block size in matmul autotuner
- support transposed inputs of dot
* reenable matmul acceleration
* Add first part to FMA for dot operation on HW without MFMA support.
* Fix offline tests.
* Fix lit tests
* refactor mmav3 to mfma
* fix rebase issues
* fix detection of mfma support and wrong assert
* remove unnecessary macros
* Add documentation for MFMA layout.
* fix line size computation for B argument
* Fix getElemsPerThread() and getSizePerThread() functions for MFMA layout.
---------
Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>
Co-authored-by: dfukalov <1671137+dfukalov@users.noreply.github.com>
Co-authored-by: weihan13 <weihan13@amd.com>
Co-authored-by: Ognjen Plavsic <ognjen.plavsic@dxc.com>
Right now, `do_bench` estimates the runtime of the kernel and then uses
that to run it a number of times approximately equal to 100ms (by
default).
However, when actually running the kernel, it also issues a `zero_()`
call to clear the L2 cache. For small kernels, the `zero_()` kernel can
be slower than the actual kernel we're benchmarking, causing us to badly
overshoot our target latency.
This has the perverse effect that very small invocations may take much
longer to autotune than larger ones. By way of concrete example, before
this PR, I tested the wall-clock time for the first call to
`triton.ops.matmul(A, B.T)` in a process, on two `(N, N)` matrices in
float32. I found that a 4k x 4k x 4k matmul warmed up in about 2.5s, but
a 64 x 64 x 64 matmul took over 5 seconds!
This PR fixes this issue by including the same call to `zero_()` inside
our measurement loop.
With this change, I find that the 4kx4kx4k and 64x64x64 matmuls warm up
in very similar amounts of time, both around 2.5s.
I noticed this because we tend to run tests on very small models in CI
just to test code paths without regard to numerics, and found those
tests were perversely taking longer than "real" models in some cases. It
seems plausible that a better solution would be a pragma to disable
autotuning entirely for such tests, but I think this change is a clear
improvement as-is.
- If `TRITON_DEBUG=True`, all triton functions will be compiled in the
debug mode.
- Otherwise, a triton function `f`'s debug flag is either `True`,
`False` or `None` (default).
- If `True`, `f` is compiled in the debug mode.
- If `False`, `f` is compiled in the normal mode.
- If `None`, `f` is compiled based on its caller's debug flag. The root
(kernel) function's debug flag can also be set through the `compile`
function.
cc @ngimel , @Chillee