From f12e2a75db1224793d8678552d1982713378a23b Mon Sep 17 00:00:00 2001 From: wozeparrot Date: Fri, 10 Oct 2025 00:32:33 -0700 Subject: [PATCH] feat: add thunderkittens (#12590) --- .../thunder/cuda/include/common/base_ops.cuh | 400 ++++++ .../cuda/include/common/base_types.cuh | 519 +++++++ extra/thunder/cuda/include/common/common.cuh | 11 + extra/thunder/cuda/include/common/debug.cuh | 56 + extra/thunder/cuda/include/common/util.cuh | 314 ++++ extra/thunder/cuda/include/kittens.cuh | 12 + .../cuda/include/ops/device/device.cuh | 51 + .../thunder/cuda/include/ops/group/group.cuh | 96 ++ .../cuda/include/ops/group/memory/memory.cuh | 21 + .../complex/complex_global_to_register.cuh | 42 + .../tile/complex/complex_global_to_shared.cuh | 37 + .../complex/complex_shared_to_register.cuh | 34 + .../group/memory/tile/global_to_register.cuh | 207 +++ .../group/memory/tile/global_to_shared.cuh | 168 +++ .../group/memory/tile/shared_to_register.cuh | 323 +++++ .../group/memory/tile/tensor_to_register.cuh | 325 +++++ .../include/ops/group/memory/tile/tile.cuh | 16 + .../include/ops/group/memory/tile/tma.cuh | 134 ++ .../ops/group/memory/tile/tma_cluster.cuh | 33 + .../include/ops/group/memory/util/tma.cuh | 68 + .../ops/group/memory/util/tma_cluster.cuh | 90 ++ .../include/ops/group/memory/util/util.cuh | 168 +++ .../group/memory/vec/global_to_register.cuh | 138 ++ .../ops/group/memory/vec/global_to_shared.cuh | 77 + .../group/memory/vec/shared_to_register.cuh | 159 +++ .../cuda/include/ops/group/memory/vec/tma.cuh | 221 +++ .../ops/group/memory/vec/tma_cluster.cuh | 31 + .../cuda/include/ops/group/memory/vec/vec.cuh | 8 + .../cuda/include/ops/group/mma/mma.cuh | 17 + .../include/ops/group/mma/tensor/tensor.cuh | 172 +++ .../cuda/include/ops/group/mma/warp/warp.cuh | 947 +++++++++++++ .../ops/group/mma/warpgroup/base/64x112.impl | 334 +++++ .../ops/group/mma/warpgroup/base/64x128.impl | 813 +++++++++++ .../ops/group/mma/warpgroup/base/64x144.impl | 382 +++++ .../ops/group/mma/warpgroup/base/64x16.impl | 190 +++ .../ops/group/mma/warpgroup/base/64x160.impl | 666 +++++++++ .../ops/group/mma/warpgroup/base/64x176.impl | 430 ++++++ .../ops/group/mma/warpgroup/base/64x192.impl | 674 +++++++++ .../ops/group/mma/warpgroup/base/64x208.impl | 478 +++++++ .../ops/group/mma/warpgroup/base/64x224.impl | 826 +++++++++++ .../ops/group/mma/warpgroup/base/64x240.impl | 526 +++++++ .../ops/group/mma/warpgroup/base/64x256.impl | 1260 +++++++++++++++++ .../ops/group/mma/warpgroup/base/64x32.impl | 446 ++++++ .../ops/group/mma/warpgroup/base/64x48.impl | 238 ++++ .../ops/group/mma/warpgroup/base/64x64.impl | 587 ++++++++ .../ops/group/mma/warpgroup/base/64x80.impl | 286 ++++ .../ops/group/mma/warpgroup/base/64x96.impl | 703 +++++++++ .../ops/group/mma/warpgroup/base/base.cuh | 47 + .../ops/group/mma/warpgroup/warpgroup.cuh | 1170 +++++++++++++++ .../include/ops/group/register/register.cuh | 7 + .../tile/complex/complex_conversions.cuh | 98 ++ .../register/tile/complex/complex_maps.cuh | 137 ++ .../ops/group/register/tile/conversions.cuh | 415 ++++++ .../include/ops/group/register/tile/maps.cuh | 836 +++++++++++ .../ops/group/register/tile/reductions.cuh | 554 ++++++++ .../include/ops/group/register/tile/tile.cuh | 47 + .../ops/group/register/vec/conversions.cuh | 153 ++ .../include/ops/group/register/vec/maps.cuh | 374 +++++ .../ops/group/register/vec/reductions.cuh | 233 +++ .../include/ops/group/register/vec/vec.cuh | 59 + .../cuda/include/ops/group/shared/shared.cuh | 7 + .../ops/group/shared/tile/conversions.cuh | 16 + .../include/ops/group/shared/tile/maps.cuh | 236 +++ .../ops/group/shared/tile/reductions.cuh | 372 +++++ .../include/ops/group/shared/tile/tile.cuh | 37 + .../ops/group/shared/vec/conversions.cuh | 27 + .../include/ops/group/shared/vec/maps.cuh | 259 ++++ .../ops/group/shared/vec/reductions.cuh | 193 +++ .../cuda/include/ops/group/shared/vec/vec.cuh | 38 + extra/thunder/cuda/include/ops/ops.cuh | 262 ++++ .../cuda/include/ops/thread/memory/memory.cuh | 10 + .../include/ops/thread/memory/tile/tile.cuh | 10 + .../include/ops/thread/memory/tile/tma.cuh | 564 ++++++++ .../ops/thread/memory/util/multimem.cuh | 405 ++++++ .../include/ops/thread/memory/util/tensor.cuh | 30 + .../include/ops/thread/memory/util/tma.cuh | 249 ++++ .../include/ops/thread/memory/util/util.cuh | 443 ++++++ .../include/ops/thread/memory/vec/tma.cuh | 416 ++++++ .../include/ops/thread/memory/vec/vec.cuh | 10 + .../cuda/include/ops/thread/mma/mma.cuh | 8 + .../include/ops/thread/mma/tensor/tensor.cuh | 523 +++++++ .../cuda/include/ops/thread/thread.cuh | 13 + extra/thunder/cuda/include/pyutils/broker.cuh | 551 +++++++ extra/thunder/cuda/include/pyutils/club.cuh | 122 ++ .../cuda/include/pyutils/parallel_tensor.cuh | 336 +++++ .../thunder/cuda/include/pyutils/pyutils.cuh | 235 +++ .../cuda/include/pyutils/torch_helpers.cuh | 7 + .../cuda/include/pyutils/torchutils.cuh | 180 +++ extra/thunder/cuda/include/pyutils/util.cuh | 19 + .../cuda/include/types/device/device.cuh | 12 + .../thunder/cuda/include/types/device/ipc.cuh | 195 +++ .../thunder/cuda/include/types/device/pgl.cuh | 173 +++ .../thunder/cuda/include/types/device/vmm.cuh | 180 +++ .../thunder/cuda/include/types/global/cgl.cuh | 56 + .../thunder/cuda/include/types/global/gl.cuh | 225 +++ .../cuda/include/types/global/global.cuh | 13 + .../thunder/cuda/include/types/global/tma.cuh | 428 ++++++ .../cuda/include/types/global/util.cuh | 99 ++ .../cuda/include/types/register/crt.cuh | 95 ++ .../cuda/include/types/register/crv.cuh | 88 ++ .../cuda/include/types/register/register.cuh | 15 + .../cuda/include/types/register/rt.cuh | 155 ++ .../cuda/include/types/register/rt_base.cuh | 112 ++ .../cuda/include/types/register/rt_layout.cuh | 42 + .../cuda/include/types/register/rv.cuh | 122 ++ .../cuda/include/types/register/rv_layout.cuh | 40 + .../thunder/cuda/include/types/shared/cst.cuh | 82 ++ .../thunder/cuda/include/types/shared/csv.cuh | 74 + .../cuda/include/types/shared/shared.cuh | 14 + .../thunder/cuda/include/types/shared/st.cuh | 349 +++++ .../include/types/shared/st_descriptor.cuh | 118 ++ .../thunder/cuda/include/types/shared/sv.cuh | 130 ++ .../cuda/include/types/tensor/tensor.cuh | 112 ++ .../thunder/cuda/include/types/tensor/tt.cuh | 97 ++ extra/thunder/cuda/include/types/types.cuh | 68 + extra/thunder/{ => metal}/gemm.py | 0 .../{ => metal}/include/common/base_ops.metal | 0 .../include/common/base_types.metal | 0 .../{ => metal}/include/common/common.metal | 0 .../{ => metal}/include/common/utils.metal | 0 .../{ => metal}/include/ops/group/group.metal | 0 .../include/ops/group/memory/memory.metal | 0 .../memory/tile/global_to_register.metal | 0 .../group/memory/tile/global_to_shared.metal | 0 .../memory/tile/shared_to_register.metal | 0 .../include/ops/group/memory/tile/tile.metal | 0 .../group/memory/vec/global_to_register.metal | 0 .../group/memory/vec/global_to_shared.metal | 0 .../group/memory/vec/shared_to_register.metal | 0 .../include/ops/group/memory/vec/vec.metal | 0 .../include/ops/group/shared/shared.metal | 0 .../ops/group/shared/tile/conversions.metal | 0 .../include/ops/group/shared/tile/maps.metal | 0 .../ops/group/shared/tile/reductions.metal | 0 .../include/ops/group/shared/tile/tile.metal | 0 .../ops/group/shared/vec/conversions.metal | 0 .../include/ops/group/shared/vec/maps.metal | 0 .../include/ops/group/shared/vec/vec.metal | 0 .../thunder/{ => metal}/include/ops/ops.metal | 0 .../include/ops/warp/memory/memory.metal | 0 .../complex/complex_global_to_register.metal | 0 .../complex/complex_global_to_shared.metal | 0 .../complex/complex_shared_to_register.metal | 0 .../warp/memory/tile/global_to_register.metal | 0 .../warp/memory/tile/global_to_shared.metal | 0 .../warp/memory/tile/shared_to_register.metal | 0 .../include/ops/warp/memory/tile/tile.metal | 0 .../include/ops/warp/memory/util/util.metal | 0 .../warp/memory/vec/global_to_register.metal | 0 .../warp/memory/vec/global_to_shared.metal | 0 .../warp/memory/vec/shared_to_register.metal | 0 .../include/ops/warp/memory/vec/vec.metal | 0 .../include/ops/warp/register/register.metal | 0 .../ops/warp/register/tile/conversions.metal | 0 .../include/ops/warp/register/tile/maps.metal | 0 .../include/ops/warp/register/tile/mma.metal | 0 .../ops/warp/register/tile/reductions.metal | 0 .../include/ops/warp/register/tile/tile.metal | 0 .../ops/warp/register/vec/conversions.metal | 0 .../include/ops/warp/register/vec/maps.metal | 0 .../ops/warp/register/vec/reductions.metal | 0 .../include/ops/warp/register/vec/vec.metal | 0 .../include/ops/warp/shared/shared.metal | 0 .../ops/warp/shared/tile/conversions.metal | 0 .../include/ops/warp/shared/tile/maps.metal | 0 .../ops/warp/shared/tile/reductions.metal | 0 .../include/ops/warp/shared/tile/tile.metal | 0 .../ops/warp/shared/vec/conversions.metal | 0 .../include/ops/warp/shared/vec/maps.metal | 0 .../ops/warp/shared/vec/reductions.metal | 0 .../include/ops/warp/shared/vec/vec.metal | 0 .../{ => metal}/include/ops/warp/warp.metal | 0 extra/thunder/{ => metal}/include/tk.metal | 0 .../include/types/global/cgl.metal | 0 .../{ => metal}/include/types/global/gl.metal | 0 .../include/types/global/global.metal | 0 .../include/types/global/util.metal | 0 .../include/types/register/crt.metal | 0 .../include/types/register/crv.metal | 0 .../include/types/register/register.metal | 0 .../include/types/register/rt.metal | 0 .../include/types/register/rt_base.metal | 0 .../include/types/register/rt_layout.metal | 0 .../include/types/register/rv.metal | 0 .../include/types/register/rv_layout.metal | 0 .../include/types/shared/cst.metal | 0 .../include/types/shared/csv.metal | 0 .../include/types/shared/shared.metal | 0 .../{ => metal}/include/types/shared/st.metal | 0 .../{ => metal}/include/types/shared/sv.metal | 0 .../{ => metal}/include/types/types.metal | 0 191 files changed, 26536 insertions(+) create mode 100644 extra/thunder/cuda/include/common/base_ops.cuh create mode 100644 extra/thunder/cuda/include/common/base_types.cuh create mode 100644 extra/thunder/cuda/include/common/common.cuh create mode 100644 extra/thunder/cuda/include/common/debug.cuh create mode 100644 extra/thunder/cuda/include/common/util.cuh create mode 100644 extra/thunder/cuda/include/kittens.cuh create mode 100644 extra/thunder/cuda/include/ops/device/device.cuh create mode 100644 extra/thunder/cuda/include/ops/group/group.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/memory.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_global_to_register.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_global_to_shared.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_shared_to_register.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/global_to_register.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/global_to_shared.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/shared_to_register.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/tensor_to_register.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/tile.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/tma.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/tile/tma_cluster.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/util/tma.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/util/tma_cluster.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/util/util.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/vec/global_to_register.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/vec/global_to_shared.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/vec/shared_to_register.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/vec/tma.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/vec/tma_cluster.cuh create mode 100644 extra/thunder/cuda/include/ops/group/memory/vec/vec.cuh create mode 100644 extra/thunder/cuda/include/ops/group/mma/mma.cuh create mode 100644 extra/thunder/cuda/include/ops/group/mma/tensor/tensor.cuh create mode 100644 extra/thunder/cuda/include/ops/group/mma/warp/warp.cuh create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x112.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x128.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x144.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x16.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x160.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x176.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x192.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x208.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x224.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x240.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x256.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x32.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x48.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x64.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x80.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x96.impl create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/base/base.cuh create mode 100644 extra/thunder/cuda/include/ops/group/mma/warpgroup/warpgroup.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/register.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/tile/complex/complex_conversions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/tile/complex/complex_maps.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/tile/conversions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/tile/maps.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/tile/reductions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/tile/tile.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/vec/conversions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/vec/maps.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/vec/reductions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/register/vec/vec.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/shared.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/tile/conversions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/tile/maps.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/tile/reductions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/tile/tile.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/vec/conversions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/vec/maps.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/vec/reductions.cuh create mode 100644 extra/thunder/cuda/include/ops/group/shared/vec/vec.cuh create mode 100644 extra/thunder/cuda/include/ops/ops.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/memory.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/tile/tile.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/tile/tma.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/util/multimem.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/util/tensor.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/util/tma.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/util/util.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/vec/tma.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/memory/vec/vec.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/mma/mma.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/mma/tensor/tensor.cuh create mode 100644 extra/thunder/cuda/include/ops/thread/thread.cuh create mode 100644 extra/thunder/cuda/include/pyutils/broker.cuh create mode 100644 extra/thunder/cuda/include/pyutils/club.cuh create mode 100644 extra/thunder/cuda/include/pyutils/parallel_tensor.cuh create mode 100644 extra/thunder/cuda/include/pyutils/pyutils.cuh create mode 100644 extra/thunder/cuda/include/pyutils/torch_helpers.cuh create mode 100644 extra/thunder/cuda/include/pyutils/torchutils.cuh create mode 100644 extra/thunder/cuda/include/pyutils/util.cuh create mode 100644 extra/thunder/cuda/include/types/device/device.cuh create mode 100644 extra/thunder/cuda/include/types/device/ipc.cuh create mode 100644 extra/thunder/cuda/include/types/device/pgl.cuh create mode 100644 extra/thunder/cuda/include/types/device/vmm.cuh create mode 100644 extra/thunder/cuda/include/types/global/cgl.cuh create mode 100644 extra/thunder/cuda/include/types/global/gl.cuh create mode 100644 extra/thunder/cuda/include/types/global/global.cuh create mode 100644 extra/thunder/cuda/include/types/global/tma.cuh create mode 100644 extra/thunder/cuda/include/types/global/util.cuh create mode 100644 extra/thunder/cuda/include/types/register/crt.cuh create mode 100644 extra/thunder/cuda/include/types/register/crv.cuh create mode 100644 extra/thunder/cuda/include/types/register/register.cuh create mode 100644 extra/thunder/cuda/include/types/register/rt.cuh create mode 100644 extra/thunder/cuda/include/types/register/rt_base.cuh create mode 100644 extra/thunder/cuda/include/types/register/rt_layout.cuh create mode 100644 extra/thunder/cuda/include/types/register/rv.cuh create mode 100644 extra/thunder/cuda/include/types/register/rv_layout.cuh create mode 100644 extra/thunder/cuda/include/types/shared/cst.cuh create mode 100644 extra/thunder/cuda/include/types/shared/csv.cuh create mode 100644 extra/thunder/cuda/include/types/shared/shared.cuh create mode 100644 extra/thunder/cuda/include/types/shared/st.cuh create mode 100644 extra/thunder/cuda/include/types/shared/st_descriptor.cuh create mode 100644 extra/thunder/cuda/include/types/shared/sv.cuh create mode 100644 extra/thunder/cuda/include/types/tensor/tensor.cuh create mode 100644 extra/thunder/cuda/include/types/tensor/tt.cuh create mode 100644 extra/thunder/cuda/include/types/types.cuh rename extra/thunder/{ => metal}/gemm.py (100%) rename extra/thunder/{ => metal}/include/common/base_ops.metal (100%) rename extra/thunder/{ => metal}/include/common/base_types.metal (100%) rename extra/thunder/{ => metal}/include/common/common.metal (100%) rename extra/thunder/{ => metal}/include/common/utils.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/group.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/memory.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/tile/global_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/tile/global_to_shared.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/tile/shared_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/tile/tile.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/vec/global_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/vec/global_to_shared.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/vec/shared_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/memory/vec/vec.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/shared/shared.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/shared/tile/conversions.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/shared/tile/maps.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/shared/tile/reductions.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/shared/tile/tile.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/shared/vec/conversions.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/shared/vec/maps.metal (100%) rename extra/thunder/{ => metal}/include/ops/group/shared/vec/vec.metal (100%) rename extra/thunder/{ => metal}/include/ops/ops.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/memory.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/tile/complex/complex_global_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/tile/complex/complex_global_to_shared.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/tile/complex/complex_shared_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/tile/global_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/tile/global_to_shared.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/tile/shared_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/tile/tile.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/util/util.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/vec/global_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/vec/global_to_shared.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/vec/shared_to_register.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/memory/vec/vec.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/register.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/tile/conversions.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/tile/maps.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/tile/mma.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/tile/reductions.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/tile/tile.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/vec/conversions.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/vec/maps.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/vec/reductions.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/register/vec/vec.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/shared.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/tile/conversions.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/tile/maps.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/tile/reductions.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/tile/tile.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/vec/conversions.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/vec/maps.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/vec/reductions.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/shared/vec/vec.metal (100%) rename extra/thunder/{ => metal}/include/ops/warp/warp.metal (100%) rename extra/thunder/{ => metal}/include/tk.metal (100%) rename extra/thunder/{ => metal}/include/types/global/cgl.metal (100%) rename extra/thunder/{ => metal}/include/types/global/gl.metal (100%) rename extra/thunder/{ => metal}/include/types/global/global.metal (100%) rename extra/thunder/{ => metal}/include/types/global/util.metal (100%) rename extra/thunder/{ => metal}/include/types/register/crt.metal (100%) rename extra/thunder/{ => metal}/include/types/register/crv.metal (100%) rename extra/thunder/{ => metal}/include/types/register/register.metal (100%) rename extra/thunder/{ => metal}/include/types/register/rt.metal (100%) rename extra/thunder/{ => metal}/include/types/register/rt_base.metal (100%) rename extra/thunder/{ => metal}/include/types/register/rt_layout.metal (100%) rename extra/thunder/{ => metal}/include/types/register/rv.metal (100%) rename extra/thunder/{ => metal}/include/types/register/rv_layout.metal (100%) rename extra/thunder/{ => metal}/include/types/shared/cst.metal (100%) rename extra/thunder/{ => metal}/include/types/shared/csv.metal (100%) rename extra/thunder/{ => metal}/include/types/shared/shared.metal (100%) rename extra/thunder/{ => metal}/include/types/shared/st.metal (100%) rename extra/thunder/{ => metal}/include/types/shared/sv.metal (100%) rename extra/thunder/{ => metal}/include/types/types.metal (100%) diff --git a/extra/thunder/cuda/include/common/base_ops.cuh b/extra/thunder/cuda/include/common/base_ops.cuh new file mode 100644 index 0000000000..cbade07113 --- /dev/null +++ b/extra/thunder/cuda/include/common/base_ops.cuh @@ -0,0 +1,400 @@ +/** + * @file + * @brief Basic operations on generic types. + */ + +#pragma once + +#include +#include +#include "base_types.cuh" + +namespace kittens { + +/** + * @namespace base_ops + * + * @brief A namespace for operations on basic data types. + */ +namespace base_ops { + +/* ---------- CONST OPS ---------- */ + +/** + * @brief Represents the zero constant operation. + * + * This operation returns the zero value of the specified type. + * + * @tparam T The data type for which to return the zero value. + * @return The zero value of type T. + */ +struct zero { + template __device__ static inline constexpr T op(args... _) { return base_types::constants::zero(); } +}; +/** + * @brief Represents the one constant operation. + * + * This operation returns the one value of the specified type. + * + * @tparam T The data type for which to return the one value. + * @return The one value of type T. + */ +struct one { + template __device__ static inline constexpr T op(args... _) { return base_types::constants::one(); } +}; +/** + * @brief Represents the positive infinity constant operation. + * + * This operation returns the positive infinity value of the specified type. + * + * @tparam T The data type for which to return the positive infinity value. + * @return The positive infinity value of type T. + */ +struct pos_infty { + template __device__ static inline constexpr T op(args... _) { return base_types::constants::pos_infty(); } +}; +/** + * @brief Represents the negative infinity constant operation. + * + * This operation returns the negative infinity value of the specified type. + * + * @tparam T The data type for which to return the negative infinity value. + * @return The negative infinity value of type T. + */ +struct neg_infty { + template __device__ static inline constexpr T op(args... _) { return base_types::constants::neg_infty(); } +}; + + +/* ---------- UNARY OPS ---------- */ + +/** + * @brief Exponential function operation. + * + * This operation calculates the exponential of the input value. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The exponential of the input value. + */ +struct exp { + template static __device__ inline T op(const T &x) { return exp(x); } +}; +template<> __device__ inline float exp::op (const float &x ) { return __expf(x); } +template<> __device__ inline float2 exp::op(const float2 &x) { return float2{__expf(x.x), __expf(x.y)}; } +template<> __device__ inline bf16 exp::op (const bf16 &x ) { return hexp(x); } +template<> __device__ inline bf16_2 exp::op(const bf16_2 &x) { return h2exp(x); } +template<> __device__ inline half exp::op (const half &x ) { return hexp(x); } +template<> __device__ inline half_2 exp::op(const half_2 &x) { return h2exp(x); } + +/** + * @brief Exponential function operation, in base 2 + * + * This operation calculates the exponential of the input value, in base 2. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The exponential of the input value. + */ +struct exp2 { + template static __device__ inline T op(const T &x) { return exp2f(x); } +}; +template<> __device__ inline float exp2::op (const float &x ) { return exp2f(x); } +template<> __device__ inline float2 exp2::op(const float2 &x) { return float2{exp2f(x.x), exp2f(x.y)}; } +template<> __device__ inline bf16 exp2::op (const bf16 &x ) { return hexp2(x); } +template<> __device__ inline bf16_2 exp2::op(const bf16_2 &x) { return h2exp2(x); } +template<> __device__ inline half exp2::op (const half &x ) { return hexp2(x); } +template<> __device__ inline half_2 exp2::op(const half_2 &x) { return h2exp2(x); } +/** + * @brief Natural log function operation. + * + * This operation calculates the natural logarithm of the input value. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The natural logarithm of the input value. + */ +struct log { + template static __device__ inline T op(const T &x) { return log(x); } +}; +template<> __device__ inline float log::op (const float &x ) { return __logf(x); } +template<> __device__ inline float2 log::op(const float2 &x) { return float2{__logf(x.x), __logf(x.y)}; } +template<> __device__ inline bf16 log::op (const bf16 &x ) { return hlog(x); } +template<> __device__ inline bf16_2 log::op(const bf16_2 &x) { return h2log(x); } +template<> __device__ inline half log::op (const half &x ) { return hlog(x); } +template<> __device__ inline half_2 log::op(const half_2 &x) { return h2log(x); } +/** + * @brief Logarithm base 2 operation. + * + * This operation calculates the logarithm base 2 of the input value. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The logarithm base 2 of the input value. + */ +struct log2 { + template static __device__ inline T op(const T &x) { return log2(x); } +}; +template<> __device__ inline float log2::op (const float &x ) { return __log2f(x); } +template<> __device__ inline float2 log2::op(const float2 &x) { return float2{__log2f(x.x), __log2f(x.y)}; } +template<> __device__ inline bf16 log2::op (const bf16 &x ) { return hlog2(x); } +template<> __device__ inline bf16_2 log2::op(const bf16_2 &x) { return h2log2(x); } +template<> __device__ inline half log2::op (const half &x ) { return hlog2(x); } +template<> __device__ inline half_2 log2::op(const half_2 &x) { return h2log2(x); } +/** + * @brief Absolute value operation. + * + * This operation calculates the absolute value of the input. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The absolute value of the input. + */ +struct abs { + template static __device__ inline T op(const T &x) { return abs(x); } +}; +template<> __device__ inline float abs::op (const float &x ) { return fabsf(x); } +template<> __device__ inline float2 abs::op(const float2 &x) { return float2{fabsf(x.x), fabsf(x.y)}; } +template<> __device__ inline bf16 abs::op (const bf16 &x ) { return __habs(x); } +template<> __device__ inline bf16_2 abs::op(const bf16_2 &x) { return __habs2(x); } +template<> __device__ inline half abs::op (const half &x ) { return __habs(x); } +template<> __device__ inline half_2 abs::op(const half_2 &x) { return __habs2(x); } +/** + * @brief Rectified Linear Unit (ReLU) operation. + * + * This operation applies the ReLU function to the input, which is the + * maximum of zero and the input value. + * + * @tparam T The data type of the input and output values. + * @param x[in] The input value. + * @return The result of ReLU function applied to the input. + */ +struct relu { + template static __device__ inline T op(const T &x) { return max(x, base_types::constants::zero()); } +}; +template<> __device__ inline float relu::op (const float &x ) { return max(x, 0.f); } +template<> __device__ inline float2 relu::op(const float2 &x) { return float2{max(x.x, 0.f), max(x.y, 0.f)}; } +template<> __device__ inline bf16 relu::op (const bf16 &x ) { return __hmax(x, base_types::constants::zero()); } +template<> __device__ inline bf16_2 relu::op(const bf16_2 &x) { return __hmax2(x, base_types::constants::zero()); } +template<> __device__ inline half relu::op (const half &x ) { return __hmax(x, base_types::constants::zero()); } +template<> __device__ inline half_2 relu::op(const half_2 &x) { return __hmax2(x, base_types::constants::zero()); } +/** + * @brief Copy operation. + * + * This operation returns the input value unchanged. + * + * @tparam T The data type of the input and output values. + * @param a[in] The input value. + * @return The same value as the input. + */ +struct copy { // for non-compile-time setters. + template static __device__ inline T op(const T &a) { return a; } +}; + + +/* ---------- BINARY OPS ---------- */ + +/** + * @brief Copy2 operation. + * + * This operation returns the second input value unchanged. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value (ignored). + * @param b[in] The second input value. + * @return The same value as the second input. + */ +struct copy2 { // this turns out to be a slightly hacky op that makes some code cleaner :/ + template static __device__ inline T op(const T &a, const T &b) { return b; } +}; +/** + * @brief Sum operation. + * + * This operation calculates the sum of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The sum of the input values. + */ +struct sum { + template static __device__ inline T op(const T &a, const T &b) { return a+b; } +}; +template<> __device__ inline float2 sum::op(const float2 &a, const float2 &b) { +#ifdef KITTENS_BLACKWELL + float2 c; + asm volatile("add.f32x2 %0, %1, %2;" : "=l"(*(uint64_t*)&c) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&b)); + return c; +#else + return float2{a.x+b.x, a.y+b.y}; +#endif +} +template<> __device__ inline bf16 sum::op (const bf16 &a, const bf16 &b) { return __hadd(a, b); } +template<> __device__ inline bf16_2 sum::op(const bf16_2 &a, const bf16_2 &b) { return __hadd2(a, b); } +template<> __device__ inline half sum::op (const half &a, const half &b) { return __hadd(a, b); } +template<> __device__ inline half_2 sum::op(const half_2 &a, const half_2 &b) { return __hadd2(a, b); } +/** + * @brief Subtraction operation. + * + * This operation calculates the difference between two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The difference between the input values. + */ +struct sub { + template static __device__ inline T op(const T &a, const T &b) { return a-b; } +}; +template<> __device__ inline float2 sub::op(const float2 &a, const float2 &b) { +#ifdef KITTENS_BLACKWELL + float2 c; + asm volatile("sub.f32x2 %0, %1, %2;" : "=l"(*(uint64_t*)&c) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&b)); + return c; +#else + return float2{a.x-b.x, a.y-b.y}; +#endif +} +template<> __device__ inline bf16 sub::op (const bf16 &a, const bf16 &b) { return __hsub(a, b); } +template<> __device__ inline bf16_2 sub::op(const bf16_2 &a, const bf16_2 &b) { return __hsub2(a, b); } +template<> __device__ inline half sub::op (const half &a, const half &b) { return __hsub(a, b); } +template<> __device__ inline half_2 sub::op(const half_2 &a, const half_2 &b) { return __hsub2(a, b); } +/** + * @brief Multiplication operation. + * + * This operation calculates the product of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The product of the input values. + */ +struct mul { + template static __device__ inline T op(const T &a, const T &b) { return a*b; } +}; +template<> __device__ inline float2 mul::op(const float2 &a, const float2 &b) { +#ifdef KITTENS_BLACKWELL + float2 c; + asm volatile("mul.f32x2 %0, %1, %2;" : "=l"(*(uint64_t*)&c) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&b)); + return c; +#else + return float2{a.x*b.x, a.y*b.y}; +#endif +} +template<> __device__ inline bf16 mul::op (const bf16 &a, const bf16 &b) { return __hmul(a, b); } +template<> __device__ inline bf16_2 mul::op(const bf16_2 &a, const bf16_2 &b) { return __hmul2(a, b); } +template<> __device__ inline half mul::op (const half &a, const half &b) { return __hmul(a, b); } +template<> __device__ inline half_2 mul::op(const half_2 &a, const half_2 &b) { return __hmul2(a, b); } +/** + * @brief Division operation. + * + * This operation calculates the quotient of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The quotient of the input values. + */ +struct div { + template static __device__ inline T op(const T &a, const T &b) { return a/b; } +}; +template<> __device__ inline float2 div::op(const float2 &a, const float2 &b) { return float2{a.x/b.x, a.y/b.y}; } +template<> __device__ inline bf16 div::op (const bf16 &a, const bf16 &b) { return __hdiv(a, b); } +template<> __device__ inline bf16_2 div::op(const bf16_2 &a, const bf16_2 &b) { return __h2div(a, b); } // this op is a special snowflake +template<> __device__ inline half div::op (const half &a, const half &b) { return __hdiv(a, b); } +template<> __device__ inline half_2 div::op(const half_2 &a, const half_2 &b) { return __h2div(a, b); } +/** + * @brief Maximum operation. + * + * This operation calculates the maximum of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The maximum of the input values. + */ + struct max { + template static __device__ inline T op(const T &a, const T &b) { return ::max(a, b); } +}; +template<> __device__ inline float2 max::op(const float2 &a, const float2 &b) { return float2{::max(a.x, b.x), ::max(a.y, b.y)}; } +template<> __device__ inline bf16 max::op (const bf16 &a, const bf16 &b) { return __hmax(a, b); } +template<> __device__ inline bf16_2 max::op(const bf16_2 &a, const bf16_2 &b) { return __hmax2(a, b); } +template<> __device__ inline half max::op (const half &a, const half &b) { return __hmax(a, b); } +template<> __device__ inline half_2 max::op(const half_2 &a, const half_2 &b) { return __hmax2(a, b); } +/** + * @brief Minimum operation. + * + * This operation calculates the minimum of two input values. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @return The minimum of the input values. + */ +struct min { + template static __device__ inline T op(const T &a, const T &b) { return ::min(a, b); } +}; +template<> __device__ inline float2 min::op(const float2 &a, const float2 &b) { return float2{::min(a.x, b.x), ::min(a.y, b.y)}; } +template<> __device__ inline bf16 min::op (const bf16 &a, const bf16 &b) { return __hmin(a, b); } +template<> __device__ inline bf16_2 min::op(const bf16_2 &a, const bf16_2 &b) { return __hmin2(a, b); } +template<> __device__ inline half min::op (const half &a, const half &b) { return __hmin(a, b); } +template<> __device__ inline half_2 min::op(const half_2 &a, const half_2 &b) { return __hmin2(a, b); } + + +/* ---------- TERNARY OPS ---------- */ + +/** + * @brief Fused multiply-add operation A * B + C. + * + * This operation performs a fused multiply-add, computing (A * B) + C with only one rounding. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The second input value. + * @param c[in] The third input value to be added. + * @return The result of the fused multiply-add operation. + */ +struct fma_AxBtC { + template static __device__ inline T op(const T &a, const T &b, const T &c) { + return sum::op(mul::op(a, b), c); + } +}; +template<> __device__ inline float2 fma_AxBtC::op(const float2 &a, const float2 &b, const float2 &c) { +#ifdef KITTENS_BLACKWELL + float2 d; + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;" : "=l"(*(uint64_t*)&d) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&b), "l"(*(uint64_t*)&c)); + return d; +#else + return float2{a.x*b.x+c.x, a.y*b.y+c.y}; +#endif +} +/** + * @brief Fused multiply-add operation A * C + B. + * + * This operation performs a fused multiply-add, computing (A * C) + B with only one rounding. + * This is particularly useful for attention mechanisms in neural networks. + * + * @tparam T The data type of the input and output values. + * @param a[in] The first input value. + * @param b[in] The third input value to be added. + * @param c[in] The second input value. + * @return The result of the fused multiply-add operation. + */ +struct fma_AxCtB { // this is the one needed for attention + template static __device__ inline T op(const T &a, const T &b, const T &c) { + return sum::op(mul::op(a, c), b); + } +}; +template<> __device__ inline float2 fma_AxCtB::op(const float2 &a, const float2 &b, const float2 &c) { +#ifdef KITTENS_BLACKWELL + float2 d; + asm volatile("fma.rn.f32x2 %0, %1, %2, %3;" : "=l"(*(uint64_t*)&d) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&c), "l"(*(uint64_t*)&b)); + return d; +#else + return float2{a.x*c.x+b.x, a.y*c.y+b.y}; +#endif +} + +} // namespace base_ops + +} // namespace kittens diff --git a/extra/thunder/cuda/include/common/base_types.cuh b/extra/thunder/cuda/include/common/base_types.cuh new file mode 100644 index 0000000000..bd1109c40c --- /dev/null +++ b/extra/thunder/cuda/include/common/base_types.cuh @@ -0,0 +1,519 @@ +/** + * @file + * @brief Declarations, manipulations, and wrappers for basic types. + * + * This file is a bunch of utilities for going back and forth between different types. + * + * Many of them are for the compiler, so as to clean up the code. It unfortunately + * seems necessary when we have types we really care about that are less than word width. + */ + +#pragma once + +#ifdef KITTENS_HOPPER +#include +#endif + +#include +#include +#include +#include + + +namespace kittens { + +/** + * @brief Bfloat16 floating-point type. + */ +using bf16 = __nv_bfloat16; +/** + * @brief Half-precision floating-point type. + */ +using half = __half; +/** + * @brief Packed word of two bfloat16 floating-point values. + */ +using bf16_2 = __nv_bfloat162; +/** + * @brief Packed word of two half-precision floating-point values. + */ +using half_2 = __half2; +#ifdef KITTENS_HOPPER +/** + * @brief float8 floating-point type. + */ +using fp8e4m3 = __nv_fp8_e4m3; +using fp8e5m2 = __nv_fp8_e5m2; +#ifdef KITTENS_BLACKWELL +using fp8e8m0 = __nv_fp8_e8m0; +#endif +/** + * @brief 2-packed float8 floating-point type. + */ +using fp8e4m3_2 = __nv_fp8x2_e4m3; +using fp8e5m2_2 = __nv_fp8x2_e5m2; +#ifdef KITTENS_BLACKWELL +using fp8e8m0_2 = __nv_fp8x2_e8m0; +#endif +/** + * @brief 4-packed float8 floating-point type. + */ +using fp8e4m3_4 = __nv_fp8x4_e4m3; +using fp8e5m2_4 = __nv_fp8x4_e5m2; +#ifdef KITTENS_BLACKWELL +using fp8e8m0_4 = __nv_fp8x4_e8m0; +#endif +#endif + +namespace ducks { +/** + * @namespace base_types + * + * @brief A namespace for concepts for basic data types. + */ +namespace base_types { + +#ifdef KITTENS_HOPPER +#ifdef KITTENS_BLACKWELL +template +concept T2 = std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v; // could add half_2 later if implemented. +template +concept T1 = std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v; // could add half_2 later if implemented. +#else +template +concept T2 = std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v; +template +concept T1 = std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v; +#endif +#else +template +concept T2 = std::is_same_v || std::is_same_v || std::is_same_v; +template +concept T1 = std::is_same_v || std::is_same_v || std::is_same_v; +#endif + +} // namespace base_types +} // namespace ducks + +/** + * @namespace base_types + * + * @brief A namespace for ThunderKittens basic data types. + */ +namespace base_types { + +/** + * @brief Provides compile-time constants for different types. + * + * @tparam T The type for which to provide constants. + */ +template struct constants { + /** + * @brief Zero + * @return Constexpr zero with type T + */ + static __device__ inline constexpr T zero() { return T{0}; } + /** + * @brief One + * @return Constexpr one with type T + */ + static __device__ inline constexpr T one() { return T{1}; } + /** + * @brief Positive infinity. Particularly useful for initializing before a min op. + * @return Constexpr positive infinity with type T + */ + static __device__ inline constexpr T pos_infty() { return T{INFINITY}; } // I'll find a better way at some point but this appears to work. + /** + * @brief Negative infinity. Particularly useful for initializing before a max op. + * @return Constexpr negative infinity with type T + */ + static __device__ inline constexpr T neg_infty() { return T{-INFINITY}; } +}; +template<> struct constants { + static __device__ inline constexpr float2 zero() { return float2{0.f, 0.f}; } + static __device__ inline constexpr float2 one() { return float2{1.f, 1.f}; } + static __device__ inline constexpr float2 pos_infty() { return float2{constants::pos_infty(), constants::pos_infty()}; } + static __device__ inline constexpr float2 neg_infty() { return float2{constants::neg_infty(), constants::neg_infty()}; } +}; +template<> struct constants { + static __device__ inline constexpr bf16 zero() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x0000)); } // unfortunately __float2bf16_rn is not constexpr + static __device__ inline constexpr bf16 one() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x3F80)); } + static __device__ inline constexpr bf16 pos_infty() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x7F80)); } + static __device__ inline constexpr bf16 neg_infty() { return std::bit_cast<__nv_bfloat16>(uint16_t(0xFF80)); } +}; +template<> struct constants { + static __device__ inline constexpr bf16_2 zero() { return bf16_2{constants::zero(), constants::zero()}; } + static __device__ inline constexpr bf16_2 one() { return bf16_2{constants::one(), constants::one()}; } + static __device__ inline constexpr bf16_2 pos_infty() { return bf16_2{constants::pos_infty(), constants::pos_infty()}; } + static __device__ inline constexpr bf16_2 neg_infty() { return bf16_2{constants::neg_infty(), constants::neg_infty()}; } +}; +template<> struct constants { + static __device__ inline constexpr half zero() { return std::bit_cast<__half>(uint16_t(0x0000)); } + static __device__ inline constexpr half one() { return std::bit_cast<__half>(uint16_t(0x3C00)); } + static __device__ inline constexpr half pos_infty() { return std::bit_cast<__half>(uint16_t(0x7C00)); } + static __device__ inline constexpr half neg_infty() { return std::bit_cast<__half>(uint16_t(0xFC00)); } +}; +template<> struct constants { + static __device__ inline constexpr half_2 zero() { return half_2{constants::zero(), constants::zero()}; } + static __device__ inline constexpr half_2 one() { return half_2{constants::one(), constants::one()}; } + static __device__ inline constexpr half_2 pos_infty() { return half_2{constants::pos_infty(), constants::pos_infty()}; } + static __device__ inline constexpr half_2 neg_infty() { return half_2{constants::neg_infty(), constants::neg_infty()}; } +}; +#ifdef KITTENS_HOPPER +template<> struct constants { + static __device__ inline constexpr fp8e4m3 zero() { return std::bit_cast<__nv_fp8_e4m3>(uint8_t(0x00)); } + static __device__ inline constexpr fp8e4m3 one() { return std::bit_cast<__nv_fp8_e4m3>(uint8_t(0x38)); } +}; +template<> struct constants { + static __device__ inline constexpr fp8e4m3_2 zero() { return std::bit_cast(uint16_t(0x0000)); } + static __device__ inline constexpr fp8e4m3_2 one() { return std::bit_cast(uint16_t(0x3838)); } +}; +template<> struct constants { + static __device__ inline constexpr fp8e4m3_4 zero() { return std::bit_cast(uint32_t(0x00000000)); } + static __device__ inline constexpr fp8e4m3_4 one() { return std::bit_cast(uint32_t(0x38383838)); } +}; +template<> struct constants { + static __device__ inline constexpr fp8e5m2 zero() { return std::bit_cast<__nv_fp8_e5m2>(uint8_t(0x00)); } + static __device__ inline constexpr fp8e5m2 one() { return std::bit_cast<__nv_fp8_e5m2>(uint8_t(0x3C)); } +}; +template<> struct constants { + static __device__ inline constexpr fp8e5m2_2 zero() { return std::bit_cast(uint16_t(0x0000)); } + static __device__ inline constexpr fp8e5m2_2 one() { return std::bit_cast(uint16_t(0x3C3C)); } +}; +template<> struct constants { + static __device__ inline constexpr fp8e5m2_4 zero() { return std::bit_cast(uint32_t(0x00000000)); } + static __device__ inline constexpr fp8e5m2_4 one() { return std::bit_cast(uint32_t(0x3C3C3C3C)); } +}; +#endif + +template<> struct constants { + static __device__ inline constexpr int zero() { return 0; } + static __device__ inline constexpr int one() { return 1; } +}; +template<> struct constants { + static __device__ inline constexpr int2 zero() { return int2{0, 0}; } + static __device__ inline constexpr int2 one() { return int2{1, 1}; } +}; + +/** + * @brief Provides information about packing of elements for a given type. + * + * @tparam T The type for which to provide packing information. + */ +template struct packing { + /** + * @brief The number of elements packed together. + * + * @return constexpr int representing number of elements within the type. + */ + static __device__ inline constexpr int num() { return 1; } + /** + * @brief Packs a single T element twice (replicated) into its packed type. + * + * @param i[in] The element to pack. + * @return The packed type. + */ + static __device__ inline constexpr T pack(const bf16 &i); +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = bf16; + using packed_type = bf16_2; + static __device__ inline constexpr bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = bf16; + using packed_type = bf16_2; + static __device__ inline constexpr bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = half; + using packed_type = half_2; + static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = half; + using packed_type = half_2; + static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = float; + using packed_type = float2; + static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = float; + using packed_type = float2; + static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = char; + using packed_type = char2; + static __device__ inline constexpr char2 pack(const char &i) { return char2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = char; + using packed_type = char2; + static __device__ inline constexpr char2 pack(const char &i) { return char2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = int; + using packed_type = int2; + static __device__ inline constexpr int2 pack(const int &i) { return int2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = int; + using packed_type = int2; + static __device__ inline constexpr int2 pack(const int &i) { return int2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = uint; + using packed_type = uint2; + static __device__ inline constexpr uint2 pack(const uint &i) { return uint2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = uint; + using packed_type = uint2; + static __device__ inline constexpr uint2 pack(const uint &i) { return uint2{i, i}; } // this replication makes code cleaner later. +}; +struct uint64_2 { uint64_t x, y; }; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = uint64_t; + using packed_type = uint64_2; + static __device__ inline constexpr uint64_2 pack(const uint64_t &i) { return uint64_2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 2; } + using unpacked_type = uint64_t; + using packed_type = uint64_2; + static __device__ inline constexpr uint64_2 pack(const uint64_t &i) { return uint64_2{i, i}; } // this replication makes code cleaner later. +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 4; } +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 4; } +}; +#ifdef KITTENS_HOPPER +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = fp8e4m3; + using packed_type = fp8e4m3_4; +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 4; } + using unpacked_type = fp8e4m3; + using packed_type = fp8e4m3_4; +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = fp8e5m2; + using packed_type = fp8e5m2_4; +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 4; } + using unpacked_type = fp8e5m2; + using packed_type = fp8e5m2_4; +}; +#ifdef KITTENS_BLACKWELL +template<> struct packing { + static __device__ inline constexpr int num() { return 1; } + using unpacked_type = fp8e8m0; + using packed_type = fp8e8m0_4; +}; +template<> struct packing { + static __device__ inline constexpr int num() { return 4; } + using unpacked_type = fp8e8m0; + using packed_type = fp8e8m0_4; +}; +#endif +#endif + + +/** + * @brief Provides templated functionality to convert between different types. + * + * @tparam T The target type for conversion. + * @tparam U The source type for conversion. + */ +template struct convertor { + /** + * @brief Converts a value of type U to type T. + * + * @param u[in] The value of type U to convert. + * @return T The converted value of type T. + */ + static __host__ __device__ inline T convert(const U & u) { + return (T)u; + } +}; +template<> struct convertor { + static __host__ __device__ inline float convert(const bf16 & u) { + return __bfloat162float(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline bf16 convert(const float & u) { + return __float2bfloat16_rn(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float2 convert(const bf16_2 & u) { + return __bfloat1622float2(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline bf16_2 convert(const float2 & u) { + return __float22bfloat162_rn(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float convert(const half & u) { + return __half2float(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline half convert(const float & u) { + return __float2half(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float2 convert(const half_2 & u) { + return __half22float2(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline half_2 convert(const float2 & u) { + return __float22half2_rn(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline bf16 convert(const half & u) { + return __float2bfloat16_rn(__half2float(u)); + } +}; +template<> struct convertor { + static __host__ __device__ inline half convert(const bf16 & u) { + return __float2half(__bfloat162float(u)); + } +}; +template<> struct convertor { + static __host__ __device__ inline bf16_2 convert(const half_2 & u) { + return __float22bfloat162_rn(__half22float2(u)); + } +}; +template<> struct convertor { + static __host__ __device__ inline half_2 convert(const bf16_2 & u) { + return __float22half2_rn(__bfloat1622float2(u)); + } +}; +#ifdef KITTENS_HOPPER +// fp8e4m3 +template<> struct convertor { + static __host__ __device__ inline fp8e4m3_4 convert(const float4& u) { + return __nv_fp8x4_e4m3(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float4 convert(const fp8e4m3_4& u) { + __nv_fp8_e4m3 *vals = reinterpret_cast<__nv_fp8_e4m3*>(const_cast<__nv_fp8x4_e4m3*>(&u)); + return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3])); + } +}; +template<> struct convertor { + static __host__ __device__ inline fp8e4m3_2 convert(const float2& u) { + return __nv_fp8x2_e4m3(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float2 convert(const fp8e4m3_2& u) { + __nv_fp8_e4m3 *vals = reinterpret_cast<__nv_fp8_e4m3*>(const_cast<__nv_fp8x2_e4m3*>(&u)); + return make_float2(float(vals[0]), float(vals[1])); + } +}; +template<> struct convertor { + static __host__ __device__ inline fp8e4m3 convert(const float & u) { + return __nv_fp8_e4m3(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float convert(const fp8e4m3 & u) { + return float(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline bf16_2 convert(const fp8e4m3_4 & u) { + float4 f4 = convertor::convert(u); + float2 f2 = make_float2(f4.x, f4.y); + return __float22bfloat162_rn(f2); + } +}; +template<> struct convertor { + static __host__ __device__ inline fp8e4m3_4 convert(const bf16_2 & u) { + float2 f2 = __bfloat1622float2(u); + float4 f4 = make_float4(f2.x, f2.y, 0.0f, 0.0f); + return __nv_fp8x4_e4m3(f4); + } +}; +// fp8e5m2 +template<> struct convertor { + static __host__ __device__ inline fp8e5m2_4 convert(const float4& u) { + return __nv_fp8x4_e5m2(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float4 convert(const fp8e5m2_4& u) { + __nv_fp8_e5m2 *vals = reinterpret_cast<__nv_fp8_e5m2*>(const_cast<__nv_fp8x4_e5m2*>(&u)); + return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3])); + } +}; +template<> struct convertor { + static __host__ __device__ inline fp8e5m2_2 convert(const float2& u) { + return __nv_fp8x2_e5m2(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float2 convert(const fp8e5m2_2& u) { + __nv_fp8_e5m2 *vals = reinterpret_cast<__nv_fp8_e5m2*>(const_cast<__nv_fp8x2_e5m2*>(&u)); + return make_float2(float(vals[0]), float(vals[1])); + } +}; +template<> struct convertor { + static __host__ __device__ inline fp8e5m2 convert(const float & u) { + return __nv_fp8_e5m2(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline float convert(const fp8e5m2 & u) { + return float(u); + } +}; +template<> struct convertor { + static __host__ __device__ inline bf16_2 convert(const fp8e5m2_4 & u) { + float4 f4 = convertor::convert(u); + float2 f2 = make_float2(f4.x, f4.y); + return __float22bfloat162_rn(f2); + } +}; +template<> struct convertor { + static __host__ __device__ inline fp8e5m2_4 convert(const bf16_2 & u) { + float2 f2 = __bfloat1622float2(u); + float4 f4 = make_float4(f2.x, f2.y, 0.0f, 0.0f); + return __nv_fp8x4_e5m2(f4); + } +}; +#endif +} +} diff --git a/extra/thunder/cuda/include/common/common.cuh b/extra/thunder/cuda/include/common/common.cuh new file mode 100644 index 0000000000..7a95a713cf --- /dev/null +++ b/extra/thunder/cuda/include/common/common.cuh @@ -0,0 +1,11 @@ +/** + * @file + * @brief A collection of common resources on which ThunderKittens depends. + */ + + +#pragma once + +#include "util.cuh" +#include "base_types.cuh" +#include "base_ops.cuh" \ No newline at end of file diff --git a/extra/thunder/cuda/include/common/debug.cuh b/extra/thunder/cuda/include/common/debug.cuh new file mode 100644 index 0000000000..586cbd3ce2 --- /dev/null +++ b/extra/thunder/cuda/include/common/debug.cuh @@ -0,0 +1,56 @@ +#pragma once + +// Reset +#define TK_RESET "\033[0m" + +// Foreground colors +#define TK_FG_BLACK "\033[30m" +#define TK_FG_RED "\033[31m" +#define TK_FG_GREEN "\033[32m" +#define TK_FG_YELLOW "\033[33m" +#define TK_FG_BLUE "\033[34m" +#define TK_FG_MAGENTA "\033[35m" +#define TK_FG_CYAN "\033[36m" +#define TK_FG_WHITE "\033[37m" + +// Background colors +#define TK_BG_BLACK "\033[40m" +#define TK_BG_RED "\033[41m" +#define TK_BG_GREEN "\033[42m" +#define TK_BG_YELLOW "\033[43m" +#define TK_BG_BLUE "\033[44m" +#define TK_BG_MAGENTA "\033[45m" +#define TK_BG_CYAN "\033[46m" +#define TK_BG_WHITE "\033[47m" + +// Bright foreground colors +#define TK_FG_BRIGHT_BLACK "\033[90m" +#define TK_FG_BRIGHT_RED "\033[91m" +#define TK_FG_BRIGHT_GREEN "\033[92m" +#define TK_FG_BRIGHT_YELLOW "\033[93m" +#define TK_FG_BRIGHT_BLUE "\033[94m" +#define TK_FG_BRIGHT_MAGENTA "\033[95m" +#define TK_FG_BRIGHT_CYAN "\033[96m" +#define TK_FG_BRIGHT_WHITE "\033[97m" + +// Bright background colors +#define TK_BG_BRIGHT_BLACK "\033[100m" +#define TK_BG_BRIGHT_RED "\033[101m" +#define TK_BG_BRIGHT_GREEN "\033[102m" +#define TK_BG_BRIGHT_YELLOW "\033[103m" +#define TK_BG_BRIGHT_BLUE "\033[104m" +#define TK_BG_BRIGHT_MAGENTA "\033[105m" +#define TK_BG_BRIGHT_CYAN "\033[106m" +#define TK_BG_BRIGHT_WHITE "\033[107m" + +// Text styles +#define TK_BOLD "\033[1m" +#define TK_DIM "\033[2m" +#define TK_ITALIC "\033[3m" +#define TK_UNDERLINE "\033[4m" +#define TK_BLINK "\033[5m" +#define TK_REVERSE "\033[7m" +#define TK_HIDDEN "\033[8m" + +// Macro to combine styles +#define TK_STYLE(...) "\033[" #__VA_ARGS__ "m" \ No newline at end of file diff --git a/extra/thunder/cuda/include/common/util.cuh b/extra/thunder/cuda/include/common/util.cuh new file mode 100644 index 0000000000..d1f06f9cbc --- /dev/null +++ b/extra/thunder/cuda/include/common/util.cuh @@ -0,0 +1,314 @@ +/** + * @file + * @brief General utilities for ThunderKittens. + */ + +#pragma once + +#include +#include +#include +#include + +// CUDA driver API +#define CUCHECK(cmd) do { \ + CUresult err = cmd; \ + if (err != CUDA_SUCCESS) { \ + const char *errStr; \ + cuGetErrorString(err, &errStr); \ + fprintf(stderr, "Failed: CUDA error %s:%d '%s'\n", \ + __FILE__, __LINE__, errStr); \ + exit(EXIT_FAILURE); \ + } \ +} while(0) + +// CUDA runtime API +#define CUDACHECK(cmd) do { \ + cudaError_t err = cmd; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "Failed: CUDA error %s:%d '%s'\n", \ + __FILE__, __LINE__, cudaGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ +} while(0) + +/** + * @namespace kittens + * + * @brief The main namespace of ThunderKittens. + */ +namespace kittens { + +/* ---------- GENERAL CONSTANTS FOR KITTENS ---------- */ + +/** + * @brief Tile dimension constant. + */ +template constexpr int TILE_COL_DIM = sizeof(T) == 1 ? 32 : 16; +template constexpr int TILE_ROW_DIM = 16; +/** + * @brief Tile num elements constant calculated as TILE_DIM squared. + */ +template constexpr int TILE_ELEMENTS{TILE_COL_DIM*TILE_ROW_DIM}; +/** + * @brief Constant representing number of threads in a warp. + */ +constexpr int WARP_THREADS{32}; +/** + * @brief Constant representing number of threads in a warpgroup of four warps. + */ +constexpr int WARPGROUP_THREADS{128}; +/** + + * @brief Constant representing number of warps in a warpgroup of four warps. + */ +constexpr int WARPGROUP_WARPS{4}; +/** + + * @brief Get the warp ID of the current thread. + * @return The warp ID. + */ +__device__ static __forceinline__ int warpid() { + // uint32_t wid; + // asm volatile("mov.u32 %0, %warpid;" : "=r"(wid)); + // return wid; + return threadIdx.x >> 5; +} +/** + * @brief Get the warpgroup ID of the current thread. + * @return The warpgroup ID. + */ +__device__ static __forceinline__ int warpgroupid() { return warpid() >> 2; } +/** + * @brief Get the lane ID of the current thread within its warp. + * @return The lane ID. + */ +__device__ static __forceinline__ int laneid() { + // uint32_t lid; + // asm volatile("mov.u32 %0, %laneid;" : "=r"(lid)); + // return lid; + return threadIdx.x & 31; +} + +#if defined(KITTENS_HOPPER) +constexpr int MAX_SHARED_MEMORY = 227000; +#elif defined(KITTENS_A100) +constexpr int MAX_SHARED_MEMORY = 164000; +#elif defined(KITTENS_4090) +constexpr int MAX_SHARED_MEMORY = 100000; +#endif + +struct transpose { + static constexpr int N = 0; // not transposed + static constexpr int T = 1; // transposed +}; +struct axis { + static constexpr int ROW = 0; // row axis of a tile + static constexpr int COL = 1; // column axis of a tile +}; + +/* ---------- TYPE HELPERS ---------- */ + +/** + * @namespace ducks + * + * @brief ThunderKittens' namespace for template metaprogramming.. + * + * This includes primarily dummy types and concept wrappers, along + * with a few additional utilities. + */ +namespace ducks { + +/** + * @brief A type representing an empty default for a template. + */ +struct default_type {}; + +// This macro can't be done as a template, so it doesn't really have a location in kittens. +#define typeof(A) typename std::remove_const::type>::type + +} + +/* ---------- SHUFFLE UTILS ---------- */ + +/** + * @brief Mask constant for all active threads in a warp. + */ +static constexpr uint32_t MASK_ALL = 0xFFFFFFFF; + +/** + * @brief Perform a shuffle down operation on a packed type synchronously across a warp. + * @tparam T The type of the value to be shuffled. + * @param mask[in] The mask of active threads. + * @param f[in] The value to be shuffled. + * @param delta[in] The number of positions to shuffle down. + * @return The result of the shuffle operation. + */ +template +__device__ static inline T packed_shfl_down_sync(uint32_t mask, const T &f, int delta) { + return __shfl_down_sync(mask, f, delta); +} +template<> +__device__ inline float2 packed_shfl_down_sync(uint32_t mask, const float2 &f, int delta) { + float2 r; + r.x = __shfl_down_sync(mask, f.x, delta); + r.y = __shfl_down_sync(mask, f.y, delta); + return r; +} +/** + * @brief Perform a packed shuffle operation synchronously across a warp. + * @tparam T The type of the value to be shuffled. + * @param mask[in] The mask of active threads. + * @param f[in] The value to be shuffled. + * @param src[in] The source lane from which to shuffle. + * @return The result of the shuffle operation. + */ +template +__device__ static inline T packed_shfl_sync(uint32_t mask, const T &f, int src) { + return __shfl_sync(mask, f, src); +} +template<> +__device__ inline float2 packed_shfl_sync(uint32_t mask, const float2 &f, int src) { + float2 r; + r.x = __shfl_sync(mask, f.x, src); + r.y = __shfl_sync(mask, f.y, src); + return r; +} + +/* ---------- SHARED MEMORY UTILS ---------- */ + +// namespace ducks { +// namespace sb { +// struct identifier {}; +// } +// } + +// template +// struct sb { +// using identifier = ducks::sb::identifier; +// Args... args; +// }; + +// namespace ducks { +// namespace sb { +// template concept all = requires { +// typename T::identifier; +// } && std::is_same_v; +// } +// } + +// Joyously stolen from https://github.com/NVIDIA/cutlass/blob/5c447dd84f8ae0e1d48ff9a2eae26ce8c4958101/include/cute/container/alignment.hpp#L51 +#if defined(__CUDACC__) +#define KITTENS_ALIGN_AS(n) __align__(n) +#else +#define KITTENS_ALIGN_AS(n) alignas(n) +#endif + +#ifdef KITTENS_HOPPER +#define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(128) +#else +#define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(16) +#endif + +/** + * @brief Dummy structure for alignment purposes. Needed for WGMMA and TMA calls. + */ +struct KITTENS_DEFAULT_ALIGN alignment_dummy { int dummy; }; +/** + * @brief Very simple allocator for dynamic shared memory. Advances pointer and tracks alignments. + * @tparam default_alignment The default alignment this allocator will enforce. If <=0 (default -1) it will not align. + */ +#ifdef KITTENS_HOPPER +template +#else +template +#endif +struct shared_allocator { + int *ptr; + + private: + // Recursive template to generate N-dimensional array type + template + struct variadic_array; + template + struct variadic_array { + using type = typename variadic_array::type[first_dim]; + }; + template + struct variadic_array { + using type = A; + }; + template + using variadic_array_t = typename variadic_array::type; + + template + __device__ inline void align_ptr() { + if constexpr (alignment > 0) { + uint64_t p = reinterpret_cast(ptr); + if(p % alignment != 0) { + ptr = (int*)(p + (alignment-(p%alignment))); + } + } + } + + public: + /** + * @brief Construct a new shared allocator using a pointer to extern shared memory. + * @param[in] _ptr Pointer to the start of the extern shared memory. + */ + __device__ shared_allocator(int *_ptr): ptr(_ptr) {} + /** + * @brief Allocate shared memory for a single instance or N-dimensional array of type A. + * @tparam A The type of the object to allocate. + * @tparam dims... A list of dimensions for the N-dimensional array. + * @return Reference to the allocated object. + */ + template + __device__ inline variadic_array_t& allocate() { + // static_assert(sizeof(A) % default_alignment == 0, "Type is not aligned properly for array allocation"); + align_ptr(); + using at = variadic_array_t; + at*p = reinterpret_cast(ptr); + ptr += sizeof(at)/sizeof(int); + return *p; + } + /** + * @brief Allocate shared memory for a single instance or N-dimensional array of type A. + * @tparam alignment An alignment to enforce for this particular object. + * @tparam A The type of the object to allocate. + * @tparam dims... A list of dimensions for the N-dimensional array. + * @return Reference to the allocated object. + */ + template + __device__ inline variadic_array_t& allocate() { + // static_assert(sizeof(A) % alignment == 0, "Type is not aligned properly for array allocation"); + align_ptr(); + using at = variadic_array_t; + at*p = reinterpret_cast(ptr); + ptr += sizeof(at)/sizeof(int); + return *p; + } +}; +#if (defined(KITTENS_HOPPER) || defined(KITTENS_BLACKWELL)) +/** + * @brief A wrapper for an allocator that enforces sufficient alignment to be used for TMA loads and stores. + */ +using tma_allocator = shared_allocator<1024>; +using tma_swizzle_allocator = tma_allocator; // swizzled TMA modes require up to 1024 byte alignments :/ + +/* Get CTA ID within a cluster */ +__device__ static inline int3 clusterIdx() { + int3 cluster_idx; + asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(cluster_idx.x)); + asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(cluster_idx.y)); + asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(cluster_idx.z)); + return cluster_idx; +} +__device__ static inline int cluster_ctarank() { + uint32_t ctarank; + asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(ctarank)); + return ctarank; +} +#endif + +} // namespace kittens diff --git a/extra/thunder/cuda/include/kittens.cuh b/extra/thunder/cuda/include/kittens.cuh new file mode 100644 index 0000000000..974a896f1c --- /dev/null +++ b/extra/thunder/cuda/include/kittens.cuh @@ -0,0 +1,12 @@ +/** + * @file + * @brief The master header file of ThunderKittens. This file includes everything you need! + */ + +#pragma once + +#include "common/common.cuh" +#include "types/types.cuh" +#include "ops/ops.cuh" +#include "pyutils/util.cuh" +// #include "pyutils/pyutils.cuh" // for simple binding without including torch \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/device/device.cuh b/extra/thunder/cuda/include/ops/device/device.cuh new file mode 100644 index 0000000000..412791ea45 --- /dev/null +++ b/extra/thunder/cuda/include/ops/device/device.cuh @@ -0,0 +1,51 @@ +/** + * @file + * @brief An aggregate header of all device (multi-GPU) operations defined by ThunderKittens + */ + +#pragma once + +#include "../../types/types.cuh" + +namespace kittens { + +template +struct device { + +static_assert(_NUM_DEVICES >= 0 && _NUM_DEVICES <= 72, "Invalid number of devices"); +static constexpr int NUM_DEVICES = _NUM_DEVICES; + +#ifdef KITTENS_HOPPER + +using barrier_t = pgl, NUM_DEVICES, true>; + +/** + * @brief Multi-GPU synchronization barrier for coordinated kernel exit + * + * Performs a synchronization across all devices to ensure all GPUs complete + * their work before any kernel exits. Does not synchronize intra-node threads + * or threadblocks. + * + * @param barrier Pre-allocated barrier structure, must be initialized to 0 + * @param dev_idx Current device index (0 to NUM_DEVICES - 1) + * @param id Synchronization point identifier (default: 0). 0 is fine for most cases + * + */ +__device__ static inline void sync_on_exit(const barrier_t &barrier, const int dev_idx, const int id = 0) { + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) { + cuda::atomic_ref barrier_uc(barrier[dev_idx][{id}]); + + // Inter-note check-in + multimem::red(barrier.mc_ptr_at({id}), 1); + asm volatile ("{fence.proxy.alias;}" ::: "memory"); + while (barrier_uc.load(cuda::memory_order_acquire) < NUM_DEVICES); + barrier_uc.fetch_sub(NUM_DEVICES, cuda::memory_order_release); + } +} + +#endif + +}; + +} // namespace kittens diff --git a/extra/thunder/cuda/include/ops/group/group.cuh b/extra/thunder/cuda/include/ops/group/group.cuh new file mode 100644 index 0000000000..1a9d69971c --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/group.cuh @@ -0,0 +1,96 @@ +/** + * @file + * @brief An aggregate header of all group (multi-warp) operations defined by ThunderKittens + */ + +#pragma once + +#include + +#include "../../common/common.cuh" +#include "../../types/types.cuh" +#include "../thread/thread.cuh" // several group memory ops rely on underlying warp-scope ops + +#define KITTENS_CHECK_WARP static_assert(GROUP_WARPS==1, "Warp (GROUP_WARPS=1) function called from a non-warp group."); +// A "warpgroup" is a special group of 4 consecutive warps defined by NVIDIA for certain SM_90+ operations. +#define KITTENS_CHECK_WARPGROUP static_assert(GROUP_WARPS==4, "Warpgroup (GROUP_WARPS=4) function called from a non-warpgroup group."); + +// WGMMA relies on some template structures that cannot be specialized within the group struct, so we declare them in advance. +#ifdef KITTENS_HOPPER +#include "mma/warpgroup/base/base.cuh" +#endif + +namespace kittens { +/* +This is meant to be used with a `using group_N = kittens::group;` at the start of every kernel. +*/ +template +struct group { +static constexpr int GROUP_WARPS = _GROUP_WARPS; // This alias produces nice parallelism. +static constexpr int GROUP_THREADS = GROUP_WARPS * kittens::WARP_THREADS; // This alias produces nice parallelism. +__device__ static inline int laneid() { return threadIdx.x % GROUP_THREADS; } +__device__ static inline int warpid() { return laneid() / kittens::WARP_THREADS; } +__device__ static inline int groupid() { return threadIdx.x / GROUP_THREADS; } + +__device__ static inline void sync(int id) { + asm volatile("bar.sync %0, %1;\n" :: "r"(id), "n"(GROUP_THREADS)); +} +template __device__ static inline void sync() { + static_assert(GROUP_WARPS==1, "barrier-less sync() can only be called by a single warp!"); + asm volatile("bar.warp.sync %0;\n" :: "n"(MASK)); +} +__device__ static inline void arrive(int id) { + asm volatile("bar.arrive %0, %1;\n" :: "r"(id), "n"(GROUP_THREADS)); +} + +#include "memory/memory.cuh" +#include "shared/shared.cuh" +#include "register/register.cuh" + +#ifdef KITTENS_HOPPER +#include "mma/mma.cuh" + +template __device__ static inline void increase_registers() { + static_assert(n_reg % 8 == 0, "n_reg must be a multiple of 8"); + asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" :: "n"(n_reg)); +} +template __device__ static inline void decrease_registers() { + static_assert(n_reg % 8 == 0, "n_reg must be a multiple of 8"); + asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" :: "n"(n_reg)); +} +__device__ static inline void producer_registers() { decrease_registers<24>(); } +template __device__ static inline void consumer_registers() { increase_registers<480/NCWG - 8*(NCWG>3) - 224*(NCWG==1)>(); } + +#endif + +}; + +namespace everyone { + +// Block-level synchronization +__device__ static inline void sync(int id) { + asm volatile("bar.sync %0;\n" :: "r"(id)); +} + +// Cluster-level synchronization functions +namespace tma { +namespace cluster { +__device__ static inline void arrive_aligned() { // All threads in the cluster must call this + asm volatile ("barrier.cluster.arrive.release.aligned;\n"); +} +__device__ static inline void wait_aligned() { + asm volatile ("barrier.cluster.wait.acquire.aligned;\n"); +} +__device__ static inline void sync() { + arrive_aligned(); + wait_aligned(); +} +} +} + +}; + +using warp = group<1>; // scope used by most pre-Hopper GPUs, and also for most register operations. +using warpgroup = group<4>; // special scope commonly used by Hopper and later. + +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/memory.cuh b/extra/thunder/cuda/include/ops/group/memory/memory.cuh new file mode 100644 index 0000000000..2607a11327 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/memory.cuh @@ -0,0 +1,21 @@ +/** + * @file + * @brief An aggregate header of colaborative group memory movement operations + */ + +#include "util/util.cuh" +#include "tile/tile.cuh" +#include "vec/vec.cuh" + +#ifdef KITTENS_HOPPER +struct tma { +#include "util/tma.cuh" +#include "tile/tma.cuh" +#include "vec/tma.cuh" +struct cluster { +#include "util/tma_cluster.cuh" +#include "tile/tma_cluster.cuh" +#include "vec/tma_cluster.cuh" +}; +}; +#endif \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_global_to_register.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_global_to_register.cuh new file mode 100644 index 0000000000..fb35caa1b2 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_global_to_register.cuh @@ -0,0 +1,42 @@ +/** + * @file + * @brief Functions for a group to collaboratively transfer data directly between global memory and registers and back. + */ + +/** + * @brief Collaboratively loads data from a source array into register tiles. + * + * @tparam RT The register tile type. + * @tparam U The data type of the source array. + * @param dst[out] The destination tile to load data into. + * @param src[in] The source array to load data from. + * @param row_stride[in] The stride in elements between rows in the source array. + */ +template>> +__device__ inline static void load(CRT &dst, const CGL &src, const COORD &idx) { + load(dst.real, src.real, idx); + load(dst.imag, src.imag, idx); +} +template>> +__device__ inline static void load(CRT &dst, const CGL &src, const COORD &idx) { + load<2, CRT, CGL>(dst, src, idx); +} + +/** + * @brief Collaboratively stores data from register tiles to a destination array in global memory. + * + * @tparam RT The register tile type. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register tile to store data from. + * @param row_stride[in] The stride in elements between rows in the destination array. + */ +template>> +__device__ inline static void store(CGL &dst, const CRT &src, const COORD &idx) { + store(dst.real, src.real, idx); + store(dst.imag, src.imag, idx); +} +template>> +__device__ inline static void store(CGL &dst, const CRT &src, const COORD &idx) { + store<2, CRT, CGL>(dst, src, idx); +} diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_global_to_shared.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_global_to_shared.cuh new file mode 100644 index 0000000000..789bea84d2 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_global_to_shared.cuh @@ -0,0 +1,37 @@ +/** + * @file + * @brief Group (collaborative warp) ops for loading shared tiles from and storing to global memory. + */ + +template> +__device__ static inline void load(CST &dst, const CGL &src, const COORD &idx) { + load(dst.real, src.real, idx); + load(dst.imag, src.imag, idx); +} +template> +__device__ static inline void load(CST &dst, const CGL &src, const COORD &idx) { + load<2, false, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx); + load<2, false, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx); +} + +template> +__device__ static inline void store(CGL &dst, const CST &src, const COORD &idx) { + store(dst.real, src.real, idx); + store(dst.imag, src.imag, idx); +} +template> +__device__ static inline void store(CGL &dst, const CST &src, const COORD &idx) { + store<2, false, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx); + store<2, false, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx); +} + +template> +__device__ static inline void load_async(CST &dst, const CGL &src, const COORD &idx) { + load_async(dst.real, src.real, idx); + load_async(dst.imag, src.imag, idx); +} +template> +__device__ static inline void load_async(CST &dst, const CGL &src, const COORD &idx) { + load_async<2, false, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx); + load_async<2, false, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_shared_to_register.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_shared_to_register.cuh new file mode 100644 index 0000000000..85b2d0437e --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/complex/complex_shared_to_register.cuh @@ -0,0 +1,34 @@ +/** + * @file + * @brief Functions for a warpgroup to collaboratively transfer data directly between shared memory and registers and back. + */ + +/** + * @brief Collaboratively load data from a shared tile into register tiles split across a warpgroup. + * + * @tparam RT The register tile type + * @tparam ST The shared tile type + * @param dst[out] The destination register tile. + * @param src[in] The source shared tile. + */ +template +__device__ inline static void load(RT &dst, const ST &src) { + load(dst.real, src.real); + load(dst.imag, src.imag); +} + + +/** + * @brief Collaboratively store data into a shared tile from register tiles split across a warpgroup. + * + * @tparam RT The register tile type + * @tparam ST The shared tile type + * @param dst[out] The destination shared tile. + * @param src[in] The source register tile. + */ +template +__device__ inline static void store(ST &dst, const RT &src) { + store(dst.real, src.real); + store(dst.imag, src.imag); +} + diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/global_to_register.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/global_to_register.cuh new file mode 100644 index 0000000000..e22570d116 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/global_to_register.cuh @@ -0,0 +1,207 @@ +/** + * @file + * @brief Functions for a group to collaboratively transfer data directly between global memory and registers and back. + */ + +/** + * @brief Collaboratively loads data from a source array into row-major layout tiles. + * + * @tparam RT The row-major layout tile type. + * @tparam U The data type of the source array. + * @param dst[out] The destination tile to load data into. + * @param src[in] The source array to load data from. + * @param row_stride[in] The stride in elements between rows in the source array. + */ +template>> +__device__ inline static void load(RT &dst, const GL &src, const COORD &idx) { + using T2 = RT::dtype; + using U = typename GL::dtype; + + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for load/store"); + #endif + + U *src_ptr = (U*)&src[(idx.template unit_coord())]; + const int row_stride = src.template stride(); + using U2 = base_types::packing::packed_type; + int warp_laneid = threadIdx.x % WARP_THREADS; + int local_warpid; + if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4)); + else local_warpid = warpid(); + const int row_offset = dst.rows*local_warpid; + #pragma unroll + for(int i = 0; i < dst.height; i++) { + int row = row_offset + i*dst.tile_size_row + (warp_laneid / 4); + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size_col + 2*(warp_laneid % 4); + dst.tiles[i][j].data[0] = base_types::convertor::convert(*(U2*)(&src_ptr[(row+0)*row_stride + (col+0)])); + dst.tiles[i][j].data[2] = base_types::convertor::convert(*(U2*)(&src_ptr[(row+0)*row_stride + (col+8)])); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size_col + 2*(warp_laneid % 4); + dst.tiles[i][j].data[1] = base_types::convertor::convert(*(U2*)(&src_ptr[(row+8)*row_stride + (col+0)])); + dst.tiles[i][j].data[3] = base_types::convertor::convert(*(U2*)(&src_ptr[(row+8)*row_stride + (col+8)])); + } + } +} +/** + * @brief Collaboratively loads data from a source array into column-major layout tiles. + * + * @tparam RT The column-major layout tile type. + * @tparam U The data type of the source array. + * @param dst[out] The destination tile to load data into. + * @param src[in] The source array to load data from. + * @param row_stride[in] The stride in elements between rows in the source array. + */ +template>> +__device__ inline static void load(RT &dst, const GL &src, const COORD &idx) { + using T = typename RT::T; + using U = typename GL::dtype; + + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for load/store"); + #endif + + U *src_ptr = (U*)&src[(idx.template unit_coord())]; + const int row_stride = src.template stride(); + int warp_laneid = threadIdx.x % WARP_THREADS; + int local_warpid; + if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4)); + else local_warpid = warpid(); + const int row_offset = dst.rows*local_warpid; + #pragma unroll + for(int i = 0; i < dst.height; i++) { + int row = row_offset + i*dst.tile_size_row + 2*(warp_laneid % 4); + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size_col + (warp_laneid / 4); + dst.tiles[i][j].data[0].x = base_types::convertor::convert(src_ptr[(row+0)*row_stride + (col+0)]); + dst.tiles[i][j].data[1].x = base_types::convertor::convert(src_ptr[(row+0)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size_col + (warp_laneid / 4); + dst.tiles[i][j].data[0].y = base_types::convertor::convert(src_ptr[(row+1)*row_stride + (col+0)]); + dst.tiles[i][j].data[1].y = base_types::convertor::convert(src_ptr[(row+1)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size_col + (warp_laneid / 4); + dst.tiles[i][j].data[2].x = base_types::convertor::convert(src_ptr[(row+8)*row_stride + (col+0)]); + dst.tiles[i][j].data[3].x = base_types::convertor::convert(src_ptr[(row+8)*row_stride + (col+8)]); + } + #pragma unroll + for(int j = 0; j < dst.width; j++) { + int col = j*dst.tile_size_col + (warp_laneid / 4); + dst.tiles[i][j].data[2].y = base_types::convertor::convert(src_ptr[(row+9)*row_stride + (col+0)]); + dst.tiles[i][j].data[3].y = base_types::convertor::convert(src_ptr[(row+9)*row_stride + (col+8)]); + } + } +} +template>> +__device__ inline static void load(RT &dst, const GL &src, const COORD &idx) { + load<2>(dst, src, idx); +} +/** + * @brief Collaboratively stores data from register tiles to a destination array in global memory with a row-major layout. + * + * @tparam RT The register tile type with a row-major layout. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register tile to store data from. + * @param row_stride[in] The stride in elements between rows in the destination array. + */ +template>> +__device__ inline static void store(const GL &dst, const RT &src, const COORD &idx) { + using T2 = RT::dtype; + using U = typename GL::dtype; + + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for load/store"); + #endif + + U *dst_ptr = (U*)&dst[(idx.template unit_coord())]; + const int row_stride = dst.template stride(); + using U2 = base_types::packing::packed_type; + int warp_laneid = threadIdx.x % WARP_THREADS; + int local_warpid; + if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4)); + else local_warpid = warpid(); + const int row_offset = src.rows*local_warpid; + #pragma unroll + for(int i = 0; i < src.height; i++) { + int row = row_offset + i*src.tile_size_row + (warp_laneid / 4); + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size_col + 2*(warp_laneid % 4); + *(U2*)(&dst_ptr[(row+0)*row_stride + (col+0)]) = base_types::convertor::convert(src.tiles[i][j].data[0]); + *(U2*)(&dst_ptr[(row+0)*row_stride + (col+8)]) = base_types::convertor::convert(src.tiles[i][j].data[2]); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size_col + 2*(warp_laneid % 4); + *(U2*)(&dst_ptr[(row+8)*row_stride + (col+0)]) = base_types::convertor::convert(src.tiles[i][j].data[1]); + *(U2*)(&dst_ptr[(row+8)*row_stride + (col+8)]) = base_types::convertor::convert(src.tiles[i][j].data[3]); + } + } +} +/** + * @brief Collaboratively stores data from register tiles to a destination array in global memory with a column-major layout. + * + * @tparam RT The register tile type with a column-major layout. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register tile to store data from. + * @param row_stride[in] The stride in elements between rows in the destination array. + */ +template>> +__device__ inline static void store(const GL &dst, const RT &src, const COORD &idx) { + using T = base_types::packing::unpacked_type; + using U = typename GL::dtype; + + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for load/store"); + #endif + + U *dst_ptr = (U*)&dst[(idx.template unit_coord())]; + const int row_stride = dst.template stride(); + int warp_laneid = threadIdx.x % WARP_THREADS; + int local_warpid; + if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4)); + else local_warpid = warpid(); + const int row_offset = src.rows*local_warpid; + #pragma unroll + for(int i = 0; i < src.height; i++) { + int row = row_offset + i*src.tile_size_row + 2*(warp_laneid % 4); + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size_col + (warp_laneid / 4); + dst_ptr[(row+0)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[0].x); + dst_ptr[(row+0)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[1].x); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size_col + (warp_laneid / 4); + dst_ptr[(row+1)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[0].y); + dst_ptr[(row+1)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[1].y); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size_col + (warp_laneid / 4); + dst_ptr[(row+8)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[2].x); + dst_ptr[(row+8)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[3].x); + } + #pragma unroll + for(int j = 0; j < src.width; j++) { + int col = j*src.tile_size_col + (warp_laneid / 4); + dst_ptr[(row+9)*row_stride + (col+0)] = base_types::convertor::convert(src.tiles[i][j].data[2].y); + dst_ptr[(row+9)*row_stride + (col+8)] = base_types::convertor::convert(src.tiles[i][j].data[3].y); + } + } +} +template>> +__device__ inline static void store(const GL &dst, const RT &src, const COORD &idx) { + store<2>(dst, src, idx); +} diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/global_to_shared.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/global_to_shared.cuh new file mode 100644 index 0000000000..831f2298bf --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/global_to_shared.cuh @@ -0,0 +1,168 @@ +/** + * @file + * @brief Group (collaborative warp) ops for loading shared tiles from and storing to global memory. + */ + + +/** + * @brief Loads data from global memory into a shared memory tile. + * + * @tparam ST The type of the shared tile. + * @param[out] dst The destination shared memory tile. + * @param[in] src The source global memory array. + * @param[in] idx The coordinate of the tile in the global memory array. + */ +template> +__device__ static inline void load(ST &dst, const GL &src, const COORD &idx) { + using T = typename ST::dtype; + const int row_stride = src.template stride(); + // we can handle this many rows each time we run a memcpy_async + constexpr int elem_per_memcpy = sizeof(float4)/sizeof(typename ST::dtype); + constexpr int memcpy_per_row = dst.cols / elem_per_memcpy; + constexpr int total_calls = (dst.height*dst.width * kittens::TILE_ROW_DIM*kittens::TILE_COL_DIM + GROUP_THREADS*elem_per_memcpy-1) / (GROUP_THREADS*elem_per_memcpy); // round up + constexpr int total_rows = dst.height*dst.width; + + coord<> unit_coord = idx.template unit_coord(); + typename GL::dtype *src_ptr = (typename GL::dtype*)&src[unit_coord]; + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst.data[0])); + int laneid = threadIdx.x % GROUP_THREADS; + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int load_idx = i * GROUP_THREADS + laneid; + + int row = load_idx / memcpy_per_row; + int col = (load_idx*elem_per_memcpy) % dst.cols; + + if constexpr (assume_aligned) { + float4 tmp; + move::ldg(tmp, (float4*)&src_ptr[row*row_stride + col]); + move::sts(dst.idx(dst_ptr, {row, col}), tmp); + } + else { + if (row + unit_coord.template dim() < src.template shape()) { + float4 tmp; + move::ldg(tmp, (float4*)&src_ptr[row*row_stride + col]); + move::sts(dst.idx(dst_ptr, {row, col}), tmp); + } + else { + float4 zeros = {0.f,0.f,0.f,0.f}; + move::sts(dst.idx(dst_ptr, {row, col}), zeros); // use the default value + } + } + } +} +template> +__device__ static inline void load(ST &dst, const GL &src, const COORD &idx) { + load<2, false, ST, GL, COORD>(dst, src, idx); +} + +/** + * @brief Stores data from a shared memory tile into global memory. + * + * @tparam ST The type of the shared tile. + * @param[out] dst The destination global memory array. + * @param[in] src The source shared memory tile. + * @param row_stride[in] The stride between rows in the destination array. + */ +template> +__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) { + using T = typename ST::dtype; + const int row_stride = dst.template stride(); + // we can handle this many rows each time we run a memcpy_async + constexpr int elem_per_memcpy = sizeof(float4)/sizeof(typename ST::dtype); + constexpr int memcpy_per_row = src.cols / elem_per_memcpy; + constexpr int total_calls = (src.height*src.width * kittens::TILE_ROW_DIM*kittens::TILE_COL_DIM + GROUP_THREADS*elem_per_memcpy-1) / (GROUP_THREADS*elem_per_memcpy); // round up + + coord<> unit_coord = idx.template unit_coord(); + typename GL::dtype *dst_ptr = (typename GL::dtype*)&dst[unit_coord]; + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src.data[0])); + int laneid = threadIdx.x % GROUP_THREADS; + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int load_idx = i * GROUP_THREADS + laneid; + + int row = load_idx / memcpy_per_row; + int col = (load_idx*elem_per_memcpy) % src.cols; + + if constexpr (assume_aligned) { + float4 tmp; + move::lds(tmp, src.idx(src_ptr, {row, col})); + move::stg((float4*)&dst_ptr[row*row_stride + col], tmp); + } + else { + if (row + unit_coord.template dim() < dst.template shape()) { + float4 tmp; + move::lds(tmp, src.idx(src_ptr, {row, col})); + move::stg((float4*)&dst_ptr[row*row_stride + col], tmp); + } + } + } +} +template> +__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) { + store<2, false, ST, GL, COORD>(dst, src, idx); +} + +/** + * @brief Asynchronously loads data from global memory into a shared memory tile. + * + * @tparam ST The type of the shared tile. + * @param[out] dst The destination shared memory tile. + * @param[in] src The source global memory array. + * + * @note This function expects 16-byte alignments. Otherwise, behavior is undefined. + */ +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx) { + using T = typename ST::dtype; + const int row_stride = src.template stride(); + // we can handle this many rows each time we run a memcpy_async + constexpr int elem_per_memcpy = sizeof(float4)/sizeof(typename ST::dtype); + constexpr int memcpy_per_row = dst.cols / elem_per_memcpy; + constexpr int total_calls = (dst.height*dst.width * kittens::TILE_ROW_DIM*kittens::TILE_COL_DIM + GROUP_THREADS*elem_per_memcpy-1) / (GROUP_THREADS*elem_per_memcpy); // round up + + coord<> unit_coord = idx.template unit_coord(); + typename GL::dtype *src_ptr = (typename GL::dtype*)&src[unit_coord]; + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst.data[0])); + int laneid = threadIdx.x % GROUP_THREADS; + + #pragma unroll + for(int i = 0; i < total_calls; i++) { + + int load_idx = i * GROUP_THREADS + laneid; + + int row = load_idx / memcpy_per_row; + int col = (load_idx*elem_per_memcpy) % dst.cols; + + if constexpr (assume_aligned) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], 16;\n" + :: "r"(dst.idx(dst_ptr, {row, col})), "l"(&src_ptr[row*row_stride + col]) + : "memory" + ); + } + else { + if (row + unit_coord.template dim() < src.template shape()) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], 16;\n" + :: "r"(dst.idx(dst_ptr, {row, col})), "l"(&src_ptr[row*row_stride + col]) + : "memory" + ); + } + else { + // printf("thread %d skipping async load on row %d, col %d\n", threadIdx.x, row + unit_coord.template dim(), col); + float4 zeros = {0.f,0.f,0.f,0.f}; + move::sts(dst.idx(dst_ptr, {row, col}), zeros); // use the default value + } + } + } + asm volatile("cp.async.commit_group;\n" ::: "memory"); +} +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx) { + load_async<2, false, ST, GL, COORD>(dst, src, idx); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/shared_to_register.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/shared_to_register.cuh new file mode 100644 index 0000000000..06f5b2e076 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/shared_to_register.cuh @@ -0,0 +1,323 @@ +/** + * @file + * @brief Functions for a warpgroup to collaboratively transfer data directly between shared memory and registers and back. + */ + +/** + * @brief Collaboratively load data from a shared tile into register tiles split across a warpgroup. + * + * @tparam RT The register tile type + * @tparam ST The shared tile type + * @param dst[out] The destination register tile. + * @param src[in] The source shared tile. + */ +template +__device__ inline static void load(RT &dst, const ST &src) { + constexpr int height = ST::height; + constexpr int warp_height = RT::height; + static_assert(height%GROUP_WARPS == 0, "Group load / store requires tile height to be a multiple of GROUP_WARPS."); + static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height."); + static_assert(ST::width==RT::width, "Group load / store requires tile widths to match."); + int local_warpid; + if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4)); + else local_warpid = warpid(); + using T2 = RT::dtype; + using U = ST::dtype; + using T = base_types::packing::unpacked_type; + using U2 = base_types::packing::packed_type; + int warp_laneid = ::kittens::laneid(); + + // convert to shared state space + uint32_t shared_addr = static_cast(__cvta_generic_to_shared(&src.data[0])); + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + if constexpr (sizeof(typename ST::dtype) == 2) { + // handle the row-major layout for 16-bit types + U2 tmp[4]; + int row = (local_warpid*warp_height + i)*dst.tile_size_row + (warp_laneid % 16); + int col = j*dst.tile_size_col + (warp_laneid / 16) * 8; + if constexpr (std::is_same_v) { + move::ldsm4(tmp[0], tmp[1], tmp[2], tmp[3], src.idx(shared_addr, {row, col})); + } + else { + move::ldsm4t(tmp[0], tmp[2], tmp[1], tmp[3], src.idx(shared_addr, {row, col})); + } + dst.tiles[i][j].data[0] = base_types::convertor::convert(tmp[0]); + dst.tiles[i][j].data[1] = base_types::convertor::convert(tmp[1]); + dst.tiles[i][j].data[2] = base_types::convertor::convert(tmp[2]); + dst.tiles[i][j].data[3] = base_types::convertor::convert(tmp[3]); + } + else if constexpr (std::is_same_v && sizeof(typename ST::dtype) == 1) { + // handle the row-major layout for 8-bit types + int warp_group_16 = (warp_laneid / 16); // divide each warp into two groups of 16 threads + int lane_in_16 = warp_laneid % 16; // position in group of 16 threads + int row = (local_warpid*warp_height + i)*dst.tile_size_row + (lane_in_16 % 16); // find base row for warp in warpgroup and then distribute the 16 threads in the warp across the rows + int col = j*dst.tile_size_col + warp_group_16 * 16; // find base column and then *16 for second half of the warp + + U2 tmp[4]; + if constexpr (std::is_same_v) { + move::ldsm4(tmp[0], tmp[1], tmp[2], tmp[3], src.idx(shared_addr, {row, col})); + } + else { + move::ldsm4t(tmp[0], tmp[2], tmp[1], tmp[3], src.idx(shared_addr, {row, col})); + } + dst.tiles[i][j].data[0] = base_types::convertor::convert(tmp[0]); + dst.tiles[i][j].data[1] = base_types::convertor::convert(tmp[1]); + dst.tiles[i][j].data[2] = base_types::convertor::convert(tmp[2]); + dst.tiles[i][j].data[3] = base_types::convertor::convert(tmp[3]); + } + else if constexpr (std::is_same_v && sizeof(typename ST::dtype) == 4) { + // handle the row-major layout for 32-bit types + int row = (local_warpid*warp_height + i)*dst.tile_size_row + (warp_laneid / 4); + int col = j*dst.tile_size_col + 2*(warp_laneid % 4); + if constexpr (ST::rows != ST::underlying_rows || ST::cols != ST::underlying_cols) { // subtile case + row += src.row_offset; + col += src.col_offset; + } + int blit = sizeof(typename ST::dtype) * ((warp_laneid%4) / 2); + U2 tmp[4]; + static constexpr int swizzle_repeat = ST::swizzle_bytes * 8; + static constexpr int subtile_cols = ST::swizzle_bytes / sizeof(U); + const int outer_idx = col/subtile_cols; + const uint32_t addr_1 = shared_addr + sizeof(U)*(outer_idx*ST::underlying_rows*subtile_cols + (row+0)*subtile_cols + col%subtile_cols); + const uint32_t addr_2 = shared_addr + sizeof(U)*(outer_idx*ST::underlying_rows*subtile_cols + (row+8)*subtile_cols + col%subtile_cols); + const int swizzle_1 = blit ^ ((addr_1 % swizzle_repeat) >> 7) << 4; + const int swizzle_2 = blit ^ ((addr_2 % swizzle_repeat) >> 7) << 4; + move::lds(tmp[0].x, (addr_1+ 0)^swizzle_1); + move::lds(tmp[0].y, (addr_1+ 4)^swizzle_1); + move::lds(tmp[2].x, (addr_1+32)^swizzle_1); + move::lds(tmp[2].y, (addr_1+36)^swizzle_1); + move::lds(tmp[1].x, (addr_2+ 0)^swizzle_2); + move::lds(tmp[1].y, (addr_2+ 4)^swizzle_2); + move::lds(tmp[3].x, (addr_2+32)^swizzle_2); + move::lds(tmp[3].y, (addr_2+36)^swizzle_2); + dst.tiles[i][j].data[0] = base_types::convertor::convert(tmp[0]); + dst.tiles[i][j].data[1] = base_types::convertor::convert(tmp[1]); + dst.tiles[i][j].data[2] = base_types::convertor::convert(tmp[2]); + dst.tiles[i][j].data[3] = base_types::convertor::convert(tmp[3]); + if(blit) { + #pragma unroll + for(int k = 0; k < 4; k++) { + dst.tiles[i][j].data[k] = T2{dst.tiles[i][j].data[k].y, dst.tiles[i][j].data[k].x}; + } + } + } + else { + // handle the column-major layout + int row = (local_warpid*warp_height + i)*dst.tile_size_row + 2*(warp_laneid % 4); + int col = j*dst.tile_size_col + (warp_laneid / 4); + U2 tmp[4]; + move::lds(tmp[0].x, src.idx(shared_addr, {row+0, col+0})); + move::lds(tmp[0].y, src.idx(shared_addr, {row+1, col+0})); + move::lds(tmp[1].x, src.idx(shared_addr, {row+0, col+8})); + move::lds(tmp[1].y, src.idx(shared_addr, {row+1, col+8})); + move::lds(tmp[2].x, src.idx(shared_addr, {row+8, col+0})); + move::lds(tmp[2].y, src.idx(shared_addr, {row+9, col+0})); + move::lds(tmp[3].x, src.idx(shared_addr, {row+8, col+8})); + move::lds(tmp[3].y, src.idx(shared_addr, {row+9, col+8})); + dst.tiles[i][j].data[0] = base_types::convertor::convert(tmp[0]); + dst.tiles[i][j].data[1] = base_types::convertor::convert(tmp[1]); + dst.tiles[i][j].data[2] = base_types::convertor::convert(tmp[2]); + dst.tiles[i][j].data[3] = base_types::convertor::convert(tmp[3]); + } + } + } +} + + +/** + * @brief Collaboratively store data into a shared tile from register tiles split across a warpgroup. + * + * @tparam RT The register tile type + * @tparam ST The shared tile type + * @param dst[out] The destination shared tile. + * @param src[in] The source register tile. + */ +template +__device__ inline static void store(ST &dst, const RT &src) { + constexpr int height = ST::height; + constexpr int warp_height = RT::height; + static_assert(height%GROUP_WARPS == 0, "Group load / store requires tile height to be a multiple of GROUP_WARPS."); + static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height."); + static_assert(ST::width==RT::width, "Group load / store requires tile widths to match."); + int local_warpid; + if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4)); + else local_warpid = warpid(); + using T2 = RT::dtype; + using U = ST::dtype; + using T = base_types::packing::unpacked_type; + using U2 = base_types::packing::packed_type; + int warp_laneid = ::kittens::laneid(); + + // convert to shared state space + uint32_t shared_addr = static_cast(__cvta_generic_to_shared(&dst.data[0])); + + #pragma unroll + for(int i = 0; i < warp_height; i++) { + #pragma unroll + for(int j = 0; j < src.width; j++) { + if constexpr (sizeof(typename ST::dtype) == 2) { + // handle the row-major layout + U2 tmp[4]; + tmp[0] = base_types::convertor::convert(src.tiles[i][j].data[0]); + tmp[1] = base_types::convertor::convert(src.tiles[i][j].data[1]); + tmp[2] = base_types::convertor::convert(src.tiles[i][j].data[2]); + tmp[3] = base_types::convertor::convert(src.tiles[i][j].data[3]); +#ifdef KITTENS_HOPPER + int row = (local_warpid*warp_height + i)*src.tile_size_row + (warp_laneid % 16); + int col = j*src.tile_size_col + (warp_laneid / 16) * 8; + if constexpr (std::is_same_v) { + move::stsm4(dst.idx(shared_addr, {row, col}), tmp[0], tmp[1], tmp[2], tmp[3]); + } + else { + move::stsm4t(dst.idx(shared_addr, {row, col}), tmp[0], tmp[2], tmp[1], tmp[3]); + } +#else + if constexpr (std::is_same_v) { + int row = (local_warpid*warp_height + i)*src.tile_size_row + (warp_laneid / 4); + int col = j*src.tile_size_col + 2*(warp_laneid % 4); + move::sts(dst.idx(shared_addr, {row+0, col+0}), tmp[0]); + move::sts(dst.idx(shared_addr, {row+8, col+0}), tmp[1]); + move::sts(dst.idx(shared_addr, {row+0, col+8}), tmp[2]); + move::sts(dst.idx(shared_addr, {row+8, col+8}), tmp[3]); + } + else { + int row = (local_warpid*warp_height + i)*src.tile_size_row + 2*(warp_laneid % 4); + int col = j*src.tile_size_col + (warp_laneid / 4); + move::sts(dst.idx(shared_addr, {row+0, col+0}), tmp[0].x); + move::sts(dst.idx(shared_addr, {row+1, col+0}), tmp[0].y); + move::sts(dst.idx(shared_addr, {row+0, col+8}), tmp[1].x); + move::sts(dst.idx(shared_addr, {row+1, col+8}), tmp[1].y); + move::sts(dst.idx(shared_addr, {row+8, col+0}), tmp[2].x); + move::sts(dst.idx(shared_addr, {row+9, col+0}), tmp[2].y); + move::sts(dst.idx(shared_addr, {row+8, col+8}), tmp[3].x); + move::sts(dst.idx(shared_addr, {row+9, col+8}), tmp[3].y); + } +#endif + } + else if constexpr (std::is_same_v && sizeof(typename ST::dtype) == 1) { + // handle the row-major layout for 8-bit types + + int warp_group_16 = (warp_laneid / 16); // divide each warp into two groups of 16 threads + int lane_in_16 = warp_laneid % 16; // position in group of 16 threads + int row = (local_warpid*warp_height + i)*src.tile_size_row + (lane_in_16 % 16); // find base row for warp in warpgroup and then distribute the 16 threads in the warp across the rows + int col = j*src.tile_size_col + warp_group_16 * 16; // find base column and then *16 for second half of the warp + + U2 tmp[4]; + tmp[0] = base_types::convertor::convert(src.tiles[i][j].data[0]); + tmp[1] = base_types::convertor::convert(src.tiles[i][j].data[1]); + tmp[2] = base_types::convertor::convert(src.tiles[i][j].data[2]); + tmp[3] = base_types::convertor::convert(src.tiles[i][j].data[3]); + if constexpr (std::is_same_v) { + move::stsm4(dst.idx(shared_addr, {row, col}), tmp[0], tmp[1], tmp[2], tmp[3]); + } + else { + move::stsm4t(dst.idx(shared_addr, {row, col}), tmp[0], tmp[2], tmp[1], tmp[3]); + } + } + else if constexpr (std::is_same_v && sizeof(typename ST::dtype) == 4) { + // handle the row-major layout for 32-bit types + int row = (local_warpid*warp_height + i)*src.tile_size_row + (warp_laneid / 4); + int col = j*src.tile_size_col + 2*(warp_laneid % 4); + if constexpr (ST::rows != ST::underlying_rows || ST::cols != ST::underlying_cols) { // subtile case + row += dst.row_offset; + col += dst.col_offset; + } + int blit = sizeof(typename ST::dtype) * ((warp_laneid%4) / 2); + T2 reg_tmp[4]; + if(blit) { + #pragma unroll + for(int k = 0; k < 4; k++) { + reg_tmp[k] = T2{src.tiles[i][j].data[k].y, src.tiles[i][j].data[k].x}; + } + } + else { + #pragma unroll + for(int k = 0; k < 4; k++) { + reg_tmp[k] = src.tiles[i][j].data[k]; + } + } + U2 tmp[4]; + tmp[0] = base_types::convertor::convert(reg_tmp[0]); + tmp[1] = base_types::convertor::convert(reg_tmp[1]); + tmp[2] = base_types::convertor::convert(reg_tmp[2]); + tmp[3] = base_types::convertor::convert(reg_tmp[3]); + static constexpr int swizzle_repeat = ST::swizzle_bytes * 8; + static constexpr int subtile_cols = ST::swizzle_bytes / sizeof(U); + const int outer_idx = col/subtile_cols; + const uint32_t addr_1 = shared_addr + sizeof(U)*(outer_idx*ST::underlying_rows*subtile_cols + (row+0)*subtile_cols + col%subtile_cols); + const uint32_t addr_2 = shared_addr + sizeof(U)*(outer_idx*ST::underlying_rows*subtile_cols + (row+8)*subtile_cols + col%subtile_cols); + const int swizzle_1 = blit ^ ((addr_1 % swizzle_repeat) >> 7) << 4; + const int swizzle_2 = blit ^ ((addr_2 % swizzle_repeat) >> 7) << 4; + move::sts((addr_1+ 0)^swizzle_1, tmp[0].x); + move::sts((addr_1+ 4)^swizzle_1, tmp[0].y); + move::sts((addr_1+32)^swizzle_1, tmp[2].x); + move::sts((addr_1+36)^swizzle_1, tmp[2].y); + move::sts((addr_2+ 0)^swizzle_2, tmp[1].x); + move::sts((addr_2+ 4)^swizzle_2, tmp[1].y); + move::sts((addr_2+32)^swizzle_2, tmp[3].x); + move::sts((addr_2+36)^swizzle_2, tmp[3].y); + } + else { + // handle the column-major layout + int row = (local_warpid*warp_height + i)*src.tile_size_row + 2*(warp_laneid % 4); + int col = j*src.tile_size_col + (warp_laneid / 4); + U2 tmp[4]; + tmp[0] = base_types::convertor::convert(src.tiles[i][j].data[0]); + tmp[1] = base_types::convertor::convert(src.tiles[i][j].data[1]); + tmp[2] = base_types::convertor::convert(src.tiles[i][j].data[2]); + tmp[3] = base_types::convertor::convert(src.tiles[i][j].data[3]); + move::sts(dst.idx(shared_addr, {row+0, col+0}), tmp[0].x); + move::sts(dst.idx(shared_addr, {row+1, col+0}), tmp[0].y); + move::sts(dst.idx(shared_addr, {row+0, col+8}), tmp[1].x); + move::sts(dst.idx(shared_addr, {row+1, col+8}), tmp[1].y); + move::sts(dst.idx(shared_addr, {row+8, col+0}), tmp[2].x); + move::sts(dst.idx(shared_addr, {row+9, col+0}), tmp[2].y); + move::sts(dst.idx(shared_addr, {row+8, col+8}), tmp[3].x); + move::sts(dst.idx(shared_addr, {row+9, col+8}), tmp[3].y); + } + } + } +} + +// Load and store of vectors from/to shared tiles. + +template +__device__ inline static auto load(RV &dst, const ST &src, int2 row_col) { + KITTENS_CHECK_WARP; + static_assert(ST::cols>=RV::length, "Shared tile must be at least as wide as the vector."); + using T = RV::T; + using U = ST::T; + int warp_laneid = ::kittens::laneid(); + + // convert to shared state space + uint32_t shared_addr = static_cast(__cvta_generic_to_shared(&src.data[0])); + + #pragma unroll + for(int col = warp_laneid; col < dst.length; col+=WARP_THREADS) { + U tmp; + move::lds(tmp, src.idx(shared_addr, {row_col.x, row_col.y + col})); + dst.data[col/WARP_THREADS][0] = base_types::convertor::convert(tmp); + } +} + +template +__device__ inline static auto store(ST &dst, const RV &src, int2 row_col) { + KITTENS_CHECK_WARP; + static_assert(ST::cols>=RV::length, "Shared tile must be at least as wide as the vector."); + using T = RV::T; + using U = ST::T; + int warp_laneid = ::kittens::laneid(); + + // convert to shared state space + uint32_t shared_addr = static_cast(__cvta_generic_to_shared(&dst.data[0])); + + #pragma unroll + for(int col = warp_laneid; col < src.length; col+=WARP_THREADS) { + U tmp = base_types::convertor::convert(src.data[col/WARP_THREADS][0]); + move::sts(dst.idx(shared_addr, {row_col.x, row_col.y + col}), tmp); + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/tensor_to_register.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/tensor_to_register.cuh new file mode 100644 index 0000000000..c2ff7f30a8 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/tensor_to_register.cuh @@ -0,0 +1,325 @@ +/** + * @file + * @brief Group (collaborative warp) ops for loading tensor tiles into register tiles. + */ + +/** + * @brief Load data from a tensor tile into a register tile. + * + * @tparam RT The register tile type + * @tparam TM The tensor memory tile type + * @param dst[out] The destination register tile. + * @param src[in] The source tensor tile. + */ +template +__device__ inline static void load_async(RT &dst, const TM &src) { + if constexpr (GROUP_WARPS == 1) { + static_assert(RT::height == TM::height, "register tile and tensor tile must match height"); + static_assert(RT::width == TM::width, "register tile and tensor tile must match width"); + + using T2 = RT::dtype; + using U = typename TM::dtype; + using U2 = base_types::packing::packed_type; + + if constexpr (sizeof(typename TM::dtype) == 1) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x2.pack::16b.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(*(uint32_t*) &dst.tiles[i][j].data[0]), + "=r"(*(uint32_t*) &dst.tiles[i][j].data[1]), + "=r"(*(uint32_t*) &dst.tiles[i][j].data[2]), + "=r"(*(uint32_t*) &dst.tiles[i][j].data[3]) + : "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)/(4/(uint32_t)sizeof(U))) + ); + } + } + } else if constexpr (sizeof(typename TM::dtype) == 2) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + asm volatile( + "tcgen05.ld.sync.aligned.16x128b.x2.pack::16b.b32 {%0, %1, %2, %3}, [%4];\n" + : "=r"(*(uint32_t*) &dst.tiles[i][j].data[0]), + "=r"(*(uint32_t*) &dst.tiles[i][j].data[1]), + "=r"(*(uint32_t*) &dst.tiles[i][j].data[2]), + "=r"(*(uint32_t*) &dst.tiles[i][j].data[3]) + : "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)) + ); + } + } + } + else if constexpr (sizeof(typename TM::dtype) == 4) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + if constexpr (dst.width%4 == 0) { + #pragma unroll + for(int j = 0; j < dst.width; j+=4) { + U2 data[16]; + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x8.b32 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, [%32];\n" + : "=f"(data[0].x), "=f"(data[0].y), + "=f"(data[1].x), "=f"(data[1].y), + "=f"(data[2].x), "=f"(data[2].y), + "=f"(data[3].x), "=f"(data[3].y), + "=f"(data[4].x), "=f"(data[4].y), + "=f"(data[5].x), "=f"(data[5].y), + "=f"(data[6].x), "=f"(data[6].y), + "=f"(data[7].x), "=f"(data[7].y), + "=f"(data[8].x), "=f"(data[8].y), + "=f"(data[9].x), "=f"(data[9].y), + "=f"(data[10].x), "=f"(data[10].y), + "=f"(data[11].x), "=f"(data[11].y), + "=f"(data[12].x), "=f"(data[12].y), + "=f"(data[13].x), "=f"(data[13].y), + "=f"(data[14].x), "=f"(data[14].y), + "=f"(data[15].x), "=f"(data[15].y) + : "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)/(4/(uint32_t)sizeof(U))) + ); + #pragma unroll + for(int k = 0; k < 4; k++) { + dst.tiles[i][j+0].data[k] = base_types::convertor::convert(data[k]); + dst.tiles[i][j+1].data[k] = base_types::convertor::convert(data[k+4]); + dst.tiles[i][j+2].data[k] = base_types::convertor::convert(data[k+8]); + dst.tiles[i][j+3].data[k] = base_types::convertor::convert(data[k+12]); + } + } + } + else if constexpr (dst.width%2 == 0) { + #pragma unroll + for(int j = 0; j < dst.width; j+=2) { + U2 data[8]; + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x4.b32 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, [%16];\n" + : "=f"(data[0].x), "=f"(data[0].y), + "=f"(data[1].x), "=f"(data[1].y), + "=f"(data[2].x), "=f"(data[2].y), + "=f"(data[3].x), "=f"(data[3].y), + "=f"(data[4].x), "=f"(data[4].y), + "=f"(data[5].x), "=f"(data[5].y), + "=f"(data[6].x), "=f"(data[6].y), + "=f"(data[7].x), "=f"(data[7].y) + : "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)/(4/(uint32_t)sizeof(U))) + ); + #pragma unroll + for(int k = 0; k < 4; k++) { + dst.tiles[i][j+0].data[k] = base_types::convertor::convert(data[k]); + dst.tiles[i][j+1].data[k] = base_types::convertor::convert(data[k+4]); + } + } + } + else { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + U2 data[4]; + asm volatile( + "tcgen05.ld.sync.aligned.16x256b.x2.b32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n" + : "=f"(data[0].x), "=f"(data[0].y), + "=f"(data[1].x), "=f"(data[1].y), + "=f"(data[2].x), "=f"(data[2].y), + "=f"(data[3].x), "=f"(data[3].y) + : "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)/(4/(uint32_t)sizeof(U))) + ); + #pragma unroll + for(int k = 0; k < 4; k++) { + dst.tiles[i][j].data[k] = base_types::convertor::convert(data[k]); + } + } + } + } + } + } + else { + static_assert(GROUP_WARPS==4 || GROUP_WARPS==8); + constexpr int warp_rows = TM::rows/GROUP_WARPS; + static_assert(TM::cols==RT::cols); + static_assert(warp_rows==RT::rows); + if constexpr (GROUP_WARPS == 4) { + auto src_subtile = src.template subtile>(32*warpid(), 0); + ::kittens::group<1>::load_async(dst, src_subtile); + } + else { + auto src_subtile = src.template subtile>(32*(warpid()%4)+16*(warpid()/4), 0); + ::kittens::group<1>::load_async(dst, src_subtile); + } + } +} + + +/** + * @brief Store data into a tensor tile from a register tile. + * + * @tparam RT The register tile type + * @tparam TM The tensor memory tile type + * @param dst[out] The destination tensor tile. + * @param src[in] The source register tile. + */ +template +__device__ inline static void store_async(TM &dst, const RT &src) { + if constexpr (GROUP_WARPS == 1) { + static_assert(RT::height == TM::height, "register tile and tensor tile must match height"); + static_assert(RT::width == TM::width, "register tile and tensor tile must match width"); + + using T2 = RT::dtype; + using T = base_types::packing::unpacked_type; + using U = TM::dtype; + using U2 = base_types::packing::packed_type; + + if constexpr (sizeof(typename TM::dtype) == 2) { + #pragma unroll + for(int i = 0; i < src.height; i++) { + if constexpr (src.width%4 == 0) { + #pragma unroll + for(int j = 0; j < src.width; j+=4) { + asm volatile( + "tcgen05.st.sync.aligned.16x128b.x8.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16};\n" + :: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))), + "r"(*(uint32_t*)&src.tiles[i][j+0].data[0]), + "r"(*(uint32_t*)&src.tiles[i][j+0].data[1]), + "r"(*(uint32_t*)&src.tiles[i][j+0].data[2]), + "r"(*(uint32_t*)&src.tiles[i][j+0].data[3]), + "r"(*(uint32_t*)&src.tiles[i][j+1].data[0]), + "r"(*(uint32_t*)&src.tiles[i][j+1].data[1]), + "r"(*(uint32_t*)&src.tiles[i][j+1].data[2]), + "r"(*(uint32_t*)&src.tiles[i][j+1].data[3]), + "r"(*(uint32_t*)&src.tiles[i][j+2].data[0]), + "r"(*(uint32_t*)&src.tiles[i][j+2].data[1]), + "r"(*(uint32_t*)&src.tiles[i][j+2].data[2]), + "r"(*(uint32_t*)&src.tiles[i][j+2].data[3]), + "r"(*(uint32_t*)&src.tiles[i][j+3].data[0]), + "r"(*(uint32_t*)&src.tiles[i][j+3].data[1]), + "r"(*(uint32_t*)&src.tiles[i][j+3].data[2]), + "r"(*(uint32_t*)&src.tiles[i][j+3].data[3]) + ); + } + } + else if constexpr (src.width%2 == 0) { + #pragma unroll + for(int j = 0; j < src.width; j+=2) { + asm volatile( + "tcgen05.st.sync.aligned.16x128b.x4.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8};\n" + :: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))), + "r"(*(uint32_t*)&src.tiles[i][j+0].data[0]), + "r"(*(uint32_t*)&src.tiles[i][j+0].data[1]), + "r"(*(uint32_t*)&src.tiles[i][j+0].data[2]), + "r"(*(uint32_t*)&src.tiles[i][j+0].data[3]), + "r"(*(uint32_t*)&src.tiles[i][j+1].data[0]), + "r"(*(uint32_t*)&src.tiles[i][j+1].data[1]), + "r"(*(uint32_t*)&src.tiles[i][j+1].data[2]), + "r"(*(uint32_t*)&src.tiles[i][j+1].data[3]) + ); + } + } + else { + #pragma unroll + for(int j = 0; j < src.width; j++) { + asm volatile( + "tcgen05.st.sync.aligned.16x128b.x2.b32 [%0], {%1, %2, %3, %4};\n" + :: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))), + "r"(*(uint32_t*)&src.tiles[i][j].data[0]), + "r"(*(uint32_t*)&src.tiles[i][j].data[1]), + "r"(*(uint32_t*)&src.tiles[i][j].data[2]), + "r"(*(uint32_t*)&src.tiles[i][j].data[3]) + ); + } + } + } + } + else if constexpr (sizeof(typename TM::dtype) == 4) { + #pragma unroll + for(int i = 0; i < src.height; i++) { + if constexpr(src.width%4 == 0) { + #pragma unroll + for(int j = 0; j < src.width; j+=4) { + U2 data[16]; + #pragma unroll + for(int k = 0; k < 4; k++) { + data[k] = base_types::convertor::convert(src.tiles[i][j].data[k]); + data[k+4] = base_types::convertor::convert(src.tiles[i][j+1].data[k]); + data[k+8] = base_types::convertor::convert(src.tiles[i][j+2].data[k]); + data[k+12] = base_types::convertor::convert(src.tiles[i][j+3].data[k]); + } + asm volatile( + "tcgen05.st.sync.aligned.16x256b.x8.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32};\n" + :: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))), + "f"(data[0].x), "f"(data[0].y), + "f"(data[1].x), "f"(data[1].y), + "f"(data[2].x), "f"(data[2].y), + "f"(data[3].x), "f"(data[3].y), + "f"(data[4].x), "f"(data[4].y), + "f"(data[5].x), "f"(data[5].y), + "f"(data[6].x), "f"(data[6].y), + "f"(data[7].x), "f"(data[7].y), + "f"(data[8].x), "f"(data[8].y), + "f"(data[9].x), "f"(data[9].y), + "f"(data[10].x), "f"(data[10].y), + "f"(data[11].x), "f"(data[11].y), + "f"(data[12].x), "f"(data[12].y), + "f"(data[13].x), "f"(data[13].y), + "f"(data[14].x), "f"(data[14].y), + "f"(data[15].x), "f"(data[15].y) + ); + } + } + else if constexpr(src.width%2 == 0) { + #pragma unroll + for(int j = 0; j < src.width; j+=2) { + U2 data[8]; + #pragma unroll + for(int k = 0; k < 4; k++) { + data[k] = base_types::convertor::convert(src.tiles[i][j].data[k]); + data[k+4] = base_types::convertor::convert(src.tiles[i][j+1].data[k]); + } + asm volatile( + "tcgen05.st.sync.aligned.16x256b.x4.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16};\n" + :: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))), + "f"(data[0].x), "f"(data[0].y), + "f"(data[1].x), "f"(data[1].y), + "f"(data[2].x), "f"(data[2].y), + "f"(data[3].x), "f"(data[3].y), + "f"(data[4].x), "f"(data[4].y), + "f"(data[5].x), "f"(data[5].y), + "f"(data[6].x), "f"(data[6].y), + "f"(data[7].x), "f"(data[7].y) + ); + } + } + else { + #pragma unroll + for(int j = 0; j < src.width; j++) { + U2 data[4]; + #pragma unroll + for(int k = 0; k < 4; k++) { + data[k] = base_types::convertor::convert(src.tiles[i][j].data[k]); + } + asm volatile( + "tcgen05.st.sync.aligned.16x256b.x2.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8};\n" + :: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))), + "f"(data[0].x), "f"(data[0].y), + "f"(data[1].x), "f"(data[1].y), + "f"(data[2].x), "f"(data[2].y), + "f"(data[3].x), "f"(data[3].y) + ); + } + } + } + } + } + else { + static_assert(GROUP_WARPS==4 || GROUP_WARPS==8); + constexpr int warp_rows = TM::rows/GROUP_WARPS; + static_assert(TM::cols==RT::cols); + static_assert(warp_rows==RT::rows); + if constexpr (GROUP_WARPS == 4) { + auto dst_subtile = dst.template subtile>(32*warpid(), 0); + ::kittens::group<1>::store_async(dst_subtile, src); + } + else { + auto dst_subtile = dst.template subtile>(32*(warpid()%4)+16*(warpid()/4), 0); + ::kittens::group<1>::store_async(dst_subtile, src); + } + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/tile.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/tile.cuh new file mode 100644 index 0000000000..da6125811a --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/tile.cuh @@ -0,0 +1,16 @@ +/** + * @file + * @brief An aggregate header of group memory operations on tiles. + */ + +#include "shared_to_register.cuh" +#include "global_to_register.cuh" +#include "global_to_shared.cuh" +#ifdef KITTENS_BLACKWELL +#include "tensor_to_register.cuh" +#endif + +#include "complex/complex_shared_to_register.cuh" +#include "complex/complex_global_to_register.cuh" +#include "complex/complex_global_to_shared.cuh" + diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/tma.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/tma.cuh new file mode 100644 index 0000000000..c8d14de735 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/tma.cuh @@ -0,0 +1,134 @@ +/** + * @file + * @brief Functions for a group scope to call tile TMA functions. + */ + +template> +__device__ static inline void prefetch(ST &dst, const GL &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::prefetch(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void prefetch(ST &dst, const GL &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::prefetch(dst, src, idx); // Don't do the mask + } +} + +template> +__device__ static inline void store_async(const GL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_async(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void store_async(const GL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_async(dst, src, idx); + } +} + +template> +__device__ static inline void store_async(const PGL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_async(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void store_async(const PGL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_async(dst, src, idx); + } +} + +template> +__device__ static inline void store_add_async(const GL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_add_async(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void store_add_async(const GL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_add_async(dst, src, idx); + } +} + +template> +__device__ static inline void store_add_async(const PGL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_add_async(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void store_add_async(const PGL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_add_async(dst, src, idx); + } +} + +template> +__device__ static inline void store_min_async(const GL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_min_async(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void store_min_async(const GL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_min_async(dst, src, idx); + } +} + +template> +__device__ static inline void store_min_async(const PGL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_min_async(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void store_min_async(const PGL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_min_async(dst, src, idx); + } +} + +template> +__device__ static inline void store_max_async(const GL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_max_async(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void store_max_async(const GL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_max_async(dst, src, idx); + } +} + +template> +__device__ static inline void store_max_async(const PGL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_max_async(dst, src, idx); // Don't do the mask + } +} +template> +__device__ static inline void store_max_async(const PGL &dst, const ST &src, const COORD &idx) { + if(laneid() == 0) { + ::kittens::tma::store_max_async(dst, src, idx); + } +} + +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar) { + if(laneid() == 0) { + ::kittens::tma::load_async(dst, src, idx, bar); // Don't do the mask + } +} +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar) { + if(laneid() == 0) { + ::kittens::tma::load_async(dst, src, idx, bar); + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/tile/tma_cluster.cuh b/extra/thunder/cuda/include/ops/group/memory/tile/tma_cluster.cuh new file mode 100644 index 0000000000..4b2ae0ff22 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/tile/tma_cluster.cuh @@ -0,0 +1,33 @@ +/** + * @file + * @brief Functions for a group scope to call tile TMA cluster functions. + */ + + +#ifdef KITTENS_BLACKWELL +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { + if(laneid() == 0) { + ::kittens::tma::cluster::load_async(dst, src, idx, bar, cluster_mask, dst_mbar_cta); + } +} +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { + if(laneid() == 0) { + ::kittens::tma::cluster::load_async(dst, src, idx, bar, cluster_mask, dst_mbar_cta); + } +} +#else +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask) { + if(laneid() == 0) { + ::kittens::tma::cluster::load_async(dst, src, idx, bar, cluster_mask); + } +} +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask) { + if(laneid() == 0) { + ::kittens::tma::cluster::load_async(dst, src, idx, bar, cluster_mask); + } +} +#endif \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/util/tma.cuh b/extra/thunder/cuda/include/ops/group/memory/util/tma.cuh new file mode 100644 index 0000000000..f51afa5a03 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/util/tma.cuh @@ -0,0 +1,68 @@ +/** + * @file + * @brief Various utilities for group TMA memory operations. + */ + +/* ---------- Barrier functions for async load ---------- */ + +/** +* @brief Sets the number of bytes expected at the semaphore. +* +* This function sets the number of bytes expected at the semaphore for the first thread in the warp. +* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly +* instruction to set the expected number of bytes. +* +* @param semaphore Reference to the semaphore variable. +* @param bytes The number of bytes expected at the semaphore. +*/ +__device__ static inline void expect_bytes(semaphore& bar, uint32_t bytes) { + if(laneid() == 0) { + ::kittens::tma::expect_bytes(bar, bytes); + } +} +/** +* @brief Sets the number of bytes expected at the semaphore. +* +* This function sets the number of bytes expected at the mbarrier before the transaction arrives. +*/ +template +__device__ static inline void expect(semaphore& bar, const T& _1, const args&... _2) { + expect_bytes(bar, size_bytes); +} + +/* ---------- Synchronization functions for async store ---------- */ + +/** + * @brief Commits previous asynchronous TMA stores to a group and performs them. +*/ +__device__ static inline void store_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} +/** + * @brief Waits for previous committed TMA store groups to complete. + * + * @tparam N The maximum number of remaining TMA store groups. Defaults to 0. +*/ +template +__device__ static inline void store_async_wait() { + asm volatile ( + "cp.async.bulk.wait_group %0;" + : + : "n"(N) + : "memory" + ); +} +/** + * @brief Waits for previous committed TMA store groups to finish reading from shared memory. + * + * @tparam N The maximum number of remaining TMA store groups. Defaults to 0. +*/ +template +__device__ static inline void store_async_read_wait() { + asm volatile ( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(N) + : "memory" + ); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/util/tma_cluster.cuh b/extra/thunder/cuda/include/ops/group/memory/util/tma_cluster.cuh new file mode 100644 index 0000000000..30db9aa6de --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/util/tma_cluster.cuh @@ -0,0 +1,90 @@ + +/** +* @brief Waits for the requested semaphore phase, at cluster scope +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void wait(semaphore& bar, int kPhaseBit) { + void const* const ptr = &bar; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +} + +/** +* @brief Sets the number of bytes expected at the semaphore, assuming a multicast instruction. +* +* This function sets the number of bytes expected at the semaphore for the first thread in the warp. +* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly +* instruction to set the expected number of bytes. +* +* It's worth being aware that this function is particularly necessary for multicast loads, and +* distributed shared memory can actually be done with a normal tma::expect followed by wait. See +* the unit tests of dsmem for an example. +* +* @param semaphore Reference to the semaphore variable. +* @param bytes The number of bytes expected at the semaphore. +*/ +__device__ static inline void expect_bytes(semaphore& bar, uint32_t bytes, int dst_cta) { + if(laneid() == 0) { + ::kittens::tma::cluster::expect_bytes(bar, bytes, dst_cta); + } +} +/** +* @brief Sets the number of bytes expected at the semaphore. +* +* This function sets the number of bytes expected at the semaphore for the first thread in the warp. +* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly +* instruction to set the expected number of bytes. +* +* @tparam T The type of the data to be stored at the semaphore. +* @param semaphore Reference to the semaphore variable. +*/ +/** +* @brief Sets the number of bytes expected at the semaphore. +* +* This function sets the number of bytes expected at the mbarrier before the transaction arrives. +*/ +template +__device__ static inline void expect(semaphore& bar, int dst_cta, const T& _1, const args&... _2) { + expect_bytes(bar, size_bytes, dst_cta); +} + +/** +* @brief Arrives at a semaphore in cluster scope. +* +* Marks a thread arrival at an mbarrier +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void arrive(semaphore& bar, int dst_cta, uint32_t count=1) { + if(laneid() == 0) { + ::kittens::tma::cluster::arrive(bar, dst_cta, count); + } +} + +// Generic transfer +__device__ static inline void store_async(void *dst, void *src, int dst_cta, uint32_t size_bytes, semaphore& bar) { + if(laneid() == 0) { + ::kittens::tma::cluster::store_async(dst, src, dst_cta, size_bytes, bar); + } +} + +// Templated transfer for convenience +template +__device__ static inline void store_async(T &dst_, T &src_, int dst_cta, semaphore& bar) { + store_async((void*)&dst_, (void*)&src_, dst_cta, size_bytes, bar); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/util/util.cuh b/extra/thunder/cuda/include/ops/group/memory/util/util.cuh new file mode 100644 index 0000000000..cf5d4b4a8d --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/util/util.cuh @@ -0,0 +1,168 @@ +/** + * @file + * @brief Various utilities for group memory operations. + */ + + +template __device__ static inline void load_async_wait(int bar_id) { // for completing (non-TMA) async loads + asm volatile("cp.async.wait_group %0;\n" : : "n"(N) : "memory"); + sync(bar_id); +} +template __device__ static inline void load_async_wait() { // for completing (non-TMA) async loads + KITTENS_CHECK_WARP + asm volatile("cp.async.wait_group %0;\n" : : "n"(N) : "memory"); + __syncwarp(); +} + +__device__ static inline void arrive(barrier bar) { + asm volatile("bar.arrive %0, %1;\n" :: "r"(bar.barrier_id), "n"(GROUP_WARPS*WARP_THREADS) : "memory"); +} +__device__ static inline void arrive_and_wait(barrier bar) { + asm volatile("bar.sync %0, %1;\n" :: "r"(bar.barrier_id), "n"(GROUP_WARPS*WARP_THREADS) : "memory"); +} + +/** + * @brief Initializes a synchronization semaphore with a transaction count and sets the expected number of bytes. + * + * This function sets up a semaphore that is used to synchronize threads within a block during asynchronous operations. + * It initializes the semaphore with a thread count semaphore. + * + * Additionally, if it is given a shared tile type, it will also call `set_bytes` to prepare for the memory transaction. + * + * @param[out] semaphore The semaphore variable to initialize. + * @param[in] tc The thread counter for the semaphore. + */ +__device__ static inline void init_semaphore(semaphore& bar, int thread_count, int transaction_count=0) { + if (laneid() == 0) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(bar_ptr), "r"(thread_count+transaction_count) + ); + } +} +/** + * @brief Invalidate an mbarrier + * + * @param[out] semaphore The semaphore variable to initialize. + * @param[in] tc The thread counter for the semaphore. + */ +__device__ static inline void invalidate_semaphore(semaphore& bar) { + if (laneid() == 0) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + asm volatile ( + "mbarrier.inval.shared::cta.b64 [%0];\n" + :: "r"(bar_ptr) + ); + } +} + +/** +* @brief Arrives at a semaphore. +* +* Marks a warp arrival at an mbarrier +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void arrive(semaphore& sem) { + if(laneid() == 0) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&sem)); + asm volatile ( + "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];\n" + : + : "r"(mbar_ptr) + : "memory" + ); + } +} +template __device__ static inline void arrive(barrier bar) { + asm volatile("bar.arrive %0, %1;\n" :: "r"(bar.barrier_id), "n"(num_warps*WARP_THREADS) : "memory"); +} + +#ifdef KITTENS_HOPPER +/** +* @brief Arrives at a semaphore. +* +* Marks a warp arrival at an mbarrier +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void arrive(semaphore& sem, uint32_t count) { + if(laneid() == 0) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&sem)); + asm volatile ( + "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" + : + : "r"(mbar_ptr), "r"(count) + : "memory" + ); + } +} +#endif + +/** +* @brief Waits for the requested semaphore phase. +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void wait(semaphore& sem, int kPhaseBit) { + void const* const ptr = &sem; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + +#ifdef KITTENS_HOPPER + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +#else + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures to save instruction issue slots + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +#endif +} + +/** +* @brief Checks if the requested semaphore phase is ready. +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline int test_wait(semaphore& sem, int kPhaseBit) { + void const* const ptr = &sem; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + int result; + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2;\n" + "selp.u32 %0,1,0,P1;" + "}\n" + : "=r"(result) + : "r"(mbar_ptr), "r"(kPhaseBit) + ); + return result; +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/vec/global_to_register.cuh b/extra/thunder/cuda/include/ops/group/memory/vec/global_to_register.cuh new file mode 100644 index 0000000000..2567150ce1 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/vec/global_to_register.cuh @@ -0,0 +1,138 @@ +/** + * @file + * @brief Functions for a warpgroup to collaboratively transfer data directly between global memory and registers and back. + */ + +/** + * @brief Collaboratively loads data into register vectors from a source array in global memory. + * + * @tparam RV The register vector type. + * @tparam U The data type of the source array. + * @param[out] dst The destination register vector to load data into. + * @param[in] src The source array in global memory to load data from. + */ +template +__device__ inline static void load(RV &dst, const GL &src, const coord> &idx) { + if constexpr (GROUP_WARPS == 1) { + using T2 = RV::dtype; + using U = typename GL::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + U *src_ptr = (U*)&src[(idx.template unit_coord<-1, 3>())]; + int laneid = ::kittens::laneid(); + + if constexpr (std::is_same_v) { + #pragma unroll + for(auto w = 0; w < (dst.outer_dim+3)/4; w++) { + int idx = w*64 + (laneid/4)*8 + 2*(laneid%4); + int o_dim = w*4 + (laneid/4) / 2; + int i_dim = (laneid/4) % 2; + // this should be a maximally coalesced load. + if(idx < dst.outer_dim*16) + dst[o_dim][i_dim] = base_types::convertor::convert(*(U2*)&src_ptr[idx]); + } + // now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need. + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + int leader = 8*(w%4) + (laneid%4); // repeats every 64 columns + dst[w][0] = packed_shfl_sync(MASK_ALL, dst[w][0], leader); + dst[w][1] = packed_shfl_sync(MASK_ALL, dst[w][1], leader+4); + } + } + else if constexpr (std::is_same_v) { + // really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true + // otherwise there will be some pain :/ + #pragma unroll + for(auto w = 0; w < (dst.outer_dim+1)/2; w++) { + int idx = w*32 + (laneid%4)*8 + (laneid/4); + int o_dim = w*2 + (laneid%4) / 2; + // this should be a maximally coalesced load. + if(idx < dst.outer_dim*16) { + T tmp = base_types::convertor::convert(src_ptr[idx]); + if(laneid%2==0) dst[o_dim][0].x = tmp; + else dst[o_dim][0].y = tmp; + } + } + // now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need. + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + int leader = (laneid/4)*4 + 2*(w%2); // repeats every 64 columns + dst[w][0].x = __shfl_sync(MASK_ALL, dst[w][0].x, leader); + dst[w][0].y = __shfl_sync(MASK_ALL, dst[w][0].y, leader+1); + } + } + else if constexpr (std::is_same_v) { + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + if(w < dst.outer_dim-1 || dst.length%32 == 0 || laneid<16) { + dst[w][0] = base_types::convertor::convert(src_ptr[w*32 + laneid]); + } + } + } + } + else { + // Call warp level load + ::kittens::group<1>::load(dst, src, coord(idx.b, idx.d, idx.r, idx.c*GROUP_WARPS+warpid())); + } +} +/** + * @brief Collaboratively stores data from register vectors to a destination array in global memory. + * + * @tparam RV The register vector type. + * @tparam U The data type of the destination array. + * @param[out] dst The destination array in global memory to store data into. + * @param[in] src The source register vector to store data from. + */ +template +__device__ inline static void store(GL &dst, const RV &src, const coord> &idx) { + if constexpr (GROUP_WARPS == 1) { + using T2 = RV::dtype; + using U = typename GL::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + U *dst_ptr = (U*)&dst[(idx.template unit_coord<-1, 3>())]; + int laneid = ::kittens::laneid(); + + if constexpr (std::is_same_v) { + #pragma unroll + for(auto w = 0; w < (src.outer_dim+3)/4; w++) { + int idx = w*64 + (laneid/4)*8 + 2*(laneid%4); + int o_dim = w*4 + (laneid/4) / 2; + int i_dim = (laneid/4) % 2; + // this should be a maximally coalesced store. I hope! + if(idx < src.outer_dim*16) + *(U2*)&dst_ptr[idx] = base_types::convertor::convert(src[o_dim][i_dim]); + } + } + else if constexpr (std::is_same_v) { + // really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true + // otherwise there will be some pain :/ + #pragma unroll + for(auto w = 0; w < (src.outer_dim+1)/2; w++) { + int idx = w*32 + (laneid%4)*8 + (laneid/4); + int o_dim = w*2 + (laneid%4) / 2; + // this should be a maximally coalesced load. + if(idx < src.outer_dim*16) { + U tmp; + if(laneid%2==0) tmp = base_types::convertor::convert(src[o_dim][0].x); + else tmp = base_types::convertor::convert(src[o_dim][0].y); + dst_ptr[idx] = tmp; + } + } + } + else if constexpr (std::is_same_v) { + #pragma unroll + for(auto w = 0; w < src.outer_dim; w++) { + if(w < src.outer_dim-1 || src.length%32 == 0 || laneid<16) { + dst_ptr[w*32 + laneid] = base_types::convertor::convert(src[w][0]); + } + } + } + } + else { + // Call warp level store + ::kittens::group<1>::store(dst, src, coord(idx.b, idx.d, idx.r, idx.c*GROUP_WARPS+warpid())); + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/vec/global_to_shared.cuh b/extra/thunder/cuda/include/ops/group/memory/vec/global_to_shared.cuh new file mode 100644 index 0000000000..73f0826cd3 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/vec/global_to_shared.cuh @@ -0,0 +1,77 @@ +/** + * @file + * @brief Group (collaborative warp) ops for loading shared vectors from and storing to global memory. + */ + +/** + * @brief Loads data from global memory into shared memory vector. + * + * This function loads data from a global memory location pointed to by `src` into a shared memory vector `dst`. + * It calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`. + * The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp. + * + * @tparam SV Shared vector type, must satisfy ducks::sv::all concept. + * @param dst Reference to the shared vector where the data will be loaded. + * @param src Pointer to the global memory location from where the data will be loaded. + */ +template> +__device__ static inline void load(SV &dst, const GL &src, const COORD &idx) { + constexpr uint32_t elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype); + constexpr uint32_t total_calls = SV::length / elem_per_transfer; // guaranteed to divide + typename GL::dtype *src_ptr = (typename GL::dtype*)&src[(idx.template unit_coord<-1, 3>())]; + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst.data[0])); + #pragma unroll + for(uint32_t i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) { + if(i * elem_per_transfer < dst.length) { + float4 tmp; + move::ldg(tmp, (float4*)&src_ptr[i*elem_per_transfer]); + move::sts(dst_ptr + sizeof(typename SV::dtype)*i*elem_per_transfer, tmp); + } + } +} + +/** + * @brief Stores data from a shared memory vector to global memory. + * + * This function stores data from a shared memory vector `src` to a global memory location pointed to by `dst`. + * Similar to the load function, it calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`. + * The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp. + * + * @tparam SV Shared vector type, must satisfy ducks::sv::all concept. + * @param dst Pointer to the global memory location where the data will be stored. + * @param src Reference to the shared vector from where the data will be stored. + */ +template> +__device__ static inline void store(GL &dst, const SV &src, const COORD &idx) { + constexpr uint32_t elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype); + constexpr uint32_t total_calls = SV::length / elem_per_transfer; // guaranteed to divide + typename GL::dtype *dst_ptr = (typename GL::dtype*)&dst[(idx.template unit_coord<-1, 3>())]; + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src.data[0])); + #pragma unroll + for(uint32_t i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) { + if(i * elem_per_transfer < src.length) { + float4 tmp; + move::lds(tmp, src_ptr + sizeof(typename SV::dtype)*i*elem_per_transfer); + move::stg((float4*)&dst_ptr[i*elem_per_transfer], tmp); + } + } +} + +template> +__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx) { + constexpr uint32_t elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype); + constexpr uint32_t total_calls = SV::length / elem_per_transfer; // guaranteed to divide + typename GL::dtype *src_ptr = (typename GL::dtype*)&src[(idx.template unit_coord<-1, 3>())]; + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst.data[0])); + #pragma unroll + for(uint32_t i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) { + if(i * elem_per_transfer < dst.length) { + asm volatile( + "cp.async.cg.shared.global.L2::128B [%0], [%1], 16;\n" + :: "r"(dst_ptr + (uint32_t)sizeof(typename SV::dtype)*i*elem_per_transfer), "l"((uint64_t)&src_ptr[i*elem_per_transfer]) + : "memory" + ); + } + } + asm volatile("cp.async.commit_group;\n" ::: "memory"); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/vec/shared_to_register.cuh b/extra/thunder/cuda/include/ops/group/memory/vec/shared_to_register.cuh new file mode 100644 index 0000000000..12152597dd --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/vec/shared_to_register.cuh @@ -0,0 +1,159 @@ +/** + * @file + * @brief Functions for a group to collaboratively transfer data directly between shared memory and registers and back. + */ + +/** + * @brief Collaboratively load data from a shared vector into register vectors split across a warpgroup. + * + * @tparam RV The register vector type + * @tparam SV The shared vector type + * @param dst[out] The destination register vector. + * @param src[in] The source shared vector. + */ +template +__device__ inline static void load(RV &dst, const SV &src) { + using T2 = RV::dtype; + using U = SV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + if constexpr (GROUP_WARPS == 1) { + static_assert(SV::length == RV::length); + + int laneid = ::kittens::laneid(); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src.data[0])); + + __syncwarp(); + if constexpr (std::is_same_v) { + #pragma unroll + for(auto w = 0; w < (dst.outer_dim+3)/4; w++) { + int idx = w*64 + (laneid/4)*8 + 2*(laneid%4); + int o_dim = w*4 + (laneid/4) / 2; + int i_dim = (laneid/4) % 2; + // this should be a maximally coalesced load. + if(idx < dst.outer_dim*16) { + U2 tmp; + move::lds(tmp, src_ptr + sizeof(typename SV::dtype)*idx); + dst[o_dim][i_dim] = base_types::convertor::convert(tmp); + } + } + __syncwarp(); + // now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need. + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + int leader = 8*(w%4) + (laneid%4); // repeats every 64 columns + dst[w][0] = packed_shfl_sync(MASK_ALL, dst[w][0], leader); + dst[w][1] = packed_shfl_sync(MASK_ALL, dst[w][1], leader+4); + } + } + else if constexpr (std::is_same_v) { + // really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true + // otherwise there will be some pain :/ + #pragma unroll + for(auto w = 0; w < (dst.outer_dim+1)/2; w++) { + int idx = w*32 + (laneid%4)*8 + (laneid/4); + int o_dim = w*2 + (laneid%4) / 2; + // this should be a maximally coalesced load. + if(idx < dst.outer_dim*16) { + U tmp; + move::lds(tmp, src_ptr + sizeof(typename SV::dtype)*idx); + if(laneid%2==0) dst[o_dim][0].x = base_types::convertor::convert(tmp); + else dst[o_dim][0].y = base_types::convertor::convert(tmp); + } + } + __syncwarp(); + // now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need. + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + int leader = (laneid/4)*4 + 2*(w%2); // repeats every 64 columns + dst[w][0].x = __shfl_sync(MASK_ALL, dst[w][0].x, leader); + dst[w][0].y = __shfl_sync(MASK_ALL, dst[w][0].y, leader+1); + } + } + else if constexpr (std::is_same_v) { + #pragma unroll + for(auto w = 0; w < dst.outer_dim; w++) { + if(w < dst.outer_dim-1 || RV::length%32 == 0 || laneid<16) { + U tmp; + move::lds(tmp, src_ptr + sizeof(typename SV::dtype)*(w*32 + laneid)); + dst[w][0] = base_types::convertor::convert(tmp); + } + } + } + } + else { + static_assert(SV::length == RV::length*GROUP_WARPS);// confirm size correct + auto &_src = src.template subvec(warpid()); // pretend it's smaller and do warp-level load + + ::kittens::group<1>::load(dst, _src); // warp-level + } +} + +/** + * @brief Collaboratively store data into a shared vector from register vectors split across a warpgroup. + * + * @tparam RV The register vector type + * @tparam SV The shared vector type + * @param dst[out] The destination shared vector. + * @param src[in] The source register vector. + */ +template +__device__ inline static void store(SV &dst, const RV &src) { + using T2 = RV::dtype; + using U = SV::dtype; + using U2 = base_types::packing::packed_type; + using T = base_types::packing::unpacked_type; + + if constexpr (GROUP_WARPS == 1) { + static_assert(SV::length == RV::length); + + int laneid = ::kittens::laneid(); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst.data[0])); + + __syncwarp(); + if constexpr (std::is_same_v) { + #pragma unroll + for(auto w = 0; w < (src.outer_dim+3)/4; w++) { + int idx = w*64 + (laneid/4)*8 + 2*(laneid%4); + int o_dim = w*4 + (laneid/4) / 2; + int i_dim = (laneid/4) % 2; + // this should be a maximally coalesced store. I hope! + if(idx < src.outer_dim*16) { + U2 tmp = base_types::convertor::convert(src[o_dim][i_dim]); + move::sts(dst_ptr + sizeof(typename SV::dtype)*idx, tmp); + } + } + } + else if constexpr (std::is_same_v) { + // really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true + // otherwise there will be some pain :/ + #pragma unroll + for(auto w = 0; w < (src.outer_dim+1)/2; w++) { + int idx = w*32 + (laneid%4)*8 + (laneid/4); + int o_dim = w*2 + (laneid%4) / 2; + // this should be a maximally coalesced load. + if(idx < src.outer_dim*16) { + U tmp; + if(laneid%2==0) tmp = base_types::convertor::convert(src[o_dim][0].x); + else tmp = base_types::convertor::convert(src[o_dim][0].y); + move::sts(dst_ptr + sizeof(typename SV::dtype)*idx, tmp); + } + } + } + else if constexpr (std::is_same_v) { + #pragma unroll + for(auto w = 0; w < src.outer_dim; w++) { + if(w < src.outer_dim-1 || RV::length%32 == 0 || laneid<16) { + U tmp = base_types::convertor::convert(src[w][0]); + move::sts(dst_ptr + sizeof(typename SV::dtype)*(w*32 + laneid), tmp); + } + } + } + } + else { + static_assert(SV::length == RV::length*GROUP_WARPS);// confirm size correct + auto &_dst = dst.template subvec(warpid()); // pretend it's smaller and do warp-level load + + ::kittens::group<1>::store(_dst, src); // warp-level + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/memory/vec/tma.cuh b/extra/thunder/cuda/include/ops/group/memory/vec/tma.cuh new file mode 100644 index 0000000000..8f88ccf4ae --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/vec/tma.cuh @@ -0,0 +1,221 @@ +/** + * @file + * @brief Functions for a group scope to call vec TMA functions. + */ + +/* ---------- Prefetch Tensor Map ---------- */ + +/** + * @brief Prefetches data from global memory into a shared memory vector, along with the tensormap. + * + * @tparam SV A shared vector type with a TMA-compatible layout + * @param[out] dst The destination shared memory vector. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in] vec_idx The coord of the requested vector. + */ +template> +__device__ static inline void prefetch(SV &dst, const GL &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + ::kittens::detail::tma::vec_prefetch_tma_internal(tma_ptr, tma_coord); + } +} +__KITTENS_TMA_DEFINE_DEFAULT_LOAD_CACHE_VEC__(prefetch) + + +/* ---------- Async load and store data from gmem/smem ---------- */ + +/** + * @brief Asynchronously stores data into global memory from a shared memory vector. + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam SV A shared vector type with a TMA-compatible layout + * @param[out] dst_tma_map The destination tensormap address in global memory + * @param[in] src The source shared memory vector. + * @param[in] vec_idx The coord of the vector destination. + */ +template> +__device__ static inline void store_async(const GL &dst, const SV &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + store_commit_group(); +} +__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_async) + +template> +__device__ static inline void store_async(const PGL &dst, const SV &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + store_commit_group(); +} +__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_async) + + +/** +* @brief Asynchronously performs an add reduction and stores the result into global memory. +* +* This function performs an asynchronous add reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction. +* +* @tparam SV A shared vector type with a TMA-compatible layout +* @param[out] dst_tma_map The destination tensormap address in global memory +* @param[in] src The source shared memory vector. +* @param[in] vec_idx The coord of the vector destination. +*/ +template> +__device__ static inline void store_add_async(const GL &dst, const SV &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_add_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + store_commit_group(); +} +__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_add_async) + +template> +__device__ static inline void store_add_async(const PGL &dst, const SV &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_add_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + store_commit_group(); +} +__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_add_async) + + +/** +* @brief Asynchronously performs an min reduction and stores the result into global memory. +* +* This function performs an asynchronous min reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction. +* +* @tparam SV A shared vector type with a TMA-compatible layout +* @param[out] dst_tma_map The destination tensormap address in global memory +* @param[in] src The source shared memory vector. +* @param[in] vec_idx The coord of the vector destination. +*/ +template> +__device__ static inline void store_min_async(const GL &dst, const SV &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_min_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + store_commit_group(); +} +__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_min_async) + +template> +__device__ static inline void store_min_async(const PGL &dst, const SV &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_min_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + store_commit_group(); +} +__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_min_async) + +/** +* @brief Asynchronously performs an max reduction and stores the result into global memory. +* +* This function performs an asynchronous max reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction. +* +* @tparam SV A shared vector type with a TMA-compatible layout +* @param[out] dst_tma_map The destination tensormap address in global memory +* @param[in] src The source shared memory vector. +* @param[in] vec_idx The coord of the vector destination. +*/ +template> +__device__ static inline void store_max_async(const GL &dst, const SV &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_max_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + store_commit_group(); +} +__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_max_async) + +template> +__device__ static inline void store_max_async(const PGL &dst, const SV &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_max_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + store_commit_group(); +} +__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_max_async) + +/** + * @brief Asynchronously loads data from global memory into a shared memory vector. + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam SV A shared vector type with a TMA-compatible layout + * @param[out] dst The destination shared memory vector. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in] vec_idx The coord of the requested vector. + * @param[in,out] bar The semaphore used for synchronization of the asynchronous copy. + */ +template> +__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx, semaphore& bar) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t dst_i_ptr = dst_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_load_async_tma_internal(tma_ptr, dst_i_ptr, mbar_ptr, tma_coord); + } +} +__KITTENS_TMA_DEFINE_SEMAPHORE_CACHE_VEC__(load_async) diff --git a/extra/thunder/cuda/include/ops/group/memory/vec/tma_cluster.cuh b/extra/thunder/cuda/include/ops/group/memory/vec/tma_cluster.cuh new file mode 100644 index 0000000000..3ae848c404 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/vec/tma_cluster.cuh @@ -0,0 +1,31 @@ +/** + * @file + * @brief Functions for a group scope to call vec TMA cluster functions. + */ + +/** + * @brief Asynchronously loads data from global memory into a shared memory vector, broadcast across a cluster + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam SV A shared vector type with a TMA-compatible layout + * @param[out] dst The destination shared memory vector. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in,out] bar The semaphore used for synchronization of the asynchronous copy. + * @param[in] vec_idx The coord of the requested vector. + * @param[in] cluster_mask The mask of the clusters to broadcast to. + */ +template> +__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst)); + for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2; i += WARP_THREADS) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t dst_i_ptr = dst_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::cluster::vec_load_async_tma_internal(tma_ptr, dst_i_ptr, mbar_ptr, tma_coord, cluster_mask, dst_mbar_cta); + } +} +__KITTENS_TMA_DEFINE_CLUSTER_SEMAPHORE_CACHE_VEC__(load_async) diff --git a/extra/thunder/cuda/include/ops/group/memory/vec/vec.cuh b/extra/thunder/cuda/include/ops/group/memory/vec/vec.cuh new file mode 100644 index 0000000000..ac4229e27b --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/memory/vec/vec.cuh @@ -0,0 +1,8 @@ +/** + * @file + * @brief An aggregate header of group memory operations on vectors. + */ + +#include "shared_to_register.cuh" +#include "global_to_register.cuh" +#include "global_to_shared.cuh" diff --git a/extra/thunder/cuda/include/ops/group/mma/mma.cuh b/extra/thunder/cuda/include/ops/group/mma/mma.cuh new file mode 100644 index 0000000000..cc4fc63dfd --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/mma.cuh @@ -0,0 +1,17 @@ +/** + * @file + * @brief An aggregate header for all group-scope MMA operations. + */ + +// All compilation targets can use the warp-scope MMA operations. +#include "warp/warp.cuh" + +// Hopper has its own warpgroup-scope MMA operations. +#ifdef KITTENS_HOPPER +#include "warpgroup/warpgroup.cuh" +#endif + +// Blackwell has its own tensor-scope MMA operations. +#ifdef KITTENS_BLACKWELL +#include "tensor/tensor.cuh" +#endif \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/tensor/tensor.cuh b/extra/thunder/cuda/include/ops/group/mma/tensor/tensor.cuh new file mode 100644 index 0000000000..bad42c9a2e --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/tensor/tensor.cuh @@ -0,0 +1,172 @@ +/** + * @file Group-level tcgen05 MMA operations. +*/ + +template +__device__ static inline void mma(D &d, const A &a, const B &b, semaphore &sem) { + if(laneid() == 0) ::kittens::mma(d, a, b, sem); +} +template +__device__ static inline void mma2(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm2(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} + +template +__device__ static inline void mma_AB(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma2_AB(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mma_ABt(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma2_ABt(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mma_AtB(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma2_AtB(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mma_AtBt(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma2_AtBt(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} + +template +__device__ static inline void mm_AB(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm2_AB(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mm_ABt(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm2_ABt(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mm_AtB(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm2_AtB(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mm_AtBt(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm2_AtBt(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} + +// no sem versions + + +template +__device__ static inline void mma(D &d, const A &a, const B &b) { + if(laneid() == 0) ::kittens::mma(d, a, b); +} +template +__device__ static inline void mma2(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2(D &d, const A &a, const B &b) { + mma2(d, a, b); +} + +template +__device__ static inline void mma_AB(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mma2_AB(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mma_ABt(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mma2_ABt(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mma_AtB(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mma2_AtB(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mma_AtBt(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mma2_AtBt(D &d, const A &a, const B &b) { + mma2(d, a, b); +} + +template +__device__ static inline void mm_AB(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2_AB(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mm_ABt(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2_ABt(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mm_AtB(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2_AtB(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mm_AtBt(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2_AtBt(D &d, const A &a, const B &b) { + mma2(d, a, b); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warp/warp.cuh b/extra/thunder/cuda/include/ops/group/mma/warp/warp.cuh new file mode 100644 index 0000000000..90692c35d9 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warp/warp.cuh @@ -0,0 +1,947 @@ +/** + * @file + * @brief Matrix multiply-accumulate operations for tiles stored in registers. + */ + +/** + * @brief Perform the HMMA.16816 operation. + * + * This function performs the half-precision matrix multiply-accumulate operation + * using the `mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32` instruction. + * + * @param[out] d0 The first half of the output float2 accumulator. + * @param[out] d1 The second half of the output float2 accumulator. + * @param[in] a0 The first half of the first input bf16_2 matrix. + * @param[in] a1 The second half of the first input bf16_2 matrix. + * @param[in] a2 The first half of the second input bf16_2 matrix. + * @param[in] a3 The second half of the second input bf16_2 matrix. + * @param[in] b0 The first half of the bf16_2 matrix B. + * @param[in] b1 The second half of the bf16_2 matrix B. + * @param[in] c0 The first half of the float2 accumulator matrix C. + * @param[in] c1 The second half of the float2 accumulator matrix C. + */ +__device__ static inline void hmma16816( float2 &d0, float2 &d1, + const bf16_2 &a0, const bf16_2 &a1, const bf16_2 &a2, const bf16_2 &a3, + const bf16_2 &b0, const bf16_2 &b1, + const float2 &c0, const float2 &c1 ) { + asm volatile( + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " \ + "{%0, %1, %2, %3}, " \ + "{%4, %5, %6, %7}, " \ + "{%8, %9}, " \ + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(d0.x), "+f"(d0.y), + "+f"(d1.x), "+f"(d1.y) + + // A matrix + : "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)), + "r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)), + + // B matrix + "r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)), + + // C matrix + "f"(c0.x), "f"(c0.y), + "f"(c1.x), "f"(c1.y) + ); +} +/** + * @brief Perform the HMMA.16816 operation with inputs as fp16 and fp32 accumulators + * + * This function performs the half-precision matrix multiply-accumulate operation + * using the `mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32` instruction. + * + * @param[out] d0 The first half of the output float2 accumulator. + * @param[out] d1 The second half of the output float2 accumulator. + * @param[in] a0 The first half of the first input half_2 matrix. + * @param[in] a1 The second half of the first input half_2 matrix. + * @param[in] a2 The first half of the second input half_2 matrix. + * @param[in] a3 The second half of the second input half_2 matrix. + * @param[in] b0 The first half of the half_2 matrix B. + * @param[in] b1 The second half of the half_2 matrix B. + * @param[in] c0 The first half of the float2 accumulator matrix C. + * @param[in] c1 The second half of the float2 accumulator matrix C. + */ +__device__ static inline void hmma16816( float2 &d0, float2 &d1, + const half_2 &a0, const half_2 &a1, const half_2 &a2, const half_2 &a3, + const half_2 &b0, const half_2 &b1, + const float2 &c0, const float2 &c1 ) { + asm volatile( + // https://docs.nvidia.com/cuda/parallel-thread-execution/#multiply-and-accumulate-instruction-mma + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " \ + "{%0, %1, %2, %3}, " \ + "{%4, %5, %6, %7}, " \ + "{%8, %9}, " \ + "{%10, %11, %12, %13};" + + // D matrix + : "+f"(d0.x), "+f"(d0.y), + "+f"(d1.x), "+f"(d1.y) + + // A matrix + : "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)), + "r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)), + + // B matrix + "r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)), + + // C matrix + "f"(c0.x), "f"(c0.y), + "f"(c1.x), "f"(c1.y) + ); +} +/** + * @brief Perform the HMMA.16816 operation. + * + * This function performs the half-precision matrix multiply-accumulate operation + * using the `mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16` instruction. + * + * @param[out] d0 The first half of the output half_2 accumulator. + * @param[out] d1 The second half of the output half_2 accumulator. + * @param[in] a0 The first half of the first input half_2 matrix. + * @param[in] a1 The second half of the first input half_2 matrix. + * @param[in] a2 The first half of the second input half_2 matrix. + * @param[in] a3 The second half of the second input half_2 matrix. + * @param[in] b0 The first half of the half_2 matrix B. + * @param[in] b1 The second half of the half_2 matrix B. + * @param[in] c0 The first half of the half_2 accumulator matrix C. + * @param[in] c1 The second half of the half_2 accumulator matrix C. + */ +__device__ static inline void hmma16816( half_2 &d0, half_2 &d1, + const half_2 &a0, const half_2 &a1, const half_2 &a2, const half_2 &a3, + const half_2 &b0, const half_2 &b1, + const half_2 &c0, const half_2 &c1 ) { + asm volatile( + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \ + "{%0, %1}, " \ + "{%2, %3, %4, %5}, " \ + "{%6, %7}, " \ + "{%8, %9};" + + // D matrix + : "=r"(*(uint32_t*)(&d0)), "=r"(*(uint32_t*)(&d1)) + + // A matrix + : "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)), + "r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)), + + // B matrix + "r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)), + + // C matrix + "r"(*(uint32_t*)(&c0)), "r"(*(uint32_t*)(&c1)) + ); +} + +#ifdef KITTENS_HOPPER +/** +* @brief Perform the HMMA.16816 operation for FP8 using fp8e4m3_2. +* +* Using mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 instruction +* but with fp8e4m3_2 (2 FP8 values) instead of fp8e4m3_4 +*/ +/** + * @brief Perform the HMMA.16816 operation for FP8. + * + * This function performs the fp8-precision matrix multiply-accumulate operation + * using the `mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32` instruction. + * + * @param[out] d0 The first half of the output float2 accumulator. + * @param[out] d1 The second half of the output float2 accumulator. + * @param[in] a0,a1,a2,a3 Input FP8 matrix A values + * @param[in] b0,b1 Input FP8 matrix B values + * @param[in] c0,c1 Input float2 accumulator matrix C values + */ +__device__ static inline void hmma16816( float2 &d0, float2 &d1, + const fp8e4m3_4 &a0, const fp8e4m3_4 &a1, + const fp8e4m3_4 &a2, const fp8e4m3_4 &a3, + const fp8e4m3_4 &b0, const fp8e4m3_4 &b1, + const float2 &c0, const float2 &c1) { + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0, %1, %2, %3}, " + "{%4, %5, %6, %7}, " + "{%8, %9}, " + "{%10, %11, %12, %13};" + + // D matrix (output) + : "+f"(d0.x), "+f"(d0.y), + "+f"(d1.x), "+f"(d1.y) + + // A matrix + : "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)), + "r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)), + + // B matrix + "r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)), + + // C matrix + "f"(c0.x), "f"(c0.y), + "f"(c1.x), "f"(c1.y) + ); +} +#endif + +/** + * @brief Base matrix multiply-accumulate operation for row layout. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +/** + * @brief Base matrix multiply-accumulate operation for row layout + * with fp16 inputs and fp32 accumulators. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +#ifdef KITTENS_HOPPER +/** + * @brief Base matrix multiply-accumulate operation for row layout. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +#endif +/** + * @brief Base matrix multiply-accumulate operation for row layout. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +/** + * @brief Base dot product operation for row layout. + * + * This function performs the base dot product operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in row-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_ABt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in row-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], // for some reason this one seems to need to be backwards + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], // for some reason this one seems to need to be backwards + c.data[2], c.data[3] + ); +} +/** + * @brief Base dot product operation for row layout + * with fp16 inputs and fp32 accumulators. + * + * This function performs the base dot product operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in row-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_ABt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in row-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], // for some reason this one seems to need to be backwards + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], // for some reason this one seems to need to be backwards + c.data[2], c.data[3] + ); +} +#ifdef KITTENS_HOPPER +/** + * @brief Base dot product operation for row layout. + * + * This function performs the base dot product operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in row-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_ABt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in row-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], // for some reason this one seems to need to be backwards + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], // for some reason this one seems to need to be backwards + c.data[2], c.data[3] + ); +} +#endif + + +/** + * @brief Base matrix multiply-accumulate operation for row layout with transposed A. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AtB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +/** + * @brief Base matrix multiply-accumulate operation for row layout with transposed A + * with fp16 inputs and fp32 accumulators. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AtB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +#ifdef KITTENS_HOPPER +/** + * @brief Base matrix multiply-accumulate operation for row layout with transposed A. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AtB_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +#endif + +/** + * @brief Base matrix multiply-accumulate operation for row layout with transposed A and B. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AtBt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +/** + * @brief Base matrix multiply-accumulate operation for row layout with transposed A and B + * with fp16 inputs and fp32 accumulators. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in row-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AtBt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in row-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +#ifdef KITTENS_HOPPER +/** + * @brief Base matrix multiply-accumulate operation for row layout with transposed A and B. + * + * This function performs the base matrix multiply-accumulate operation + * using the `hmma16816` function for matrices in row layout. + * + * @param[out] d The output rt_base accumulator. + * @param[in] a The first input rt_base matrix. + * @param[in] b The second input rt_base matrix in column-major mode. + * @param[in] c The input rt_base accumulator matrix. + */ +__device__ static inline void mma_AtBt_base(rt_base &d, + const rt_base &a, + const rt_base &b, // in col-major mode + const rt_base &c) { + hmma16816( + d.data[0], d.data[1], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[0], b.data[2], + c.data[0], c.data[1] + ); + hmma16816( + d.data[2], d.data[3], + a.data[0], a.data[1], a.data[2], a.data[3], + b.data[1], b.data[3], + c.data[2], c.data[3] + ); +} +#endif + +/** + * @brief Matrix multiply-accumulate operation. + * + * This function performs the matrix multiply-accumulate operation + * using the `hmma16816` function. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_hf accumulator. + * @param[in] a The first input rt_hf matrix. + * @param[in] b The second input rt_hf matrix in column-major mode. + * @param[in] c The input rt_hf accumulator matrix. + */ +template +__device__ static inline void mma_AB(D &d, + const A &a, + const B &b, + const C &c) { + KITTENS_CHECK_WARP + static_assert(D::rows == A::rows && D::cols == B::cols); // Check D matches A, B + static_assert(A::cols == B::rows); // Check reduction dim is same + static_assert(D::rows == C::rows && D::cols == C::cols); // Check D matches C + #ifdef KITTENS_HOPPER + static_assert( + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + ); + #else + static_assert( + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + ); + #endif + #pragma unroll + for(int n = 0; n < D::height; n++) { + #pragma unroll + for(int m = 0; m < D::width; m++) { + mma_AB_base( + d.tiles[n][m], + a.tiles[n][0], + b.tiles[0][m], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < A::width; k++) { + mma_AB_base( + d.tiles[n][m], + a.tiles[n][k], + b.tiles[k][m], + d.tiles[n][m] + ); + } + } + } +} +/** + * @brief Dot product operation for row layout. + * + * This function performs the dot product operation + * using the `hmma16816` function. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_fl accumulator. + * @param[in] a The first input rt_bf matrix. + * @param[in] b The second input rt_bf matrix in row-major mode. + * @param[in] c The input rt_fl accumulator matrix. + */ +template +__device__ static inline void mma_ABt(D &d, + const A &a, + const B &b, // notice row and (M, K) instead of col and (K, M) + const C &c) { + KITTENS_CHECK_WARP + static_assert(D::rows == A::rows && D::cols == B::rows); // Check D matches A, B + static_assert(A::cols == B::cols); // Check reduction dim is same + static_assert(D::rows == C::rows && D::cols == C::cols); // Check D matches C + #ifdef KITTENS_HOPPER + static_assert( + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + ); + #else + static_assert( + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + ); + #endif + #pragma unroll + for(int n = 0; n < D::height; n++) { + #pragma unroll + for(int m = 0; m < D::width; m++) { + mma_ABt_base( + d.tiles[n][m], + a.tiles[n][0], + b.tiles[m][0], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < A::width; k++) { + mma_ABt_base( + d.tiles[n][m], + a.tiles[n][k], + b.tiles[m][k], + d.tiles[n][m] + ); + } + } + } +} +/** + * @brief Matrix multiply-accumulate operation with transposed A. + * + * This function performs the matrix multiply-accumulate operation + * using the `hmma16816` instruction. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_fl accumulator. + * @param[in] a The first input rt_bf matrix. + * @param[in] b The second input rt_bf matrix in column-major mode. + * @param[in] c The input rt_fl accumulator matrix. + */ +template +__device__ static inline void mma_AtB(D &d, + const A &a, + const B &b, + const C &c) { + KITTENS_CHECK_WARP + static_assert(D::rows == A::cols && D::cols == B::cols); // Check D matches A, B + static_assert(A::rows == B::rows); // Check reduction dim is same + static_assert(D::rows == C::rows && D::cols == C::cols); // Check D matches C + #ifdef KITTENS_HOPPER + static_assert( + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + ); + #else + static_assert( + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + ); + #endif + #pragma unroll + for(int n = 0; n < D::height; n++) { + #pragma unroll + for(int m = 0; m < D::width; m++) { + mma_AtB_base( + d.tiles[n][m], + a.tiles[0][n], + b.tiles[0][m], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < A::height; k++) { + mma_AtB_base( + d.tiles[n][m], + a.tiles[k][n], + b.tiles[k][m], + d.tiles[n][m] + ); + } + } + } +} +/** + * @brief Matrix multiply-accumulate operation with transposed A and B. + * + * This function performs the matrix multiply-accumulate operation + * using the `hmma16816` instruction. + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_fl accumulator. + * @param[in] a The first input rt_bf matrix. + * @param[in] b The second input rt_bf matrix in column-major mode. + * @param[in] c The input rt_fl accumulator matrix. + */ +template +__device__ static inline void mma_AtBt(D &d, + const A &a, + const B &b, + const C &c) { + KITTENS_CHECK_WARP + static_assert(D::rows == A::cols && D::cols == B::rows); // Check D matches A, B + static_assert(A::rows == B::cols); // Check reduction dim is same + static_assert(D::rows == C::rows && D::cols == C::cols); // Check D matches C + #ifdef KITTENS_HOPPER + static_assert( + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + ); + #else + static_assert( + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v && + std::is_same_v && std::is_same_v) + ); + #endif + #pragma unroll + for(int n = 0; n < D::height; n++) { + #pragma unroll + for(int m = 0; m < D::width; m++) { + mma_AtBt_base( + d.tiles[n][m], + a.tiles[0][n], + b.tiles[m][0], + c.tiles[n][m] + ); + #pragma unroll + for(int k = 1; k < A::height; k++) { + mma_AtBt_base( + d.tiles[n][m], + a.tiles[k][n], + b.tiles[m][k], + d.tiles[n][m] + ); + } + } + } +} + +template +__device__ static inline void mma(D &d, + const A &a, + const B &b, + const C &c) { + KITTENS_CHECK_WARP + if constexpr(trans_A == transpose::T) { + if constexpr(trans_B == transpose::T) { + mma_AtBt(d, a, b, c); + } else { + mma_AtB(d, a, b, c); + } + } else { + if constexpr(trans_B == transpose::T) { + mma_ABt(d, a, b, c); + } else { + mma_AB(d, a, b, c); + } + } +} +template +__device__ static inline C mma(const A &a, + const B &b, + const C &c) { + KITTENS_CHECK_WARP + C d; + if constexpr(trans_A == transpose::T) { + if constexpr(trans_B == transpose::T) { + mma_AtBt(d, a, b, c); + } else { + mma_AtB(d, a, b, c); + } + } else { + if constexpr(trans_B == transpose::T) { + mma_ABt(d, a, b, c); + } else { + mma_AB(d, a, b, c); + } + } + return d; +} + + +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------- COMPLEX INPUTS -------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- + + + +/** + * @brief Matrix multiply-accumulate operation for complex tiles + * + * This function calls mma_AB with hf arguments + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_cmplx_hf accumulator. + * @param[in] a The first input rt_cmplx_hf matrix. + * @param[in] b The second input rt_cmplx_hf matrix in column-major mode. + * @param[in] c The input rt_cmplx_hf accumulator matrix. + */ +template +__device__ static inline void mma_AB(crt_hf &d, + const crt_hf &a, + const crt_hf &b, + const crt_hf &c) { + KITTENS_CHECK_WARP + + // Copy data from input accumulate register into output + ::kittens::group<1>::copy(d.real, c.real); + ::kittens::group<1>::copy(d.imag, c.imag); + + // Negative on B matrix so we can use single accum register + rt_hf tmp; + // Hex value for -1 in float16 + constexpr half factor = std::bit_cast<__half>(uint16_t(0xFB80)); + ::kittens::group<1>::mul(tmp, a.imag, factor); + mma_AB(d.real, a.real, b.real, d.real); + mma_AB(d.real, tmp, b.imag, d.real); + + mma_AB(d.imag, a.real, b.imag, d.imag); + mma_AB(d.imag, a.imag, b.real, d.imag); +} +/** + * @brief Matrix multiply-accumulate operation for complex tiles + * + * This function calls mma_AB with bf16 arguments + * + * @tparam N The number of row tiles. + * @tparam K The number of column tiles for the A matrix and row tiles for the B matrix. + * @tparam M The number of column tiles for the B matrix. + * @param[out] d The output rt_cmplx_fl accumulator. + * @param[in] a The first input rt_cmplx_bf matrix. + * @param[in] b The second input rt_cmplx_bf matrix in column-major mode. + * @param[in] c The input rt_cmplx_fl accumulator matrix. + */ + +template +__device__ static inline void mma_AB(crt_fl &d, + const crt_bf &a, + const crt_bf &b, + const crt_fl &c) { + KITTENS_CHECK_WARP + + // Copy data from input accumulate register into output + ::kittens::group<1>::copy(d.real, c.real); + ::kittens::group<1>::copy(d.imag, c.imag); + + // Negative on B matrix so we can use single accum register + kittens::rt_bf tmp; + // Hex value for -1 in bf16 + constexpr bf16 factor = std::bit_cast<__nv_bfloat16>(uint16_t(0xBF80)); + ::kittens::group<1>::mul(tmp, a.imag, factor); + mma_AB(d.real, a.real, b.real, d.real); + mma_AB(d.real, tmp, b.imag, d.real); + + mma_AB(d.imag, a.real, b.imag, d.imag); + mma_AB(d.imag, a.imag, b.real, d.imag); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x112.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x112.impl new file mode 100644 index 0000000000..3cba73022c --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x112.impl @@ -0,0 +1,334 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %61, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \ + "{%56, %57, %58, %59}, " \ + "%60, " \ + "p, 1, %63, %62;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %61, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \ + "{%56, %57, %58, %59}, " \ + "%60, " \ + "p, 1, %63, %62;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %33, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27}, " \ + "{%28, %29, %30, %31}, " \ + "%32, " \ + "p, 1, %35, %34;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %58, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \ + "%56, " \ + "%57, " \ + "p, 1, %61, %59, %60;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %58, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \ + "%56, " \ + "%57, " \ + "p, 1, %61, %59, %60;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %30, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27}, " \ + "%28, " \ + "%29, " \ + "p, 1, %33, %31, %32;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x128.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x128.impl new file mode 100644 index 0000000000..8e1cf35a9a --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x128.impl @@ -0,0 +1,813 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %69, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "{%64, %65, %66, %67}, " \ + "%68, " \ + "p, 1, %71, %70;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %69, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "{%64, %65, %66, %67}, " \ + "%68, " \ + "p, 1, %71, %70;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %37, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "{%32, %33, %34, %35}, " \ + "%36, " \ + "p, 1, %39, %38;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %69, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "{%64, %65, %66, %67}, " \ + "%68, " \ + "p, 1, %70;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %69, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "{%64, %65, %66, %67}, " \ + "%68, " \ + "p, 1, %70;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %37, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "{%32, %33, %34, %35}, " \ + "%36, " \ + "p, 1, %38;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %37, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "{%32, %33, %34, %35}, " \ + "%36, " \ + "p, 1, %38;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %66, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "%64, " \ + "%65, " \ + "p, 1, %69, %67, %68;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %66, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "%64, " \ + "%65, " \ + "p, 1, %69, %67, %68;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %34, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "%32, " \ + "%33, " \ + "p, 1, %37, %35, %36;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %66, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "%64, " \ + "%65, " \ + "p, 1, %67;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %66, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "%64, " \ + "%65, " \ + "p, 1, %67;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y), + "+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y), + "+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y), + "+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y), + "+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y), + "+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y), + "+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y), + "+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y), + "+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %34, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "%32, " \ + "%33, " \ + "p, 1, %35;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %34, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "%32, " \ + "%33, " \ + "p, 1, %35;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][7].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x144.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x144.impl new file mode 100644 index 0000000000..0616a66a99 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x144.impl @@ -0,0 +1,382 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %77, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71}, " \ + "{%72, %73, %74, %75}, " \ + "%76, " \ + "p, 1, %79, %78;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %77, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71}, " \ + "{%72, %73, %74, %75}, " \ + "%76, " \ + "p, 1, %79, %78;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %41, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35}, " \ + "{%36, %37, %38, %39}, " \ + "%40, " \ + "p, 1, %43, %42;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %74, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71}, " \ + "%72, " \ + "%73, " \ + "p, 1, %77, %75, %76;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %74, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71}, " \ + "%72, " \ + "%73, " \ + "p, 1, %77, %75, %76;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %38, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35}, " \ + "%36, " \ + "%37, " \ + "p, 1, %41, %39, %40;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x16.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x16.impl new file mode 100644 index 0000000000..578991c127 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x16.impl @@ -0,0 +1,190 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %13, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "{%8, %9, %10, %11}, " \ + "%12, " \ + "p, 1, %15, %14;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %13, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "{%8, %9, %10, %11}, " \ + "%12, " \ + "p, 1, %15, %14;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %9, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " \ + "{%0, %1, %2, %3}, " \ + "{%4, %5, %6, %7}, " \ + "%8, " \ + "p, 1, %11, %10;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %10, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "%8, " \ + "%9, " \ + "p, 1, %13, %11, %12;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %10, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "%8, " \ + "%9, " \ + "p, 1, %13, %11, %12;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %6, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " \ + "{%0, %1, %2, %3}, " \ + "%4, " \ + "%5, " \ + "p, 1, %9, %7, %8;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x160.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x160.impl new file mode 100644 index 0000000000..533dd02157 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x160.impl @@ -0,0 +1,666 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %85, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \ + "{%80, %81, %82, %83}, " \ + "%84, " \ + "p, 1, %87, %86;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %85, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \ + "{%80, %81, %82, %83}, " \ + "%84, " \ + "p, 1, %87, %86;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %45, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \ + "{%40, %41, %42, %43}, " \ + "%44, " \ + "p, 1, %47, %46;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %85, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \ + "{%80, %81, %82, %83}, " \ + "%84, " \ + "p, 1, %86;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %85, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \ + "{%80, %81, %82, %83}, " \ + "%84, " \ + "p, 1, %86;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %82, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \ + "%80, " \ + "%81, " \ + "p, 1, %85, %83, %84;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %82, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \ + "%80, " \ + "%81, " \ + "p, 1, %85, %83, %84;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %42, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \ + "%40, " \ + "%41, " \ + "p, 1, %45, %43, %44;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %82, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \ + "%80, " \ + "%81, " \ + "p, 1, %83;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %82, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \ + "%80, " \ + "%81, " \ + "p, 1, %83;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x176.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x176.impl new file mode 100644 index 0000000000..4a8f355cdb --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x176.impl @@ -0,0 +1,430 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %93, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87}, " \ + "{%88, %89, %90, %91}, " \ + "%92, " \ + "p, 1, %95, %94;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %93, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87}, " \ + "{%88, %89, %90, %91}, " \ + "%92, " \ + "p, 1, %95, %94;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %49, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43}, " \ + "{%44, %45, %46, %47}, " \ + "%48, " \ + "p, 1, %51, %50;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %90, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87}, " \ + "%88, " \ + "%89, " \ + "p, 1, %93, %91, %92;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %90, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87}, " \ + "%88, " \ + "%89, " \ + "p, 1, %93, %91, %92;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %46, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43}, " \ + "%44, " \ + "%45, " \ + "p, 1, %49, %47, %48;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x192.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x192.impl new file mode 100644 index 0000000000..e5e73f3459 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x192.impl @@ -0,0 +1,674 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %101, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \ + "{%96, %97, %98, %99}, " \ + "%100, " \ + "p, 1, %103, %102;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %101, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \ + "{%96, %97, %98, %99}, " \ + "%100, " \ + "p, 1, %103, %102;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %53, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "{%48, %49, %50, %51}, " \ + "%52, " \ + "p, 1, %55, %54;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %101, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \ + "{%96, %97, %98, %99}, " \ + "%100, " \ + "p, 1, %102;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %101, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \ + "{%96, %97, %98, %99}, " \ + "%100, " \ + "p, 1, %102;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %98, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \ + "%96, " \ + "%97, " \ + "p, 1, %101, %99, %100;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %98, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \ + "%96, " \ + "%97, " \ + "p, 1, %101, %99, %100;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %50, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "%48, " \ + "%49, " \ + "p, 1, %53, %51, %52;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %98, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \ + "%96, " \ + "%97, " \ + "p, 1, %99;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x208.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x208.impl new file mode 100644 index 0000000000..92325639ac --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x208.impl @@ -0,0 +1,478 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %109, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103}, " \ + "{%104, %105, %106, %107}, " \ + "%108, " \ + "p, 1, %111, %110;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %109, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103}, " \ + "{%104, %105, %106, %107}, " \ + "%108, " \ + "p, 1, %111, %110;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %57, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51}, " \ + "{%52, %53, %54, %55}, " \ + "%56, " \ + "p, 1, %59, %58;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %106, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103}, " \ + "%104, " \ + "%105, " \ + "p, 1, %109, %107, %108;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %106, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103}, " \ + "%104, " \ + "%105, " \ + "p, 1, %109, %107, %108;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %54, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51}, " \ + "%52, " \ + "%53, " \ + "p, 1, %57, %55, %56;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x224.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x224.impl new file mode 100644 index 0000000000..6405bf7cad --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x224.impl @@ -0,0 +1,826 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %117, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \ + "{%112, %113, %114, %115}, " \ + "%116, " \ + "p, 1, %119, %118;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %117, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \ + "{%112, %113, %114, %115}, " \ + "%116, " \ + "p, 1, %119, %118;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %61, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \ + "{%56, %57, %58, %59}, " \ + "%60, " \ + "p, 1, %63, %62;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %117, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \ + "{%112, %113, %114, %115}, " \ + "%116, " \ + "p, 1, %118;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %117, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \ + "{%112, %113, %114, %115}, " \ + "%116, " \ + "p, 1, %118;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %114, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \ + "%112, " \ + "%113, " \ + "p, 1, %117, %115, %116;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %114, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \ + "%112, " \ + "%113, " \ + "p, 1, %117, %115, %116;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %58, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \ + "%56, " \ + "%57, " \ + "p, 1, %61, %59, %60;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %114, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \ + "%112, " \ + "%113, " \ + "p, 1, %117, %115, %116;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %114, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \ + "%112, " \ + "%113, " \ + "p, 1, %117, %115, %116;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x240.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x240.impl new file mode 100644 index 0000000000..7a3246d2e6 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x240.impl @@ -0,0 +1,526 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %125, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119}, " \ + "{%120, %121, %122, %123}, " \ + "%124, " \ + "p, 1, %127, %126;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %125, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119}, " \ + "{%120, %121, %122, %123}, " \ + "%124, " \ + "p, 1, %127, %126;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %65, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59}, " \ + "{%60, %61, %62, %63}, " \ + "%64, " \ + "p, 1, %67, %66;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %122, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119}, " \ + "%120, " \ + "%121, " \ + "p, 1, %125, %123, %124;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %122, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119}, " \ + "%120, " \ + "%121, " \ + "p, 1, %125, %123, %124;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %62, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59}, " \ + "%60, " \ + "%61, " \ + "p, 1, %65, %63, %64;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x256.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x256.impl new file mode 100644 index 0000000000..cc46f822fb --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x256.impl @@ -0,0 +1,1260 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %133, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "{%128, %129, %130, %131}, " \ + "%132, " \ + "p, 1, %135, %134;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %133, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "{%128, %129, %130, %131}, " \ + "%132, " \ + "p, 1, %135, %134;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %69, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "{%64, %65, %66, %67}, " \ + "%68, " \ + "p, 1, %71, %70;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %133, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "{%128, %129, %130, %131}, " \ + "%132, " \ + "p, 1, %134;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %133, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "{%128, %129, %130, %131}, " \ + "%132, " \ + "p, 1, %134;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %69, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "{%64, %65, %66, %67}, " \ + "%68, " \ + "p, 1, %70;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %69, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "{%64, %65, %66, %67}, " \ + "%68, " \ + "p, 1, %70;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %130, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "%128, " \ + "%129, " \ + "p, 1, %133, %131, %132;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %130, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "%128, " \ + "%129, " \ + "p, 1, %133, %131, %132;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %66, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "%64, " \ + "%65, " \ + "p, 1, %69, %67, %68;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %130, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "%128, " \ + "%129, " \ + "p, 1, %131;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %130, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m.e5m " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120, %121, %122, %123, %124, %125, %126, %127}, " \ + "%128, " \ + "%129, " \ + "p, 1, %131;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y), + "+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y), + "+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y), + "+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y), + "+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y), + "+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y), + "+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y), + "+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y), + "+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y), + "+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y), + "+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y), + "+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y), + "+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y), + "+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y), + "+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y), + "+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y), + "+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y), + "+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y), + "+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y), + "+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y), + "+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y), + "+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y), + "+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y), + "+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y), + "+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y), + "+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y), + "+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y), + "+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y), + "+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y), + "+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y), + "+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y), + "+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y), + "+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y), + "+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y), + "+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y), + "+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y), + "+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y), + "+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y), + "+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y), + "+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y), + "+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y), + "+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y), + "+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y), + "+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y), + "+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y), + "+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y), + "+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y), + "+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y), + "+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y), + "+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y), + "+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y), + "+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y), + "+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y), + "+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y), + "+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y), + "+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y), + "+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y), + "+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y), + "+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y), + "+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y), + "+f"(dst.tiles[0][15].data[0].x), "+f"(dst.tiles[0][15].data[0].y), + "+f"(dst.tiles[0][15].data[1].x), "+f"(dst.tiles[0][15].data[1].y), + "+f"(dst.tiles[0][15].data[2].x), "+f"(dst.tiles[0][15].data[2].y), + "+f"(dst.tiles[0][15].data[3].x), "+f"(dst.tiles[0][15].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %66, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "%64, " \ + "%65, " \ + "p, 1, %67;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %66, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \ + "%64, " \ + "%65, " \ + "p, 1, %67;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][10].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][11].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][12].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][13].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][14].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][15].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x32.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x32.impl new file mode 100644 index 0000000000..12507de27c --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x32.impl @@ -0,0 +1,446 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %21, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "{%16, %17, %18, %19}, " \ + "%20, " \ + "p, 1, %23, %22;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %21, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "{%16, %17, %18, %19}, " \ + "%20, " \ + "p, 1, %23, %22;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %13, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "{%8, %9, %10, %11}, " \ + "%12, " \ + "p, 1, %15, %14;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %21, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "{%16, %17, %18, %19}, " \ + "%20, " \ + "p, 1, %22;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %21, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "{%16, %17, %18, %19}, " \ + "%20, " \ + "p, 1, %22;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %13, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "{%8, %9, %10, %11}, " \ + "%12, " \ + "p, 1, %14;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %18, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "%16, " \ + "%17, " \ + "p, 1, %21, %19, %20;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %18, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "%16, " \ + "%17, " \ + "p, 1, %21, %19, %20;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %10, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "%8, " \ + "%9, " \ + "p, 1, %13, %11, %12;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %18, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "%16, " \ + "%17, " \ + "p, 1, %19;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %18, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "%16, " \ + "%17, " \ + "p, 1, %19;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %10, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "%8, " \ + "%9, " \ + "p, 1, %11;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %10, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7}, " \ + "%8, " \ + "%9, " \ + "p, 1, %11;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x48.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x48.impl new file mode 100644 index 0000000000..d573d922d8 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x48.impl @@ -0,0 +1,238 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %29, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "{%24, %25, %26, %27}, " \ + "%28, " \ + "p, 1, %31, %30;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %29, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "{%24, %25, %26, %27}, " \ + "%28, " \ + "p, 1, %31, %30;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %17, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11}, " \ + "{%12, %13, %14, %15}, " \ + "%16, " \ + "p, 1, %19, %18;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %26, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "%24, " \ + "%25, " \ + "p, 1, %29, %27, %28;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %26, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "%24, " \ + "%25, " \ + "p, 1, %29, %27, %28;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %14, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11}, " \ + "%12, " \ + "%13, " \ + "p, 1, %17, %15, %16;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x64.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x64.impl new file mode 100644 index 0000000000..59605361fb --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x64.impl @@ -0,0 +1,587 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %37, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "{%32, %33, %34, %35}, " \ + "%36, " \ + "p, 1, %39, %38;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %37, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "{%32, %33, %34, %35}, " \ + "%36, " \ + "p, 1, %39, %38;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %21, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "{%16, %17, %18, %19}, " \ + "%20, " \ + "p, 1, %23, %22;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %37, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "{%32, %33, %34, %35}, " \ + "%36, " \ + "p, 1, %38;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %37, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "{%32, %33, %34, %35}, " \ + "%36, " \ + "p, 1, %38;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %21, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "{%16, %17, %18, %19}, " \ + "%20, " \ + "p, 1, %22;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %21, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "{%16, %17, %18, %19}, " \ + "%20, " \ + "p, 1, %22;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %34, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "%32, " \ + "%33, " \ + "p, 1, %37, %35, %36;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %34, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "%32, " \ + "%33, " \ + "p, 1, %37, %35, %36;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %18, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "%16, " \ + "%17, " \ + "p, 1, %21, %19, %20;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %34, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "%32, " \ + "%33, " \ + "p, 1, %35;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), // transpose is not supported for FP8 + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %34, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \ + "%32, " \ + "%33, " \ + "p, 1, %35;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), // transpose is not supported for FP8 + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %18, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "%16, " \ + "%17, " \ + "p, 1, %19;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %18, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \ + "%16, " \ + "%17, " \ + "p, 1, %19;\n" \ + "}\n" + // a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x80.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x80.impl new file mode 100644 index 0000000000..c813c82246 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x80.impl @@ -0,0 +1,286 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %45, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \ + "{%40, %41, %42, %43}, " \ + "%44, " \ + "p, 1, %47, %46;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %45, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \ + "{%40, %41, %42, %43}, " \ + "%44, " \ + "p, 1, %47, %46;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %25, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19}, " \ + "{%20, %21, %22, %23}, " \ + "%24, " \ + "p, 1, %27, %26;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %42, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \ + "%40, " \ + "%41, " \ + "p, 1, %45, %43, %44;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %42, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \ + "%40, " \ + "%41, " \ + "p, 1, %45, %43, %44;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %22, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19}, " \ + "%20, " \ + "%21, " \ + "p, 1, %25, %23, %24;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x96.impl b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x96.impl new file mode 100644 index 0000000000..29ca752f17 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x96.impl @@ -0,0 +1,703 @@ +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt_base & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %53, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "{%48, %49, %50, %51}, " \ + "%52, " \ + "p, 1, %55, %54;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %53, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "{%48, %49, %50, %51}, " \ + "%52, " \ + "p, 1, %55, %54;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %29, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "{%24, %25, %26, %27}, " \ + "%28, " \ + "p, 1, %31, %30;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %53, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "{%48, %49, %50, %51}, " \ + "%52, " \ + "p, 1, %54;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %53, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "{%48, %49, %50, %51}, " \ + "%52, " \ + "p, 1, %54;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %29, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "{%24, %25, %26, %27}, " \ + "%28, " \ + "p, 1, %30;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %29, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "{%24, %25, %26, %27}, " \ + "%28, " \ + "p, 1, %30;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]) + + : "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]), + "r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]), + + "l"(b_st_desc), + "r"(scale_d), + // "n"(trans_b), + "n"(scale_b) + ); + } + + } + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ) { + static_assert( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v), + "Invalid type combination for WGMMA." + ); + static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option"); + // ----- BF16,BF16 -> FP32 ----- // + if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %50, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "%48, " \ + "%49, " \ + "p, 1, %53, %51, %52;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %50, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "%48, " \ + "%49, " \ + "p, 1, %53, %51, %52;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP16,FP16 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %26, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "%24, " \ + "%25, " \ + "p, 1, %29, %27, %28;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + "n"(trans_a), + "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %50, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "%48, " \ + "%49, " \ + "p, 1, %51;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP32 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %50, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \ + "%48, " \ + "%49, " \ + "p, 1, %51;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y), + "+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y), + "+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y), + "+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y), + "+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y), + "+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y), + "+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y), + "+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y), + "+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y), + "+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y), + "+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y), + "+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y), + "+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y), + "+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y), + "+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y), + "+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y), + "+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y), + "+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y), + "+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y), + "+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y), + "+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y), + "+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y), + "+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y), + "+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %26, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "%24, " \ + "%25, " \ + "p, 1, %27;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + // ----- FP8,FP8 -> FP16 ----- // + else if constexpr (std::is_same_v && std::is_same_v) { + asm volatile ( + "{\n" + ".reg .pred p;\n" \ + "setp.ne.b32 p, %26, 0;\n" \ + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " \ + "{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \ + "%24, " \ + "%25, " \ + "p, 1, %27;\n" \ + "}\n" + // a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b + + : "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][0].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][1].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][2].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][3].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][4].data[3]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[0]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[1]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[2]), + "+r"(*(uint32_t*)&dst.tiles[0][5].data[3]) + + : "l"(a_st_desc), + "l"(b_st_desc), + + "r"(scale_d), + // "n"(trans_a), + // "n"(trans_b), + "n"(scale_b) + ); + } + } +}; \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/base.cuh b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/base.cuh new file mode 100644 index 0000000000..6bd09c6816 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/base/base.cuh @@ -0,0 +1,47 @@ +#pragma once + +#include "../../../../../common/common.cuh" +#include "../../../../../types/types.cuh" + +namespace kittens { +namespace detail { +namespace wgmma { + +// templated wrapper for PTX +template +struct base { + template __device__ static inline void rt_st( + rt &dst, + const rt & a_rt, + const uint64_t b_st_desc, + int scale_d = 1 + ); + template __device__ static inline void st_st( + rt &dst, + const uint64_t a_st_desc, + const uint64_t b_st_desc, + int scale_d = 1 + ); +}; + +// all the ptx's +#include "64x16.impl" +#include "64x32.impl" +#include "64x48.impl" +#include "64x64.impl" +#include "64x80.impl" +#include "64x96.impl" +#include "64x112.impl" +#include "64x128.impl" +#include "64x144.impl" +#include "64x160.impl" +#include "64x176.impl" +#include "64x192.impl" +#include "64x208.impl" +#include "64x224.impl" +#include "64x240.impl" +#include "64x256.impl" + +} // namespace wgmma +} // namespace detail +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/mma/warpgroup/warpgroup.cuh b/extra/thunder/cuda/include/ops/group/mma/warpgroup/warpgroup.cuh new file mode 100644 index 0000000000..ae4c315281 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/mma/warpgroup/warpgroup.cuh @@ -0,0 +1,1170 @@ +/** + * @file + * @brief Warpgroup matrix-multiply accumulate operations. These ops are necessary to achieve full utilization on H100 GPUs. + */ + + + +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- +// ------------------------------------------------------ FENCES ------------------------------------------------------ +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- + + +/** + * @brief Synchronize the warp group and ensure that all writes to shared memory are visible to all threads in the warp group. + * + * This function acts as a fence for shared memory operations, ensuring that all previous writes are visible before proceeding. + * This function should be called before running wgmma::mma or wgmma::dot instructions. + * + * @tparam height The height of the matrix `dst`. + * @tparam width The width of the matrix `dst`. + * @param dst[in,out] The destination register-tile matrix to be synchronized. + */ +template +__device__ static inline void mma_fence(D &dst) { + KITTENS_CHECK_WARPGROUP + #pragma unroll + for(int i = 0; i < D::height; i++) { + #pragma unroll + for(int j = 0; j < D::width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + if constexpr(std::is_same_v) { + asm volatile("" : "+f"(dst.tiles[i][j].data[k].x) :: "memory"); + asm volatile("" : "+f"(dst.tiles[i][j].data[k].y) :: "memory"); + } else { + asm volatile("" : "+r"(*(uint32_t*)&dst.tiles[i][j].data[k]) :: "memory"); + } + } + } + } + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); +} +template +__device__ static inline void mma_fence(D &dst) { + KITTENS_CHECK_WARPGROUP + #pragma unroll + for(int i = 0; i < D::height; i++) { + #pragma unroll + for(int j = 0; j < D::width; j++) { + #pragma unroll + for(int k = 0; k < dst.real.packed_per_tile; k++) { + if constexpr(std::is_same_v) { + asm volatile("" : "+f"(dst.real.tiles[i][j].data[k].x) :: "memory"); + asm volatile("" : "+f"(dst.real.tiles[i][j].data[k].y) :: "memory"); + asm volatile("" : "+f"(dst.imag.tiles[i][j].data[k].x) :: "memory"); + asm volatile("" : "+f"(dst.imag.tiles[i][j].data[k].y) :: "memory"); + } else { + asm volatile("" : "+r"(*(uint32_t*)&dst.real.tiles[i][j].data[k]) :: "memory"); + asm volatile("" : "+r"(*(uint32_t*)&dst.imag.tiles[i][j].data[k]) :: "memory"); + } + } + } + } + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); +} +template // prevents static assert being instantiated unless called. +__device__ static inline void mma_fence() { + KITTENS_CHECK_WARPGROUP + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); +} + +/** + * @brief Commit the current set of warp group matrix multiply accumulate calls. + */ +template // prevents static assert being instantiated unless called. +__device__ static inline void mma_commit_group() { + KITTENS_CHECK_WARPGROUP + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +} + +/** + * @brief Wait for the warp group to reach a synchronization point. + * + * This function stalls the current warpgroup until enough WGMMA committed groups have been completed. + * + * @tparam N The number of remaining active WGMMA committed groups allowed. This will stall until the number of active groups is less than or equal to N. Defaults to 0. + */ +template +__device__ static inline void mma_async_wait() { + KITTENS_CHECK_WARPGROUP + asm volatile ("wgmma.wait_group.sync.aligned %0;" : : "n"(N) : "memory"); +} + + +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- +// ------------------------------------------------------ NORMAL ------------------------------------------------------ +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- + +/* + ### OPTIONS: + + REG+SMEM -> REG + - mma_AB (accum) [DONE] + - mm_AB (reset) [DONE] + - mma_ABt (accum) [DONE] + - mm_ABt (reset) [DONE] + + SMEM+SMEM -> REG + - mma_AB (accum) [DONE] + - mm_AB (reset) [DONE] + - mma_ABt (accum) [DONE] + - mm_ABt (reset) [DONE] + - mma_AtB (accum) [DONE] + - mm_AtB (reset) [DONE] + - mma_AtBt (accum) [DONE] + - mm_AtBt (reset) [DONE] + +Note: mma is an alias for mma_AB and dot is an alias for mma_ABt +*/ + +// [(register, shared) -> register] edition +/** + * @brief Perform matrix multiply-accumulate operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function multiplies a register tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam N_DIV_4 The height of the matrix `a` divided by 4. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The width of the matrices `b` and `d`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source register tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_AB(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M_DIV_4 = A::height; + static_assert(D::height == M_DIV_4); // output register is correctly sized + constexpr int N = B::width; + constexpr int K = A::width; + static_assert(B::height == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + #endif + using base = kittens::detail::wgmma::base*N, 0, 1>; + kittens::st_descriptor, 1> b_desc(b); // apologies for this hack -- it either calls ST constructor or copy constructor. + + if constexpr (fence) { mma_fence(d); } + + // Do it + #pragma unroll + for(int m = 0; m < M_DIV_4; m++) { + rt, TILE_COL_DIM*N, ducks::rt_layout::row> &d_ref = group<1>::subtile_inplace>(d, m); + base::rt_st( + d_ref, + a.tiles[m][0], + b_desc.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::rt_st( + d_ref, + a.tiles[m][k], + b_desc.chunk_descriptor(k), + 1 + ); + } + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_AB(D &d, + const A &a, + const B &b) { + mma_AB(d, a, b); +} + +template +__device__ static inline void mma_AB(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M = A::height; + static_assert(M == 4); + static_assert(D::height == 1); // output register is correctly sized + constexpr int N = B::width; + constexpr int K = A::width; + static_assert(B::height == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + #endif + using base = kittens::detail::wgmma::base*N, 0, 1>; + kittens::st_descriptor, 0> a_desc(a); + kittens::st_descriptor, 1> b_desc(b); + + if constexpr (fence) { mma_fence(d); } + + // Do it + base::st_st( + d, + a_desc.chunk_descriptor(0), + b_desc.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d, + a_desc.chunk_descriptor(k), + b_desc.chunk_descriptor(k), + 1 + ); + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_AB(D &d, + const A &a, + const B &b) { + mma_AB(d, a, b); +} + +// [(register, shared) -> register] edition +/** + * @brief Perform matrix outer product operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function computes an outer product of a register tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam N_DIV_4 The height of the matrix `a` divided by 4. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source register tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_ABt(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M_DIV_4 = A::height; + static_assert(D::height == M_DIV_4); // output register is correctly sized + constexpr int N = B::height; + constexpr int K = A::width; + static_assert(B::width == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + using base = kittens::detail::wgmma::base*N, 0, 0>; + kittens::st_descriptor, 0> b_desc(b); + + if constexpr (fence) { mma_fence(d); } + + // Do it + #pragma unroll + for(int m = 0; m < M_DIV_4; m++) { + rt, TILE_COL_DIM*N, ducks::rt_layout::row> &d_ref = group<1>::subtile_inplace>(d, m); + base::rt_st( + d_ref, + a.tiles[m][0], + b_desc.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::rt_st( + d_ref, + a.tiles[m][k], + b_desc.chunk_descriptor(k), + 1 + ); + } + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_ABt(D &d, + const A &a, + const B &b) { + mma_ABt(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix outer product operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_A The layout of the matrix `a`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source shared tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_ABt(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M = A::height; + static_assert(M == 4); + static_assert(D::height == 1); // output register is correctly sized + constexpr int N = B::height; + constexpr int K = A::width; + static_assert(B::width == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + using base = kittens::detail::wgmma::base*N, 0, 0>; + kittens::st_descriptor, 0> a_desc(a); + kittens::st_descriptor, 0> b_desc(b); + + if constexpr (fence) { mma_fence(d); } + + // Do it + base::st_st( + d, + a_desc.chunk_descriptor(0), + b_desc.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d, + a_desc.chunk_descriptor(k), + b_desc.chunk_descriptor(k), + 1 + ); + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_ABt(D &d, + const A &a, + const B &b) { + mma_ABt(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix multiply using warp group matrix multiply-accumulate (WGMMA) primitives, with A transposed. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_A The layout of the matrix `a`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source shared tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_AtB(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M = A::width; + static_assert(M == 4); + static_assert(D::height == 1); // output register is correctly sized + constexpr int N = B::width; + constexpr int K = A::height; + static_assert(B::height == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + #endif + using base = kittens::detail::wgmma::base*N, 1, 1>; + kittens::st_descriptor, 1> a_desc(a); + kittens::st_descriptor, 1> b_desc(b); + + if constexpr (fence) { mma_fence(d); } + + // Do it + base::st_st( + d, + a_desc.chunk_descriptor(0), + b_desc.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d, + a_desc.chunk_descriptor(k), + b_desc.chunk_descriptor(k), + 1 + ); + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_AtB(D &d, + const A &a, + const B &b) { + mma_AtB(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix multiply using warp group matrix multiply-accumulate (WGMMA) primitives, with A and B transposed. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam D The destination register tile type. + * @tparam A The source shared tile type. + * @tparam B The source shared tile type. + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + */ +template +__device__ static inline void mma_AtBt(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M = A::width; + static_assert(M == 4); + static_assert(D::height == 1); // output register is correctly sized + constexpr int N = B::height; + constexpr int K = A::height; + static_assert(B::width == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + #endif + using base = kittens::detail::wgmma::base*N, 1, 0>; + kittens::st_descriptor, 1> a_desc(a); + kittens::st_descriptor, 0> b_desc(b); + + if constexpr (fence) { mma_fence(d); } + + // Do it + base::st_st( + d, + a_desc.chunk_descriptor(0), + b_desc.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d, + a_desc.chunk_descriptor(k), + b_desc.chunk_descriptor(k), + 1 + ); + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_AtBt(D &d, + const A &a, + const B &b) { + mma_AtBt(d, a, b); +} + + + +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------- COMPLEX INPUTS -------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- +// -------------------------------------------------------------------------------------------------------------------- + + +/* + ### OPTIONS: + + REG+SMEM -> REG + - mma_AB (accum) [TODO] + - mm_AB (reset) [TODO] + - mma_ABt (accum) [TODO] + - mm_ABt (reset) [TODO] + + SMEM+SMEM -> REG + - mma_AB (accum) [TODO] + - mm_AB (reset) [TODO] + - mma_ABt (accum) [TODO] + - mm_ABt (reset) [TODO] + - mma_AtB (accum) [TODO] + - mm_AtB (reset) [TODO] + - mma_AtBt (accum) [TODO] + - mm_AtBt (reset) [TODO] + +Note: mma is an alias for mma_AB and dot is an alias for mma_ABt +*/ + +// [(register, shared) -> register] edition +/** + * @brief Perform matrix multiply-accumulate operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function multiplies a register tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam N_DIV_4 The height of the matrix `a` divided by 4. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The width of the matrices `b` and `d`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source register tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_AB(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M_DIV_4 = A::height; + static_assert(D::height == M_DIV_4); // output register is correctly sized + constexpr int N = B::width; + constexpr int K = A::width; + static_assert(B::height == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + #endif + using base = kittens::detail::wgmma::base*N, 0, 1>; + kittens::st_descriptor, 1> b_desc_real(b.real); + kittens::st_descriptor, 1> b_desc_imag(b.imag); + + if constexpr (fence) { mma_fence(d); } + + // Do it + #pragma unroll // Do real part + for(int m = 0; m < M_DIV_4; m++) { + rt, TILE_COL_DIM*N, ducks::rt_layout::row> &d_ref = group<1>::subtile_inplace>(d.real, m); + base::rt_st( + d_ref, + a.real.tiles[m][0], + b_desc_real.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::rt_st( + d_ref, + a.real.tiles[m][k], + b_desc_real.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::rt_st<-1>( // INVERT THE SIGN OF THE IMAGINARY PART + d_ref, + a.imag.tiles[m][k], + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + } + #pragma unroll // Do imaginary part + for(int m = 0; m < M_DIV_4; m++) { + rt, TILE_COL_DIM*N, ducks::rt_layout::row> &d_ref = group<1>::subtile_inplace>(d.imag, m); + base::rt_st( + d_ref, + a.real.tiles[m][0], + b_desc_imag.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::rt_st( + d_ref, + a.real.tiles[m][k], + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::rt_st( + d_ref, + a.imag.tiles[m][k], + b_desc_real.chunk_descriptor(k), + 1 + ); + } + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_AB(D &d, + const A &a, + const B &b) { + mma_AB(d, a, b); +} + +template +__device__ static inline void mma_AB(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M = A::height; + static_assert(M == 4); + static_assert(D::height == 1); // output register is correctly sized + constexpr int N = B::width; + constexpr int K = A::width; + static_assert(B::height == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + #endif + using base = kittens::detail::wgmma::base*N, 0, 1>; + kittens::st_descriptor, 0> a_desc_real(a.real); + kittens::st_descriptor, 0> a_desc_imag(a.imag); + kittens::st_descriptor, 1> b_desc_real(b.real); + kittens::st_descriptor, 1> b_desc_imag(b.imag); + + if constexpr (fence) { mma_fence(d); } + + // Do it + base::st_st( + d.real, + a_desc_real.chunk_descriptor(0), + b_desc_real.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d.real, + a_desc_real.chunk_descriptor(k), + b_desc_real.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::st_st<-1>( // INVERT THE SIGN OF THE IMAGINARY PART + d.real, + a_desc_imag.chunk_descriptor(k), + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + base::st_st( + d.imag, + a_desc_real.chunk_descriptor(0), + b_desc_imag.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d.imag, + a_desc_real.chunk_descriptor(k), + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::st_st( + d.imag, + a_desc_imag.chunk_descriptor(k), + b_desc_real.chunk_descriptor(k), + 1 + ); + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_AB(D &d, + const A &a, + const B &b) { + mma_AB(d, a, b); +} + +// [(register, shared) -> register] edition +/** + * @brief Perform matrix outer product operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function computes an outer product of a register tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam N_DIV_4 The height of the matrix `a` divided by 4. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source register tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_ABt(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M_DIV_4 = A::height; + static_assert(D::height == M_DIV_4); // output register is correctly sized + constexpr int N = B::height; + constexpr int K = A::width; + static_assert(B::width == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + using base = kittens::detail::wgmma::base*N, 0, 0>; + kittens::st_descriptor, 0> b_desc_real(b.real); + kittens::st_descriptor, 0> b_desc_imag(b.imag); + + if constexpr (fence) { mma_fence(d); } + + // Do it + #pragma unroll + for(int m = 0; m < M_DIV_4; m++) { + rt, TILE_ROW_DIM*N, ducks::rt_layout::row> &d_ref = group<1>::subtile_inplace>(d.real, m); + base::rt_st( + d_ref, + a.real.tiles[m][0], + b_desc_real.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::rt_st( + d_ref, + a.real.tiles[m][k], + b_desc_real.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::rt_st<-1>( // INVERT THE SIGN OF THE IMAGINARY PART + d_ref, + a.imag.tiles[m][k], + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + } + #pragma unroll + for(int m = 0; m < M_DIV_4; m++) { + rt, TILE_ROW_DIM*N, ducks::rt_layout::row> &d_ref = group<1>::subtile_inplace>(d.imag, m); + base::rt_st( + d_ref, + a.real.tiles[m][0], + b_desc_imag.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::rt_st( + d_ref, + a.real.tiles[m][k], + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::rt_st( + d_ref, + a.imag.tiles[m][k], + b_desc_real.chunk_descriptor(k), + 1 + ); + } + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_ABt(D &d, + const A &a, + const B &b) { + mma_ABt(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix outer product operation using warp group matrix multiply-accumulate (WGMMA) primitives. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_A The layout of the matrix `a`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source shared tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_ABt(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M = A::height; + static_assert(M == 4); + static_assert(D::height == 1); // output register is correctly sized + constexpr int N = B::height; + constexpr int K = A::width; + static_assert(B::width == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + using base = kittens::detail::wgmma::base*N, 0, 0>; + kittens::st_descriptor, 0> a_desc_real(a.real); + kittens::st_descriptor, 0> a_desc_imag(a.imag); + kittens::st_descriptor, 0> b_desc_real(b.real); + kittens::st_descriptor, 0> b_desc_imag(b.imag); + + if constexpr (fence) { mma_fence(d); } + + // Do it + base::st_st( + d.real, + a_desc_real.chunk_descriptor(0), + b_desc_real.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d.real, + a_desc_real.chunk_descriptor(k), + b_desc_real.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::st_st<-1>( // INVERT THE SIGN OF THE IMAGINARY PART + d.real, + a_desc_imag.chunk_descriptor(k), + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + base::st_st( + d.imag, + a_desc_real.chunk_descriptor(0), + b_desc_imag.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d.imag, + a_desc_real.chunk_descriptor(k), + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::st_st( + d.imag, + a_desc_imag.chunk_descriptor(k), + b_desc_real.chunk_descriptor(k), + 1 + ); + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_ABt(D &d, + const A &a, + const B &b) { + mma_ABt(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix multiply using warp group matrix multiply-accumulate (WGMMA) primitives, with A transposed. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + * @tparam K The common dimension of matrices `a` and `b`. + * @tparam M The height of the matrices `b` and `d`. + * @tparam L_A The layout of the matrix `a`. + * @tparam L_B The layout of the matrix `b`. + * @param d[out] The destination register tile where the result is accumulated or written. + * @param a[in] The source shared tile to be multiplied. + * @param b[in] The source shared tile to be multiplied. + */ +template +__device__ static inline void mma_AtB(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M = A::width; + static_assert(M == 4); + static_assert(D::height == 1); // output register is correctly sized + constexpr int N = B::width; + constexpr int K = A::height; + static_assert(B::height == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + #endif + using base = kittens::detail::wgmma::base*N, 1, 1>; + kittens::st_descriptor, 1> a_desc_real(a.real); + kittens::st_descriptor, 1> a_desc_imag(a.imag); + kittens::st_descriptor, 1> b_desc_real(b.real); + kittens::st_descriptor, 1> b_desc_imag(b.imag); + + if constexpr (fence) { mma_fence(d); } + + // Do it + base::st_st( + d.real, + a_desc_real.chunk_descriptor(0), + b_desc_real.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d.real, + a_desc_real.chunk_descriptor(k), + b_desc_real.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::st_st<-1>( // INVERT THE SIGN OF THE IMAGINARY PART + d.real, + a_desc_imag.chunk_descriptor(k), + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + base::st_st( + d.imag, + a_desc_real.chunk_descriptor(0), + b_desc_imag.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d.imag, + a_desc_real.chunk_descriptor(k), + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::st_st( + d.imag, + a_desc_imag.chunk_descriptor(k), + b_desc_real.chunk_descriptor(k), + 1 + ); + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_AtB(D &d, + const A &a, + const B &b) { + mma_AtB(d, a, b); +} + +// [(shared, shared) -> register] edition +/** + * @brief Perform matrix multiply using warp group matrix multiply-accumulate (WGMMA) primitives, with A and B transposed. + * + * This function computes an outer product of a shared tile `a` with a shared tile `b` and writes the result into a register tile `d`. + * + * @tparam D The destination register tile type. + * @tparam A The source shared tile type. + * @tparam B The source shared tile type. + * @tparam accumulate Whether to accumulate the result into `d` or overwrite `d`. + */ +template +__device__ static inline void mma_AtBt(D &d, + const A &a, + const B &b) { + // Checks + KITTENS_CHECK_WARPGROUP + constexpr int M = A::width; + static_assert(M == 4); + static_assert(D::height == 1); // output register is correctly sized + constexpr int N = B::height; + constexpr int K = A::height; + static_assert(B::width == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; + using T_D = D::T; + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + static_assert(!std::is_same_v && !std::is_same_v, "Currently unsupported type"); + #endif + using base = kittens::detail::wgmma::base*N, 1, 0>; + kittens::st_descriptor, 1> a_desc_real(a.real); + kittens::st_descriptor, 1> a_desc_imag(a.imag); + kittens::st_descriptor, 0> b_desc_real(b.real); + kittens::st_descriptor, 0> b_desc_imag(b.imag); + + if constexpr (fence) { mma_fence(d); } + + // Do it + base::st_st( + d.real, + a_desc_real.chunk_descriptor(0), + b_desc_real.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d.real, + a_desc_real.chunk_descriptor(k), + b_desc_real.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::st_st<-1>( // INVERT THE SIGN OF THE IMAGINARY PART + d.real, + a_desc_imag.chunk_descriptor(k), + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + base::st_st( + d.imag, + a_desc_real.chunk_descriptor(0), + b_desc_imag.chunk_descriptor(0), + accumulate + ); + #pragma unroll + for(int k = 1; k < K; k++) { + base::st_st( + d.imag, + a_desc_real.chunk_descriptor(k), + b_desc_imag.chunk_descriptor(k), + 1 + ); + } + #pragma unroll + for(int k = 0; k < K; k++) { + base::st_st( + d.imag, + a_desc_imag.chunk_descriptor(k), + b_desc_real.chunk_descriptor(k), + 1 + ); + } + mma_commit_group(); // commit the group of these WGMMA calls. +} +template +__device__ static inline void mm_AtBt(D &d, + const A &a, + const B &b) { + mma_AtBt(d, a, b); +} + +// Some extra wrappers for prettiness + +template +__device__ static inline void mma(D &d, + const A &a, + const B &b) { + if constexpr(trans_A == transpose::T) { + if constexpr(trans_B == transpose::T) { + mma_AtBt(d, a, b); + } else { + mma_AtB(d, a, b); + } + } else { + if constexpr(trans_B == transpose::T) { + mma_ABt(d, a, b); + } else { + mma_AB(d, a, b); + } + } +} +template +__device__ static inline void mm(D &d, + const A &a, + const B &b) { + if constexpr(trans_A == transpose::T) { + if constexpr(trans_B == transpose::T) { + mm_AtBt(d, a, b); + } else { + mm_AtB(d, a, b); + } + } else { + if constexpr(trans_B == transpose::T) { + mm_ABt(d, a, b); + } else { + mm_AB(d, a, b); + } + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/register.cuh b/extra/thunder/cuda/include/ops/group/register/register.cuh new file mode 100644 index 0000000000..f87cfe017a --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/register.cuh @@ -0,0 +1,7 @@ +/** + * @file + * @brief An aggregate header for warp operations on data stored in registers. + */ + +#include "tile/tile.cuh" +#include "vec/vec.cuh" \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/tile/complex/complex_conversions.cuh b/extra/thunder/cuda/include/ops/group/register/tile/complex/complex_conversions.cuh new file mode 100644 index 0000000000..6430a0a381 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/tile/complex/complex_conversions.cuh @@ -0,0 +1,98 @@ +/** + * @file + * @brief Conversions between data layouts and types for complex register tiles. + */ + +/* ---------- LAYOUT SWAPS ---------- */ + +/** + * @brief Swaps the layout of a complex register tile. + * + * This function swaps the layout of a complex register tile by + * swapping the real and imaginary component tiles' layouts + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the register tile. + * @tparam _width The width of the register tile. + * @tparam layout The current layout of the register tile. + * @param dst[out] Reference to the destination register tile where the result will be stored. + * @param src[in] Reference to the source register tile to be swapped. + */ +template +__device__ static inline void swap_layout(crt::type> &dst, const crt &src) { + swap_layout(dst.real, src.real); + swap_layout(dst.real, src.real); +} +/** + * @brief Swaps the layout of a complex register tile in place. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the register tile. + * @tparam _width The width of the register tile. + * @tparam layout The current layout of the register tile. + * @param tile[in,out] Reference to the register tile to be swapped in place. + * @return A reference to the swapped register tile. + */ +template +__device__ static inline crt::type>& swap_layout_inplace(crt &tile) { + tile.real = swap_layout_inplace(tile.real); + tile.imag = swap_layout_inplace(tile.imag); + return tile; +} + +/* ---------- TRANSPOSE ---------- */ + +/** + * @brief Transposes a complex register tile. + * + * This function is marked "sep", which means that the registers underlying dst MUST be separate + * from the registers underlying src. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the src register tile, and the width of the dst tile. + * @tparam _width The width of the src register tile, and the height of the dst tile. + * @tparam layout The layout of the register tile. + * @param dst[out] Reference to the register tile in which to store the transposed src. + * @param src[in] Reference to the register tile to be transposed. + */ +template +__device__ static inline void transpose_sep(crt &dst, const crt &src) { + transpose_sep(dst.real, src.real); + transpose_sep(dst.imag, src.imag); +} +/** + * @brief Transposes a square complex register tile in-place. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height (in units of 16) of the src register tile, and the width of the dst tile. (Must be the same as _width.) + * @tparam _width The width (in units of 16) of the src register tile, and the height of the dst tile. (Must be the same as _height.) + * @tparam layout The current layout of the register tile. + * @param src[in] Reference to the register tile to be transposed. + * @return A reference to the transposed register tile. + */ +template +__device__ static inline crt& transpose_inplace(crt &tile) { + tile.real = transpose_inplace(tile.real); + tile.imag = transpose_inplace(tile.imag); + + return tile; +} + +/* ---------- TYPE SWAPS ---------- */ + +/** + * @brief Copies a complex register tile, converting the underlying type if necessary. + * + * @tparam T2 The data type of the destination register elements. + * @tparam U2 The data type of the source register elements. + * @tparam _height The height (in units of 16) of the register tiles. + * @tparam _width The width (in units of 16) of the register tiles. + * @tparam layout The current layout of the register tile. + * @param[out] dst A reference to the destination register tile. + * @param[in] src A reference to the source register tile. + */ +template +__device__ static inline void copy(crt &dst, const crt &src) { + copy(dst.real, src.real); + copy(dst.imag, src.imag); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/tile/complex/complex_maps.cuh b/extra/thunder/cuda/include/ops/group/register/tile/complex/complex_maps.cuh new file mode 100644 index 0000000000..46fce709f0 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/tile/complex/complex_maps.cuh @@ -0,0 +1,137 @@ +/** + * @file + * @brief Map operations between complex tiles. + */ + +/** + * @brief Sets all elements of a complex tile to zero. + * + * @tparam T Complex tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void zero(T &dst) { + zero(dst.real); + zero(dst.imag); +} +/** + * @brief Applies the exponential function to each element of a complex tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the exponential function on. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + using dtype = T::dtype; + dtype tmp; + // out of place storage + dtype rdst; + dtype idst; + + // exp(a) + exp(rdst, src.real); + copy(idst, rdst); + // exp(a)cos(b) + exp(a)sin(b)i + cos(tmp, src.imag); + mul(rdst, rdst, tmp); + sin(tmp, src.imag); + mul(idst, idst, tmp); + + copy(dst.real, rdst); + copy(dst.imag, idst); +} +/** + * @brief Adds two complex tiles element-wise. + * + * @tparam T Complex Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the addition. + * @param rhs[in] Right-hand side source tile for the addition. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const T &rhs) { + add(dst.real, lhs.real, rhs.real); + add(dst.imag, lhs.imag, rhs.imag); +} +/** + * @brief Subtracts two tiles element-wise. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the subtraction. + * @param rhs[in] Right-hand side source tile for the subtraction. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const T &rhs) { + sub(dst.real, lhs.real, rhs.real); + sub(dst.imag, lhs.imag, rhs.imag); +} +/** + * @brief Multiplies two tiles element-wise. + * + * @tparam T Complex tile type. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the multiplication. + * @param rhs[in] Right-hand side source tile for the multiplication. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const T &rhs) { + using dtype = T::component; + dtype tmp; + // out of place storage regs + dtype rdst; + dtype idst; + + // (a + bi) * (c + di) --> (ac - bd) + (ad + bc)i + // Real component + mul(rdst, lhs.real, rhs.real); + mul(tmp, lhs.imag, rhs.imag); + sub(rdst, rdst, tmp); + + // Imag component + mul(idst, lhs.imag, rhs.real); + mul(tmp, lhs.real, rhs.imag); + add(idst, idst, tmp); + + copy(dst.real, rdst); + copy(dst.imag, idst); +} +/** + * @brief Divides two tiles element-wise. + * + * @tparam T Complex tile type. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the division. + * @param rhs[in] Right-hand side source tile or scalar for the division. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const T &rhs) { + using dtype = T::dtype; + dtype tmp; + dtype denom; + // out of place storage regs + dtype rdst; + dtype idst; + + // Calculate denom - square of b terms + mul(tmp, rhs.real, rhs.real); + mul(denom, rhs.imag, rhs.imag); + add(denom, tmp, denom); + // Real component + mul(rdst, lhs.real, rhs.real); + mul(tmp, lhs.imag, rhs.imag); + add(rdst, rdst, tmp); + // Imag component + mul(dst.imag, lhs.imag, rhs.real); + mul(tmp, lhs.real, rhs.imag); + sub(idst, idst, tmp); + // Divide components by denom + div(rdst, rdst, denom); + div(idst, idst, denom); + copy(dst.real, rdst); + copy(dst.imag, idst); +} + + diff --git a/extra/thunder/cuda/include/ops/group/register/tile/conversions.cuh b/extra/thunder/cuda/include/ops/group/register/tile/conversions.cuh new file mode 100644 index 0000000000..f3a7d74345 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/tile/conversions.cuh @@ -0,0 +1,415 @@ +/** + * @file + * @brief Conversions between data layouts and types for register tiles. + */ + +/* ---------- LAYOUT SWAPS ---------- */ + +/** + * @brief Perform a matrix transpose on a block of 8 bf16_2 elements using inline assembly. + * + * This low-level operation is utilized by higher-level layout swap functions to transpose + * the layout of bf16_2 elements within a register tile. The function leverages inline PTX + * assembly to efficiently swap the layout of the given block. + * + * @param[out] dst A reference to the destination bf16_2 element where the transposed result is stored. + * @param[in] src A reference to the source bf16_2 element to be transposed. + */ +__device__ static inline void swap_layout_8(bf16_2 &dst, const bf16_2 &src) { + KITTENS_CHECK_WARP + asm volatile ( + "movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n" + : "+r"(*(uint32_t*)(&dst)) + : "r"(*(uint32_t*)(&src)) + ); +} +/** + * @brief Swaps the layout of a register base tile. + * + * This function swaps the layout of a register base tile by performing a series of layout swaps + * on its constituent bf16_2 elements. It is used to change the data layout within a register tile. + * + * @tparam T2 The data type of the register tile elements. + * @tparam layout The current layout of the register tile. + * @param dst[out] Reference to the destination register base tile where the result will be stored. + * @param src[in] Reference to the source register base tile to be swapped. + */ +template +__device__ static inline void swap_layout(rt_base::type> &dst, const rt_base &src) { + swap_layout_8(dst.data[0], src.data[0]); + // technically this swap can be eliminated if we simply reinterpret the layout of the registers + // everywhere else in the code, but that feels... very likely to cause bugs and not worth it. + typename rt_base::T2 data1_cache = src.data[1]; // important for swap! + swap_layout_8(dst.data[1], src.data[2]); + swap_layout_8(dst.data[2], data1_cache); + swap_layout_8(dst.data[3], src.data[3]); +} +/** + * @brief Swaps the layout of a register tile. + * + * This function swaps the layout of a register tile by iterating over its height and width + * and performing layout swaps on each of its base elements. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the register tile. + * @tparam _width The width of the register tile. + * @tparam layout The current layout of the register tile. + * @param dst[out] Reference to the destination register tile where the result will be stored. + * @param src[in] Reference to the source register tile to be swapped. + */ +template +__device__ static inline void swap_layout(rt::type> &dst, const rt &src) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + swap_layout(dst.tiles[i][j], src.tiles[i][j]); + } + } +} + +/** + * @brief Swaps the layout of a register base tile in place. + * + * This function swaps the layout of a register base tile in place by casting it to the + * transposed layout type and then performing the layout swap. + * + * @tparam T2 The data type of the register tile elements. + * @tparam layout The current layout of the register tile. + * @param src[in] Reference to the register base tile to be swapped in place. + * @return A reference to the swapped register base tile. + */ +template +__device__ static inline rt_base::type>& swap_layout_inplace(const rt_base &src) { + rt_base::type> &dst = *(rt_base::type>*)(&src); + swap_layout(dst, src); + return dst; +} +/** + * @brief Swaps the layout of a register tile in place. + * + * This function swaps the layout of a register tile in place by iterating over its height and width + * and performing in-place layout swaps on each of its base elements. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the register tile. + * @tparam _width The width of the register tile. + * @tparam layout The current layout of the register tile. + * @param tile[in,out] Reference to the register tile to be swapped in place. + * @return A reference to the swapped register tile. + */ +template +__device__ static inline rt::type>& swap_layout_inplace(rt &tile) { + #pragma unroll + for(int i = 0; i < tile.height; i++) { + #pragma unroll + for(int j = 0; j < tile.width; j++) { + swap_layout_inplace(tile.tiles[i][j]); + } + } + return *(rt::type>*)(&tile); +} + +/* ---------- TRANSPOSE ---------- */ + +/** + * @brief Transposes a register base tile. + * + * @tparam T2 The data type of the register tile elements. + * @tparam layout The current layout of the register tile. + * @param dst[out] Reference to the register tile in which to store the transposed src. + * @param src[in] Reference to the register base tile to be transposed. + */ +template +__device__ static inline void transpose(rt_base &dst, const rt_base &src) { + swap_layout_8(dst.data[0], src.data[0]); + // technically this swap can be eliminated if we simply reinterpret the layout of the registers + // everywhere else in the code, but that feels... very likely to cause bugs and not worth it. + typename rt_base::T2 data1_cache = src.data[1]; // important for swap! + swap_layout_8(dst.data[1], src.data[2]); + swap_layout_8(dst.data[2], data1_cache); + swap_layout_8(dst.data[3], src.data[3]); +} +/** + * @brief Transposes a register tile. + * + * This function is marked "sep", which means that the registers underlying dst MUST be separate + * from the registers underlying src. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height of the src register tile, and the width of the dst tile. + * @tparam _width The width of the src register tile, and the height of the dst tile. + * @tparam layout The layout of the register tile. + * @param dst[out] Reference to the register tile in which to store the transposed src. + * @param src[in] Reference to the register tile to be transposed. + */ +template +__device__ static inline void transpose_sep(RT &dst, const rt &src) { + #pragma unroll + for(int i = 0; i < RT::height; i++) { + #pragma unroll + for(int j = 0; j < RT::width; j++) { + transpose(dst.tiles[i][j], src.tiles[j][i]); + } + } +} + +/** + * @brief Transposes a register base tile in-place. + * + * @tparam T2 The data type of the register base tile elements. + * @tparam layout The current layout of the register base tile. + * @param src[in] Reference to the register tile to be transposed. + * @return A reference to the transposed register base tile. + */ +template +__device__ static inline rt_base& transpose_inplace(rt_base &src) { + transpose(src, src); + return src; +} +/** + * @brief Transposes a square register tile in-place. + * + * @tparam T2 The data type of the register tile elements. + * @tparam _height The height (in units of 16) of the src register tile, and the width of the dst tile. (Must be the same as _width.) + * @tparam _width The width (in units of 16) of the src register tile, and the height of the dst tile. (Must be the same as _height.) + * @tparam layout The current layout of the register tile. + * @param src[in] Reference to the register tile to be transposed. + * @return A reference to the transposed register tile. + */ +template +__device__ static inline rt& transpose_inplace(rt &tile) { + static_assert(_cols == _rows, "in-place register tile transpose is only allowed for square tiles."); + #pragma unroll + for(int i = 0; i < tile.height; i++) { + #pragma unroll + for(int j = 0; j < i; j++) { + rt_base tmp; + copy(tmp, tile.tiles[i][j]); + transpose(tile.tiles[i][j], tile.tiles[j][i]); + transpose(tile.tiles[j][i], tmp); + } + transpose_inplace(tile.tiles[i][i]); + } + return tile; +} + +/* ---------- TYPE SWAPS ---------- */ + +/** + * @brief Copies a register base tile, converting the underlying type if necessary. + * + * @tparam T2 The data type of the destination register elements. + * @tparam U2 The data type of the source register elements. + * @tparam layout The current layout of the register base tile. + * @param[out] dst A reference to the destination register base tile. + * @param[in] src A reference to the source register base tile. + */ +template +__device__ static inline void copy(rt_base &dst, const rt_base &src) { + using T2 = typename base_types::packing::packed_type; + using U2 = typename base_types::packing::packed_type; + #pragma unroll + for(int k = 0; k < dst.packed_per_thread; k++) { + dst.data[k] = base_types::convertor::convert(src.data[k]); + } +} +#ifdef KITTENS_HOPPER +/** + * @brief Copies a register tile, converting the underlying type if necessary. + * + * @tparam T2 The data type of the destination register elements. + * @tparam U2 The data type of the source register elements. + * @tparam _height The height (in units of 16) of the register tiles. + * @tparam _width The width (in units of 16) of the register tiles. + * @tparam layout The current layout of the register tile. + * @param[out] dst A reference to the destination register tile. + * @param[in] src A reference to the source register tile. + */ +template +__device__ static inline void copy(rt &dst, const rt &src) { + + if constexpr ( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) + ) { + // FLOAT (SRC -- 1H x 2W) to FP8 (DST -- 1H x 1W) + int laneid = threadIdx.x % 32; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.tiles[0][0].packed_per_thread; k++) { + + // check for half, float, bf16 + using src_t = std::conditional_t, float2, std::conditional_t, bf16_2, half2>>; + src_t val1, val2; + + // Put something up for adoption + if (laneid % 2 == 0) { + // put up src left core matrix first as 0, 2 + val1 = src.tiles[i][2*j + k/2].data[(k%2)+0]; + val2 = src.tiles[i][2*j + k/2].data[(k%2)+2]; + } else { + // put up src right core matrix first as 1, 3 + val1 = src.tiles[i][2*j + k/2].data[(k%2)+2]; + val2 = src.tiles[i][2*j + k/2].data[(k%2)+0]; + } + + // Shuffle first 4 floats + int row_mask = 4 * ( laneid / 4 ); + int row_offset = row_mask + ( (laneid-row_mask) / 2 ) + ( laneid % 2 ); + int src_offset = (laneid % 2 == 0 ) ? row_offset + 0 : ( row_offset + 1 ); + src_t val01 = packed_shfl_sync(MASK_ALL, val1, src_offset); // Get from even thread + + int src_offset2 = (laneid % 4 < 2 ) ? src_offset + 1 : (src_offset - 1); + src_t val23 = packed_shfl_sync(MASK_ALL, val2, src_offset2); // Get from odd thread + + // Convert to fp8e4m3_4 + float4 f4; + using fp8_4_t = std::conditional_t, fp8e4m3_4, fp8e5m2_4>; + fp8_4_t f4_fp8; + if ( laneid % 4 < 2 ) { + f4.x = val01.x; // Thread 2N's first value + f4.y = val01.y; // Thread 2N's second value + f4.z = val23.x; // Thread 2N+1's first value + f4.w = val23.y; // Thread 2N+1's second value + f4_fp8 = base_types::convertor::convert(f4); + dst.tiles[i][j].data[k] = f4_fp8; + } else { + f4.x = val23.x; // Thread 2N+1's first value + f4.y = val23.y; // Thread 2N+1's second value + f4.z = val01.x; // Thread 2N's first value + f4.w = val01.y; // Thread 2N's second value + f4_fp8 = base_types::convertor::convert(f4); + dst.tiles[i][j].data[k] = f4_fp8; + } + } + } + } + } + else if constexpr ( + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) || + (std::is_same_v && std::is_same_v) + ) { + // FP8 (SRC -- 1H x 1W) to FLOAT (DST -- 1H x 2W) + int laneid = threadIdx.x % 32; + + #pragma unroll + for(int i = 0; i < src.height; i++) { + #pragma unroll + for(int j = 0; j < src.width; j++) { + #pragma unroll + for(int k = 0; k < src.tiles[0][0].packed_per_thread; k++) { + int dst_j = 2*j + k/2; + + // Put something up for adoption + using fp8_4_t = std::conditional_t, fp8e4m3_4, fp8e5m2_4>; + fp8_4_t val = src.tiles[i][j].data[k]; + float4 f4 = base_types::convertor::convert(val); + float2 f2_0, f2_1; + if ( laneid % 4 < 2 ) { // src 0 and 1 should put up .x and .y first + f2_0 = make_float2(f4.x, f4.y); + f2_1 = make_float2(f4.z, f4.w); + } + else { // src 2 and 3 should put up .z and .w first + f2_0 = make_float2(f4.z, f4.w); + f2_1 = make_float2(f4.x, f4.y); + } + + int row_offset = 4 * (laneid/4) + (laneid%2) * 2 + (laneid%4) / 2; + float2 f2_0_shfl = packed_shfl_sync(MASK_ALL, f2_0, row_offset); + float2 f2_1_shfl = packed_shfl_sync(MASK_ALL, f2_1, row_offset^2); + + // convert to dst type if needed + using dst_t = std::conditional_t, float2, std::conditional_t, bf16_2, half2>>; + if constexpr (!(std::is_same_v)) { + dst_t f2_0_shfl_t = base_types::convertor::convert(f2_0_shfl); + dst_t f2_1_shfl_t = base_types::convertor::convert(f2_1_shfl); + if (laneid % 2 == 0) { + dst.tiles[i][dst_j].data[(k%2)+0] = f2_0_shfl_t; + dst.tiles[i][dst_j].data[(k%2)+2] = f2_1_shfl_t; + } else { + dst.tiles[i][dst_j].data[(k%2)+0] = f2_1_shfl_t; + dst.tiles[i][dst_j].data[(k%2)+2] = f2_0_shfl_t; + } + } else { + if (laneid % 2 == 0) { + dst.tiles[i][dst_j].data[(k%2)+0] = f2_0_shfl; + dst.tiles[i][dst_j].data[(k%2)+2] = f2_1_shfl; + } else { + dst.tiles[i][dst_j].data[(k%2)+0] = f2_1_shfl; + dst.tiles[i][dst_j].data[(k%2)+2] = f2_0_shfl; + } + } + } + } + } + } + // default case where the layouts map 1:1 in thread ownership logic + else { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + copy(dst.tiles[i][j], src.tiles[i][j]); + } + } + } +} +#else +/** + * @brief Copies a register tile, converting the underlying type if necessary. + * + * @tparam T2 The data type of the destination register elements. + * @tparam U2 The data type of the source register elements. + * @tparam _height The height (in units of 16) of the register tiles. + * @tparam _width The width (in units of 16) of the register tiles. + * @tparam layout The current layout of the register tile. + * @param[out] dst A reference to the destination register tile. + * @param[in] src A reference to the source register tile. + */ +template +__device__ static inline void copy(rt &dst, const rt &src) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + copy(dst.tiles[i][j], src.tiles[i][j]); + } + } +} +#endif + +/* ---------- SUBTILE ---------- */ + +/** +* @brief Returns a reference to a subtile of the given tile. +* +* @tparam subtile_height The height of the subtile. +* @tparam RT The type of the input tile, which must satisfy the ducks::rt::all concept. +* @param src The input tile. +* @param idx The coord of the subtile. +* @return A reference to the subtile. +* +* @note The subtile height must evenly divide the tile height. +*/ +template +__device__ static inline rt &subtile_inplace(RT & src, int idx) { + KITTENS_CHECK_WARP + using T = typename RT::T; + static_assert(RT::height % (subtile_rows / TILE_ROW_DIM) == 0, "subtile height should evenly divide tile height."); + return reinterpret_cast&>( + src.tiles[idx*(subtile_rows / TILE_ROW_DIM)] + ); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/tile/maps.cuh b/extra/thunder/cuda/include/ops/group/register/tile/maps.cuh new file mode 100644 index 0000000000..c623aa6b67 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/tile/maps.cuh @@ -0,0 +1,836 @@ +/** + * @file + * @brief Map operations: between tiles, and those which apply vectors to tiles. + */ + +/* ---------- Uniform tile maps (independent of layout) ---------- */ + +/** + * @brief Applies a unary operation to each element of a tile. + * + * @tparam op Unary operation to apply. + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + */ +template +__device__ static inline void unary_map(T &dst, const T &src) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = op::template op(src.tiles[i][j].data[k]); + } + } + } +} + +/** + * @brief Applies a binary operation to each element of a tile with a scalar parameter. + * + * @tparam op Binary operation to apply. + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param param[in] Scalar parameter for the binary operation. + */ +template +__device__ static inline void bin_map(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = op::template op(src.tiles[i][j].data[k], param); + } + } + } +} +/** + * @brief Applies a binary operation to each element of a tile with an unpacked scalar parameter. + * + * @tparam op Binary operation to apply. + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param param[in] Unpacked scalar parameter for the binary operation. + */ +template +__device__ static inline void bin_map(T &dst, const T &src, const typename base_types::packing::unpacked_type ¶m) { + // The optimizing compiler should eliminate this pack in the 32-bit case but not in the 16-bit case + bin_map(dst, src, base_types::packing::pack(param)); +} +/** + * @brief Applies a binary operation element-wise between two tiles. + * + * @tparam op Binary operation to apply. + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the operation. + * @param rhs[in] Right-hand side source tile for the operation. + */ +template +__device__ static inline void bin_map(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + dst.tiles[i][j].data[k] = op::template op(lhs.tiles[i][j].data[k], rhs.tiles[i][j].data[k]); + } + } + } +} + +template +__device__ static inline void apply(RT &dst, const RT &src, Lambda &&lambda) { + int row_offset = 0; + if constexpr(GROUP_WARPS > 1) { + row_offset = warpid()*RT::height; + } + static_assert(sizeof(RT::T) != 1, "Cannot apply lambda to 8-bit types"); + if constexpr (ducks::rt::row_layout) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + int row = row_offset + i*TILE_ROW_DIM + (k%2) * (TILE_ROW_DIM/2) + ::kittens::laneid()/4; + int col = j*TILE_COL_DIM + (k/2) * (TILE_COL_DIM/2) + (::kittens::laneid()%4)*2; + dst.tiles[i][j].data[k].x = lambda(row, col+0, src.tiles[i][j].data[k].x); + dst.tiles[i][j].data[k].y = lambda(row, col+1, src.tiles[i][j].data[k].y); + } + } + } + } + else { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k++) { + int row = row_offset + i*TILE_ROW_DIM + (k/2) * (TILE_ROW_DIM/2) + (::kittens::laneid()%4)*2; + int col = j*TILE_COL_DIM + (k%2) * (TILE_COL_DIM/2) + ::kittens::laneid()/4; + dst.tiles[i][j].data[k].x = lambda(row+0, col, src.tiles[i][j].data[k].x); + dst.tiles[i][j].data[k].y = lambda(row+1, col, src.tiles[i][j].data[k].y); + } + } + } + } +} +template +__device__ static inline RT apply(const RT &src, Lambda &&lambda) { + RT dst; + apply(dst, src, std::forward(lambda)); + return dst; +} + +/* ---------- Row tile maps ----------*/ + +/** + * @brief Applies an operation across the rows of a tile in a row-major layout. + * + * @tparam op Operation to apply. + * @tparam T Tile type with row-major layout. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param row_values[in] Column vector containing values to apply across each row. + */ +template +__device__ static inline void row_map(T &dst, const T &src, const V &row_values) { + + static_assert(std::is_same_v::col_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + dtype packed_top_row = base_types::packing::pack(row_values[i][0].x); // first value in eager mode + dtype packed_bottom_row = base_types::packing::pack(row_values[i][0].y); // second value in eager mode + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k+=2) { + dst.tiles[i][j].data[k+0] = op::template op(src.tiles[i][j].data[k+0], packed_top_row); + dst.tiles[i][j].data[k+1] = op::template op(src.tiles[i][j].data[k+1], packed_bottom_row); + } + } + } +} +/** + * @brief Applies an operation across the rows of a tile in a column-major layout. + * + * @tparam op Operation to apply. + * @tparam T Tile type with column-major layout. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param row_values[in] Column vector containing values to apply across each row. + */ +template +__device__ static inline void row_map(T &dst, const T &src, const V &row_values) { + + static_assert(std::is_same_v); // compatible type + static_assert(std::is_same_v::col_vec_layout>); // compatible layout + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile/2; k++) { + dst.tiles[i][j].data[k+0] = op::template op(src.tiles[i][j].data[k+0], row_values[i][0]); + dst.tiles[i][j].data[k+2] = op::template op(src.tiles[i][j].data[k+2], row_values[i][1]); + } + } + } +} + + +// Three-operand row map. Mostly useful for FMA instructions. + +/** + * @brief Applies an operation across the rows of two tiles in a row-major layout, using a third operand. + * + * @tparam op Operation to apply. + * @tparam T Tile type with row-major layout. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param a[in] First source tile to apply the operation on. + * @param b[in] Second source tile to apply the operation on. + * @param row_values[in] Column vector containing values to apply across each row. + */ +template +__device__ static inline void row_map(T &dst, const T &a, const T &b, const V &row_values) { + + static_assert(std::is_same_v::col_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + dtype packed_top_row = base_types::packing::pack(row_values[i][0].x); // first value in eager mode + dtype packed_bottom_row = base_types::packing::pack(row_values[i][0].y); // second value in eager mode + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k+=2) { + dst.tiles[i][j].data[k+0] = op::template op(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], packed_top_row); + dst.tiles[i][j].data[k+1] = op::template op(a.tiles[i][j].data[k+1], b.tiles[i][j].data[k+1], packed_bottom_row); + } + } + } +} +/** + * @brief Applies an operation across the rows of two tiles in a column-major layout, using a third operand. + * + * @tparam op Operation to apply. + * @tparam T Tile type with column-major layout. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param a[in] First source tile to apply the operation on. + * @param b[in] Second source tile to apply the operation on. + * @param row_values[in] Column vector containing values to apply across each row. + */ +template +__device__ static inline void row_map(T &dst, const T &a, const T &b, const V &row_values) { + + static_assert(std::is_same_v::col_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile/2; k++) { + dst.tiles[i][j].data[k+0] = op::template op(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], row_values[i][0]); + dst.tiles[i][j].data[k+2] = op::template op(a.tiles[i][j].data[k+2], b.tiles[i][j].data[k+2], row_values[i][1]); + } + } + } +} + +/* ---------- Col major tile maps ----------*/ + +/** + * @brief Applies an operation across the columns of a tile in a row-major layout. + * + * @tparam op Operation to apply. + * @tparam T Tile type with row-major layout. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param col_values[in] Row vector containing values to apply across each column. + */ +template +__device__ static inline void col_map(T &dst, const T &src, const V &col_values) { + KITTENS_CHECK_WARP + + static_assert(std::is_same_v::row_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile/2; k++) { + dst.tiles[i][j].data[k+0] = op::template op(src.tiles[i][j].data[k+0], col_values[j][0]); + dst.tiles[i][j].data[k+2] = op::template op(src.tiles[i][j].data[k+2], col_values[j][1]); + } + } + } +} +/** + * @brief Applies an operation across the columns of a tile in a column-major layout. + * + * @tparam op Operation to apply. + * @tparam T Tile type with column-major layout. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the operation on. + * @param col_values[in] Row vector containing values to apply across each column. + */ +template +__device__ static inline void col_map(T &dst, const T &src, const V &col_values) { + KITTENS_CHECK_WARP + + static_assert(std::is_same_v::row_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int j = 0; j < dst.width; j++) { + dtype packed_left_col = base_types::packing::pack(col_values[j][0].x); // first value in eager mode + dtype packed_right_col = base_types::packing::pack(col_values[j][0].y); // second value in eager mode + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k+=2) { + dst.tiles[i][j].data[k+0] = op::template op(src.tiles[i][j].data[k+0], packed_left_col); + dst.tiles[i][j].data[k+1] = op::template op(src.tiles[i][j].data[k+1], packed_right_col); + } + } + } +} + +// Three-operand col map +/** + * @brief Applies an operation across the columns of two tiles in a row-major layout, using a third operand. + * + * @tparam op Operation to apply. + * @tparam T Tile type with row-major layout. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param a[in] First source tile to apply the operation on. + * @param b[in] Second source tile to apply the operation on. + * @param col_values[in] Row vector containing values to apply across each column. + */ +template +__device__ static inline void col_map(T &dst, const T &a, const T &b, const V &col_values) { + KITTENS_CHECK_WARP + + static_assert(std::is_same_v::row_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = T::dtype; + + #pragma unroll + for(int j = 0; j < dst.width; j++) { + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile/2; k++) { + dst.tiles[i][j].data[k+0] = op::template op(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], col_values[j][0]); + dst.tiles[i][j].data[k+2] = op::template op(a.tiles[i][j].data[k+2], b.tiles[i][j].data[k+2], col_values[j][1]); + } + } + } +} +/** + * @brief Applies an operation across the columns of two tiles in a column-major layout, using a third operand. + * + * @tparam op Operation to apply. + * @tparam T Tile type with column-major layout. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param a[in] First source tile to apply the operation on. + * @param b[in] Second source tile to apply the operation on. + * @param col_values[in] Row vector containing values to apply across each column. + */ +template +__device__ static inline void col_map(T &dst, const T &a, const T &b, const V &col_values) { + KITTENS_CHECK_WARP + + static_assert(std::is_same_v); // compatible type + static_assert(std::is_same_v::row_vec_layout>); // compatible layout + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = T::dtype; + #pragma unroll + for(int j = 0; j < dst.width; j++) { + dtype packed_left_col = base_types::packing::pack(col_values[j][0].x); // first value in eager mode + dtype packed_right_col = base_types::packing::pack(col_values[j][0].y); // second value in eager mode + #pragma unroll + for(int i = 0; i < dst.height; i++) { + #pragma unroll + for(int k = 0; k < dst.packed_per_tile; k+=2) { + dst.tiles[i][j].data[k+0] = op::template op(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], packed_left_col); + dst.tiles[i][j].data[k+1] = op::template op(a.tiles[i][j].data[k+1], b.tiles[i][j].data[k+1], packed_right_col); + } + } + } +} + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// All of the annoying qualifiers *should* be automatically inferred during compile-time. +// So, syntax should just be kittens::add_row(tile, colvec); + +/** + * @brief Sets all elements of a tile to zero. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void zero(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of a tile to one. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void one(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of a tile to positive infinity. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_map(dst, dst); +} +/** + * @brief Sets all elements of a tile to negative infinity. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_map(dst, dst); +} + +/** + * @brief Applies the exponential function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the exponential function on. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_map(dst, src); +} +template +__device__ static inline T exp(const T &src) { + T dst; + exp(dst, src); + return dst; +} + +/** + * @brief Applies the exponential function to each element of a tile, in base 2. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the exponential function on. + */ +template +__device__ static inline void exp2(T &dst, const T &src) { + unary_map(dst, src); +} +template +__device__ static inline T exp2(const T &src) { + T dst; + exp2(dst, src); + return dst; +} + +/** + * @brief Applies the natural logarithm function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the natural logarithm function on. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_map(dst, src); +} +template +__device__ static inline T log(const T &src) { + T dst; + log(dst, src); + return dst; +} + +/** + * @brief Applies the logarithm base 2 function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the logarithm base 2 function on. + */ +template +__device__ static inline void log2(T &dst, const T &src) { + unary_map(dst, src); +} +template +__device__ static inline T log2(const T &src) { + T dst; + log2(dst, src); + return dst; +} + +/** + * @brief Applies the absolute value function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the absolute value function on. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_map(dst, src); +} +template +__device__ static inline T abs(const T &src) { + T dst; + abs(dst, src); + return dst; +} + +/** + * @brief Applies the rectified linear unit (ReLU) function to each element of a tile. + * + * @tparam T Tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the ReLU function on. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_map(dst, src); +} +template +__device__ static inline T relu(const T &src) { + T dst; + relu(dst, src); + return dst; +} + +/** + * @brief Copies the elements from one tile to another. + * + * @tparam T Destination tile type. + * @tparam U Source tile type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to copy from. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_map(dst, src); +} + +/** + * @brief Applies the max operation element-wise between two tiles or a tile and a scalar. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the operation. + * @param rhs[in] Right-hand side source tile or scalar for the operation. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +template +__device__ static inline T max(const T &lhs, const U &rhs) { + T dst; + max(dst, lhs, rhs); + return dst; +} + +/** + * @brief Applies the min operation element-wise between two tiles or a tile and a scalar. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the operation. + * @param rhs[in] Right-hand side source tile or scalar for the operation. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +template +__device__ static inline T min(const T &lhs, const U &rhs) { + T dst; + min(dst, lhs, rhs); + return dst; +} + +/** + * @brief Adds two tiles element-wise or adds a scalar to each element of a tile. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the addition. + * @param rhs[in] Right-hand side source tile or scalar for the addition. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +/** + * @brief Subtracts two tiles element-wise or subtracts a scalar from each element of a tile. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the subtraction. + * @param rhs[in] Right-hand side source tile or scalar for the subtraction. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} +/** + * @brief Multiplies two tiles element-wise or multiplies each element of a tile by a scalar. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the multiplication. + * @param rhs[in] Right-hand side source tile or scalar for the multiplication. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +/** + * @brief Divides two tiles element-wise or divides each element of a tile by a scalar. + * + * @tparam T Tile type. + * @tparam U Second operand type, which can be a tile or a scalar. + * @param dst[out] Destination tile where the result is stored. + * @param lhs[in] Left-hand side source tile for the division. + * @param rhs[in] Right-hand side source tile or scalar for the division. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +/** + * @brief Adds row values to each row of a tile. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the addition on. + * @param row_values[in] Column vector containing values to add to each row. + */ +template +__device__ static inline void add_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} + +/** + * @brief Subtracts row values from each row of a tile. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the subtraction on. + * @param row_values[in] Column vector containing values to subtract from each row. + */ +template +__device__ static inline void sub_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} + +/** + * @brief Multiplies each row of a tile by row values. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the multiplication on. + * @param row_values[in] Column vector containing values to multiply each row by. + */ +template +__device__ static inline void mul_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} + +/** + * @brief Divides each row of a tile by row values. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the division on. + * @param row_values[in] Column vector containing values to divide each row by. + */ +template +__device__ static inline void div_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} + +/** + * @brief Broadcast a vector into into a tile's rows. + * + * @tparam T Tile type. + * @tparam V Column vector type. + * @param dst[out] Destination tile where the result is stored. + * @param row_values[in] Column vector containing values to broadcast into rows. + */ +template +__device__ static inline void broadcast_row(T &dst, const V &row_values) { + row_map(dst, dst, row_values); +} +template +__device__ static inline T broadcast_row(const V &row_values) { + T dst; + broadcast_row(dst, row_values); + return dst; +} + + +// col maps +/** + * @brief Adds column values to each column of a tile. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the addition on. + * @param col_values[in] Row vector containing values to add to each column. + */ +template +__device__ static inline void add_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} + +/** + * @brief Subtracts column values from each column of a tile. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the subtraction on. + * @param col_values[in] Row vector containing values to subtract from each column. + */ +template +__device__ static inline void sub_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} + +/** + * @brief Multiplies each column of a tile by column values. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the multiplication on. + * @param col_values[in] Row vector containing values to multiply each column by. + */ +template +__device__ static inline void mul_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} + +/** + * @brief Divides each column of a tile by column values. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param src[in] Source tile to apply the division on. + * @param col_values[in] Row vector containing values to divide each column by. + */ +template +__device__ static inline void div_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} + +/** + * @brief Broadcast a vector into into a tile's columns. + * + * @tparam T Tile type. + * @tparam V Row vector type. + * @param dst[out] Destination tile where the result is stored. + * @param row_values[in] Row vector containing values to broadcast into cols. + */ +template +__device__ static inline void broadcast_col(T &dst, const V &col_values) { + col_map(dst, dst, col_values); +} +template +__device__ static inline T broadcast_col(const V &col_values) { + T dst; + broadcast_col(dst, col_values); + return dst; +} + +// Triangular masks +template +__device__ static inline void tril(RT &dst, const RT &src, int diagonal=0, const typename base_types::packing::unpacked_type &val=0) { + apply(dst, src, [val, diagonal]__device__(int row, int col, auto &src_val) { + return col <= row + diagonal ? src_val : val; + }); +} +template +__device__ static inline void triu(RT &dst, const RT &src, int diagonal=0, const typename base_types::packing::unpacked_type &val=0) { + apply(dst, src, [val, diagonal]__device__(int row, int col, auto &src_val) { + return col >= row + diagonal ? src_val : val; + }); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/tile/reductions.cuh b/extra/thunder/cuda/include/ops/group/register/tile/reductions.cuh new file mode 100644 index 0000000000..49efa39ab5 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/tile/reductions.cuh @@ -0,0 +1,554 @@ +/** + * @file + * @brief Reduction operations mapping tiles to vectors. + */ + +/** + * @brief Perform a row-wise reduction on a matrix in row-major layout. + * + * This function template performs a parallel reduction across the rows of a matrix using a specified operation. + * It leverages warp shuffle functions for efficient intra-warp communication. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type with row layout. + * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when reset is false. + */ +template +__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) { + // I actually like these static asserts because they give more verbose errors when things go wrong. + static_assert(std::is_same_v::col_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = V::dtype; + + const int leader = threadIdx.x & 0x1C; // 11100 in binary + #pragma unroll + for(int i = 0; i < src.height; i++) { + dtype accum_top_row = op::template op(src.tiles[i][0].data[0], src.tiles[i][0].data[2]); + dtype accum_bottom_row = op::template op(src.tiles[i][0].data[1], src.tiles[i][0].data[3]); + #pragma unroll + for(int j = 1; j < src.width; j++) { + #pragma unroll + for(int k = 0; k < src.packed_per_tile; k+=2) { + accum_top_row = op::template op(accum_top_row, src.tiles[i][j].data[k+0]); + accum_bottom_row = op::template op(accum_bottom_row, src.tiles[i][j].data[k+1]); + } + } + dtype accum_packed; + accum_packed.x = op::template op::unpacked_type>(accum_top_row.x, accum_top_row.y); + accum_packed.y = op::template op::unpacked_type>(accum_bottom_row.x, accum_bottom_row.y); + + // Now we need to do a lil shuffle to make everyone happy. + + accum_packed = op::template op(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 2)); + accum_packed = op::template op(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 1)); + + accum_packed = packed_shfl_sync(MASK_ALL, accum_packed, leader); + + if(reset) { + row_accum[i][0] = accum_packed; + } + else { + row_accum[i][0] = op::template op(src_accum[i][0], accum_packed); + } + } +} +/** + * @brief Perform a row-wise reduction on a matrix in column-major layout. + * + * This function template performs a parallel reduction across the rows of a matrix using a specified operation. + * It leverages warp shuffle functions for efficient intra-warp communication and is optimized for column-major matrices. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type with column layout. + * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when reset is false. + */ +template +__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) { + // I actually like these static asserts because they give more verbose errors when things go wrong. + static_assert(std::is_same_v::col_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::height); // compatible size + + using dtype = V::dtype; + + const int leader = threadIdx.x & 0x3; // 00011 in binary + #pragma unroll + for(int i = 0; i < src.height; i++) { + dtype accum_top_rows = op::template op(src.tiles[i][0].data[0], src.tiles[i][0].data[1]); + dtype accum_bottom_rows = op::template op(src.tiles[i][0].data[2], src.tiles[i][0].data[3]); + #pragma unroll + for(int j = 1; j < src.width; j++) { + #pragma unroll + for(int k = 0; k < src.packed_per_tile/2; k++) { + accum_top_rows = op::template op(accum_top_rows, src.tiles[i][j].data[k+0]); + accum_bottom_rows = op::template op(accum_bottom_rows, src.tiles[i][j].data[k+2]); + } + } + + // Now we need to do a lil shuffle to make everyone happy. + + accum_top_rows = op::template op(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 16)); + accum_top_rows = op::template op(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 8)); + accum_top_rows = op::template op(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 4)); + + accum_bottom_rows = op::template op(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 16)); + accum_bottom_rows = op::template op(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 8)); + accum_bottom_rows = op::template op(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 4)); + + accum_top_rows = packed_shfl_sync(MASK_ALL, accum_top_rows, leader); + accum_bottom_rows = packed_shfl_sync(MASK_ALL, accum_bottom_rows, leader); + + if(reset) { + row_accum[i][0] = accum_top_rows; + row_accum[i][1] = accum_bottom_rows; + } + else { + row_accum[i][0] = op::template op(src_accum[i][0], accum_top_rows); + row_accum[i][1] = op::template op(src_accum[i][1], accum_bottom_rows); + } + } +} + +// Col reduction. +/** + * @brief Perform a column-wise reduction on a matrix in row-major layout. + * + * This function template performs a parallel reduction across the columns of a matrix using a specified operation. + * It leverages warp shuffle functions for efficient intra-warp communication and is optimized for row-major matrices. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The vector type for the column accumulator. + * @tparam T The matrix type with row layout. + * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when reset is false. + */ +template +__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) { + // I actually like these static asserts because they give more verbose errors when things go wrong. + KITTENS_CHECK_WARP + static_assert(std::is_same_v::row_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = V::dtype; + + const int leader = threadIdx.x & 0x3; // 00011 in binary + #pragma unroll + for(int j = 0; j < src.width; j++) { + dtype accum_left_cols = op::template op(src.tiles[0][j].data[0], src.tiles[0][j].data[1]); + dtype accum_right_cols = op::template op(src.tiles[0][j].data[2], src.tiles[0][j].data[3]); + #pragma unroll + for(int i = 1; i < src.height; i++) { + #pragma unroll + for(int k = 0; k < src.packed_per_tile/2; k++) { + accum_left_cols = op::template op(accum_left_cols, src.tiles[i][j].data[k+0]); + accum_right_cols = op::template op(accum_right_cols, src.tiles[i][j].data[k+2]); + } + } + + // Now we need to do a lil shuffle to make everyone happy. + + accum_left_cols = op::template op(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 16)); + accum_left_cols = op::template op(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 8)); + accum_left_cols = op::template op(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 4)); + + accum_right_cols = op::template op(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 16)); + accum_right_cols = op::template op(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 8)); + accum_right_cols = op::template op(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 4)); + + accum_left_cols = packed_shfl_sync(MASK_ALL, accum_left_cols, leader); + accum_right_cols = packed_shfl_sync(MASK_ALL, accum_right_cols, leader); + + if(reset) { + col_accum[j][0] = accum_left_cols; + col_accum[j][1] = accum_right_cols; + } + else { + col_accum[j][0] = op::template op(src_accum[j][0], accum_left_cols); + col_accum[j][1] = op::template op(src_accum[j][1], accum_right_cols); + } + } +} +/** + * @brief Perform a column-wise reduction on a matrix in column-major layout. + * + * This function template performs a parallel reduction across the columns of a matrix using a specified operation. + * It leverages warp shuffle functions for efficient intra-warp communication and is optimized for column-major matrices. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The vector type for the column accumulator. + * @tparam T The matrix type with column layout. + * @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when reset is false. + */ +template +__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) { + // I actually like these static asserts because they give more verbose errors when things go wrong. + KITTENS_CHECK_WARP + static_assert(std::is_same_v::row_vec_layout>); // compatible layout + static_assert(std::is_same_v); // compatible type + static_assert(V::outer_dim == T::width); // compatible size + + using dtype = V::dtype; + const int leader = threadIdx.x & 0x1C; // 11100 in binary + #pragma unroll + for(int j = 0; j < src.width; j++) { // note now width is the outer loop + dtype accum_left_col = op::template op(src.tiles[0][j].data[0], src.tiles[0][j].data[2]); + dtype accum_right_col = op::template op(src.tiles[0][j].data[1], src.tiles[0][j].data[3]); + #pragma unroll + for(int i = 1; i < src.height; i++) { // and height is the inner loop + #pragma unroll + for(int k = 0; k < src.packed_per_tile; k+=2) { + accum_left_col = op::template op(accum_left_col, src.tiles[i][j].data[k+0]); + accum_right_col = op::template op(accum_right_col, src.tiles[i][j].data[k+1]); + } + } + dtype accum_packed; + accum_packed.x = op::template op::unpacked_type>(accum_left_col.x, accum_left_col.y); + accum_packed.y = op::template op::unpacked_type>(accum_right_col.x, accum_right_col.y); + + // Now we need to do a lil shuffle to make everyone happy. + + accum_packed = op::template op(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 2)); + accum_packed = op::template op(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 1)); + + accum_packed = packed_shfl_sync(MASK_ALL, accum_packed, leader); + + if(reset) { + col_accum[j][0] = accum_packed; + } + else { + col_accum[j][0] = op::template op(src_accum[j][0], accum_packed); + } + } +} + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// two-operand row reductions. (Accumulate and REPLACE.) +/** + * @brief Store the maximum of each row of the src register tile in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the minimum of each row of the src register tile in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the sum of each row of the src register tile in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the product of each row of the src register tile in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +// three-operand row reductions. (Accumulate ONTO.) +/** + * @brief Store the maximum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the minimum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the sum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the product of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} + +// two-operand col reductions. (Accumulate and REPLACE.) + +/** + * @brief Store the maximum of each column of the src register tile in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the minimum of each column of the src register tile in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the sum of each column of the src register tile in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the product of each column of the src register tile in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +// three-operand col reductions. (Accumulate ONTO.) +/** + * @brief Store the maximum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the minimum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the sum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the product of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector. + * + * @tparam V The vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} + +// templated versions of each + +template +__device__ static inline void max(RV &dst, const T &src, const RV &src_accum) { + if constexpr (ax == axis::COL) row_max(dst, src, src_accum); + else col_max(dst, src, src_accum); +} +template +__device__ static inline auto max(const T &src, const RV &src_accum) { + RV dst; + if constexpr (ax == axis::COL) row_max(dst, src, src_accum); + else col_max(dst, src, src_accum); + return dst; +} +template +__device__ static inline void max(RV &dst, const T &src) { + if constexpr (ax == axis::COL) row_max(dst, src); + else col_max(dst, src); +} +template +__device__ static inline auto max(const T &src) { + using RV = std::conditional_t; + RV dst; + if constexpr (ax == axis::COL) row_max(dst, src); + else col_max(dst, src); + return dst; +} + +template +__device__ static inline void min(RV &dst, const T &src, const RV &src_accum) { + if constexpr (ax == axis::COL) row_min(dst, src, src_accum); + else col_min(dst, src, src_accum); +} +template +__device__ static inline auto min(const T &src, const RV &src_accum) { + RV dst; + if constexpr (ax == axis::COL) row_min(dst, src, src_accum); + else col_min(dst, src, src_accum); + return dst; +} +template +__device__ static inline void min(RV &dst, const T &src) { + if constexpr (ax == axis::COL) row_min(dst, src); + else col_min(dst, src); +} +template +__device__ static inline auto min(const T &src) { + using RV = std::conditional_t; + RV dst; + if constexpr (ax == axis::COL) row_min(dst, src); + else col_min(dst, src); + return dst; +} + +template +__device__ static inline void sum(RV &dst, const T &src, const RV &src_accum) { + if constexpr (ax == axis::COL) row_sum(dst, src, src_accum); + else col_sum(dst, src, src_accum); +} +template +__device__ static inline auto sum(const T &src, const RV &src_accum) { + RV dst; + if constexpr (ax == axis::COL) row_sum(dst, src, src_accum); + else col_sum(dst, src, src_accum); + return dst; +} +template +__device__ static inline void sum(RV &dst, const T &src) { + if constexpr (ax == axis::COL) row_sum(dst, src); + else col_sum(dst, src); +} +template +__device__ static inline auto sum(const T &src) { + using RV = std::conditional_t; + RV dst; + if constexpr (ax == axis::COL) row_sum(dst, src); + else col_sum(dst, src); + return dst; +} + +template +__device__ static inline void prod(RV &dst, const T &src, const RV &src_accum) { + if constexpr (ax == axis::COL) row_prod(dst, src, src_accum); + else col_prod(dst, src, src_accum); +} +template +__device__ static inline auto prod(const T &src, const RV &src_accum) { + RV dst; + if constexpr (ax == axis::COL) row_prod(dst, src, src_accum); + else col_prod(dst, src, src_accum); + return dst; +} +template +__device__ static inline void prod(RV &dst, const T &src) { + if constexpr (ax == axis::COL) row_prod(dst, src); + else col_prod(dst, src); +} +template +__device__ static inline auto prod(const T &src) { + using RV = std::conditional_t; + RV dst; + if constexpr (ax == axis::COL) row_prod(dst, src); + else col_prod(dst, src); + return dst; +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/tile/tile.cuh b/extra/thunder/cuda/include/ops/group/register/tile/tile.cuh new file mode 100644 index 0000000000..1ddbaf4380 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/tile/tile.cuh @@ -0,0 +1,47 @@ +/** + * @file + * @brief An aggregate header for warp operations on register tiles. + */ + +#include "conversions.cuh" +#include "maps.cuh" +#include "reductions.cuh" + +template +__device__ static inline bool hasnan(const RT &src) { + KITTENS_CHECK_WARP + bool nan_detected = false; + #pragma unroll + for(int i = 0; i < RT::height; i++) { + #pragma unroll + for(int j = 0; j < RT::width; j++) { + #pragma unroll + for(int k = 0; k < RT::packed_per_tile; k++) { + if constexpr (std::is_same_v) { + if(isnan(src.tiles[i][j].data[k].x) || isnan(src.tiles[i][j].data[k].y)) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__bfloat162float(src.tiles[i][j].data[k].x)) || isnan(__bfloat162float(src.tiles[i][j].data[k].y))) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__half2float(src.tiles[i][j].data[k].x)) || isnan(__half2float(src.tiles[i][j].data[k].y))) { + nan_detected = true; + } + } + else { + static_assert(sizeof(typename RT::T) == 999, "Unsupported dtype"); + } + } + } + } + // Ballot across the warp to see if any lane detected a nan + return (__ballot_sync(0xffffffff, nan_detected) != 0); +} + +#include "complex/complex_conversions.cuh" +#include "complex/complex_maps.cuh" + diff --git a/extra/thunder/cuda/include/ops/group/register/vec/conversions.cuh b/extra/thunder/cuda/include/ops/group/register/vec/conversions.cuh new file mode 100644 index 0000000000..3bc7177e17 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/vec/conversions.cuh @@ -0,0 +1,153 @@ +/** + * @file + * @brief Conversions on vectors stored in registers. + */ + +struct vec_conversion_detail { + +// i am not smart enough to figure out these indices without these helpers :/ +// again, blame nvidia for these stupid, stupid layouts +__device__ static inline int row_from_indices_dim2(int laneid, int inner_dim, int x_or_y) { + return 8*inner_dim + (laneid%4)*2 + x_or_y; +} +__device__ static inline int row_from_indices_dim1(int laneid, int x_or_y) { + return 8*x_or_y + (laneid/4); +} +__device__ static inline int canonical_src_lane_dim2(int row) { + return (row/2)%4 + 4*(row%2); // draw even rows from 0...3 and odds from 4...7 +} +__device__ static inline int canonical_src_lane_dim1(int row) { + return (row*4)%32; +} + +}; + +/** + * @brief Copies data from one register vector to another. + * + * @tparam RV1 The type of the destination register vector. + * @tparam RV2 The type of the source register vector. + * @param dst[out] The destination register vector. + * @param src[in] The source register vector to copy from. + */ +template +__device__ static inline void copy(RV1 &dst, const RV2 &src) { + KITTENS_CHECK_WARP + static_assert(RV1::length == RV2::length, "Register vectors must be the same length."); + using D1 = RV1::dtype; + using D2 = RV2::dtype; + if constexpr (std::is_same_v) { // just a simple copy / typecast + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + #pragma unroll + for(int j = 0; j < RV1::inner_dim; j++) { + dst[i][j] = base_types::convertor::convert(src[i][j]); + } + } + } + else { // Inner dimensions are not the same, this is really a layout conversion. + int laneid = ::kittens::laneid(); + if constexpr (std::is_same_v && std::is_same_v) { // align -> ortho layout + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + dst[i][0].x = packed_shfl_sync( + kittens::MASK_ALL, + laneid < 4 ? src[i][0].x : src[i][0].y, // mirrors canonical_src_lane_dim2 + vec_conversion_detail::canonical_src_lane_dim2(vec_conversion_detail::row_from_indices_dim1(laneid, 0)) + ); + dst[i][0].y = packed_shfl_sync( + kittens::MASK_ALL, + laneid < 4 ? src[i][1].x : src[i][1].y, // mirrors canonical_src_lane_dim2 + vec_conversion_detail::canonical_src_lane_dim2(vec_conversion_detail::row_from_indices_dim1(laneid, 1)) + ); + } + } + else if constexpr (std::is_same_v && std::is_same_v) { // ortho -> align layout + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + dst[i][0].x = packed_shfl_sync( + kittens::MASK_ALL, + src[i][0].x, // first 8 rows + vec_conversion_detail::canonical_src_lane_dim1(vec_conversion_detail::row_from_indices_dim2(laneid, 0, 0)) + ); + dst[i][0].y = packed_shfl_sync( + kittens::MASK_ALL, + src[i][0].x, // first 8 rows + vec_conversion_detail::canonical_src_lane_dim1(vec_conversion_detail::row_from_indices_dim2(laneid, 0, 1)) + ); + dst[i][1].x = packed_shfl_sync( + kittens::MASK_ALL, + src[i][0].y, // last 8 rows + vec_conversion_detail::canonical_src_lane_dim1(vec_conversion_detail::row_from_indices_dim2(laneid, 1, 0)) + ); + dst[i][1].y = packed_shfl_sync( + kittens::MASK_ALL, + src[i][0].y, // last 8 rows + vec_conversion_detail::canonical_src_lane_dim1(vec_conversion_detail::row_from_indices_dim2(laneid, 1, 1)) + ); + } + } + else if constexpr (std::is_same_v && std::is_same_v) { // naive -> ortho layout + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + dst[i][0].x = packed_shfl_sync( + kittens::MASK_ALL, src[i/2][0], + 16*(i%2) + 0 + (laneid/4) + ); + dst[i][0].y = packed_shfl_sync( + kittens::MASK_ALL, src[i/2][0], + 16*(i%2) + 8 + (laneid/4) + ); + } + } + else if constexpr (std::is_same_v && std::is_same_v) { // ortho -> naive layout + int lane_replication = laneid%4; // 0...3 + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + D1 tmp = 0; + if(RV1::length%32==0 || i < RV1::outer_dim-1 || lane_replication<2) { + tmp = lane_replication%2 ? src[2*i + (lane_replication>=2)][0].y : src[2*i + (lane_replication>=2)][0].x; + } + dst[i][0] = packed_shfl_sync( + kittens::MASK_ALL, tmp, + (laneid%8)*4 + (laneid/8) + ); + } + } + else if constexpr (std::is_same_v && std::is_same_v) { // naive -> align layout + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + dst[i][0].x = packed_shfl_sync( + kittens::MASK_ALL, src[i/2][0], + 16*(i%2) + 0 + 2*(laneid%4) + 0 + ); + dst[i][0].y = packed_shfl_sync( + kittens::MASK_ALL, src[i/2][0], + 16*(i%2) + 0 + 2*(laneid%4) + 1 + ); + dst[i][1].x = packed_shfl_sync( + kittens::MASK_ALL, src[i/2][0], + 16*(i%2) + 8 + 2*(laneid%4) + 0 + ); + dst[i][1].y = packed_shfl_sync( + kittens::MASK_ALL, src[i/2][0], + 16*(i%2) + 8 + 2*(laneid%4) + 1 + ); + } + } + else if constexpr (std::is_same_v && std::is_same_v) { // align -> naive layout + int lane_replication = laneid/8; // 0...3 + #pragma unroll + for(int i = 0; i < RV1::outer_dim; i++) { + D1 tmp = 0; + if(RV1::length%32==0 || i < RV1::outer_dim-1 || laneid<16) { + tmp = (laneid%8)<4 ? src[2*i + (lane_replication>=2)][lane_replication%2].x : src[2*i + (lane_replication>=2)][lane_replication%2].y; + } + dst[i][0] = packed_shfl_sync( + kittens::MASK_ALL, tmp, + 4*(laneid%2) + (laneid%8)/2 + (laneid&0b11000) + ); + } + } + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/vec/maps.cuh b/extra/thunder/cuda/include/ops/group/register/vec/maps.cuh new file mode 100644 index 0000000000..fb0ccbc3c2 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/vec/maps.cuh @@ -0,0 +1,374 @@ +/** + * @file + * @brief Maps on vectors stored in registers. + */ + +/* ---------- Vector Maps ---------- */ + +/** + * @brief Perform a unary operation on a vector. + * + * @tparam op The unary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector to perform the operation on. + */ +template +__device__ static inline void unary_op(T &dst, const T &src) { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + #pragma unroll + for(int j = 0; j < dst.inner_dim; j++) { + dst[i][j] = op::template op(src[i][j]); + } + } +} +/** + * @brief Perform a binary operation on two vectors. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vectors. + * @param dst[out] The destination vector where the result is stored. + * @param lhs[in] The left-hand side vector for the operation. + * @param rhs[in] The right-hand side vector for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + #pragma unroll + for(int j = 0; j < dst.inner_dim; j++) { + dst[i][j] = op::template op(lhs[i][j], rhs[i][j]); + } + } +} +/** + * @brief Perform a binary operation on a vector and a scalar. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector for the operation. + * @param param[in] The scalar parameter for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + #pragma unroll + for(int j = 0; j < dst.inner_dim; j++) { + dst[i][j] = op::template op(src[i][j], param); + } + } +} +/** + * @brief Perform a binary operation on a vector and an unpacked scalar. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector for the operation. + * @param param[in] The unpacked scalar parameter for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &src, const typename base_types::packing::unpacked_type ¶m) { + bin_op(dst, src, base_types::packing::pack(param)); +} + + +template +__device__ static inline void apply(RV &dst, const RV &src, Lambda &&lambda) { + int group_offset = 0; + if constexpr(GROUP_WARPS > 1) { + group_offset = warpid()*RV::length; + } + static_assert(sizeof(RV::T) != 1, "Cannot apply lambda to 8-bit types"); + if constexpr (ducks::rv::ortho_layout) { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + int base_idx = group_offset + i*16 + ::kittens::laneid()/4; + dst[i][0].x = lambda(base_idx+0, src[i][0].x); + dst[i][0].y = lambda(base_idx+8, src[i][0].y); + } + } + else if constexpr (ducks::rv::align_layout) { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + int base_idx = group_offset + i*16 + 2*(::kittens::laneid()%4); + dst[i][0].x = lambda(base_idx+0, src[i][0].x); + dst[i][0].y = lambda(base_idx+1, src[i][0].y); + dst[i][1].x = lambda(base_idx+8, src[i][1].x); + dst[i][1].y = lambda(base_idx+9, src[i][1].y); + } + } + else { + #pragma unroll + for(int i = 0; i < dst.outer_dim; i++) { + int base_idx = group_offset + i*32 + ::kittens::laneid(); + if (i < dst.outer_dim-1 || dst.length%32 == 0 || ::kittens::laneid()<16) { + dst[i][0] = lambda(base_idx, src[i][0]); + } + } + } +} +template +__device__ static inline RV apply(const RV &src, Lambda &&lambda) { + RV dst; + apply(dst, src, std::forward(lambda)); + return dst; +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// ---- const ops ---- + +/** + * @brief Sets all elements of a register vector to zero. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector to be set to zero. + */ +template +__device__ static inline void zero(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a register vector to one. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector to be set to one. + */ +template +__device__ static inline void one(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a register vector to positive infinity. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector to be set to positive infinity. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a register vector to negative infinity. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector to be set to negative infinity. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_op(dst, dst); +} + +// ---- unary ops ---- + +/** + * @brief Copies the elements from one register vector to another. + * + * @tparam T Register vector type. + * @tparam U Type of the source vector. + * @param dst[out] Destination vector where the elements will be copied to. + * @param src[in] Source vector to copy the elements from. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_op(dst, dst, src); // the second arg is ignored here. +} +/** + * @brief Applies the exponential function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_op(dst, src); +} +template +__device__ static inline T exp(const T &src) { + T dst; + exp(dst, src); + return dst; +} +/** + * @brief Applies the exponential function element-wise to a register vector, in base 2. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void exp2(T &dst, const T &src) { + unary_op(dst, src); +} +template +__device__ static inline T exp2(const T &src) { + T dst; + exp2(dst, src); + return dst; +} +/** + * @brief Applies the natural logarithm function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_op(dst, src); +} +template +__device__ static inline T log(const T &src) { + T dst; + log(dst, src); + return dst; +} +/** + * @brief Applies the logarithm base 2 function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the logarithm base 2 function to. + */ +template +__device__ static inline void log2(T &dst, const T &src) { + unary_op(dst, src); +} +template +__device__ static inline T log2(const T &src) { + T dst; + log2(dst, src); + return dst; +} +/** + * @brief Applies the absolute value function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the absolute values will be stored. + * @param src[in] Source vector to apply the absolute value function to. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_op(dst, src); +} +template +__device__ static inline T abs(const T &src) { + T dst; + abs(dst, src); + return dst; +} +/** + * @brief Applies the rectified linear unit (ReLU) function element-wise to a register vector. + * + * @tparam T Register vector type. + * @param dst[out] Destination vector where the ReLU values will be stored. + * @param src[in] Source vector to apply the ReLU function to. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_op(dst, src); +} +template +__device__ static inline T relu(const T &src) { + T dst; + relu(dst, src); + return dst; +} + +// ---- binary ops ---- + +/** + * @brief Computes the element-wise maximum of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the maximum values will be stored. + * @param lhs[in] First vector for the maximum operation. + * @param rhs[in] Second vector for the maximum operation. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +template +__device__ static inline T max(const T &lhs, const U &rhs) { + T dst; + max(dst, lhs, rhs); + return dst; +} +/** + * @brief Computes the element-wise minimum of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the minimum values will be stored. + * @param lhs[in] First vector for the minimum operation. + * @param rhs[in] Second vector for the minimum operation. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +template +__device__ static inline T min(const T &lhs, const U &rhs) { + T dst; + min(dst, lhs, rhs); + return dst; +} +/** + * @brief Computes the element-wise sum of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the sum values will be stored. + * @param lhs[in] First vector for the sum operation. + * @param rhs[in] Second vector for the sum operation. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise difference of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the difference values will be stored. + * @param lhs[in] First vector for the difference operation. + * @param rhs[in] Second vector for the difference operation. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise product of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the product values will be stored. + * @param lhs[in] First vector for the product operation. + * @param rhs[in] Second vector for the product operation. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise division of two register vectors. + * + * @tparam T Register vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the division values will be stored. + * @param lhs[in] First vector for the division operation. + * @param rhs[in] Second vector for the division operation. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/vec/reductions.cuh b/extra/thunder/cuda/include/ops/group/register/vec/reductions.cuh new file mode 100644 index 0000000000..f9ae971924 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/vec/reductions.cuh @@ -0,0 +1,233 @@ +/** + * @file + * @brief Reductions on vectors stored in registers. + */ + +/* ---------- Vector Reductions ---------- */ + +/** + * @brief Performs a reduction operation on elements of a register vector within a warp. + * + * This function applies a specified operation to reduce the elements of a register vector `src` to a single value. + * The result is stored in `accum`. If the `reset` parameter is true, the reduction includes an initial value `src_accum`. + * The reduction operation is performed in a warp-wide context, ensuring synchronization between threads in the warp. + * + * @tparam op The operation to perform on the elements. Must provide a static `op` method. + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @tparam reset A boolean flag indicating whether to include an initial value in the reduction. + * @param[out] accum The result of the reduction operation. + * @param[in] src The register vector to reduce. + * @param[in] src_accum The initial value to include in the reduction if `reset` is false. + */ +template +__device__ static inline void reduce( + typename base_types::packing::unpacked_type &dst_accum, + const RV &src, + const typename base_types::packing::unpacked_type &src_accum) { + KITTENS_CHECK_WARP + using T = base_types::packing::unpacked_type; + int laneid = kittens::laneid(); + if constexpr (std::is_same_v) { + T accum = op::template op(src[0][0].x, src[0][0].y); + #pragma unroll + for(int i = 1; i < src.outer_dim; i++) { + accum = op::template op(accum, src[i][0].x); + accum = op::template op(accum, src[i][0].y); + } + // we've now reduced everything into 8 distinct values, replicated across lanes x, x+1, x+2, x+3 for x≡0(mod4) + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 16)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 8)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 4)); + // we've now reduced everything into 1 distinct value, replicated across lanes 0, 1, 2, 3 + if constexpr (!reset) accum = op::template op(accum, src_accum); + // final result has now been achieved (incorporating src_accum if necessary), finally broadcast back to all threads. + dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0); + } + else if constexpr (std::is_same_v) { + T accum = op::template op(src[0][0].x, src[0][0].y); + accum = op::template op(accum, src[0][1].x); + accum = op::template op(accum, src[0][1].y); + #pragma unroll + for(int i = 1; i < src.outer_dim; i++) { + // it is possible that shfl_sync's would be faster but I doubt it, replication is likely better. Certainly simpler. + accum = op::template op(accum, src[i][0].x); + accum = op::template op(accum, src[i][0].y); + accum = op::template op(accum, src[i][1].x); + accum = op::template op(accum, src[i][1].y); + } + // we've now reduced everything into 4 distinct values, replicated across lanes x, x+4, x+8, ..., x+28 for x<4 + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 2)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 1)); + // we've now reduced everything into 1 distinct value, replicated across lanes 0, 4, 8, 12, ..., 28 + if constexpr (!reset) accum = op::template op(accum, src_accum); + // final result has now been achieved (incorporating src_accum if necessary), finally broadcast back to all threads from lane 0 + dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0); + } + else if constexpr (std::is_same_v) { + T accum = src[0][0]; + #pragma unroll + for(int i = 1; i < src.outer_dim; i++) { + if (i < src.outer_dim-1 || i*kittens::TILE_ROW_DIM*2 + laneid < src.length) { + accum = op::template op(accum, src[i][0]); + } + } + if(src.length > 16) accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 16)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 8)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 4)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 2)); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 1)); + if constexpr (!reset) accum = op::template op(accum, src_accum); + dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0); + } +} + + +/** + * @brief Finds the maximum element in a register vector. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] max_val The maximum value found in the vector. + * @param[in] src The register vector to find the maximum in. + */ +template +__device__ static inline void max(typename base_types::packing::unpacked_type &max_val, const RV &src) { + reduce(max_val, src, max_val); +} +template +__device__ static inline typename base_types::packing::unpacked_type max(const RV &src) { + typename base_types::packing::unpacked_type max_val; + reduce(max_val, src, max_val); + return max_val; +} + +/** + * @brief Finds the minimum element in a register vector. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] min_val The minimum value found in the vector. + * @param[in] src The register vector to find the minimum in. + */ +template +__device__ static inline void min(typename base_types::packing::unpacked_type &min_val, const RV &src) { + reduce(min_val, src, min_val); +} +template +__device__ static inline typename base_types::packing::unpacked_type min(const RV &src) { + typename base_types::packing::unpacked_type min_val; + reduce(min_val, src, min_val); + return min_val; +} + +/** + * @brief Calculates the sum of elements in a register vector. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] sum_val The sum of the values in the vector. + * @param[in] src The register vector to sum. + */ +template +__device__ static inline void sum(typename base_types::packing::unpacked_type &sum_val, const RV &src) { + reduce(sum_val, src, sum_val); +} +template +__device__ static inline typename base_types::packing::unpacked_type sum(const RV &src) { + typename base_types::packing::unpacked_type sum_val; + reduce(sum_val, src, sum_val); + return sum_val; +} + +/** + * @brief Calculates the product of elements in a register vector. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] prod_val The product of the values in the vector. + * @param[in] src The register vector to multiply. + */ +template +__device__ static inline void prod(typename base_types::packing::unpacked_type &prod_val, const RV &src) { + reduce(prod_val, src, prod_val); +} +template +__device__ static inline typename base_types::packing::unpacked_type prod(const RV &src) { + typename base_types::packing::unpacked_type prod_val; + reduce(prod_val, src, prod_val); + return prod_val; +} + +// Three operand versions. + +/** + * @brief Finds the maximum element in a register vector and accumulates it with src_accum. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] max_val The maximum value found in the vector, accumulated with src_accum. + * @param[in] src The register vector to find the maximum in. + * @param[in] src_accum The initial value to accumulate with the maximum value found. + */ +template +__device__ static inline void max(typename base_types::packing::unpacked_type &max_val, const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + reduce(max_val, src, src_accum); +} +template +__device__ static inline typename base_types::packing::unpacked_type max(const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + typename base_types::packing::unpacked_type max_val; + reduce(max_val, src, src_accum); + return max_val; +} + +/** + * @brief Finds the minimum element in a register vector and accumulates it with src_accum. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] min_val The minimum value found in the vector, accumulated with src_accum. + * @param[in] src The register vector to find the minimum in. + * @param[in] src_accum The initial value to accumulate with the minimum value found. + */ +template +__device__ static inline void min(typename base_types::packing::unpacked_type &min_val, const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + reduce(min_val, src, src_accum); +} +template +__device__ static inline typename base_types::packing::unpacked_type min(const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + typename base_types::packing::unpacked_type min_val; + reduce(min_val, src, src_accum); + return min_val; +} + +/** + * @brief Calculates the sum of elements in a register vector and accumulates it with src_accum. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] sum_val The sum of the values in the vector, accumulated with src_accum. + * @param[in] src The register vector to sum. + * @param[in] src_accum The initial value to accumulate with the sum of the vector. + */ +template +__device__ static inline void sum(typename base_types::packing::unpacked_type &sum_val, const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + reduce(sum_val, src, src_accum); +} +template +__device__ static inline typename base_types::packing::unpacked_type sum(const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + typename base_types::packing::unpacked_type sum_val; + reduce(sum_val, src, src_accum); + return sum_val; +} + +/** + * @brief Calculates the product of elements in a register vector and accumulates it with src_accum. + * + * @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept. + * @param[out] prod_val The product of the values in the vector, accumulated with src_accum. + * @param[in] src The register vector to multiply. + * @param[in] src_accum The initial value to accumulate with the product of the vector. + */ +template +__device__ static inline void prod(typename base_types::packing::unpacked_type &prod_val, const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + reduce(prod_val, src, src_accum); +} +template +__device__ static inline typename base_types::packing::unpacked_type prod(const RV &src, const typename base_types::packing::unpacked_type &src_accum) { + typename base_types::packing::unpacked_type prod_val; + reduce(prod_val, src, src_accum); + return prod_val; +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/register/vec/vec.cuh b/extra/thunder/cuda/include/ops/group/register/vec/vec.cuh new file mode 100644 index 0000000000..cd5f9d35ee --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/register/vec/vec.cuh @@ -0,0 +1,59 @@ +/** + * @file + * @brief An aggregate header for warp operations on register vectors. + */ + +#include "conversions.cuh" +#include "maps.cuh" +#include "reductions.cuh" + +template +__device__ static inline bool hasnan(const RV &src) { + KITTENS_CHECK_WARP + bool nan_detected = false; + #pragma unroll + for(int i = 0; i < RV::outer_dim; i++) { + #pragma unroll + for(int j = 0; j < RV::inner_dim; j++) { + if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + if(isnan(src[i][j].x) || isnan(src[i][j].y)) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__bfloat162float(src[i][j].x)) || isnan(__bfloat162float(src[i][j].y))) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__half2float(src[i][j].x)) || isnan(__half2float(src[i][j].y))) { + nan_detected = true; + } + } + } + else if constexpr (std::is_same_v) { + if constexpr (std::is_same_v) { + if(isnan(src[i][j])) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__bfloat162float(src[i][j]))) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__half2float(src[i][j]))) { + nan_detected = true; + } + } + } + else { + static_assert(sizeof(typename RV::dtype) == 999, "Unsupported dtype"); + } + } + } + // Ballot across the warp to see if any lane detected a nan + return (__ballot_sync(0xffffffff, nan_detected) != 0); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/shared/shared.cuh b/extra/thunder/cuda/include/ops/group/shared/shared.cuh new file mode 100644 index 0000000000..6558b07f15 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/shared.cuh @@ -0,0 +1,7 @@ +/** + * @file + * @brief An aggregate header of group operations on data in shared memory + */ + +#include "tile/tile.cuh" +#include "vec/vec.cuh" \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/shared/tile/conversions.cuh b/extra/thunder/cuda/include/ops/group/shared/tile/conversions.cuh new file mode 100644 index 0000000000..95c614b23a --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/tile/conversions.cuh @@ -0,0 +1,16 @@ +/** + * @file + * @brief Group conversions between different shared memory tile types. + */ + +/* ---------- COPIES ---------- */ + +template +__device__ static inline void copy(ST1 &dst, const ST2 &src) { + static_assert(ST1::height == ST2::height && ST1::width == ST2::width, "Tiles must have the same height and width"); + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i+=GROUP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = base_types::convertor::convert(src[{row, col}]); + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/shared/tile/maps.cuh b/extra/thunder/cuda/include/ops/group/shared/tile/maps.cuh new file mode 100644 index 0000000000..b0b7330944 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/tile/maps.cuh @@ -0,0 +1,236 @@ +/** + * @file + * @brief Group maps on shared tiles. + */ + + +template // T2, w, h can be inferred from dst as long as op is specialized +__device__ static inline void unary_map(T &dst, const T &src) { + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + dst.data[i] = op::template op(src.data[i]); + } +} + +template +__device__ static inline void bin_map(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + dst.data[i] = op::template op(src.data[i], param); + } +} + +template +__device__ static inline void bin_map(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + dst.data[i] = op::template op(lhs.data[i], rhs.data[i]); + } +} + +template +__device__ static inline void row_map(T &dst, const T &src, const V &vec) { + static_assert(std::is_same::value, "Tile and vector must have the same data type"); + static_assert(V::length == T::rows, "Vector length must match the number of rows in the tile"); + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = op::template op(src[{row, col}], vec[row]); + } +} + +template +__device__ static inline void col_map(T &dst, const T &src, const V &vec) { + static_assert(std::is_same::value, "Tile and vector must have the same data type"); + static_assert(V::length == T::cols, "Vector length must match the number of columns in the tile"); + #pragma unroll + for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) { + int row = i/dst.cols, col = i%dst.cols; + dst[{row, col}] = op::template op(src[{row, col}], vec[col]); + } +} + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// All of the annoying qualifiers *should* be automatically inferred during compile-time. +// So, syntax should just be kittens::add_row(tile, colvec); + +// const maps + +template +__device__ static inline void zero(T &dst) { + unary_map(dst, dst); +} + +template +__device__ static inline void one(T &dst) { + unary_map(dst, dst); +} + +template +__device__ static inline void pos_infty(T &dst) { + unary_map(dst, dst); +} + +template +__device__ static inline void neg_infty(T &dst) { + unary_map(dst, dst); +} + +// unary maps + +template +__device__ static inline void exp(T &dst, const T &src) { + unary_map(dst, src); +} + +template +__device__ static inline void exp2(T &dst, const T &src) { + unary_map(dst, src); +} + +template +__device__ static inline void log(T &dst, const T &src) { + unary_map(dst, src); +} + +template +__device__ static inline void log2(T &dst, const T &src) { + unary_map(dst, src); +} + +template +__device__ static inline void abs(T &dst, const T &src) { + unary_map(dst, src); +} + +template +__device__ static inline void relu(T &dst, const T &src) { + unary_map(dst, src); +} + +template +__device__ static inline void copy(T &dst, const U &src) { + bin_map(dst, src); +} + +// uniform binary maps + +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_map(dst, lhs, rhs); +} + +// Row and col maps + + +template +__device__ static inline void add_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} + +template +__device__ static inline void sub_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} + +template +__device__ static inline void mul_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} + +template +__device__ static inline void div_row(T &dst, const T &src, const V &row_values) { + row_map(dst, src, row_values); +} + +template +__device__ static inline void broadcast_row(T &dst, const V &row_values) { + row_map(dst, dst, row_values); +} + + +// col maps + +template +__device__ static inline void add_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} + +template +__device__ static inline void sub_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} + +template +__device__ static inline void mul_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} + +template +__device__ static inline void div_col(T &dst, const T &src, const V &col_values) { + col_map(dst, src, col_values); +} + +template +__device__ static inline void broadcast_col(T &dst, const V &col_values) { + col_map(dst, dst, col_values); +} + +// Templated versions of each + +template +__device__ static inline void add(T &dst, const T &src, const V &col_values) { + if constexpr (axis == axis::COL) add_col(dst, src, col_values); + else add_row(dst, src, col_values); +} + +template +__device__ static inline void sub(T &dst, const T &src, const V &col_values) { + if constexpr (axis == axis::COL) sub_col(dst, src, col_values); + else sub_row(dst, src, col_values); +} + +template +__device__ static inline void mul(T &dst, const T &src, const V &col_values) { + if constexpr (axis == axis::COL) mul_col(dst, src, col_values); + else mul_row(dst, src, col_values); +} + +template +__device__ static inline void div(T &dst, const T &src, const V &col_values) { + if constexpr (axis == axis::COL) div_col(dst, src, col_values); + else div_row(dst, src, col_values); +} + +template +__device__ static inline void broadcast(T &dst, const V &col_values) { + if constexpr (axis == axis::COL) broadcast_col(dst, col_values); + else broadcast_row(dst, col_values); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/shared/tile/reductions.cuh b/extra/thunder/cuda/include/ops/group/shared/tile/reductions.cuh new file mode 100644 index 0000000000..f237f93247 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/tile/reductions.cuh @@ -0,0 +1,372 @@ +/** + * @file + * @brief Group reductions on shared tiles. + */ + +/** + * Performs row-wise reduction on a matrix using a specified operation. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type with row layout. + * @param row_accum The accumulator where the result of the reduction is stored. + * @param src The source matrix on which to perform the reduction. + * @param src_accum The initial value of the accumulator, used when reset is false. + * @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + */ +template +__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) { + using dtype = typename V::dtype; + for (int row = laneid(); row < src.rows; row += GROUP_THREADS) { + dtype accum = src[{row, 0}]; + #pragma unroll + for (int col = 1; col < src.cols; col++) { + accum = op::template op(accum, src[{row, col}]); + } + if (reset) { + row_accum[row] = accum; + } else { + row_accum[row] = op::template op(src_accum[row], accum); + } + } +} + +/** + * Performs column-wise reduction on a matrix using a specified operation. + * + * @tparam op The operation to be applied for reduction. + * @tparam V The shared vector type for the column accumulator. + * @tparam T The shared matrix type with column layout. + * @param col_accum The accumulator where the result of the reduction is stored. + * @param src The source matrix on which to perform the reduction. + * @param src_accum The initial value of the accumulator, used when reset is false. + * @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not. + */ +template +__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) { + using dtype = typename V::dtype; + for (int col = laneid(); col < src.cols; col += GROUP_THREADS) { + dtype accum = src[{0, col}]; + #pragma unroll + for (int row = 1; row < src.rows; row++) { + accum = op::template op(accum, src[{row, col}]); + } + if (reset) { + col_accum[col] = accum; + } else { + col_accum[col] = op::template op(src_accum[col], accum); + } + } +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +/** + * @brief Store the maximum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the minimum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the sum of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} +/** + * @brief Store the product of each row of the src shared matrix in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src) { + row_reduce(row_accum, src, row_accum); +} + +/** + * @brief Store the maximum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_max(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the minimum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_min(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the sum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_sum(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} +/** + * @brief Store the product of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] row_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void row_prod(V &row_accum, const T &src, const V &src_accum) { + row_reduce(row_accum, src, src_accum); +} + +/** + * @brief Store the maximum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the minimum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the sum of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} +/** + * @brief Store the product of each column of the src shared matrix in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src) { + col_reduce(col_accum, src, col_accum); +} + +/** + * @brief Store the maximum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_max(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the minimum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_min(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the sum of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_sum(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} +/** + * @brief Store the product of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector. + * + * @tparam V The shared vector type for the row accumulator. + * @tparam T The shared matrix type. + * @param[out] col_accum The accumulator where the result of the reduction is stored. + * @param[in] src The source matrix on which to perform the reduction. + * @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value. + */ +template +__device__ static inline void col_prod(V &col_accum, const T &src, const V &src_accum) { + col_reduce(col_accum, src, src_accum); +} + +// templated versions of each + +template +__device__ static inline void max(V &dst, const T &src, const V &src_accum) { + if constexpr (ax == axis::COL) row_max(dst, src, src_accum); + else col_max(dst, src, src_accum); +} +template +__device__ static inline auto max(const T &src, const V &src_accum) { + V dst; + if constexpr (ax == axis::COL) row_max(dst, src, src_accum); + else col_max(dst, src, src_accum); + return dst; +} +template +__device__ static inline void max(V &dst, const T &src) { + if constexpr (ax == axis::COL) row_max(dst, src); + else col_max(dst, src); +} +template +__device__ static inline auto max(const T &src) { + using V = std::conditional_t; + V dst; + if constexpr (ax == axis::COL) row_max(dst, src); + else col_max(dst, src); + return dst; +} + +template +__device__ static inline void min(V &dst, const T &src, const V &src_accum) { + if constexpr (ax == axis::COL) row_min(dst, src, src_accum); + else col_min(dst, src, src_accum); +} +template +__device__ static inline auto min(const T &src, const V &src_accum) { + V dst; + if constexpr (ax == axis::COL) row_min(dst, src, src_accum); + else col_min(dst, src, src_accum); + return dst; +} +template +__device__ static inline void min(V &dst, const T &src) { + if constexpr (ax == axis::COL) row_min(dst, src); + else col_min(dst, src); +} +template +__device__ static inline auto min(const T &src) { + using V = std::conditional_t; + V dst; + if constexpr (ax == axis::COL) row_min(dst, src); + else col_min(dst, src); + return dst; +} + +template +__device__ static inline void sum(V &dst, const T &src, const V &src_accum) { + if constexpr (ax == axis::COL) row_sum(dst, src, src_accum); + else col_sum(dst, src, src_accum); +} +template +__device__ static inline auto sum(const T &src, const V &src_accum) { + V dst; + if constexpr (ax == axis::COL) row_sum(dst, src, src_accum); + else col_sum(dst, src, src_accum); + return dst; +} +template +__device__ static inline void sum(V &dst, const T &src) { + if constexpr (ax == axis::COL) row_sum(dst, src); + else col_sum(dst, src); +} +template +__device__ static inline auto sum(const T &src) { + using V = std::conditional_t; + V dst; + if constexpr (ax == axis::COL) row_sum(dst, src); + else col_sum(dst, src); + return dst; +} + +template +__device__ static inline void prod(V &dst, const T &src, const V &src_accum) { + if constexpr (ax == axis::COL) row_prod(dst, src, src_accum); + else col_prod(dst, src, src_accum); +} +template +__device__ static inline auto prod(const T &src, const V &src_accum) { + V dst; + if constexpr (ax == axis::COL) row_prod(dst, src, src_accum); + else col_prod(dst, src, src_accum); + return dst; +} +template +__device__ static inline void prod(V &dst, const T &src) { + if constexpr (ax == axis::COL) row_prod(dst, src); + else col_prod(dst, src); +} +template +__device__ static inline auto prod(const T &src) { + using V = std::conditional_t; + V dst; + if constexpr (ax == axis::COL) row_prod(dst, src); + else col_prod(dst, src); + return dst; +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/shared/tile/tile.cuh b/extra/thunder/cuda/include/ops/group/shared/tile/tile.cuh new file mode 100644 index 0000000000..e1bb87bba4 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/tile/tile.cuh @@ -0,0 +1,37 @@ +/** + * @file + * @brief An aggregate header for group operations on shared tiles. + */ + +#include "conversions.cuh" +#include "maps.cuh" +#include "reductions.cuh" + +template +__device__ static inline bool hasnan(const ST &src) { + KITTENS_CHECK_WARP + bool nan_detected = false; + #pragma unroll + for(int i = laneid(); i < ST::num_elements; i+=GROUP_THREADS) { + if constexpr (std::is_same_v) { + if(isnan(src[i])) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__bfloat162float(src[i]))) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__half2float(src[i]))) { + nan_detected = true; + } + } + else { + static_assert(sizeof(typename ST::T) == 999, "Unsupported dtype"); + } + } + // Ballot across the warp to see if any lane detected a nan + return (__ballot_sync(0xffffffff, nan_detected) != 0); +} diff --git a/extra/thunder/cuda/include/ops/group/shared/vec/conversions.cuh b/extra/thunder/cuda/include/ops/group/shared/vec/conversions.cuh new file mode 100644 index 0000000000..4d4c7d3635 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/vec/conversions.cuh @@ -0,0 +1,27 @@ +/** + * @file + * @brief Group conversions on shared vectors. + */ + +/** + * @brief Copies data from one shared vector to another, converting data types if necessary. + * + * This function copies data from the source shared vector `src` to the destination shared vector `dst`. + * If the data types of `src` and `dst` are the same, it performs a direct memory copy. Otherwise, it + * converts each element from the source data type to the destination data type using the appropriate + * converter before copying. + * + * @tparam SV1 The type of the destination shared vector, must satisfy the ducks::sv::all concept. + * @tparam SV2 The type of the source shared vector, must satisfy the ducks::sv::all concept. + * @param[out] dst The destination shared vector. + * @param[in] src The source shared vector. + * @note The lengths of `src` and `dst` must be equal. This is enforced at compile time. + */ +template +__device__ static inline void copy(SV1 &dst, const SV2 &src) { + static_assert(SV1::length == SV2::length, "Source and destination vectors must have the same length."); + #pragma unroll + for(int i = laneid(); i < dst.length; i+=GROUP_THREADS) { + dst[i] = base_types::convertor::convert(src[i]); + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/shared/vec/maps.cuh b/extra/thunder/cuda/include/ops/group/shared/vec/maps.cuh new file mode 100644 index 0000000000..987a2cb00b --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/vec/maps.cuh @@ -0,0 +1,259 @@ +/** + * @file + * @brief Group maps on shared vectors. + */ + +/** + * @brief Applies a unary operation to each element of a shared memory vector. + * + * @tparam op Unary operation type. + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector in which to store the result. + * @param src[in] Source vector to apply the unary operation. + */ +template +__device__ static inline void unary_op(T &dst, const T &src) { + #pragma unroll + for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) { + dst[cur] = op::template op(src[cur]); + } +} +/** + * @brief Perform a binary operation on two shared vectors. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vectors. + * @param dst[out] The destination vector where the result is stored. + * @param lhs[in] The left-hand side vector for the operation. + * @param rhs[in] The right-hand side vector for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &lhs, const T &rhs) { + #pragma unroll + for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) { + dst[cur] = op::template op(lhs[cur], rhs[cur]); + } +} +/** + * @brief Perform a binary operation on a shared vector and a scalar. + * + * @tparam op The binary operation to perform. + * @tparam T The type of the vector. + * @param dst[out] The destination vector where the result is stored. + * @param src[in] The source vector for the operation. + * @param param[in] The scalar parameter for the operation. + */ +template +__device__ static inline void bin_op(T &dst, const T &src, const typename T::dtype ¶m) { + #pragma unroll + for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) { + dst[cur] = op::template op(src[cur], param); + } +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// ---- const ops ---- + +/** + * @brief Sets all elements of a shared memory vector to zero. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to zero. + */ +template +__device__ static inline void zero(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to one. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to one. + */ +template +__device__ static inline void one(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to positive infinity. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to positive infinity. + */ +template +__device__ static inline void pos_infty(T &dst) { + unary_op(dst, dst); +} +/** + * @brief Sets all elements of a shared memory vector to negative infinity. + * + * @tparam T Shared memory vector type. + * @param dst[out] Destination vector to be set to negative infinity. + */ +template +__device__ static inline void neg_infty(T &dst) { + unary_op(dst, dst); +} + +// ---- unary ops ---- + +/** + * @brief Copies the elements from one shared vector to another. + * + * @tparam T Shared vector type. + * @tparam U Type of the source vector. + * @param dst[out] Destination vector where the elements will be copied to. + * @param src[in] Source vector to copy the elements from. + */ +template +__device__ static inline void copy(T &dst, const U &src) { + bin_op(dst, dst, src); // the second arg is ignored here. +} +/** + * @brief Applies the exponential function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void exp(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the exponential function element-wise to a shared vector, in base 2. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the exponential values will be stored. + * @param src[in] Source vector to apply the exponential function to. + */ +template +__device__ static inline void exp2(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the natural logarithm function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the logarithm values will be stored. + * @param src[in] Source vector to apply the logarithm function to. + */ +template +__device__ static inline void log(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the logarithm base 2 function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the logarithm base 2 values will be stored. + * @param src[in] Source vector to apply the logarithm base 2 function to. + */ +template +__device__ static inline void log2(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the absolute value function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the absolute values will be stored. + * @param src[in] Source vector to apply the absolute value function to. + */ +template +__device__ static inline void abs(T &dst, const T &src) { + unary_op(dst, src); +} +/** + * @brief Applies the rectified linear unit (ReLU) function element-wise to a shared vector. + * + * @tparam T Shared vector type. + * @param dst[out] Destination vector where the ReLU values will be stored. + * @param src[in] Source vector to apply the ReLU function to. + */ +template +__device__ static inline void relu(T &dst, const T &src) { + unary_op(dst, src); +} + +// ---- binary ops ---- + +/** + * @brief Computes the element-wise maximum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the maximum values will be stored. + * @param lhs[in] First vector for the maximum operation. + * @param rhs[in] Second vector for the maximum operation. + */ +template +__device__ static inline void max(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise minimum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the minimum values will be stored. + * @param lhs[in] First vector for the minimum operation. + * @param rhs[in] Second vector for the minimum operation. + */ +template +__device__ static inline void min(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise sum of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the sum values will be stored. + * @param lhs[in] First vector for the sum operation. + * @param rhs[in] Second vector for the sum operation. + */ +template +__device__ static inline void add(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise difference of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the difference values will be stored. + * @param lhs[in] First vector for the difference operation. + * @param rhs[in] Second vector for the difference operation. + */ +template +__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise product of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the product values will be stored. + * @param lhs[in] First vector for the product operation. + * @param rhs[in] Second vector for the product operation. + */ +template +__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} +/** + * @brief Computes the element-wise division of two shared vectors. + * + * @tparam T Shared vector type. + * @tparam U Type of the second vector. + * @param dst[out] Destination vector where the division values will be stored. + * @param lhs[in] First vector for the division operation. + * @param rhs[in] Second vector for the division operation. + */ +template +__device__ static inline void div(T &dst, const T &lhs, const U &rhs) { + bin_op(dst, lhs, rhs); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/shared/vec/reductions.cuh b/extra/thunder/cuda/include/ops/group/shared/vec/reductions.cuh new file mode 100644 index 0000000000..8cf858e8dc --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/vec/reductions.cuh @@ -0,0 +1,193 @@ +/** + * @file + * @brief Group reductions on shared vectors. + */ + +// The fastest way to do this, under most circumstances, is actually to just have each warp replicate it. +// This is not true for enormous shared vectors, but doing that efficiently actually requires some extra scratch shared memory. +// So, this is sufficient for the time being. +template +__device__ static inline void reduce(typename SV::dtype &dst_accum, const SV &src, const typename SV::dtype &src_accum) { + if constexpr (GROUP_WARPS == 1) { + using T = SV::dtype; + int lane = laneid(); + T accum; + if(lane < src.length) accum = src[lane]; // initialize a register accumulator + __syncwarp(); + for(int i = lane+kittens::WARP_THREADS; i < src.length; i+=kittens::WARP_THREADS) { + accum = op::template op(accum, src[i]); + } + __syncwarp(); + // We can now reduce within the warp. + if constexpr (src.length > 16) { + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 16)); + __syncwarp(); + } + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 8)); + __syncwarp(); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 4)); + __syncwarp(); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 2)); + __syncwarp(); + accum = op::template op(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 1)); + __syncwarp(); + if constexpr (!reset) accum = op::template op(accum, src_accum); + // broadcast to all threads in the warp. + dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0); // everyone takes from warp leader + } + else { + ::kittens::group<1>::reduce(dst_accum, src, src_accum); + } +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +/** + * @brief Finds the maximum element in a shared memory vector. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] max_val The maximum value found in the vector. + * @param[in] src The shared memory vector to find the maximum in. + */ +template +__device__ static inline void max(typename SV::dtype &max_val, const SV &src) { + reduce(max_val, src, max_val); +} +template +__device__ static inline typename SV::dtype max(const SV &src) { + typename SV::dtype max_val; + reduce(max_val, src, max_val); + return max_val; +} + +/** + * @brief Finds the minimum element in a shared memory vector. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] min_val The minimum value found in the vector. + * @param[in] src The shared memory vector to find the minimum in. + */ +template +__device__ static inline void min(typename SV::dtype &min_val, const SV &src) { + reduce(min_val, src, min_val); +} +template +__device__ static inline typename SV::dtype min(const SV &src) { + typename SV::dtype min_val; + reduce(min_val, src, min_val); + return min_val; +} + +/** + * @brief Calculates the sum of elements in a shared memory vector. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] sum_val The sum of the values in the vector. + * @param[in] src The shared memory vector to sum. + */ +template +__device__ static inline void sum(typename SV::dtype &sum_val, const SV &src) { + reduce(sum_val, src, sum_val); +} +template +__device__ static inline typename SV::dtype sum(const SV &src) { + typename SV::dtype sum_val; + reduce(sum_val, src, sum_val); + return sum_val; +} + +/** + * @brief Calculates the product of elements in a shared memory vector. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] prod_val The product of the values in the vector. + * @param[in] src The shared memory vector to multiply. + */ +template +__device__ static inline void prod(typename SV::dtype &prod_val, const SV &src) { + reduce(prod_val, src, prod_val); +} +template +__device__ static inline typename SV::dtype prod(const SV &src) { + typename SV::dtype prod_val; + reduce(prod_val, src, prod_val); + return prod_val; +} + +// Three operand versions. + +/** + * @brief Finds the maximum element in a shared memory vector and accumulates it with src_accum. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] max_val The maximum value found in the vector, accumulated with src_accum. + * @param[in] src The shared memory vector to find the maximum in. + * @param[in] src_accum The initial value to accumulate with the maximum value found. + */ +template +__device__ static inline void max(typename SV::dtype &max_val, const SV &src, const typename SV::dtype &src_accum) { + reduce(max_val, src, src_accum); +} +template +__device__ static inline typename SV::dtype max(const SV &src, const typename SV::dtype &src_accum) { + typename SV::dtype max_val; + reduce(max_val, src, src_accum); + return max_val; +} + +/** + * @brief Finds the minimum element in a shared memory vector and accumulates it with src_accum. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] min_val The minimum value found in the vector, accumulated with src_accum. + * @param[in] src The shared memory vector to find the minimum in. + * @param[in] src_accum The initial value to accumulate with the minimum value found. + */ +template +__device__ static inline void min(typename SV::dtype &min_val, const SV &src, const typename SV::dtype &src_accum) { + reduce(min_val, src, src_accum); +} +template +__device__ static inline typename SV::dtype min(const SV &src, const typename SV::dtype &src_accum) { + typename SV::dtype min_val; + reduce(min_val, src, src_accum); + return min_val; +} + +/** + * @brief Calculates the sum of elements in a shared memory vector and accumulates it with src_accum. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] sum_val The sum of the values in the vector, accumulated with src_accum. + * @param[in] src The shared memory vector to sum. + * @param[in] src_accum The initial value to accumulate with the sum of the vector. + */ +template +__device__ static inline void sum(typename SV::dtype &sum_val, const SV &src, const typename SV::dtype &src_accum) { + reduce(sum_val, src, src_accum); +} +template +__device__ static inline typename SV::dtype sum(const SV &src, const typename SV::dtype &src_accum) { + typename SV::dtype sum_val; + reduce(sum_val, src, src_accum); + return sum_val; +} + +/** + * @brief Calculates the product of elements in a shared memory vector and accumulates it with src_accum. + * + * @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept. + * @param[out] prod_val The product of the values in the vector, accumulated with src_accum. + * @param[in] src The shared memory vector to multiply. + * @param[in] src_accum The initial value to accumulate with the product of the vector. + */ +template +__device__ static inline void prod(typename SV::dtype &prod_val, const SV &src, const typename SV::dtype &src_accum) { + reduce(prod_val, src, src_accum); +} +template +__device__ static inline typename SV::dtype prod(const SV &src, const typename SV::dtype &src_accum) { + typename SV::dtype prod_val; + reduce(prod_val, src, src_accum); + return prod_val; +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/group/shared/vec/vec.cuh b/extra/thunder/cuda/include/ops/group/shared/vec/vec.cuh new file mode 100644 index 0000000000..883ad52700 --- /dev/null +++ b/extra/thunder/cuda/include/ops/group/shared/vec/vec.cuh @@ -0,0 +1,38 @@ +/** + * @file + * @brief An aggregate header for group operations on shared vectors. + */ + +#include "conversions.cuh" +#include "maps.cuh" +// no group vector reductions as they would require additional shared memory and synchronization, and those side effects just aren't worth it. +// warp vector reductions should be plenty fast in 99.9% of situations. + +template +__device__ static inline bool hasnan(const SV &src) { + KITTENS_CHECK_WARP + bool nan_detected = false; + #pragma unroll + for(int i = laneid(); i < SV::length; i+=GROUP_THREADS) { + if constexpr (std::is_same_v) { + if(isnan(src[i])) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__bfloat162float(src[i]))) { + nan_detected = true; + } + } + else if constexpr (std::is_same_v) { + if(isnan(__half2float(src[i]))) { + nan_detected = true; + } + } + else { + static_assert(sizeof(typename SV::T) == 999, "Unsupported dtype"); + } + } + // Ballot across the warp to see if any lane detected a nan + return (__ballot_sync(0xffffffff, nan_detected) != 0); +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/ops.cuh b/extra/thunder/cuda/include/ops/ops.cuh new file mode 100644 index 0000000000..dfc075dc9a --- /dev/null +++ b/extra/thunder/cuda/include/ops/ops.cuh @@ -0,0 +1,262 @@ +/** + * @file + * @brief A collection of all of the operations that ThunderKittens defines. + */ + +#pragma once + +#include "thread/thread.cuh" +#include "group/group.cuh" +#include "device/device.cuh" + +namespace kittens { + +// Operator overloading, which defaults to warp scope. + +// Tile operators + +template +__device__ static inline T operator+(const T &lhs, const U &rhs) { + T dst; + warp::add(dst, lhs, rhs); + return dst; +} +template +__device__ static inline void operator+=(T &lhs, const U &rhs) { + warp::add(lhs, lhs, rhs); +} +template +__device__ static inline T operator-(const T &lhs, const U &rhs) { + T dst; + warp::sub(dst, lhs, rhs); + return dst; +} +template +__device__ static inline void operator-=(T &lhs, const U &rhs) { + warp::sub(lhs, lhs, rhs); +} +template +__device__ static inline T operator*(const T &lhs, const U &rhs) { + T dst; + warp::mul(dst, lhs, rhs); + return dst; +} +template +__device__ static inline void operator*=(T &lhs, const U &rhs) { + warp::mul(lhs, lhs, rhs); +} +template +__device__ static inline T operator/(const T &lhs, const U &rhs) { + T dst; + warp::div(dst, lhs, rhs); + return dst; +} +template +__device__ static inline void operator/=(T &lhs, const U &rhs) { + warp::div(lhs, lhs, rhs); +} +template +__device__ static inline T operator+(const T &src, const V &row_values) { + T dst; + warp::add_row(dst, src, row_values); + return dst; +} +template +__device__ static inline T operator+(const T &src, const V &row_values) { + T dst; + warp::add_row(dst, src, row_values); + return dst; +} +template +__device__ static inline void operator+=(T &lhs, const V &row_values) { + warp::add_row(lhs, lhs, row_values); +} +template +__device__ static inline void operator+=(T &lhs, const V &row_values) { + warp::add_row(lhs, lhs, row_values); +} +template +__device__ static inline T operator-(const T &src, const V &row_values) { + T dst; + warp::sub_row(dst, src, row_values); + return dst; +} +template +__device__ static inline T operator-(const T &src, const V &row_values) { + T dst; + warp::sub_row(dst, src, row_values); + return dst; +} +template +__device__ static inline void operator-=(T &lhs, const V &row_values) { + warp::sub_row(lhs, lhs, row_values); +} +template +__device__ static inline void operator-=(T &lhs, const V &row_values) { + warp::sub_row(lhs, lhs, row_values); +} +template +__device__ static inline T operator*(const T &src, const V &row_values) { + T dst; + warp::mul_row(dst, src, row_values); + return dst; +} +template +__device__ static inline T operator*(const T &src, const V &row_values) { + T dst; + warp::mul_row(dst, src, row_values); + return dst; +} +template +__device__ static inline void operator*=(T &lhs, const V &row_values) { + warp::mul_row(lhs, lhs, row_values); +} +template +__device__ static inline void operator*=(T &lhs, const V &row_values) { + warp::mul_row(lhs, lhs, row_values); +} +template +__device__ static inline T operator/(const T &src, const V &row_values) { + T dst; + warp::div_row(dst, src, row_values); + return dst; +} +template +__device__ static inline T operator/(const T &src, const V &row_values) { + T dst; + warp::div_row(dst, src, row_values); + return dst; +} +template +__device__ static inline void operator/=(T &lhs, const V &row_values) { + warp::div_row(lhs, lhs, row_values); +} +template +__device__ static inline void operator/=(T &lhs, const V &row_values) { + warp::div_row(lhs, lhs, row_values); +} +template +__device__ static inline T operator+(const T &src, const V &col_values) { + T dst; + warp::add_col(dst, src, col_values); + return dst; +} +template +__device__ static inline T operator+(const T &src, const V &col_values) { + T dst; + warp::add_col(dst, src, col_values); + return dst; +} +template +__device__ static inline void operator+=(T &lhs, const V &col_values) { + warp::add_col(lhs, lhs, col_values); +} +template +__device__ static inline void operator+=(T &lhs, const V &col_values) { + warp::add_col(lhs, lhs, col_values); +} +template +__device__ static inline T operator-(const T &src, const V &col_values) { + T dst; + warp::sub_col(dst, src, col_values); + return dst; +} +template +__device__ static inline T operator-(const T &src, const V &col_values) { + T dst; + warp::sub_col(dst, src, col_values); + return dst; +} +template +__device__ static inline void operator-=(T &lhs, const V &col_values) { + warp::sub_col(lhs, lhs, col_values); +} +template +__device__ static inline void operator-=(T &lhs, const V &col_values) { + warp::sub_col(lhs, lhs, col_values); +} +template +__device__ static inline T operator*(const T &src, const V &col_values) { + T dst; + warp::mul_col(dst, src, col_values); + return dst; +} +template +__device__ static inline T operator*(const T &src, const V &col_values) { + T dst; + warp::mul_col(dst, src, col_values); + return dst; +} +template +__device__ static inline void operator*=(T &lhs, const V &col_values) { + warp::mul_col(lhs, lhs, col_values); +} +template +__device__ static inline void operator*=(T &lhs, const V &col_values) { + warp::mul_col(lhs, lhs, col_values); +} +template +__device__ static inline T operator/(const T &src, const V &col_values) { + T dst; + warp::div_col(dst, src, col_values); + return dst; +} +template +__device__ static inline T operator/(const T &src, const V &col_values) { + T dst; + warp::div_col(dst, src, col_values); + return dst; +} +template +__device__ static inline void operator/=(T &lhs, const V &col_values) { + warp::div_col(lhs, lhs, col_values); +} +template +__device__ static inline void operator/=(T &lhs, const V &col_values) { + warp::div_col(lhs, lhs, col_values); +} + +// Vector operators + +template +__device__ static inline T operator+(const T &lhs, const U &rhs) { + T dst; + warp::add(dst, lhs, rhs); + return dst; +} +template +__device__ static inline void operator+=(T &lhs, const U &rhs) { + warp::add(lhs, lhs, rhs); +} +template +__device__ static inline T operator-(const T &lhs, const U &rhs) { + T dst; + warp::sub(dst, lhs, rhs); + return dst; +} +template +__device__ static inline void operator-=(T &lhs, const U &rhs) { + warp::sub(lhs, lhs, rhs); +} +template +__device__ static inline T operator*(const T &lhs, const U &rhs) { + T dst; + warp::mul(dst, lhs, rhs); + return dst; +} +template +__device__ static inline void operator*=(T &lhs, const U &rhs) { + warp::mul(lhs, lhs, rhs); +} +template +__device__ static inline T operator/(const T &lhs, const U &rhs) { + T dst; + warp::div(dst, lhs, rhs); + return dst; +} +template +__device__ static inline void operator/=(T &lhs, const U &rhs) { + warp::div(lhs, lhs, rhs); +} + +} diff --git a/extra/thunder/cuda/include/ops/thread/memory/memory.cuh b/extra/thunder/cuda/include/ops/thread/memory/memory.cuh new file mode 100644 index 0000000000..dc151ce49f --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/memory.cuh @@ -0,0 +1,10 @@ +/** + * @file + * @brief An aggregate header of warp memory operations, where a single warp loads or stores data on its own. + */ + +#pragma once + +#include "util/util.cuh" +#include "tile/tile.cuh" +#include "vec/vec.cuh" \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/thread/memory/tile/tile.cuh b/extra/thunder/cuda/include/ops/thread/memory/tile/tile.cuh new file mode 100644 index 0000000000..f2bbdcc9be --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/tile/tile.cuh @@ -0,0 +1,10 @@ +/** + * @file + * @brief An aggregate header of warp memory operations on tiles, where a single warp loads or stores data on its own. + */ + +#pragma once + +#ifdef KITTENS_HOPPER +#include "tma.cuh" +#endif \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/thread/memory/tile/tma.cuh b/extra/thunder/cuda/include/ops/thread/memory/tile/tma.cuh new file mode 100644 index 0000000000..3b1d543771 --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/tile/tma.cuh @@ -0,0 +1,564 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" +#include "../util/util.cuh" + +#include +#include + +namespace kittens { +namespace tma { + +namespace detail { +template __device__ inline int4 tma_coords(const coord &unit_coord) { + constexpr int swizzle_elements = ST::swizzle_bytes / sizeof(typename ST::dtype); + if constexpr (axis == 2) return {unit_coord.r, unit_coord.c / swizzle_elements, unit_coord.d, unit_coord.b}; + else if constexpr (axis == 1) return {unit_coord.d, unit_coord.c / swizzle_elements, unit_coord.r, unit_coord.b}; + else if constexpr (axis == 0) return {unit_coord.b, unit_coord.c / swizzle_elements, unit_coord.r, unit_coord.d}; +} +} + +/* ---------- Prefetch Tensor Map ---------- */ + +/** + * @brief Prefetches data from global memory into a shared memory tile, along with the tensormap. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination shared memory tile. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in] tile_row_idx The row coord of the requested tile. This is in units of complete tiles. + * @param[in] tile_col_idx The column coord of the requested tile. This is in units of complete tiles. + */ +template> +__device__ static inline void prefetch(ST &dst, const GL &src, const COORD &idx) { + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global.tile" + " [%0, {%1, %2, %3, %4, %5}];" + : + : "l"(tma_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global.tile.L2::cache_hint" + " [%0, {%1, %2, %3, %4, %5}], %6;" + : + : "l"(tma_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } +} +template> +__device__ static inline void prefetch(ST &dst, const GL &src, const COORD &idx) { + prefetch(dst, src, idx); +} + +/* ---------- Async load and store data from gmem/smem ---------- */ + +/** + * @brief Asynchronously stores data into global memory from a shared memory tile. + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination tensormap address in global memory + * @param[in] src_tma_map The source shared memory tile. + * @param[in] tile_row_idx The row coord of the tile destination. This is in units of complete tiles. + * @param[in] tile_col_idx The column coord of the tile destination. This is in units of complete tiles. + */ +template> +__device__ static inline void store_async(const GL &dst, const ST &src, const COORD &idx) { + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } + store_commit_group(); +} +template> +__device__ static inline void store_async(const GL &dst, const ST &src, const COORD &idx) { + store_async(dst, src, idx); +} +template> +__device__ static inline void store_async(const PGL &dst, const ST &src, const COORD &idx) { + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } + store_commit_group(); +} +template> +__device__ static inline void store_async(const PGL &dst, const ST &src, const COORD &idx) { + store_async(dst, src, idx); +} + +/* ---------- Async reduction + store data from gmem/smem ---------- */ + +/** + * @brief Asynchronously performs an add reduction and stores the result into global memory from a shared memory tile. + * + * This function performs an asynchronous add reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination tensormap address in global memory + * @param[in] src_tma_map The source shared memory tile. + * @param[in] tile_row_idx The row coord of the tile destination. This is in units of complete tiles. + * @param[in] tile_col_idx The column coord of the tile destination. This is in units of complete tiles. + */ +template> +__device__ static inline void store_add_async(const GL &dst, const ST &src, const COORD &idx) { + + static_assert(!(std::is_same_v || + std::is_same_v), + "TMA does not support async add reductions for fp8 types."); + + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.tile.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } + store_commit_group(); +} +template> +__device__ static inline void store_add_async(const GL &dst, const ST &src, const COORD &idx) { + store_add_async(dst, src, idx); +} +template> +__device__ static inline void store_add_async(const PGL &dst, const ST &src, const COORD &idx) { + + static_assert(!(std::is_same_v || + std::is_same_v), + "TMA does not support async add reductions for fp8 types."); + + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.tile.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } + store_commit_group(); +} +template> +__device__ static inline void store_add_async(const PGL &dst, const ST &src, const COORD &idx) { + store_add_async(dst, src, idx); +} + +/** + * @brief Asynchronously performs an min reduction and stores the result into global memory from a shared memory tile. + * + * This function performs an asynchronous min reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination tensormap address in global memory + * @param[in] src_tma_map The source shared memory tile. + * @param[in] tile_row_idx The row coord of the tile destination. This is in units of complete tiles. + * @param[in] tile_col_idx The column coord of the tile destination. This is in units of complete tiles. + */ +template> +__device__ static inline void store_min_async(const GL &dst, const ST &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + + static_assert(!(std::is_same_v || + std::is_same_v), + "TMA does not support async add reductions for fp8 types."); + + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.min.tile.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.min.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } + store_commit_group(); +} +template> +__device__ static inline void store_min_async(const GL &dst, const ST &src, const COORD &idx) { + store_min_async(dst, src, idx); +} +template> +__device__ static inline void store_min_async(const PGL &dst, const ST &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + + static_assert(!(std::is_same_v || + std::is_same_v), + "TMA does not support async add reductions for fp8 types."); + + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.min.tile.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.min.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } + store_commit_group(); +} +template> +__device__ static inline void store_min_async(const PGL &dst, const ST &src, const COORD &idx) { + store_min_async(dst, src, idx); +} + +/** + * @brief Asynchronously performs an max reduction and stores the result into global memory from a shared memory tile. + * + * This function performs an asynchronous max reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination tensormap address in global memory + * @param[in] src_tma_map The source shared memory tile. + * @param[in] tile_row_idx The row coord of the tile destination. This is in units of complete tiles. + * @param[in] tile_col_idx The column coord of the tile destination. This is in units of complete tiles. + */ +template> +__device__ static inline void store_max_async(const GL &dst, const ST &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + + static_assert(!(std::is_same_v || + std::is_same_v), + "TMA does not support async add reductions for fp8 types."); + + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.max.tile.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.max.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } + store_commit_group(); +} +template> +__device__ static inline void store_max_async(const GL &dst, const ST &src, const COORD &idx) { + store_max_async(dst, src, idx); +} +template> +__device__ static inline void store_max_async(const PGL &dst, const ST &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + + static_assert(!(std::is_same_v || + std::is_same_v), + "TMA does not support async add reductions for fp8 types."); + + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.max.tile.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.max.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5, %6}], [%1], %7;" + : + : "l"(tma_ptr), "r"(src_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } + store_commit_group(); +} +template> +__device__ static inline void store_max_async(const PGL &dst, const ST &src, const COORD &idx) { + store_max_async(dst, src, idx); +} + +/** + * @brief Asynchronously loads data from global memory into a shared memory tile. + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination shared memory tile. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in,out] bar The semaphore used for synchronization of the asynchronous copy. + * @param[in] tile_row_idx The row coord of the requested tile. This is in units of complete tiles. + * @param[in] tile_col_idx The column coord of the requested tile. This is in units of complete tiles. + */ +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar) { + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + + if constexpr (policy == cache_policy::NORMAL) { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2];" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w) + : "memory" + ); + } + else { + asm volatile( + "cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy()) + : "memory" + ); + } +} +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar) { + load_async(dst, src, idx, bar); +} + +namespace cluster { + +/** + * @brief Asynchronously loads data from global memory into a shared memory tile, across a threadblock cluster + * + * This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction. + * + * @tparam ST A shared tile type with a TMA-compatible layout + * @param[out] dst The destination shared memory tile. + * @param[in] src_tma_map The source tensormap address in global memory + * @param[in,out] bar The semaphore used for synchronization of the asynchronous copy. + * @param[in] tile_row_idx The row coord of the requested tile. This is in units of complete tiles. + * @param[in] tile_col_idx The column coord of the requested tile. This is in units of complete tiles. + * @param[in] cluster_mask The mask of the clusters to broadcast to. + */ +#ifdef KITTENS_BLACKWELL +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) +#else +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask) +#endif +{ + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst)); + coord unit_coord = idx.template unit_coord(); // convert to unit coordinates + int4 tma_coords = detail::tma_coords(unit_coord); + +#ifdef KITTENS_BLACKWELL + if(dst_mbar_cta != -1) { + uint32_t neighbor_mbar_ptr; + asm volatile ( + "mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_mbar_ptr) + : "r"(mbar_ptr), "r"(dst_mbar_cta) + ); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.cta_group::2.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(neighbor_mbar_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "h"(cluster_mask) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.cta_group::2.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8, %9;" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(neighbor_mbar_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "h"(cluster_mask), "l"(make_cache_policy()) + : "memory" + ); + } + } else +#endif + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "h"(cluster_mask) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8, %9;" + : + : "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "h"(cluster_mask), "l"(make_cache_policy()) + : "memory" + ); + } +} +#ifdef KITTENS_BLACKWELL +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { + load_async(dst, src, idx, bar, cluster_mask, dst_mbar_cta); +} +#else +template> +__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask) { + load_async(dst, src, idx, bar, cluster_mask); +} +#endif + +} // namespace cluster +} // namespace tma + +} // namespace kittens diff --git a/extra/thunder/cuda/include/ops/thread/memory/util/multimem.cuh b/extra/thunder/cuda/include/ops/thread/memory/util/multimem.cuh new file mode 100644 index 0000000000..e308428dc5 --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/util/multimem.cuh @@ -0,0 +1,405 @@ +/** + * @file + * @brief Wrappers for multimem operations + */ + +#pragma once + +namespace kittens { + +enum class reduce_op { + ADD = 0, + MIN = 1, + MAX = 2 +}; + +enum class memory_model { + WEAK = 0, + STRONG = 1 +}; + +template +struct multimem; + +template <> +struct multimem { + template + __device__ static inline void ld_reduce(int &dst, const int *src) { + if constexpr (Op == reduce_op::ADD) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.add.s32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.add.s32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MIN) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.min.s32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.min.s32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MAX) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.max.s32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.max.s32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } + } + } + template + __device__ static inline void st(int *dst, const int &src) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.st.weak.global.s32 [%0], %1;" + :: "l"(dst), "r"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.st.release.sys.global.s32 [%0], %1;" + :: "l"(dst), "r"(src) : "memory"); + } + } + template + __device__ static inline void red(int *dst, const int &src) { + if constexpr (Op == reduce_op::ADD) { + asm volatile("multimem.red.release.sys.global.add.s32 [%0], %1;" + : : "l"(dst), "r"(src) : "memory"); + } else if constexpr (Op == reduce_op::MIN) { + asm volatile("multimem.red.release.sys.global.min.s32 [%0], %1;" + : : "l"(dst), "r"(src) : "memory"); + } else if constexpr (Op == reduce_op::MAX) { + asm volatile("multimem.red.release.sys.global.max.s32 [%0], %1;" + : : "l"(dst), "r"(src) : "memory"); + } + } +}; + +template <> +struct multimem { + template + __device__ static inline void ld_reduce(uint &dst, const uint *src) { + if constexpr (Op == reduce_op::ADD) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.add.u32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.add.u32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MIN) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.min.u32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.min.u32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MAX) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.max.u32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.max.u32 %0, [%1];" + : "=r"(dst) : "l"(src) : "memory"); + } + } + } + template + __device__ static inline void st(uint *dst, const uint &src) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.st.weak.global.u32 [%0], %1;" + :: "l"(dst), "r"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.st.release.sys.global.u32 [%0], %1;" + :: "l"(dst), "r"(src) : "memory"); + } + } + template + __device__ static inline void red(uint *dst, const uint &src) { + if constexpr (Op == reduce_op::ADD) { + asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;" + : : "l"(dst), "r"(src) : "memory"); + } else if constexpr (Op == reduce_op::MIN) { + asm volatile("multimem.red.release.sys.global.min.u32 [%0], %1;" + : : "l"(dst), "r"(src) : "memory"); + } else if constexpr (Op == reduce_op::MAX) { + asm volatile("multimem.red.release.sys.global.max.u32 [%0], %1;" + : : "l"(dst), "r"(src) : "memory"); + } + } +}; + +template <> +struct multimem { + template + __device__ static inline void ld_reduce(float &dst, const float *src) { + static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f32 ld_reduce operations"); + if constexpr (Op == reduce_op::ADD) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.add.f32 %0, [%1];" + : "=f"(dst) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.add.f32 %0, [%1];" + : "=f"(dst) : "l"(src) : "memory"); + } + } + } + template + __device__ static inline void st(float *dst, const float &src) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.st.weak.global.f32 [%0], %1;" + :: "l"(dst), "f"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.st.release.sys.global.f32 [%0], %1;" + :: "l"(dst), "f"(src) : "memory"); + } + } + template + __device__ static inline void red(float *dst, const float &src) { + static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f32 red operations"); + if constexpr (Op == reduce_op::ADD) { + asm volatile("multimem.red.release.sys.global.add.f32 [%0], %1;" + : : "l"(dst), "f"(src) : "memory"); + } + } +}; + + +template <> +struct multimem { + template + __device__ static inline void ld_reduce(float2 &dst, const float2 *src) { + static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f32 ld_reduce operations"); + if constexpr (Op == reduce_op::ADD) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.add.v2.f32 {%0, %1}, [%2];" + : "=f"(dst.x), "=f"(dst.y) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.add.v2.f32 {%0, %1}, [%2];" + : "=f"(dst.x), "=f"(dst.y) : "l"(src) : "memory"); + } + } + } + template + __device__ static inline void st(float2 *dst, const float2 &src) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.st.weak.global.v2.f32 [%0], {%1, %2};" + :: "l"(dst), "f"(src.x), "f"(src.y) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.st.release.sys.global.v2.f32 [%0], {%1, %2};" + :: "l"(dst), "f"(src.x), "f"(src.y) : "memory"); + } + } + template + __device__ static inline void red(float2 *dst, const float2 &src) { + static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f32 red operations"); + if constexpr (Op == reduce_op::ADD) { + asm volatile("multimem.red.release.sys.global.add.v2.f32 [%0], {%1, %2};" + : : "l"(dst), "f"(src.x), "f"(src.y) : "memory"); + } + } +}; + +template <> +struct multimem { + template + __device__ static inline void ld_reduce(bf16 &dst, const bf16 *src) { + if constexpr (Op == reduce_op::ADD) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.add.acc::f32.bf16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.add.acc::f32.bf16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MIN) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.min.bf16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.min.bf16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MAX) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.max.bf16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.max.bf16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } + } + template + __device__ static inline void st(bf16 *dst, const bf16 &src) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.st.weak.global.bf16 [%0], %1;" + :: "l"(dst), "h"(*reinterpret_cast(&src)) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.st.release.sys.global.bf16 [%0], %1;" + :: "l"(dst), "h"(*reinterpret_cast(&src)) : "memory"); + } + } + template + __device__ static inline void red(bf16 *dst, const bf16 &src) { + static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for bf16 red operations"); + if constexpr (Op == reduce_op::ADD) { + asm volatile("multimem.red.release.sys.global.add.bf16 [%0], %1;" + : : "l"(dst), "h"(*reinterpret_cast(&src)) : "memory"); + } + } +}; + +template <> +struct multimem { + template + __device__ static inline void ld_reduce(bf16_2 &dst, const bf16_2 *src) { + if constexpr (Op == reduce_op::ADD) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.add.acc::f32.bf16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.add.acc::f32.bf16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MIN) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.min.bf16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.min.bf16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MAX) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.max.bf16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.max.bf16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } + } + template + __device__ static inline void st(bf16_2 *dst, const bf16_2 &src) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.st.weak.global.bf16x2 [%0], %1;" + :: "l"(dst), "r"(*reinterpret_cast(&src)) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.st.release.sys.global.bf16x2 [%0], %1;" + :: "l"(dst), "r"(*reinterpret_cast(&src)) : "memory"); + } + } + template + __device__ static inline void red(bf16_2 *dst, const bf16_2 &src) { + static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for bf16_2 red operations"); + if constexpr (Op == reduce_op::ADD) { + asm volatile("multimem.red.release.sys.global.add.bf16x2 [%0], %1;" + : : "l"(dst), "r"(*reinterpret_cast(&src)) : "memory"); + } + } +}; + +template <> +struct multimem { + template + __device__ static inline void ld_reduce(half &dst, const half *src) { + if constexpr (Op == reduce_op::ADD) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.add.acc::f32.f16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.add.acc::f32.f16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MIN) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.min.f16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.min.f16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MAX) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.max.f16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.max.f16 %0, [%1];" + : "=h"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } + } + template + __device__ static inline void st(half *dst, const half &src) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.st.weak.global.f16 [%0], %1;" + :: "l"(dst), "h"(*reinterpret_cast(&src)) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.st.release.sys.global.f16 [%0], %1;" + :: "l"(dst), "h"(*reinterpret_cast(&src)) : "memory"); + } + } + template + __device__ static inline void red(half *dst, const half &src) { + static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f16 red operations"); + if constexpr (Op == reduce_op::ADD) { + asm volatile("multimem.red.release.sys.global.add.f16 [%0], %1;" + : : "l"(dst), "h"(*reinterpret_cast(&src)) : "memory"); + } + } +}; + +template <> +struct multimem { + template + __device__ static inline void ld_reduce(half_2 &dst, const half_2 *src) { + if constexpr (Op == reduce_op::ADD) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.add.acc::f32.f16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.add.acc::f32.f16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MIN) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.min.f16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.min.f16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } else if constexpr (Op == reduce_op::MAX) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.ld_reduce.weak.global.max.f16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.ld_reduce.acquire.sys.global.max.f16x2 %0, [%1];" + : "=r"(*reinterpret_cast(&dst)) : "l"(src) : "memory"); + } + } + } + template + __device__ static inline void st(half_2 *dst, const half_2 &src) { + if constexpr (M == memory_model::WEAK) { + asm volatile("multimem.st.weak.global.f16x2 [%0], %1;" + :: "l"(dst), "r"(*reinterpret_cast(&src)) : "memory"); + } else if constexpr (M == memory_model::STRONG) { + asm volatile("multimem.st.release.sys.global.f16x2 [%0], %1;" + :: "l"(dst), "r"(*reinterpret_cast(&src)) : "memory"); + } + } + template + __device__ static inline void red(half_2 *dst, const half_2 &src) { + static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f16_2 red operations"); + if constexpr (Op == reduce_op::ADD) { + asm volatile("multimem.red.release.sys.global.add.f16x2 [%0], %1;" + : : "l"(dst), "r"(*reinterpret_cast(&src)) : "memory"); + } + } +}; + +} // namespace kittens diff --git a/extra/thunder/cuda/include/ops/thread/memory/util/tensor.cuh b/extra/thunder/cuda/include/ops/thread/memory/util/tensor.cuh new file mode 100644 index 0000000000..657f596302 --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/util/tensor.cuh @@ -0,0 +1,30 @@ +/** + * @file + * @brief Functions for transferring data directly between tensor memory and register memory. + */ + +#pragma once + +#include + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" +#include "util.cuh" + +namespace kittens { + +__device__ static inline void tensor_before_thread_sync() { + asm volatile("tcgen05.fence::before_thread_sync;\n"); +} +__device__ static inline void tensor_after_thread_sync() { + asm volatile("tcgen05.fence::after_thread_sync;\n"); +} + +__device__ inline static void tensor_load_wait() { + asm volatile("tcgen05.wait::ld.sync.aligned;"); +} +__device__ inline static void tensor_store_wait() { + asm volatile("tcgen05.wait::st.sync.aligned;"); +} + +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/thread/memory/util/tma.cuh b/extra/thunder/cuda/include/ops/thread/memory/util/tma.cuh new file mode 100644 index 0000000000..82c6210bd8 --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/util/tma.cuh @@ -0,0 +1,249 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +#include +#include + +namespace kittens { +/** + * @brief A namespace for all of ThunderKittens' TMA functionality. +*/ +namespace tma { + +/* ---------- Barrier functions for async load ---------- */ + +/** +* @brief Sets the number of bytes expected at the semaphore. +* +* This function sets the number of bytes expected at the semaphore for the first thread in the warp. +* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly +* instruction to set the expected number of bytes. +* +* @param semaphore Reference to the semaphore variable. +* @param bytes The number of bytes expected at the semaphore. +*/ +__device__ static inline void expect_bytes(semaphore& bar, uint32_t bytes) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(bar_ptr), "r"(bytes)); +} +/** +* @brief Sets the number of bytes expected at the semaphore. +* +* This function sets the number of bytes expected at the mbarrier before the transaction arrives. +*/ +template +__device__ static inline void expect(semaphore& bar, const T& _1, const args&... _2) { + expect_bytes(bar, size_bytes); +} + +/* ---------- Synchronization functions for async store ---------- */ + +/** + * @brief Commits previous asynchronous TMA stores to a group and performs them. +*/ +__device__ static inline void store_commit_group() { + asm volatile("cp.async.bulk.commit_group;"); +} +/** + * @brief Waits for previous committed TMA store groups to complete. + * + * @tparam N The maximum number of remaining TMA store groups. Defaults to 0. +*/ +template +__device__ static inline void store_async_wait() { + asm volatile ( + "cp.async.bulk.wait_group %0;" + : + : "n"(N) + : "memory" + ); +} +/** + * @brief Waits for previous committed TMA store groups to finish reading from shared memory. + * + * @tparam N The maximum number of remaining TMA store groups. Defaults to 0. +*/ +template +__device__ static inline void store_async_read_wait() { + asm volatile ( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(N) + : "memory" + ); +} + +/* ---------- Cluster-scope operations ---------- */ + +namespace cluster { + +/** +* @brief Waits for the requested semaphore phase, at cluster scope +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void wait(semaphore& bar, int kPhaseBit) { + void const* const ptr = &bar; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +} + +__device__ static inline void careful_wait(semaphore& bar, int kPhaseBit) { + void const* const ptr = &bar; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "{\n" + ".reg .b64 start_clock, current_clock;\n" + "mov.b64 start_clock, %clock64;\n" + ".reg .pred P_CLOCK;\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "mov.b64 current_clock, %clock64;\n" + "sub.u64 current_clock, current_clock, start_clock;\n" + "setp.ge.u64 P_CLOCK, current_clock, 1000000;\n" + "@P_CLOCK trap;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +} + +/** +* @brief Sets the number of bytes expected at the semaphore, assuming a multicast instruction. +* +* This function sets the number of bytes expected at the semaphore for the first thread in the warp. +* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly +* instruction to set the expected number of bytes. +* +* It's worth being aware that this function is particularly necessary for multicast loads, and +* distributed shared memory can actually be done with a normal tma::expect followed by wait. See +* the unit tests of dsmem for an example. +* +* @param semaphore Reference to the semaphore variable. +* @param bytes The number of bytes expected at the semaphore. +*/ +__device__ static inline void expect_bytes(semaphore& bar, uint32_t bytes, int dst_cta) { + uint32_t mbar_addr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t neighbor_mbar_addr; + asm volatile ( + "mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_mbar_addr) + : "r"(mbar_addr), "r"(dst_cta) + ); + + asm volatile ("mbarrier.arrive.expect_tx.shared::cluster.b64 _, [%0], %1;\n" + :: "r"(neighbor_mbar_addr), "r"(bytes)); +} +/** +* @brief Sets the number of bytes expected at the semaphore. +* +* This function sets the number of bytes expected at the semaphore for the first thread in the warp. +* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly +* instruction to set the expected number of bytes. +* +* @tparam T The type of the data to be stored at the semaphore. +* @param semaphore Reference to the semaphore variable. +*/ +/** +* @brief Sets the number of bytes expected at the semaphore. +* +* This function sets the number of bytes expected at the mbarrier before the transaction arrives. +*/ +template +__device__ static inline void expect(semaphore& bar, int dst_cta, const T& _1, const args&... _2) { + expect_bytes(bar, size_bytes, dst_cta); +} + +/** +* @brief Arrives at a semaphore in cluster scope. +* +* Marks a thread arrival at an mbarrier +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void arrive(semaphore& bar, int dst_cta, uint32_t count=1) { + uint32_t mbar_addr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t neighbor_mbar_addr; + asm volatile ( + "mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_mbar_addr) + : "r"(mbar_addr), "r"(dst_cta) + ); + asm volatile ( + "mbarrier.arrive.shared::cluster.b64 _, [%0], %1;\n" + : + : "r"(neighbor_mbar_addr), "r" (count) + : "memory" + ); +} + +// Generic transfer +__device__ static inline void store_async(void *dst, void *src, int dst_cta, uint32_t size_bytes, semaphore& bar) { + void const* const ptr = &bar; + uint32_t mbarrier_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + // ************************************************** + // load from src to dst in different threadblocks + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(src)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(dst)); + + // mapa instr = https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mapa + // find dst addr in neighbor's cta + uint32_t neighbor_addr_dst; + asm volatile ( + "mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_dst) + : "r"(dst_ptr), "r"(dst_cta) + ); + + uint32_t neighbor_addr_mbarrier = mbarrier_ptr; + asm volatile ( + "mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_addr_mbarrier) + : "r"(mbarrier_ptr), "r"(dst_cta) + ); + + // cp.async instr = https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk + // copy src into dst in neighbor's cta + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + asm volatile ( + "cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(neighbor_addr_dst), "r"(src_ptr), "r"(size_bytes), "r"(neighbor_addr_mbarrier) + : "memory" + ); +} + +// Templated transfer for convenience +template +__device__ static inline void store_async(T &dst_, T &src_, int dst_cta, semaphore& bar) { + store_async((void*)&dst_, (void*)&src_, dst_cta, size_bytes, bar); +} + +} // namespace cluster +} // namespace tma +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/thread/memory/util/util.cuh b/extra/thunder/cuda/include/ops/thread/memory/util/util.cuh new file mode 100644 index 0000000000..cd3c760dbd --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/util/util.cuh @@ -0,0 +1,443 @@ +/** + * @file + * @brief General memory utilities not specialized for either tiles or vectors. + */ + +#pragma once + +namespace kittens { + +/* ---------- To prevent generic addressing, PTX ---------- */ + +template struct move { + __device__ static inline void lds(T& dst, uint32_t src); + __device__ static inline void sts(uint32_t dst, const T& src); + __device__ static inline void ldg(T& dst, T* src); + __device__ static inline void stg(T* dst, const T& src); +}; +// unpacked types +template<> struct move { + __device__ static inline void lds(bf16& dst, uint32_t src) { + asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(*(uint16_t*)&dst) : "r"(src)); + } + __device__ static inline void sts(uint32_t dst, const bf16& src) { + asm volatile("st.shared.b16 [%1], %0;\n" : : "h"(*(uint16_t*)&src), "r"(dst)); + } + __device__ static inline void ldg(bf16& dst, bf16* src) { + asm volatile("ld.global.b16 %0, [%1];\n" : "=h"(*(uint16_t*)&dst) : "l"(src)); + } + __device__ static inline void stg(bf16* dst, const bf16& src) { + asm volatile("st.global.b16 [%1], %0;\n" : : "h"(*(uint16_t*)&src), "l"(dst)); + } +}; +template<> struct move { + __device__ static inline void lds(half& dst, uint32_t src) { + asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(*(uint16_t*)&dst) : "r"(src)); + } + __device__ static inline void sts(uint32_t dst, const half& src) { + asm volatile("st.shared.b16 [%1], %0;\n" : : "h"(*(uint16_t*)&src), "r"(dst)); + } + __device__ static inline void ldg(half& dst, half* src) { + asm volatile("ld.global.b16 %0, [%1];\n" : "=h"(*(uint16_t*)&dst) : "l"(src)); + } + __device__ static inline void stg(half* dst, const half& src) { + asm volatile("st.global.b16 [%1], %0;\n" : : "h"(*(uint16_t*)&src), "l"(dst)); + } +}; +template<> struct move { + __device__ static inline void lds(float& dst, uint32_t src) { + asm volatile("ld.shared.f32 %0, [%1];\n" : "=f"(dst) : "r"(src)); + } + __device__ static inline void sts(uint32_t dst, const float& src) { + asm volatile("st.shared.f32 [%1], %0;\n" : : "f"(src), "r"(dst)); + } + __device__ static inline void ldg(float& dst, float* src) { + asm volatile("ld.global.f32 %0, [%1];\n" : "=f"(dst) : "l"(src)); + } + __device__ static inline void stg(float* dst, const float& src) { + asm volatile("st.global.f32 [%1], %0;\n" : : "f"(src), "l"(dst)); + } +}; +template<> struct move { + __device__ static inline void lds(int& dst, uint32_t src) { + asm volatile("ld.shared.u32 %0, [%1];\n" : "=r"(dst) : "r"(src)); + } + __device__ static inline void sts(uint32_t dst, const int& src) { + asm volatile("st.shared.u32 [%1], %0;\n" : : "r"(src), "r"(dst)); + } + __device__ static inline void ldg(int& dst, int* src) { + asm volatile("ld.global.u32 %0, [%1];\n" : "=r"(dst) : "l"(src)); + } + __device__ static inline void stg(int* dst, const int& src) { + asm volatile("st.global.u32 [%1], %0;\n" : : "r"(src), "l"(dst)); + } +}; +// packed types +template<> struct move { + __device__ static inline void lds(bf16_2& dst, uint32_t src) { + asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(*(uint32_t*)&dst) : "r"(src)); + } + __device__ static inline void sts(uint32_t dst, const bf16_2& src) { + asm volatile("st.shared.b32 [%1], %0;\n" : : "r"(*(uint32_t*)&src), "r"(dst)); + } + __device__ static inline void ldg(bf16_2& dst, bf16_2* src) { + asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(*(uint32_t*)&dst) : "l"(src)); + } + __device__ static inline void stg(bf16_2* dst, const bf16_2& src) { + asm volatile("st.global.b32 [%1], %0;\n" : : "r"(*(uint32_t*)&src), "l"(dst)); + } + __device__ static inline void ldsm4(bf16_2& dst1, bf16_2& dst2, bf16_2& dst3, bf16_2& dst4, uint32_t src) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" : + "=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src)); + } + __device__ static inline void ldsm4t(bf16_2& dst1, bf16_2& dst2, bf16_2& dst3, bf16_2& dst4, uint32_t src) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" : + "=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src)); + } + __device__ static inline void stsm4(uint32_t dst, bf16_2& src1, bf16_2& src2, bf16_2& src3, bf16_2& src4) { + asm volatile("stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" :: + "r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst)); + } + __device__ static inline void stsm4t(uint32_t dst, bf16_2& src1, bf16_2& src2, bf16_2& src3, bf16_2& src4) { + asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" :: + "r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst)); + } +}; +template<> struct move { + __device__ static inline void lds(half_2& dst, uint32_t src) { + asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(*(uint32_t*)&dst) : "r"(src)); + } + __device__ static inline void sts(uint32_t dst, const half_2& src) { + asm volatile("st.shared.b32 [%1], %0;\n" : : "r"(*(uint32_t*)&src), "r"(dst)); + } + __device__ static inline void ldg(half_2& dst, half_2* src) { + asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(*(uint32_t*)&dst) : "l"(src)); + } + __device__ static inline void stg(half_2* dst, const half_2& src) { + asm volatile("st.global.b32 [%1], %0;\n" : : "r"(*(uint32_t*)&src), "l"(dst)); + } + __device__ static inline void ldsm4(half_2& dst1, half_2& dst2, half_2& dst3, half_2& dst4, uint32_t src) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" : + "=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src)); + } + __device__ static inline void ldsm4t(half_2& dst1, half_2& dst2, half_2& dst3, half_2& dst4, uint32_t src) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" : + "=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src)); + } + __device__ static inline void stsm4(uint32_t dst, half_2& src1, half_2& src2, half_2& src3, half_2& src4) { + asm volatile("stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" :: + "r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst)); + } + __device__ static inline void stsm4t(uint32_t dst, half_2& src1, half_2& src2, half_2& src3, half_2& src4) { + asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" :: + "r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst)); + } +}; +template<> struct move { + __device__ static inline void lds(float2& dst, uint32_t src) { + asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];\n" : "=f"(dst.x), "=f"(dst.y) : "r"(src)); + } + __device__ static inline void sts(uint32_t dst, const float2& src) { + asm volatile("st.shared.v2.f32 [%2], {%0, %1};\n" : : "f"(src.x), "f"(src.y), "r"(dst)); + } + __device__ static inline void ldg(float2& dst, float2* src) { + asm volatile("ld.global.v2.f32 {%0, %1}, [%2];\n" : "=f"(dst.x), "=f"(dst.y) : "l"(src)); + } + __device__ static inline void stg(float2* dst, const float2& src) { + asm volatile("st.global.v2.f32 [%2], {%0, %1};\n" : : "f"(src.x), "f"(src.y), "l"(dst)); + } +}; +template<> struct move { + __device__ static inline void lds(float4& dst, uint32_t src) { + asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];\n" : "=f"(dst.x), "=f"(dst.y), "=f"(dst.z), "=f"(dst.w) : "r"(src)); + } + __device__ static inline void sts(uint32_t dst, const float4& src) { + asm volatile("st.shared.v4.f32 [%4], {%0, %1, %2, %3};\n" : : "f"(src.x), "f"(src.y), "f"(src.z), "f"(src.w), "r"(dst)); + } + __device__ static inline void ldg(float4& dst, float4* src) { + asm volatile("ld.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" : "=f"(dst.x), "=f"(dst.y), "=f"(dst.z), "=f"(dst.w) : "l"(src)); + } + __device__ static inline void stg(float4* dst, const float4& src) { + asm volatile("st.global.v4.f32 [%4], {%0, %1, %2, %3};\n" : : "f"(src.x), "f"(src.y), "f"(src.z), "f"(src.w), "l"(dst)); + } +}; +#ifdef KITTENS_HOPPER +template<> struct move { + __device__ static inline void ldsm4(fp8e4m3_4& dst1, fp8e4m3_4& dst2, fp8e4m3_4& dst3, fp8e4m3_4& dst4, uint32_t src) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" : + "=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src)); + } + __device__ static inline void stsm4(uint32_t dst, fp8e4m3_4& src1, fp8e4m3_4& src2, fp8e4m3_4& src3, fp8e4m3_4& src4) { + asm volatile("stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" :: + "r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst)); + } + +}; +template<> struct move { + __device__ static inline void ldsm4(fp8e5m2_4& dst1, fp8e5m2_4& dst2, fp8e5m2_4& dst3, fp8e5m2_4& dst4, uint32_t src) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" : + "=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src)); + } + __device__ static inline void stsm4(uint32_t dst, fp8e5m2_4& src1, fp8e5m2_4& src2, fp8e5m2_4& src3, fp8e5m2_4& src4) { + asm volatile("stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" :: + "r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst)); + } +}; +#endif + +/* ---------- Constants for Cache policies ---------- */ + +enum cache_policy { + NORMAL = 0, + EVICT_FIRST = 1, + EVICT_LAST = 2 +}; +template __device__ inline uint64_t make_cache_policy() { + uint64_t cache_policy_val; + constexpr float fraction = 1.0f; + static_assert(policy == cache_policy::EVICT_FIRST || policy == cache_policy::EVICT_LAST, "Unexpected cache policy"); + if constexpr (policy == cache_policy::EVICT_FIRST) { + asm volatile("createpolicy.fractional.L2::evict_first.b64 %0, %1;\n" : "=l"(cache_policy_val) : "f"(fraction)); + } + else { + asm volatile("createpolicy.fractional.L2::evict_last.b64 %0, %1;\n" : "=l"(cache_policy_val) : "f"(fraction)); + } + return cache_policy_val; +} +/* ---------- Generic (non-Hopper specific) semaphore functions ---------- */ + +struct semaphore { +private: + uint64_t value; +}; // note that this is an opaque type, so the value should not be accessed directly. +template struct barrier { + int barrier_id; + __device__ __forceinline__ barrier(int _id) : barrier_id(_id) {} + __device__ __forceinline__ barrier operator[](int i) { + return barrier(barrier_id + i); + } +}; + +/** + * @brief Initializes a synchronization semaphore with a transaction count and sets the expected number of bytes. + * + * This function sets up a semaphore that is used to synchronize threads within a block during asynchronous operations. + * It initializes the semaphore with a thread count semaphore. + * + * Additionally, if it is given a shared tile type, it will also call `set_bytes` to prepare for the memory transaction. + * + * @param[out] semaphore The semaphore variable to initialize. + * @param[in] tc The thread counter for the semaphore. + */ +__device__ static inline void init_semaphore(semaphore& bar, int thread_count, int transaction_count=0) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + + asm volatile ( + "mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(bar_ptr), "r"(thread_count+transaction_count) + ); +} +/** + * @brief Invalidate an mbarrier + * + * @param[out] semaphore The semaphore variable to initialize. + * @param[in] tc The thread counter for the semaphore. + */ +__device__ static inline void invalidate_semaphore(semaphore& bar) { + void const* const ptr = &bar; + uint32_t bar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + asm volatile ( + "mbarrier.inval.shared::cta.b64 [%0];\n" + :: "r"(bar_ptr) + ); +} + +/** +* @brief Arrives at a semaphore. +* +* Marks a warp arrival at an mbarrier +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void arrive(semaphore& sem) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&sem)); + asm volatile ( + "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];\n" + : + : "r"(mbar_ptr) + : "memory" + ); +} +template __device__ static inline void arrive(barrier bar) { + asm volatile("bar.arrive %0, %1;\n" :: "r"(bar.barrier_id), "n"(num_warps*WARP_THREADS) : "memory"); +} + +#ifdef KITTENS_HOPPER +/** +* @brief Arrives at a semaphore. +* +* Marks a warp arrival at an mbarrier +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void arrive(semaphore& sem, uint32_t count) { + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&sem)); + asm volatile ( + "mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n" + : + : "r"(mbar_ptr), "r"(count) + : "memory" + ); +} +#endif + +/** +* @brief Waits for the requested semaphore phase. +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline void wait(semaphore& sem, int kPhaseBit) { + void const* const ptr = &sem; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + +#ifdef KITTENS_HOPPER + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +#else + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures to save instruction issue slots + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +#endif +} + +__device__ static inline void careful_wait(semaphore& sem, int kPhaseBit) { + void const* const ptr = &sem; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + +#ifdef KITTENS_HOPPER + asm volatile ( + "{\n" + ".reg .b64 start_clock, current_clock;\n" + "mov.b64 start_clock, %clock64;\n" + ".reg .pred P_CLOCK;\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "mov.b64 current_clock, %clock64;\n" + "sub.u64 current_clock, current_clock, start_clock;\n" + "setp.ge.u64 P_CLOCK, current_clock, 1000000;\n" + "@P_CLOCK trap;\n" + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +#else + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra.uni DONE;\n" + "nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures to save instruction issue slots + "bra.uni LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(mbar_ptr), + "r"(kPhaseBit) + ); +#endif +} + +/** +* @brief Checks if the requested semaphore phase is ready. +* +* @param semaphore Reference to the semaphore variable. +* @param kPhaseBit The phase bit used for the semaphore. +*/ +__device__ static inline int test_wait(semaphore& sem, int kPhaseBit) { + void const* const ptr = &sem; + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(ptr)); + int result; + asm volatile ( + "{\n" + ".reg .pred P1;\n" + "mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2;\n" + "selp.u32 %0,1,0,P1;" + "}\n" + : "=r"(result) + : "r"(mbar_ptr), "r"(kPhaseBit) + ); + return result; +} + +__device__ static inline void arrive_and_wait(semaphore& sem, int kPhaseBit) { + arrive(sem); + wait(sem, kPhaseBit); +} +template __device__ static inline void arrive_and_wait(barrier bar) { + asm volatile("bar.sync %0, %1;\n" :: "r"(bar.barrier_id), "n"(num_warps*WARP_THREADS) : "memory"); +} + +template __device__ static inline void load_async_wait() { // for completing (non-TMA) async loads + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); + } + __syncwarp(); +} + +// meant to be used only with shared tiles and shared vectors +namespace detail { +template struct size_info { + static constexpr uint32_t bytes = sizeof(std::remove_reference_t); +}; +template struct size_info { + static constexpr uint32_t elements = ST::num_elements; + static constexpr uint32_t bytes = ST::num_elements * sizeof(typename ST::dtype); +}; +template struct size_info { + static constexpr uint32_t elements = SV::length; + static constexpr uint32_t bytes = SV::length * sizeof(typename SV::dtype); +}; +} +template inline constexpr uint32_t size_bytes = 0; // base case +template inline constexpr uint32_t size_bytes = detail::size_info::bytes + size_bytes; // recursive case + +} // namespace kittens + +#ifdef KITTENS_HOPPER +#include "multimem.cuh" +#include "tma.cuh" +#endif + +#ifdef KITTENS_BLACKWELL +#include "tensor.cuh" +#endif \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/thread/memory/vec/tma.cuh b/extra/thunder/cuda/include/ops/thread/memory/vec/tma.cuh new file mode 100644 index 0000000000..dd92ccab44 --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/vec/tma.cuh @@ -0,0 +1,416 @@ +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" +#include "../util/util.cuh" + +#include +#include + +// This is a macro that helps us define default cache policy versions of each function. +#define __KITTENS_TMA_DEFINE_DEFAULT_LOAD_CACHE_VEC__(function_name) \ +template> \ +__device__ static inline void function_name(SV &dst, const GL &src, const COORD &idx) { \ + function_name(dst, src, idx); \ +} +#define __KITTENS_TMA_DEFINE_PGL_DEFAULT_LOAD_CACHE_VEC__(function_name) \ +template> \ +__device__ static inline void function_name(SV &dst, const PGL &src, const COORD &idx) { \ + function_name(dst, src, idx); \ +} +#define __KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(function_name) \ +template> \ +__device__ static inline void function_name(const GL &dst, const SV &src, const COORD &idx) { \ + function_name(dst, src, idx); \ +} +#define __KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(function_name) \ +template> \ +__device__ static inline void function_name(const PGL &dst, const SV &src, const COORD &idx) { \ + function_name(dst, src, idx); \ +} +#define __KITTENS_TMA_DEFINE_SEMAPHORE_CACHE_VEC__(function_name) \ +template> \ +__device__ static inline void function_name(SV &dst, const GL &src, const COORD &idx, semaphore& bar) { \ + function_name(dst, src, idx, bar); \ +} +#define __KITTENS_TMA_DEFINE_PGL_SEMAPHORE_CACHE_VEC__(function_name) \ +template> \ +__device__ static inline void function_name(SV &dst, const PGL &src, const COORD &idx, semaphore& bar) { \ + function_name(dst, src, idx, bar); \ +} +#define __KITTENS_TMA_DEFINE_CLUSTER_SEMAPHORE_CACHE_VEC__(function_name) \ +template> \ +__device__ static inline void function_name(SV &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { \ + function_name(dst, src, idx, bar, cluster_mask, dst_mbar_cta); \ +} +#define __KITTENS_TMA_DEFINE_PGL_CLUSTER_SEMAPHORE_CACHE_VEC__(function_name) \ +template> \ +__device__ static inline void function_name(SV &dst, const PGL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { \ + function_name(dst, src, idx, bar, cluster_mask, dst_mbar_cta); \ +} + + +namespace kittens { + +namespace detail { +namespace tma { + +template __device__ static inline void vec_prefetch_tma_internal(uint64_t tma_ptr, coord<> tma_coord) { + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global.tile" + " [%0, {%1, %2, %3, %4}];" + : + : "l"(tma_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global.tile.L2::cache_hint" + " [%0, {%1, %2, %3, %4}], %5;" + : + : "l"(tma_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy()) + : "memory" + ); + } +} + +template __device__ static inline void vec_store_async_tma_internal(uint64_t tma_ptr, uint32_t src_i_ptr, coord<> tma_coord) { + asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5}], [%1], %6;" + : + : "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy()) + : "memory" + ); + } +} + +template __device__ static inline void vec_store_add_async_tma_internal(uint64_t tma_ptr, uint32_t src_i_ptr, coord<> tma_coord) { + asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5}], [%1], %6;" + : + : "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy()) + : "memory" + ); + } +} + +template __device__ static inline void vec_store_min_async_tma_internal(uint64_t tma_ptr, uint32_t src_i_ptr, coord<> tma_coord) { + asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.min.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.min.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5}], [%1], %6;" + : + : "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy()) + : "memory" + ); + } +} + +template __device__ static inline void vec_store_max_async_tma_internal(uint64_t tma_ptr, uint32_t src_i_ptr, coord<> tma_coord) { + asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory"); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.max.tile.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b) + : "memory" + ); + } + else { + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.max.tile.bulk_group.L2::cache_hint" + " [%0, {%2, %3, %4, %5}], [%1], %6;" + : + : "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy()) + : "memory" + ); + } +} + +template __device__ static inline void vec_load_async_tma_internal(uint64_t tma_ptr, uint32_t dst_i_ptr, uint32_t mbar_ptr, coord<> tma_coord) { + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2];" + : + : "r"(dst_i_ptr), "l"(tma_ptr), "r"(mbar_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(dst_i_ptr), "l"(tma_ptr), "r"(mbar_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy()) + : "memory" + ); + } +} + +namespace cluster { +template __device__ static inline void vec_load_async_tma_internal(uint64_t tma_ptr, uint32_t dst_i_ptr, uint32_t mbar_ptr, coord<> tma_coord, uint16_t cluster_mask, int dst_mbar_cta=-1) { +#ifdef KITTENS_BLACKWELL + if(dst_mbar_cta != -1) { + uint32_t neighbor_mbar_ptr; + asm volatile ( + "mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(neighbor_mbar_ptr) + : "r"(mbar_ptr), "r"(dst_mbar_cta) + ); + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.cta_group::2.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(dst_i_ptr), "l"(tma_ptr), "r"(neighbor_mbar_ptr), + "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "h"(cluster_mask) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.cta_group::2.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7, %8;" + : + : "r"(dst_i_ptr), "l"(tma_ptr), "r"(neighbor_mbar_ptr), + "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "h"(cluster_mask), "l"(make_cache_policy()) + : "memory" + ); + } + } else +#endif + if constexpr (policy == cache_policy::NORMAL) { + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(dst_i_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "h"(cluster_mask) + : "memory" + ); + } + else { + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7, %8;" + : + : "r"(dst_i_ptr), "l"(tma_ptr), "r"(mbar_ptr), + "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "h"(cluster_mask), "l"(make_cache_policy()) + : "memory" + ); + } +} +} // namespace cluster + +} // namespace tma +} // namespace detail + +namespace tma { + +template> +__device__ static inline void prefetch(SV &dst, const GL &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + ::kittens::detail::tma::vec_prefetch_tma_internal(tma_ptr, tma_coord); + } +} +__KITTENS_TMA_DEFINE_DEFAULT_LOAD_CACHE_VEC__(prefetch) + +template> +__device__ static inline void store_async(const GL &dst, const SV &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + ::kittens::tma::store_commit_group(); +} +__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_async) + +template> +__device__ static inline void store_async(const PGL &dst, const SV &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + ::kittens::tma::store_commit_group(); +} +__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_async) + +template> +__device__ static inline void store_add_async(const GL &dst, const SV &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_add_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + ::kittens::tma::store_commit_group(); +} +__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_add_async) + +template> +__device__ static inline void store_add_async(const PGL &dst, const SV &src, const COORD &idx) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_add_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + ::kittens::tma::store_commit_group(); +} +__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_add_async) + +template> +__device__ static inline void store_min_async(const GL &dst, const SV &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_min_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + ::kittens::tma::store_commit_group(); +} +__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_min_async) + +template> +__device__ static inline void store_min_async(const PGL &dst, const SV &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_min_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + ::kittens::tma::store_commit_group(); +} +__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_min_async) + +template> +__device__ static inline void store_max_async(const GL &dst, const SV &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_max_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + ::kittens::tma::store_commit_group(); +} +__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_max_async) + +template> +__device__ static inline void store_max_async(const PGL &dst, const SV &src, const COORD &idx) { + static_assert(!std::is_same_v, "TMA does not support async min/max reductions for fp32 types."); + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(dst.template get_tma()); + uint32_t src_ptr = static_cast(__cvta_generic_to_shared(&src)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_store_max_async_tma_internal(tma_ptr, src_i_ptr, tma_coord); + } + ::kittens::tma::store_commit_group(); +} +__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_max_async) + +template> +__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx, semaphore& bar) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t dst_i_ptr = dst_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::vec_load_async_tma_internal(tma_ptr, dst_i_ptr, mbar_ptr, tma_coord); + } +} +__KITTENS_TMA_DEFINE_SEMAPHORE_CACHE_VEC__(load_async) + +namespace cluster { +template> +__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { + coord<> unit_coord = idx.template unit_coord<-1, 3>(); + uint64_t tma_ptr = reinterpret_cast(src.template get_tma()); + uint32_t mbar_ptr = static_cast(__cvta_generic_to_shared(&bar)); + uint32_t dst_ptr = static_cast(__cvta_generic_to_shared(&dst)); + for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2; i++) { + coord<> tma_coord = unit_coord; + tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1; + uint32_t dst_i_ptr = dst_ptr + i*::kittens::detail::tma::sv_tma_dim1*sizeof(typename SV::dtype); + ::kittens::detail::tma::cluster::vec_load_async_tma_internal(tma_ptr, dst_i_ptr, mbar_ptr, tma_coord, cluster_mask, dst_mbar_cta); + } +} +__KITTENS_TMA_DEFINE_CLUSTER_SEMAPHORE_CACHE_VEC__(load_async) +} // namespace cluster +} // namespace tma +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/thread/memory/vec/vec.cuh b/extra/thunder/cuda/include/ops/thread/memory/vec/vec.cuh new file mode 100644 index 0000000000..7a42c6790f --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/memory/vec/vec.cuh @@ -0,0 +1,10 @@ +/** + * @file + * @brief An aggregate header of warp memory operations on vectors, where a single warp loads or stores data on its own. + */ + +#pragma once + +#ifdef KITTENS_HOPPER +#include "tma.cuh" +#endif \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/thread/mma/mma.cuh b/extra/thunder/cuda/include/ops/thread/mma/mma.cuh new file mode 100644 index 0000000000..67eec84383 --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/mma/mma.cuh @@ -0,0 +1,8 @@ +/** + * @file + * @brief An aggregate header for warp operations on data stored in tensor memory. + */ + +#pragma once + +#include "tensor/tensor.cuh" \ No newline at end of file diff --git a/extra/thunder/cuda/include/ops/thread/mma/tensor/tensor.cuh b/extra/thunder/cuda/include/ops/thread/mma/tensor/tensor.cuh new file mode 100644 index 0000000000..72911ca6c5 --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/mma/tensor/tensor.cuh @@ -0,0 +1,523 @@ +/** + * @file + * @brief Matrix multiply-accumulate operations for tiles stored in tensor memory. + */ + +#pragma once + +#include "../../../../common/common.cuh" +#include "../../../../types/types.cuh" + +namespace kittens { +namespace detail { +namespace tcgen05 { +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#instruction-descriptor +template +__device__ static inline uint32_t instruction_descriptor() { + uint32_t desc = 0; + if constexpr (sizeof(AB) == 2) { // kind::f16 + // either accumulate to float, or the input is half and the output is half + static_assert(std::is_same_v || std::is_same_v); + desc |= 0b00 << 0; // sparsity bits unneeded + desc |= 0b0 << 2; // dense + desc |= 0b0 << 3; // no saturate on fp types + if constexpr (std::is_same_v) { + desc |= 0b01 << 4; // D matrix is FP32 + } + else { + desc |= 0b00 << 4; // D matrix is FP16 + } + desc |= 0b0 << 6; // reserved + if constexpr (std::is_same_v) { + desc |= 0b000 << 7; // 16-bit A input type as FP16 + desc |= 0b000 << 10; // 16-bit B input type as FP16 + } else if constexpr (std::is_same_v) { + desc |= 0b001 << 7; // 16-bit A input type as BF16 + desc |= 0b001 << 10; // 16-bit B input type as BF16 + } else if constexpr (std::is_same_v) { + desc |= 0b000 << 7; // 8-bit A input type as FP8 e4m3 + desc |= 0b000 << 10; // 8-bit B input type as FP8 e4m3 + } else if constexpr (std::is_same_v) { + desc |= 0b001 << 7; // 8-bit A input type as FP8 e5m2 + desc |= 0b001 << 10; // 8-bit B input type as FP8 e5m2 + } + /* fp6 and fp4 + else if constexpr (std::is_same_v) { + desc |= 0b011 << 7; // 6-bit A input type as FP6 e2m3 + desc |= 0b011 << 10; // 6-bit B input type as FP6 e2m3 + } + else if constexpr (std::is_same_v) { + desc |= 0b100 << 7; // 6-bit A input type as FP6 e3m2 + desc |= 0b100 << 10; // 6-bit B input type as FP6 e3m2 + } + else if constexpr (std::is_same_v) { + desc |= 0b101 << 7; // 4-bit A input type as FP4 e3m1 + desc |= 0b101 << 10; // 4-bit B input type as FP4 e3m1 + } + */ + if constexpr (neg) { + desc |= 0b1 << 13; // Do negate A matrix + } + else { + desc |= 0b0 << 13; // Don't negate A matrix + } + desc |= 0b0 << 14; // Don't negate B matrix (in all cases) + if constexpr (trans_a) { + desc |= 0b1 << 15; // Transpose A matrix + } + else { + desc |= 0b0 << 15; // Don't transpose A matrix + } + if constexpr (trans_b) { + desc |= 0b1 << 16; // Transpose B matrix + } + else { + desc |= 0b0 << 16; // Don't transpose B matrix + } + desc |= (N >> 3) << 17; // B matrix has dimension N, encoded + desc |= 0b0 << 23; // reserved + desc |= (M >> 4) << 24; // A matrix has dimension M, encoded + desc |= 0b0 << 29; // reserved + desc |= 0b00 << 30; // no shift for B-matrix reuse + } else if constexpr (sizeof(AB) == 1) { // kind::f8f6f4 + static_assert(std::is_same_v || std::is_same_v); // FP8/6/4 has to accumulate to float or half + desc |= 0b00 << 0; // sparsity bits unneeded + desc |= 0b0 << 2; // dense + desc |= 0b0 << 3; // no saturate on fp types + if constexpr (std::is_same_v) { + desc |= 0b01 << 4; // D matrix is FP32 + } + else { + desc |= 0b00 << 4; // D matrix is FP16 + } + desc |= 0b0 << 6; // reserved + if constexpr (std::is_same_v) { + desc |= 0b000 << 7; // 8-bit A input type as FP8 e4m3 + desc |= 0b000 << 10; // 8-bit B input type as FP8 e4m3 + } else if constexpr (std::is_same_v) { + desc |= 0b001 << 7; // 8-bit A input type as FP8 e5m2 + desc |= 0b001 << 10; // 8-bit B input type as FP8 e5m2 + } + /* fp6 and fp4 + else if constexpr (std::is_same_v) { + desc |= 0b011 << 7; // 6-bit A input type as FP6 e2m3 + desc |= 0b011 << 10; // 6-bit B input type as FP6 e2m3 + } + else if constexpr (std::is_same_v) { + desc |= 0b100 << 7; // 6-bit A input type as FP6 e3m2 + desc |= 0b100 << 10; // 6-bit B input type as FP6 e3m2 + } + else if constexpr (std::is_same_v) { + desc |= 0b101 << 7; // 4-bit A input type as FP4 e3m1 + desc |= 0b101 << 10; // 4-bit B input type as FP4 e3m1 + } + */ + if constexpr (neg) { + desc |= 0b1 << 13; // Do negate A matrix + } + else { + desc |= 0b0 << 13; // Don't negate A matrix + } + desc |= 0b0 << 14; // Don't negate B matrix (in all cases) + if constexpr (trans_a) { + desc |= 0b1 << 15; // Transpose A matrix + } + else { + desc |= 0b0 << 15; // Don't transpose A matrix + } + if constexpr (trans_b) { + desc |= 0b1 << 16; // Transpose B matrix + } + else { + desc |= 0b0 << 16; // Don't transpose B matrix + } + desc |= (N >> 3) << 17; // B matrix has dimension N, encoded + desc |= 0b0 << 23; // reserved + desc |= (M >> 4) << 24; // A matrix has dimension M, encoded + desc |= 0b0 << 29; // reserved + desc |= 0b00 << 30; // no shift for B-matrix reuse + } + else { + static_assert(sizeof(AB) == 999, "Invalid AB type size; not implemented yet."); + } + return desc; +}; + +template +__device__ static inline void tt_st(uint32_t d_tt_addr, uint32_t a_tt_addr, uint64_t b_desc, uint32_t idesc) { + if constexpr (std::is_same_v || std::is_same_v) { + // TODO(danfu): is there a better way to do this with string manipulation that the compiler likes? + if constexpr (ncta == 1) { + asm volatile( + "{.reg .pred p;\n" \ + "setp.eq.u32 p, 1, %4;\n" \ + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, p;}\n" + :: "r"(d_tt_addr), "r"(a_tt_addr), "l"(b_desc), "r"(idesc), "n"(acc) + ); + } + else { + asm volatile( + "{.reg .pred p;\n" \ + "setp.eq.u32 p, 1, %4;\n" \ + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], [%1], %2, %3, p;}\n" + :: "r"(d_tt_addr), "r"(a_tt_addr), "l"(b_desc), "r"(idesc), "n"(acc) + ); + } + } else { + if constexpr (ncta == 1) { + asm volatile( + "{.reg .pred p;\n" \ + "setp.eq.u32 p, 1, %4;\n" \ + "tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, p;}\n" + :: "r"(d_tt_addr), "r"(a_tt_addr), "l"(b_desc), "r"(idesc), "n"(acc) + ); + } + else { + asm volatile( + "{.reg .pred p;\n" \ + "setp.eq.u32 p, 1, %4;\n" \ + "tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, p;}\n" + :: "r"(d_tt_addr), "r"(a_tt_addr), "l"(b_desc), "r"(idesc), "n"(acc) + ); + } + } +} + +template +__device__ static inline void st_st(uint32_t d_tt_addr, uint64_t a_desc, uint64_t b_desc, uint32_t idesc) { + if constexpr (std::is_same_v || std::is_same_v) { + // TODO(danfu): is there a better way to do this with string manipulation that the compiler likes? + if constexpr (ncta == 1) { + asm volatile( + "{.reg .pred p;\n" \ + "setp.eq.u32 p, 1, %4;\n" \ + "tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p;}\n" + :: "r"(d_tt_addr), "l"(a_desc), "l"(b_desc), "r"(idesc), "n"(acc) + ); + } + else { + asm volatile( + "{.reg .pred p;\n" \ + "setp.eq.u32 p, 1, %4;\n" \ + "tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p;}\n" + :: "r"(d_tt_addr), "l"(a_desc), "l"(b_desc), "r"(idesc), "n"(acc) + ); + } + } else { + if constexpr (ncta == 1) { + asm volatile( + "{.reg .pred p;\n" \ + "setp.eq.u32 p, 1, %4;\n" \ + "tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p;}\n" + :: "r"(d_tt_addr), "l"(a_desc), "l"(b_desc), "r"(idesc), "n"(acc) + ); + } + else { + asm volatile( + "{.reg .pred p;\n" \ + "setp.eq.u32 p, 1, %4;\n" \ + "tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p;}\n" + :: "r"(d_tt_addr), "l"(a_desc), "l"(b_desc), "r"(idesc), "n"(acc) + ); + } + } +} + +template __device__ static inline void commit(kittens::semaphore &sem) { + if constexpr (ncta == 1) { + asm volatile( + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [%0];\n" + :: "l"(&sem) + ); + } + else { + asm volatile( + "tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n" + :: "l"(&sem), "h"((uint16_t)(0b11)) + ); + } +} + +} // namespace tcgen05 +} // namespace detail + +template constexpr int reduction_dimension = sizeof(T_AB) == 2 ? 16 : sizeof(T_AB) == 4 ? 8 : 32; // haven't added fp4 yet. +// RS matmul equivalent +template +__device__ static inline void mma(D &d, const A &a, const B &b) { + constexpr int trans_b = 1 - n_trans_b; + + // Do everything here. + constexpr int M = (trans_a ? A::cols : A::rows) * ncta; + static_assert(M == D::rows*ncta && ((ncta == 1 && (M == 64 || M == 128)) || (ncta == 2 && (M == 128 || M == 256)))); // output register is correctly sized + + constexpr int N = (trans_b ? B::cols : B::rows) * ncta; + static_assert(N == D::cols); // output register is correctly sized + + constexpr int K = trans_a ? A::rows : A::cols; + static_assert((trans_b ? B::rows : B::cols) == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; static_assert(std::is_same_v); + using T_D = D::T; + + constexpr int red_dim = reduction_dimension; + static_assert(K%red_dim == 0, "K dimension must be divisible by red_dim."); + + static_assert( + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v), + "Currently unsupported type combination for matrix multiply." + ); + uint32_t idesc = detail::tcgen05::instruction_descriptor(); + kittens::st_descriptor, trans_b> b_desc(b); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + + detail::tcgen05::template tt_st( + d.addr, + a.template chunk_addr(0), + b_desc.chunk_descriptor(0), + idesc + ); + #pragma unroll + for(int i = 1; i < K/red_dim; i++) { + detail::tcgen05::template tt_st( + d.addr, + a.template chunk_addr(i), + b_desc.chunk_descriptor(i), + idesc + ); + } +} +template +__device__ static inline void mma(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b); + detail::tcgen05::commit(sem); +} +// SS matmul equivalent +template +__device__ static inline void mma(D &d, const A &a, const B &b) { + constexpr int trans_b = 1 - n_trans_b; + + // Do everything here. + constexpr int M = (trans_a ? A::cols : A::rows) * ncta; + static_assert(M == D::rows*ncta && ((ncta == 1 && (M == 64 || M == 128)) || (ncta == 2 && (M == 128 || M == 256)))); // output register is correctly sized + + constexpr int N = (trans_b ? B::cols : B::rows) * ncta; + static_assert(N == D::cols); // output register is correctly sized + + constexpr int K = trans_a ? A::rows : A::cols; + static_assert((trans_b ? B::rows : B::cols) == K); // K dimension must match + static_assert(std::is_same_v); // A and B must match type. + + // Usings + using T_AB = A::T; static_assert(std::is_same_v); + using T_D = D::T; + + constexpr int red_dim = reduction_dimension; + static_assert(K%red_dim == 0, "K dimension must be divisible by red_dim."); + + static_assert( + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v) || + (std::is_same_v && !std::is_same_v), + "Currently unsupported type combination for matrix multiply." + ); + uint32_t idesc = detail::tcgen05::instruction_descriptor(); + kittens::st_descriptor, trans_a> a_desc(a); + kittens::st_descriptor, trans_b> b_desc(b); + + asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory"); + + detail::tcgen05::template st_st( + d.addr, + a_desc.chunk_descriptor(0), + b_desc.chunk_descriptor(0), + idesc + ); + #pragma unroll + for(int i = 1; i < K/red_dim; i++) { + detail::tcgen05::template st_st( + d.addr, + a_desc.chunk_descriptor(i), + b_desc.chunk_descriptor(i), + idesc + ); + } +} +template +__device__ static inline void mma(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b); + detail::tcgen05::commit(sem); +} +// Accumulator / numcta wrappers +template +__device__ static inline void mma2(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma2(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mm2(D &d, const A &a, const B &b) { + mma2(d, a, b); +} + +// Transpose wrappers +template +__device__ static inline void mma_AB(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma_AB(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mma2_AB(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mma2_AB(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mma_ABt(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma_ABt(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mma2_ABt(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mma2_ABt(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mma_AtB(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma_AtB(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mma2_AtB(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mma2_AtB(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mma_AtBt(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mma_AtBt(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mma2_AtBt(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mma2_AtBt(D &d, const A &a, const B &b) { + mma2(d, a, b); +} + +template +__device__ static inline void mm_AB(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm_AB(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2_AB(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mm2_AB(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mm_ABt(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm_ABt(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2_ABt(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mm2_ABt(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mm_AtB(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm_AtB(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2_AtB(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mm2_AtB(D &d, const A &a, const B &b) { + mma2(d, a, b); +} +template +__device__ static inline void mm_AtBt(D &d, const A &a, const B &b, semaphore &sem) { + mma(d, a, b, sem); +} +template +__device__ static inline void mm_AtBt(D &d, const A &a, const B &b) { + mma(d, a, b); +} +template +__device__ static inline void mm2_AtBt(D &d, const A &a, const B &b, semaphore &sem) { + mma2(d, a, b, sem); +} +template +__device__ static inline void mm2_AtBt(D &d, const A &a, const B &b) { + mma2(d, a, b); +} + + +} // namespace kittens + diff --git a/extra/thunder/cuda/include/ops/thread/thread.cuh b/extra/thunder/cuda/include/ops/thread/thread.cuh new file mode 100644 index 0000000000..d6de48003b --- /dev/null +++ b/extra/thunder/cuda/include/ops/thread/thread.cuh @@ -0,0 +1,13 @@ +/** + * @file + * @brief An aggregate header of all warp (worker) operations defined by ThunderKittens + */ + +#pragma once + +// no namespace wrapper needed here + +#include "memory/memory.cuh" +#ifdef KITTENS_BLACKWELL +#include "mma/mma.cuh" +#endif \ No newline at end of file diff --git a/extra/thunder/cuda/include/pyutils/broker.cuh b/extra/thunder/cuda/include/pyutils/broker.cuh new file mode 100644 index 0000000000..a045679754 --- /dev/null +++ b/extra/thunder/cuda/include/pyutils/broker.cuh @@ -0,0 +1,551 @@ +/** + * @file broker.cuh + * @brief Utility for multiprocess data exchange and synchronization. + * + * This file provides the KittensBroker class, which enables efficient inter-process + * communication and synchronization using POSIX shared memory, semaphores, and sockets. + * The broker is designed to work in multi-GPU environments where processes need to + * exchange data and synchronize execution across different local ranks. + * + * @note This implementation relies on POSIX IPC mechanisms and is intended for + * Unix-like systems. All processes must be running on the same node. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64) + #error "KittensBroker is not supported on Windows" +#endif + +namespace kittens { + +namespace detail { +namespace broker { + +static constexpr int MAX_LOCAL_WORLD_SIZE = 72; +static constexpr int VAULT_SIZE_PER_RANK = 64; // sizeof(cudaIpcMemHandle_t) + +struct KittensVault { + static constexpr int INIT_CODE = 0x43617473; // "Cats" + int init; + int barrier; + int sense; + uint8_t data[MAX_LOCAL_WORLD_SIZE * VAULT_SIZE_PER_RANK]; +}; + +static constexpr int SHM_SIZE = (sizeof(KittensVault) + 4095) / 4096 * 4096; + +__host__ inline static void init_sync( + int local_rank, + volatile KittensVault *vault +) { + if (local_rank == 0) { + // initialize barrier resources + vault->barrier = 0; + vault->sense = 0; + __sync_synchronize(); // make previous writes visible + vault->init = KittensVault::INIT_CODE; + } else { + while (vault->init != KittensVault::INIT_CODE) usleep(1); + __sync_synchronize(); // see leader's previous writes + } +} + +__host__ inline static void sync( + int local_world_size, + volatile KittensVault *vault +) { + if (vault->init != KittensVault::INIT_CODE) + throw std::runtime_error("KittensBroker: KittensVault not initialized"); + + // Phase 1 + int arrived = __sync_add_and_fetch(&vault->barrier, 1); + if (arrived == local_world_size) vault->sense = 1; + while (!vault->sense) usleep(1); + + // Make previous writes visible + __sync_synchronize(); + + // Phase 2 + arrived = __sync_add_and_fetch(&vault->barrier, -1); + if (arrived == 0) vault->sense = 0; + while (vault->sense) usleep(1); +} + +__host__ inline void *create_shm(const char *key, size_t size) { + int shm_fd; + shm_fd = shm_open(key, O_RDWR | O_CREAT | O_EXCL | O_CLOEXEC, 0600); + + if (shm_fd < 0) { + if (errno == EEXIST) + throw std::runtime_error("KittensBroker: Named shared memory already exists"); + throw std::runtime_error("KittensBroker: Failed to create shared memory"); + } + + if (ftruncate(shm_fd, size) != 0) { + shm_unlink(key); + close(shm_fd); + throw std::runtime_error("KittensBroker: Failed to truncate shared memory"); + } + + void *addr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + close(shm_fd); + if (addr == MAP_FAILED) { + shm_unlink(key); + throw std::runtime_error("KittensBroker: Failed to map to shared memory"); + } + + return addr; +} + +__host__ inline void *open_shm(const char *key, size_t size) { + int shm_fd; + while (true) { + shm_fd = shm_open(key, O_RDWR | O_CLOEXEC, 0); + if (shm_fd >= 0) + break; + if (errno != ENOENT) + throw std::runtime_error("KittensBroker: Failed to open shared memory"); + usleep(1); + } + + struct stat shm_st; + do { + if (fstat(shm_fd, &shm_st) != 0) { + shm_unlink(key); + close(shm_fd); + throw std::runtime_error("KittensBroker: Failed to open shared memory stats"); + } + usleep(1); + } while ((size_t)shm_st.st_size < size); + + void *addr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0); + close(shm_fd); + if (addr == MAP_FAILED) { + shm_unlink(key); + throw std::runtime_error("KittensBroker: Failed to map to shared memory"); + } + + return addr; +} + +__host__ inline void unlink_shm(const char *key) { + shm_unlink(key); +} + +__host__ inline void unmap_shm(void *addr, size_t size) { + munmap(addr, size); +} + +__host__ inline int create_socket(const char *key, int local_rank) { + int sock_fd; + if ((sock_fd = socket(AF_UNIX, SOCK_DGRAM | SOCK_CLOEXEC, 0)) < 0) + throw std::runtime_error("KittensBroker: Socket creation error"); + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + + char unique_key[64]; + int n = snprintf(unique_key, sizeof(unique_key), "%s%d", key, local_rank); + if (n < 0 || n >= (int)sizeof(unique_key)) { + close(sock_fd); + throw std::runtime_error("KittensBroker: Socket name too long"); + } + + size_t len = strnlen(unique_key, sizeof(addr.sun_path)); + if (len > (sizeof(addr.sun_path) - 1)) { + close(sock_fd); + throw std::runtime_error("KittensBroker: Socket name too long"); + } + strcpy(addr.sun_path, unique_key); + unlink(unique_key); + + if (bind(sock_fd, (struct sockaddr *)&addr, SUN_LEN(&addr)) < 0) { + close(sock_fd); + throw std::runtime_error("KittensBroker: Failed to bind socket"); + } + + return sock_fd; +} + +__host__ inline void send_fd( + int sock_fd, + int data_fd, + const char *dst_key, + int dst_local_rank, + int src_local_rank +) { + union { + struct cmsghdr cm; + char* control; + } control_un; + + size_t sizeof_control = CMSG_SPACE(sizeof(int)); + control_un.control = reinterpret_cast(malloc(sizeof_control)); + if (!control_un.control) { + close(sock_fd); + close(data_fd); + throw std::runtime_error("KittensBroker: Failed to allocate a control buffer"); + } + + struct msghdr msg {}; + msg.msg_control = control_un.control; + msg.msg_controllen = sizeof_control; + + struct cmsghdr *cmptr = CMSG_FIRSTHDR(&msg); + cmptr->cmsg_len = CMSG_LEN(sizeof(int)); + cmptr->cmsg_level = SOL_SOCKET; + cmptr->cmsg_type = SCM_RIGHTS; + memmove(CMSG_DATA(cmptr), &data_fd, sizeof(data_fd)); + + struct sockaddr_un addr {}; + addr.sun_family = AF_UNIX; + char dst_unique_key[64]; + int n = snprintf(dst_unique_key, sizeof(dst_unique_key), "%s%d", dst_key, dst_local_rank); + if (n < 0 || n >= (int)sizeof(dst_unique_key)) { + free(control_un.control); + close(sock_fd); + close(data_fd); + throw std::runtime_error("KittensBroker: dst path too long"); + } + strcpy(addr.sun_path, dst_unique_key); + msg.msg_name = (void *)&addr; + msg.msg_namelen = sizeof(struct sockaddr_un); + + int payload = src_local_rank; + struct iovec iov[1]; + iov[0].iov_base = &payload; + iov[0].iov_len = sizeof(payload); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + + while (true) { + ssize_t sent = sendmsg(sock_fd, &msg, 0); + if (sent <= 0) { + if (errno == EINTR) continue; + close(sock_fd); + close(data_fd); + free(control_un.control); + throw std::runtime_error("KittensBroker: Failed to send FD over socket"); + } + break; + } + + free(control_un.control); +} + +__host__ inline void recv_fd(int sock_fd, int *data_fd, int *src_local_rank) { + union { + struct cmsghdr cm; + char* control; + } control_un; + + size_t sizeof_control = CMSG_SPACE(sizeof(int)); + control_un.control = reinterpret_cast(malloc(sizeof_control)); + if (!control_un.control) { + close(sock_fd); + throw std::runtime_error("KittensBroker: Failed to allocate a control buffer"); + } + + struct msghdr msg {}; + msg.msg_control = control_un.control; + msg.msg_controllen = sizeof_control; + + int payload = -1; + struct iovec iov[1]; + iov[0].iov_base = &payload; + iov[0].iov_len = sizeof(payload); + msg.msg_iov = iov; + msg.msg_iovlen = 1; + + while (true) { + ssize_t received = recvmsg(sock_fd, &msg, 0); + if (received < 0 && errno == EINTR) { + msg.msg_controllen = sizeof_control; + msg.msg_iovlen = 1; + continue; + } + if (received < static_cast(sizeof(*data_fd))) { + free(control_un.control); + close(sock_fd); + throw std::runtime_error("KittensBroker: Failed to receive data over socket"); + } + break; + } + + if (msg.msg_flags & MSG_CTRUNC) { + free(control_un.control); + close(sock_fd); + throw std::runtime_error("KittensBroker: Control data truncated"); + } + + struct cmsghdr *cmptr = CMSG_FIRSTHDR(&msg); + if (!cmptr || + cmptr->cmsg_len != CMSG_LEN(sizeof(int)) || + cmptr->cmsg_level != SOL_SOCKET || + cmptr->cmsg_type != SCM_RIGHTS) { + free(control_un.control); + close(sock_fd); + throw std::runtime_error("KittensBroker: Failed to receive data over socket"); + } + + memmove(data_fd, CMSG_DATA(cmptr), sizeof(*data_fd)); + free(control_un.control); + *src_local_rank = payload; +} + +__host__ inline void unlink_socket(const char *key, int local_rank) { + char unique_key[64]; + int n = snprintf(unique_key, sizeof(unique_key), "%s%d", key, local_rank); + if (n < 0 || n >= (int)sizeof(unique_key)) + throw std::runtime_error("KittensBroker: Socket name too long"); + unlink(unique_key); +} + +__host__ inline void close_socket(int sock_fd) { + close(sock_fd); +} + +} // namespace broker +} // namespace detail + +/** + @brief KittensBroker utility for multiprocess data exchange. + + Note that the code relies on POSIX sockets/shared memory/semaphores for + inter-process communication and synchronization. + + The main functions meant to be used by the user are: + + KittensBroker broker(local_rank, local_world_size); + broker.exchange_data(dst, src, size); // exchange data between all processes + broker.exchange_fds(dst, src_fd); // exchange file descriptors between all processes + broker.broadcast_fd(dst, src_fd, src_rank); // broadcast file descriptor from src_rank to all processes + broker.sync(); // wait until all processes reach here + */ +struct KittensBroker { + // TODO: make unique per process group + static inline constexpr const char *SHM_KEY_ = "/kittens_broker_shm"; + static inline constexpr const char *SOCK_KEY_ = "/tmp/kittens_broker.sock"; + + int local_rank_; + int local_world_size_; + + void *shm_raw_; + volatile detail::broker::KittensVault *shm_; + int sock_; + + __host__ inline KittensBroker(int local_rank, int local_world_size) + : local_rank_(local_rank), + local_world_size_(local_world_size), + shm_raw_(nullptr), + shm_(nullptr), + sock_(-1) { + if (local_rank_ < 0) + throw std::runtime_error("KittensBroker: Local rank must be non-negative"); + if (local_rank_ >= local_world_size_) + throw std::runtime_error("KittensBroker: Local rank is greater than local world size"); + if (local_world_size_ > detail::broker::MAX_LOCAL_WORLD_SIZE) + throw std::runtime_error("KittensBroker: Local world size is greater than MAX_LOCAL_WORLD_SIZE"); + + if (local_rank_ == 0) { + shm_raw_ = detail::broker::create_shm(SHM_KEY_, sizeof(detail::broker::KittensVault)); + shm_ = reinterpret_cast(shm_raw_); + memset(shm_raw_, 0, sizeof(detail::broker::KittensVault)); + } else { + shm_raw_ = detail::broker::open_shm(SHM_KEY_, sizeof(detail::broker::KittensVault)); + shm_ = reinterpret_cast(shm_raw_); + } + detail::broker::init_sync(local_rank_, shm_); + detail::broker::sync(local_world_size_, shm_); + + if (local_rank_ ==0) + detail::broker::unlink_shm(SHM_KEY_); + detail::broker::sync(local_world_size_, shm_); + + sock_ = detail::broker::create_socket(SOCK_KEY_, local_rank_); + detail::broker::sync(local_world_size_, shm_); + } + + KittensBroker(const KittensBroker&) = delete; + KittensBroker& operator=(const KittensBroker&) = delete; + + __host__ inline KittensBroker(KittensBroker&& other) noexcept + : local_rank_(other.local_rank_), + local_world_size_(other.local_world_size_), + shm_raw_(other.shm_raw_), + shm_(other.shm_), + sock_(other.sock_) { + other.local_rank_ = -1; + other.local_world_size_ = -1; + other.shm_raw_ = nullptr; + other.shm_ = nullptr; + other.sock_ = -1; + } + + __host__ inline void destroy() { + if (shm_raw_) { + detail::broker::unmap_shm(shm_raw_, sizeof(detail::broker::KittensVault)); + shm_raw_ = nullptr; + shm_ = nullptr; + } + if (sock_ >= 0) { + detail::broker::unlink_socket(SOCK_KEY_, local_rank_); + detail::broker::close_socket(sock_); + sock_ = -1; + } + local_rank_ = -1; + local_world_size_ = -1; + } + + __host__ inline KittensBroker& operator=(KittensBroker&& other) noexcept { + if (this != &other) { + destroy(); + local_rank_ = other.local_rank_; + local_world_size_ = other.local_world_size_; + shm_raw_ = other.shm_raw_; + shm_ = other.shm_; + sock_ = other.sock_; + other.local_rank_ = -1; + other.local_world_size_ = -1; + other.shm_raw_ = nullptr; + other.shm_ = nullptr; + other.sock_ = -1; + } + return *this; + } + + __host__ inline ~KittensBroker() { + destroy(); + } + + __host__ inline void sync(int num_ranks = -1) { + if (num_ranks == -1) + num_ranks = local_world_size_; + else if (num_ranks < 0 || num_ranks > local_world_size_) + throw std::runtime_error("KittensBroker: Invalid number of ranks"); + + detail::broker::sync(num_ranks, shm_); + } + + __host__ inline void exchange_data(void *dst_, const void *src_, size_t size) { + if (size > detail::broker::VAULT_SIZE_PER_RANK) + throw std::runtime_error("KittensBroker: Size is greater than VAULT_SIZE_PER_RANK"); + + uint8_t *dst = reinterpret_cast(dst_); + const uint8_t *src = reinterpret_cast(src_); + + // Exchange data + sync(); // ensure all processes enter together + memcpy(const_cast(shm_->data) + local_rank_ * detail::broker::VAULT_SIZE_PER_RANK, src, size); + sync(); // ensure all processes exit together + + // Pack and copy back to destination + for (int i = 0; i < local_world_size_; i++) + memcpy(dst + i * size, const_cast(shm_->data) + i * detail::broker::VAULT_SIZE_PER_RANK, size); + } + + __host__ inline void exchange_fds(int *dst, const int data_fd) { + if (dst == nullptr) + throw std::runtime_error("KittensBroker: dst is null"); + if (data_fd < 0) + throw std::runtime_error("KittensBroker: source fd is negative"); + + // Initialize dst buffer + for (int i = 0; i < local_world_size_; ++i) + dst[i] = -1; + + // Ensure all processes enter together + sync(); + + if (local_rank_ == 0) { + // Rank 0 receives all FDs from and distributes them to other ranks + dst[0] = data_fd; + for (int i = 0; i < local_world_size_ - 1; i++) { + int received_fd; + int src_local_rank; + detail::broker::recv_fd(sock_, &received_fd, &src_local_rank); + if (received_fd < 0) + throw std::runtime_error("KittensBroker: Failed to receive FD over socket"); + if (src_local_rank == local_rank_) + throw std::runtime_error("KittensBroker: Invalid source rank"); + dst[src_local_rank] = received_fd; + } + for (int dst_local_rank = 1; dst_local_rank < local_world_size_; dst_local_rank++) { + for (int src_local_rank = 0; src_local_rank < local_world_size_; src_local_rank++) { + if (dst_local_rank == src_local_rank) + continue; + detail::broker::send_fd(sock_, dst[src_local_rank], SOCK_KEY_, dst_local_rank, src_local_rank); + } + } + close(dst[0]); // no longer needed + dst[0] = -1; + } else { + // The rest sends its FD to and receives the other FDs from rank 0 + detail::broker::send_fd(sock_, data_fd, SOCK_KEY_, 0, local_rank_); + close(data_fd); // no longer needed + for (int i = 0; i < local_world_size_ - 1; i++) { + int received_fd; + int src_local_rank; + detail::broker::recv_fd(sock_, &received_fd, &src_local_rank); + if (received_fd < 0) + throw std::runtime_error("KittensBroker: Failed to receive FD over socket"); + if (src_local_rank == local_rank_) + throw std::runtime_error("KittensBroker: Invalid source rank"); + dst[src_local_rank] = received_fd; + } + } + + // Ensure all processes exit together + sync(); + } + + __host__ inline void broadcast_fd(int *dst, const int data_fd, const int src_local_rank) { + if (src_local_rank < 0 || src_local_rank >= local_world_size_) + throw std::runtime_error("KittensBroker: Invalid source rank"); + + // Ensure all processes enter together + sync(); + + if (local_rank_ == src_local_rank) { + if (data_fd < 0) + throw std::runtime_error("KittensBroker: Source rank has invalid FD"); + for (int dst_local_rank = 0; dst_local_rank < local_world_size_; dst_local_rank++) { + if (dst_local_rank == src_local_rank) + continue; + detail::broker::send_fd(sock_, data_fd, SOCK_KEY_, dst_local_rank, src_local_rank); + } + close(data_fd); // no longer needed + } else { + if (!dst) + throw std::runtime_error("KittensBroker: Destination rank has invalid buffer"); + int _src_local_rank; + detail::broker::recv_fd(sock_, dst, &_src_local_rank); + if (*dst < 0) + throw std::runtime_error("KittensBroker: Failed to receive valid FD over socket"); + if (_src_local_rank != src_local_rank) + throw std::runtime_error("KittensBroker: Invalid source rank"); + } + + // Ensure all processes exit together + sync(); + } +}; + +} // namespace kittens diff --git a/extra/thunder/cuda/include/pyutils/club.cuh b/extra/thunder/cuda/include/pyutils/club.cuh new file mode 100644 index 0000000000..9d5580fca9 --- /dev/null +++ b/extra/thunder/cuda/include/pyutils/club.cuh @@ -0,0 +1,122 @@ +#include +#include +#include +#include +#include + +/* + CUDA-specific ThreadPool + + Example usage + + // Construction + KittensClub club(device_ids, NUM_DEVICES); + + // Dispatch work to all threads (no need to set device) + club.execute([&](int dev_idx) { + int dev; + CUDACHECK(cudaGetDevice(&dev)); + if (dev != dev_idx) { + fprintf(stderr, "Device mismatch: expected %d, got %d\n", dev_idx, dev); + exit(1); + } + }); +*/ +class KittensClub { +public: + __host__ inline KittensClub(const int *device_ids, const int num_devices); + __host__ inline KittensClub(const int *device_ids, const cudaStream_t *streams, const int num_devices); + __host__ inline ~KittensClub(); + + // Dispatches `task` to all threads, and waits for all threads to finish (using cv) + __host__ inline void execute(std::function task); + +private: + // Condition indicators + bool stop; + std::vector task_available; + int n_task_done; + + // Threadpool + std::vector workers; + + // Streams for each device + std::vector streams; + + // Main entry point for each thread + __host__ inline void worker(int worker_id, int device_id); + + // Used to dispatch work to all threads + std::function current_task; + + // Synchronization + std::mutex mutex; + std::condition_variable cond_task_available; + std::condition_variable cond_task_done; +}; + +__host__ inline KittensClub::KittensClub(const int *device_ids, const int num_devices) : stop(false), n_task_done(0) { + for (size_t dev_idx = 0; dev_idx < num_devices; ++dev_idx) { + task_available.push_back(false); + streams.push_back(0); // Use default stream (null stream) + workers.emplace_back([this, dev_idx, device_ids] { worker(dev_idx, device_ids[dev_idx]); }); + } +} + +__host__ inline KittensClub::KittensClub(const int *device_ids, const cudaStream_t *streams_in, const int num_devices) : stop(false), n_task_done(0) { + for (size_t dev_idx = 0; dev_idx < num_devices; ++dev_idx) { + task_available.push_back(false); + streams.push_back(streams_in[dev_idx]); + workers.emplace_back([this, dev_idx, device_ids] { worker(dev_idx, device_ids[dev_idx]); }); + } +} + +__host__ inline KittensClub::~KittensClub() { + { + std::lock_guard lock(mutex); + stop = true; + } + cond_task_available.notify_all(); + for (std::thread &worker : workers) { + worker.join(); + } +} + +__host__ inline void KittensClub::execute(std::function task) { + { + std::lock_guard lock(mutex); + current_task = task; + for (size_t i = 0; i < task_available.size(); ++i) + task_available[i] = true; + } + cond_task_available.notify_all(); + { + std::unique_lock lock(mutex); + cond_task_done.wait(lock, [this] { return n_task_done == workers.size(); }); + n_task_done = 0; + } +} + +__host__ inline void KittensClub::worker(int worker_id, int device_id) { + cudaSetDevice(device_id); // done once and never again! This saves a LOT of time + while (true) { + std::function task; + { + std::unique_lock lock(mutex); + cond_task_available.wait(lock, [this, worker_id] { return stop || task_available[worker_id]; }); + + if (stop) + return; + + task = current_task; + task_available[worker_id] = false; + } + task(worker_id, streams[worker_id]); + { + std::lock_guard lock(mutex); // adds about 10 microseconds overhead + ++n_task_done; + if (n_task_done == workers.size()) + cond_task_done.notify_one(); + } + } +} diff --git a/extra/thunder/cuda/include/pyutils/parallel_tensor.cuh b/extra/thunder/cuda/include/pyutils/parallel_tensor.cuh new file mode 100644 index 0000000000..ecbcca5c60 --- /dev/null +++ b/extra/thunder/cuda/include/pyutils/parallel_tensor.cuh @@ -0,0 +1,336 @@ +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "../types/device/vmm.cuh" +#include "../types/device/ipc.cuh" +#include "broker.cuh" + +namespace kittens { +namespace py { + +/** + * @brief Distributed tensor wrapper for multi-GPU IPC sharing and multicast. + * Can be later used for easy PGL creation right before a kernel call. + * Meant to be used as a single object per thread/process. + */ +struct TKParallelTensor { + inline static std::map, KittensBroker> brokers_; // lazily initialized + + at::Tensor data_; // for direct access from PyTorch + std::vector shape_; + at::ScalarType dtype_; + + std::vector raw_ptrs_; + size_t allocated_size_; + + int local_rank_; // identical to device index + int local_world_size_; + + bool multicast_; + void *multicast_ptr_; + size_t multicast_allocated_size_; + + detail::ipc::flavor ipc_flavor_; + + __host__ inline TKParallelTensor( + const at::Tensor &tensor, + int local_rank, + int local_world_size, + bool multicast + ) : data_(tensor), + shape_(tensor.sizes().vec()), + dtype_(tensor.scalar_type()), + raw_ptrs_(local_world_size, nullptr), + allocated_size_(tensor.nbytes()), + local_rank_(local_rank), + local_world_size_(local_world_size), + multicast_(multicast), + multicast_ptr_(nullptr), + multicast_allocated_size_(0), + ipc_flavor_(detail::ipc::flavor::LEGACY) { + + TORCH_CHECK(tensor.is_cuda(), "Tensor must be on CUDA device"); + TORCH_CHECK(tensor.is_contiguous(), "Tensor must be contiguous"); + TORCH_CHECK(tensor.dim() <= 4, "Only tensors with dim <= 4 are supported for TKParallelTensor"); + TORCH_CHECK(tensor.device().index() == local_rank_, "Tensor device index must match local_rank"); + TORCH_CHECK(local_rank_ >= 0, "local_rank must be non-negative"); + TORCH_CHECK(local_rank_ < local_world_size_, "local_rank must be less than local_world_size"); + TORCH_CHECK(!multicast, "Multicast is not supported for pre-allocated tensors"); + + brokers_.try_emplace( + {local_rank_, local_world_size_}, + local_rank_, local_world_size_ + ); + + if (brokers_.size() > 1) + std::cerr << "WARNING: 2 KittensBroker instances created in the same process. This is not safe." << std::endl; + + c10::cuda::CUDAGuard device_guard(local_rank_); + exchange_ipc_handles(); + } + + __host__ inline TKParallelTensor( + const std::vector &shape, + const at::ScalarType dtype, + int local_rank, + int local_world_size, + bool multicast + ) : shape_(shape), + dtype_(dtype), + raw_ptrs_(local_world_size, nullptr), + allocated_size_(0), + local_rank_(local_rank), + local_world_size_(local_world_size), + multicast_(multicast), + multicast_ptr_(nullptr), + multicast_allocated_size_(0), + ipc_flavor_(detail::ipc::flavor::VMM) { + + TORCH_CHECK(local_rank_ >= 0, "local_rank must be non-negative"); + TORCH_CHECK(local_rank_ < local_world_size_, "local_rank must be less than local_world_size"); + + brokers_.try_emplace( + {local_rank_, local_world_size_}, + local_rank_, local_world_size_ + ); + + if (brokers_.size() > 1) + std::cerr << "WARNING: 2 KittensBroker instances created in the same process. This is not safe." << std::endl; + + c10::cuda::CUDAGuard device_guard(local_rank_); + create_shareable_cuda_tensor(); + exchange_ipc_handles(); + + if (multicast_) + initialize_multicast(); + } + + TKParallelTensor(const TKParallelTensor&) = delete; + TKParallelTensor& operator=(const TKParallelTensor&) = delete; + TKParallelTensor& operator=(TKParallelTensor&& other) = delete; + + __host__ inline TKParallelTensor(TKParallelTensor&& other) : + data_(std::move(other.data_)), + shape_(std::move(other.shape_)), + dtype_(std::move(other.dtype_)), + raw_ptrs_(std::move(other.raw_ptrs_)), + allocated_size_(other.allocated_size_), + local_rank_(other.local_rank_), + local_world_size_(other.local_world_size_), + multicast_(other.multicast_), + multicast_ptr_(other.multicast_ptr_), + multicast_allocated_size_(other.multicast_allocated_size_), + ipc_flavor_(other.ipc_flavor_) { + other.data_ = at::Tensor(); + other.shape_.clear(); + other.dtype_ = at::ScalarType::Undefined; + other.raw_ptrs_.clear(); + other.allocated_size_ = 0; + other.local_rank_ = -1; + other.local_world_size_ = -1; + other.multicast_ = false; + other.multicast_ptr_ = nullptr; + other.multicast_allocated_size_ = 0; + } + + __host__ inline ~TKParallelTensor() { + destroy(); + } + + __host__ inline at::Tensor data() const { + return data_; + } + + __host__ inline void create_shareable_cuda_tensor() { + c10::cuda::CUDAGuard device_guard(local_rank_); + + TORCH_CHECK(!shape_.empty(), "Shape must be non-empty"); + TORCH_CHECK(shape_.size() <= 4, "Shape must have at most 4 dimensions for TKParallelTensor"); + size_t size = c10::elementSize(dtype_); + for (auto dim : shape_) { + TORCH_CHECK(dim > 0, "Size dimensions must be positive"); + size *= static_cast(dim); + } + + void *raw_ptr; + detail::vmm::vm_alloc_map_set_access( + &raw_ptr, &allocated_size_, size, local_rank_, local_world_size_); + + // Create local copies for capture + int local_rank = local_rank_; + size_t allocated_size = allocated_size_; + + auto deleter = [local_rank, raw_ptr, allocated_size](void* p) mutable { + if (!p) return; + c10::cuda::CUDAGuard device_guard(local_rank); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + CUDACHECK(cudaStreamSynchronize(stream)); + detail::vmm::vm_unmap(raw_ptr, allocated_size); + }; + + at::TensorOptions options = at::TensorOptions() + .dtype(dtype_) + .device(at::kCUDA, local_rank_); + + data_ = at::from_blob(raw_ptr, shape_, std::move(deleter), options); + } + + template + __host__ inline void exchange_ipc_handles() { + using handle_t = detail::ipc::handle; + + // Get IPC handle + detail::ipc::check_support(local_rank_); + void *raw_ptr = reinterpret_cast(data_.data_ptr()); + handle_t ipc_handle; + detail::ipc::export_handle(&ipc_handle, raw_ptr); + + // Exchange IPC handles + std::vector all_ipc_handles(local_world_size_); + if constexpr (IPC_FLAVOR == detail::ipc::flavor::LEGACY) { + brokers_.at({local_rank_, local_world_size_}).exchange_data( + reinterpret_cast(all_ipc_handles.data()), + reinterpret_cast(&ipc_handle), + sizeof(handle_t) + ); + } else if constexpr (IPC_FLAVOR == detail::ipc::flavor::VMM) { + brokers_.at({local_rank_, local_world_size_}).exchange_fds( + reinterpret_cast(all_ipc_handles.data()), + ipc_handle.handle_ + ); + } else { + throw std::runtime_error("Invalid IPC flavor"); + } + + // Import IPC handles + for (int i = 0; i < local_world_size_; i++) { + if (i == local_rank_) + raw_ptrs_[i] = raw_ptr; + else + detail::ipc::import_handle(&raw_ptrs_[i], all_ipc_handles[i], allocated_size_, local_world_size_); + } + } + + __host__ inline void initialize_multicast() { + using handle_t = detail::ipc::handle; + + detail::vmm::multicast_check(local_rank_); + detail::ipc::check_support(local_rank_); + detail::vmm::handle multicast_handle; + + if (local_rank_ == 0) { + // Create multicast handle; only a single rank should create MC handle + detail::vmm::multicast_create_handle( + &multicast_handle, + &multicast_allocated_size_, + allocated_size_, + local_world_size_ + ); + + // Currently, non-rank-0 path assumes allocated_size_ == multicast_allocated_size_ + if (allocated_size_ != multicast_allocated_size_) + throw std::runtime_error("Multicast allocated size does not match memory allocated size"); + + // Get IPC handle + handle_t ipc_handle; + detail::ipc::export_handle(&ipc_handle, multicast_handle); + + // Broadcast the IPC multicast handle + brokers_.at({local_rank_, local_world_size_}).broadcast_fd(nullptr, ipc_handle.handle_, 0); + } else { + // Receive the IPC multicast handle from rank 0 + handle_t ipc_handle; + brokers_.at({local_rank_, local_world_size_}).broadcast_fd(&ipc_handle.handle_, -1, 0); + multicast_allocated_size_ = allocated_size_; + detail::ipc::import_handle(&multicast_handle, ipc_handle, multicast_allocated_size_, local_world_size_); + } + + // Add all devices to the MC handle. Must sync + detail::vmm::multicast_bind_device(multicast_handle, local_rank_); + brokers_.at({local_rank_, local_world_size_}).sync(); // must ensure all devices are added + + // Bind all memory to the MC handle and map to a virtual address; must be done after adding all devices + detail::vmm::handle memory_handle; + detail::vmm::vm_retrieve_handle(&memory_handle, raw_ptrs_[local_rank_]); + detail::vmm::multicast_bind_memory(multicast_handle, memory_handle, allocated_size_); + brokers_.at({local_rank_, local_world_size_}).sync(); + + // Map virtual address to multicast handle and set access; must be done after adding all devices + detail::vmm::vm_map(&multicast_ptr_, multicast_handle, multicast_allocated_size_); + detail::vmm::vm_set_access(multicast_ptr_, multicast_allocated_size_, local_world_size_); + + // Free the handles immediately + detail::vmm::vm_free(multicast_handle); + detail::vmm::vm_free(memory_handle); + } + + __host__ inline void destroy() { + // 1. Multicast cleanup + if (multicast_ && multicast_ptr_) { + brokers_.at({local_rank_, local_world_size_}).sync(); + detail::vmm::handle multicast_handle; + detail::vmm::vm_retrieve_handle(&multicast_handle, multicast_ptr_); + detail::vmm::vm_unmap(multicast_ptr_, multicast_allocated_size_); + detail::vmm::multicast_unbind_device(multicast_handle, multicast_allocated_size_, local_rank_); + brokers_.at({local_rank_, local_world_size_}).sync(); + detail::vmm::vm_free(multicast_handle); + } + + // 2. Imported handle cleanup + for (int i = 0; i < local_world_size_; i++) { + if (i != local_rank_ && i < raw_ptrs_.size()) { + if (ipc_flavor_ == detail::ipc::flavor::LEGACY) { + detail::ipc::free_handle(raw_ptrs_[i], allocated_size_); + } else if (ipc_flavor_ == detail::ipc::flavor::VMM) { + detail::ipc::free_handle(raw_ptrs_[i], allocated_size_); + } else { + throw std::runtime_error("Invalid IPC flavor"); + } + } + } + brokers_.at({local_rank_, local_world_size_}).sync(); // must sync before destroying the tensor + + // 3. Tensor cleanup + if (data_.defined()) + data_.reset(); // properly decreases the ref count + + // 4. Member variables cleanup + shape_.clear(); + dtype_ = at::ScalarType::Undefined; + raw_ptrs_.clear(); + allocated_size_ = 0; + local_rank_ = -1; + local_world_size_ = -1; + multicast_ = false; + multicast_ptr_ = nullptr; + multicast_allocated_size_ = 0; + } +}; + +} // namespace py +} // namespace kittens + +#define BIND_TK_PARALLEL_TENSOR(m) \ + pybind11::class_(m, "TKParallelTensor") \ + .def(pybind11::init(), \ + pybind11::arg("tensor"), \ + pybind11::arg("local_rank"), \ + pybind11::arg("local_world_size"), \ + pybind11::arg("multicast") = false) \ + .def(pybind11::init&, const at::ScalarType&, int, int, bool>(), \ + pybind11::arg("shape"), \ + pybind11::arg("dtype"), \ + pybind11::arg("local_rank"), \ + pybind11::arg("local_world_size"), \ + pybind11::arg("multicast") = false) \ + .def("data", &kittens::py::TKParallelTensor::data) \ + .def_readonly("data_", &kittens::py::TKParallelTensor::data_) \ + .def_readonly("local_rank_", &kittens::py::TKParallelTensor::local_rank_) \ + .def_readonly("local_world_size_", &kittens::py::TKParallelTensor::local_world_size_) diff --git a/extra/thunder/cuda/include/pyutils/pyutils.cuh b/extra/thunder/cuda/include/pyutils/pyutils.cuh new file mode 100644 index 0000000000..0f3101927d --- /dev/null +++ b/extra/thunder/cuda/include/pyutils/pyutils.cuh @@ -0,0 +1,235 @@ +#pragma once + +#include "util.cuh" +#include +#include // for automatic Python list -> std::vector conversion + +namespace kittens { +namespace py { + +template struct from_object { + static T make(pybind11::object obj) { + return obj.cast(); + } + static T unwrap(pybind11::object obj, int dev_idx) { + return make(obj); // Scalars should be passed in as a scalar + } +}; +template struct from_object { + static GL make(pybind11::object obj) { + // Check if argument is a torch.Tensor + if (pybind11::hasattr(obj, "__class__") && + obj.attr("__class__").attr("__name__").cast() == "Tensor") { + + // Check if tensor is contiguous + if (!obj.attr("is_contiguous")().cast()) { + throw std::runtime_error("Tensor must be contiguous"); + } + if (obj.attr("device").attr("type").cast() == "cpu") { + throw std::runtime_error("Tensor must be on CUDA device"); + } + + // Get shape, pad with 1s if needed + std::array shape = {1, 1, 1, 1}; + auto py_shape = obj.attr("shape").cast(); + size_t dims = py_shape.size(); + if (dims > 4) { + throw std::runtime_error("Expected Tensor.ndim <= 4"); + } + for (size_t i = 0; i < dims; ++i) { + shape[4 - dims + i] = pybind11::cast(py_shape[i]); + } + + // Get data pointer using data_ptr() + uint64_t data_ptr = obj.attr("data_ptr")().cast(); + + // Create GL object using make_gl + return make_gl(data_ptr, shape[0], shape[1], shape[2], shape[3]); + } + throw std::runtime_error("Expected a torch.Tensor"); + } + static GL unwrap(pybind11::object obj, int dev_idx) { + if (!pybind11::isinstance(obj)) + throw std::runtime_error("GL unwrap expected a Python list."); + pybind11::list lst = pybind11::cast(obj); + if (dev_idx >= lst.size()) + throw std::runtime_error("Device index out of bounds."); + return *lst[dev_idx].cast>(); + } +}; +template struct from_object { + static PGL make(pybind11::object obj) { + static_assert(!PGL::MULTICAST, "Multicast not yet supported on pyutils. Please initialize the multicast pointer manually."); + if (!pybind11::isinstance(obj)) + throw std::runtime_error("PGL from_object expected a Python list."); + pybind11::list tensors = pybind11::cast(obj); + if (tensors.size() != PGL::num_devices) + throw std::runtime_error("Expected a list of " + std::to_string(PGL::num_devices) + " tensors"); + std::array shape = {1, 1, 1, 1}; + uint64_t data_ptrs[PGL::num_devices]; + for (int i = 0; i < PGL::num_devices; i++) { + auto tensor = tensors[i]; + if (!pybind11::hasattr(tensor, "__class__") || + tensor.attr("__class__").attr("__name__").cast() != "Tensor") + throw std::runtime_error("Expected a list of torch.Tensor"); + if (!tensor.attr("is_contiguous")().cast()) + throw std::runtime_error("Tensor must be contiguous"); + if (tensor.attr("device").attr("type").cast() == "cpu") + throw std::runtime_error("Tensor must be on CUDA device"); + auto py_shape = tensor.attr("shape").cast(); + size_t dims = py_shape.size(); + if (dims > 4) + throw std::runtime_error("Expected Tensor.ndim <= 4"); + for (size_t j = 0; j < dims; ++j) { + if (i == 0) + shape[4 - dims + j] = pybind11::cast(py_shape[j]); + else if (shape[4 - dims + j] != pybind11::cast(py_shape[j])) + throw std::runtime_error("All tensors must have the same shape"); + } + data_ptrs[i] = tensor.attr("data_ptr")().cast(); + } + return make_pgl(data_ptrs, shape[0], shape[1], shape[2], shape[3]); + } + static PGL unwrap(pybind11::object obj, int dev_idx) { + return *obj.cast>(); + } +}; + +static std::unordered_set registered; +template static void register_pyclass(pybind11::module &m) { + if constexpr (ducks::gl::all || ducks::pgl::all) { + std::string _typename = typeid(T).name(); + if (registered.find(_typename) == registered.end()) { + pybind11::class_>(m, _typename.c_str()); + registered.insert(_typename); + } + } +} +template static pybind11::object multigpu_make(pybind11::object obj) { + if constexpr (ducks::gl::all) { + if (!pybind11::isinstance(obj)) + throw std::runtime_error("multigpu_make [GL] expected a Python list."); + pybind11::list lst = pybind11::cast(obj); + std::vector> gls; + for (int i = 0; i < lst.size(); i++) + gls.push_back(std::make_shared(from_object::make(lst[i]))); + return pybind11::cast(gls); + } else if constexpr (ducks::pgl::all) { + return pybind11::cast(std::make_shared(from_object::make(obj))); + } else { + return pybind11::cast(from_object::make(obj)); + } +} + +template concept has_dynamic_shared_memory = requires(T t) { { t.dynamic_shared_memory() } -> std::convertible_to; }; +template concept is_multigpu_globals = requires { + { T::num_devices } -> std::convertible_to; + { T::dev_idx } -> std::convertible_to; +} && T::num_devices >= 1; + +template struct trait; +template struct trait { using member_type = MT; using type = T; }; +template using object = pybind11::object; +template static void bind_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) { + m.def(name, [](object... args, pybind11::kwargs kwargs) { + TGlobal __g__ {from_object::member_type>::make(args)...}; + cudaStream_t raw_stream = nullptr; + if (kwargs.contains("stream")) { + // Extract stream pointer + uintptr_t stream_ptr = kwargs["stream"].attr("cuda_stream").cast(); + raw_stream = reinterpret_cast(stream_ptr); + } + if constexpr (has_dynamic_shared_memory) { + int __dynamic_shared_memory__ = (int)__g__.dynamic_shared_memory(); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__); + kernel<<<__g__.grid(), __g__.block(), __dynamic_shared_memory__, raw_stream>>>(__g__); + } else { + kernel<<<__g__.grid(), __g__.block(), 0, raw_stream>>>(__g__); + } + }); +} +template static void bind_function(auto m, auto name, auto TGlobal::*... member_ptrs) { + m.def(name, [](object... args) { + TGlobal __g__ {from_object::member_type>::make(args)...}; + function(__g__); + }); +} +static void bind_multigpu_boilerplate(auto m) { + m.def("enable_all_p2p_access", [](const std::vector& device_ids) { + int device_count; + CUDACHECK(cudaGetDeviceCount(&device_count)); + if (device_count < device_ids.size()) + throw std::runtime_error("Not enough CUDA devices available"); + for (int i = 0; i < device_ids.size(); i++) { + CUDACHECK(cudaSetDevice(device_ids[i])); + for (int j = 0; j < device_ids.size(); j++) { + if (i == j) continue; + int can_access = 0; + CUDACHECK(cudaDeviceCanAccessPeer(&can_access, device_ids[i], device_ids[j])); + if (!can_access) + throw std::runtime_error("Device " + std::to_string(device_ids[i]) + " cannot access device " + std::to_string(device_ids[j])); + cudaError_t res = cudaDeviceEnablePeerAccess(device_ids[j], 0); + if (res != cudaSuccess && res != cudaErrorPeerAccessAlreadyEnabled) { + CUDACHECK(res); + } + } + } + }); + pybind11::class_>(m, "KittensClub") + .def(pybind11::init([](const std::vector& device_ids) { + int device_count; + CUDACHECK(cudaGetDeviceCount(&device_count)); + if (device_count < device_ids.size()) + throw std::runtime_error("Not enough CUDA devices available"); + auto club = std::make_shared(device_ids.data(), device_ids.size()); + club->execute([&](int dev_idx, cudaStream_t stream) {}); // warmup + return club; + }), pybind11::arg("device_ids")) + .def(pybind11::init([](const std::vector& device_ids, const std::vector& streams) { + int device_count; + CUDACHECK(cudaGetDeviceCount(&device_count)); + if (device_count < device_ids.size()) + throw std::runtime_error("Not enough CUDA devices available"); + if (streams.size() != device_ids.size()) + throw std::runtime_error("Number of streams must match number of devices"); + + std::vector raw_streams(streams.size()); + for (size_t i = 0; i < streams.size(); ++i) { + uintptr_t stream_ptr = streams[i].attr("cuda_stream").cast(); + raw_streams[i] = reinterpret_cast(stream_ptr); + } + + auto club = std::make_shared(device_ids.data(), raw_streams.data(), device_ids.size()); + club->execute([&](int dev_idx, cudaStream_t stream) {}); // warmup + return club; + }), pybind11::arg("device_ids"), pybind11::arg("streams")); +} +template static void bind_multigpu_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) { + static_assert(is_multigpu_globals, "Multigpu globals must have a member num_devices >= 1 and dev_idx"); + (register_pyclass::member_type>(m), ...); + m.def((std::string("make_globals_")+name).c_str(), [](object... args) -> std::vector { + return {multigpu_make::member_type>(args)...}; + }); + m.def(name, [](std::shared_ptr club, object... args) { + std::vector __g__; + for (int i = 0; i < TGlobal::num_devices; i++) { + __g__.emplace_back(from_object::member_type>::unwrap(args, i)...); + __g__.back().dev_idx = i; + } + if constexpr (has_dynamic_shared_memory) { + club->execute([&](int dev_idx, cudaStream_t stream) { + int __dynamic_shared_memory__ = (int)__g__[dev_idx].dynamic_shared_memory(); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__); + kernel<<<__g__[dev_idx].grid(), __g__[dev_idx].block(), __dynamic_shared_memory__, stream>>>(__g__[dev_idx]); + }); + } else { + club->execute([&](int dev_idx, cudaStream_t stream) { + kernel<<<__g__[dev_idx].grid(), __g__[dev_idx].block(), 0, stream>>>(__g__[dev_idx]); + }); + } + }); + // TODO: PGL destructor binding +} + +} // namespace py +} // namespace kittens diff --git a/extra/thunder/cuda/include/pyutils/torch_helpers.cuh b/extra/thunder/cuda/include/pyutils/torch_helpers.cuh new file mode 100644 index 0000000000..4b0f6b34d2 --- /dev/null +++ b/extra/thunder/cuda/include/pyutils/torch_helpers.cuh @@ -0,0 +1,7 @@ +#pragma once + +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) \ No newline at end of file diff --git a/extra/thunder/cuda/include/pyutils/torchutils.cuh b/extra/thunder/cuda/include/pyutils/torchutils.cuh new file mode 100644 index 0000000000..e2c4ca299c --- /dev/null +++ b/extra/thunder/cuda/include/pyutils/torchutils.cuh @@ -0,0 +1,180 @@ +#pragma once + +#include +#include + +#include "kittens.cuh" +#include "parallel_tensor.cuh" + +namespace kittens { +namespace py { + +template +concept has_min_blocks_per_sm = requires { std::integral_constant{}; }; + +template +consteval int min_blocks_per_sm() { + if constexpr(has_min_blocks_per_sm) + return Config::MIN_BLOCKS_PER_SM; + else + return 1; +} + +template +__global__ +__launch_bounds__(Config::NUM_THREADS, min_blocks_per_sm()) +void global_kernel_unclustered(const __grid_constant__ Globals G) { + Kernel(G); +} + +template +__global__ +__launch_bounds__(Config::NUM_THREADS, min_blocks_per_sm()) +__cluster_dims__(Config::CLUSTER_SIZE) +void global_kernel_clustered(const __grid_constant__ Globals G) { + Kernel(G); +} + +template +static inline void tensor_check(const at::Tensor &t) { + TORCH_CHECK(t.is_cuda(), "Tensor must be on CUDA device") + TORCH_CHECK(t.is_contiguous(), "Tensor must be contiguous") + TORCH_CHECK(t.dim() <= 4, "Expected Tensor.dim() <= 4"); + + if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Char, "Tensor has invalid dtype (expected int8)"); + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Short, "Tensor has invalid dtype (expected int16)"); + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Int, "Tensor has invalid dtype (expected int32)"); + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Long, "Tensor has invalid dtype (expected int64)"); + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Float8_e4m3fn, "Tensor has invalid dtype (expected fp8e4m3)"); + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Float8_e5m2, "Tensor has invalid dtype (expected fp8e5m2)"); +#ifdef KITTENS_BLACKWELL + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Byte, "Tensor has invalid dtype (expected fp8e8m0 represented as uint8)"); +#endif + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::BFloat16, "Tensor has invalid dtype (expected bfloat16)"); + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Half, "Tensor has invalid dtype (expected float16)"); + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Float, "Tensor has invalid dtype (expected float32)"); + } else if constexpr (std::is_same_v) { + TORCH_CHECK(t.dtype() == at::ScalarType::Double, "Tensor has invalid dtype (expected float64)"); + } else { + TORCH_CHECK(false, "Unsupported dtype"); + } +} + +template +static inline void parallel_tensor_check(const TKParallelTensor& t) { + tensor_check(t.data_); + TORCH_CHECK(t.data_.sizes().vec() == t.shape_, "Shape mismatch between TKParallelTensor and the underlying tensor"); + TORCH_CHECK(t.data_.dtype() == t.dtype_, "Dtype mismatch between TKParallelTensor and the underlying tensor"); + TORCH_CHECK(t.raw_ptrs_.size() == PGL::num_devices, "Number of devices mismatch between PGL and TKParallelTensor"); + TORCH_CHECK(t.local_rank_ == t.data_.device().index(), "Current tensor device index mismatch within TKParallelTensor"); + TORCH_CHECK(t.local_world_size_ == PGL::num_devices, "Number of devices mismatch between PGL and TKParallelTensor"); + TORCH_CHECK(t.multicast_ == PGL::multicast, "Multicast mismatch between PGL and TKParallelTensor"); + TORCH_CHECK(t.raw_ptrs_[t.local_rank_] == reinterpret_cast(t.data_.data_ptr()), "Current tensor data pointer not found in TKParallelTensor's raw_ptrs_"); +} + +template +static inline GL tensor_to_gl(const at::Tensor &t) { + tensor_check(t); + + std::array shape = {1, 1, 1, 1}; + for (int i = 0; i < static_cast(t.dim()); ++i) + shape[4 - t.dim() + i] = static_cast(t.size(i)); + + uint64_t data_ptr = reinterpret_cast(t.data_ptr()); + + return ::kittens::make_gl(data_ptr, shape[0], shape[1], shape[2], shape[3]); +} + +template +static inline PGL parallel_tensor_to_pgl(TKParallelTensor &t) { + parallel_tensor_check(t); + + std::array shape = {1, 1, 1, 1}; + for (int i = 0; i < static_cast(t.data_.dim()); ++i) { + shape[4 - t.data_.dim() + i] = static_cast(t.data_.size(i)); + } + + if constexpr (PGL::multicast) + return ::kittens::make_pgl( + reinterpret_cast(t.multicast_ptr_), reinterpret_cast(t.raw_ptrs_.data()), shape[0], shape[1], shape[2], shape[3]); + else + return ::kittens::make_pgl( + reinterpret_cast(t.raw_ptrs_.data()), shape[0], shape[1], shape[2], shape[3]); +} + +template +static inline GL make_fake_gl(const int batch, const int depth, const int rows, const int cols) { + return ::kittens::make_gl(reinterpret_cast(nullptr), batch, depth, rows, cols); +} + +static inline void _device_check(const at::Tensor& first, const at::Tensor& second) { + TORCH_CHECK(first.device() == second.device(), "All tensors must be on the same device"); +} + +template +static inline void device_check(const T1& first, const Ts&... rest) { + (_device_check(first, rest), ...); +} + +static inline void _parallel_tensor_check(const TKParallelTensor& first, const TKParallelTensor& second) { + TORCH_CHECK(first.local_rank_ == second.local_rank_, "All parallel tensors must have the same local_rank"); + TORCH_CHECK(first.local_world_size_ == second.local_world_size_, "All parallel tensors must have the same local_world_size"); +} + +template +static inline void parallel_tensor_check(const T1& first, const Ts&... rest) { + (_parallel_tensor_check(first, rest), ...); +} + +template +concept static_grid = requires { Config::NUM_BLOCKS; }; + +template +concept static_block = requires { Config::NUM_THREADS; }; + +template +concept static_dynamic_shared_memory = requires { Config::DYNAMIC_SHARED_MEMORY; }; + +template +static inline void launch_kernel(const Globals &G) { + dim3 grid; + if constexpr (static_grid) + grid = dim3{Config::NUM_BLOCKS, 1, 1}; + else + grid = G.grid(); + + dim3 block; + if constexpr (static_block) + block = dim3{Config::NUM_THREADS, 1, 1}; + else + block = G.block(); + + int dynamic_shared_memory; + if constexpr (static_dynamic_shared_memory) + dynamic_shared_memory = static_cast(Config::DYNAMIC_SHARED_MEMORY); + else + dynamic_shared_memory = G.dynamic_shared_memory(); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if constexpr (Config::CLUSTER_SIZE <= 1) { + CUDACHECK(cudaFuncSetAttribute(global_kernel_unclustered, cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_shared_memory)); + global_kernel_unclustered<<>>(G); + } else { + CUDACHECK(cudaFuncSetAttribute(global_kernel_clustered, cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_shared_memory)); + global_kernel_clustered<<>>(G); + } +} + +} // namespace py +} // namespace kittens diff --git a/extra/thunder/cuda/include/pyutils/util.cuh b/extra/thunder/cuda/include/pyutils/util.cuh new file mode 100644 index 0000000000..0f92c6d905 --- /dev/null +++ b/extra/thunder/cuda/include/pyutils/util.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include "../ops/ops.cuh" +#include "club.cuh" +#include + +#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__) +template +void check(T err, char const* const func, char const* const file, + int const line) +{ + if (err != cudaSuccess) + { + std::cerr << "CUDA Runtime Error at: " << file << ":" << line + << std::endl; + std::cerr << cudaGetErrorString(err) << " " << func << std::endl; + //std::exit(EXIT_FAILURE); + } +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/device/device.cuh b/extra/thunder/cuda/include/types/device/device.cuh new file mode 100644 index 0000000000..069d4af899 --- /dev/null +++ b/extra/thunder/cuda/include/types/device/device.cuh @@ -0,0 +1,12 @@ +/** + * @file + * @brief An aggregate header file for all the device types defined by ThunderKittens. + */ + +#pragma once + +#if defined(KITTENS_HOPPER) || defined(KITTENS_BLACKWELL) +#include "ipc.cuh" +#include "pgl.cuh" +#include "vmm.cuh" +#endif diff --git a/extra/thunder/cuda/include/types/device/ipc.cuh b/extra/thunder/cuda/include/types/device/ipc.cuh new file mode 100644 index 0000000000..6c3f09a8d5 --- /dev/null +++ b/extra/thunder/cuda/include/types/device/ipc.cuh @@ -0,0 +1,195 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "../../common/common.cuh" +#include "vmm.cuh" + +namespace kittens { +namespace ducks { +namespace ipc { +namespace handle { + +struct identifier {}; + +template concept all = requires { + typename T::identifier; +} && std::is_same_v; + +} // namespace handle +} // namespace ipc +} // namespace ducks + +namespace detail { +namespace ipc { + +enum flavor { + LEGACY = 0, + VMM = 1 +}; + +template +struct handle; + +template<> +struct handle { + using identifier = ducks::ipc::handle::identifier; + static constexpr flavor flavor_ = flavor::LEGACY; + cudaIpcMemHandle_t handle_ {}; +}; + +template<> +struct handle { + using identifier = ducks::ipc::handle::identifier; + static constexpr flavor flavor_ = flavor::VMM; + int handle_; +}; + +__host__ inline static void check_support(const int device_id) { + CUdevice device; + CUCHECK(cuDeviceGet(&device, device_id)); + + int ipc_supported = 0; + CUDACHECK(cudaDeviceGetAttribute(&ipc_supported, cudaDevAttrIpcEventSupport, device_id)); + int ipc_handle_supported = 0; + CUCHECK(cuDeviceGetAttribute(&ipc_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED, device)); + + if (!ipc_supported || !ipc_handle_supported) + throw std::runtime_error("CUDA IPC is not supported on this device"); +} + +template +__host__ inline static void export_handle( + IPC_HANDLE *ipc_handle, + void *ptr +) { + if constexpr (IPC_HANDLE::flavor_ == flavor::LEGACY) { + CUDACHECK(cudaIpcGetMemHandle(&ipc_handle->handle_, ptr)); + } else if constexpr (IPC_HANDLE::flavor_ == flavor::VMM) { + CUmemGenericAllocationHandle memory_handle; + detail::vmm::vm_retrieve_handle(&memory_handle, ptr); + // ** Important: this handle (FD) must be manually closed by the user ** + CUCHECK(cuMemExportToShareableHandle(&ipc_handle->handle_, memory_handle, detail::vmm::HANDLE_TYPE, 0)); + detail::vmm::vm_free(memory_handle); + } else { + throw std::runtime_error("Invalid IPC handle type"); + } +} + +template +__host__ inline static void export_handle( + IPC_HANDLE *ipc_handle, + CUmemGenericAllocationHandle &memory_handle +) { + if constexpr (IPC_HANDLE::flavor_ == flavor::VMM) { + CUCHECK(cuMemExportToShareableHandle(&ipc_handle->handle_, memory_handle, detail::vmm::HANDLE_TYPE, 0)); + } else { + throw std::runtime_error("Invalid IPC handle type"); + } +} + +template +__host__ inline static void import_handle ( + void **ptr, + IPC_HANDLE &ipc_handle, + const size_t size, + int local_world_size +) { + if constexpr (IPC_HANDLE::flavor_ == flavor::LEGACY) { + CUDACHECK(cudaIpcOpenMemHandle(ptr, ipc_handle.handle_, cudaIpcMemLazyEnablePeerAccess)); // this is the only flag supported + } else if constexpr (IPC_HANDLE::flavor_ == flavor::VMM) { + CUmemGenericAllocationHandle memory_handle; + CUCHECK(cuMemImportFromShareableHandle(&memory_handle, reinterpret_cast(static_cast(ipc_handle.handle_)), detail::vmm::HANDLE_TYPE)); + detail::vmm::vm_map(ptr, memory_handle, size); + detail::vmm::vm_set_access(*ptr, size, local_world_size); + detail::vmm::vm_free(memory_handle); + close(ipc_handle.handle_); // close fd immediately + ipc_handle.handle_ = -1; + } else { + throw std::runtime_error("Invalid IPC handle type"); + } +} + +template +__host__ inline static void import_handle ( + CUmemGenericAllocationHandle *memory_handle, + IPC_HANDLE &ipc_handle, + const size_t size, + int local_world_size +) { + if constexpr (IPC_HANDLE::flavor_ == flavor::VMM) { + CUCHECK(cuMemImportFromShareableHandle(memory_handle, reinterpret_cast(static_cast(ipc_handle.handle_)), detail::vmm::HANDLE_TYPE)); + close(ipc_handle.handle_); // close fd immediately + ipc_handle.handle_ = -1; + } else { + throw std::runtime_error("Invalid IPC handle type"); + } +} + +template +__host__ inline static void free_handle( + void *ptr, + const size_t size +) { + if constexpr (_flavor == flavor::LEGACY) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } else if constexpr (_flavor == flavor::VMM) { + detail::vmm::vm_unmap(ptr, size); + } else { + throw std::runtime_error("Invalid IPC handle type"); + } +} + +__host__ inline static void enable_all_peer_access(int num_devices) { + int num_available_devices; + CUCHECK(cuDeviceGetCount(&num_available_devices)); + if (num_available_devices < num_devices) + throw std::runtime_error("Not enough GPUs available"); + + std::vector devices(num_devices); + std::vector contexts(num_devices); + + for (int i = 0; i < num_devices; i++) { + CUCHECK(cuDeviceGet(&devices[i], i)); + CUCHECK(cuCtxCreate(&contexts[i], 0, devices[i])); + } + + for (int i = 0; i < num_devices; i++) { + int device_compute_mode; + CUCHECK(cuDeviceGetAttribute(&device_compute_mode, CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, devices[i])); + if (device_compute_mode != CU_COMPUTEMODE_DEFAULT) + throw std::runtime_error("Device is in an unsupported compute mode"); + + int vmm_supported = 0; + CUCHECK(cuDeviceGetAttribute(&vmm_supported, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, devices[i])); + if (!vmm_supported) + throw std::runtime_error("Device does not support CUDA VMM"); + + int ipc_handle_supported; + CUCHECK(cuDeviceGetAttribute(&ipc_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED, devices[i])); + if (!ipc_handle_supported) + throw std::runtime_error("Device does not support IPC handles"); + + for (int j = 0; j < num_devices; j++) { + if (i == j) continue; + int can_access_peer; + CUCHECK(cuDeviceCanAccessPeer(&can_access_peer, devices[i], devices[j])); + if (!can_access_peer) + throw std::runtime_error("Device cannot access peer device"); + CUCHECK(cuCtxSetCurrent(contexts[i])); + CUCHECK(cuCtxEnablePeerAccess(contexts[j], 0)); + } + } + + for (size_t i = 0; i < contexts.size(); ++i) + CUCHECK(cuCtxDestroy(contexts[i])); +} + +} // namespace ipc +} // namespace detail +} // namespace kittens diff --git a/extra/thunder/cuda/include/types/device/pgl.cuh b/extra/thunder/cuda/include/types/device/pgl.cuh new file mode 100644 index 0000000000..f0f7910603 --- /dev/null +++ b/extra/thunder/cuda/include/types/device/pgl.cuh @@ -0,0 +1,173 @@ +/** + * @file + * @brief Templated layouts for parallel global memory. + */ + +#pragma once + +#include "../../common/common.cuh" +#include "../shared/shared.cuh" +#include "../global/global.cuh" + +namespace kittens { + +/* ---------- Parallel global layout descriptor ---------- */ + +namespace ducks { +namespace pgl { + +struct identifier {}; + +/** + * @brief Concept for all parallel global layouts. + * @tparam T The type to check against the concept requirements. + * + * Requires: + * - T has a nested type identifier that is the same as ducks::pgl::identifier. + */ +template concept all = requires { + typename T::identifier; +} && std::is_same_v; + +} // namespace pgl +} // namespace ducks + +/** + * @brief Parallel global layout. Represents a region of data spread across multiple devices. + * @tparam GL The underlying global layout on each device. + * @tparam NUM_DEVICES The number of GPU devices. + * @tparam MULTICAST Whether the multicast object should be initialized by the caller. + * @tparam TMA_Types The types of TMA descriptors to use for the multicast locations. + Only valid if MULTICAST is true. + */ +template +struct pgl { + using identifier = ducks::pgl::identifier; + using GL = _GL; + using T = GL::dtype; + using dtype = T; + + static constexpr int num_devices = NUM_DEVICES; + static constexpr bool multicast = MULTICAST; + + T *mc_ptr; // multicast pointer; nullptr if MULTICAST is false + GL gls[NUM_DEVICES]; + + detail::descriptor_dict tma_descs; + + __host__ __device__ const GL &operator[](int idx) const { return gls[idx]; } + __device__ inline T* mc_ptr_at(const coord &idx) const { + static_assert(MULTICAST, "Multicast is not enabled for this PGL."); + const GL &gl = gls[0]; // all gls have the same shape + return &mc_ptr[((idx.b * gl.depth() + idx.d) * gl.rows() + idx.r) * gl.cols() + idx.c]; + } + + __host__ inline pgl(T **_data, // an array of NUM_DEVICES pointers to the data on each device + ducks::gl::make_arg_t _batch, + ducks::gl::make_arg_t _depth, + ducks::gl::make_arg_t _rows, + ducks::gl::make_arg_t _cols) : + pgl(std::make_index_sequence{}, _data, _batch, _depth, _rows, _cols) { } + + __host__ inline pgl(T *_mc_ptr, // multicast pointer, initialized by the caller + T **_data, // an array of NUM_DEVICES pointers to the data on each device + ducks::gl::make_arg_t _batch, + ducks::gl::make_arg_t _depth, + ducks::gl::make_arg_t _rows, + ducks::gl::make_arg_t _cols) : + pgl(std::make_index_sequence{}, _mc_ptr, _data, _batch, _depth, _rows, _cols) { } + + template + __host__ inline pgl(std::index_sequence, + T **_data, + ducks::gl::make_arg_t _batch, + ducks::gl::make_arg_t _depth, + ducks::gl::make_arg_t _rows, + ducks::gl::make_arg_t _cols) : + mc_ptr(nullptr), gls{GL(_data[I], _batch, _depth, _rows, _cols)...} { + static_assert(!MULTICAST, "Multicast pointer not passed to multicast-enabled PGL."); + } + + template + __host__ inline pgl(std::index_sequence, + T *_mc_ptr, + T **_data, + ducks::gl::make_arg_t _batch, + ducks::gl::make_arg_t _depth, + ducks::gl::make_arg_t _rows, + ducks::gl::make_arg_t _cols) : + mc_ptr(_mc_ptr), gls{GL(_data[I], _batch, _depth, _rows, _cols)...} { + static_assert(MULTICAST, "Multicast pointer passed to multicast-disabled PGL."); + tma_descs = detail::descriptor_dict( + mc_ptr, gls[0].batch_internal, gls[0].depth_internal, gls[0].rows_internal, gls[0].cols_internal); + } + + template + __device__ inline const CUtensorMap* get_tma() const { + return tma_descs.template get(); + } + + __host__ __device__ inline auto batch() const { return gls[0].batch(); } + __host__ __device__ inline auto depth() const { return gls[0].depth(); } + __host__ __device__ inline auto rows() const { return gls[0].rows(); } + __host__ __device__ inline auto cols() const { return gls[0].cols(); } + __host__ __device__ inline size_t numel() const { return static_cast(batch()) * depth() * rows() * cols(); } + + template __device__ inline size_t shape() const { return gls[0].template shape(); } + template __device__ inline size_t stride() const { return gls[0].template stride(); } +}; + +template __host__ inline PGL make_pgl( + uint64_t *data, int b, int d, int r, int c +) { + if constexpr (safe) { + if (PGL::GL::__b__ > 0 && b != PGL::GL::__b__) { + throw std::runtime_error("Batch dimension mismatch. Expected: " + std::to_string(PGL::GL::__b__) + ", Got: " + std::to_string(b)); + } + if (PGL::GL::__d__ > 0 && d != PGL::GL::__d__) { + throw std::runtime_error("Depth dimension mismatch. Expected: " + std::to_string(PGL::GL::__d__) + ", Got: " + std::to_string(d)); + } + if (PGL::GL::__r__ > 0 && r != PGL::GL::__r__) { + throw std::runtime_error("Row dimension mismatch. Expected: " + std::to_string(PGL::GL::__r__) + ", Got: " + std::to_string(r)); + } + if (PGL::GL::__c__ > 0 && c != PGL::GL::__c__) { + throw std::runtime_error("Column dimension mismatch. Expected: " + std::to_string(PGL::GL::__c__) + ", Got: " + std::to_string(c)); + } + } + return PGL( + reinterpret_cast(data), + make_unsafe_gl_arg(b), + make_unsafe_gl_arg(d), + make_unsafe_gl_arg(r), + make_unsafe_gl_arg(c) + ); +} + +template __host__ inline PGL make_pgl( + uint64_t mc_ptr, uint64_t *data, int b, int d, int r, int c +) { + if constexpr (safe) { + if (PGL::GL::__b__ > 0 && b != PGL::GL::__b__) { + throw std::runtime_error("Batch dimension mismatch. Expected: " + std::to_string(PGL::GL::__b__) + ", Got: " + std::to_string(b)); + } + if (PGL::GL::__d__ > 0 && d != PGL::GL::__d__) { + throw std::runtime_error("Depth dimension mismatch. Expected: " + std::to_string(PGL::GL::__d__) + ", Got: " + std::to_string(d)); + } + if (PGL::GL::__r__ > 0 && r != PGL::GL::__r__) { + throw std::runtime_error("Row dimension mismatch. Expected: " + std::to_string(PGL::GL::__r__) + ", Got: " + std::to_string(r)); + } + if (PGL::GL::__c__ > 0 && c != PGL::GL::__c__) { + throw std::runtime_error("Column dimension mismatch. Expected: " + std::to_string(PGL::GL::__c__) + ", Got: " + std::to_string(c)); + } + } + return PGL( + reinterpret_cast(mc_ptr), + reinterpret_cast(data), + make_unsafe_gl_arg(b), + make_unsafe_gl_arg(d), + make_unsafe_gl_arg(r), + make_unsafe_gl_arg(c) + ); +} + +} // namespace kittens diff --git a/extra/thunder/cuda/include/types/device/vmm.cuh b/extra/thunder/cuda/include/types/device/vmm.cuh new file mode 100644 index 0000000000..8b8d274ec2 --- /dev/null +++ b/extra/thunder/cuda/include/types/device/vmm.cuh @@ -0,0 +1,180 @@ +#pragma once + +#include +#include +#include + +#include "../../common/common.cuh" + +namespace kittens { +namespace detail { +namespace vmm { + +// Intra-node shareable handle type +// This makes the handle shareable with cuMemExportToShareableHandle/cuMemImportFromShareableHandle +static constexpr CUmemAllocationHandleType HANDLE_TYPE = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; + +typedef CUmemGenericAllocationHandle handle; + +__host__ inline static void vm_alloc( + CUmemGenericAllocationHandle *handle, + size_t *allocated_size, + const size_t size, + const int device_id +) { + CUmemAllocationProp prop = {}; + prop.location.id = device_id; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = HANDLE_TYPE; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + + size_t granularity; + CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED)); + *allocated_size = (size + granularity - 1) / granularity * granularity; // round-up + + CUCHECK(cuMemCreate(handle, *allocated_size, &prop, 0)); +} + +__host__ inline static void vm_map( + void **ptr, + const CUmemGenericAllocationHandle &handle, + const size_t size +) { + CUdeviceptr device_ptr; + CUCHECK(cuMemAddressReserve(&device_ptr, size, 0, 0, 0)); + CUCHECK(cuMemMap(device_ptr, size, 0, handle, 0)); + *ptr = (void *)device_ptr; +} + +__host__ inline static void vm_set_access( + void *ptr, + const size_t size, + const int num_devices +) { + std::vector descs(num_devices); + for (int i = 0; i < num_devices; i++) { + descs[i].location.id = i; + descs[i].location.type = CU_MEM_LOCATION_TYPE_DEVICE; + descs[i].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + } + CUCHECK(cuMemSetAccess(reinterpret_cast(ptr), size, descs.data(), num_devices)); +} + +__host__ inline static void vm_retrieve_handle( + CUmemGenericAllocationHandle *handle, + void *ptr +) { + // Every call to this requires a corresponding call to cuMemRelease + CUCHECK(cuMemRetainAllocationHandle(handle, ptr)); +} + +__host__ inline static void vm_unmap( + void *ptr, + const size_t size +) { + CUCHECK(cuMemUnmap(reinterpret_cast(ptr), size)); + CUCHECK(cuMemAddressFree(reinterpret_cast(ptr), size)); +} + +__host__ inline static void vm_free(CUmemGenericAllocationHandle &handle) { + // It is recommended to free the handle ASAP; the backing memory will + // only be freed when all handles AND address mappings are released + CUCHECK(cuMemRelease(handle)); +} + +__host__ inline static void vm_alloc_map_set_access( + void **ptr, + size_t *allocated_size, + const size_t size, + const int device_id, + const int num_devices +) { + CUmemGenericAllocationHandle handle; + vm_alloc(&handle, allocated_size, size, device_id); + vm_map(ptr, handle, *allocated_size); + vm_set_access(*ptr, *allocated_size, num_devices); + vm_free(handle); // release the handle ASAP +} + +__host__ inline static void multicast_check(const int device_id) { + CUdevice device; + CUCHECK(cuDeviceGet(&device, device_id)); + + int multicast_supported; + CUresult result = cuDeviceGetAttribute( + &multicast_supported, + CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, + device + ); + + if (!multicast_supported) + throw std::runtime_error("Device does not support multicast"); +} + +__host__ inline static void multicast_create_handle( + CUmemGenericAllocationHandle *handle, + size_t *allocated_size, + const size_t size, + const int num_devices +) { + if (num_devices <= 1) + throw std::runtime_error("Multicast requires at least 2 devices"); + + CUmulticastObjectProp prop = {}; + prop.numDevices = num_devices; + prop.handleTypes = HANDLE_TYPE; + + size_t granularity; + CUCHECK(cuMulticastGetGranularity(&granularity, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED)); + *allocated_size = (size + granularity - 1) / granularity * granularity; + prop.size = *allocated_size; + + // After this, the handle must be shared with all processes through MPI, KittensBroker, etc. + cuMulticastCreate(handle, &prop); +} + +__host__ inline static void multicast_bind_device( + const CUmemGenericAllocationHandle &handle, + const int device_id +) { + // All processes must sync after this, before binding any memory + CUdevice device; + CUCHECK(cuDeviceGet(&device, device_id)); + CUCHECK(cuMulticastAddDevice(handle, device)); +} + +__host__ inline static void multicast_bind_memory( + const CUmemGenericAllocationHandle &multicast_handle, + const CUmemGenericAllocationHandle &memory_handle, + const size_t size +) { + // All processes should finish adding device before calling this function + CUCHECK(cuMulticastBindMem(multicast_handle, 0, memory_handle, 0, size, 0)); +} + +__host__ inline static void multicast_bind_address( + const CUmemGenericAllocationHandle &multicast_handle, + void *ptr, + const size_t size +) { + // All processes should finish adding device before calling this function + CUmemGenericAllocationHandle memory_handle; + vm_retrieve_handle(&memory_handle, ptr); + multicast_bind_memory(multicast_handle, memory_handle, size); + vm_free(memory_handle); +} + +__host__ inline static void multicast_unbind_device( + const CUmemGenericAllocationHandle &handle, + const size_t size, + const int device_id +) { + // Unbinding memory is not needed + CUdevice device; + CUCHECK(cuDeviceGet(&device, device_id)); + CUCHECK(cuMulticastUnbind(handle, device, 0, size)); +} + +} // namespace vmm +} // namespace detail +} // namespace kittens diff --git a/extra/thunder/cuda/include/types/global/cgl.cuh b/extra/thunder/cuda/include/types/global/cgl.cuh new file mode 100644 index 0000000000..67565c02c6 --- /dev/null +++ b/extra/thunder/cuda/include/types/global/cgl.cuh @@ -0,0 +1,56 @@ +/** + * @file + * @brief Templated layouts for complex global memory. + */ + +#pragma once + +#include "../../common/common.cuh" +#include "../shared/cst.cuh" +#include "gl.cuh" +#include "util.cuh" +#ifdef KITTENS_HOPPER +#include "tma.cuh" +#endif + +namespace kittens { + +/* ---------- Global layout descriptor ---------- */ + +namespace ducks { +namespace cgl { +struct identifier {}; +} +} + +// namespace detail { +// template concept tile = ducks::cst::all || ducks::crt::all; +// template concept vec = ducks::csv::all || ducks::crv::all; +// } + +template +struct cgl { + using identifier = ducks::cgl::identifier; + using component = _GL; + using T = component::T; + using T2 = component::T2; + using dtype = component::dtype; + component real, imag; +}; + +namespace ducks { +namespace cgl { +/** +* @brief Concept for all complex global layouts. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as ducks::cgl::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::cgl::identifier +} +} + +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/global/gl.cuh b/extra/thunder/cuda/include/types/global/gl.cuh new file mode 100644 index 0000000000..d7eceae2ee --- /dev/null +++ b/extra/thunder/cuda/include/types/global/gl.cuh @@ -0,0 +1,225 @@ +/** + * @file + * @brief Templated layouts for global memory. + */ + +#pragma once + +#include "../../common/common.cuh" +#include "../shared/shared.cuh" +#include "util.cuh" +#ifdef KITTENS_HOPPER +#include +#include "tma.cuh" +#endif + +namespace kittens { + +/* ---------- Global layout axes ---------- */ + +struct dim { + static constexpr int BATCH = 0; + static constexpr int DEPTH = 1; + static constexpr int ROW = 2; + static constexpr int COL = 3; +}; + +/* ---------- Associative dictionary for global layouts ---------- */ + +#ifdef KITTENS_HOPPER +namespace ducks { +namespace tma { +namespace descriptor { +struct identifier {}; +template concept all = requires { + typename T::identifier; +} && std::is_same_v; +} // namespace descriptor +} // namespace tma +} // namespace ducks +namespace detail { +namespace tma { +template struct descriptor_copy_helper {}; +template struct descriptor_copy_helper<_T> { static constexpr int value = _T::axis; using T = _T::T; static constexpr bool swizzle_flag = _T::swizzle_flag; }; +template struct descriptor_copy_helper<_T> { static constexpr int value = 2; using T = _T; static constexpr bool swizzle_flag = true; }; +template struct descriptor_copy_helper<_T> { static constexpr int value = -1; using T = _T; static constexpr bool swizzle_flag = true; }; +template using descriptor_copy_helper_t = descriptor_copy_helper::T; +template static constexpr int descriptor_copy_helper_v = descriptor_copy_helper::value; +template static constexpr bool descriptor_copy_helper_swizzle_flag = descriptor_copy_helper::swizzle_flag; +} // namespace tma +} // namespace detail +namespace tma { +template struct descriptor { + using identifier = ducks::tma::descriptor::identifier; + using T = detail::tma::descriptor_copy_helper_t<_T>; + static_assert(ducks::st::all || ducks::sv::all || ducks::tma::descriptor::all, "Must be a shared TK type to generate a TMA descriptor."); + static constexpr int axis = ( + ducks::tma::descriptor::all<_T> ? detail::tma::descriptor_copy_helper_v<_T> : // if a copy, inherit the axis from the original descriptor. + (_axis != -9999) ? _axis : detail::tma::descriptor_copy_helper_v<_T>); // if a default value was provided, use it. + static_assert((kittens::ducks::st::all && axis >= 0 && axis <= 2) || (kittens::ducks::sv::all && axis == -1), "Internal template error detected."); + static constexpr bool swizzle_flag = ducks::tma::descriptor::all<_T> ? detail::tma::descriptor_copy_helper_swizzle_flag<_T> : _swizzle_flag; +}; +} // namespace tma +#endif + +namespace detail { +template +struct descriptor_dict { + __host__ descriptor_dict() {} + template __host__ descriptor_dict(T _, int b, int d, int r, int c) {} + __host__ __device__ descriptor_dict(const descriptor_dict &other) {} +#ifdef KITTENS_HOPPER + template __device__ const CUtensorMap* get() const { + static_assert( + std::is_same_v && std::is_same_v, + "SKILL ISSUE: Requested a TMA descriptor for a type not initialized in the global layout." + ); + } +#endif +}; + +#ifdef KITTENS_HOPPER +template +struct descriptor_dict<_T, Args...> { + static_assert(ducks::sv::all<_T> || ducks::st::all<_T> || ducks::tma::descriptor::all<_T>, "Must be a shared TK type to generate a TMA descriptor."); + using DESC = kittens::tma::descriptor<_T>; // copy or initialize with a default value + CUtensorMap tma_desc; + descriptor_dict other_descs; + __host__ descriptor_dict() {} + __host__ descriptor_dict(typename DESC::T::dtype *data, int b, int d, int r, int c): other_descs(data, b, d, r, c) { + kittens::detail::tma::create_tensor_map(&tma_desc, data, b, d, r, c); + } + __host__ __device__ inline descriptor_dict(const descriptor_dict &other) : + tma_desc(other.tma_desc), other_descs(other.other_descs) {} + template __device__ inline const CUtensorMap* get() const { + if constexpr (std::is_same_v && DESC::axis == axis) { return &tma_desc; } + else { return other_descs.template get(); } + } +}; +#endif +} + +/* ---------- Global layout descriptor ---------- */ + +namespace ducks { +namespace gl { +struct identifier {}; +} +} + +template +struct gl { + using identifier = ducks::gl::identifier; + + using T = base_types::packing<_T>::unpacked_type; + using T2 = base_types::packing<_T>::packed_type; + using dtype = T; + + T* raw_ptr; + + static constexpr int __b__ = b, __d__ = d, __r__ = r, __c__ = c; // Not to be touched by the user. + + ducks::gl::make_dim_t batch_internal; + ducks::gl::make_dim_t depth_internal; + ducks::gl::make_dim_t rows_internal; + ducks::gl::make_dim_t cols_internal; + + template __device__ __host__ static constexpr std::enable_if_t<(B > 0), int> batch() { return B; } + template __device__ __host__ std::enable_if_t<(B == -1), int> batch() const { return batch_internal; } + template __device__ __host__ static constexpr std::enable_if_t<(D > 0), int> depth() { return D; } + template __device__ __host__ std::enable_if_t<(D == -1), int> depth() const { return depth_internal; } + template __device__ __host__ static constexpr std::enable_if_t<(R > 0), int> rows() { return R; } + template __device__ __host__ std::enable_if_t<(R == -1), int> rows() const { return rows_internal; } + template __device__ __host__ static constexpr std::enable_if_t<(C > 0), int> cols() { return C; } + template __device__ __host__ std::enable_if_t<(C == -1), int> cols() const { return cols_internal; } + + detail::descriptor_dict tma_descs; + + __host__ inline gl(T *_data, + ducks::gl::make_arg_t _batch, + ducks::gl::make_arg_t _depth, + ducks::gl::make_arg_t _rows, + ducks::gl::make_arg_t _cols) : + raw_ptr(_data), batch_internal(_batch), depth_internal(_depth), rows_internal(_rows), cols_internal(_cols) { + tma_descs = detail::descriptor_dict(raw_ptr, batch_internal, depth_internal, rows_internal, cols_internal); + } + __host__ __device__ inline gl(const gl &other) : + raw_ptr(other.raw_ptr), batch_internal(other.batch_internal), depth_internal(other.depth_internal), rows_internal(other.rows_internal), cols_internal(other.cols_internal), tma_descs(other.tma_descs) {} +#ifdef KITTENS_HOPPER + template __device__ inline const CUtensorMap* get_tma() const { + return tma_descs.template get(); + } +#endif + __device__ inline T& operator[](const coord &idx) const { // yes I am abusing the const qualifier here a bit. + return raw_ptr[((idx.b*depth() + idx.d)*rows() + idx.r)*cols() + idx.c]; + } + template __device__ inline size_t shape() const { + static_assert(axis==0 || axis==1 || axis==2 || axis==3, "Axis must be 0, 1, 2, or 3."); + if constexpr (axis==0) { return size_t(batch()); } + else if constexpr (axis==1) { return size_t(depth()); } + else if constexpr (axis==2) { return size_t(rows()); } + else if constexpr (axis==3) { return size_t(cols()); } + } + template __device__ inline size_t stride() const { + static_assert(axis==0 || axis==1 || axis==2 || axis==3, "Axis must be 0, 1, 2, or 3."); + if constexpr (axis==0) { return depth()*rows()*cols(); } + else if constexpr (axis==1) { return rows()*cols(); } + else if constexpr (axis==2) { return cols(); } + else if constexpr (axis==3) { return 1; } + } +}; + +template using gl3 = gl<_T, 1, d, r, c, TMA_Types...>; +template using gl2 = gl<_T, 1, 1, r, c, TMA_Types...>; +template using gl1 = gl<_T, 1, 1, 1, c, TMA_Types...>; + +namespace ducks { +namespace gl { +/** +* @brief Concept for all global layouts. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as ducks::gl::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::gl::identifier +} +} + +// Structs for initializing global layouts automatically. +// struct unsafe_gl { +// uint64_t data; +// int b, d, r, c; +// unsafe_gl(uint64_t data, int b, int d, int r, int c) : data(data), b(b), d(d), r(r), c(c) {} +// }; +template auto make_unsafe_gl_arg(int param) { // typename std::conditional_t<(N < 0), std::nullptr_t, int> + if constexpr (N > 0) { return nullptr; } + else { return param; } +} +template __host__ inline GL make_gl(uint64_t data, int b, int d, int r, int c) { + if constexpr (safe) { + if(GL::__b__ > 0 && b != GL::__b__) { + throw std::runtime_error("Batch dimension mismatch. Expected: " + std::to_string(GL::__b__) + ", Got: " + std::to_string(b)); + } + if(GL::__d__ > 0 && d != GL::__d__) { + throw std::runtime_error("Depth dimension mismatch. Expected: " + std::to_string(GL::__d__) + ", Got: " + std::to_string(d)); + } + if(GL::__r__ > 0 && r != GL::__r__) { + throw std::runtime_error("Row dimension mismatch. Expected: " + std::to_string(GL::__r__) + ", Got: " + std::to_string(r)); + } + if(GL::__c__ > 0 && c != GL::__c__) { + throw std::runtime_error("Column dimension mismatch. Expected: " + std::to_string(GL::__c__) + ", Got: " + std::to_string(c)); + } + } + return GL( + reinterpret_cast(data), + make_unsafe_gl_arg(b), + make_unsafe_gl_arg(d), + make_unsafe_gl_arg(r), + make_unsafe_gl_arg(c) + ); +} + +} // namespace kittens diff --git a/extra/thunder/cuda/include/types/global/global.cuh b/extra/thunder/cuda/include/types/global/global.cuh new file mode 100644 index 0000000000..00d894626b --- /dev/null +++ b/extra/thunder/cuda/include/types/global/global.cuh @@ -0,0 +1,13 @@ +/** + * @file + * @brief An aggregate header file for all the global types defined by ThunderKittens. + */ + +#pragma once + +#ifdef KITTENS_HOPPER +#include "tma.cuh" +#endif +#include "util.cuh" +#include "gl.cuh" +#include "cgl.cuh" diff --git a/extra/thunder/cuda/include/types/global/tma.cuh b/extra/thunder/cuda/include/types/global/tma.cuh new file mode 100644 index 0000000000..c52c266d80 --- /dev/null +++ b/extra/thunder/cuda/include/types/global/tma.cuh @@ -0,0 +1,428 @@ +#pragma once + +#include +#include +#include +#include // for std::hash +#include +#include +#include "../../common/common.cuh" +#include "../shared/shared.cuh" + +namespace kittens { +namespace detail { +namespace tma { + +__host__ static inline std::string format_tma_error( + const char* error_type, + const char* error_string, + int batch, int depth, int rows, int cols, + CUtensorMap* tma_map, + CUtensorMapDataType tma_format, + uint32_t tma_dim, + void* global_addr, + const uint64_t* gmem_shape, + const uint64_t* gmem_stride, + const uint32_t* smem_shape, + const uint32_t* smem_stride, + size_t gmem_shape_size, + size_t gmem_stride_size, + size_t smem_shape_size, + size_t smem_stride_size, + CUtensorMapInterleave tma_interleave, + CUtensorMapSwizzle tma_swizzle, + CUtensorMapL2promotion tma_l2Promotion, + CUtensorMapFloatOOBfill tma_oobFill, + const std::string& extra_info = "" +) { + std::ostringstream oss; + oss << "Error in " << error_type << " TMA descriptor creation: "; + oss << (error_string ? error_string : "Unknown CUDA error"); + oss << "\nParameters:"; + oss << "\n batch: " << batch; + oss << "\n depth: " << depth; + oss << "\n rows: " << rows; + oss << "\n cols: " << cols; + if (!extra_info.empty()) + oss << "\n " << extra_info; + + oss << "\ncuTensorMapEncodeTiled arguments:"; + oss << "\n tma_map: " << reinterpret_cast(tma_map); + oss << "\n tma_format: " << tma_format; + oss << "\n tma_dim: " << tma_dim; + oss << "\n global_addr: " << reinterpret_cast(global_addr); + + // Check if global_addr is valid device memory + cudaPointerAttributes attributes; + cudaError_t err = cudaPointerGetAttributes(&attributes, global_addr); + if (err == cudaSuccess) { + oss << "\n global_addr memory type: "; + if (attributes.type == cudaMemoryTypeDevice) { + oss << "valid device memory"; + } else if (attributes.type == cudaMemoryTypeHost) { + oss << "host memory (invalid for TMA)"; + } else if (attributes.type == cudaMemoryTypeManaged) { + oss << "managed memory"; + } else { + oss << "unknown memory type"; + } + } else { + oss << "\n global_addr memory type: unable to determine (error: " << cudaGetErrorString(err) << ")"; + } + + oss << "\n gmem_shape: " << reinterpret_cast(gmem_shape) << " ["; + for (size_t i = 0; i < gmem_shape_size; ++i) + oss << gmem_shape[i] << (i < gmem_shape_size - 1 ? ", " : ""); + oss << "]"; + + oss << "\n gmem_stride: " << reinterpret_cast(gmem_stride) << " ["; + for (size_t i = 0; i < gmem_stride_size; ++i) + oss << gmem_stride[i] << (i < gmem_stride_size - 1 ? ", " : ""); + oss << "]"; + + oss << "\n smem_shape: " << reinterpret_cast(smem_shape) << " ["; + for (size_t i = 0; i < smem_shape_size; ++i) + oss << smem_shape[i] << (i < smem_shape_size - 1 ? ", " : ""); + oss << "]"; + + oss << "\n smem_stride: " << reinterpret_cast(smem_stride) << " ["; + for (size_t i = 0; i < smem_stride_size; ++i) + oss << smem_stride[i] << (i < smem_stride_size - 1 ? ", " : ""); + oss << "]"; + + oss << "\n tma_interleave: " << tma_interleave; + oss << "\n tma_swizzle: " << tma_swizzle; + oss << "\n tma_l2Promotion: " << tma_l2Promotion; + oss << "\n tma_oobFill: " << tma_oobFill; + + return oss.str(); +} + +/* ---------- Create tile tensor map descriptor (HOST) ---------- */ + +/** +* @brief Creates a tensor map for the given source tensor. +* +* This function creates a tensor map (CUtensorMap) for the specified source shared tile type. The tensor map +* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor +* map based on the provided source tensor pointer and the layout specified by the ST template parameter. +* +* @tparam ST The source tensor type, which must be TMA-compatible. +* @tparam blocks_height The number of tiles present on the height axis in global memory. +* @tparam blocks_width The number of tiles present on the width axis in global memory. Defaults to 1. +* @param tma_map Pointer to the CUtensorMap object to be initialized. +* @param src Pointer to the source tensor data in global memory. +*/ +template +__host__ static inline void create_tensor_map(CUtensorMap *tma_map, const typename ST::dtype *src, int batch, int depth, int rows, int cols) { + using dtype = typename ST::dtype; + static_assert(axis==0 || axis==1 || axis==2, "axis must be 0, 1, or 2"); + + constexpr uint32_t tma_dim = enable_swizzle ? 5 : 4; + void *global_addr = (void*)(src); + + constexpr CUtensorMapDataType tma_format = ( + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_FLOAT16 : + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_FLOAT32 : + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_UINT8 : + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_UINT8 : +#ifdef KITTENS_BLACKWELL + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_UINT8 : +#endif + CUtensorMapDataType(-1) + ); + constexpr CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + constexpr CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + constexpr CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + constexpr CUtensorMapSwizzle tma_swizzle = enable_swizzle ? ( + ST::swizzle_bytes == 32 ? CU_TENSOR_MAP_SWIZZLE_32B : + ST::swizzle_bytes == 64 ? CU_TENSOR_MAP_SWIZZLE_64B : + ST::swizzle_bytes == 128 ? CU_TENSOR_MAP_SWIZZLE_128B : + CU_TENSOR_MAP_SWIZZLE_NONE + ) : CU_TENSOR_MAP_SWIZZLE_NONE; + + // Works for tma_dim = 4 too + uint64_t gmem_shape [5] = {0, 0, 0, 0, 0}; + uint64_t gmem_stride[4] = {0, 0, 0, 0}; + uint32_t smem_shape [5] = {0, 0, 0, 0, 0}; + uint32_t smem_stride[5] = {1, 1, 1, 1, 1}; + + constexpr uint64_t shared_tile_height = ST::rows; + constexpr uint64_t shared_tile_width = ST::cols; + + constexpr int swizzle_elements = ST::swizzle_bytes / sizeof(dtype); + + if constexpr (enable_swizzle) { + if constexpr (axis == 2) { + gmem_shape[0] = swizzle_elements; + gmem_shape[1] = (uint64_t)rows; + gmem_shape[2] = (uint64_t)(cols+swizzle_elements-1) / swizzle_elements; // round up, note this can potentially screw up out of bounds access handling :/ + gmem_shape[3] = (uint64_t)depth; + gmem_shape[4] = (uint64_t)batch; + + gmem_stride[0] = (uint64_t)cols * sizeof(dtype); + gmem_stride[1] = ST::swizzle_bytes; + gmem_stride[2] = (uint64_t)rows * cols * sizeof(dtype); + gmem_stride[3] = (uint64_t)depth * rows * cols * sizeof(dtype); + } + else if constexpr (axis == 1) { + gmem_shape[0] = swizzle_elements; + gmem_shape[1] = (uint64_t)depth; + gmem_shape[2] = (uint64_t)(cols+swizzle_elements-1) / swizzle_elements; // round up, note this can potentially screw up out of bounds access handling :/ + gmem_shape[3] = (uint64_t)rows; + gmem_shape[4] = (uint64_t)batch; + + gmem_stride[0] = (uint64_t)rows * cols * sizeof(dtype); + gmem_stride[1] = ST::swizzle_bytes; + gmem_stride[2] = (uint64_t)cols * sizeof(dtype); + gmem_stride[3] = (uint64_t)depth * rows * cols * sizeof(dtype); + + } + else { + gmem_shape[0] = swizzle_elements; + gmem_shape[1] = (uint64_t)batch; + gmem_shape[2] = (uint64_t)(cols+swizzle_elements-1) / swizzle_elements; // round up, note this can potentially screw up out of bounds access handling :/ + gmem_shape[3] = (uint64_t)rows; + gmem_shape[4] = (uint64_t)depth; + + gmem_stride[0] = (uint64_t)depth * rows * cols * sizeof(dtype); + gmem_stride[1] = ST::swizzle_bytes; + gmem_stride[2] = (uint64_t)cols * sizeof(dtype); + gmem_stride[3] = (uint64_t)rows * cols * sizeof(dtype); + } + smem_shape[0] = swizzle_elements; + smem_shape[1] = shared_tile_height; + smem_shape[2] = shared_tile_width / swizzle_elements; + smem_shape[3] = 1; + smem_shape[4] = 1; + } else { + gmem_shape[0] = (uint64_t)cols; + gmem_shape[1] = (uint64_t)rows; + gmem_shape[2] = (uint64_t)depth; + gmem_shape[3] = (uint64_t)batch; + + gmem_stride[0] = (uint64_t)cols * sizeof(dtype); + gmem_stride[1] = (uint64_t)rows * cols * sizeof(dtype); + gmem_stride[2] = (uint64_t)depth * rows * cols * sizeof(dtype); + + smem_shape[0] = shared_tile_width; + smem_shape[1] = shared_tile_height; + smem_shape[2] = 1; + smem_shape[3] = 1; + } + + // ensure that the global address is always 16-byte aligned + assert((reinterpret_cast(global_addr) & 0b1111) == 0); + + assert(gmem_stride[0] % 16 == 0); // gmem_stride[0] elements must be a multiple of 16B + assert(gmem_stride[1] % 16 == 0); // gmem_stride[1] elements must be a multiple of 16B + assert(gmem_stride[2] % 16 == 0); // gmem_stride[2] elements must be a multiple of 16B + assert(gmem_stride[3] % 16 == 0); // gmem_stride[2] elements must be a multiple of 16B + + assert(smem_shape[0] <= 256); // smem_shape[0] elements must be <= 256 + assert(smem_shape[1] <= 256); // smem_shape[1] elements must be <= 256 + assert(smem_shape[2] <= 256); // smem_shape[2] elements must be <= 256 + + assert((smem_shape[0]*sizeof(dtype)) % 16 == 0); // if wgmma_interleave is none, then smem_shape[0] * sizeof(dtype) must be a multiple of 16B + + assert(smem_stride[0] <= 8); // smem_stride[0] must be less <= 8 + assert(smem_stride[1] <= 8); // smem_stride[1] must be less <= 8 + assert(smem_stride[2] <= 8); // smem_stride[2] must be less <= 8 + assert(smem_stride[3] <= 8); // smem_stride[3] must be less <= 8 + assert(smem_stride[4] <= 8); // smem_stride[3] must be less <= 8 + + assert(smem_stride[0] == 1); // smem_stride[0] is ignored when wgmma_interleave is none + + if constexpr (tma_interleave == CU_TENSOR_MAP_INTERLEAVE_NONE && tma_swizzle != CU_TENSOR_MAP_SWIZZLE_NONE) { + assert(smem_shape[0] * sizeof(dtype) <= ST::swizzle_bytes); + } + + const uint64_t *gmem_shape_ptr = &gmem_shape[0]; + const uint64_t *gmem_stride_ptr = &gmem_stride[0]; + const uint32_t *smem_shape_ptr = &smem_shape[0]; + const uint32_t *smem_stride_ptr = &smem_stride[0]; + + CUresult result = cuTensorMapEncodeTiled( + tma_map, + tma_format, + tma_dim, + global_addr, + gmem_shape_ptr, + gmem_stride_ptr, + smem_shape_ptr, + smem_stride_ptr, + tma_interleave, + tma_swizzle, + tma_l2Promotion, + tma_oobFill); + + const char *error_string; + CUresult res = cuGetErrorString(result, &error_string); + if (result != CUDA_SUCCESS) { + std::string error_msg = format_tma_error( + "tile", error_string, + batch, depth, rows, cols, + tma_map, tma_format, tma_dim, global_addr, + gmem_shape_ptr, gmem_stride_ptr, + smem_shape_ptr, smem_stride_ptr, + 5, 4, 5, 5, + tma_interleave, tma_swizzle, tma_l2Promotion, tma_oobFill, + "ST::rows: " + std::to_string(ST::rows) + "\n ST::cols: " + std::to_string(ST::cols) + ); + throw std::runtime_error(error_msg); + } +} + +/** +* @brief Allocates on the GPU and initializes a tensor map for the given source tensor. +* +* This function creates a tensor map (CUtensorMap) for the specified source shared tile type. The tensor map +* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor +* map based on the provided source tensor pointer and the layout specified by the ST template parameter. +* +* @tparam ST The source tensor type, which must be TMA-compatible. +* @tparam blocks_height The number of tiles present on the height axis in global memory. +* @tparam blocks_width The number of tiles present on the width axis in global memory. Defaults to 1. +* @param src Pointer to the source tensor data in global memory. +* @returns Pointer to the CUtensorMap object to be initialized. +*/ +template +__host__ static inline CUtensorMap* allocate_and_create_tensor_map(const typename ST::dtype *src, int batch, int depth, int rows, int cols) { + CUtensorMap *tma_map_d; + cudaMalloc(&tma_map_d, sizeof(CUtensorMap)); + CUtensorMap tma_map_host; // put it on the stack, why not. + create_tensor_map(&tma_map_host, src, batch, depth, rows, cols); + cudaMemcpy(tma_map_d, &tma_map_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + return tma_map_d; +} + +/* ---------- Create vector tensor map descriptor (HOST) ---------- */ + +// First, we need a template system to determine how to divide up a long shared vector into multiple subvectors. +// We have to do this because the first dimension for TMA is limited to 256 elements. +// Our goal is to find the largest multiple of 16 that is <= 256 and divides the vector length evenly. + +template struct find_vector_divider { + static constexpr int value = (SV::length % (16*D) == 0 && (SV::length < 256 || ((16*D)*sizeof(typename SV::dtype)) % 128 == 0)) ? + 16*D : find_vector_divider::value; +}; +template struct find_vector_divider { static constexpr int value = 16; }; // base case +template constexpr int sv_tma_dim1 = find_vector_divider::value; // inner dim +template constexpr int sv_tma_dim2 = (SV::length / sv_tma_dim1); + +/** +* @brief Creates a tensor map for the given source vector. +* +* This function creates a tensor map (CUtensorMap) for the specified source shared vector type. The tensor map +* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor +* map based on the provided source tensor pointer and the layout specified by the SV template parameter. +* +* @tparam SV The source tensor type, which must be TMA-compatible. +* @tparam num_vectors The number of vectors present in global memory. +* @param tma_map Pointer to the CUtensorMap object to be initialized. +* @param src Pointer to the source tensor data in global memory. +*/ +template +__host__ static inline void create_tensor_map(CUtensorMap *tma_map, const typename SV::dtype *src, int batch, int depth, int rows, int cols) { + using dtype = typename SV::dtype; + static_assert(axis == -1, "for vector TMA, row axis must be -1 as it's unused"); + static_assert(SV::length <= 256 || (SV::length*sizeof(dtype)) % 128 == 0); + // There is technically a way around ^ that involves instantiating two separate TMA descriptors, one of size 256 + // and the other of size %256, but this is a fairly mild restriction and the other approach is a real PITA and incurs other costs. + static_assert(disable_swizzle, "for vector TMA, swizzle should be disabled"); + + constexpr uint32_t tma_dim = 4; + void *global_addr = (void*)(src); + + constexpr CUtensorMapDataType tma_format = ( + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 : + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_FLOAT16 : + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_FLOAT32 : + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_UINT8 : + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_UINT8 : +#ifdef KITTENS_BLACKWELL + std::is_same_v ? CU_TENSOR_MAP_DATA_TYPE_UINT8 : +#endif + CUtensorMapDataType(-1) + ); + constexpr CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + constexpr CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE; + constexpr CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + constexpr CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE; + + constexpr uint64_t dim1 = sv_tma_dim1; // inner dim + // constexpr uint64_t dim2 = sv_tma_dim2; outer dim, not used here. + + uint64_t gmem_shape [4] = {(uint64_t)cols, (uint64_t)rows, (uint64_t)depth, (uint64_t)batch}; + uint64_t gmem_stride[3] = {(uint64_t)cols*sizeof(dtype), (uint64_t)cols*rows*sizeof(dtype), (uint64_t)cols*rows*depth*sizeof(dtype)}; + uint32_t smem_shape [4] = {(uint32_t)dim1, 1, 1, 1}; + uint32_t smem_stride[4] = {1, 1, 1, 1}; + + // ensure that the global address is always 16-byte aligned + assert((reinterpret_cast(global_addr) & 0b1111) == 0); + + assert(smem_shape[0] <= 256); // smem_shape[0] elements must be <= 256. + + const uint64_t *gmem_shape_ptr = &gmem_shape[0]; + const uint64_t *gmem_stride_ptr = &gmem_stride[0]; + const uint32_t *smem_shape_ptr = &smem_shape[0]; + const uint32_t *smem_stride_ptr = &smem_stride[0]; + + CUresult result = cuTensorMapEncodeTiled( + tma_map, + tma_format, + tma_dim, + global_addr, + gmem_shape_ptr, + gmem_stride_ptr, + smem_shape_ptr, + smem_stride_ptr, + tma_interleave, + swizzle, + tma_l2Promotion, + tma_oobFill + ); + + const char *error_string; + CUresult res = cuGetErrorString(result, &error_string); + if (result != CUDA_SUCCESS) { + std::string error_msg = format_tma_error( + "vector", error_string, + batch, depth, rows, cols, + tma_map, tma_format, tma_dim, global_addr, + gmem_shape_ptr, gmem_stride_ptr, + smem_shape_ptr, smem_stride_ptr, + 4, 3, 4, 4, + tma_interleave, swizzle, tma_l2Promotion, tma_oobFill, + "SV::length: " + std::to_string(SV::length) + ); + throw std::runtime_error(error_msg); + } +}; + +/** +* @brief Allocates on the GPU and initializes a tensor map for the given source tensor. +* +* This function creates a tensor map (CUtensorMap) for the specified source shared vector type. The tensor map +* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor +* map based on the provided source tensor pointer and the layout specified by the SV template parameter. +* +* @tparam SV The source tensor type, which must be TMA-compatible. +* @tparam num_vectors The number of vectors present in global memory. +* @param src Pointer to the source tensor data in global memory. +* @returns Pointer to the CUtensorMap object to be initialized. +*/ +template +__host__ static inline CUtensorMap* allocate_and_create_tensor_map(const typename SV::dtype *src, int batch, int depth, int rows, int cols) { + CUtensorMap *tma_map_d; + cudaMalloc(&tma_map_d, sizeof(CUtensorMap)); + CUtensorMap tma_map_host; // put it on the stack, why not. + create_tensor_map(&tma_map_host, src, batch, depth, rows, cols); + cudaMemcpy(tma_map_d, &tma_map_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice); + return tma_map_d; +} + +} // namespace tma +} // namespace detail +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/global/util.cuh b/extra/thunder/cuda/include/types/global/util.cuh new file mode 100644 index 0000000000..3490286113 --- /dev/null +++ b/extra/thunder/cuda/include/types/global/util.cuh @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include "../register/register.cuh" + +namespace kittens { +namespace ducks { +namespace gl { + +template concept cdim = (d > 0); // represents a compile-time dimension +template concept rdim = (d == -1); // represents a runtime dimension +template struct compiled_dim { + static_assert(cdim<_v>, "Invalid compile-time dimension value"); + static constexpr size_t v = _v; + __host__ __device__ inline compiled_dim(const std::nullptr_t &_) {} + __host__ __device__ inline constexpr operator size_t() const { return v; } +}; +struct runtime_dim { + size_t v; + __host__ __device__ inline runtime_dim(const size_t &_v) : v(_v) {} + __host__ __device__ inline operator size_t() const { return v; } +}; +template using make_dim_t = std::conditional_t, runtime_dim, compiled_dim>; +template using make_arg_t = std::conditional_t, size_t, std::nullptr_t>; // we pass runtime dims as size_t, comptime dims as nullptr_t +} +} + +namespace detail { +template concept tile = ducks::st::all || ducks::rt::all || ducks::cst::all || ducks::crt::all; +template concept vec = ducks::sv::all || ducks::rv::all || ducks::csv::all || ducks::crv::all; +} + +namespace ducks { +namespace coord { +struct identifier {}; +} +} +template struct coord { // essentially a named int4 for tensor coordinates. + using identifier = ducks::coord::identifier; + using BASE = _T; // in units of what type? + static_assert(std::is_same_v || detail::tile || detail::vec); // ensure BASE is a valid type + int b, d, r, c; + __device__ inline coord(int _b, int _d, int _r, int _c) : b(_b), d(_d), r(_r), c(_c) {} + __device__ inline coord( int _d, int _r, int _c) : b( 0), d(_d), r(_r), c(_c) {} + __device__ inline coord( int _r, int _c) : b( 0), d( 0), r(_r), c(_c) {} + __device__ inline coord( int _c) : b( 0), d( 0), r( 0), c(_c) {} + __device__ inline coord( ) : b( 0), d( 0), r( 0), c( 0) {} + template __device__ inline coord(const coord &other) : b(other.b), d(other.d), r(other.r), c(other.c) {} + __device__ inline coord(const int4 &other) : b(other.x), d(other.y), r(other.z), c(other.w) {} + __device__ inline operator int4() const { return int4(b, d, r, c); } + template __device__ inline coord unit_coord() const { + if constexpr (detail::tile) { + static_assert(row_axis != col_axis, "row and column axes must be different"); + static_assert(row_axis >= 0 && row_axis <= 3, "row axis must be between 0 and 3"); + static_assert(col_axis >= 0 && col_axis <= 3, "column axis must be between 0 and 3"); + static_assert(col_axis == 3, "for now, column axis must be 3"); + return coord( + row_axis == 0 ? b*BASE::rows : b, + row_axis == 1 ? d*BASE::rows : d, + row_axis == 2 ? r*BASE::rows : r, + c*BASE::cols + ); + } + else if constexpr (detail::vec) { + static_assert(row_axis == -1, "row axis must be be -1 for a vector coordinate to be converted to a unit coordinate"); + static_assert(col_axis >= 0 && col_axis <= 3, "column axis must be between 0 and 3"); + static_assert(col_axis == 3, "for now, column axis must be 3"); + return coord(b, d, r, c*BASE::length); + } + else { + return coord(*this); + } + } + template __device__ inline int dim() const { + static_assert(axis >= 0 && axis <= 3, "axis must be between 0 and 3"); + if constexpr (axis == 0) { return b; } + else if constexpr (axis == 1) { return d; } + else if constexpr (axis == 2) { return r; } + else { return c; } + } +}; +namespace ducks { +namespace coord { +/** +* @brief Concept for all coordinate types. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as ducks::coord::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::coord::identifier +template concept tile = all && (std::is_same_v || detail::tile); +template concept vec = all && (std::is_same_v || detail::vec); +} +} +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/register/crt.cuh b/extra/thunder/cuda/include/types/register/crt.cuh new file mode 100644 index 0000000000..023ae713a2 --- /dev/null +++ b/extra/thunder/cuda/include/types/register/crt.cuh @@ -0,0 +1,95 @@ +/** + * @file + * @brief Abstraction for a complex register tile composed of real and imaginary tiles + */ + +#pragma once + +#include "rt.cuh" +#include "crv.cuh" + +namespace kittens { + +namespace ducks { +namespace crt { +/** + * @brief A dummy type used to identify complex register tiles. + * + * For a type to quack like an rt_cmplx, it should define its identifier as ducks::rt::cmplx_identifier. + * If a type quacks like ducks::rt::cmplx_identifier, it will be treated as an rt_cmplx by compiler checks. + */ +struct identifier {}; +/** +* @brief Concept for register tiles that are complex. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T is a register tile. +* - T has a complex tile identifier. +*/ +template concept all = requires { + typename T::identifier; +} && std::is_same_v && ducks::rt::all; + +/* +* Requires: +* - T is a register tile. +* - T has an internal type layout that is ducks::rt_layout::row. +*/ +template +concept row_layout = all && std::is_same_v; +/** +* @brief Concept for register tiles with col layout. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T is a register tile. +* - T has an internal type layout that is ducks::rt_layout::col. +*/ +template +concept col_layout = all && std::is_same_v; +} // namespace rt +} // namespace ducks + +/** + * @brief Complex tile structure + * + * @tparam T2 The packed data type used for the matrix elements. + * @tparam _rows The height of the tile in terms of the number of subtiles. + * @tparam _cols The width of the tile in terms of the number of subtiles. + * @tparam _layout The layout of the internal register tiles, either row-major or column-major. + * + * This structure is designed to abstract complex number operations internally to the real and imaginary + * register tiles, respectively + * + * In general, you probably want a row-major tile, unless you specifically want to call mma + */ +template +struct crt { + using identifier = ducks::crt::identifier; + using component = rt<_T, _rows, _cols, _layout>; /// Data type of each internal tile. + using layout = component::layout; ///< Layout of the matrix tile, ensures compatibility with the rt concepts + using T = component::T; + using T2 = component::T2; + using dtype = component::dtype; ///< Data type of the elements in the tile. + + static constexpr int rows = component::rows; + static constexpr int cols = component::cols; + static constexpr int height = component::height; + static constexpr int width = component::width; + + // Real/imag tiles have same internal layout and size + component real; + component imag; + + using row_vec = crv::row_vec_layout>; ///< A type representing a column vector for this tile. + using col_vec = crv::col_vec_layout>; ///< A type representing a column vector for this tile. +}; + +template using crt_fl = crt; +template using crt_bf = crt; +template using crt_hf = crt; + + + +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/register/crv.cuh b/extra/thunder/cuda/include/types/register/crv.cuh new file mode 100644 index 0000000000..e688f723a8 --- /dev/null +++ b/extra/thunder/cuda/include/types/register/crv.cuh @@ -0,0 +1,88 @@ +/** + * @file + * @brief Register vectors for computations on axes. + */ + +#pragma once + +#include +#include + +#include "../../common/common.cuh" +#include "rv_layout.cuh" + +namespace kittens { + +/* ---------- MAIN VECTOR STRUCT ---------- */ + +// helper struct for type inference +namespace ducks { +/** + * @namespace rt + * + * @brief The namespace where concepts and abstract types for register vectors live. + */ +namespace crv { +/** + * @brief A dummy type used to identify register vectors. + * + * For a type to quack like an rv, it should define its identifier as ducks::rv::identifier. + * If a type quacks like ducks::rv::identifier, it will be treated as an rv by compiler checks. + */ +struct identifier {}; +/** +* @brief Concept for all register vectors. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as rv::identifier. +*/ +template +concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::rv::identifier. + +template concept naive_layout = all && std::is_same_v; +template concept align_layout = all && std::is_same_v; +template concept ortho_layout = all && std::is_same_v; +template concept tile_layout = align_layout || ortho_layout; // vector layouts for interacting with tiles. +} +} +/** + * @brief Register vector structure. + * + * @tparam _T The packed data type used for the vector elements. + * @tparam _outer_dim The size of the tile, in units of TILE_DIM (16). + * @tparam _inner_dim This controls the layout of the tile in terms of which axis it maps on the register tile layout. + * + * Register vectors are used to accumulate and map values across tiles. You can do computation + * on them directly if you want, but they're not designed to be maximally efficient vectors + * as they have substantial duplication and strange layouts to help them work efficiently with + * the register layouts used by the tensor cores. ThunderKittens wants you working with tiles + * where possible! + */ + +template +struct crv { + using identifier = ducks::crv::identifier; + using component = rv<_T, _length, _layout>; /// Data type of each internal tile. + using layout = component::layout; ///< Layout of the matrix tile, ensures compatibility with the rv concepts + + using T = component::T; + using T2 = component::T2; + using dtype = component::dtype; ///< Data type of the elements in the tile. + + static constexpr int length = component::length; + static constexpr int tiles = component::tiles; + + // Real/imag tiles have same internal layout and size + component real; + component imag; +}; + + +template using crv_fl = crv; +template using crv_bf = crv; +template using crv_hf = crv; + +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/register/register.cuh b/extra/thunder/cuda/include/types/register/register.cuh new file mode 100644 index 0000000000..f3525a0416 --- /dev/null +++ b/extra/thunder/cuda/include/types/register/register.cuh @@ -0,0 +1,15 @@ +/** + * @file + * @brief An aggregate header file for all the register types defined by ThunderKittens. + */ + +#pragma once + +#include "rv_layout.cuh" +#include "rt_base.cuh" +#include "rv.cuh" +#include "rt.cuh" + +#include "crv.cuh" +#include "crt.cuh" + diff --git a/extra/thunder/cuda/include/types/register/rt.cuh b/extra/thunder/cuda/include/types/register/rt.cuh new file mode 100644 index 0000000000..b5765d570a --- /dev/null +++ b/extra/thunder/cuda/include/types/register/rt.cuh @@ -0,0 +1,155 @@ +/** + * @file + * @brief The main ThunderKittens register tile struct, where most computation happens. + */ + +#pragma once + +#include +#include + +#include "../../common/common.cuh" + +#include "rt_layout.cuh" +#include "rt_base.cuh" +#include "rv.cuh" + +namespace kittens { + +/* ---------- MAIN TILE STRUCT ---------- */ + +// helper struct for type inference +namespace ducks { +/** + * @namespace rt + * + * @brief The namespace where concepts and abstract types for register tiles live. + */ +namespace rt { +/** + * @brief A dummy type used to identify register tiles. + * + * For a type to quack like an rt, it should define its identifier as ducks::rt::identifier. + * If a type quacks like ducks::rt::identifier, it will be treated as an rt by compiler checks. + */ +struct identifier {}; +/** +* @brief Concept for all register tiles. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as rt::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::rt::identifier +/** +* @brief Concept for register tiles with row layout. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T is a register tile. +* - T has an internal type layout that is ducks::rt_layout::row. +*/ +template +concept row_layout = all && std::is_same_v; +/** +* @brief Concept for register tiles with col layout. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T is a register tile. +* - T has an internal type layout that is ducks::rt_layout::col. +*/ +template +concept col_layout = all && std::is_same_v; +} // namespace rt +} // namespace ducks + +/** + * @brief Main tile structure for manipulating data in registers. + * + * @tparam T2 The packed data type used for the matrix elements. + * @tparam _height The height of the tile in terms of the number of subtiles. + * @tparam _width The width of the tile in terms of the number of subtiles. + * @tparam _layout The layout of the internal base tiles, either row-major or column-major. + * + * This structure is designed to handle matrix tiles in a flexible manner, allowing + * for operations on tiles that are composed of smaller subtiles. It supports both + * row-major and column-major layouts and includes helper structs for type inference + * in vector maps. + * + * In general, you probably want a row-major tile, unless you specifically want to call mma + */ +template +struct rt { + using identifier = ducks::rt::identifier; ///< Type identifier for the rt structure. + using layout = _layout; ///< Layout of the matrix tile. + static_assert(kittens::ducks::base_types::T1<_T>); // confirm it's a supported type + using T = kittens::base_types::packing<_T>::unpacked_type; + using T2 = kittens::base_types::packing<_T>::packed_type; + using dtype = T2; ///< Data type of the matrix elements + + static constexpr int rows = _rows; ///< Total number of rows. + static_assert(rows % rt_base::tile_size_row == 0, "Rows must be divisible by the tile size"); + static constexpr int cols = _cols; ///< Total number of columns. + static_assert(cols % rt_base::tile_size_col == 0, "Columns must be divisible by the tile size"); + static constexpr int height = rows / rt_base::tile_size_row; ///< Height in subtiles. + static constexpr int width = cols / rt_base::tile_size_col; ///< Width in subtiles. + static constexpr int tile_size_row = rt_base::tile_size_row; ///< Size of the base tile. + static constexpr int tile_size_col = rt_base::tile_size_col; ///< Size of the base tile. + static constexpr int num_elements = rt_base::num_elements * width * height; ///< Total number of elements. + static constexpr int elements_per_thread = rt_base::elements_per_thread * width * height; ///< Elements handled per thread. + static constexpr int packed_per_thread = rt_base::packed_per_thread * width * height; ///< Packed elements per thread. + static constexpr int packed_per_tile = rt_base::packed_per_thread; ///< Packed elements per tile. + + rt_base tiles[height][width]; ///< The actual storage for the matrix tile, organized in subtiles. + + using row_vec = rv::row_vec_layout>; ///< A type representing a column vector for this tile. + using col_vec = rv::col_vec_layout>; ///< A type representing a column vector for this tile. + + __device__ inline void operator=(const T &value) { + T2 value2 = base_types::packing::pack(value); + #pragma unroll + for(int i = 0; i < height; i++) { + #pragma unroll + for(int j = 0; j < width; j++) { + #pragma unroll + for(int k = 0; k < packed_per_tile; k++) { + tiles[i][j].data[k] = value2; + } + } + } + } + template + __device__ inline void operator=(const rt &other) { + using U2 = base_types::packing::packed_type; + #pragma unroll + for(int i = 0; i < height; i++) { + #pragma unroll + for(int j = 0; j < width; j++) { + #pragma unroll + for(int k = 0; k < packed_per_tile; k++) { + tiles[i][j].data[k] = base_types::convertor::convert(other.tiles[i][j].data[k]); + } + } + } + } +}; + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// layout and type wrappers + +template using rt_fl = rt; +template using rt_bf = rt; +template using rt_hf = rt; +#ifdef KITTENS_HOPPER +template using rt_fp8e4m3 = rt; +template using rt_fp8e5m2 = rt; +#ifdef KITTENS_BLACKWELL +template using rt_fp8e8m0 = rt; +#endif +#endif +} // namespace kittens diff --git a/extra/thunder/cuda/include/types/register/rt_base.cuh b/extra/thunder/cuda/include/types/register/rt_base.cuh new file mode 100644 index 0000000000..c15f1c5910 --- /dev/null +++ b/extra/thunder/cuda/include/types/register/rt_base.cuh @@ -0,0 +1,112 @@ +/** + * @file + * @brief The basic 16x16 register tile on which larger register tiles are built. + */ + +#pragma once + +#include + +#include "../../common/common.cuh" +#include "rt_layout.cuh" +#include "rv_layout.cuh" + +namespace kittens { + +/* ---------- BASE 16x16 SUBTILE STRUCT ---------- */ + +namespace ducks { +/** + * @namespace rt_base + * + * @brief The namespace where concepts and abstract types for register base (16x16) tiles live. + */ +namespace rt_base { +/** + * @brief A dummy type used to identify register base tiles. + * + * For a type to quack like an rt_base, it should define its identifier as ducks::rt_base::identifier. + * If a type quacks like ducks::rt_base::identifier, it will be treated as an rt_base by compiler checks. + */ +struct identifier {}; +} +} // namespace ducks + +/** + * @brief Basic tile structure for computation in registers. + * + * @tparam T2 The packed data type used for the matrix elements. + * @tparam _layout The layout of the base tile, either row-major or column-major. + * + * This type is a primarily utility for building larger inline templates + * out of PTX primitives and managing layouts. + * + * In general, you probably want a row-major tile, unless you specifically want to call mma + */ +template struct rt_base { + using identifier = ducks::rt_base::identifier; ///< Type identifier for the rt_base structure. + using layout = _layout; ///< Layout of the matrix tile. + static_assert(kittens::ducks::base_types::T1<_T>); // confirm it's a supported type + using T = kittens::base_types::packing<_T>::unpacked_type; + using T2 = kittens::base_types::packing<_T>::packed_type; + using dtype = T2; ///< Data type of the matrix elements + + #ifdef KITTENS_HOPPER + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v, + "rt_base was provided an unsupported type." + ); + #else + static_assert( + std::is_same_v || std::is_same_v || std::is_same_v, + "rt_base was provided an unsupported type." + ); + #endif + + static constexpr int tile_size_row = kittens::TILE_ROW_DIM; // < Tile size is a constant 16 for everyone + static constexpr int tile_size_col = kittens::TILE_COL_DIM; + static constexpr int rows = tile_size_row; ///< Number of rows. + static constexpr int cols = tile_size_col; ///< Number of cols. + static constexpr int num_elements = rows*cols; // 256 (64 for fp8e4m3) + static constexpr int elements_per_thread = num_elements / 32; // 8 (2 for fp8e4m3) + + static constexpr int packed_per_thread = (elements_per_thread / base_types::packing::num()) ; // 4 + static constexpr int registers_per_thread = packed_per_thread * sizeof(dtype) / 4; // 4 or 8, registers are 32-bit words + + using row_vec_layout = std::conditional_t, ducks::rv_layout::align, ducks::rv_layout::ortho>; // for holding column reductions + using col_vec_layout = std::conditional_t, ducks::rv_layout::ortho, ducks::rv_layout::align>; // for holding row reductions + + dtype data[packed_per_thread]; ///< The actual storage for the base tile +}; + +// rt_base is 2x the number of elements for fp8e4m3 +// then when we convert a 16x16 of float2, we have 512 elements in the tile +// and with fp8e4m3x4 packed type, we have 16x32x4=2048 elements in the tile + +/* ---------- CONCEPTS ---------- */ + +namespace ducks { +namespace rt_base { +/** +* @brief Concept for all register base tiles. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as rt_base::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::rt::identifier +} // namespace rt +} // namespace ducks + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +template using rt_base_fl = rt_base; +template using rt_base_bf = rt_base; +template using rt_base_hf = rt_base; +#ifdef KITTENS_HOPPER +template using rt_base_fp8e4m3 = rt_base; +template using rt_base_fp8e5m2 = rt_base; +#endif +} diff --git a/extra/thunder/cuda/include/types/register/rt_layout.cuh b/extra/thunder/cuda/include/types/register/rt_layout.cuh new file mode 100644 index 0000000000..a9f9f337cf --- /dev/null +++ b/extra/thunder/cuda/include/types/register/rt_layout.cuh @@ -0,0 +1,42 @@ +/** + * @file + * @brief Layouts and their manipulations for register tiles. + */ + +#pragma once + +#include + +namespace kittens { +namespace ducks { +/** + * @namespace rt_layout + * + * @brief A namespace for template metaprogramming with register tile layouts. + */ +namespace rt_layout { + +/** + * @brief A dummy type used to identify a row-major layout for a register tile. + */ +struct row {}; // for most matrices +/** + * @brief A dummy type used to identify a col-major layout for a register tile. + */ +struct col {}; // for the B-matrix of MMA ops. + +/** + * @brief A concept to check if a type is a register tile layout. + */ +template +concept all = std::is_same_v || std::is_same_v; + +/** + * @brief A struct to generate a transposed layout. + */ +template struct transpose { using type = col; }; +template<> struct transpose { using type = row; }; + +} // namespace rt_layout +} // namespace ducks +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/register/rv.cuh b/extra/thunder/cuda/include/types/register/rv.cuh new file mode 100644 index 0000000000..21af8ffabc --- /dev/null +++ b/extra/thunder/cuda/include/types/register/rv.cuh @@ -0,0 +1,122 @@ +/** + * @file + * @brief Register vectors for computations on axes. + */ + +#pragma once + +#include +#include + +#include "../../common/common.cuh" +#include "rv_layout.cuh" + +namespace kittens { + +/* ---------- MAIN VECTOR STRUCT ---------- */ + +// helper struct for type inference +namespace ducks { +/** + * @namespace rt + * + * @brief The namespace where concepts and abstract types for register vectors live. + */ +namespace rv { +/** + * @brief A dummy type used to identify register vectors. + * + * For a type to quack like an rv, it should define its identifier as ducks::rv::identifier. + * If a type quacks like ducks::rv::identifier, it will be treated as an rv by compiler checks. + */ +struct identifier {}; +/** +* @brief Concept for all register vectors. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as rv::identifier. +*/ +template +concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::rv::identifier. + +template concept naive_layout = all && std::is_same_v; +template concept align_layout = all && std::is_same_v; +template concept ortho_layout = all && std::is_same_v; +template concept tile_layout = align_layout || ortho_layout; // vector layouts for interacting with tiles. +} +} +/** + * @brief Register vector structure. + * + * @tparam _T The packed data type used for the vector elements. + * @tparam _outer_dim The size of the tile, in units of TILE_DIM (16). + * @tparam _inner_dim This controls the layout of the tile in terms of which axis it maps on the register tile layout. + * + * Register vectors are used to accumulate and map values across tiles. You can do computation + * on them directly if you want, but they're not designed to be maximally efficient vectors + * as they have substantial duplication and strange layouts to help them work efficiently with + * the register layouts used by the tensor cores. ThunderKittens wants you working with tiles + * where possible! + */ +template +struct rv { + using identifier = ducks::rv::identifier; ///< Type identifier for the rv structure. + static_assert(kittens::ducks::base_types::T1<_T>); // confirm it's a supported type + using layout = _layout; + static constexpr bool is_naive = std::is_same_v; + using T = kittens::base_types::packing<_T>::unpacked_type; + using T2 = kittens::base_types::packing<_T>::packed_type; + using dtype = std::conditional_t; ///< Data type of the vector elements + + static constexpr int length = _length; ///< Length in elements. + static_assert(length % kittens::TILE_ROW_DIM == 0, "Length must be divisible by the tile dimension"); + static constexpr int tiles = _length / kittens::TILE_ROW_DIM; ///< Length in subtiles, aliased for consistency with sv type + static constexpr int inner_dim = layout::inner_dim; ///< Internal layout within a subtile. Either 1 or 2. + static constexpr int outer_dim = is_naive ? (tiles+1)/2 : tiles; ///< Outer dim (also length in tiles) + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for fp8"); + #endif + + dtype data[outer_dim][inner_dim]; ///< The actual register vector data. + + __device__ inline dtype* operator[](size_t idx) { return &data[idx][0]; } ///< A wrapper for indexing into vector data. + __device__ inline const dtype* operator[](size_t idx) const { return &data[idx][0]; } ///< A wrapper for indexing into vector data. + __device__ inline dtype& operator[](int2 outin) { return data[outin.x][outin.y]; } ///< A wrapper for indexing into vector data. + __device__ inline const dtype& operator[](int2 outin) const { return data[outin.x][outin.y]; } ///< A wrapper for indexing into vector data. + + __device__ inline void operator=(const T &value) { + dtype value2; + if constexpr(is_naive) { + value2 = value; + } else { + value2 = base_types::packing::pack(value); + } + #pragma unroll + for(int i = 0; i < outer_dim; i++) { + #pragma unroll + for(int j = 0; j < inner_dim; j++) { + data[i][j] = value2; + } + } + } + template + __device__ inline void operator=(const rv &other) { + using U2 = base_types::packing::packed_type; + #pragma unroll + for(int i = 0; i < outer_dim; i++) { + #pragma unroll + for(int j = 0; j < inner_dim; j++) { + data[i][j] = base_types::convertor::convert(other.data[i][j]); + } + } + } +}; + +template using rv_fl = rv; +template using rv_bf = rv; +template using rv_hf = rv; + +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/register/rv_layout.cuh b/extra/thunder/cuda/include/types/register/rv_layout.cuh new file mode 100644 index 0000000000..0165a86f76 --- /dev/null +++ b/extra/thunder/cuda/include/types/register/rv_layout.cuh @@ -0,0 +1,40 @@ +/** + * @file + * @brief Layouts and their manipulations for register tiles. + */ + +#pragma once + +#include + +namespace kittens { +namespace ducks { +/** + * @namespace rv_layout + * + * @brief A namespace for template metaprogramming with register vector layouts. + */ +namespace rv_layout { + +/** + * @brief A dummy type used to identify an aligned (8x replicated) layout. + */ +struct align { constexpr static int inner_dim = 2; }; +/** + * @brief A dummy type used to identify an orthogonal (4x replicated) layout. + */ +struct ortho { constexpr static int inner_dim = 1; }; +/** + * @brief A dummy type used to identify an unreplicated layout, for better coalesced loads and vector operations like layernorm. + */ +struct naive { constexpr static int inner_dim = 1; }; + +/** + * @brief A concept to check if a type is a register tile layout. + */ +template +concept all = std::is_same_v || std::is_same_v || std::is_same_v; + +} // namespace rv_layout +} // namespace ducks +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/shared/cst.cuh b/extra/thunder/cuda/include/types/shared/cst.cuh new file mode 100644 index 0000000000..98c63f9a3a --- /dev/null +++ b/extra/thunder/cuda/include/types/shared/cst.cuh @@ -0,0 +1,82 @@ +/** + * @file + * @brief Abstraction for a complex register tile composed of real and imaginary tiles + */ + +#pragma once + +#include "st.cuh" + +namespace kittens { + +namespace ducks { +namespace cst { +/** + * @brief A dummy type used to identify complex register tiles. + * + * For a type to quack like an st_cmplx, it should define its identifier as ducks::st::cmplx_identifier. + * If a type quacks like ducks::st::cmplx_identifier, it will be treated as an st_cmplx by compiler checks. + */ +struct identifier {}; + +/** +* @brief Concept for shared tiles that are complex. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T is a shared tile. +* - T has a complex tile identifier. +*/ +template concept all = requires { + typename T::identifier; +} && std::is_same_v && ducks::st::all; + +} // namespace st +} // namespace ducks + +/** + * @brief Complex tile structure + * + * @tparam T2 The packed data type used for the matrix elements. + * @tparam _rows The height of the tile in terms of the number of subtiles. + * @tparam _cols The width of the tile in terms of the number of subtiles. + * @tparam _layout The layout of the internal register tiles + * + * This structure is designed to abstract complex number operations internally to the real and imaginary + * shared tiles, respectively + * + * + */ +template +struct cst { + using identifier = ducks::cst::identifier; + using component = st<_T, _rows, _cols>; /// Data type of each internal tile. + using T = component::T; + using T2 = component::T2; + using dtype = component::dtype; ///< Data type of the elements in the tile. + + static constexpr int rows = component::rows; + static constexpr int cols = component::cols; + static constexpr int height = component::height; + static constexpr int width = component::width; + + // todo: fill in the rest for convenience, but they're all accessible via component so it's not urgent. + + // Real/imag tiles have same internal layout and size + component real; + component imag; + + // vector types + using col_vec = csv; + using row_vec = csv; +}; + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +template using cst_bf = cst; +template using cst_hf = cst; +template using cst_fl = cst; + + + +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/shared/csv.cuh b/extra/thunder/cuda/include/types/shared/csv.cuh new file mode 100644 index 0000000000..dab205fa23 --- /dev/null +++ b/extra/thunder/cuda/include/types/shared/csv.cuh @@ -0,0 +1,74 @@ +/** + * @file + * @brief Abstraction for a complex register tile composed of real and imaginary tiles + */ + +#pragma once + +#include "st.cuh" + +namespace kittens { + +namespace ducks { +namespace csv { +/** + * @brief A dummy type used to identify complex register tiles. + * + * For a type to quack like an st_cmplx, it should define its identifier as ducks::st::cmplx_identifier. + * If a type quacks like ducks::st::cmplx_identifier, it will be treated as an st_cmplx by compiler checks. + */ +struct identifier {}; +/** +* @brief Concept for shared vectors that are complex. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T is a shared tile. +* - T has a complex tile identifier. +*/ +template concept all = requires { + typename T::identifier; +} && std::is_same_v && ducks::sv::all; + +} // namespace st +} // namespace ducks + +/** + * @brief Complex tile structure + * + * @tparam T2 The packed data type used for the matrix elements. + * @tparam _height The height of the tile in terms of the number of subtiles. + * @tparam _width The width of the tile in terms of the number of subtiles. + * @tparam _layout The layout of the internal register tiles + * + * This structure is designed to abstract complex number operations internally to the real and imaginary + * shared tiles, respectively + * + * + */ +template +struct csv { + using identifier = ducks::csv::identifier; + using component = sv<_T, _length>; /// Data type of each internal tile. + using T = component::T; + using T2 = component::T2; + using dtype = component::dtype; ///< Data type of the elements in the tile. + + static constexpr int length = component::length; + static constexpr int tiles = component::tiles; + + // todo: fill in the rest for convenience, but they're all accessible via component so it's not urgent. + + // Real/imag tiles have same internal layout and size + component real; + component imag; +}; + + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +template using csv_bf = csv; +template using csv_hf = csv; +template using csv_fl = csv; + +} \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/shared/shared.cuh b/extra/thunder/cuda/include/types/shared/shared.cuh new file mode 100644 index 0000000000..773011b07c --- /dev/null +++ b/extra/thunder/cuda/include/types/shared/shared.cuh @@ -0,0 +1,14 @@ +/** + * @file + * @brief An aggregate header file for all the shared types defined by ThunderKittens. + */ + +#pragma once + +#include "sv.cuh" +#include "st.cuh" + +#include "csv.cuh" +#include "cst.cuh" + +#include "st_descriptor.cuh" \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/shared/st.cuh b/extra/thunder/cuda/include/types/shared/st.cuh new file mode 100644 index 0000000000..8176438382 --- /dev/null +++ b/extra/thunder/cuda/include/types/shared/st.cuh @@ -0,0 +1,349 @@ +/** + * @file + * @brief The ThunderKittens shared tile struct. + */ + +#pragma once + +#include "../../common/common.cuh" +#include "sv.cuh" + +/* ---------- MAIN TILE STRUCT ---------- */ + +// these are helper structs for type inference +namespace kittens { +namespace ducks { +/** + * @namespace rt + * + * @brief The namespace where concepts and abstract types for shared tiles live. + */ +namespace st { +/** + * @brief A dummy type used to identify shared tiles. + * + * For a type to quack like an st, it should define its identifier as ducks::st::identifier. + * If a type quacks like ducks::st::identifier, it will be treated as an st by compiler checks. + * This is particularly useful for subtiles. + */ +struct identifier {}; +/** +* @brief Concept for all shared tiles. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as st::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::st::identifier +} +} // namespace ducks + +// Forward declaration of subtile +template< + typename ST, + int _subtile_height, + int _subtile_width +> +struct st_subtile; + +/** + * @brief Shared memory tile structure for various data types and layouts. + * + * @tparam T The data type of the elements in the tile. Not packed! + * @tparam _rows The height of the tile. + * @tparam _cols The width of the tile. + */ +template +struct KITTENS_DEFAULT_ALIGN st { + using identifier = ducks::st::identifier; ///< Type identifier for shared memory tile. + using T = base_types::packing<_T>::unpacked_type; + using T2 = base_types::packing<_T>::packed_type; + using dtype = T; ///< Data type of the elements in the tile. + + // define underlying data as same as that projected, to make clear that this is *not* a subtile. + static constexpr int underlying_rows = _rows; + static constexpr int underlying_cols = _cols; + static constexpr int underlying_height = _rows / kittens::TILE_ROW_DIM; + static constexpr int underlying_width = _cols / kittens::TILE_COL_DIM; + static constexpr int underlying_num_elements = underlying_rows * underlying_cols; + + static constexpr int rows = _rows; ///< Total number of rows in the tile. + static_assert(rows % kittens::TILE_ROW_DIM == 0, "Rows must be divisible by the tile dimension"); + static constexpr int cols = _cols; ///< Total number of cols in the tile. + static_assert(cols % kittens::TILE_COL_DIM == 0, "Cols must be divisible by the tile dimension"); + static constexpr int height = _rows / kittens::TILE_ROW_DIM; ///< Height of the tile in terms of 16-element subtiles. + static constexpr int width = _cols / kittens::TILE_COL_DIM; ///< Width of the tile in terms of 16-element subtiles. + static constexpr int num_elements = rows * cols; ///< Total number of elements in the tile. + + static_assert(base_types::packing::num() == 1); // must be a 1-packed type (e.g. float, bf16, etc) + + static constexpr int swizzle_bytes = ( + sizeof(dtype) == 1 ? ( // Add FP8 case + underlying_width%4 == 0 ? 128 : + underlying_width%2 == 0 ? 64 : 32 + ) : + sizeof(dtype) == 2 ? ( + underlying_width%4 == 0 ? 128 : + underlying_width%2 == 0 ? 64 : 32 + ) : + sizeof(dtype) == 4 ? ( + underlying_width%2 == 0 ? 128 : 64 + ) : -1 + ); + + // wgmma layout with swizzling + dtype data[rows*cols]; ///< Raw data storage for the tile. + + __device__ static inline T* idx(T *ptr, int2 coord) { // naive row-major coord default + int r = coord.x, c = coord.y; // alias + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = c/subtile_cols; + const uint64_t addr = (uint64_t)(&ptr[outer_idx*rows*subtile_cols + r*subtile_cols + c%subtile_cols]); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (T*)(addr ^ swizzle); + } + __device__ static inline uint32_t idx(uint32_t ptr, int2 coord) { + int r = coord.x, c = coord.y; // alias + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = c/subtile_cols; + const uint32_t addr = ptr + sizeof(T)*(outer_idx*rows*subtile_cols + r*subtile_cols + c%subtile_cols); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (addr ^ swizzle); + } + /** + * @brief Access a shared tile element using a row and column, as if the tile were row-major. + * + * This is the preferred way to access memory within a shared tile, which abstracts + * indexing calculations for swizzled layouts. + */ + __device__ inline dtype& operator[](const int2 &rowcol) { + return *idx(data, rowcol); + } + __device__ inline const dtype& operator[](const int2 &rowcol) const { + return *(const dtype*)idx((dtype*)data, rowcol); + } + __device__ inline dtype& operator[](int idx) { + return data[idx]; + } + __device__ inline const dtype& operator[](int idx) const { + return data[idx]; + } + + template + __device__ inline st_subtile, subtile_rows, subtile_cols> subtile(int2 rowcol); + + // vector types + using col_vec = sv; ///< Column vector type for this tile + using row_vec = sv; ///< Row vector type for this tile +}; + + + +/** + * @brief A reference into a chunk of shared tile memory. + * + * The st_subtile is a drop-in replacement for an st which internally + * references the appropriate memory while performing minimal address + * calculations. You should never create this directly, but instead + * have subtile_inplace return it for you instead. (`auto` is nice.) + * + * You can generally just pretend this is an st. But not for wgmma's. + */ +template< + typename _ST, + int _subtile_rows, + int _subtile_cols +> +struct st_subtile { + using identifier = ducks::st::identifier; // i quack like an st, gcc will never know the difference + using ST = _ST; + using T = ST::T; + using T2 = ST::T2; + using dtype = T; ///< Data type of the elements in the tile. + + static constexpr int underlying_rows = ST::underlying_rows; + static_assert(underlying_rows % kittens::TILE_ROW_DIM == 0, "Underlying rows must be divisible by the tile dimension"); + static constexpr int underlying_cols = ST::underlying_cols; + static_assert(underlying_cols % kittens::TILE_COL_DIM == 0, "Underlying cols must be divisible by the tile dimension"); + static constexpr int underlying_height = ST::underlying_height; + static constexpr int underlying_width = ST::underlying_width; + static constexpr int underlying_num_elements = ST::underlying_num_elements; + + static constexpr int rows = _subtile_rows; + static_assert(rows % kittens::TILE_ROW_DIM == 0, "Rows must be divisible by the tile dimension"); + static constexpr int cols = _subtile_cols; + static_assert(cols % kittens::TILE_COL_DIM == 0, "Cols must be divisible by the tile dimension"); + static constexpr int height = rows / kittens::TILE_ROW_DIM; + static constexpr int width = cols / kittens::TILE_COL_DIM; + static constexpr int num_elements = rows * cols; + + static constexpr int swizzle_bytes = ST::swizzle_bytes; + + dtype *data; + int row_offset, col_offset; + + __device__ st_subtile(ST &src, int2 rowcol) { + data = &src.data[0]; + row_offset = rowcol.x * rows; + col_offset = rowcol.y * cols; + } + + __device__ inline T* idx(T *ptr, const int2 coord) { // naive row-major coord default + int r = coord.x+row_offset, c = coord.y+col_offset; // alias + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = c/subtile_cols; + const uint64_t addr = (uint64_t)(&ptr[outer_idx*underlying_rows*subtile_cols + r*subtile_cols + c%subtile_cols]); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (T*)(addr ^ swizzle); + } + __device__ inline uint32_t idx(uint32_t ptr, const int2 coord) const { // naive row-major coord default + int r = coord.x+row_offset, c = coord.y+col_offset; // alias + static constexpr int swizzle_repeat = swizzle_bytes * 8; + static constexpr int subtile_cols = swizzle_bytes / sizeof(T); + const int outer_idx = c/subtile_cols; + const uint32_t addr = ptr + sizeof(T)*(outer_idx*underlying_rows*subtile_cols + r*subtile_cols + c%subtile_cols); + const int swizzle = ((addr % swizzle_repeat) >> 7) << 4; + return (addr ^ swizzle); + } + /** + * @brief Access a shared tile element using a row and column, as if the tile were row-major. + * + * This is the preferred way to access memory within a shared tile, which abstracts + * indexing calculations for swizzled layouts. + */ + __device__ inline dtype& operator[](const int2 &rowcol) { + return *idx(data, rowcol); + } + __device__ inline const dtype& operator[](const int2 &rowcol) const { + return *(const dtype*)idx((dtype*)data, rowcol); + } + + // single-coord operator[] is left undefined as it would likely be an improper use of st_subtile type. + // can of course be end-run by just accessing .data directly. + + // vector types + using col_vec = sv; + using row_vec = sv; + + __device__ inline void operator=(const dtype &value) { // runs at warp scope by default + #pragma unroll + for(int i = kittens::laneid(); i < num_elements; i += WARP_THREADS) { + data[i] = value; + } + } +}; + +template // Class template parameters +template // Function template parameters +__device__ inline st_subtile, subtile_rows, subtile_cols> // Return type +st<_T, _rows, _cols>::subtile(int2 rowcol) // Qualified function name and parameters +{ + // Type aliases for convenience within the function body + using ST_t = st<_T, _rows, _cols>; // Alias for the parent tile type + using dtype = typename ST_t::dtype; // Alias for the data type + + // Static assertions (as provided in the initial request) + static_assert(subtile_rows > 0 && subtile_cols > 0, "Subtile dimensions must be positive."); + static_assert(subtile_rows % kittens::TILE_ROW_DIM == 0, + "Subtile rows must be divisible by the base tile row dimension."); + static_assert(subtile_cols % kittens::TILE_COL_DIM == 0, + "Subtile cols must be divisible by the base tile col dimension."); + + // Calculate height/width in terms of base tiles for further checks + constexpr int subtile_height = subtile_rows / kittens::TILE_ROW_DIM; + constexpr int subtile_width = subtile_cols / kittens::TILE_COL_DIM; + static_assert(subtile_height > 0 && subtile_width > 0, "Subtile height/width in base tiles must be positive."); + + // Check divisibility of parent height/width by subtile height/width + static_assert(ST_t::height % subtile_height == 0, + "Parent tile height (in base tiles) must be divisible by subtile height (in base tiles)."); + static_assert(ST_t::width % subtile_width == 0, + "Parent tile width (in base tiles) must be divisible by subtile width (in base tiles)."); + + // Ensure the parent st object is not itself a subtile view by comparing its + // dimensions to its underlying dimensions. + static_assert(ST_t::height == ST_t::underlying_height && ST_t::width == ST_t::underlying_width, + "Cannot create a subtile from an object that appears to be a subtile view (height/width mismatch underlying)."); + // Also check rows/cols directly for robustness, though height/width check might suffice. + static_assert(ST_t::rows == ST_t::underlying_rows && ST_t::cols == ST_t::underlying_cols, + "Cannot create a subtile from an object that appears to be a subtile view (rows/cols mismatch underlying)."); + + + // Construct and return the st_subtile object using its constructor: + // st_subtile(ST &src, int2 rowcol) + // Here, 'src' is the current 'st' object (*this) + return st_subtile(*this, rowcol); +} + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +template using st_bf = st; +template using st_hf = st; +template using st_fl = st; +#ifdef KITTENS_HOPPER +template using st_fp8e4m3 = st; +template using st_fp8e5m2 = st; +#ifdef KITTENS_BLACKWELL +template using st_fp8e8m0 = st; +#endif +#endif + +/* ---------- PRINTOUTS ---------- */ + +/** + * @brief Print the contents of a shared tile as a formatted table. + * + * This function should be called by a single thread in the warp. + * It will print the entire tile atomically to avoid interleaved output. + * + * @param tile The shared tile to print + */ +template +__device__ inline void print(const ST& tile) { + printf("Shared Tile %dx%d:\n", ST::rows, ST::cols); + + // Print column headers + printf(" "); // Padding for row indices + for (int c = 0; c < ST::cols; c++) { + printf("%8d ", c); + } + printf("\n"); + + // Print separator line + printf(" "); + for (int c = 0; c < ST::cols; c++) { + printf("--------+"); + } + printf("\n"); + + // Print data rows + for (int r = 0; r < ST::rows; r++) { + printf("%3d |", r); // Row index + for (int c = 0; c < ST::cols; c++) { + if constexpr (std::is_same_v) { + printf("%8.3f ", static_cast(tile[{r,c}])); +#ifdef KITTENS_BLACKWELL + } else if constexpr (std::is_same_v) { + printf("%8.3f ", static_cast(tile[{r,c}])); +#endif + } else if constexpr (std::is_same_v) { + printf("%8.3f ", tile[{r,c}]); + } else if constexpr (std::is_same_v) { + printf("%8.3f ", __bfloat162float(tile[{r,c}])); + } else if constexpr (std::is_integral_v) { + printf("%8d ", (int)tile[{r,c}]); + } else { + printf("%8.3f ", (float)tile[{r,c}]); + } + } + printf("\n"); + } + printf("\n"); +} + +} diff --git a/extra/thunder/cuda/include/types/shared/st_descriptor.cuh b/extra/thunder/cuda/include/types/shared/st_descriptor.cuh new file mode 100644 index 0000000000..d9cc24e111 --- /dev/null +++ b/extra/thunder/cuda/include/types/shared/st_descriptor.cuh @@ -0,0 +1,118 @@ +/** + * @file + * @brief The ThunderKittens shared tile descriptors, used for Hopper and Blackwell tensor cores. + */ + +#pragma once + +#if defined(KITTENS_HOPPER) || defined(KITTENS_BLACKWELL) + +#include "../../common/common.cuh" +#include "st.cuh" +#include "cst.cuh" + +namespace kittens { +namespace ducks { +namespace st_descriptor { +struct identifier {}; +} +} + +namespace detail { +// see https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor +__device__ static inline uint64_t matrix_descriptor_encode(uint64_t x) { return (((x) & 0x3FFFF) >> 0x4); } +} + +template +struct st_descriptor { + using identifier = ducks::st_descriptor::identifier; + using ST = _ST; + static constexpr int height = ST::height; + static constexpr int width = ST::width; + using T = ST::T; + uint64_t base_desc; + __device__ inline st_descriptor(const ST &tile) { +#ifdef KITTENS_BLACKWELL + base_desc = detail::matrix_descriptor_encode((uint64_t)(&tile.data[0])) | (1llu<<46); // needed for blackwell shared memory descriptors. +#else + base_desc = detail::matrix_descriptor_encode((uint64_t)(&tile.data[0])); +#endif + if constexpr (transpose) { // transpose mode + if constexpr (ST::width%4 == 0) { + base_desc |= detail::matrix_descriptor_encode((uint64_t)2048*ST::height) << 16; + base_desc |= detail::matrix_descriptor_encode((uint64_t)1024) << 32; + base_desc |= 1llu << 62; // set wgmma_swizzle mode + } + else if constexpr (ST::width%2 == 0) { + base_desc |= detail::matrix_descriptor_encode((uint64_t)1024*ST::height) << 16; + base_desc |= detail::matrix_descriptor_encode((uint64_t)512) << 32; + base_desc |= 2llu << 62; // set wgmma_swizzle mode + } + else { + base_desc |= detail::matrix_descriptor_encode((uint64_t)512*ST::height) << 16; + base_desc |= detail::matrix_descriptor_encode((uint64_t)256) << 32; + base_desc |= 3llu << 62; // set wgmma_swizzle mode + } + } + else { // normal mode + if constexpr (ST::width%4 == 0) { + base_desc |= detail::matrix_descriptor_encode((uint64_t)16) << 16; // this line doesn't matter + base_desc |= detail::matrix_descriptor_encode((uint64_t)1024) << 32; // 128 byte swizzle x 8 for core matrix rows + base_desc |= 1llu << 62; // set wgmma_swizzle mode + } + else if constexpr (ST::width%2 == 0) { + base_desc |= detail::matrix_descriptor_encode((uint64_t)16) << 16; // this line doesn't matter + base_desc |= detail::matrix_descriptor_encode((uint64_t)512) << 32; // 64 byte swizzle x 8 for core matrix rows + base_desc |= 2llu << 62; // set wgmma_swizzle mode + } + else { + base_desc |= detail::matrix_descriptor_encode((uint64_t)16) << 16; // this line doesn't matter + base_desc |= detail::matrix_descriptor_encode((uint64_t)256) << 32; // 32 byte swizzle x 8 for core matrix rows + base_desc |= 3llu << 62; // set wgmma_swizzle mode + } + } + } + __device__ inline st_descriptor(const st_descriptor &other) : base_desc(other.base_desc) {} // copy constructor + __device__ inline uint64_t chunk_descriptor(int chunk_idx) { + if constexpr (transpose) { // transpose mode + if constexpr (ST::width%4 == 0) { + return base_desc + detail::matrix_descriptor_encode(chunk_idx*2048); + } + else if constexpr (ST::width%2 == 0) { + return base_desc + detail::matrix_descriptor_encode(chunk_idx*1024); + } + else { + return base_desc + detail::matrix_descriptor_encode(chunk_idx*512); + } + } + else { // normal mode + if constexpr (ST::width%4 == 0) { + return base_desc + detail::matrix_descriptor_encode((chunk_idx%4)*32 + (chunk_idx/4)*ST::height*2048); + } + else if constexpr (ST::width%2 == 0) { + return base_desc + detail::matrix_descriptor_encode((chunk_idx%2)*32 + (chunk_idx/2)*ST::height*1024); + } + else { + return base_desc + detail::matrix_descriptor_encode(chunk_idx*ST::height*512); + } + } + } +}; + +namespace ducks { +namespace st_descriptor { +// input refers to either an ST directly or to a pre-generated descriptor, which can save cycles in certain situations. +template concept input = ducks::st::all || (requires {typename T::identifier;} && std::is_same_v); +template concept complex_input = ducks::cst::all; +namespace detail { +template struct st_getter { using type = typename T::ST; }; +template struct st_getter { using type = T; }; +template struct st_getter { using type = T::component; }; +template using get_st = typename st_getter::type; +} // namespace detail +} // namespace st_descriptor +} // namespace ducks + +} // namespace kittens + +#endif \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/shared/sv.cuh b/extra/thunder/cuda/include/types/shared/sv.cuh new file mode 100644 index 0000000000..475c51b777 --- /dev/null +++ b/extra/thunder/cuda/include/types/shared/sv.cuh @@ -0,0 +1,130 @@ +/** + * @file + * @brief The ThunderKittens shared vector struct. + */ + +#pragma once + +#include +#include + +#include "../../common/common.cuh" + +namespace kittens { + +/* ---------- MAIN VECTOR STRUCT ---------- */ + +namespace ducks { +/** + * @namespace sv + * + * @brief The namespace where concepts and abstract types for shared vectors live. + */ +namespace sv { +/** + * @brief A dummy type used to identify shared vectors. + * + * For a type to quack like an sv, it should define its identifier as ducks::sv::identifier. + * If a type quacks like ducks::sv::identifier, it will be treated as an sv by compiler checks. + */ +struct identifier {}; +/** +* @brief Concept for all shared vectors. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as sv::identifier. +*/ +template +concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::sv::identifier +} +} + +/** + * @brief Shared vector structure. + * + * @tparam _T The packed data type used for the vector elements. + * @tparam _tiles The size of the tile, in units of TILE_ROW_DIM (16 for fp16, bf16, fp32). + * + * Shared vectors are used to accumulate and map values across shared tiles. + * Unlike every other structure present in ThunderKittens, these have a simple + * uniform layout which is just an array in memory. EZ! + */ +template +struct KITTENS_DEFAULT_ALIGN sv { + using identifier = ducks::sv::identifier; + using T = base_types::packing<_T>::unpacked_type; + using T2 = base_types::packing<_T>::packed_type; + using dtype = T; ///< Data type of the elements in the tile. + + static constexpr int length = _length; ///< Length in elements. + static_assert(length % TILE_ROW_DIM == 0, "Length must be divisible by the tile dimension"); + static constexpr int tiles = length / TILE_ROW_DIM; ///< Length in subtiles.' + #ifdef KITTENS_HOPPER + static_assert(!std::is_same_v && !std::is_same_v, "Unsupported type for fp8"); + #endif + +#ifdef KITTENS_HOPPER + static constexpr int num_alloc_elements = ((length * sizeof(dtype) + 127) / 128) * (128 / sizeof(dtype)); // round up to the nearest 128-byte boundary +#else + static constexpr int num_alloc_elements = length; +#endif + dtype data[num_alloc_elements]; ///< The actual shared vector data. + + __device__ static inline T* idx(T *ptr, int idx) { // useful for computations in shared address space, as silly as it sounds. + return ptr[idx]; + } + + __device__ inline dtype& operator[](size_t idx) { return data[idx]; } + __device__ inline const dtype& operator[](size_t idx) const { return data[idx]; } + + template __device__ inline sv<_T, sub_length> &subvec(int idx) { + return *(sv*)&data[idx * sub_length]; + } + template __device__ inline const sv<_T, sub_length> &subvec(int idx) const { + return *(sv*)&data[idx * sub_length]; + } + + __device__ inline void operator=(const dtype &value) { // runs at warp scope by default + #pragma unroll + for(int i = kittens::laneid(); i < length; i += WARP_THREADS) { + data[i] = value; + } + } +}; + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +// vector types +template using sv_bf = sv; +template using sv_hf = sv; +template using sv_fl = sv; + +/* ---------- PRINTOUTS ---------- */ + +template +__device__ inline void print(const SV& sv) { + printf("Shared Vector %d:\n", SV::length); + for(int i = 0; i < SV::length; i++) { + if constexpr (std::is_same_v) { + printf("%f ", static_cast(sv[i])); +#ifdef KITTENS_BLACKWELL + } else if constexpr (std::is_same_v) { + printf("%f ", static_cast(sv[i])); +#endif + } else if constexpr (std::is_same_v) { + printf("%f ", __bfloat162float(sv[i])); + } else if constexpr (std::is_same_v) { + printf("%f ", __half2float(sv[i])); + } else if constexpr (std::is_same_v) { + printf("%f ", sv[i]); + } else { + printf("%d ", (int)(sv[i])); + } + } + printf("\n"); +} + +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/tensor/tensor.cuh b/extra/thunder/cuda/include/types/tensor/tensor.cuh new file mode 100644 index 0000000000..fd26274b0a --- /dev/null +++ b/extra/thunder/cuda/include/types/tensor/tensor.cuh @@ -0,0 +1,112 @@ +/** + * @file + * @brief An aggregate header file for all the tensor types defined by ThunderKittens. + */ + +#pragma once + +#include "tt.cuh" + +// A thin wrapper that allows for certain compile-time checks to be performed when allocating tensor memory. +namespace kittens { +namespace ducks { +/** + * @namespace tensor_allocator + * + * @brief The namespace where concepts and abstract types for tensor memory allocation live. + */ +namespace tensor_allocator { +/** + * @brief A dummy type used to identify tensor memory. + */ +struct identifier {}; +/** +* @brief Concept for all tensor_allocator types. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as tensor_allocator::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::tt::identifier +} // namespace tensor_allocator +} // namespace ducks + +template struct tensor_allocator { + using identifier = ducks::tensor_allocator::identifier; + static constexpr int nblocks = _nblocks; + static constexpr int cols =((512/nblocks) / 32) * 32; + static constexpr int ncta = _ncta; + uint32_t addr; + template __device__ inline void check_bounds() { + static_assert(col_offset >= 0 && col_offset + TT::cols <= cols, "Tile allocation extends out of bounds of the tensor allocator!"); + } + __device__ inline tensor_allocator() { + __shared__ uint32_t shared_addr; + static_assert(cols>0 && cols%32==0, "cols must be a multiple of 32"); + if constexpr (ncta == 1) { + if(warpid() == 0) { + asm volatile( + "tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;\n" + :: "l"((uint64_t)&shared_addr), "n"(cols) + ); + asm volatile("tcgen05.relinquish_alloc_permit.cta_group::1.sync.aligned;\n"); + } + } + else { + if(warpid() == 0) { + asm volatile( + "tcgen05.alloc.cta_group::2.sync.aligned.shared::cta.b32 [%0], %1;\n" + :: "l"((uint64_t)&shared_addr), "n"(cols) + ); + asm volatile("tcgen05.relinquish_alloc_permit.cta_group::2.sync.aligned;\n"); + } + } + asm volatile("tcgen05.fence::before_thread_sync;\n"); + asm volatile("bar.sync 0;\n"); + asm volatile("tcgen05.fence::after_thread_sync;\n"); + addr = shared_addr; + } + __device__ inline uint32_t get_addr(int superlane, int col_offset) const { return addr + ((superlane*16) << 16) + col_offset; } + template __device__ inline auto allocate(int superlane, int col_offset) { +#ifndef NDEBUG + if(col_offset + TT::cols > cols) { + printf("Tile allocation extends out of bounds of the tensor allocator! col_offset: %d, TT::cols: %d, allocator cols: %d\n", col_offset, TT::cols, cols); + asm volatile("trap;"); + } + if(superlane < 0 || superlane > 1) { + printf("Superlane must be 0 or 1! superlane: %d\n", superlane); + asm volatile("trap;"); + } +#endif + return TT(get_addr(superlane, col_offset)); + } + template __device__ inline auto allocate(int col_offset) { +#ifndef NDEBUG + if(col_offset + TT::cols > cols) { + printf("Tile allocation extends out of bounds of the tensor allocator! col_offset: %d, TT::cols: %d, allocator cols: %d\n", col_offset, TT::cols, cols); + asm volatile("trap;"); + } +#endif + return TT(get_addr(0, col_offset)); + } + __device__ inline ~tensor_allocator() { // Note that this must be called after all threads are done with that tensor memory -- likely after a syncthreads / cluster::sync()! + if constexpr (ncta == 1) { + if(warpid() == 0) { + asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;\n" + :: "r"(addr), "n"(cols) + ); + } + } + else { + if(warpid() == 0) { + asm volatile("tcgen05.dealloc.cta_group::2.sync.aligned.b32 %0, %1;\n" + :: "r"(addr), "n"(cols) + ); + } + } + } +}; + +} // namespace kittens \ No newline at end of file diff --git a/extra/thunder/cuda/include/types/tensor/tt.cuh b/extra/thunder/cuda/include/types/tensor/tt.cuh new file mode 100644 index 0000000000..2b1aba2f70 --- /dev/null +++ b/extra/thunder/cuda/include/types/tensor/tt.cuh @@ -0,0 +1,97 @@ +/** + * @file + * @brief The ThunderKittens tensor memory struct. + */ + +#pragma once + +#include "../../common/common.cuh" + +/* ---------- MAIN tt STRUCT ---------- */ + +// these are helper structs for type inference +namespace kittens { +namespace ducks { +/** + * @namespace tt + * + * @brief The namespace where concepts and abstract types for shared tiles live. + */ +namespace tt { +/** + * @brief A dummy type used to identify tensor memory. + */ +struct identifier {}; +/** +* @brief Concept for all tt tiles. +* @tparam T The type to check against the concept requirements. +* +* Requires: +* - T has a nested type identifier that is the same as tt::identifier. +*/ +template concept all = requires { + typename T::identifier; // Checks if T::identifier exists +} && std::is_same_v; // Checks if T::identifier is ducks::tt::identifier +template concept half = all && T::rows == 64; +template concept full = all && T::rows == 128; +} // namespace tt +} // namespace ducks + +/** + * @brief Shared memory tile structure for various data types and layouts. + * + * @tparam T The data type of the elements in the tile. Not packed! + * @tparam _rows The height of the tile. + * @tparam _cols The width of the tile. + */ +template +struct tt { + using identifier = ducks::tt::identifier; ///< Type identifier for shared memory tile. + using T = base_types::packing<_T>::unpacked_type; + using T2 = base_types::packing<_T>::packed_type; + using dtype = T; ///< Data type of the elements in the tile. + + static constexpr int rows = _rows; + static constexpr int cols = _cols; + static constexpr int height = rows / kittens::TILE_ROW_DIM; + static constexpr int width = cols / kittens::TILE_COL_DIM; + + uint32_t addr; + + __device__ inline tt() : addr(0) {} + __device__ inline tt(uint32_t addr) : addr(addr) {} + + template __device__ inline TT subtile(int row_offset, int col_offset) const { +#ifndef NDEBUG + if(row_offset < 0 || row_offset+TT::rows > rows || col_offset < 0 || col_offset+TT::cols > cols) { + printf("Subtile out of bounds! full tile rows: %d, full tile cols: %d, subtile rows: %d, subtile cols: %d, row_offset: %d, col_offset: %d\n", rows, cols, TT::rows, TT::cols, row_offset, col_offset); + asm volatile("trap;"); + } +#endif + return TT(addr + (row_offset<<16) + col_offset/(4/(uint32_t)sizeof(T))); + } + template __device__ inline uint32_t chunk_addr(int chunk) const { + if constexpr (transpose) { + if constexpr (std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { + return addr + ((16 * chunk) << 16); + } + else { + static_assert(sizeof(T) == 999, "Currently unsupported type for input to an mma."); + } + } + else { + if constexpr (std::is_same_v || std::is_same_v) { + return addr + (16 * chunk / (4/(uint32_t)sizeof(T))); + } + else if constexpr (std::is_same_v || std::is_same_v) { + return addr + (32 * chunk / (4/(uint32_t)sizeof(T))); + } + else { + static_assert(sizeof(T) == 999, "Currently unsupported type for input to an mma."); + } + } + } + +}; + +} // namespace kittens diff --git a/extra/thunder/cuda/include/types/types.cuh b/extra/thunder/cuda/include/types/types.cuh new file mode 100644 index 0000000000..42fbb1a3d5 --- /dev/null +++ b/extra/thunder/cuda/include/types/types.cuh @@ -0,0 +1,68 @@ +/** + * @file + * @brief An aggregate header file for all the register and shared types defined by ThunderKittens. + */ + +#pragma once + +#include "device/device.cuh" +#include "register/register.cuh" +#include "shared/shared.cuh" +#include "global/global.cuh" +#if defined(KITTENS_HOPPER) || defined(KITTENS_BLACKWELL) +#include "device/device.cuh" +#endif +#ifdef KITTENS_BLACKWELL +#include "tensor/tensor.cuh" +#endif + +/* ---------- WRAPPERS FOR PRETTINESS ---------- */ + +namespace kittens { + +/** + * @brief Row vector type alias. + * + * This template alias provides a convenient way to refer to the row vector type + * associated with a given class or type `T`. It assumes that the class `T` has + * a nested type named `row_vec`. + * + * @tparam T The class or type for which the row vector type is defined. + * + * Example usage: + * @code + * kittens::row_vec row_vector; + * @endcode + */ +template +using row_vec = T::row_vec; + +/** + * @brief Column vector type alias. + * + * This template alias provides a convenient way to refer to the column vector type + * associated with a given class or type `T`. It assumes that the class `T` has + * a nested type named `col_vec`. + * + * @tparam T The class or type for which the column vector type is defined. + * + * Example usage: + * @code + * kittens::col_vec col_vector; + * @endcode + */ +template +using col_vec = T::col_vec; + +// ^ this code lives here because it applies to both sv and rv types + +// register tile layouts +using row_l = ducks::rt_layout::row; +using col_l = ducks::rt_layout::col; + +// register vector layouts +using align_l = ducks::rv_layout::align; +using ortho_l = ducks::rv_layout::ortho; +using naive_l = ducks::rv_layout::naive; + +} diff --git a/extra/thunder/gemm.py b/extra/thunder/metal/gemm.py similarity index 100% rename from extra/thunder/gemm.py rename to extra/thunder/metal/gemm.py diff --git a/extra/thunder/include/common/base_ops.metal b/extra/thunder/metal/include/common/base_ops.metal similarity index 100% rename from extra/thunder/include/common/base_ops.metal rename to extra/thunder/metal/include/common/base_ops.metal diff --git a/extra/thunder/include/common/base_types.metal b/extra/thunder/metal/include/common/base_types.metal similarity index 100% rename from extra/thunder/include/common/base_types.metal rename to extra/thunder/metal/include/common/base_types.metal diff --git a/extra/thunder/include/common/common.metal b/extra/thunder/metal/include/common/common.metal similarity index 100% rename from extra/thunder/include/common/common.metal rename to extra/thunder/metal/include/common/common.metal diff --git a/extra/thunder/include/common/utils.metal b/extra/thunder/metal/include/common/utils.metal similarity index 100% rename from extra/thunder/include/common/utils.metal rename to extra/thunder/metal/include/common/utils.metal diff --git a/extra/thunder/include/ops/group/group.metal b/extra/thunder/metal/include/ops/group/group.metal similarity index 100% rename from extra/thunder/include/ops/group/group.metal rename to extra/thunder/metal/include/ops/group/group.metal diff --git a/extra/thunder/include/ops/group/memory/memory.metal b/extra/thunder/metal/include/ops/group/memory/memory.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/memory.metal rename to extra/thunder/metal/include/ops/group/memory/memory.metal diff --git a/extra/thunder/include/ops/group/memory/tile/global_to_register.metal b/extra/thunder/metal/include/ops/group/memory/tile/global_to_register.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/tile/global_to_register.metal rename to extra/thunder/metal/include/ops/group/memory/tile/global_to_register.metal diff --git a/extra/thunder/include/ops/group/memory/tile/global_to_shared.metal b/extra/thunder/metal/include/ops/group/memory/tile/global_to_shared.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/tile/global_to_shared.metal rename to extra/thunder/metal/include/ops/group/memory/tile/global_to_shared.metal diff --git a/extra/thunder/include/ops/group/memory/tile/shared_to_register.metal b/extra/thunder/metal/include/ops/group/memory/tile/shared_to_register.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/tile/shared_to_register.metal rename to extra/thunder/metal/include/ops/group/memory/tile/shared_to_register.metal diff --git a/extra/thunder/include/ops/group/memory/tile/tile.metal b/extra/thunder/metal/include/ops/group/memory/tile/tile.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/tile/tile.metal rename to extra/thunder/metal/include/ops/group/memory/tile/tile.metal diff --git a/extra/thunder/include/ops/group/memory/vec/global_to_register.metal b/extra/thunder/metal/include/ops/group/memory/vec/global_to_register.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/vec/global_to_register.metal rename to extra/thunder/metal/include/ops/group/memory/vec/global_to_register.metal diff --git a/extra/thunder/include/ops/group/memory/vec/global_to_shared.metal b/extra/thunder/metal/include/ops/group/memory/vec/global_to_shared.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/vec/global_to_shared.metal rename to extra/thunder/metal/include/ops/group/memory/vec/global_to_shared.metal diff --git a/extra/thunder/include/ops/group/memory/vec/shared_to_register.metal b/extra/thunder/metal/include/ops/group/memory/vec/shared_to_register.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/vec/shared_to_register.metal rename to extra/thunder/metal/include/ops/group/memory/vec/shared_to_register.metal diff --git a/extra/thunder/include/ops/group/memory/vec/vec.metal b/extra/thunder/metal/include/ops/group/memory/vec/vec.metal similarity index 100% rename from extra/thunder/include/ops/group/memory/vec/vec.metal rename to extra/thunder/metal/include/ops/group/memory/vec/vec.metal diff --git a/extra/thunder/include/ops/group/shared/shared.metal b/extra/thunder/metal/include/ops/group/shared/shared.metal similarity index 100% rename from extra/thunder/include/ops/group/shared/shared.metal rename to extra/thunder/metal/include/ops/group/shared/shared.metal diff --git a/extra/thunder/include/ops/group/shared/tile/conversions.metal b/extra/thunder/metal/include/ops/group/shared/tile/conversions.metal similarity index 100% rename from extra/thunder/include/ops/group/shared/tile/conversions.metal rename to extra/thunder/metal/include/ops/group/shared/tile/conversions.metal diff --git a/extra/thunder/include/ops/group/shared/tile/maps.metal b/extra/thunder/metal/include/ops/group/shared/tile/maps.metal similarity index 100% rename from extra/thunder/include/ops/group/shared/tile/maps.metal rename to extra/thunder/metal/include/ops/group/shared/tile/maps.metal diff --git a/extra/thunder/include/ops/group/shared/tile/reductions.metal b/extra/thunder/metal/include/ops/group/shared/tile/reductions.metal similarity index 100% rename from extra/thunder/include/ops/group/shared/tile/reductions.metal rename to extra/thunder/metal/include/ops/group/shared/tile/reductions.metal diff --git a/extra/thunder/include/ops/group/shared/tile/tile.metal b/extra/thunder/metal/include/ops/group/shared/tile/tile.metal similarity index 100% rename from extra/thunder/include/ops/group/shared/tile/tile.metal rename to extra/thunder/metal/include/ops/group/shared/tile/tile.metal diff --git a/extra/thunder/include/ops/group/shared/vec/conversions.metal b/extra/thunder/metal/include/ops/group/shared/vec/conversions.metal similarity index 100% rename from extra/thunder/include/ops/group/shared/vec/conversions.metal rename to extra/thunder/metal/include/ops/group/shared/vec/conversions.metal diff --git a/extra/thunder/include/ops/group/shared/vec/maps.metal b/extra/thunder/metal/include/ops/group/shared/vec/maps.metal similarity index 100% rename from extra/thunder/include/ops/group/shared/vec/maps.metal rename to extra/thunder/metal/include/ops/group/shared/vec/maps.metal diff --git a/extra/thunder/include/ops/group/shared/vec/vec.metal b/extra/thunder/metal/include/ops/group/shared/vec/vec.metal similarity index 100% rename from extra/thunder/include/ops/group/shared/vec/vec.metal rename to extra/thunder/metal/include/ops/group/shared/vec/vec.metal diff --git a/extra/thunder/include/ops/ops.metal b/extra/thunder/metal/include/ops/ops.metal similarity index 100% rename from extra/thunder/include/ops/ops.metal rename to extra/thunder/metal/include/ops/ops.metal diff --git a/extra/thunder/include/ops/warp/memory/memory.metal b/extra/thunder/metal/include/ops/warp/memory/memory.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/memory.metal rename to extra/thunder/metal/include/ops/warp/memory/memory.metal diff --git a/extra/thunder/include/ops/warp/memory/tile/complex/complex_global_to_register.metal b/extra/thunder/metal/include/ops/warp/memory/tile/complex/complex_global_to_register.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/tile/complex/complex_global_to_register.metal rename to extra/thunder/metal/include/ops/warp/memory/tile/complex/complex_global_to_register.metal diff --git a/extra/thunder/include/ops/warp/memory/tile/complex/complex_global_to_shared.metal b/extra/thunder/metal/include/ops/warp/memory/tile/complex/complex_global_to_shared.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/tile/complex/complex_global_to_shared.metal rename to extra/thunder/metal/include/ops/warp/memory/tile/complex/complex_global_to_shared.metal diff --git a/extra/thunder/include/ops/warp/memory/tile/complex/complex_shared_to_register.metal b/extra/thunder/metal/include/ops/warp/memory/tile/complex/complex_shared_to_register.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/tile/complex/complex_shared_to_register.metal rename to extra/thunder/metal/include/ops/warp/memory/tile/complex/complex_shared_to_register.metal diff --git a/extra/thunder/include/ops/warp/memory/tile/global_to_register.metal b/extra/thunder/metal/include/ops/warp/memory/tile/global_to_register.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/tile/global_to_register.metal rename to extra/thunder/metal/include/ops/warp/memory/tile/global_to_register.metal diff --git a/extra/thunder/include/ops/warp/memory/tile/global_to_shared.metal b/extra/thunder/metal/include/ops/warp/memory/tile/global_to_shared.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/tile/global_to_shared.metal rename to extra/thunder/metal/include/ops/warp/memory/tile/global_to_shared.metal diff --git a/extra/thunder/include/ops/warp/memory/tile/shared_to_register.metal b/extra/thunder/metal/include/ops/warp/memory/tile/shared_to_register.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/tile/shared_to_register.metal rename to extra/thunder/metal/include/ops/warp/memory/tile/shared_to_register.metal diff --git a/extra/thunder/include/ops/warp/memory/tile/tile.metal b/extra/thunder/metal/include/ops/warp/memory/tile/tile.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/tile/tile.metal rename to extra/thunder/metal/include/ops/warp/memory/tile/tile.metal diff --git a/extra/thunder/include/ops/warp/memory/util/util.metal b/extra/thunder/metal/include/ops/warp/memory/util/util.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/util/util.metal rename to extra/thunder/metal/include/ops/warp/memory/util/util.metal diff --git a/extra/thunder/include/ops/warp/memory/vec/global_to_register.metal b/extra/thunder/metal/include/ops/warp/memory/vec/global_to_register.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/vec/global_to_register.metal rename to extra/thunder/metal/include/ops/warp/memory/vec/global_to_register.metal diff --git a/extra/thunder/include/ops/warp/memory/vec/global_to_shared.metal b/extra/thunder/metal/include/ops/warp/memory/vec/global_to_shared.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/vec/global_to_shared.metal rename to extra/thunder/metal/include/ops/warp/memory/vec/global_to_shared.metal diff --git a/extra/thunder/include/ops/warp/memory/vec/shared_to_register.metal b/extra/thunder/metal/include/ops/warp/memory/vec/shared_to_register.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/vec/shared_to_register.metal rename to extra/thunder/metal/include/ops/warp/memory/vec/shared_to_register.metal diff --git a/extra/thunder/include/ops/warp/memory/vec/vec.metal b/extra/thunder/metal/include/ops/warp/memory/vec/vec.metal similarity index 100% rename from extra/thunder/include/ops/warp/memory/vec/vec.metal rename to extra/thunder/metal/include/ops/warp/memory/vec/vec.metal diff --git a/extra/thunder/include/ops/warp/register/register.metal b/extra/thunder/metal/include/ops/warp/register/register.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/register.metal rename to extra/thunder/metal/include/ops/warp/register/register.metal diff --git a/extra/thunder/include/ops/warp/register/tile/conversions.metal b/extra/thunder/metal/include/ops/warp/register/tile/conversions.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/tile/conversions.metal rename to extra/thunder/metal/include/ops/warp/register/tile/conversions.metal diff --git a/extra/thunder/include/ops/warp/register/tile/maps.metal b/extra/thunder/metal/include/ops/warp/register/tile/maps.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/tile/maps.metal rename to extra/thunder/metal/include/ops/warp/register/tile/maps.metal diff --git a/extra/thunder/include/ops/warp/register/tile/mma.metal b/extra/thunder/metal/include/ops/warp/register/tile/mma.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/tile/mma.metal rename to extra/thunder/metal/include/ops/warp/register/tile/mma.metal diff --git a/extra/thunder/include/ops/warp/register/tile/reductions.metal b/extra/thunder/metal/include/ops/warp/register/tile/reductions.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/tile/reductions.metal rename to extra/thunder/metal/include/ops/warp/register/tile/reductions.metal diff --git a/extra/thunder/include/ops/warp/register/tile/tile.metal b/extra/thunder/metal/include/ops/warp/register/tile/tile.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/tile/tile.metal rename to extra/thunder/metal/include/ops/warp/register/tile/tile.metal diff --git a/extra/thunder/include/ops/warp/register/vec/conversions.metal b/extra/thunder/metal/include/ops/warp/register/vec/conversions.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/vec/conversions.metal rename to extra/thunder/metal/include/ops/warp/register/vec/conversions.metal diff --git a/extra/thunder/include/ops/warp/register/vec/maps.metal b/extra/thunder/metal/include/ops/warp/register/vec/maps.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/vec/maps.metal rename to extra/thunder/metal/include/ops/warp/register/vec/maps.metal diff --git a/extra/thunder/include/ops/warp/register/vec/reductions.metal b/extra/thunder/metal/include/ops/warp/register/vec/reductions.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/vec/reductions.metal rename to extra/thunder/metal/include/ops/warp/register/vec/reductions.metal diff --git a/extra/thunder/include/ops/warp/register/vec/vec.metal b/extra/thunder/metal/include/ops/warp/register/vec/vec.metal similarity index 100% rename from extra/thunder/include/ops/warp/register/vec/vec.metal rename to extra/thunder/metal/include/ops/warp/register/vec/vec.metal diff --git a/extra/thunder/include/ops/warp/shared/shared.metal b/extra/thunder/metal/include/ops/warp/shared/shared.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/shared.metal rename to extra/thunder/metal/include/ops/warp/shared/shared.metal diff --git a/extra/thunder/include/ops/warp/shared/tile/conversions.metal b/extra/thunder/metal/include/ops/warp/shared/tile/conversions.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/tile/conversions.metal rename to extra/thunder/metal/include/ops/warp/shared/tile/conversions.metal diff --git a/extra/thunder/include/ops/warp/shared/tile/maps.metal b/extra/thunder/metal/include/ops/warp/shared/tile/maps.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/tile/maps.metal rename to extra/thunder/metal/include/ops/warp/shared/tile/maps.metal diff --git a/extra/thunder/include/ops/warp/shared/tile/reductions.metal b/extra/thunder/metal/include/ops/warp/shared/tile/reductions.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/tile/reductions.metal rename to extra/thunder/metal/include/ops/warp/shared/tile/reductions.metal diff --git a/extra/thunder/include/ops/warp/shared/tile/tile.metal b/extra/thunder/metal/include/ops/warp/shared/tile/tile.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/tile/tile.metal rename to extra/thunder/metal/include/ops/warp/shared/tile/tile.metal diff --git a/extra/thunder/include/ops/warp/shared/vec/conversions.metal b/extra/thunder/metal/include/ops/warp/shared/vec/conversions.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/vec/conversions.metal rename to extra/thunder/metal/include/ops/warp/shared/vec/conversions.metal diff --git a/extra/thunder/include/ops/warp/shared/vec/maps.metal b/extra/thunder/metal/include/ops/warp/shared/vec/maps.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/vec/maps.metal rename to extra/thunder/metal/include/ops/warp/shared/vec/maps.metal diff --git a/extra/thunder/include/ops/warp/shared/vec/reductions.metal b/extra/thunder/metal/include/ops/warp/shared/vec/reductions.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/vec/reductions.metal rename to extra/thunder/metal/include/ops/warp/shared/vec/reductions.metal diff --git a/extra/thunder/include/ops/warp/shared/vec/vec.metal b/extra/thunder/metal/include/ops/warp/shared/vec/vec.metal similarity index 100% rename from extra/thunder/include/ops/warp/shared/vec/vec.metal rename to extra/thunder/metal/include/ops/warp/shared/vec/vec.metal diff --git a/extra/thunder/include/ops/warp/warp.metal b/extra/thunder/metal/include/ops/warp/warp.metal similarity index 100% rename from extra/thunder/include/ops/warp/warp.metal rename to extra/thunder/metal/include/ops/warp/warp.metal diff --git a/extra/thunder/include/tk.metal b/extra/thunder/metal/include/tk.metal similarity index 100% rename from extra/thunder/include/tk.metal rename to extra/thunder/metal/include/tk.metal diff --git a/extra/thunder/include/types/global/cgl.metal b/extra/thunder/metal/include/types/global/cgl.metal similarity index 100% rename from extra/thunder/include/types/global/cgl.metal rename to extra/thunder/metal/include/types/global/cgl.metal diff --git a/extra/thunder/include/types/global/gl.metal b/extra/thunder/metal/include/types/global/gl.metal similarity index 100% rename from extra/thunder/include/types/global/gl.metal rename to extra/thunder/metal/include/types/global/gl.metal diff --git a/extra/thunder/include/types/global/global.metal b/extra/thunder/metal/include/types/global/global.metal similarity index 100% rename from extra/thunder/include/types/global/global.metal rename to extra/thunder/metal/include/types/global/global.metal diff --git a/extra/thunder/include/types/global/util.metal b/extra/thunder/metal/include/types/global/util.metal similarity index 100% rename from extra/thunder/include/types/global/util.metal rename to extra/thunder/metal/include/types/global/util.metal diff --git a/extra/thunder/include/types/register/crt.metal b/extra/thunder/metal/include/types/register/crt.metal similarity index 100% rename from extra/thunder/include/types/register/crt.metal rename to extra/thunder/metal/include/types/register/crt.metal diff --git a/extra/thunder/include/types/register/crv.metal b/extra/thunder/metal/include/types/register/crv.metal similarity index 100% rename from extra/thunder/include/types/register/crv.metal rename to extra/thunder/metal/include/types/register/crv.metal diff --git a/extra/thunder/include/types/register/register.metal b/extra/thunder/metal/include/types/register/register.metal similarity index 100% rename from extra/thunder/include/types/register/register.metal rename to extra/thunder/metal/include/types/register/register.metal diff --git a/extra/thunder/include/types/register/rt.metal b/extra/thunder/metal/include/types/register/rt.metal similarity index 100% rename from extra/thunder/include/types/register/rt.metal rename to extra/thunder/metal/include/types/register/rt.metal diff --git a/extra/thunder/include/types/register/rt_base.metal b/extra/thunder/metal/include/types/register/rt_base.metal similarity index 100% rename from extra/thunder/include/types/register/rt_base.metal rename to extra/thunder/metal/include/types/register/rt_base.metal diff --git a/extra/thunder/include/types/register/rt_layout.metal b/extra/thunder/metal/include/types/register/rt_layout.metal similarity index 100% rename from extra/thunder/include/types/register/rt_layout.metal rename to extra/thunder/metal/include/types/register/rt_layout.metal diff --git a/extra/thunder/include/types/register/rv.metal b/extra/thunder/metal/include/types/register/rv.metal similarity index 100% rename from extra/thunder/include/types/register/rv.metal rename to extra/thunder/metal/include/types/register/rv.metal diff --git a/extra/thunder/include/types/register/rv_layout.metal b/extra/thunder/metal/include/types/register/rv_layout.metal similarity index 100% rename from extra/thunder/include/types/register/rv_layout.metal rename to extra/thunder/metal/include/types/register/rv_layout.metal diff --git a/extra/thunder/include/types/shared/cst.metal b/extra/thunder/metal/include/types/shared/cst.metal similarity index 100% rename from extra/thunder/include/types/shared/cst.metal rename to extra/thunder/metal/include/types/shared/cst.metal diff --git a/extra/thunder/include/types/shared/csv.metal b/extra/thunder/metal/include/types/shared/csv.metal similarity index 100% rename from extra/thunder/include/types/shared/csv.metal rename to extra/thunder/metal/include/types/shared/csv.metal diff --git a/extra/thunder/include/types/shared/shared.metal b/extra/thunder/metal/include/types/shared/shared.metal similarity index 100% rename from extra/thunder/include/types/shared/shared.metal rename to extra/thunder/metal/include/types/shared/shared.metal diff --git a/extra/thunder/include/types/shared/st.metal b/extra/thunder/metal/include/types/shared/st.metal similarity index 100% rename from extra/thunder/include/types/shared/st.metal rename to extra/thunder/metal/include/types/shared/st.metal diff --git a/extra/thunder/include/types/shared/sv.metal b/extra/thunder/metal/include/types/shared/sv.metal similarity index 100% rename from extra/thunder/include/types/shared/sv.metal rename to extra/thunder/metal/include/types/shared/sv.metal diff --git a/extra/thunder/include/types/types.metal b/extra/thunder/metal/include/types/types.metal similarity index 100% rename from extra/thunder/include/types/types.metal rename to extra/thunder/metal/include/types/types.metal