Commit Graph

1894 Commits

Author SHA1 Message Date
Alexander Efimov
6691de65db [MFMA] Support BFloat16 on MI100 (#295)
* [MFMA] Support BFloat16 on MI100

This PR makes use of mfma_f32_32x32x4bf16 instruction, available on MI100.

* fix tests, fix mfma encoding comment, fix switch between mfma versions.

* replace kDim from mfma layout with kWidth from dotOp layout

* rebase fix

* fix mfma to dot op shortcut for bfloat16

* fix review comments
2023-09-08 15:08:34 -05:00
SJW
491eb9ddfe [MLIR] Added tritongpu-stream-pipeline pass (#305)
* [MLIR] Added tritongpu-stream-pipeline pass
     - Prologue: Hoist the pipelinable load operations and shared memory store
       for the ramp up stage
     - Pipelined Loop: Assemble the loop body minus last iteration
       - Prefetch next tile from global into regs (while computing from previous)
       - Non-load loop body
       - Store next tile into shared mem
     - Epilogue: Peeled non-load loop body for last iteration

* * updated comment
2023-09-07 15:24:59 -05:00
jayfurmanek
83a0958566 Merge pull request #322 from ROCmSoftwarePlatform/f8_and_bf16_conversions
Enable fp8 conversions and fix bf16 conversions
2023-09-07 14:38:16 -05:00
Wen Chen
076a04d5eb [ROCM] Optimized int8 to bf16 conversion by not reusing FpToFpOpConversion::convertFp32ToBf16.
Changed the lit test rules for vectorized int8 to bf16 conversion on
ROCm as ROCm has a different implementation.
2023-09-07 17:26:43 +00:00
Shucai Xiao
fb3f2d6feb refine gemm tuning scripts (#309)
* refine the gemm tuning scripts to reduce tuning space and better perf numbers

* added code to support tuning in full tuning space

* add a function to get best tuning config

* refine the matmul tutorial example to print out best tuning config for each input

* added even_k to gemm kernel heuristic for better performance

* address review comments
2023-09-07 08:09:11 -05:00
Wen Chen
ffc230ebfe [ROCM] Fixed implementation of fp32 to bf16 conversion on ROCm. 2023-09-06 18:10:54 -05:00
Wen Chen
2d3e38e182 [ROCM] Added ROCm support for int8 to bfloat16 conversion. 2023-09-06 18:10:54 -05:00
Wen Chen
59a40d3f72 [ROCM] Added ROCm support for the conversions of following data types:
[float8e4m3, float8e4m3b15, float8e5m2] <-> [float16, bfloat16]
2023-09-06 18:10:54 -05:00
jayfurmanek
00393d0bc0 Merge pull request #321 from binarman/ifu230810_fp16_upstream_fix
Align FP16 FMA dot fix with upstream
2023-09-06 12:51:16 -05:00
jayfurmanek
9f1a5f42df Merge pull request #317 from ROCmSoftwarePlatform/ifu230810-2
IIFU 230810 (Rebase)
2023-09-06 09:35:06 -05:00
Aleksandr Efimov
751edfb3b9 [BACKEND] Fix fma mixed-precision
This is partial cherry-pick of https://github.com/openai/triton/pull/2184

Dropped code unrealted to dot fix.
2023-09-05 21:16:50 +00:00
Aleksandr Efimov
591681d36e Revert "[Dot] Fix FMA fp16xfp16 dot (#315)"
This reverts commit 11752a6993.
2023-09-05 21:12:56 +00:00
Alexander Efimov
11752a6993 [Dot] Fix FMA fp16xfp16 dot (#315)
Disable reorder of FMA dot arguments for amd gpu.
2023-09-05 20:52:08 +00:00
Keren Zhou
c0f418bcdd [BACKEND] Fix BF16 dot operand type mismatch (#2162)
https://github.com/openai/triton/issues/2156
2023-09-05 20:46:31 +00:00
Aleksandr Efimov
2f7ead6f3b Fix subprocess tests for IFU
This PR changes printf ttgir -> llvm conversion,
unknown location is assigned to global constant holding format string.
This fixes problem in test_subprocess.py tests,
which failed during construction of file location for format string constants.
2023-09-05 20:46:04 +00:00
Corbin Robeck
007bea9994 Add bitcode writer to AMDGCN hsaco output 2023-09-01 04:02:29 +00:00
Jason Furmanek
320b1029da Temporarily disable F8 tests on ROCm 2023-09-01 04:02:14 +00:00
Jason Furmanek
7954cc8e16 Fix reorder-instrutcions ROCM 2023-09-01 04:02:06 +00:00
Jason Furmanek
df5c263a19 Fix merge conflicts 2023-09-01 04:01:32 +00:00
Jason Furmanek
3eaeb89d18 Merge commit '5df904233c11a65bd131ead7268f84cca7804275' into ifu230810-2
Conflicts:
	include/triton/Dialect/Triton/Transforms/Passes.h
	include/triton/Dialect/TritonGPU/IR/Dialect.h
	include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td
	lib/Analysis/Allocation.cpp
	lib/Analysis/Utility.cpp
	lib/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVM.cpp
	lib/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.cpp
	lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp
	lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
	lib/Dialect/TritonGPU/Transforms/ReorderInstructions.cpp
	lib/Target/LLVMIR/LLVMIRTranslation.cpp
	python/src/triton.cc
	python/triton/compiler/compiler.py
	python/triton/ops/flash_attention.py
	python/triton/runtime/autotuner.py
	python/triton/runtime/jit.py
	python/triton/tools/aot.py
	python/tutorials/06-fused-attention.py
	test/Conversion/tritongpu_to_llvm.mlir
	test/Target/tritongpu_to_llvmir.mlir
	test/Target/tritongpu_to_llvmir_noinline.mlir
2023-09-01 03:25:33 +00:00
Vinayak Gokhale
9cdf3a58c3 Enable split kernel in bwd pass (#303)
* Add fwd and bwd v2

Changes are largely from upstream.

* Split bwd kernel in dq and dk+dv

Only adds the split kernels. They are not enabled yet.

* Pull scalar multiplies out of the loop

* Enable split kernel for bwd pass

* Put back P_SEQ=128 in fwd test

Not used for bwd test

* Address review comments

* Address comments

Conditionally set causal/ splitkernel to False for bwd.

* Add block pointer semantics to bwd pass

This significantly increases perf for bwd, similar to fwd.
2023-08-29 13:51:29 -05:00
Lixun Zhang
b834f42ae4 [autotuner] Add an option to print best_config for each key 2023-08-28 14:45:54 -05:00
jayfurmanek
ff7e707f87 Enable usage of block pointer semantics for AMD gpus (#301)
* Enable usage of block pointer semantics for AMD gpus

This commit enables usage of block pointer semantics by enabling
rewrite_tensor_pointer_pass that rewrites block pointer loads/stores
to legacy loads/stores.

* Update FA fwd in tutorial to use the block pointers

* use 90 compute capability for amd gpus in python/triton/compiler/compiler.py

Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>

---------

Co-authored-by: Ognjen Plavsic <ognjen.plavsic@dxc.com>
Co-authored-by: Lixun Zhang <lixun.zhang@amd.com>
Co-authored-by: Aleksandr Efimov <130555951+alefimov-amd@users.noreply.github.com>
Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>
2023-08-24 13:05:12 -05:00
jayfurmanek
fa429316d4 Merge pull request #268 from ROCmSoftwarePlatform/improve_reduce_for_fa
[CHERRY-PICKED FROM UPSTREAM][BACKEND] no longer uses shared mem or barriers for single-warp reductions (openai#1915)
2023-08-21 13:29:11 -05:00
Alexander Efimov
d86b19f7a3 [CI] [Dot] Reduced test suite (#302)
Use upstream list of test for dot op on machines with no MFMA support.
This is needed to reduce time required for PR testing.
2023-08-18 07:47:14 -05:00
Alexander Efimov
01b0108c94 [MFMA] [FA] Keep bf16 results of FA dot operations in registers (#298)
This PR enables optimization for keeping bf16 values in registers between dot operations.
2023-08-18 07:33:00 -05:00
Alexander Efimov
9ab335196f [MFMA] More optimal offset computation (#286)
This PR replaces expensive operations with simpler ones:
mul,div are replaced with select and compare.

This is minor change, it decreses number of required registers
when dot operation loading is a bottleneck by one.
2023-08-18 07:32:38 -05:00
Alexander Efimov
23979098c8 [MFMA] MI200 bfloat16 support (#294)
This PR enables bfloat16 support in MFMA dot on MI200.
Used mfma_f32_32x32x8bf16_1k instruction.
2023-08-18 07:28:18 -05:00
Shucai Xiao
f7cf2c032b Changes of the tutorial matmul scripts to get good performance (#297)
* simple changes of the matmul scripts to get good performance. Specification reason for the performance boost needs futher investigation and are tracked

* fix review comments

* change the num_warps in the autotuning config for hip to workaround an error and change the rtol so correctness check passed
2023-08-17 13:24:49 -05:00
Lixun Zhang
e924e0ecf6 [NFC] Rename MMAv2kWidth to kWidth (#299) 2023-08-17 07:45:22 -05:00
Lixun Zhang
87e45cb011 Set vecSize and maxPhase more generically 2023-08-16 08:30:32 -05:00
Lixun Zhang
7156fcb0ef Set vecSize = 4 and maxPhase = BLOCK_K/4 2023-08-16 08:30:32 -05:00
Ognjen Plavsic
d0b7793935 Add hasSideEffects flag to intra-warp shuffle inline assembly 2023-08-15 13:21:44 +00:00
Philippe Tillet
4215086931 [BACKEND] no longer uses shared mem or barriers for single-warp reductions (#1915)
0-bytes shared mem buffers don't materialize empty allocation buffers;
this could lead to unnecessary barriers.

note: reduceop code has become quite messy and will require some cleanup
2023-08-15 11:51:20 +00:00
oplavsic
398d2c7dd0 Fix FA v2 hanging issue when BLOCK_N=32 (#274)
* Fix FA v2 hanging issue when BLOCK_N=32

* Fix broken tests
2023-08-10 12:52:09 -05:00
Alexander Efimov
a1f4ee6a6b [MFMA] Swizzled operands (#285)
This pr enables generation of swizzled tensors for mfma dot operands
2023-08-09 14:53:15 -05:00
Alexander Efimov
af05f01218 [Tests] Fix some tests in test_core_amd.py (#288)
This PR:
- enables test_dot_mfma_vector_load for fast path in mfma dot op pipeline
- fixes kernel execution for mfma enabled GPUS
- disables mfma layout conversion tests on architectures which can not run these tests
2023-08-08 20:12:32 +02:00
Alexander Efimov
ff95dddf18 [tutorial] Amd specific tuning configs (#287)
This pr adds amd specific tuning configs for matmul tutorial with num_stages == 1.
2023-08-08 20:11:30 +02:00
jayfurmanek
32d7c6d646 Fix runtime/test_subproc.py for hip devices (#284)
* Fix runtime/test_subproc.py for hip devices

* address review comments
2023-08-07 10:30:36 -05:00
oplavsic
138844568d Use optimal number of VGPRs (#281)
* Use optimal number of VGPRs

* Fix tritongpu_to_hsaco test
2023-08-04 10:46:53 -05:00
Alexander Efimov
e1de24cd5c [MFMA][Dot] Vector loads for fast path (#278)
This PR adds generation of vector loads in
fast path of `shared`->`mfma dot op` layout conversion

Co-authored-by: jayfurmanek <Jason.Furmanek@amd.com>
2023-08-03 16:45:25 -05:00
Alexander Efimov
7158ec286a [MFMA] [Dot] Support vector loads in normal path (#275)
* [MFMA] [Dot] Support vector loads in normal path

This PR adds generation of vector loads in normal path of
MFMA dot operand loading.
This requires shared layout to have contiguous elements
which should be loaded by one lane.

* remove redundant refactoring

* fix tests

* extend test with transposed A/B tensors
2023-08-03 14:57:39 -05:00
Alexander Efimov
86f8b64ae0 [Dot] [MFMA] [FMA] Update Dot implementation to support upstream tests (#260)
* [Dot] [MFMA] Support FP16 output of MFMA dot

This PR adds cast of output tensor to requested data type.

* add tests

* fix test for FMA implementation

* loose fp16xfp16->fp16 tolerance

* enable FMA fallback for unsupported sizes of dot operation

* rework granularity check

* add constant modifier to granularity
2023-08-03 13:47:18 -05:00
Vinayak Gokhale
f1063bb33c Enable backward pass in FA tutorial test (#282)
Enabled the backward pass in the fused attention tutorial.
The tolerance when comparing to the naive implementation
had to be changed. The block size is forced to be 64x64
due to the 64 KiB LDS. Default is block 128 for A100's
larger SMEM. This creates differences in order of computation
and reuslts in a larger gap between the naive and FA
implementations.
2023-08-03 10:12:46 -05:00
Shucai Xiao
31cfda8f0e enable more gemm tests corresponding to PR#273 (#279) 2023-08-02 16:45:31 -05:00
Alexander Efimov
a318bc5777 Disable pipelining (#276)
This PR sets default value of pipeline stages to 1 for amd.
Add explicit num stages in test_gemm test
2023-08-01 14:54:17 -05:00
Shucai Xiao
82db24cde7 add gfx940/941 to archs that support mfma (#277) 2023-08-01 10:57:16 -05:00
Philippe Tillet
5df904233c [BACKEND] now clamping fp8e4b15 in conversion from fp16 (#2017) 2023-07-31 18:50:42 -07:00
Yan Chunwei
89b0b79d75 [FRONTEND] fix the silent return issue in AOT launcher (#2013)
In the current link.py, it produces the launcher code as below:

```python
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_am, int32_t stride_bk){
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm % 16 == 0))
    return matmul_fp16xfp16_16x16x16_688cc413_0d1d2d3d45d(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
  // ...
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0))
    return matmul_fp16xfp16_16x16x16_7c0255bf_0d1d2d345(stream, gX, gY, gZ, C, A, B, stride_cm, stride_am, stride_bk);
}
```
Note that, when the input does not match any of the if branches, it will
do nothing, and the compiler should make it return 0 as a default
behavior, which equals to `CUDA_SUCCESS`, this doesn't match the
expectation.

This PR adds a `return CUDA_VALUE_ERROR;` to the tail of launchers, and
it produces code like:

```c++
CUresult matmul_fp16xfp16_16x16x16(CUstream stream, unsigned int gX, unsigned int gY, unsigned int gZ, CUdeviceptr C, CUdeviceptr A, CUdeviceptr B, int32_t stride_cm, int32_t stride_cn, int32_t stride_am, int32_t stride_ak, int32_t stride_bk, int32_t stride_bn){
  if ((C % 16 == 0) && (A % 16 == 0) && (B % 16 == 0) && (stride_cm == 1) && (stride_cn == 1) && (stride_am == 1) && (stride_ak == 1) && (stride_bk % 16 == 0) && (stride_bn == 1))
    return matmul_fp16xfp16_16x16x16_1f18a6da_0d1d2d3c4c5c6c7d8c(stream, gX, gY, gZ, C, A, B, stride_bk);

  return CUDA_ERROR_INVALID_VALUE;
}
```

And it requires users to check the result in their application, which I
think should match the initial AOT ideas.
2023-07-31 09:59:28 -07:00
Lixun Zhang
0cc3f8b84a Favor dim 1 (N) when setting warpsPerCTA for non-FA mfma (#273)
* Set a more squared warpsPerCTA for non-FA mfma

Fix https://github.com/ROCmSoftwarePlatform/frameworks-internal/issues/4790

* Set warpsPerCTA along the N dim first

* remove the swap workaround since it harm gemm tutorial perf
2023-07-28 16:02:48 -05:00