#include "kittens.cuh" using namespace kittens; template> __device__ inline static void atomic_pk_add_bf16_with_warpid(const GL &dst, const RT &src, const COORD &idx, int warpid) { using T = base_types::packing::unpacked_type; using T2 = base_types::packing::packed_type; using U = typename GL::dtype; using U2 = base_types::packing::packed_type; static_assert(std::is_same_v, "RT must be a row layout"); static_assert(std::is_same_v, "atomic_pk_add_bf16 is only supported for bf16"); static_assert(std::is_same_v, "atomic_pk_add_bf16 is only supported where T is float"); U *dst_ptr = (U*)&dst[(idx.template unit_coord())]; const int row_stride = dst.template stride(); int laneid = kittens::laneid(); const uint32_t buffer_size = row_stride * RT::rows * sizeof(U); std::uintptr_t as_int = reinterpret_cast(dst_ptr); std::uint64_t as_u64 = static_cast(as_int); buffer_resource br = make_buffer_resource(as_u64, buffer_size, 0x00020000); int lane_offset = laneid * 2 + warpid * 512; using range_type = ducks::art::get_nth_range_t; static_assert(range_type::lo + 3 == range_type::hi, "buffer_atomic_pk_add_bf16 requires 4 consecutive registers"); static_assert(range_type::hi < 256, "registers need to be VGPRS"); const int tile_offset = N * row_stride * RT::base_tile_rows + M * 256; constexpr int GPR_0_BF16 = range_type::lo; constexpr int GPR_1_BF16 = range_type::lo + 1; macros::v_cvt_pk_bf16_f32(); macros::v_cvt_pk_bf16_f32(); const uint32_t byte_offset_0 = static_cast((tile_offset + lane_offset) * sizeof(U)); const uint32_t byte_offset_1 = static_cast((tile_offset + lane_offset + 128) * sizeof(U)); macros::buffer_atomic_pk_add_bf16(br, byte_offset_0); macros::buffer_atomic_pk_add_bf16(br, byte_offset_1); } template> __device__ inline static void atomic_pk_add_bf16_with_warpid(const GL &dst, const RT &src, const COORD &idx, int warpid) { using T = base_types::packing::unpacked_type; using T2 = base_types::packing::packed_type; using U = typename GL::dtype; using U2 = base_types::packing::packed_type; static_assert(std::is_same_v, "RT must be a row layout"); static_assert(std::is_same_v, "atomic_pk_add_bf16 is only supported for bf16"); static_assert(std::is_same_v, "atomic_pk_add_bf16 is only supported where T is float"); U *dst_ptr = (U*)&dst[(idx.template unit_coord())]; const int row_stride = dst.template stride(); int laneid = kittens::laneid(); const uint32_t buffer_size = row_stride * RT::rows * sizeof(U); std::uintptr_t as_int = reinterpret_cast(dst_ptr); std::uint64_t as_u64 = static_cast(as_int); buffer_resource br = make_buffer_resource(as_u64, buffer_size, 0x00020000); int lane_offset = laneid * 2 + warpid * 512; auto perform_atomic_pk_add_bf16_with_warpid = [&]() { using range_type = ducks::art::get_nth_range_t; static_assert(range_type::lo + 3 == range_type::hi, "buffer_atomic_pk_add_bf16 requires 4 consecutive registers"); static_assert(range_type::hi < 256, "registers need to be VGPRS"); const int tile_offset = N * row_stride * RT::base_tile_rows + M * 256; constexpr int GPR_0_BF16 = range_type::lo; constexpr int GPR_1_BF16 = range_type::lo + 1; macros::v_cvt_pk_bf16_f32(); macros::v_cvt_pk_bf16_f32(); const uint32_t byte_offset_0 = static_cast((tile_offset + lane_offset) * sizeof(U)); const uint32_t byte_offset_1 = static_cast((tile_offset + lane_offset + 128) * sizeof(U)); macros::buffer_atomic_pk_add_bf16(br, byte_offset_0); macros::buffer_atomic_pk_add_bf16(br, byte_offset_1); }; // Compile-time nested loops over N and M [&](std::index_sequence) { ([&]() { [&](std::index_sequence) { ([&]() { perform_atomic_pk_add_bf16_with_warpid.template operator()(); }.template operator()(), ...); }(std::make_index_sequence{}); }.template operator()(), ...); }(std::make_index_sequence{}); }