mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
feat: add thunderkittens (#12590)
This commit is contained in:
400
extra/thunder/cuda/include/common/base_ops.cuh
Normal file
400
extra/thunder/cuda/include/common/base_ops.cuh
Normal file
@@ -0,0 +1,400 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Basic operations on generic types.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <limits>
|
||||
#include "base_types.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/**
|
||||
* @namespace base_ops
|
||||
*
|
||||
* @brief A namespace for operations on basic data types.
|
||||
*/
|
||||
namespace base_ops {
|
||||
|
||||
/* ---------- CONST OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Represents the zero constant operation.
|
||||
*
|
||||
* This operation returns the zero value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the zero value.
|
||||
* @return The zero value of type T.
|
||||
*/
|
||||
struct zero {
|
||||
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::zero(); }
|
||||
};
|
||||
/**
|
||||
* @brief Represents the one constant operation.
|
||||
*
|
||||
* This operation returns the one value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the one value.
|
||||
* @return The one value of type T.
|
||||
*/
|
||||
struct one {
|
||||
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::one(); }
|
||||
};
|
||||
/**
|
||||
* @brief Represents the positive infinity constant operation.
|
||||
*
|
||||
* This operation returns the positive infinity value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the positive infinity value.
|
||||
* @return The positive infinity value of type T.
|
||||
*/
|
||||
struct pos_infty {
|
||||
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::pos_infty(); }
|
||||
};
|
||||
/**
|
||||
* @brief Represents the negative infinity constant operation.
|
||||
*
|
||||
* This operation returns the negative infinity value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the negative infinity value.
|
||||
* @return The negative infinity value of type T.
|
||||
*/
|
||||
struct neg_infty {
|
||||
template<typename T, typename... args> __device__ static inline constexpr T op(args... _) { return base_types::constants<T>::neg_infty(); }
|
||||
};
|
||||
|
||||
|
||||
/* ---------- UNARY OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Exponential function operation.
|
||||
*
|
||||
* This operation calculates the exponential of the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The exponential of the input value.
|
||||
*/
|
||||
struct exp {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return exp(x); }
|
||||
};
|
||||
template<> __device__ inline float exp::op<float> (const float &x ) { return __expf(x); }
|
||||
template<> __device__ inline float2 exp::op<float2>(const float2 &x) { return float2{__expf(x.x), __expf(x.y)}; }
|
||||
template<> __device__ inline bf16 exp::op<bf16> (const bf16 &x ) { return hexp(x); }
|
||||
template<> __device__ inline bf16_2 exp::op<bf16_2>(const bf16_2 &x) { return h2exp(x); }
|
||||
template<> __device__ inline half exp::op<half> (const half &x ) { return hexp(x); }
|
||||
template<> __device__ inline half_2 exp::op<half_2>(const half_2 &x) { return h2exp(x); }
|
||||
|
||||
/**
|
||||
* @brief Exponential function operation, in base 2
|
||||
*
|
||||
* This operation calculates the exponential of the input value, in base 2.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The exponential of the input value.
|
||||
*/
|
||||
struct exp2 {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return exp2f(x); }
|
||||
};
|
||||
template<> __device__ inline float exp2::op<float> (const float &x ) { return exp2f(x); }
|
||||
template<> __device__ inline float2 exp2::op<float2>(const float2 &x) { return float2{exp2f(x.x), exp2f(x.y)}; }
|
||||
template<> __device__ inline bf16 exp2::op<bf16> (const bf16 &x ) { return hexp2(x); }
|
||||
template<> __device__ inline bf16_2 exp2::op<bf16_2>(const bf16_2 &x) { return h2exp2(x); }
|
||||
template<> __device__ inline half exp2::op<half> (const half &x ) { return hexp2(x); }
|
||||
template<> __device__ inline half_2 exp2::op<half_2>(const half_2 &x) { return h2exp2(x); }
|
||||
/**
|
||||
* @brief Natural log function operation.
|
||||
*
|
||||
* This operation calculates the natural logarithm of the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The natural logarithm of the input value.
|
||||
*/
|
||||
struct log {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return log(x); }
|
||||
};
|
||||
template<> __device__ inline float log::op<float> (const float &x ) { return __logf(x); }
|
||||
template<> __device__ inline float2 log::op<float2>(const float2 &x) { return float2{__logf(x.x), __logf(x.y)}; }
|
||||
template<> __device__ inline bf16 log::op<bf16> (const bf16 &x ) { return hlog(x); }
|
||||
template<> __device__ inline bf16_2 log::op<bf16_2>(const bf16_2 &x) { return h2log(x); }
|
||||
template<> __device__ inline half log::op<half> (const half &x ) { return hlog(x); }
|
||||
template<> __device__ inline half_2 log::op<half_2>(const half_2 &x) { return h2log(x); }
|
||||
/**
|
||||
* @brief Logarithm base 2 operation.
|
||||
*
|
||||
* This operation calculates the logarithm base 2 of the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The logarithm base 2 of the input value.
|
||||
*/
|
||||
struct log2 {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return log2(x); }
|
||||
};
|
||||
template<> __device__ inline float log2::op<float> (const float &x ) { return __log2f(x); }
|
||||
template<> __device__ inline float2 log2::op<float2>(const float2 &x) { return float2{__log2f(x.x), __log2f(x.y)}; }
|
||||
template<> __device__ inline bf16 log2::op<bf16> (const bf16 &x ) { return hlog2(x); }
|
||||
template<> __device__ inline bf16_2 log2::op<bf16_2>(const bf16_2 &x) { return h2log2(x); }
|
||||
template<> __device__ inline half log2::op<half> (const half &x ) { return hlog2(x); }
|
||||
template<> __device__ inline half_2 log2::op<half_2>(const half_2 &x) { return h2log2(x); }
|
||||
/**
|
||||
* @brief Absolute value operation.
|
||||
*
|
||||
* This operation calculates the absolute value of the input.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The absolute value of the input.
|
||||
*/
|
||||
struct abs {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return abs(x); }
|
||||
};
|
||||
template<> __device__ inline float abs::op<float> (const float &x ) { return fabsf(x); }
|
||||
template<> __device__ inline float2 abs::op<float2>(const float2 &x) { return float2{fabsf(x.x), fabsf(x.y)}; }
|
||||
template<> __device__ inline bf16 abs::op<bf16> (const bf16 &x ) { return __habs(x); }
|
||||
template<> __device__ inline bf16_2 abs::op<bf16_2>(const bf16_2 &x) { return __habs2(x); }
|
||||
template<> __device__ inline half abs::op<half> (const half &x ) { return __habs(x); }
|
||||
template<> __device__ inline half_2 abs::op<half_2>(const half_2 &x) { return __habs2(x); }
|
||||
/**
|
||||
* @brief Rectified Linear Unit (ReLU) operation.
|
||||
*
|
||||
* This operation applies the ReLU function to the input, which is the
|
||||
* maximum of zero and the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The result of ReLU function applied to the input.
|
||||
*/
|
||||
struct relu {
|
||||
template<typename T> static __device__ inline T op(const T &x) { return max(x, base_types::constants<T>::zero()); }
|
||||
};
|
||||
template<> __device__ inline float relu::op<float> (const float &x ) { return max(x, 0.f); }
|
||||
template<> __device__ inline float2 relu::op<float2>(const float2 &x) { return float2{max(x.x, 0.f), max(x.y, 0.f)}; }
|
||||
template<> __device__ inline bf16 relu::op<bf16> (const bf16 &x ) { return __hmax(x, base_types::constants<bf16>::zero()); }
|
||||
template<> __device__ inline bf16_2 relu::op<bf16_2>(const bf16_2 &x) { return __hmax2(x, base_types::constants<bf16_2>::zero()); }
|
||||
template<> __device__ inline half relu::op<half> (const half &x ) { return __hmax(x, base_types::constants<half>::zero()); }
|
||||
template<> __device__ inline half_2 relu::op<half_2>(const half_2 &x) { return __hmax2(x, base_types::constants<half_2>::zero()); }
|
||||
/**
|
||||
* @brief Copy operation.
|
||||
*
|
||||
* This operation returns the input value unchanged.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The input value.
|
||||
* @return The same value as the input.
|
||||
*/
|
||||
struct copy { // for non-compile-time setters.
|
||||
template<typename T> static __device__ inline T op(const T &a) { return a; }
|
||||
};
|
||||
|
||||
|
||||
/* ---------- BINARY OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Copy2 operation.
|
||||
*
|
||||
* This operation returns the second input value unchanged.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value (ignored).
|
||||
* @param b[in] The second input value.
|
||||
* @return The same value as the second input.
|
||||
*/
|
||||
struct copy2 { // this turns out to be a slightly hacky op that makes some code cleaner :/
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return b; }
|
||||
};
|
||||
/**
|
||||
* @brief Sum operation.
|
||||
*
|
||||
* This operation calculates the sum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The sum of the input values.
|
||||
*/
|
||||
struct sum {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a+b; }
|
||||
};
|
||||
template<> __device__ inline float2 sum::op<float2>(const float2 &a, const float2 &b) {
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
float2 c;
|
||||
asm volatile("add.f32x2 %0, %1, %2;" : "=l"(*(uint64_t*)&c) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&b));
|
||||
return c;
|
||||
#else
|
||||
return float2{a.x+b.x, a.y+b.y};
|
||||
#endif
|
||||
}
|
||||
template<> __device__ inline bf16 sum::op<bf16> (const bf16 &a, const bf16 &b) { return __hadd(a, b); }
|
||||
template<> __device__ inline bf16_2 sum::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hadd2(a, b); }
|
||||
template<> __device__ inline half sum::op<half> (const half &a, const half &b) { return __hadd(a, b); }
|
||||
template<> __device__ inline half_2 sum::op<half_2>(const half_2 &a, const half_2 &b) { return __hadd2(a, b); }
|
||||
/**
|
||||
* @brief Subtraction operation.
|
||||
*
|
||||
* This operation calculates the difference between two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The difference between the input values.
|
||||
*/
|
||||
struct sub {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a-b; }
|
||||
};
|
||||
template<> __device__ inline float2 sub::op<float2>(const float2 &a, const float2 &b) {
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
float2 c;
|
||||
asm volatile("sub.f32x2 %0, %1, %2;" : "=l"(*(uint64_t*)&c) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&b));
|
||||
return c;
|
||||
#else
|
||||
return float2{a.x-b.x, a.y-b.y};
|
||||
#endif
|
||||
}
|
||||
template<> __device__ inline bf16 sub::op<bf16> (const bf16 &a, const bf16 &b) { return __hsub(a, b); }
|
||||
template<> __device__ inline bf16_2 sub::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hsub2(a, b); }
|
||||
template<> __device__ inline half sub::op<half> (const half &a, const half &b) { return __hsub(a, b); }
|
||||
template<> __device__ inline half_2 sub::op<half_2>(const half_2 &a, const half_2 &b) { return __hsub2(a, b); }
|
||||
/**
|
||||
* @brief Multiplication operation.
|
||||
*
|
||||
* This operation calculates the product of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The product of the input values.
|
||||
*/
|
||||
struct mul {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a*b; }
|
||||
};
|
||||
template<> __device__ inline float2 mul::op<float2>(const float2 &a, const float2 &b) {
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
float2 c;
|
||||
asm volatile("mul.f32x2 %0, %1, %2;" : "=l"(*(uint64_t*)&c) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&b));
|
||||
return c;
|
||||
#else
|
||||
return float2{a.x*b.x, a.y*b.y};
|
||||
#endif
|
||||
}
|
||||
template<> __device__ inline bf16 mul::op<bf16> (const bf16 &a, const bf16 &b) { return __hmul(a, b); }
|
||||
template<> __device__ inline bf16_2 mul::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmul2(a, b); }
|
||||
template<> __device__ inline half mul::op<half> (const half &a, const half &b) { return __hmul(a, b); }
|
||||
template<> __device__ inline half_2 mul::op<half_2>(const half_2 &a, const half_2 &b) { return __hmul2(a, b); }
|
||||
/**
|
||||
* @brief Division operation.
|
||||
*
|
||||
* This operation calculates the quotient of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The quotient of the input values.
|
||||
*/
|
||||
struct div {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return a/b; }
|
||||
};
|
||||
template<> __device__ inline float2 div::op<float2>(const float2 &a, const float2 &b) { return float2{a.x/b.x, a.y/b.y}; }
|
||||
template<> __device__ inline bf16 div::op<bf16> (const bf16 &a, const bf16 &b) { return __hdiv(a, b); }
|
||||
template<> __device__ inline bf16_2 div::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __h2div(a, b); } // this op is a special snowflake
|
||||
template<> __device__ inline half div::op<half> (const half &a, const half &b) { return __hdiv(a, b); }
|
||||
template<> __device__ inline half_2 div::op<half_2>(const half_2 &a, const half_2 &b) { return __h2div(a, b); }
|
||||
/**
|
||||
* @brief Maximum operation.
|
||||
*
|
||||
* This operation calculates the maximum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The maximum of the input values.
|
||||
*/
|
||||
struct max {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return ::max(a, b); }
|
||||
};
|
||||
template<> __device__ inline float2 max::op<float2>(const float2 &a, const float2 &b) { return float2{::max(a.x, b.x), ::max(a.y, b.y)}; }
|
||||
template<> __device__ inline bf16 max::op<bf16> (const bf16 &a, const bf16 &b) { return __hmax(a, b); }
|
||||
template<> __device__ inline bf16_2 max::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmax2(a, b); }
|
||||
template<> __device__ inline half max::op<half> (const half &a, const half &b) { return __hmax(a, b); }
|
||||
template<> __device__ inline half_2 max::op<half_2>(const half_2 &a, const half_2 &b) { return __hmax2(a, b); }
|
||||
/**
|
||||
* @brief Minimum operation.
|
||||
*
|
||||
* This operation calculates the minimum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The minimum of the input values.
|
||||
*/
|
||||
struct min {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b) { return ::min(a, b); }
|
||||
};
|
||||
template<> __device__ inline float2 min::op<float2>(const float2 &a, const float2 &b) { return float2{::min(a.x, b.x), ::min(a.y, b.y)}; }
|
||||
template<> __device__ inline bf16 min::op<bf16> (const bf16 &a, const bf16 &b) { return __hmin(a, b); }
|
||||
template<> __device__ inline bf16_2 min::op<bf16_2>(const bf16_2 &a, const bf16_2 &b) { return __hmin2(a, b); }
|
||||
template<> __device__ inline half min::op<half> (const half &a, const half &b) { return __hmin(a, b); }
|
||||
template<> __device__ inline half_2 min::op<half_2>(const half_2 &a, const half_2 &b) { return __hmin2(a, b); }
|
||||
|
||||
|
||||
/* ---------- TERNARY OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Fused multiply-add operation A * B + C.
|
||||
*
|
||||
* This operation performs a fused multiply-add, computing (A * B) + C with only one rounding.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @param c[in] The third input value to be added.
|
||||
* @return The result of the fused multiply-add operation.
|
||||
*/
|
||||
struct fma_AxBtC {
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b, const T &c) {
|
||||
return sum::op<T>(mul::op<T>(a, b), c);
|
||||
}
|
||||
};
|
||||
template<> __device__ inline float2 fma_AxBtC::op<float2>(const float2 &a, const float2 &b, const float2 &c) {
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
float2 d;
|
||||
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;" : "=l"(*(uint64_t*)&d) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&b), "l"(*(uint64_t*)&c));
|
||||
return d;
|
||||
#else
|
||||
return float2{a.x*b.x+c.x, a.y*b.y+c.y};
|
||||
#endif
|
||||
}
|
||||
/**
|
||||
* @brief Fused multiply-add operation A * C + B.
|
||||
*
|
||||
* This operation performs a fused multiply-add, computing (A * C) + B with only one rounding.
|
||||
* This is particularly useful for attention mechanisms in neural networks.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The third input value to be added.
|
||||
* @param c[in] The second input value.
|
||||
* @return The result of the fused multiply-add operation.
|
||||
*/
|
||||
struct fma_AxCtB { // this is the one needed for attention
|
||||
template<typename T> static __device__ inline T op(const T &a, const T &b, const T &c) {
|
||||
return sum::op<T>(mul::op<T>(a, c), b);
|
||||
}
|
||||
};
|
||||
template<> __device__ inline float2 fma_AxCtB::op<float2>(const float2 &a, const float2 &b, const float2 &c) {
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
float2 d;
|
||||
asm volatile("fma.rn.f32x2 %0, %1, %2, %3;" : "=l"(*(uint64_t*)&d) : "l"(*(uint64_t*)&a), "l"(*(uint64_t*)&c), "l"(*(uint64_t*)&b));
|
||||
return d;
|
||||
#else
|
||||
return float2{a.x*c.x+b.x, a.y*c.y+b.y};
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace base_ops
|
||||
|
||||
} // namespace kittens
|
||||
519
extra/thunder/cuda/include/common/base_types.cuh
Normal file
519
extra/thunder/cuda/include/common/base_types.cuh
Normal file
@@ -0,0 +1,519 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Declarations, manipulations, and wrappers for basic types.
|
||||
*
|
||||
* This file is a bunch of utilities for going back and forth between different types.
|
||||
*
|
||||
* Many of them are for the compiler, so as to clean up the code. It unfortunately
|
||||
* seems necessary when we have types we really care about that are less than word width.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include <cuda_fp8.h>
|
||||
#endif
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <string>
|
||||
#include <bit>
|
||||
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/**
|
||||
* @brief Bfloat16 floating-point type.
|
||||
*/
|
||||
using bf16 = __nv_bfloat16;
|
||||
/**
|
||||
* @brief Half-precision floating-point type.
|
||||
*/
|
||||
using half = __half;
|
||||
/**
|
||||
* @brief Packed word of two bfloat16 floating-point values.
|
||||
*/
|
||||
using bf16_2 = __nv_bfloat162;
|
||||
/**
|
||||
* @brief Packed word of two half-precision floating-point values.
|
||||
*/
|
||||
using half_2 = __half2;
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief float8 floating-point type.
|
||||
*/
|
||||
using fp8e4m3 = __nv_fp8_e4m3;
|
||||
using fp8e5m2 = __nv_fp8_e5m2;
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
using fp8e8m0 = __nv_fp8_e8m0;
|
||||
#endif
|
||||
/**
|
||||
* @brief 2-packed float8 floating-point type.
|
||||
*/
|
||||
using fp8e4m3_2 = __nv_fp8x2_e4m3;
|
||||
using fp8e5m2_2 = __nv_fp8x2_e5m2;
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
using fp8e8m0_2 = __nv_fp8x2_e8m0;
|
||||
#endif
|
||||
/**
|
||||
* @brief 4-packed float8 floating-point type.
|
||||
*/
|
||||
using fp8e4m3_4 = __nv_fp8x4_e4m3;
|
||||
using fp8e5m2_4 = __nv_fp8x4_e5m2;
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
using fp8e8m0_4 = __nv_fp8x4_e8m0;
|
||||
#endif
|
||||
#endif
|
||||
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace base_types
|
||||
*
|
||||
* @brief A namespace for concepts for basic data types.
|
||||
*/
|
||||
namespace base_types {
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
template<typename T>
|
||||
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4> || std::is_same_v<T, fp8e5m2_4> || std::is_same_v<T, fp8e8m0_4>; // could add half_2 later if implemented.
|
||||
template<typename T>
|
||||
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3> || std::is_same_v<T, fp8e5m2> || std::is_same_v<T, fp8e8m0>; // could add half_2 later if implemented.
|
||||
#else
|
||||
template<typename T>
|
||||
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2> || std::is_same_v<T, fp8e4m3_4> || std::is_same_v<T, fp8e5m2_4>;
|
||||
template<typename T>
|
||||
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half> || std::is_same_v<T, fp8e4m3> || std::is_same_v<T, fp8e5m2>;
|
||||
#endif
|
||||
#else
|
||||
template<typename T>
|
||||
concept T2 = std::is_same_v<T, float2> || std::is_same_v<T, bf16_2> || std::is_same_v<T, half_2>;
|
||||
template<typename T>
|
||||
concept T1 = std::is_same_v<T, float> || std::is_same_v<T, bf16 > || std::is_same_v<T, half>;
|
||||
#endif
|
||||
|
||||
} // namespace base_types
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @namespace base_types
|
||||
*
|
||||
* @brief A namespace for ThunderKittens basic data types.
|
||||
*/
|
||||
namespace base_types {
|
||||
|
||||
/**
|
||||
* @brief Provides compile-time constants for different types.
|
||||
*
|
||||
* @tparam T The type for which to provide constants.
|
||||
*/
|
||||
template<typename T> struct constants {
|
||||
/**
|
||||
* @brief Zero
|
||||
* @return Constexpr zero with type T
|
||||
*/
|
||||
static __device__ inline constexpr T zero() { return T{0}; }
|
||||
/**
|
||||
* @brief One
|
||||
* @return Constexpr one with type T
|
||||
*/
|
||||
static __device__ inline constexpr T one() { return T{1}; }
|
||||
/**
|
||||
* @brief Positive infinity. Particularly useful for initializing before a min op.
|
||||
* @return Constexpr positive infinity with type T
|
||||
*/
|
||||
static __device__ inline constexpr T pos_infty() { return T{INFINITY}; } // I'll find a better way at some point but this appears to work.
|
||||
/**
|
||||
* @brief Negative infinity. Particularly useful for initializing before a max op.
|
||||
* @return Constexpr negative infinity with type T
|
||||
*/
|
||||
static __device__ inline constexpr T neg_infty() { return T{-INFINITY}; }
|
||||
};
|
||||
template<> struct constants<float2> {
|
||||
static __device__ inline constexpr float2 zero() { return float2{0.f, 0.f}; }
|
||||
static __device__ inline constexpr float2 one() { return float2{1.f, 1.f}; }
|
||||
static __device__ inline constexpr float2 pos_infty() { return float2{constants<float>::pos_infty(), constants<float>::pos_infty()}; }
|
||||
static __device__ inline constexpr float2 neg_infty() { return float2{constants<float>::neg_infty(), constants<float>::neg_infty()}; }
|
||||
};
|
||||
template<> struct constants<bf16> {
|
||||
static __device__ inline constexpr bf16 zero() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x0000)); } // unfortunately __float2bf16_rn is not constexpr
|
||||
static __device__ inline constexpr bf16 one() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x3F80)); }
|
||||
static __device__ inline constexpr bf16 pos_infty() { return std::bit_cast<__nv_bfloat16>(uint16_t(0x7F80)); }
|
||||
static __device__ inline constexpr bf16 neg_infty() { return std::bit_cast<__nv_bfloat16>(uint16_t(0xFF80)); }
|
||||
};
|
||||
template<> struct constants<bf16_2> {
|
||||
static __device__ inline constexpr bf16_2 zero() { return bf16_2{constants<bf16>::zero(), constants<bf16>::zero()}; }
|
||||
static __device__ inline constexpr bf16_2 one() { return bf16_2{constants<bf16>::one(), constants<bf16>::one()}; }
|
||||
static __device__ inline constexpr bf16_2 pos_infty() { return bf16_2{constants<bf16>::pos_infty(), constants<bf16>::pos_infty()}; }
|
||||
static __device__ inline constexpr bf16_2 neg_infty() { return bf16_2{constants<bf16>::neg_infty(), constants<bf16>::neg_infty()}; }
|
||||
};
|
||||
template<> struct constants<half> {
|
||||
static __device__ inline constexpr half zero() { return std::bit_cast<__half>(uint16_t(0x0000)); }
|
||||
static __device__ inline constexpr half one() { return std::bit_cast<__half>(uint16_t(0x3C00)); }
|
||||
static __device__ inline constexpr half pos_infty() { return std::bit_cast<__half>(uint16_t(0x7C00)); }
|
||||
static __device__ inline constexpr half neg_infty() { return std::bit_cast<__half>(uint16_t(0xFC00)); }
|
||||
};
|
||||
template<> struct constants<half_2> {
|
||||
static __device__ inline constexpr half_2 zero() { return half_2{constants<half>::zero(), constants<half>::zero()}; }
|
||||
static __device__ inline constexpr half_2 one() { return half_2{constants<half>::one(), constants<half>::one()}; }
|
||||
static __device__ inline constexpr half_2 pos_infty() { return half_2{constants<half>::pos_infty(), constants<half>::pos_infty()}; }
|
||||
static __device__ inline constexpr half_2 neg_infty() { return half_2{constants<half>::neg_infty(), constants<half>::neg_infty()}; }
|
||||
};
|
||||
#ifdef KITTENS_HOPPER
|
||||
template<> struct constants<fp8e4m3> {
|
||||
static __device__ inline constexpr fp8e4m3 zero() { return std::bit_cast<__nv_fp8_e4m3>(uint8_t(0x00)); }
|
||||
static __device__ inline constexpr fp8e4m3 one() { return std::bit_cast<__nv_fp8_e4m3>(uint8_t(0x38)); }
|
||||
};
|
||||
template<> struct constants<fp8e4m3_2> {
|
||||
static __device__ inline constexpr fp8e4m3_2 zero() { return std::bit_cast<fp8e4m3_2>(uint16_t(0x0000)); }
|
||||
static __device__ inline constexpr fp8e4m3_2 one() { return std::bit_cast<fp8e4m3_2>(uint16_t(0x3838)); }
|
||||
};
|
||||
template<> struct constants<fp8e4m3_4> {
|
||||
static __device__ inline constexpr fp8e4m3_4 zero() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x00000000)); }
|
||||
static __device__ inline constexpr fp8e4m3_4 one() { return std::bit_cast<fp8e4m3_4>(uint32_t(0x38383838)); }
|
||||
};
|
||||
template<> struct constants<fp8e5m2> {
|
||||
static __device__ inline constexpr fp8e5m2 zero() { return std::bit_cast<__nv_fp8_e5m2>(uint8_t(0x00)); }
|
||||
static __device__ inline constexpr fp8e5m2 one() { return std::bit_cast<__nv_fp8_e5m2>(uint8_t(0x3C)); }
|
||||
};
|
||||
template<> struct constants<fp8e5m2_2> {
|
||||
static __device__ inline constexpr fp8e5m2_2 zero() { return std::bit_cast<fp8e5m2_2>(uint16_t(0x0000)); }
|
||||
static __device__ inline constexpr fp8e5m2_2 one() { return std::bit_cast<fp8e5m2_2>(uint16_t(0x3C3C)); }
|
||||
};
|
||||
template<> struct constants<fp8e5m2_4> {
|
||||
static __device__ inline constexpr fp8e5m2_4 zero() { return std::bit_cast<fp8e5m2_4>(uint32_t(0x00000000)); }
|
||||
static __device__ inline constexpr fp8e5m2_4 one() { return std::bit_cast<fp8e5m2_4>(uint32_t(0x3C3C3C3C)); }
|
||||
};
|
||||
#endif
|
||||
|
||||
template<> struct constants<int> {
|
||||
static __device__ inline constexpr int zero() { return 0; }
|
||||
static __device__ inline constexpr int one() { return 1; }
|
||||
};
|
||||
template<> struct constants<int2> {
|
||||
static __device__ inline constexpr int2 zero() { return int2{0, 0}; }
|
||||
static __device__ inline constexpr int2 one() { return int2{1, 1}; }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Provides information about packing of elements for a given type.
|
||||
*
|
||||
* @tparam T The type for which to provide packing information.
|
||||
*/
|
||||
template<typename T> struct packing {
|
||||
/**
|
||||
* @brief The number of elements packed together.
|
||||
*
|
||||
* @return constexpr int representing number of elements within the type.
|
||||
*/
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
/**
|
||||
* @brief Packs a single T element twice (replicated) into its packed type.
|
||||
*
|
||||
* @param i[in] The element to pack.
|
||||
* @return The packed type.
|
||||
*/
|
||||
static __device__ inline constexpr T pack(const bf16 &i);
|
||||
};
|
||||
template<> struct packing<bf16> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = bf16;
|
||||
using packed_type = bf16_2;
|
||||
static __device__ inline constexpr bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; }
|
||||
};
|
||||
template<> struct packing<bf16_2> {
|
||||
static __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = bf16;
|
||||
using packed_type = bf16_2;
|
||||
static __device__ inline constexpr bf16_2 pack(const bf16 &i) { return bf16_2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<half> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = half;
|
||||
using packed_type = half_2;
|
||||
static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; }
|
||||
};
|
||||
template<> struct packing<half_2> {
|
||||
static __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = half;
|
||||
using packed_type = half_2;
|
||||
static __device__ inline constexpr half_2 pack(const half &i) { return half_2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<float> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = float;
|
||||
using packed_type = float2;
|
||||
static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; }
|
||||
};
|
||||
template<> struct packing<float2> {
|
||||
static __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = float;
|
||||
using packed_type = float2;
|
||||
static __device__ inline constexpr float2 pack(const float &i) { return float2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<char> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = char;
|
||||
using packed_type = char2;
|
||||
static __device__ inline constexpr char2 pack(const char &i) { return char2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<char2> {
|
||||
static __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = char;
|
||||
using packed_type = char2;
|
||||
static __device__ inline constexpr char2 pack(const char &i) { return char2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<int> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = int;
|
||||
using packed_type = int2;
|
||||
static __device__ inline constexpr int2 pack(const int &i) { return int2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<int2> {
|
||||
static __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = int;
|
||||
using packed_type = int2;
|
||||
static __device__ inline constexpr int2 pack(const int &i) { return int2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<uint> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = uint;
|
||||
using packed_type = uint2;
|
||||
static __device__ inline constexpr uint2 pack(const uint &i) { return uint2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<uint2> {
|
||||
static __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = uint;
|
||||
using packed_type = uint2;
|
||||
static __device__ inline constexpr uint2 pack(const uint &i) { return uint2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
struct uint64_2 { uint64_t x, y; };
|
||||
template<> struct packing<uint64_t> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = uint64_t;
|
||||
using packed_type = uint64_2;
|
||||
static __device__ inline constexpr uint64_2 pack(const uint64_t &i) { return uint64_2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<uint64_2> {
|
||||
static __device__ inline constexpr int num() { return 2; }
|
||||
using unpacked_type = uint64_t;
|
||||
using packed_type = uint64_2;
|
||||
static __device__ inline constexpr uint64_2 pack(const uint64_t &i) { return uint64_2{i, i}; } // this replication makes code cleaner later.
|
||||
};
|
||||
template<> struct packing<float4> {
|
||||
static __device__ inline constexpr int num() { return 4; }
|
||||
};
|
||||
template<> struct packing<int4> {
|
||||
static __device__ inline constexpr int num() { return 4; }
|
||||
};
|
||||
#ifdef KITTENS_HOPPER
|
||||
template<> struct packing<fp8e4m3> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = fp8e4m3;
|
||||
using packed_type = fp8e4m3_4;
|
||||
};
|
||||
template<> struct packing<fp8e4m3_4> {
|
||||
static __device__ inline constexpr int num() { return 4; }
|
||||
using unpacked_type = fp8e4m3;
|
||||
using packed_type = fp8e4m3_4;
|
||||
};
|
||||
template<> struct packing<fp8e5m2> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = fp8e5m2;
|
||||
using packed_type = fp8e5m2_4;
|
||||
};
|
||||
template<> struct packing<fp8e5m2_4> {
|
||||
static __device__ inline constexpr int num() { return 4; }
|
||||
using unpacked_type = fp8e5m2;
|
||||
using packed_type = fp8e5m2_4;
|
||||
};
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
template<> struct packing<fp8e8m0> {
|
||||
static __device__ inline constexpr int num() { return 1; }
|
||||
using unpacked_type = fp8e8m0;
|
||||
using packed_type = fp8e8m0_4;
|
||||
};
|
||||
template<> struct packing<fp8e8m0_4> {
|
||||
static __device__ inline constexpr int num() { return 4; }
|
||||
using unpacked_type = fp8e8m0;
|
||||
using packed_type = fp8e8m0_4;
|
||||
};
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* @brief Provides templated functionality to convert between different types.
|
||||
*
|
||||
* @tparam T The target type for conversion.
|
||||
* @tparam U The source type for conversion.
|
||||
*/
|
||||
template<typename T, typename U> struct convertor {
|
||||
/**
|
||||
* @brief Converts a value of type U to type T.
|
||||
*
|
||||
* @param u[in] The value of type U to convert.
|
||||
* @return T The converted value of type T.
|
||||
*/
|
||||
static __host__ __device__ inline T convert(const U & u) {
|
||||
return (T)u;
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float, bf16> {
|
||||
static __host__ __device__ inline float convert(const bf16 & u) {
|
||||
return __bfloat162float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<bf16, float> {
|
||||
static __host__ __device__ inline bf16 convert(const float & u) {
|
||||
return __float2bfloat16_rn(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float2, bf16_2> {
|
||||
static __host__ __device__ inline float2 convert(const bf16_2 & u) {
|
||||
return __bfloat1622float2(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<bf16_2, float2> {
|
||||
static __host__ __device__ inline bf16_2 convert(const float2 & u) {
|
||||
return __float22bfloat162_rn(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float, half> {
|
||||
static __host__ __device__ inline float convert(const half & u) {
|
||||
return __half2float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<half, float> {
|
||||
static __host__ __device__ inline half convert(const float & u) {
|
||||
return __float2half(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float2, half_2> {
|
||||
static __host__ __device__ inline float2 convert(const half_2 & u) {
|
||||
return __half22float2(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<half_2, float2> {
|
||||
static __host__ __device__ inline half_2 convert(const float2 & u) {
|
||||
return __float22half2_rn(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<bf16, half> {
|
||||
static __host__ __device__ inline bf16 convert(const half & u) {
|
||||
return __float2bfloat16_rn(__half2float(u));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<half, bf16> {
|
||||
static __host__ __device__ inline half convert(const bf16 & u) {
|
||||
return __float2half(__bfloat162float(u));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<bf16_2, half_2> {
|
||||
static __host__ __device__ inline bf16_2 convert(const half_2 & u) {
|
||||
return __float22bfloat162_rn(__half22float2(u));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<half_2, bf16_2> {
|
||||
static __host__ __device__ inline half_2 convert(const bf16_2 & u) {
|
||||
return __float22half2_rn(__bfloat1622float2(u));
|
||||
}
|
||||
};
|
||||
#ifdef KITTENS_HOPPER
|
||||
// fp8e4m3
|
||||
template<> struct convertor<fp8e4m3_4, float4> {
|
||||
static __host__ __device__ inline fp8e4m3_4 convert(const float4& u) {
|
||||
return __nv_fp8x4_e4m3(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float4, fp8e4m3_4> {
|
||||
static __host__ __device__ inline float4 convert(const fp8e4m3_4& u) {
|
||||
__nv_fp8_e4m3 *vals = reinterpret_cast<__nv_fp8_e4m3*>(const_cast<__nv_fp8x4_e4m3*>(&u));
|
||||
return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3]));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e4m3_2, float2> {
|
||||
static __host__ __device__ inline fp8e4m3_2 convert(const float2& u) {
|
||||
return __nv_fp8x2_e4m3(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float2, fp8e4m3_2> {
|
||||
static __host__ __device__ inline float2 convert(const fp8e4m3_2& u) {
|
||||
__nv_fp8_e4m3 *vals = reinterpret_cast<__nv_fp8_e4m3*>(const_cast<__nv_fp8x2_e4m3*>(&u));
|
||||
return make_float2(float(vals[0]), float(vals[1]));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e4m3, float> {
|
||||
static __host__ __device__ inline fp8e4m3 convert(const float & u) {
|
||||
return __nv_fp8_e4m3(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float, fp8e4m3> {
|
||||
static __host__ __device__ inline float convert(const fp8e4m3 & u) {
|
||||
return float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<bf16_2, fp8e4m3_4> {
|
||||
static __host__ __device__ inline bf16_2 convert(const fp8e4m3_4 & u) {
|
||||
float4 f4 = convertor<float4, fp8e4m3_4>::convert(u);
|
||||
float2 f2 = make_float2(f4.x, f4.y);
|
||||
return __float22bfloat162_rn(f2);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e4m3_4, bf16_2> {
|
||||
static __host__ __device__ inline fp8e4m3_4 convert(const bf16_2 & u) {
|
||||
float2 f2 = __bfloat1622float2(u);
|
||||
float4 f4 = make_float4(f2.x, f2.y, 0.0f, 0.0f);
|
||||
return __nv_fp8x4_e4m3(f4);
|
||||
}
|
||||
};
|
||||
// fp8e5m2
|
||||
template<> struct convertor<fp8e5m2_4, float4> {
|
||||
static __host__ __device__ inline fp8e5m2_4 convert(const float4& u) {
|
||||
return __nv_fp8x4_e5m2(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float4, fp8e5m2_4> {
|
||||
static __host__ __device__ inline float4 convert(const fp8e5m2_4& u) {
|
||||
__nv_fp8_e5m2 *vals = reinterpret_cast<__nv_fp8_e5m2*>(const_cast<__nv_fp8x4_e5m2*>(&u));
|
||||
return make_float4(float(vals[0]), float(vals[1]), float(vals[2]), float(vals[3]));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e5m2_2, float2> {
|
||||
static __host__ __device__ inline fp8e5m2_2 convert(const float2& u) {
|
||||
return __nv_fp8x2_e5m2(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float2, fp8e5m2_2> {
|
||||
static __host__ __device__ inline float2 convert(const fp8e5m2_2& u) {
|
||||
__nv_fp8_e5m2 *vals = reinterpret_cast<__nv_fp8_e5m2*>(const_cast<__nv_fp8x2_e5m2*>(&u));
|
||||
return make_float2(float(vals[0]), float(vals[1]));
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e5m2, float> {
|
||||
static __host__ __device__ inline fp8e5m2 convert(const float & u) {
|
||||
return __nv_fp8_e5m2(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<float, fp8e5m2> {
|
||||
static __host__ __device__ inline float convert(const fp8e5m2 & u) {
|
||||
return float(u);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<bf16_2, fp8e5m2_4> {
|
||||
static __host__ __device__ inline bf16_2 convert(const fp8e5m2_4 & u) {
|
||||
float4 f4 = convertor<float4, fp8e5m2_4>::convert(u);
|
||||
float2 f2 = make_float2(f4.x, f4.y);
|
||||
return __float22bfloat162_rn(f2);
|
||||
}
|
||||
};
|
||||
template<> struct convertor<fp8e5m2_4, bf16_2> {
|
||||
static __host__ __device__ inline fp8e5m2_4 convert(const bf16_2 & u) {
|
||||
float2 f2 = __bfloat1622float2(u);
|
||||
float4 f4 = make_float4(f2.x, f2.y, 0.0f, 0.0f);
|
||||
return __nv_fp8x4_e5m2(f4);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
}
|
||||
}
|
||||
11
extra/thunder/cuda/include/common/common.cuh
Normal file
11
extra/thunder/cuda/include/common/common.cuh
Normal file
@@ -0,0 +1,11 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief A collection of common resources on which ThunderKittens depends.
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "util.cuh"
|
||||
#include "base_types.cuh"
|
||||
#include "base_ops.cuh"
|
||||
56
extra/thunder/cuda/include/common/debug.cuh
Normal file
56
extra/thunder/cuda/include/common/debug.cuh
Normal file
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
|
||||
// Reset
|
||||
#define TK_RESET "\033[0m"
|
||||
|
||||
// Foreground colors
|
||||
#define TK_FG_BLACK "\033[30m"
|
||||
#define TK_FG_RED "\033[31m"
|
||||
#define TK_FG_GREEN "\033[32m"
|
||||
#define TK_FG_YELLOW "\033[33m"
|
||||
#define TK_FG_BLUE "\033[34m"
|
||||
#define TK_FG_MAGENTA "\033[35m"
|
||||
#define TK_FG_CYAN "\033[36m"
|
||||
#define TK_FG_WHITE "\033[37m"
|
||||
|
||||
// Background colors
|
||||
#define TK_BG_BLACK "\033[40m"
|
||||
#define TK_BG_RED "\033[41m"
|
||||
#define TK_BG_GREEN "\033[42m"
|
||||
#define TK_BG_YELLOW "\033[43m"
|
||||
#define TK_BG_BLUE "\033[44m"
|
||||
#define TK_BG_MAGENTA "\033[45m"
|
||||
#define TK_BG_CYAN "\033[46m"
|
||||
#define TK_BG_WHITE "\033[47m"
|
||||
|
||||
// Bright foreground colors
|
||||
#define TK_FG_BRIGHT_BLACK "\033[90m"
|
||||
#define TK_FG_BRIGHT_RED "\033[91m"
|
||||
#define TK_FG_BRIGHT_GREEN "\033[92m"
|
||||
#define TK_FG_BRIGHT_YELLOW "\033[93m"
|
||||
#define TK_FG_BRIGHT_BLUE "\033[94m"
|
||||
#define TK_FG_BRIGHT_MAGENTA "\033[95m"
|
||||
#define TK_FG_BRIGHT_CYAN "\033[96m"
|
||||
#define TK_FG_BRIGHT_WHITE "\033[97m"
|
||||
|
||||
// Bright background colors
|
||||
#define TK_BG_BRIGHT_BLACK "\033[100m"
|
||||
#define TK_BG_BRIGHT_RED "\033[101m"
|
||||
#define TK_BG_BRIGHT_GREEN "\033[102m"
|
||||
#define TK_BG_BRIGHT_YELLOW "\033[103m"
|
||||
#define TK_BG_BRIGHT_BLUE "\033[104m"
|
||||
#define TK_BG_BRIGHT_MAGENTA "\033[105m"
|
||||
#define TK_BG_BRIGHT_CYAN "\033[106m"
|
||||
#define TK_BG_BRIGHT_WHITE "\033[107m"
|
||||
|
||||
// Text styles
|
||||
#define TK_BOLD "\033[1m"
|
||||
#define TK_DIM "\033[2m"
|
||||
#define TK_ITALIC "\033[3m"
|
||||
#define TK_UNDERLINE "\033[4m"
|
||||
#define TK_BLINK "\033[5m"
|
||||
#define TK_REVERSE "\033[7m"
|
||||
#define TK_HIDDEN "\033[8m"
|
||||
|
||||
// Macro to combine styles
|
||||
#define TK_STYLE(...) "\033[" #__VA_ARGS__ "m"
|
||||
314
extra/thunder/cuda/include/common/util.cuh
Normal file
314
extra/thunder/cuda/include/common/util.cuh
Normal file
@@ -0,0 +1,314 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief General utilities for ThunderKittens.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <memory>
|
||||
|
||||
// CUDA driver API
|
||||
#define CUCHECK(cmd) do { \
|
||||
CUresult err = cmd; \
|
||||
if (err != CUDA_SUCCESS) { \
|
||||
const char *errStr; \
|
||||
cuGetErrorString(err, &errStr); \
|
||||
fprintf(stderr, "Failed: CUDA error %s:%d '%s'\n", \
|
||||
__FILE__, __LINE__, errStr); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
// CUDA runtime API
|
||||
#define CUDACHECK(cmd) do { \
|
||||
cudaError_t err = cmd; \
|
||||
if (err != cudaSuccess) { \
|
||||
fprintf(stderr, "Failed: CUDA error %s:%d '%s'\n", \
|
||||
__FILE__, __LINE__, cudaGetErrorString(err)); \
|
||||
exit(EXIT_FAILURE); \
|
||||
} \
|
||||
} while(0)
|
||||
|
||||
/**
|
||||
* @namespace kittens
|
||||
*
|
||||
* @brief The main namespace of ThunderKittens.
|
||||
*/
|
||||
namespace kittens {
|
||||
|
||||
/* ---------- GENERAL CONSTANTS FOR KITTENS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Tile dimension constant.
|
||||
*/
|
||||
template<typename T> constexpr int TILE_COL_DIM = sizeof(T) == 1 ? 32 : 16;
|
||||
template<typename T> constexpr int TILE_ROW_DIM = 16;
|
||||
/**
|
||||
* @brief Tile num elements constant calculated as TILE_DIM squared.
|
||||
*/
|
||||
template<typename T> constexpr int TILE_ELEMENTS{TILE_COL_DIM<T>*TILE_ROW_DIM<T>};
|
||||
/**
|
||||
* @brief Constant representing number of threads in a warp.
|
||||
*/
|
||||
constexpr int WARP_THREADS{32};
|
||||
/**
|
||||
* @brief Constant representing number of threads in a warpgroup of four warps.
|
||||
*/
|
||||
constexpr int WARPGROUP_THREADS{128};
|
||||
/**
|
||||
|
||||
* @brief Constant representing number of warps in a warpgroup of four warps.
|
||||
*/
|
||||
constexpr int WARPGROUP_WARPS{4};
|
||||
/**
|
||||
|
||||
* @brief Get the warp ID of the current thread.
|
||||
* @return The warp ID.
|
||||
*/
|
||||
__device__ static __forceinline__ int warpid() {
|
||||
// uint32_t wid;
|
||||
// asm volatile("mov.u32 %0, %warpid;" : "=r"(wid));
|
||||
// return wid;
|
||||
return threadIdx.x >> 5;
|
||||
}
|
||||
/**
|
||||
* @brief Get the warpgroup ID of the current thread.
|
||||
* @return The warpgroup ID.
|
||||
*/
|
||||
__device__ static __forceinline__ int warpgroupid() { return warpid() >> 2; }
|
||||
/**
|
||||
* @brief Get the lane ID of the current thread within its warp.
|
||||
* @return The lane ID.
|
||||
*/
|
||||
__device__ static __forceinline__ int laneid() {
|
||||
// uint32_t lid;
|
||||
// asm volatile("mov.u32 %0, %laneid;" : "=r"(lid));
|
||||
// return lid;
|
||||
return threadIdx.x & 31;
|
||||
}
|
||||
|
||||
#if defined(KITTENS_HOPPER)
|
||||
constexpr int MAX_SHARED_MEMORY = 227000;
|
||||
#elif defined(KITTENS_A100)
|
||||
constexpr int MAX_SHARED_MEMORY = 164000;
|
||||
#elif defined(KITTENS_4090)
|
||||
constexpr int MAX_SHARED_MEMORY = 100000;
|
||||
#endif
|
||||
|
||||
struct transpose {
|
||||
static constexpr int N = 0; // not transposed
|
||||
static constexpr int T = 1; // transposed
|
||||
};
|
||||
struct axis {
|
||||
static constexpr int ROW = 0; // row axis of a tile
|
||||
static constexpr int COL = 1; // column axis of a tile
|
||||
};
|
||||
|
||||
/* ---------- TYPE HELPERS ---------- */
|
||||
|
||||
/**
|
||||
* @namespace ducks
|
||||
*
|
||||
* @brief ThunderKittens' namespace for template metaprogramming..
|
||||
*
|
||||
* This includes primarily dummy types and concept wrappers, along
|
||||
* with a few additional utilities.
|
||||
*/
|
||||
namespace ducks {
|
||||
|
||||
/**
|
||||
* @brief A type representing an empty default for a template.
|
||||
*/
|
||||
struct default_type {};
|
||||
|
||||
// This macro can't be done as a template, so it doesn't really have a location in kittens.
|
||||
#define typeof(A) typename std::remove_const<typename std::remove_reference<decltype(A)>::type>::type
|
||||
|
||||
}
|
||||
|
||||
/* ---------- SHUFFLE UTILS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Mask constant for all active threads in a warp.
|
||||
*/
|
||||
static constexpr uint32_t MASK_ALL = 0xFFFFFFFF;
|
||||
|
||||
/**
|
||||
* @brief Perform a shuffle down operation on a packed type synchronously across a warp.
|
||||
* @tparam T The type of the value to be shuffled.
|
||||
* @param mask[in] The mask of active threads.
|
||||
* @param f[in] The value to be shuffled.
|
||||
* @param delta[in] The number of positions to shuffle down.
|
||||
* @return The result of the shuffle operation.
|
||||
*/
|
||||
template<typename T>
|
||||
__device__ static inline T packed_shfl_down_sync(uint32_t mask, const T &f, int delta) {
|
||||
return __shfl_down_sync(mask, f, delta);
|
||||
}
|
||||
template<>
|
||||
__device__ inline float2 packed_shfl_down_sync<float2>(uint32_t mask, const float2 &f, int delta) {
|
||||
float2 r;
|
||||
r.x = __shfl_down_sync(mask, f.x, delta);
|
||||
r.y = __shfl_down_sync(mask, f.y, delta);
|
||||
return r;
|
||||
}
|
||||
/**
|
||||
* @brief Perform a packed shuffle operation synchronously across a warp.
|
||||
* @tparam T The type of the value to be shuffled.
|
||||
* @param mask[in] The mask of active threads.
|
||||
* @param f[in] The value to be shuffled.
|
||||
* @param src[in] The source lane from which to shuffle.
|
||||
* @return The result of the shuffle operation.
|
||||
*/
|
||||
template<typename T>
|
||||
__device__ static inline T packed_shfl_sync(uint32_t mask, const T &f, int src) {
|
||||
return __shfl_sync(mask, f, src);
|
||||
}
|
||||
template<>
|
||||
__device__ inline float2 packed_shfl_sync<float2>(uint32_t mask, const float2 &f, int src) {
|
||||
float2 r;
|
||||
r.x = __shfl_sync(mask, f.x, src);
|
||||
r.y = __shfl_sync(mask, f.y, src);
|
||||
return r;
|
||||
}
|
||||
|
||||
/* ---------- SHARED MEMORY UTILS ---------- */
|
||||
|
||||
// namespace ducks {
|
||||
// namespace sb {
|
||||
// struct identifier {};
|
||||
// }
|
||||
// }
|
||||
|
||||
// template<typename Args...>
|
||||
// struct sb {
|
||||
// using identifier = ducks::sb::identifier;
|
||||
// Args... args;
|
||||
// };
|
||||
|
||||
// namespace ducks {
|
||||
// namespace sb {
|
||||
// template<typename T> concept all = requires {
|
||||
// typename T::identifier;
|
||||
// } && std::is_same_v<T::identifier, identifier>;
|
||||
// }
|
||||
// }
|
||||
|
||||
// Joyously stolen from https://github.com/NVIDIA/cutlass/blob/5c447dd84f8ae0e1d48ff9a2eae26ce8c4958101/include/cute/container/alignment.hpp#L51
|
||||
#if defined(__CUDACC__)
|
||||
#define KITTENS_ALIGN_AS(n) __align__(n)
|
||||
#else
|
||||
#define KITTENS_ALIGN_AS(n) alignas(n)
|
||||
#endif
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
#define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(128)
|
||||
#else
|
||||
#define KITTENS_DEFAULT_ALIGN KITTENS_ALIGN_AS(16)
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Dummy structure for alignment purposes. Needed for WGMMA and TMA calls.
|
||||
*/
|
||||
struct KITTENS_DEFAULT_ALIGN alignment_dummy { int dummy; };
|
||||
/**
|
||||
* @brief Very simple allocator for dynamic shared memory. Advances pointer and tracks alignments.
|
||||
* @tparam default_alignment The default alignment this allocator will enforce. If <=0 (default -1) it will not align.
|
||||
*/
|
||||
#ifdef KITTENS_HOPPER
|
||||
template<int default_alignment=1024>
|
||||
#else
|
||||
template<int default_alignment=16>
|
||||
#endif
|
||||
struct shared_allocator {
|
||||
int *ptr;
|
||||
|
||||
private:
|
||||
// Recursive template to generate N-dimensional array type
|
||||
template<typename A, size_t... dims>
|
||||
struct variadic_array;
|
||||
template<typename A, size_t first_dim, size_t... rest_dims>
|
||||
struct variadic_array<A, first_dim, rest_dims...> {
|
||||
using type = typename variadic_array<A, rest_dims...>::type[first_dim];
|
||||
};
|
||||
template<typename A>
|
||||
struct variadic_array<A> {
|
||||
using type = A;
|
||||
};
|
||||
template<typename A, size_t... dims>
|
||||
using variadic_array_t = typename variadic_array<A, dims...>::type;
|
||||
|
||||
template<int alignment>
|
||||
__device__ inline void align_ptr() {
|
||||
if constexpr (alignment > 0) {
|
||||
uint64_t p = reinterpret_cast<uint64_t>(ptr);
|
||||
if(p % alignment != 0) {
|
||||
ptr = (int*)(p + (alignment-(p%alignment)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new shared allocator using a pointer to extern shared memory.
|
||||
* @param[in] _ptr Pointer to the start of the extern shared memory.
|
||||
*/
|
||||
__device__ shared_allocator(int *_ptr): ptr(_ptr) {}
|
||||
/**
|
||||
* @brief Allocate shared memory for a single instance or N-dimensional array of type A.
|
||||
* @tparam A The type of the object to allocate.
|
||||
* @tparam dims... A list of dimensions for the N-dimensional array.
|
||||
* @return Reference to the allocated object.
|
||||
*/
|
||||
template<typename A, size_t... dims>
|
||||
__device__ inline variadic_array_t<A, dims...>& allocate() {
|
||||
// static_assert(sizeof(A) % default_alignment == 0, "Type is not aligned properly for array allocation");
|
||||
align_ptr<default_alignment>();
|
||||
using at = variadic_array_t<A, dims...>;
|
||||
at*p = reinterpret_cast<at*>(ptr);
|
||||
ptr += sizeof(at)/sizeof(int);
|
||||
return *p;
|
||||
}
|
||||
/**
|
||||
* @brief Allocate shared memory for a single instance or N-dimensional array of type A.
|
||||
* @tparam alignment An alignment to enforce for this particular object.
|
||||
* @tparam A The type of the object to allocate.
|
||||
* @tparam dims... A list of dimensions for the N-dimensional array.
|
||||
* @return Reference to the allocated object.
|
||||
*/
|
||||
template<int alignment, typename A, size_t... dims>
|
||||
__device__ inline variadic_array_t<A, dims...>& allocate() {
|
||||
// static_assert(sizeof(A) % alignment == 0, "Type is not aligned properly for array allocation");
|
||||
align_ptr<alignment>();
|
||||
using at = variadic_array_t<A, dims...>;
|
||||
at*p = reinterpret_cast<at*>(ptr);
|
||||
ptr += sizeof(at)/sizeof(int);
|
||||
return *p;
|
||||
}
|
||||
};
|
||||
#if (defined(KITTENS_HOPPER) || defined(KITTENS_BLACKWELL))
|
||||
/**
|
||||
* @brief A wrapper for an allocator that enforces sufficient alignment to be used for TMA loads and stores.
|
||||
*/
|
||||
using tma_allocator = shared_allocator<1024>;
|
||||
using tma_swizzle_allocator = tma_allocator; // swizzled TMA modes require up to 1024 byte alignments :/
|
||||
|
||||
/* Get CTA ID within a cluster */
|
||||
__device__ static inline int3 clusterIdx() {
|
||||
int3 cluster_idx;
|
||||
asm volatile("mov.u32 %0, %clusterid.x;\n" : "=r"(cluster_idx.x));
|
||||
asm volatile("mov.u32 %0, %clusterid.y;\n" : "=r"(cluster_idx.y));
|
||||
asm volatile("mov.u32 %0, %clusterid.z;\n" : "=r"(cluster_idx.z));
|
||||
return cluster_idx;
|
||||
}
|
||||
__device__ static inline int cluster_ctarank() {
|
||||
uint32_t ctarank;
|
||||
asm volatile("mov.u32 %0, %cluster_ctarank;\n" : "=r"(ctarank));
|
||||
return ctarank;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace kittens
|
||||
12
extra/thunder/cuda/include/kittens.cuh
Normal file
12
extra/thunder/cuda/include/kittens.cuh
Normal file
@@ -0,0 +1,12 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief The master header file of ThunderKittens. This file includes everything you need!
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common/common.cuh"
|
||||
#include "types/types.cuh"
|
||||
#include "ops/ops.cuh"
|
||||
#include "pyutils/util.cuh"
|
||||
// #include "pyutils/pyutils.cuh" // for simple binding without including torch
|
||||
51
extra/thunder/cuda/include/ops/device/device.cuh
Normal file
51
extra/thunder/cuda/include/ops/device/device.cuh
Normal file
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of all device (multi-GPU) operations defined by ThunderKittens
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../types/types.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
template<int _NUM_DEVICES>
|
||||
struct device {
|
||||
|
||||
static_assert(_NUM_DEVICES >= 0 && _NUM_DEVICES <= 72, "Invalid number of devices");
|
||||
static constexpr int NUM_DEVICES = _NUM_DEVICES;
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
|
||||
using barrier_t = pgl<gl<int, 1, 1, 1, -1>, NUM_DEVICES, true>;
|
||||
|
||||
/**
|
||||
* @brief Multi-GPU synchronization barrier for coordinated kernel exit
|
||||
*
|
||||
* Performs a synchronization across all devices to ensure all GPUs complete
|
||||
* their work before any kernel exits. Does not synchronize intra-node threads
|
||||
* or threadblocks.
|
||||
*
|
||||
* @param barrier Pre-allocated barrier structure, must be initialized to 0
|
||||
* @param dev_idx Current device index (0 to NUM_DEVICES - 1)
|
||||
* @param id Synchronization point identifier (default: 0). 0 is fine for most cases
|
||||
*
|
||||
*/
|
||||
__device__ static inline void sync_on_exit(const barrier_t &barrier, const int dev_idx, const int id = 0) {
|
||||
if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 &&
|
||||
threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0) {
|
||||
cuda::atomic_ref<int, cuda::thread_scope_system> barrier_uc(barrier[dev_idx][{id}]);
|
||||
|
||||
// Inter-note check-in
|
||||
multimem<int>::red<reduce_op::ADD>(barrier.mc_ptr_at({id}), 1);
|
||||
asm volatile ("{fence.proxy.alias;}" ::: "memory");
|
||||
while (barrier_uc.load(cuda::memory_order_acquire) < NUM_DEVICES);
|
||||
barrier_uc.fetch_sub(NUM_DEVICES, cuda::memory_order_release);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
} // namespace kittens
|
||||
96
extra/thunder/cuda/include/ops/group/group.cuh
Normal file
96
extra/thunder/cuda/include/ops/group/group.cuh
Normal file
@@ -0,0 +1,96 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of all group (multi-warp) operations defined by ThunderKittens
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda/pipeline>
|
||||
|
||||
#include "../../common/common.cuh"
|
||||
#include "../../types/types.cuh"
|
||||
#include "../thread/thread.cuh" // several group memory ops rely on underlying warp-scope ops
|
||||
|
||||
#define KITTENS_CHECK_WARP static_assert(GROUP_WARPS==1, "Warp (GROUP_WARPS=1) function called from a non-warp group.");
|
||||
// A "warpgroup" is a special group of 4 consecutive warps defined by NVIDIA for certain SM_90+ operations.
|
||||
#define KITTENS_CHECK_WARPGROUP static_assert(GROUP_WARPS==4, "Warpgroup (GROUP_WARPS=4) function called from a non-warpgroup group.");
|
||||
|
||||
// WGMMA relies on some template structures that cannot be specialized within the group struct, so we declare them in advance.
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include "mma/warpgroup/base/base.cuh"
|
||||
#endif
|
||||
|
||||
namespace kittens {
|
||||
/*
|
||||
This is meant to be used with a `using group_N = kittens::group<NUM_WORKERS>;` at the start of every kernel.
|
||||
*/
|
||||
template<int _GROUP_WARPS>
|
||||
struct group {
|
||||
static constexpr int GROUP_WARPS = _GROUP_WARPS; // This alias produces nice parallelism.
|
||||
static constexpr int GROUP_THREADS = GROUP_WARPS * kittens::WARP_THREADS; // This alias produces nice parallelism.
|
||||
__device__ static inline int laneid() { return threadIdx.x % GROUP_THREADS; }
|
||||
__device__ static inline int warpid() { return laneid() / kittens::WARP_THREADS; }
|
||||
__device__ static inline int groupid() { return threadIdx.x / GROUP_THREADS; }
|
||||
|
||||
__device__ static inline void sync(int id) {
|
||||
asm volatile("bar.sync %0, %1;\n" :: "r"(id), "n"(GROUP_THREADS));
|
||||
}
|
||||
template<uint32_t MASK=0xFFFFFFFF> __device__ static inline void sync() {
|
||||
static_assert(GROUP_WARPS==1, "barrier-less sync() can only be called by a single warp!");
|
||||
asm volatile("bar.warp.sync %0;\n" :: "n"(MASK));
|
||||
}
|
||||
__device__ static inline void arrive(int id) {
|
||||
asm volatile("bar.arrive %0, %1;\n" :: "r"(id), "n"(GROUP_THREADS));
|
||||
}
|
||||
|
||||
#include "memory/memory.cuh"
|
||||
#include "shared/shared.cuh"
|
||||
#include "register/register.cuh"
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include "mma/mma.cuh"
|
||||
|
||||
template<int n_reg> __device__ static inline void increase_registers() {
|
||||
static_assert(n_reg % 8 == 0, "n_reg must be a multiple of 8");
|
||||
asm volatile("setmaxnreg.inc.sync.aligned.u32 %0;\n" :: "n"(n_reg));
|
||||
}
|
||||
template<int n_reg> __device__ static inline void decrease_registers() {
|
||||
static_assert(n_reg % 8 == 0, "n_reg must be a multiple of 8");
|
||||
asm volatile("setmaxnreg.dec.sync.aligned.u32 %0;\n" :: "n"(n_reg));
|
||||
}
|
||||
__device__ static inline void producer_registers() { decrease_registers<24>(); }
|
||||
template<int NCWG> __device__ static inline void consumer_registers() { increase_registers<480/NCWG - 8*(NCWG>3) - 224*(NCWG==1)>(); }
|
||||
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
namespace everyone {
|
||||
|
||||
// Block-level synchronization
|
||||
__device__ static inline void sync(int id) {
|
||||
asm volatile("bar.sync %0;\n" :: "r"(id));
|
||||
}
|
||||
|
||||
// Cluster-level synchronization functions
|
||||
namespace tma {
|
||||
namespace cluster {
|
||||
__device__ static inline void arrive_aligned() { // All threads in the cluster must call this
|
||||
asm volatile ("barrier.cluster.arrive.release.aligned;\n");
|
||||
}
|
||||
__device__ static inline void wait_aligned() {
|
||||
asm volatile ("barrier.cluster.wait.acquire.aligned;\n");
|
||||
}
|
||||
__device__ static inline void sync() {
|
||||
arrive_aligned();
|
||||
wait_aligned();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
using warp = group<1>; // scope used by most pre-Hopper GPUs, and also for most register operations.
|
||||
using warpgroup = group<4>; // special scope commonly used by Hopper and later.
|
||||
|
||||
}
|
||||
21
extra/thunder/cuda/include/ops/group/memory/memory.cuh
Normal file
21
extra/thunder/cuda/include/ops/group/memory/memory.cuh
Normal file
@@ -0,0 +1,21 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of colaborative group memory movement operations
|
||||
*/
|
||||
|
||||
#include "util/util.cuh"
|
||||
#include "tile/tile.cuh"
|
||||
#include "vec/vec.cuh"
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
struct tma {
|
||||
#include "util/tma.cuh"
|
||||
#include "tile/tma.cuh"
|
||||
#include "vec/tma.cuh"
|
||||
struct cluster {
|
||||
#include "util/tma_cluster.cuh"
|
||||
#include "tile/tma_cluster.cuh"
|
||||
#include "vec/tma_cluster.cuh"
|
||||
};
|
||||
};
|
||||
#endif
|
||||
@@ -0,0 +1,42 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group to collaboratively transfer data directly between global memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively loads data from a source array into register tiles.
|
||||
*
|
||||
* @tparam RT The register tile type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param dst[out] The destination tile to load data into.
|
||||
* @param src[in] The source array to load data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the source array.
|
||||
*/
|
||||
template<int axis, ducks::crt::all CRT, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<crt<typename CRT::T, GROUP_WARPS*CRT::rows, CRT::cols, typename CRT::layout>>>
|
||||
__device__ inline static void load(CRT &dst, const CGL &src, const COORD &idx) {
|
||||
load<axis, CRT::component, CGL::component, COORD>(dst.real, src.real, idx);
|
||||
load<axis, CRT::component, CGL::component, COORD>(dst.imag, src.imag, idx);
|
||||
}
|
||||
template<ducks::crt::all CRT, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<crt<typename CRT::T, GROUP_WARPS*CRT::rows, CRT::cols, typename CRT::layout>>>
|
||||
__device__ inline static void load(CRT &dst, const CGL &src, const COORD &idx) {
|
||||
load<2, CRT, CGL>(dst, src, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Collaboratively stores data from register tiles to a destination array in global memory.
|
||||
*
|
||||
* @tparam RT The register tile type.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register tile to store data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the destination array.
|
||||
*/
|
||||
template<int axis, ducks::crt::all CRT, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<crt<typename CRT::T, GROUP_WARPS*CRT::rows, CRT::cols, typename CRT::layout>>>
|
||||
__device__ inline static void store(CGL &dst, const CRT &src, const COORD &idx) {
|
||||
store<axis, typename CRT::component, typename CGL::component>(dst.real, src.real, idx);
|
||||
store<axis, typename CRT::component, typename CGL::component>(dst.imag, src.imag, idx);
|
||||
}
|
||||
template<ducks::crt::all CRT, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<crt<typename CRT::T, GROUP_WARPS*CRT::rows, CRT::cols, typename CRT::layout>>>
|
||||
__device__ inline static void store(CGL &dst, const CRT &src, const COORD &idx) {
|
||||
store<2, CRT, CGL>(dst, src, idx);
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group (collaborative warp) ops for loading shared tiles from and storing to global memory.
|
||||
*/
|
||||
|
||||
template<int axis, bool assume_aligned, ducks::cst::all CST, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<CST>>
|
||||
__device__ static inline void load(CST &dst, const CGL &src, const COORD &idx) {
|
||||
load<axis, assume_aligned, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx);
|
||||
load<axis, assume_aligned, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx);
|
||||
}
|
||||
template<ducks::cst::all CST, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<CST>>
|
||||
__device__ static inline void load(CST &dst, const CGL &src, const COORD &idx) {
|
||||
load<2, false, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx);
|
||||
load<2, false, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx);
|
||||
}
|
||||
|
||||
template<int axis, bool assume_aligned, ducks::cst::all CST, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<CST>>
|
||||
__device__ static inline void store(CGL &dst, const CST &src, const COORD &idx) {
|
||||
store<axis, assume_aligned, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx);
|
||||
store<axis, assume_aligned, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx);
|
||||
}
|
||||
template<ducks::cst::all CST, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<CST>>
|
||||
__device__ static inline void store(CGL &dst, const CST &src, const COORD &idx) {
|
||||
store<2, false, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx);
|
||||
store<2, false, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx);
|
||||
}
|
||||
|
||||
template<int axis, bool assume_aligned, ducks::cst::all CST, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<CST>>
|
||||
__device__ static inline void load_async(CST &dst, const CGL &src, const COORD &idx) {
|
||||
load_async<axis, assume_aligned, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx);
|
||||
load_async<axis, assume_aligned, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx);
|
||||
}
|
||||
template<ducks::cst::all CST, ducks::cgl::all CGL, ducks::coord::tile COORD=coord<CST>>
|
||||
__device__ static inline void load_async(CST &dst, const CGL &src, const COORD &idx) {
|
||||
load_async<2, false, typename CST::component, typename CGL::component, COORD>(dst.real, src.real, idx);
|
||||
load_async<2, false, typename CST::component, typename CGL::component, COORD>(dst.imag, src.imag, idx);
|
||||
}
|
||||
@@ -0,0 +1,34 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a warpgroup to collaboratively transfer data directly between shared memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively load data from a shared tile into register tiles split across a warpgroup.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination register tile.
|
||||
* @param src[in] The source shared tile.
|
||||
*/
|
||||
template<ducks::crt::all RT, ducks::cst::all ST>
|
||||
__device__ inline static void load(RT &dst, const ST &src) {
|
||||
load(dst.real, src.real);
|
||||
load(dst.imag, src.imag);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Collaboratively store data into a shared tile from register tiles split across a warpgroup.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination shared tile.
|
||||
* @param src[in] The source register tile.
|
||||
*/
|
||||
template<ducks::cst::all ST, ducks::crt::all RT>
|
||||
__device__ inline static void store(ST &dst, const RT &src) {
|
||||
store(dst.real, src.real);
|
||||
store(dst.imag, src.imag);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group to collaboratively transfer data directly between global memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively loads data from a source array into row-major layout tiles.
|
||||
*
|
||||
* @tparam RT The row-major layout tile type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param dst[out] The destination tile to load data into.
|
||||
* @param src[in] The source array to load data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the source array.
|
||||
*/
|
||||
template<int axis, ducks::rt::row_layout RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<rt<typename RT::T, GROUP_WARPS*RT::rows, RT::cols, typename RT::layout>>>
|
||||
__device__ inline static void load(RT &dst, const GL &src, const COORD &idx) {
|
||||
using T2 = RT::dtype;
|
||||
using U = typename GL::dtype;
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
static_assert(!std::is_same_v<T2, fp8e4m3_4> && !std::is_same_v<T2, fp8e5m2_4>, "Unsupported type for load/store");
|
||||
#endif
|
||||
|
||||
U *src_ptr = (U*)&src[(idx.template unit_coord<axis, 3>())];
|
||||
const int row_stride = src.template stride<axis>();
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
int warp_laneid = threadIdx.x % WARP_THREADS;
|
||||
int local_warpid;
|
||||
if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4));
|
||||
else local_warpid = warpid();
|
||||
const int row_offset = dst.rows*local_warpid;
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
int row = row_offset + i*dst.tile_size_row + (warp_laneid / 4);
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = j*dst.tile_size_col + 2*(warp_laneid % 4);
|
||||
dst.tiles[i][j].data[0] = base_types::convertor<T2, U2>::convert(*(U2*)(&src_ptr[(row+0)*row_stride + (col+0)]));
|
||||
dst.tiles[i][j].data[2] = base_types::convertor<T2, U2>::convert(*(U2*)(&src_ptr[(row+0)*row_stride + (col+8)]));
|
||||
}
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = j*dst.tile_size_col + 2*(warp_laneid % 4);
|
||||
dst.tiles[i][j].data[1] = base_types::convertor<T2, U2>::convert(*(U2*)(&src_ptr[(row+8)*row_stride + (col+0)]));
|
||||
dst.tiles[i][j].data[3] = base_types::convertor<T2, U2>::convert(*(U2*)(&src_ptr[(row+8)*row_stride + (col+8)]));
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Collaboratively loads data from a source array into column-major layout tiles.
|
||||
*
|
||||
* @tparam RT The column-major layout tile type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param dst[out] The destination tile to load data into.
|
||||
* @param src[in] The source array to load data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the source array.
|
||||
*/
|
||||
template<int axis, ducks::rt::col_layout RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<rt<typename RT::T, GROUP_WARPS*RT::rows, RT::cols, typename RT::layout>>>
|
||||
__device__ inline static void load(RT &dst, const GL &src, const COORD &idx) {
|
||||
using T = typename RT::T;
|
||||
using U = typename GL::dtype;
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
static_assert(!std::is_same_v<T, fp8e4m3> && !std::is_same_v<T, fp8e5m2>, "Unsupported type for load/store");
|
||||
#endif
|
||||
|
||||
U *src_ptr = (U*)&src[(idx.template unit_coord<axis, 3>())];
|
||||
const int row_stride = src.template stride<axis>();
|
||||
int warp_laneid = threadIdx.x % WARP_THREADS;
|
||||
int local_warpid;
|
||||
if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4));
|
||||
else local_warpid = warpid();
|
||||
const int row_offset = dst.rows*local_warpid;
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
int row = row_offset + i*dst.tile_size_row + 2*(warp_laneid % 4);
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = j*dst.tile_size_col + (warp_laneid / 4);
|
||||
dst.tiles[i][j].data[0].x = base_types::convertor<T, U>::convert(src_ptr[(row+0)*row_stride + (col+0)]);
|
||||
dst.tiles[i][j].data[1].x = base_types::convertor<T, U>::convert(src_ptr[(row+0)*row_stride + (col+8)]);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = j*dst.tile_size_col + (warp_laneid / 4);
|
||||
dst.tiles[i][j].data[0].y = base_types::convertor<T, U>::convert(src_ptr[(row+1)*row_stride + (col+0)]);
|
||||
dst.tiles[i][j].data[1].y = base_types::convertor<T, U>::convert(src_ptr[(row+1)*row_stride + (col+8)]);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = j*dst.tile_size_col + (warp_laneid / 4);
|
||||
dst.tiles[i][j].data[2].x = base_types::convertor<T, U>::convert(src_ptr[(row+8)*row_stride + (col+0)]);
|
||||
dst.tiles[i][j].data[3].x = base_types::convertor<T, U>::convert(src_ptr[(row+8)*row_stride + (col+8)]);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = j*dst.tile_size_col + (warp_laneid / 4);
|
||||
dst.tiles[i][j].data[2].y = base_types::convertor<T, U>::convert(src_ptr[(row+9)*row_stride + (col+0)]);
|
||||
dst.tiles[i][j].data[3].y = base_types::convertor<T, U>::convert(src_ptr[(row+9)*row_stride + (col+8)]);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<ducks::rt::all RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<rt<typename RT::T, GROUP_WARPS*RT::rows, RT::cols, typename RT::layout>>>
|
||||
__device__ inline static void load(RT &dst, const GL &src, const COORD &idx) {
|
||||
load<2>(dst, src, idx);
|
||||
}
|
||||
/**
|
||||
* @brief Collaboratively stores data from register tiles to a destination array in global memory with a row-major layout.
|
||||
*
|
||||
* @tparam RT The register tile type with a row-major layout.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register tile to store data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the destination array.
|
||||
*/
|
||||
template<int axis, ducks::rt::row_layout RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<rt<typename RT::T, GROUP_WARPS*RT::rows, RT::cols, typename RT::layout>>>
|
||||
__device__ inline static void store(const GL &dst, const RT &src, const COORD &idx) {
|
||||
using T2 = RT::dtype;
|
||||
using U = typename GL::dtype;
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
static_assert(!std::is_same_v<T2, fp8e4m3_4> && !std::is_same_v<T2, fp8e5m2_4>, "Unsupported type for load/store");
|
||||
#endif
|
||||
|
||||
U *dst_ptr = (U*)&dst[(idx.template unit_coord<axis, 3>())];
|
||||
const int row_stride = dst.template stride<axis>();
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
int warp_laneid = threadIdx.x % WARP_THREADS;
|
||||
int local_warpid;
|
||||
if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4));
|
||||
else local_warpid = warpid();
|
||||
const int row_offset = src.rows*local_warpid;
|
||||
#pragma unroll
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
int row = row_offset + i*src.tile_size_row + (warp_laneid / 4);
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
int col = j*src.tile_size_col + 2*(warp_laneid % 4);
|
||||
*(U2*)(&dst_ptr[(row+0)*row_stride + (col+0)]) = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[0]);
|
||||
*(U2*)(&dst_ptr[(row+0)*row_stride + (col+8)]) = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[2]);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
int col = j*src.tile_size_col + 2*(warp_laneid % 4);
|
||||
*(U2*)(&dst_ptr[(row+8)*row_stride + (col+0)]) = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[1]);
|
||||
*(U2*)(&dst_ptr[(row+8)*row_stride + (col+8)]) = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Collaboratively stores data from register tiles to a destination array in global memory with a column-major layout.
|
||||
*
|
||||
* @tparam RT The register tile type with a column-major layout.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register tile to store data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the destination array.
|
||||
*/
|
||||
template<int axis, ducks::rt::col_layout RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<rt<typename RT::T, GROUP_WARPS*RT::rows, RT::cols, typename RT::layout>>>
|
||||
__device__ inline static void store(const GL &dst, const RT &src, const COORD &idx) {
|
||||
using T = base_types::packing<typename RT::dtype>::unpacked_type;
|
||||
using U = typename GL::dtype;
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
static_assert(!std::is_same_v<T, fp8e4m3_4> && !std::is_same_v<T, fp8e5m2_4>, "Unsupported type for load/store");
|
||||
#endif
|
||||
|
||||
U *dst_ptr = (U*)&dst[(idx.template unit_coord<axis, 3>())];
|
||||
const int row_stride = dst.template stride<axis>();
|
||||
int warp_laneid = threadIdx.x % WARP_THREADS;
|
||||
int local_warpid;
|
||||
if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4));
|
||||
else local_warpid = warpid();
|
||||
const int row_offset = src.rows*local_warpid;
|
||||
#pragma unroll
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
int row = row_offset + i*src.tile_size_row + 2*(warp_laneid % 4);
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
int col = j*src.tile_size_col + (warp_laneid / 4);
|
||||
dst_ptr[(row+0)*row_stride + (col+0)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data[0].x);
|
||||
dst_ptr[(row+0)*row_stride + (col+8)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data[1].x);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
int col = j*src.tile_size_col + (warp_laneid / 4);
|
||||
dst_ptr[(row+1)*row_stride + (col+0)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data[0].y);
|
||||
dst_ptr[(row+1)*row_stride + (col+8)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data[1].y);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
int col = j*src.tile_size_col + (warp_laneid / 4);
|
||||
dst_ptr[(row+8)*row_stride + (col+0)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data[2].x);
|
||||
dst_ptr[(row+8)*row_stride + (col+8)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data[3].x);
|
||||
}
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
int col = j*src.tile_size_col + (warp_laneid / 4);
|
||||
dst_ptr[(row+9)*row_stride + (col+0)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data[2].y);
|
||||
dst_ptr[(row+9)*row_stride + (col+8)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data[3].y);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<ducks::rt::all RT, ducks::gl::all GL, ducks::coord::tile COORD=coord<rt<typename RT::T, GROUP_WARPS*RT::rows, RT::cols, typename RT::layout>>>
|
||||
__device__ inline static void store(const GL &dst, const RT &src, const COORD &idx) {
|
||||
store<2>(dst, src, idx);
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group (collaborative warp) ops for loading shared tiles from and storing to global memory.
|
||||
*/
|
||||
|
||||
|
||||
/**
|
||||
* @brief Loads data from global memory into a shared memory tile.
|
||||
*
|
||||
* @tparam ST The type of the shared tile.
|
||||
* @param[out] dst The destination shared memory tile.
|
||||
* @param[in] src The source global memory array.
|
||||
* @param[in] idx The coordinate of the tile in the global memory array.
|
||||
*/
|
||||
template<int axis, bool assume_aligned, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load(ST &dst, const GL &src, const COORD &idx) {
|
||||
using T = typename ST::dtype;
|
||||
const int row_stride = src.template stride<axis>();
|
||||
// we can handle this many rows each time we run a memcpy_async
|
||||
constexpr int elem_per_memcpy = sizeof(float4)/sizeof(typename ST::dtype);
|
||||
constexpr int memcpy_per_row = dst.cols / elem_per_memcpy;
|
||||
constexpr int total_calls = (dst.height*dst.width * kittens::TILE_ROW_DIM<T>*kittens::TILE_COL_DIM<T> + GROUP_THREADS*elem_per_memcpy-1) / (GROUP_THREADS*elem_per_memcpy); // round up
|
||||
constexpr int total_rows = dst.height*dst.width;
|
||||
|
||||
coord<> unit_coord = idx.template unit_coord<axis, 3>();
|
||||
typename GL::dtype *src_ptr = (typename GL::dtype*)&src[unit_coord];
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst.data[0]));
|
||||
int laneid = threadIdx.x % GROUP_THREADS;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < total_calls; i++) {
|
||||
|
||||
int load_idx = i * GROUP_THREADS + laneid;
|
||||
|
||||
int row = load_idx / memcpy_per_row;
|
||||
int col = (load_idx*elem_per_memcpy) % dst.cols;
|
||||
|
||||
if constexpr (assume_aligned) {
|
||||
float4 tmp;
|
||||
move<float4>::ldg(tmp, (float4*)&src_ptr[row*row_stride + col]);
|
||||
move<float4>::sts(dst.idx(dst_ptr, {row, col}), tmp);
|
||||
}
|
||||
else {
|
||||
if (row + unit_coord.template dim<axis>() < src.template shape<axis>()) {
|
||||
float4 tmp;
|
||||
move<float4>::ldg(tmp, (float4*)&src_ptr[row*row_stride + col]);
|
||||
move<float4>::sts(dst.idx(dst_ptr, {row, col}), tmp);
|
||||
}
|
||||
else {
|
||||
float4 zeros = {0.f,0.f,0.f,0.f};
|
||||
move<float4>::sts(dst.idx(dst_ptr, {row, col}), zeros); // use the default value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load(ST &dst, const GL &src, const COORD &idx) {
|
||||
load<2, false, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stores data from a shared memory tile into global memory.
|
||||
*
|
||||
* @tparam ST The type of the shared tile.
|
||||
* @param[out] dst The destination global memory array.
|
||||
* @param[in] src The source shared memory tile.
|
||||
* @param row_stride[in] The stride between rows in the destination array.
|
||||
*/
|
||||
template<int axis, bool assume_aligned, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) {
|
||||
using T = typename ST::dtype;
|
||||
const int row_stride = dst.template stride<axis>();
|
||||
// we can handle this many rows each time we run a memcpy_async
|
||||
constexpr int elem_per_memcpy = sizeof(float4)/sizeof(typename ST::dtype);
|
||||
constexpr int memcpy_per_row = src.cols / elem_per_memcpy;
|
||||
constexpr int total_calls = (src.height*src.width * kittens::TILE_ROW_DIM<T>*kittens::TILE_COL_DIM<T> + GROUP_THREADS*elem_per_memcpy-1) / (GROUP_THREADS*elem_per_memcpy); // round up
|
||||
|
||||
coord<> unit_coord = idx.template unit_coord<axis, 3>();
|
||||
typename GL::dtype *dst_ptr = (typename GL::dtype*)&dst[unit_coord];
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src.data[0]));
|
||||
int laneid = threadIdx.x % GROUP_THREADS;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < total_calls; i++) {
|
||||
|
||||
int load_idx = i * GROUP_THREADS + laneid;
|
||||
|
||||
int row = load_idx / memcpy_per_row;
|
||||
int col = (load_idx*elem_per_memcpy) % src.cols;
|
||||
|
||||
if constexpr (assume_aligned) {
|
||||
float4 tmp;
|
||||
move<float4>::lds(tmp, src.idx(src_ptr, {row, col}));
|
||||
move<float4>::stg((float4*)&dst_ptr[row*row_stride + col], tmp);
|
||||
}
|
||||
else {
|
||||
if (row + unit_coord.template dim<axis>() < dst.template shape<axis>()) {
|
||||
float4 tmp;
|
||||
move<float4>::lds(tmp, src.idx(src_ptr, {row, col}));
|
||||
move<float4>::stg((float4*)&dst_ptr[row*row_stride + col], tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store(const GL &dst, const ST &src, const COORD &idx) {
|
||||
store<2, false, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Asynchronously loads data from global memory into a shared memory tile.
|
||||
*
|
||||
* @tparam ST The type of the shared tile.
|
||||
* @param[out] dst The destination shared memory tile.
|
||||
* @param[in] src The source global memory array.
|
||||
*
|
||||
* @note This function expects 16-byte alignments. Otherwise, behavior is undefined.
|
||||
*/
|
||||
template<int axis, bool assume_aligned, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx) {
|
||||
using T = typename ST::dtype;
|
||||
const int row_stride = src.template stride<axis>();
|
||||
// we can handle this many rows each time we run a memcpy_async
|
||||
constexpr int elem_per_memcpy = sizeof(float4)/sizeof(typename ST::dtype);
|
||||
constexpr int memcpy_per_row = dst.cols / elem_per_memcpy;
|
||||
constexpr int total_calls = (dst.height*dst.width * kittens::TILE_ROW_DIM<T>*kittens::TILE_COL_DIM<T> + GROUP_THREADS*elem_per_memcpy-1) / (GROUP_THREADS*elem_per_memcpy); // round up
|
||||
|
||||
coord<> unit_coord = idx.template unit_coord<axis, 3>();
|
||||
typename GL::dtype *src_ptr = (typename GL::dtype*)&src[unit_coord];
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst.data[0]));
|
||||
int laneid = threadIdx.x % GROUP_THREADS;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < total_calls; i++) {
|
||||
|
||||
int load_idx = i * GROUP_THREADS + laneid;
|
||||
|
||||
int row = load_idx / memcpy_per_row;
|
||||
int col = (load_idx*elem_per_memcpy) % dst.cols;
|
||||
|
||||
if constexpr (assume_aligned) {
|
||||
asm volatile(
|
||||
"cp.async.cg.shared.global.L2::128B [%0], [%1], 16;\n"
|
||||
:: "r"(dst.idx(dst_ptr, {row, col})), "l"(&src_ptr[row*row_stride + col])
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
if (row + unit_coord.template dim<axis>() < src.template shape<axis>()) {
|
||||
asm volatile(
|
||||
"cp.async.cg.shared.global.L2::128B [%0], [%1], 16;\n"
|
||||
:: "r"(dst.idx(dst_ptr, {row, col})), "l"(&src_ptr[row*row_stride + col])
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
// printf("thread %d skipping async load on row %d, col %d\n", threadIdx.x, row + unit_coord.template dim<axis>(), col);
|
||||
float4 zeros = {0.f,0.f,0.f,0.f};
|
||||
move<float4>::sts(dst.idx(dst_ptr, {row, col}), zeros); // use the default value
|
||||
}
|
||||
}
|
||||
}
|
||||
asm volatile("cp.async.commit_group;\n" ::: "memory");
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx) {
|
||||
load_async<2, false, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
@@ -0,0 +1,323 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a warpgroup to collaboratively transfer data directly between shared memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively load data from a shared tile into register tiles split across a warpgroup.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination register tile.
|
||||
* @param src[in] The source shared tile.
|
||||
*/
|
||||
template<ducks::rt::all RT, ducks::st::all ST>
|
||||
__device__ inline static void load(RT &dst, const ST &src) {
|
||||
constexpr int height = ST::height;
|
||||
constexpr int warp_height = RT::height;
|
||||
static_assert(height%GROUP_WARPS == 0, "Group load / store requires tile height to be a multiple of GROUP_WARPS.");
|
||||
static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height.");
|
||||
static_assert(ST::width==RT::width, "Group load / store requires tile widths to match.");
|
||||
int local_warpid;
|
||||
if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4));
|
||||
else local_warpid = warpid();
|
||||
using T2 = RT::dtype;
|
||||
using U = ST::dtype;
|
||||
using T = base_types::packing<T2>::unpacked_type;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
int warp_laneid = ::kittens::laneid();
|
||||
|
||||
// convert to shared state space
|
||||
uint32_t shared_addr = static_cast<uint32_t>(__cvta_generic_to_shared(&src.data[0]));
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
if constexpr (sizeof(typename ST::dtype) == 2) {
|
||||
// handle the row-major layout for 16-bit types
|
||||
U2 tmp[4];
|
||||
int row = (local_warpid*warp_height + i)*dst.tile_size_row + (warp_laneid % 16);
|
||||
int col = j*dst.tile_size_col + (warp_laneid / 16) * 8;
|
||||
if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row>) {
|
||||
move<U2>::ldsm4(tmp[0], tmp[1], tmp[2], tmp[3], src.idx(shared_addr, {row, col}));
|
||||
}
|
||||
else {
|
||||
move<U2>::ldsm4t(tmp[0], tmp[2], tmp[1], tmp[3], src.idx(shared_addr, {row, col}));
|
||||
}
|
||||
dst.tiles[i][j].data[0] = base_types::convertor<T2, U2>::convert(tmp[0]);
|
||||
dst.tiles[i][j].data[1] = base_types::convertor<T2, U2>::convert(tmp[1]);
|
||||
dst.tiles[i][j].data[2] = base_types::convertor<T2, U2>::convert(tmp[2]);
|
||||
dst.tiles[i][j].data[3] = base_types::convertor<T2, U2>::convert(tmp[3]);
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row> && sizeof(typename ST::dtype) == 1) {
|
||||
// handle the row-major layout for 8-bit types
|
||||
int warp_group_16 = (warp_laneid / 16); // divide each warp into two groups of 16 threads
|
||||
int lane_in_16 = warp_laneid % 16; // position in group of 16 threads
|
||||
int row = (local_warpid*warp_height + i)*dst.tile_size_row + (lane_in_16 % 16); // find base row for warp in warpgroup and then distribute the 16 threads in the warp across the rows
|
||||
int col = j*dst.tile_size_col + warp_group_16 * 16; // find base column and then *16 for second half of the warp
|
||||
|
||||
U2 tmp[4];
|
||||
if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row>) {
|
||||
move<U2>::ldsm4(tmp[0], tmp[1], tmp[2], tmp[3], src.idx(shared_addr, {row, col}));
|
||||
}
|
||||
else {
|
||||
move<U2>::ldsm4t(tmp[0], tmp[2], tmp[1], tmp[3], src.idx(shared_addr, {row, col}));
|
||||
}
|
||||
dst.tiles[i][j].data[0] = base_types::convertor<T2, U2>::convert(tmp[0]);
|
||||
dst.tiles[i][j].data[1] = base_types::convertor<T2, U2>::convert(tmp[1]);
|
||||
dst.tiles[i][j].data[2] = base_types::convertor<T2, U2>::convert(tmp[2]);
|
||||
dst.tiles[i][j].data[3] = base_types::convertor<T2, U2>::convert(tmp[3]);
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row> && sizeof(typename ST::dtype) == 4) {
|
||||
// handle the row-major layout for 32-bit types
|
||||
int row = (local_warpid*warp_height + i)*dst.tile_size_row + (warp_laneid / 4);
|
||||
int col = j*dst.tile_size_col + 2*(warp_laneid % 4);
|
||||
if constexpr (ST::rows != ST::underlying_rows || ST::cols != ST::underlying_cols) { // subtile case
|
||||
row += src.row_offset;
|
||||
col += src.col_offset;
|
||||
}
|
||||
int blit = sizeof(typename ST::dtype) * ((warp_laneid%4) / 2);
|
||||
U2 tmp[4];
|
||||
static constexpr int swizzle_repeat = ST::swizzle_bytes * 8;
|
||||
static constexpr int subtile_cols = ST::swizzle_bytes / sizeof(U);
|
||||
const int outer_idx = col/subtile_cols;
|
||||
const uint32_t addr_1 = shared_addr + sizeof(U)*(outer_idx*ST::underlying_rows*subtile_cols + (row+0)*subtile_cols + col%subtile_cols);
|
||||
const uint32_t addr_2 = shared_addr + sizeof(U)*(outer_idx*ST::underlying_rows*subtile_cols + (row+8)*subtile_cols + col%subtile_cols);
|
||||
const int swizzle_1 = blit ^ ((addr_1 % swizzle_repeat) >> 7) << 4;
|
||||
const int swizzle_2 = blit ^ ((addr_2 % swizzle_repeat) >> 7) << 4;
|
||||
move<U>::lds(tmp[0].x, (addr_1+ 0)^swizzle_1);
|
||||
move<U>::lds(tmp[0].y, (addr_1+ 4)^swizzle_1);
|
||||
move<U>::lds(tmp[2].x, (addr_1+32)^swizzle_1);
|
||||
move<U>::lds(tmp[2].y, (addr_1+36)^swizzle_1);
|
||||
move<U>::lds(tmp[1].x, (addr_2+ 0)^swizzle_2);
|
||||
move<U>::lds(tmp[1].y, (addr_2+ 4)^swizzle_2);
|
||||
move<U>::lds(tmp[3].x, (addr_2+32)^swizzle_2);
|
||||
move<U>::lds(tmp[3].y, (addr_2+36)^swizzle_2);
|
||||
dst.tiles[i][j].data[0] = base_types::convertor<T2, U2>::convert(tmp[0]);
|
||||
dst.tiles[i][j].data[1] = base_types::convertor<T2, U2>::convert(tmp[1]);
|
||||
dst.tiles[i][j].data[2] = base_types::convertor<T2, U2>::convert(tmp[2]);
|
||||
dst.tiles[i][j].data[3] = base_types::convertor<T2, U2>::convert(tmp[3]);
|
||||
if(blit) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
dst.tiles[i][j].data[k] = T2{dst.tiles[i][j].data[k].y, dst.tiles[i][j].data[k].x};
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// handle the column-major layout
|
||||
int row = (local_warpid*warp_height + i)*dst.tile_size_row + 2*(warp_laneid % 4);
|
||||
int col = j*dst.tile_size_col + (warp_laneid / 4);
|
||||
U2 tmp[4];
|
||||
move<U>::lds(tmp[0].x, src.idx(shared_addr, {row+0, col+0}));
|
||||
move<U>::lds(tmp[0].y, src.idx(shared_addr, {row+1, col+0}));
|
||||
move<U>::lds(tmp[1].x, src.idx(shared_addr, {row+0, col+8}));
|
||||
move<U>::lds(tmp[1].y, src.idx(shared_addr, {row+1, col+8}));
|
||||
move<U>::lds(tmp[2].x, src.idx(shared_addr, {row+8, col+0}));
|
||||
move<U>::lds(tmp[2].y, src.idx(shared_addr, {row+9, col+0}));
|
||||
move<U>::lds(tmp[3].x, src.idx(shared_addr, {row+8, col+8}));
|
||||
move<U>::lds(tmp[3].y, src.idx(shared_addr, {row+9, col+8}));
|
||||
dst.tiles[i][j].data[0] = base_types::convertor<T2, U2>::convert(tmp[0]);
|
||||
dst.tiles[i][j].data[1] = base_types::convertor<T2, U2>::convert(tmp[1]);
|
||||
dst.tiles[i][j].data[2] = base_types::convertor<T2, U2>::convert(tmp[2]);
|
||||
dst.tiles[i][j].data[3] = base_types::convertor<T2, U2>::convert(tmp[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Collaboratively store data into a shared tile from register tiles split across a warpgroup.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination shared tile.
|
||||
* @param src[in] The source register tile.
|
||||
*/
|
||||
template<ducks::st::all ST, ducks::rt::all RT>
|
||||
__device__ inline static void store(ST &dst, const RT &src) {
|
||||
constexpr int height = ST::height;
|
||||
constexpr int warp_height = RT::height;
|
||||
static_assert(height%GROUP_WARPS == 0, "Group load / store requires tile height to be a multiple of GROUP_WARPS.");
|
||||
static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height.");
|
||||
static_assert(ST::width==RT::width, "Group load / store requires tile widths to match.");
|
||||
int local_warpid;
|
||||
if constexpr(GROUP_WARPS % 4 == 0) local_warpid = (warpid()/4+(warpid()%4)*(GROUP_WARPS/4));
|
||||
else local_warpid = warpid();
|
||||
using T2 = RT::dtype;
|
||||
using U = ST::dtype;
|
||||
using T = base_types::packing<T2>::unpacked_type;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
int warp_laneid = ::kittens::laneid();
|
||||
|
||||
// convert to shared state space
|
||||
uint32_t shared_addr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst.data[0]));
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < warp_height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
if constexpr (sizeof(typename ST::dtype) == 2) {
|
||||
// handle the row-major layout
|
||||
U2 tmp[4];
|
||||
tmp[0] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[0]);
|
||||
tmp[1] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[1]);
|
||||
tmp[2] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[2]);
|
||||
tmp[3] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[3]);
|
||||
#ifdef KITTENS_HOPPER
|
||||
int row = (local_warpid*warp_height + i)*src.tile_size_row + (warp_laneid % 16);
|
||||
int col = j*src.tile_size_col + (warp_laneid / 16) * 8;
|
||||
if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row>) {
|
||||
move<U2>::stsm4(dst.idx(shared_addr, {row, col}), tmp[0], tmp[1], tmp[2], tmp[3]);
|
||||
}
|
||||
else {
|
||||
move<U2>::stsm4t(dst.idx(shared_addr, {row, col}), tmp[0], tmp[2], tmp[1], tmp[3]);
|
||||
}
|
||||
#else
|
||||
if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row>) {
|
||||
int row = (local_warpid*warp_height + i)*src.tile_size_row + (warp_laneid / 4);
|
||||
int col = j*src.tile_size_col + 2*(warp_laneid % 4);
|
||||
move<U2>::sts(dst.idx(shared_addr, {row+0, col+0}), tmp[0]);
|
||||
move<U2>::sts(dst.idx(shared_addr, {row+8, col+0}), tmp[1]);
|
||||
move<U2>::sts(dst.idx(shared_addr, {row+0, col+8}), tmp[2]);
|
||||
move<U2>::sts(dst.idx(shared_addr, {row+8, col+8}), tmp[3]);
|
||||
}
|
||||
else {
|
||||
int row = (local_warpid*warp_height + i)*src.tile_size_row + 2*(warp_laneid % 4);
|
||||
int col = j*src.tile_size_col + (warp_laneid / 4);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+0, col+0}), tmp[0].x);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+1, col+0}), tmp[0].y);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+0, col+8}), tmp[1].x);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+1, col+8}), tmp[1].y);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+8, col+0}), tmp[2].x);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+9, col+0}), tmp[2].y);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+8, col+8}), tmp[3].x);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+9, col+8}), tmp[3].y);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row> && sizeof(typename ST::dtype) == 1) {
|
||||
// handle the row-major layout for 8-bit types
|
||||
|
||||
int warp_group_16 = (warp_laneid / 16); // divide each warp into two groups of 16 threads
|
||||
int lane_in_16 = warp_laneid % 16; // position in group of 16 threads
|
||||
int row = (local_warpid*warp_height + i)*src.tile_size_row + (lane_in_16 % 16); // find base row for warp in warpgroup and then distribute the 16 threads in the warp across the rows
|
||||
int col = j*src.tile_size_col + warp_group_16 * 16; // find base column and then *16 for second half of the warp
|
||||
|
||||
U2 tmp[4];
|
||||
tmp[0] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[0]);
|
||||
tmp[1] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[1]);
|
||||
tmp[2] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[2]);
|
||||
tmp[3] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[3]);
|
||||
if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row>) {
|
||||
move<U2>::stsm4(dst.idx(shared_addr, {row, col}), tmp[0], tmp[1], tmp[2], tmp[3]);
|
||||
}
|
||||
else {
|
||||
move<U2>::stsm4t(dst.idx(shared_addr, {row, col}), tmp[0], tmp[2], tmp[1], tmp[3]);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RT::layout, ducks::rt_layout::row> && sizeof(typename ST::dtype) == 4) {
|
||||
// handle the row-major layout for 32-bit types
|
||||
int row = (local_warpid*warp_height + i)*src.tile_size_row + (warp_laneid / 4);
|
||||
int col = j*src.tile_size_col + 2*(warp_laneid % 4);
|
||||
if constexpr (ST::rows != ST::underlying_rows || ST::cols != ST::underlying_cols) { // subtile case
|
||||
row += dst.row_offset;
|
||||
col += dst.col_offset;
|
||||
}
|
||||
int blit = sizeof(typename ST::dtype) * ((warp_laneid%4) / 2);
|
||||
T2 reg_tmp[4];
|
||||
if(blit) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
reg_tmp[k] = T2{src.tiles[i][j].data[k].y, src.tiles[i][j].data[k].x};
|
||||
}
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
reg_tmp[k] = src.tiles[i][j].data[k];
|
||||
}
|
||||
}
|
||||
U2 tmp[4];
|
||||
tmp[0] = base_types::convertor<U2, T2>::convert(reg_tmp[0]);
|
||||
tmp[1] = base_types::convertor<U2, T2>::convert(reg_tmp[1]);
|
||||
tmp[2] = base_types::convertor<U2, T2>::convert(reg_tmp[2]);
|
||||
tmp[3] = base_types::convertor<U2, T2>::convert(reg_tmp[3]);
|
||||
static constexpr int swizzle_repeat = ST::swizzle_bytes * 8;
|
||||
static constexpr int subtile_cols = ST::swizzle_bytes / sizeof(U);
|
||||
const int outer_idx = col/subtile_cols;
|
||||
const uint32_t addr_1 = shared_addr + sizeof(U)*(outer_idx*ST::underlying_rows*subtile_cols + (row+0)*subtile_cols + col%subtile_cols);
|
||||
const uint32_t addr_2 = shared_addr + sizeof(U)*(outer_idx*ST::underlying_rows*subtile_cols + (row+8)*subtile_cols + col%subtile_cols);
|
||||
const int swizzle_1 = blit ^ ((addr_1 % swizzle_repeat) >> 7) << 4;
|
||||
const int swizzle_2 = blit ^ ((addr_2 % swizzle_repeat) >> 7) << 4;
|
||||
move<U>::sts((addr_1+ 0)^swizzle_1, tmp[0].x);
|
||||
move<U>::sts((addr_1+ 4)^swizzle_1, tmp[0].y);
|
||||
move<U>::sts((addr_1+32)^swizzle_1, tmp[2].x);
|
||||
move<U>::sts((addr_1+36)^swizzle_1, tmp[2].y);
|
||||
move<U>::sts((addr_2+ 0)^swizzle_2, tmp[1].x);
|
||||
move<U>::sts((addr_2+ 4)^swizzle_2, tmp[1].y);
|
||||
move<U>::sts((addr_2+32)^swizzle_2, tmp[3].x);
|
||||
move<U>::sts((addr_2+36)^swizzle_2, tmp[3].y);
|
||||
}
|
||||
else {
|
||||
// handle the column-major layout
|
||||
int row = (local_warpid*warp_height + i)*src.tile_size_row + 2*(warp_laneid % 4);
|
||||
int col = j*src.tile_size_col + (warp_laneid / 4);
|
||||
U2 tmp[4];
|
||||
tmp[0] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[0]);
|
||||
tmp[1] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[1]);
|
||||
tmp[2] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[2]);
|
||||
tmp[3] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[3]);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+0, col+0}), tmp[0].x);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+1, col+0}), tmp[0].y);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+0, col+8}), tmp[1].x);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+1, col+8}), tmp[1].y);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+8, col+0}), tmp[2].x);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+9, col+0}), tmp[2].y);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+8, col+8}), tmp[3].x);
|
||||
move<U>::sts(dst.idx(shared_addr, {row+9, col+8}), tmp[3].y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load and store of vectors from/to shared tiles.
|
||||
|
||||
template<ducks::rv::naive_layout RV, ducks::st::all ST>
|
||||
__device__ inline static auto load(RV &dst, const ST &src, int2 row_col) {
|
||||
KITTENS_CHECK_WARP;
|
||||
static_assert(ST::cols>=RV::length, "Shared tile must be at least as wide as the vector.");
|
||||
using T = RV::T;
|
||||
using U = ST::T;
|
||||
int warp_laneid = ::kittens::laneid();
|
||||
|
||||
// convert to shared state space
|
||||
uint32_t shared_addr = static_cast<uint32_t>(__cvta_generic_to_shared(&src.data[0]));
|
||||
|
||||
#pragma unroll
|
||||
for(int col = warp_laneid; col < dst.length; col+=WARP_THREADS) {
|
||||
U tmp;
|
||||
move<U>::lds(tmp, src.idx(shared_addr, {row_col.x, row_col.y + col}));
|
||||
dst.data[col/WARP_THREADS][0] = base_types::convertor<T, U>::convert(tmp);
|
||||
}
|
||||
}
|
||||
|
||||
template<ducks::rv::naive_layout RV, ducks::st::all ST>
|
||||
__device__ inline static auto store(ST &dst, const RV &src, int2 row_col) {
|
||||
KITTENS_CHECK_WARP;
|
||||
static_assert(ST::cols>=RV::length, "Shared tile must be at least as wide as the vector.");
|
||||
using T = RV::T;
|
||||
using U = ST::T;
|
||||
int warp_laneid = ::kittens::laneid();
|
||||
|
||||
// convert to shared state space
|
||||
uint32_t shared_addr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst.data[0]));
|
||||
|
||||
#pragma unroll
|
||||
for(int col = warp_laneid; col < src.length; col+=WARP_THREADS) {
|
||||
U tmp = base_types::convertor<U, T>::convert(src.data[col/WARP_THREADS][0]);
|
||||
move<U>::sts(dst.idx(shared_addr, {row_col.x, row_col.y + col}), tmp);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,325 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group (collaborative warp) ops for loading tensor tiles into register tiles.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Load data from a tensor tile into a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam TM The tensor memory tile type
|
||||
* @param dst[out] The destination register tile.
|
||||
* @param src[in] The source tensor tile.
|
||||
*/
|
||||
template<ducks::rt::row_layout RT, ducks::tt::all TM>
|
||||
__device__ inline static void load_async(RT &dst, const TM &src) {
|
||||
if constexpr (GROUP_WARPS == 1) {
|
||||
static_assert(RT::height == TM::height, "register tile and tensor tile must match height");
|
||||
static_assert(RT::width == TM::width, "register tile and tensor tile must match width");
|
||||
|
||||
using T2 = RT::dtype;
|
||||
using U = typename TM::dtype;
|
||||
using U2 = base_types::packing<typename TM::dtype>::packed_type;
|
||||
|
||||
if constexpr (sizeof(typename TM::dtype) == 1) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
asm volatile(
|
||||
"tcgen05.ld.sync.aligned.16x128b.x2.pack::16b.b32 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(*(uint32_t*) &dst.tiles[i][j].data[0]),
|
||||
"=r"(*(uint32_t*) &dst.tiles[i][j].data[1]),
|
||||
"=r"(*(uint32_t*) &dst.tiles[i][j].data[2]),
|
||||
"=r"(*(uint32_t*) &dst.tiles[i][j].data[3])
|
||||
: "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)/(4/(uint32_t)sizeof(U)))
|
||||
);
|
||||
}
|
||||
}
|
||||
} else if constexpr (sizeof(typename TM::dtype) == 2) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
asm volatile(
|
||||
"tcgen05.ld.sync.aligned.16x128b.x2.pack::16b.b32 {%0, %1, %2, %3}, [%4];\n"
|
||||
: "=r"(*(uint32_t*) &dst.tiles[i][j].data[0]),
|
||||
"=r"(*(uint32_t*) &dst.tiles[i][j].data[1]),
|
||||
"=r"(*(uint32_t*) &dst.tiles[i][j].data[2]),
|
||||
"=r"(*(uint32_t*) &dst.tiles[i][j].data[3])
|
||||
: "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col))
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (sizeof(typename TM::dtype) == 4) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
if constexpr (dst.width%4 == 0) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j+=4) {
|
||||
U2 data[16];
|
||||
asm volatile(
|
||||
"tcgen05.ld.sync.aligned.16x256b.x8.b32 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, [%32];\n"
|
||||
: "=f"(data[0].x), "=f"(data[0].y),
|
||||
"=f"(data[1].x), "=f"(data[1].y),
|
||||
"=f"(data[2].x), "=f"(data[2].y),
|
||||
"=f"(data[3].x), "=f"(data[3].y),
|
||||
"=f"(data[4].x), "=f"(data[4].y),
|
||||
"=f"(data[5].x), "=f"(data[5].y),
|
||||
"=f"(data[6].x), "=f"(data[6].y),
|
||||
"=f"(data[7].x), "=f"(data[7].y),
|
||||
"=f"(data[8].x), "=f"(data[8].y),
|
||||
"=f"(data[9].x), "=f"(data[9].y),
|
||||
"=f"(data[10].x), "=f"(data[10].y),
|
||||
"=f"(data[11].x), "=f"(data[11].y),
|
||||
"=f"(data[12].x), "=f"(data[12].y),
|
||||
"=f"(data[13].x), "=f"(data[13].y),
|
||||
"=f"(data[14].x), "=f"(data[14].y),
|
||||
"=f"(data[15].x), "=f"(data[15].y)
|
||||
: "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)/(4/(uint32_t)sizeof(U)))
|
||||
);
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
dst.tiles[i][j+0].data[k] = base_types::convertor<T2, U2>::convert(data[k]);
|
||||
dst.tiles[i][j+1].data[k] = base_types::convertor<T2, U2>::convert(data[k+4]);
|
||||
dst.tiles[i][j+2].data[k] = base_types::convertor<T2, U2>::convert(data[k+8]);
|
||||
dst.tiles[i][j+3].data[k] = base_types::convertor<T2, U2>::convert(data[k+12]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (dst.width%2 == 0) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j+=2) {
|
||||
U2 data[8];
|
||||
asm volatile(
|
||||
"tcgen05.ld.sync.aligned.16x256b.x4.b32 {%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, [%16];\n"
|
||||
: "=f"(data[0].x), "=f"(data[0].y),
|
||||
"=f"(data[1].x), "=f"(data[1].y),
|
||||
"=f"(data[2].x), "=f"(data[2].y),
|
||||
"=f"(data[3].x), "=f"(data[3].y),
|
||||
"=f"(data[4].x), "=f"(data[4].y),
|
||||
"=f"(data[5].x), "=f"(data[5].y),
|
||||
"=f"(data[6].x), "=f"(data[6].y),
|
||||
"=f"(data[7].x), "=f"(data[7].y)
|
||||
: "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)/(4/(uint32_t)sizeof(U)))
|
||||
);
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
dst.tiles[i][j+0].data[k] = base_types::convertor<T2, U2>::convert(data[k]);
|
||||
dst.tiles[i][j+1].data[k] = base_types::convertor<T2, U2>::convert(data[k+4]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
U2 data[4];
|
||||
asm volatile(
|
||||
"tcgen05.ld.sync.aligned.16x256b.x2.b32 {%0, %1, %2, %3, %4, %5, %6, %7}, [%8];\n"
|
||||
: "=f"(data[0].x), "=f"(data[0].y),
|
||||
"=f"(data[1].x), "=f"(data[1].y),
|
||||
"=f"(data[2].x), "=f"(data[2].y),
|
||||
"=f"(data[3].x), "=f"(data[3].y)
|
||||
: "r"(src.addr + ((i * dst.tile_size_row) << 16) + (j * dst.tile_size_col)/(4/(uint32_t)sizeof(U)))
|
||||
);
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
dst.tiles[i][j].data[k] = base_types::convertor<T2, U2>::convert(data[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(GROUP_WARPS==4 || GROUP_WARPS==8);
|
||||
constexpr int warp_rows = TM::rows/GROUP_WARPS;
|
||||
static_assert(TM::cols==RT::cols);
|
||||
static_assert(warp_rows==RT::rows);
|
||||
if constexpr (GROUP_WARPS == 4) {
|
||||
auto src_subtile = src.template subtile<tt<typename TM::dtype, warp_rows, TM::cols>>(32*warpid(), 0);
|
||||
::kittens::group<1>::load_async(dst, src_subtile);
|
||||
}
|
||||
else {
|
||||
auto src_subtile = src.template subtile<tt<typename TM::dtype, warp_rows, TM::cols>>(32*(warpid()%4)+16*(warpid()/4), 0);
|
||||
::kittens::group<1>::load_async(dst, src_subtile);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Store data into a tensor tile from a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam TM The tensor memory tile type
|
||||
* @param dst[out] The destination tensor tile.
|
||||
* @param src[in] The source register tile.
|
||||
*/
|
||||
template<ducks::rt::all RT, ducks::tt::all TM>
|
||||
__device__ inline static void store_async(TM &dst, const RT &src) {
|
||||
if constexpr (GROUP_WARPS == 1) {
|
||||
static_assert(RT::height == TM::height, "register tile and tensor tile must match height");
|
||||
static_assert(RT::width == TM::width, "register tile and tensor tile must match width");
|
||||
|
||||
using T2 = RT::dtype;
|
||||
using T = base_types::packing<T2>::unpacked_type;
|
||||
using U = TM::dtype;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
|
||||
if constexpr (sizeof(typename TM::dtype) == 2) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
if constexpr (src.width%4 == 0) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j+=4) {
|
||||
asm volatile(
|
||||
"tcgen05.st.sync.aligned.16x128b.x8.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16};\n"
|
||||
:: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+0].data[0]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+0].data[1]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+0].data[2]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+0].data[3]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+1].data[0]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+1].data[1]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+1].data[2]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+1].data[3]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+2].data[0]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+2].data[1]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+2].data[2]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+2].data[3]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+3].data[0]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+3].data[1]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+3].data[2]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+3].data[3])
|
||||
);
|
||||
}
|
||||
}
|
||||
else if constexpr (src.width%2 == 0) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j+=2) {
|
||||
asm volatile(
|
||||
"tcgen05.st.sync.aligned.16x128b.x4.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8};\n"
|
||||
:: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+0].data[0]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+0].data[1]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+0].data[2]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+0].data[3]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+1].data[0]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+1].data[1]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+1].data[2]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j+1].data[3])
|
||||
);
|
||||
}
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
asm volatile(
|
||||
"tcgen05.st.sync.aligned.16x128b.x2.b32 [%0], {%1, %2, %3, %4};\n"
|
||||
:: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j].data[0]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j].data[1]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j].data[2]),
|
||||
"r"(*(uint32_t*)&src.tiles[i][j].data[3])
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (sizeof(typename TM::dtype) == 4) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
if constexpr(src.width%4 == 0) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j+=4) {
|
||||
U2 data[16];
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
data[k] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[k]);
|
||||
data[k+4] = base_types::convertor<U2, T2>::convert(src.tiles[i][j+1].data[k]);
|
||||
data[k+8] = base_types::convertor<U2, T2>::convert(src.tiles[i][j+2].data[k]);
|
||||
data[k+12] = base_types::convertor<U2, T2>::convert(src.tiles[i][j+3].data[k]);
|
||||
}
|
||||
asm volatile(
|
||||
"tcgen05.st.sync.aligned.16x256b.x8.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32};\n"
|
||||
:: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))),
|
||||
"f"(data[0].x), "f"(data[0].y),
|
||||
"f"(data[1].x), "f"(data[1].y),
|
||||
"f"(data[2].x), "f"(data[2].y),
|
||||
"f"(data[3].x), "f"(data[3].y),
|
||||
"f"(data[4].x), "f"(data[4].y),
|
||||
"f"(data[5].x), "f"(data[5].y),
|
||||
"f"(data[6].x), "f"(data[6].y),
|
||||
"f"(data[7].x), "f"(data[7].y),
|
||||
"f"(data[8].x), "f"(data[8].y),
|
||||
"f"(data[9].x), "f"(data[9].y),
|
||||
"f"(data[10].x), "f"(data[10].y),
|
||||
"f"(data[11].x), "f"(data[11].y),
|
||||
"f"(data[12].x), "f"(data[12].y),
|
||||
"f"(data[13].x), "f"(data[13].y),
|
||||
"f"(data[14].x), "f"(data[14].y),
|
||||
"f"(data[15].x), "f"(data[15].y)
|
||||
);
|
||||
}
|
||||
}
|
||||
else if constexpr(src.width%2 == 0) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j+=2) {
|
||||
U2 data[8];
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
data[k] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[k]);
|
||||
data[k+4] = base_types::convertor<U2, T2>::convert(src.tiles[i][j+1].data[k]);
|
||||
}
|
||||
asm volatile(
|
||||
"tcgen05.st.sync.aligned.16x256b.x4.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16};\n"
|
||||
:: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))),
|
||||
"f"(data[0].x), "f"(data[0].y),
|
||||
"f"(data[1].x), "f"(data[1].y),
|
||||
"f"(data[2].x), "f"(data[2].y),
|
||||
"f"(data[3].x), "f"(data[3].y),
|
||||
"f"(data[4].x), "f"(data[4].y),
|
||||
"f"(data[5].x), "f"(data[5].y),
|
||||
"f"(data[6].x), "f"(data[6].y),
|
||||
"f"(data[7].x), "f"(data[7].y)
|
||||
);
|
||||
}
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
U2 data[4];
|
||||
#pragma unroll
|
||||
for(int k = 0; k < 4; k++) {
|
||||
data[k] = base_types::convertor<U2, T2>::convert(src.tiles[i][j].data[k]);
|
||||
}
|
||||
asm volatile(
|
||||
"tcgen05.st.sync.aligned.16x256b.x2.b32 [%0], {%1, %2, %3, %4, %5, %6, %7, %8};\n"
|
||||
:: "r"(dst.addr + ((i * src.tile_size_row) << 16) + (j * src.tile_size_col)/(4/(uint32_t)sizeof(U))),
|
||||
"f"(data[0].x), "f"(data[0].y),
|
||||
"f"(data[1].x), "f"(data[1].y),
|
||||
"f"(data[2].x), "f"(data[2].y),
|
||||
"f"(data[3].x), "f"(data[3].y)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(GROUP_WARPS==4 || GROUP_WARPS==8);
|
||||
constexpr int warp_rows = TM::rows/GROUP_WARPS;
|
||||
static_assert(TM::cols==RT::cols);
|
||||
static_assert(warp_rows==RT::rows);
|
||||
if constexpr (GROUP_WARPS == 4) {
|
||||
auto dst_subtile = dst.template subtile<tt<typename TM::dtype, warp_rows, TM::cols>>(32*warpid(), 0);
|
||||
::kittens::group<1>::store_async(dst_subtile, src);
|
||||
}
|
||||
else {
|
||||
auto dst_subtile = dst.template subtile<tt<typename TM::dtype, warp_rows, TM::cols>>(32*(warpid()%4)+16*(warpid()/4), 0);
|
||||
::kittens::group<1>::store_async(dst_subtile, src);
|
||||
}
|
||||
}
|
||||
}
|
||||
16
extra/thunder/cuda/include/ops/group/memory/tile/tile.cuh
Normal file
16
extra/thunder/cuda/include/ops/group/memory/tile/tile.cuh
Normal file
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of group memory operations on tiles.
|
||||
*/
|
||||
|
||||
#include "shared_to_register.cuh"
|
||||
#include "global_to_register.cuh"
|
||||
#include "global_to_shared.cuh"
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
#include "tensor_to_register.cuh"
|
||||
#endif
|
||||
|
||||
#include "complex/complex_shared_to_register.cuh"
|
||||
#include "complex/complex_global_to_register.cuh"
|
||||
#include "complex/complex_global_to_shared.cuh"
|
||||
|
||||
134
extra/thunder/cuda/include/ops/group/memory/tile/tma.cuh
Normal file
134
extra/thunder/cuda/include/ops/group/memory/tile/tma.cuh
Normal file
@@ -0,0 +1,134 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group scope to call tile TMA functions.
|
||||
*/
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void prefetch(ST &dst, const GL &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::prefetch<axis, policy, ST, GL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void prefetch(ST &dst, const GL &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::prefetch<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_async<axis, policy, ST, GL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_async<axis, policy, ST, PGL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_async<dim::ROW, cache_policy::NORMAL, ST, PGL, COORD>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_add_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_add_async<axis, policy, ST, GL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_add_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_add_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_add_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_add_async<axis, policy, ST, PGL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_add_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_add_async<dim::ROW, cache_policy::NORMAL, ST, PGL, COORD>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_min_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_min_async<axis, policy, ST, GL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_min_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_min_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_min_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_min_async<axis, policy, ST, PGL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_min_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_min_async<dim::ROW, cache_policy::NORMAL, ST, PGL, COORD>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_max_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_max_async<axis, policy, ST, GL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_max_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_max_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_max_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_max_async<axis, policy, ST, PGL, COORD>(dst, src, idx); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_max_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::store_max_async<dim::ROW, cache_policy::NORMAL, ST, PGL, COORD>(dst, src, idx);
|
||||
}
|
||||
}
|
||||
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::load_async<axis, policy, ST, GL, COORD>(dst, src, idx, bar); // Don't do the mask
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::load_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx, bar);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group scope to call tile TMA cluster functions.
|
||||
*/
|
||||
|
||||
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::cluster::load_async<axis, policy, ST, GL, COORD>(dst, src, idx, bar, cluster_mask, dst_mbar_cta);
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::cluster::load_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx, bar, cluster_mask, dst_mbar_cta);
|
||||
}
|
||||
}
|
||||
#else
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::cluster::load_async<axis, policy, ST, GL, COORD>(dst, src, idx, bar, cluster_mask);
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::cluster::load_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx, bar, cluster_mask);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
68
extra/thunder/cuda/include/ops/group/memory/util/tma.cuh
Normal file
68
extra/thunder/cuda/include/ops/group/memory/util/tma.cuh
Normal file
@@ -0,0 +1,68 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Various utilities for group TMA memory operations.
|
||||
*/
|
||||
|
||||
/* ---------- Barrier functions for async load ---------- */
|
||||
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore.
|
||||
*
|
||||
* This function sets the number of bytes expected at the semaphore for the first thread in the warp.
|
||||
* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly
|
||||
* instruction to set the expected number of bytes.
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param bytes The number of bytes expected at the semaphore.
|
||||
*/
|
||||
__device__ static inline void expect_bytes(semaphore& bar, uint32_t bytes) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::expect_bytes(bar, bytes);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore.
|
||||
*
|
||||
* This function sets the number of bytes expected at the mbarrier before the transaction arrives.
|
||||
*/
|
||||
template<typename T, typename... args>
|
||||
__device__ static inline void expect(semaphore& bar, const T& _1, const args&... _2) {
|
||||
expect_bytes(bar, size_bytes<T, args...>);
|
||||
}
|
||||
|
||||
/* ---------- Synchronization functions for async store ---------- */
|
||||
|
||||
/**
|
||||
* @brief Commits previous asynchronous TMA stores to a group and performs them.
|
||||
*/
|
||||
__device__ static inline void store_commit_group() {
|
||||
asm volatile("cp.async.bulk.commit_group;");
|
||||
}
|
||||
/**
|
||||
* @brief Waits for previous committed TMA store groups to complete.
|
||||
*
|
||||
* @tparam N The maximum number of remaining TMA store groups. Defaults to 0.
|
||||
*/
|
||||
template <int N=0>
|
||||
__device__ static inline void store_async_wait() {
|
||||
asm volatile (
|
||||
"cp.async.bulk.wait_group %0;"
|
||||
:
|
||||
: "n"(N)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Waits for previous committed TMA store groups to finish reading from shared memory.
|
||||
*
|
||||
* @tparam N The maximum number of remaining TMA store groups. Defaults to 0.
|
||||
*/
|
||||
template <int N=0>
|
||||
__device__ static inline void store_async_read_wait() {
|
||||
asm volatile (
|
||||
"cp.async.bulk.wait_group.read %0;"
|
||||
:
|
||||
: "n"(N)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
|
||||
/**
|
||||
* @brief Waits for the requested semaphore phase, at cluster scope
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void wait(semaphore& bar, int kPhaseBit) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore, assuming a multicast instruction.
|
||||
*
|
||||
* This function sets the number of bytes expected at the semaphore for the first thread in the warp.
|
||||
* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly
|
||||
* instruction to set the expected number of bytes.
|
||||
*
|
||||
* It's worth being aware that this function is particularly necessary for multicast loads, and
|
||||
* distributed shared memory can actually be done with a normal tma::expect followed by wait. See
|
||||
* the unit tests of dsmem for an example.
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param bytes The number of bytes expected at the semaphore.
|
||||
*/
|
||||
__device__ static inline void expect_bytes(semaphore& bar, uint32_t bytes, int dst_cta) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::cluster::expect_bytes(bar, bytes, dst_cta);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore.
|
||||
*
|
||||
* This function sets the number of bytes expected at the semaphore for the first thread in the warp.
|
||||
* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly
|
||||
* instruction to set the expected number of bytes.
|
||||
*
|
||||
* @tparam T The type of the data to be stored at the semaphore.
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
*/
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore.
|
||||
*
|
||||
* This function sets the number of bytes expected at the mbarrier before the transaction arrives.
|
||||
*/
|
||||
template<typename T, typename... args>
|
||||
__device__ static inline void expect(semaphore& bar, int dst_cta, const T& _1, const args&... _2) {
|
||||
expect_bytes(bar, size_bytes<T, args...>, dst_cta);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Arrives at a semaphore in cluster scope.
|
||||
*
|
||||
* Marks a thread arrival at an mbarrier
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void arrive(semaphore& bar, int dst_cta, uint32_t count=1) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::cluster::arrive(bar, dst_cta, count);
|
||||
}
|
||||
}
|
||||
|
||||
// Generic transfer
|
||||
__device__ static inline void store_async(void *dst, void *src, int dst_cta, uint32_t size_bytes, semaphore& bar) {
|
||||
if(laneid() == 0) {
|
||||
::kittens::tma::cluster::store_async(dst, src, dst_cta, size_bytes, bar);
|
||||
}
|
||||
}
|
||||
|
||||
// Templated transfer for convenience
|
||||
template<typename T>
|
||||
__device__ static inline void store_async(T &dst_, T &src_, int dst_cta, semaphore& bar) {
|
||||
store_async((void*)&dst_, (void*)&src_, dst_cta, size_bytes<T>, bar);
|
||||
}
|
||||
168
extra/thunder/cuda/include/ops/group/memory/util/util.cuh
Normal file
168
extra/thunder/cuda/include/ops/group/memory/util/util.cuh
Normal file
@@ -0,0 +1,168 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Various utilities for group memory operations.
|
||||
*/
|
||||
|
||||
|
||||
template<int N=0> __device__ static inline void load_async_wait(int bar_id) { // for completing (non-TMA) async loads
|
||||
asm volatile("cp.async.wait_group %0;\n" : : "n"(N) : "memory");
|
||||
sync(bar_id);
|
||||
}
|
||||
template<int N=0> __device__ static inline void load_async_wait() { // for completing (non-TMA) async loads
|
||||
KITTENS_CHECK_WARP
|
||||
asm volatile("cp.async.wait_group %0;\n" : : "n"(N) : "memory");
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
__device__ static inline void arrive(barrier<GROUP_WARPS> bar) {
|
||||
asm volatile("bar.arrive %0, %1;\n" :: "r"(bar.barrier_id), "n"(GROUP_WARPS*WARP_THREADS) : "memory");
|
||||
}
|
||||
__device__ static inline void arrive_and_wait(barrier<GROUP_WARPS> bar) {
|
||||
asm volatile("bar.sync %0, %1;\n" :: "r"(bar.barrier_id), "n"(GROUP_WARPS*WARP_THREADS) : "memory");
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Initializes a synchronization semaphore with a transaction count and sets the expected number of bytes.
|
||||
*
|
||||
* This function sets up a semaphore that is used to synchronize threads within a block during asynchronous operations.
|
||||
* It initializes the semaphore with a thread count semaphore.
|
||||
*
|
||||
* Additionally, if it is given a shared tile type, it will also call `set_bytes` to prepare for the memory transaction.
|
||||
*
|
||||
* @param[out] semaphore The semaphore variable to initialize.
|
||||
* @param[in] tc The thread counter for the semaphore.
|
||||
*/
|
||||
__device__ static inline void init_semaphore(semaphore& bar, int thread_count, int transaction_count=0) {
|
||||
if (laneid() == 0) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t bar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
asm volatile (
|
||||
"mbarrier.init.shared::cta.b64 [%0], %1;\n"
|
||||
:: "r"(bar_ptr), "r"(thread_count+transaction_count)
|
||||
);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Invalidate an mbarrier
|
||||
*
|
||||
* @param[out] semaphore The semaphore variable to initialize.
|
||||
* @param[in] tc The thread counter for the semaphore.
|
||||
*/
|
||||
__device__ static inline void invalidate_semaphore(semaphore& bar) {
|
||||
if (laneid() == 0) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t bar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
asm volatile (
|
||||
"mbarrier.inval.shared::cta.b64 [%0];\n"
|
||||
:: "r"(bar_ptr)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Arrives at a semaphore.
|
||||
*
|
||||
* Marks a warp arrival at an mbarrier
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void arrive(semaphore& sem) {
|
||||
if(laneid() == 0) {
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&sem));
|
||||
asm volatile (
|
||||
"mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];\n"
|
||||
:
|
||||
: "r"(mbar_ptr)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int num_warps> __device__ static inline void arrive(barrier<num_warps> bar) {
|
||||
asm volatile("bar.arrive %0, %1;\n" :: "r"(bar.barrier_id), "n"(num_warps*WARP_THREADS) : "memory");
|
||||
}
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief Arrives at a semaphore.
|
||||
*
|
||||
* Marks a warp arrival at an mbarrier
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void arrive(semaphore& sem, uint32_t count) {
|
||||
if(laneid() == 0) {
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&sem));
|
||||
asm volatile (
|
||||
"mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n"
|
||||
:
|
||||
: "r"(mbar_ptr), "r"(count)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Waits for the requested semaphore phase.
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void wait(semaphore& sem, int kPhaseBit) {
|
||||
void const* const ptr = &sem;
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures to save instruction issue slots
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the requested semaphore phase is ready.
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline int test_wait(semaphore& sem, int kPhaseBit) {
|
||||
void const* const ptr = &sem;
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
int result;
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2;\n"
|
||||
"selp.u32 %0,1,0,P1;"
|
||||
"}\n"
|
||||
: "=r"(result)
|
||||
: "r"(mbar_ptr), "r"(kPhaseBit)
|
||||
);
|
||||
return result;
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a warpgroup to collaboratively transfer data directly between global memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively loads data into register vectors from a source array in global memory.
|
||||
*
|
||||
* @tparam RV The register vector type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param[out] dst The destination register vector to load data into.
|
||||
* @param[in] src The source array in global memory to load data from.
|
||||
*/
|
||||
template<ducks::rv::all RV, ducks::gl::all GL>
|
||||
__device__ inline static void load(RV &dst, const GL &src, const coord<rv<typename RV::T, GROUP_WARPS*RV::length, typename RV::layout>> &idx) {
|
||||
if constexpr (GROUP_WARPS == 1) {
|
||||
using T2 = RV::dtype;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
using T = base_types::packing<T2>::unpacked_type;
|
||||
|
||||
U *src_ptr = (U*)&src[(idx.template unit_coord<-1, 3>())];
|
||||
int laneid = ::kittens::laneid();
|
||||
|
||||
if constexpr (std::is_same_v<typename RV::layout, align_l>) {
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < (dst.outer_dim+3)/4; w++) {
|
||||
int idx = w*64 + (laneid/4)*8 + 2*(laneid%4);
|
||||
int o_dim = w*4 + (laneid/4) / 2;
|
||||
int i_dim = (laneid/4) % 2;
|
||||
// this should be a maximally coalesced load.
|
||||
if(idx < dst.outer_dim*16)
|
||||
dst[o_dim][i_dim] = base_types::convertor<T2, U2>::convert(*(U2*)&src_ptr[idx]);
|
||||
}
|
||||
// now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need.
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < dst.outer_dim; w++) {
|
||||
int leader = 8*(w%4) + (laneid%4); // repeats every 64 columns
|
||||
dst[w][0] = packed_shfl_sync(MASK_ALL, dst[w][0], leader);
|
||||
dst[w][1] = packed_shfl_sync(MASK_ALL, dst[w][1], leader+4);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, ortho_l>) {
|
||||
// really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true
|
||||
// otherwise there will be some pain :/
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < (dst.outer_dim+1)/2; w++) {
|
||||
int idx = w*32 + (laneid%4)*8 + (laneid/4);
|
||||
int o_dim = w*2 + (laneid%4) / 2;
|
||||
// this should be a maximally coalesced load.
|
||||
if(idx < dst.outer_dim*16) {
|
||||
T tmp = base_types::convertor<T, U>::convert(src_ptr[idx]);
|
||||
if(laneid%2==0) dst[o_dim][0].x = tmp;
|
||||
else dst[o_dim][0].y = tmp;
|
||||
}
|
||||
}
|
||||
// now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need.
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < dst.outer_dim; w++) {
|
||||
int leader = (laneid/4)*4 + 2*(w%2); // repeats every 64 columns
|
||||
dst[w][0].x = __shfl_sync(MASK_ALL, dst[w][0].x, leader);
|
||||
dst[w][0].y = __shfl_sync(MASK_ALL, dst[w][0].y, leader+1);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, naive_l>) {
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < dst.outer_dim; w++) {
|
||||
if(w < dst.outer_dim-1 || dst.length%32 == 0 || laneid<16) {
|
||||
dst[w][0] = base_types::convertor<T, U>::convert(src_ptr[w*32 + laneid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Call warp level load
|
||||
::kittens::group<1>::load(dst, src, coord<RV>(idx.b, idx.d, idx.r, idx.c*GROUP_WARPS+warpid()));
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Collaboratively stores data from register vectors to a destination array in global memory.
|
||||
*
|
||||
* @tparam RV The register vector type.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register vector to store data from.
|
||||
*/
|
||||
template<ducks::rv::all RV, ducks::gl::all GL>
|
||||
__device__ inline static void store(GL &dst, const RV &src, const coord<rv<typename RV::T, GROUP_WARPS*RV::length, typename RV::layout>> &idx) {
|
||||
if constexpr (GROUP_WARPS == 1) {
|
||||
using T2 = RV::dtype;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
using T = base_types::packing<T2>::unpacked_type;
|
||||
|
||||
U *dst_ptr = (U*)&dst[(idx.template unit_coord<-1, 3>())];
|
||||
int laneid = ::kittens::laneid();
|
||||
|
||||
if constexpr (std::is_same_v<typename RV::layout, align_l>) {
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < (src.outer_dim+3)/4; w++) {
|
||||
int idx = w*64 + (laneid/4)*8 + 2*(laneid%4);
|
||||
int o_dim = w*4 + (laneid/4) / 2;
|
||||
int i_dim = (laneid/4) % 2;
|
||||
// this should be a maximally coalesced store. I hope!
|
||||
if(idx < src.outer_dim*16)
|
||||
*(U2*)&dst_ptr[idx] = base_types::convertor<U2, T2>::convert(src[o_dim][i_dim]);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, ortho_l>) {
|
||||
// really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true
|
||||
// otherwise there will be some pain :/
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < (src.outer_dim+1)/2; w++) {
|
||||
int idx = w*32 + (laneid%4)*8 + (laneid/4);
|
||||
int o_dim = w*2 + (laneid%4) / 2;
|
||||
// this should be a maximally coalesced load.
|
||||
if(idx < src.outer_dim*16) {
|
||||
U tmp;
|
||||
if(laneid%2==0) tmp = base_types::convertor<U, T>::convert(src[o_dim][0].x);
|
||||
else tmp = base_types::convertor<U, T>::convert(src[o_dim][0].y);
|
||||
dst_ptr[idx] = tmp;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, naive_l>) {
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < src.outer_dim; w++) {
|
||||
if(w < src.outer_dim-1 || src.length%32 == 0 || laneid<16) {
|
||||
dst_ptr[w*32 + laneid] = base_types::convertor<U, T>::convert(src[w][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
// Call warp level store
|
||||
::kittens::group<1>::store(dst, src, coord<RV>(idx.b, idx.d, idx.r, idx.c*GROUP_WARPS+warpid()));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group (collaborative warp) ops for loading shared vectors from and storing to global memory.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Loads data from global memory into shared memory vector.
|
||||
*
|
||||
* This function loads data from a global memory location pointed to by `src` into a shared memory vector `dst`.
|
||||
* It calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`.
|
||||
* The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp.
|
||||
*
|
||||
* @tparam SV Shared vector type, must satisfy ducks::sv::all concept.
|
||||
* @param dst Reference to the shared vector where the data will be loaded.
|
||||
* @param src Pointer to the global memory location from where the data will be loaded.
|
||||
*/
|
||||
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void load(SV &dst, const GL &src, const COORD &idx) {
|
||||
constexpr uint32_t elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype);
|
||||
constexpr uint32_t total_calls = SV::length / elem_per_transfer; // guaranteed to divide
|
||||
typename GL::dtype *src_ptr = (typename GL::dtype*)&src[(idx.template unit_coord<-1, 3>())];
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst.data[0]));
|
||||
#pragma unroll
|
||||
for(uint32_t i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) {
|
||||
if(i * elem_per_transfer < dst.length) {
|
||||
float4 tmp;
|
||||
move<float4>::ldg(tmp, (float4*)&src_ptr[i*elem_per_transfer]);
|
||||
move<float4>::sts(dst_ptr + sizeof(typename SV::dtype)*i*elem_per_transfer, tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stores data from a shared memory vector to global memory.
|
||||
*
|
||||
* This function stores data from a shared memory vector `src` to a global memory location pointed to by `dst`.
|
||||
* Similar to the load function, it calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`.
|
||||
* The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp.
|
||||
*
|
||||
* @tparam SV Shared vector type, must satisfy ducks::sv::all concept.
|
||||
* @param dst Pointer to the global memory location where the data will be stored.
|
||||
* @param src Reference to the shared vector from where the data will be stored.
|
||||
*/
|
||||
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store(GL &dst, const SV &src, const COORD &idx) {
|
||||
constexpr uint32_t elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype);
|
||||
constexpr uint32_t total_calls = SV::length / elem_per_transfer; // guaranteed to divide
|
||||
typename GL::dtype *dst_ptr = (typename GL::dtype*)&dst[(idx.template unit_coord<-1, 3>())];
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src.data[0]));
|
||||
#pragma unroll
|
||||
for(uint32_t i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) {
|
||||
if(i * elem_per_transfer < src.length) {
|
||||
float4 tmp;
|
||||
move<float4>::lds(tmp, src_ptr + sizeof(typename SV::dtype)*i*elem_per_transfer);
|
||||
move<float4>::stg((float4*)&dst_ptr[i*elem_per_transfer], tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx) {
|
||||
constexpr uint32_t elem_per_transfer = sizeof(float4) / sizeof(typename SV::dtype);
|
||||
constexpr uint32_t total_calls = SV::length / elem_per_transfer; // guaranteed to divide
|
||||
typename GL::dtype *src_ptr = (typename GL::dtype*)&src[(idx.template unit_coord<-1, 3>())];
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst.data[0]));
|
||||
#pragma unroll
|
||||
for(uint32_t i = threadIdx.x%GROUP_THREADS; i < total_calls; i+=GROUP_THREADS) {
|
||||
if(i * elem_per_transfer < dst.length) {
|
||||
asm volatile(
|
||||
"cp.async.cg.shared.global.L2::128B [%0], [%1], 16;\n"
|
||||
:: "r"(dst_ptr + (uint32_t)sizeof(typename SV::dtype)*i*elem_per_transfer), "l"((uint64_t)&src_ptr[i*elem_per_transfer])
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
asm volatile("cp.async.commit_group;\n" ::: "memory");
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group to collaboratively transfer data directly between shared memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively load data from a shared vector into register vectors split across a warpgroup.
|
||||
*
|
||||
* @tparam RV The register vector type
|
||||
* @tparam SV The shared vector type
|
||||
* @param dst[out] The destination register vector.
|
||||
* @param src[in] The source shared vector.
|
||||
*/
|
||||
template<ducks::rv::all RV, ducks::sv::all SV>
|
||||
__device__ inline static void load(RV &dst, const SV &src) {
|
||||
using T2 = RV::dtype;
|
||||
using U = SV::dtype;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
using T = base_types::packing<T2>::unpacked_type;
|
||||
if constexpr (GROUP_WARPS == 1) {
|
||||
static_assert(SV::length == RV::length);
|
||||
|
||||
int laneid = ::kittens::laneid();
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src.data[0]));
|
||||
|
||||
__syncwarp();
|
||||
if constexpr (std::is_same_v<typename RV::layout, align_l>) {
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < (dst.outer_dim+3)/4; w++) {
|
||||
int idx = w*64 + (laneid/4)*8 + 2*(laneid%4);
|
||||
int o_dim = w*4 + (laneid/4) / 2;
|
||||
int i_dim = (laneid/4) % 2;
|
||||
// this should be a maximally coalesced load.
|
||||
if(idx < dst.outer_dim*16) {
|
||||
U2 tmp;
|
||||
move<U2>::lds(tmp, src_ptr + sizeof(typename SV::dtype)*idx);
|
||||
dst[o_dim][i_dim] = base_types::convertor<T2, U2>::convert(tmp);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
// now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need.
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < dst.outer_dim; w++) {
|
||||
int leader = 8*(w%4) + (laneid%4); // repeats every 64 columns
|
||||
dst[w][0] = packed_shfl_sync(MASK_ALL, dst[w][0], leader);
|
||||
dst[w][1] = packed_shfl_sync(MASK_ALL, dst[w][1], leader+4);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, ortho_l>) {
|
||||
// really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true
|
||||
// otherwise there will be some pain :/
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < (dst.outer_dim+1)/2; w++) {
|
||||
int idx = w*32 + (laneid%4)*8 + (laneid/4);
|
||||
int o_dim = w*2 + (laneid%4) / 2;
|
||||
// this should be a maximally coalesced load.
|
||||
if(idx < dst.outer_dim*16) {
|
||||
U tmp;
|
||||
move<U>::lds(tmp, src_ptr + sizeof(typename SV::dtype)*idx);
|
||||
if(laneid%2==0) dst[o_dim][0].x = base_types::convertor<T, U>::convert(tmp);
|
||||
else dst[o_dim][0].y = base_types::convertor<T, U>::convert(tmp);
|
||||
}
|
||||
}
|
||||
__syncwarp();
|
||||
// now we need to do a bunch of shuffle_sync's to make sure everyone has everything they need.
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < dst.outer_dim; w++) {
|
||||
int leader = (laneid/4)*4 + 2*(w%2); // repeats every 64 columns
|
||||
dst[w][0].x = __shfl_sync(MASK_ALL, dst[w][0].x, leader);
|
||||
dst[w][0].y = __shfl_sync(MASK_ALL, dst[w][0].y, leader+1);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, naive_l>) {
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < dst.outer_dim; w++) {
|
||||
if(w < dst.outer_dim-1 || RV::length%32 == 0 || laneid<16) {
|
||||
U tmp;
|
||||
move<U>::lds(tmp, src_ptr + sizeof(typename SV::dtype)*(w*32 + laneid));
|
||||
dst[w][0] = base_types::convertor<T, U>::convert(tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(SV::length == RV::length*GROUP_WARPS);// confirm size correct
|
||||
auto &_src = src.template subvec<RV::length>(warpid()); // pretend it's smaller and do warp-level load
|
||||
|
||||
::kittens::group<1>::load(dst, _src); // warp-level
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Collaboratively store data into a shared vector from register vectors split across a warpgroup.
|
||||
*
|
||||
* @tparam RV The register vector type
|
||||
* @tparam SV The shared vector type
|
||||
* @param dst[out] The destination shared vector.
|
||||
* @param src[in] The source register vector.
|
||||
*/
|
||||
template<ducks::sv::all SV, ducks::rv::all RV>
|
||||
__device__ inline static void store(SV &dst, const RV &src) {
|
||||
using T2 = RV::dtype;
|
||||
using U = SV::dtype;
|
||||
using U2 = base_types::packing<U>::packed_type;
|
||||
using T = base_types::packing<T2>::unpacked_type;
|
||||
|
||||
if constexpr (GROUP_WARPS == 1) {
|
||||
static_assert(SV::length == RV::length);
|
||||
|
||||
int laneid = ::kittens::laneid();
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst.data[0]));
|
||||
|
||||
__syncwarp();
|
||||
if constexpr (std::is_same_v<typename RV::layout, align_l>) {
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < (src.outer_dim+3)/4; w++) {
|
||||
int idx = w*64 + (laneid/4)*8 + 2*(laneid%4);
|
||||
int o_dim = w*4 + (laneid/4) / 2;
|
||||
int i_dim = (laneid/4) % 2;
|
||||
// this should be a maximally coalesced store. I hope!
|
||||
if(idx < src.outer_dim*16) {
|
||||
U2 tmp = base_types::convertor<U2, T2>::convert(src[o_dim][i_dim]);
|
||||
move<U2>::sts(dst_ptr + sizeof(typename SV::dtype)*idx, tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, ortho_l>) {
|
||||
// really hoping https://stackoverflow.com/questions/15029765/is-coalescing-triggered-for-accessing-memory-in-reverse-order is still true
|
||||
// otherwise there will be some pain :/
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < (src.outer_dim+1)/2; w++) {
|
||||
int idx = w*32 + (laneid%4)*8 + (laneid/4);
|
||||
int o_dim = w*2 + (laneid%4) / 2;
|
||||
// this should be a maximally coalesced load.
|
||||
if(idx < src.outer_dim*16) {
|
||||
U tmp;
|
||||
if(laneid%2==0) tmp = base_types::convertor<U, T>::convert(src[o_dim][0].x);
|
||||
else tmp = base_types::convertor<U, T>::convert(src[o_dim][0].y);
|
||||
move<U>::sts(dst_ptr + sizeof(typename SV::dtype)*idx, tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, naive_l>) {
|
||||
#pragma unroll
|
||||
for(auto w = 0; w < src.outer_dim; w++) {
|
||||
if(w < src.outer_dim-1 || RV::length%32 == 0 || laneid<16) {
|
||||
U tmp = base_types::convertor<U, T>::convert(src[w][0]);
|
||||
move<U>::sts(dst_ptr + sizeof(typename SV::dtype)*(w*32 + laneid), tmp);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(SV::length == RV::length*GROUP_WARPS);// confirm size correct
|
||||
auto &_dst = dst.template subvec<RV::length>(warpid()); // pretend it's smaller and do warp-level load
|
||||
|
||||
::kittens::group<1>::store(_dst, src); // warp-level
|
||||
}
|
||||
}
|
||||
221
extra/thunder/cuda/include/ops/group/memory/vec/tma.cuh
Normal file
221
extra/thunder/cuda/include/ops/group/memory/vec/tma.cuh
Normal file
@@ -0,0 +1,221 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group scope to call vec TMA functions.
|
||||
*/
|
||||
|
||||
/* ---------- Prefetch Tensor Map ---------- */
|
||||
|
||||
/**
|
||||
* @brief Prefetches data from global memory into a shared memory vector, along with the tensormap.
|
||||
*
|
||||
* @tparam SV A shared vector type with a TMA-compatible layout
|
||||
* @param[out] dst The destination shared memory vector.
|
||||
* @param[in] src_tma_map The source tensormap address in global memory
|
||||
* @param[in] vec_idx The coord of the requested vector.
|
||||
*/
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void prefetch(SV &dst, const GL &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<SV, -1>());
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
::kittens::detail::tma::vec_prefetch_tma_internal<policy>(tma_ptr, tma_coord);
|
||||
}
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_LOAD_CACHE_VEC__(prefetch)
|
||||
|
||||
|
||||
/* ---------- Async load and store data from gmem/smem ---------- */
|
||||
|
||||
/**
|
||||
* @brief Asynchronously stores data into global memory from a shared memory vector.
|
||||
*
|
||||
* This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam SV A shared vector type with a TMA-compatible layout
|
||||
* @param[out] dst_tma_map The destination tensormap address in global memory
|
||||
* @param[in] src The source shared memory vector.
|
||||
* @param[in] vec_idx The coord of the vector destination.
|
||||
*/
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_async(const GL &dst, const SV &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_async(const PGL &dst, const SV &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_async)
|
||||
|
||||
|
||||
/**
|
||||
* @brief Asynchronously performs an add reduction and stores the result into global memory.
|
||||
*
|
||||
* This function performs an asynchronous add reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam SV A shared vector type with a TMA-compatible layout
|
||||
* @param[out] dst_tma_map The destination tensormap address in global memory
|
||||
* @param[in] src The source shared memory vector.
|
||||
* @param[in] vec_idx The coord of the vector destination.
|
||||
*/
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_add_async(const GL &dst, const SV &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_add_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_add_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_add_async(const PGL &dst, const SV &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_add_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_add_async)
|
||||
|
||||
|
||||
/**
|
||||
* @brief Asynchronously performs an min reduction and stores the result into global memory.
|
||||
*
|
||||
* This function performs an asynchronous min reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam SV A shared vector type with a TMA-compatible layout
|
||||
* @param[out] dst_tma_map The destination tensormap address in global memory
|
||||
* @param[in] src The source shared memory vector.
|
||||
* @param[in] vec_idx The coord of the vector destination.
|
||||
*/
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_min_async(const GL &dst, const SV &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename SV::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_min_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_min_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_min_async(const PGL &dst, const SV &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename SV::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_min_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_min_async)
|
||||
|
||||
/**
|
||||
* @brief Asynchronously performs an max reduction and stores the result into global memory.
|
||||
*
|
||||
* This function performs an asynchronous max reduction operation using CUDA's cp.reduce.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam SV A shared vector type with a TMA-compatible layout
|
||||
* @param[out] dst_tma_map The destination tensormap address in global memory
|
||||
* @param[in] src The source shared memory vector.
|
||||
* @param[in] vec_idx The coord of the vector destination.
|
||||
*/
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_max_async(const GL &dst, const SV &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename SV::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_max_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_max_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_max_async(const PGL &dst, const SV &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename SV::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_max_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_max_async)
|
||||
|
||||
/**
|
||||
* @brief Asynchronously loads data from global memory into a shared memory vector.
|
||||
*
|
||||
* This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam SV A shared vector type with a TMA-compatible layout
|
||||
* @param[out] dst The destination shared memory vector.
|
||||
* @param[in] src_tma_map The source tensormap address in global memory
|
||||
* @param[in] vec_idx The coord of the requested vector.
|
||||
* @param[in,out] bar The semaphore used for synchronization of the asynchronous copy.
|
||||
*/
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx, semaphore& bar) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<SV, -1>());
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&bar));
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t dst_i_ptr = dst_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_load_async_tma_internal<policy>(tma_ptr, dst_i_ptr, mbar_ptr, tma_coord);
|
||||
}
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_SEMAPHORE_CACHE_VEC__(load_async)
|
||||
@@ -0,0 +1,31 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group scope to call vec TMA cluster functions.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Asynchronously loads data from global memory into a shared memory vector, broadcast across a cluster
|
||||
*
|
||||
* This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam SV A shared vector type with a TMA-compatible layout
|
||||
* @param[out] dst The destination shared memory vector.
|
||||
* @param[in] src_tma_map The source tensormap address in global memory
|
||||
* @param[in,out] bar The semaphore used for synchronization of the asynchronous copy.
|
||||
* @param[in] vec_idx The coord of the requested vector.
|
||||
* @param[in] cluster_mask The mask of the clusters to broadcast to.
|
||||
*/
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<SV, -1>());
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&bar));
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst));
|
||||
for(int i = ::kittens::laneid(); i < ::kittens::detail::tma::sv_tma_dim2<SV>; i += WARP_THREADS) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t dst_i_ptr = dst_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::cluster::vec_load_async_tma_internal<policy>(tma_ptr, dst_i_ptr, mbar_ptr, tma_coord, cluster_mask, dst_mbar_cta);
|
||||
}
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_CLUSTER_SEMAPHORE_CACHE_VEC__(load_async)
|
||||
8
extra/thunder/cuda/include/ops/group/memory/vec/vec.cuh
Normal file
8
extra/thunder/cuda/include/ops/group/memory/vec/vec.cuh
Normal file
@@ -0,0 +1,8 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of group memory operations on vectors.
|
||||
*/
|
||||
|
||||
#include "shared_to_register.cuh"
|
||||
#include "global_to_register.cuh"
|
||||
#include "global_to_shared.cuh"
|
||||
17
extra/thunder/cuda/include/ops/group/mma/mma.cuh
Normal file
17
extra/thunder/cuda/include/ops/group/mma/mma.cuh
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header for all group-scope MMA operations.
|
||||
*/
|
||||
|
||||
// All compilation targets can use the warp-scope MMA operations.
|
||||
#include "warp/warp.cuh"
|
||||
|
||||
// Hopper has its own warpgroup-scope MMA operations.
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include "warpgroup/warpgroup.cuh"
|
||||
#endif
|
||||
|
||||
// Blackwell has its own tensor-scope MMA operations.
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
#include "tensor/tensor.cuh"
|
||||
#endif
|
||||
172
extra/thunder/cuda/include/ops/group/mma/tensor/tensor.cuh
Normal file
172
extra/thunder/cuda/include/ops/group/mma/tensor/tensor.cuh
Normal file
@@ -0,0 +1,172 @@
|
||||
/**
|
||||
* @file Group-level tcgen05 MMA operations.
|
||||
*/
|
||||
|
||||
template<int trans_a, int n_trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B, int acc=1, int ncta=1>
|
||||
__device__ static inline void mma(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
if(laneid() == 0) ::kittens::mma<trans_a, n_trans_b, D, A, B, acc, ncta>(d, a, b, sem);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B, int acc=1>
|
||||
__device__ static inline void mma2(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<trans_a, trans_b, D, A, B, acc, 2>(d, a, b, sem);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<trans_a, trans_b, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<trans_a, trans_b, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::N, transpose::N, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::N, transpose::N, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_ABt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::N, transpose::T, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_ABt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::N, transpose::T, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AtB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::T, transpose::N, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AtB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::T, transpose::N, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AtBt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::T, transpose::T, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AtBt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::T, transpose::T, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::N, transpose::N, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::N, transpose::N, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_ABt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::N, transpose::T, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_ABt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::N, transpose::T, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AtB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::T, transpose::N, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AtB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::T, transpose::N, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AtBt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::T, transpose::T, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AtBt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::T, transpose::T, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
|
||||
// no sem versions
|
||||
|
||||
|
||||
template<int trans_a, int n_trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B, int acc=1, int ncta=1>
|
||||
__device__ static inline void mma(D &d, const A &a, const B &b) {
|
||||
if(laneid() == 0) ::kittens::mma<trans_a, n_trans_b, D, A, B, acc, ncta>(d, a, b);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B, int acc=1>
|
||||
__device__ static inline void mma2(D &d, const A &a, const B &b) {
|
||||
mma<trans_a, trans_b, D, A, B, acc, 2>(d, a, b);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm(D &d, const A &a, const B &b) {
|
||||
mma<trans_a, trans_b, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2(D &d, const A &a, const B &b) {
|
||||
mma2<trans_a, trans_b, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AB(D &d, const A &a, const B &b) {
|
||||
mma<transpose::N, transpose::N, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AB(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::N, transpose::N, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_ABt(D &d, const A &a, const B &b) {
|
||||
mma<transpose::N, transpose::T, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_ABt(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::N, transpose::T, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AtB(D &d, const A &a, const B &b) {
|
||||
mma<transpose::T, transpose::N, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AtB(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::T, transpose::N, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AtBt(D &d, const A &a, const B &b) {
|
||||
mma<transpose::T, transpose::T, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AtBt(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::T, transpose::T, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AB(D &d, const A &a, const B &b) {
|
||||
mma<transpose::N, transpose::N, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AB(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::N, transpose::N, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_ABt(D &d, const A &a, const B &b) {
|
||||
mma<transpose::N, transpose::T, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_ABt(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::N, transpose::T, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AtB(D &d, const A &a, const B &b) {
|
||||
mma<transpose::T, transpose::N, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AtB(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::T, transpose::N, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AtBt(D &d, const A &a, const B &b) {
|
||||
mma<transpose::T, transpose::T, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AtBt(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::T, transpose::T, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
947
extra/thunder/cuda/include/ops/group/mma/warp/warp.cuh
Normal file
947
extra/thunder/cuda/include/ops/group/mma/warp/warp.cuh
Normal file
@@ -0,0 +1,947 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Matrix multiply-accumulate operations for tiles stored in registers.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Perform the HMMA.16816 operation.
|
||||
*
|
||||
* This function performs the half-precision matrix multiply-accumulate operation
|
||||
* using the `mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32` instruction.
|
||||
*
|
||||
* @param[out] d0 The first half of the output float2 accumulator.
|
||||
* @param[out] d1 The second half of the output float2 accumulator.
|
||||
* @param[in] a0 The first half of the first input bf16_2 matrix.
|
||||
* @param[in] a1 The second half of the first input bf16_2 matrix.
|
||||
* @param[in] a2 The first half of the second input bf16_2 matrix.
|
||||
* @param[in] a3 The second half of the second input bf16_2 matrix.
|
||||
* @param[in] b0 The first half of the bf16_2 matrix B.
|
||||
* @param[in] b1 The second half of the bf16_2 matrix B.
|
||||
* @param[in] c0 The first half of the float2 accumulator matrix C.
|
||||
* @param[in] c1 The second half of the float2 accumulator matrix C.
|
||||
*/
|
||||
__device__ static inline void hmma16816( float2 &d0, float2 &d1,
|
||||
const bf16_2 &a0, const bf16_2 &a1, const bf16_2 &a2, const bf16_2 &a3,
|
||||
const bf16_2 &b0, const bf16_2 &b1,
|
||||
const float2 &c0, const float2 &c1 ) {
|
||||
asm volatile(
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " \
|
||||
"{%0, %1, %2, %3}, " \
|
||||
"{%4, %5, %6, %7}, " \
|
||||
"{%8, %9}, " \
|
||||
"{%10, %11, %12, %13};"
|
||||
|
||||
// D matrix
|
||||
: "+f"(d0.x), "+f"(d0.y),
|
||||
"+f"(d1.x), "+f"(d1.y)
|
||||
|
||||
// A matrix
|
||||
: "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)),
|
||||
"r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)),
|
||||
|
||||
// B matrix
|
||||
"r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)),
|
||||
|
||||
// C matrix
|
||||
"f"(c0.x), "f"(c0.y),
|
||||
"f"(c1.x), "f"(c1.y)
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Perform the HMMA.16816 operation with inputs as fp16 and fp32 accumulators
|
||||
*
|
||||
* This function performs the half-precision matrix multiply-accumulate operation
|
||||
* using the `mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32` instruction.
|
||||
*
|
||||
* @param[out] d0 The first half of the output float2 accumulator.
|
||||
* @param[out] d1 The second half of the output float2 accumulator.
|
||||
* @param[in] a0 The first half of the first input half_2 matrix.
|
||||
* @param[in] a1 The second half of the first input half_2 matrix.
|
||||
* @param[in] a2 The first half of the second input half_2 matrix.
|
||||
* @param[in] a3 The second half of the second input half_2 matrix.
|
||||
* @param[in] b0 The first half of the half_2 matrix B.
|
||||
* @param[in] b1 The second half of the half_2 matrix B.
|
||||
* @param[in] c0 The first half of the float2 accumulator matrix C.
|
||||
* @param[in] c1 The second half of the float2 accumulator matrix C.
|
||||
*/
|
||||
__device__ static inline void hmma16816( float2 &d0, float2 &d1,
|
||||
const half_2 &a0, const half_2 &a1, const half_2 &a2, const half_2 &a3,
|
||||
const half_2 &b0, const half_2 &b1,
|
||||
const float2 &c0, const float2 &c1 ) {
|
||||
asm volatile(
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/#multiply-and-accumulate-instruction-mma
|
||||
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " \
|
||||
"{%0, %1, %2, %3}, " \
|
||||
"{%4, %5, %6, %7}, " \
|
||||
"{%8, %9}, " \
|
||||
"{%10, %11, %12, %13};"
|
||||
|
||||
// D matrix
|
||||
: "+f"(d0.x), "+f"(d0.y),
|
||||
"+f"(d1.x), "+f"(d1.y)
|
||||
|
||||
// A matrix
|
||||
: "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)),
|
||||
"r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)),
|
||||
|
||||
// B matrix
|
||||
"r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)),
|
||||
|
||||
// C matrix
|
||||
"f"(c0.x), "f"(c0.y),
|
||||
"f"(c1.x), "f"(c1.y)
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Perform the HMMA.16816 operation.
|
||||
*
|
||||
* This function performs the half-precision matrix multiply-accumulate operation
|
||||
* using the `mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16` instruction.
|
||||
*
|
||||
* @param[out] d0 The first half of the output half_2 accumulator.
|
||||
* @param[out] d1 The second half of the output half_2 accumulator.
|
||||
* @param[in] a0 The first half of the first input half_2 matrix.
|
||||
* @param[in] a1 The second half of the first input half_2 matrix.
|
||||
* @param[in] a2 The first half of the second input half_2 matrix.
|
||||
* @param[in] a3 The second half of the second input half_2 matrix.
|
||||
* @param[in] b0 The first half of the half_2 matrix B.
|
||||
* @param[in] b1 The second half of the half_2 matrix B.
|
||||
* @param[in] c0 The first half of the half_2 accumulator matrix C.
|
||||
* @param[in] c1 The second half of the half_2 accumulator matrix C.
|
||||
*/
|
||||
__device__ static inline void hmma16816( half_2 &d0, half_2 &d1,
|
||||
const half_2 &a0, const half_2 &a1, const half_2 &a2, const half_2 &a3,
|
||||
const half_2 &b0, const half_2 &b1,
|
||||
const half_2 &c0, const half_2 &c1 ) {
|
||||
asm volatile(
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#multiply-and-accumulate-instruction-mma
|
||||
"mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " \
|
||||
"{%0, %1}, " \
|
||||
"{%2, %3, %4, %5}, " \
|
||||
"{%6, %7}, " \
|
||||
"{%8, %9};"
|
||||
|
||||
// D matrix
|
||||
: "=r"(*(uint32_t*)(&d0)), "=r"(*(uint32_t*)(&d1))
|
||||
|
||||
// A matrix
|
||||
: "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)),
|
||||
"r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)),
|
||||
|
||||
// B matrix
|
||||
"r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)),
|
||||
|
||||
// C matrix
|
||||
"r"(*(uint32_t*)(&c0)), "r"(*(uint32_t*)(&c1))
|
||||
);
|
||||
}
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief Perform the HMMA.16816 operation for FP8 using fp8e4m3_2.
|
||||
*
|
||||
* Using mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 instruction
|
||||
* but with fp8e4m3_2 (2 FP8 values) instead of fp8e4m3_4
|
||||
*/
|
||||
/**
|
||||
* @brief Perform the HMMA.16816 operation for FP8.
|
||||
*
|
||||
* This function performs the fp8-precision matrix multiply-accumulate operation
|
||||
* using the `mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32` instruction.
|
||||
*
|
||||
* @param[out] d0 The first half of the output float2 accumulator.
|
||||
* @param[out] d1 The second half of the output float2 accumulator.
|
||||
* @param[in] a0,a1,a2,a3 Input FP8 matrix A values
|
||||
* @param[in] b0,b1 Input FP8 matrix B values
|
||||
* @param[in] c0,c1 Input float2 accumulator matrix C values
|
||||
*/
|
||||
__device__ static inline void hmma16816( float2 &d0, float2 &d1,
|
||||
const fp8e4m3_4 &a0, const fp8e4m3_4 &a1,
|
||||
const fp8e4m3_4 &a2, const fp8e4m3_4 &a3,
|
||||
const fp8e4m3_4 &b0, const fp8e4m3_4 &b1,
|
||||
const float2 &c0, const float2 &c1) {
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 "
|
||||
"{%0, %1, %2, %3}, "
|
||||
"{%4, %5, %6, %7}, "
|
||||
"{%8, %9}, "
|
||||
"{%10, %11, %12, %13};"
|
||||
|
||||
// D matrix (output)
|
||||
: "+f"(d0.x), "+f"(d0.y),
|
||||
"+f"(d1.x), "+f"(d1.y)
|
||||
|
||||
// A matrix
|
||||
: "r"(*(uint32_t*)(&a0)), "r"(*(uint32_t*)(&a1)),
|
||||
"r"(*(uint32_t*)(&a2)), "r"(*(uint32_t*)(&a3)),
|
||||
|
||||
// B matrix
|
||||
"r"(*(uint32_t*)(&b0)), "r"(*(uint32_t*)(&b1)),
|
||||
|
||||
// C matrix
|
||||
"f"(c0.x), "f"(c0.y),
|
||||
"f"(c1.x), "f"(c1.y)
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<bf16_2, row_layout> matrix.
|
||||
* @param[in] b The second input rt_base<bf16_2, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<bf16, ducks::rt_layout::row> &a,
|
||||
const rt_base<bf16, ducks::rt_layout::col> &b, // in col-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout
|
||||
* with fp16 inputs and fp32 accumulators.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<half_2, row_layout> matrix.
|
||||
* @param[in] b The second input rt_base<half_2, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<half, ducks::rt_layout::row> &a,
|
||||
const rt_base<half, ducks::rt_layout::col> &b, // in col-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<fp8e4m3, row_layout> matrix.
|
||||
* @param[in] b The second input rt_base<fp8e4m3, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AB_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<fp8e4m3, ducks::rt_layout::row> &a,
|
||||
const rt_base<fp8e4m3, ducks::rt_layout::col> &b, // in col-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
#endif
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<half_2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<half_2, row_layout> matrix.
|
||||
* @param[in] b The second input rt_base<half_2, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<half_2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AB_base(rt_base<half, ducks::rt_layout::row> &d,
|
||||
const rt_base<half, ducks::rt_layout::row> &a,
|
||||
const rt_base<half, ducks::rt_layout::col> &b, // in col-major mode
|
||||
const rt_base<half, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Base dot product operation for row layout.
|
||||
*
|
||||
* This function performs the base dot product operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<bf16_2, row_layout> matrix.
|
||||
* @param[in] b The second input rt_base<bf16_2, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<bf16, ducks::rt_layout::row> &a,
|
||||
const rt_base<bf16, ducks::rt_layout::row> &b, // in row-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2], // for some reason this one seems to need to be backwards
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3], // for some reason this one seems to need to be backwards
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Base dot product operation for row layout
|
||||
* with fp16 inputs and fp32 accumulators.
|
||||
*
|
||||
* This function performs the base dot product operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<half_2, row_layout> matrix.
|
||||
* @param[in] b The second input rt_base<half_2, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<half, ducks::rt_layout::row> &a,
|
||||
const rt_base<half, ducks::rt_layout::row> &b, // in row-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2], // for some reason this one seems to need to be backwards
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3], // for some reason this one seems to need to be backwards
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief Base dot product operation for row layout.
|
||||
*
|
||||
* This function performs the base dot product operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<fp8e4m3x4, row_layout> matrix.
|
||||
* @param[in] b The second input rt_base<fp8e4m3x4, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_ABt_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<fp8e4m3, ducks::rt_layout::row> &a,
|
||||
const rt_base<fp8e4m3, ducks::rt_layout::row> &b, // in row-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2], // for some reason this one seems to need to be backwards
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3], // for some reason this one seems to need to be backwards
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout with transposed A.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<bf16_2, col_layout> matrix.
|
||||
* @param[in] b The second input rt_base<bf16_2, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AtB_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<bf16, ducks::rt_layout::col> &a,
|
||||
const rt_base<bf16, ducks::rt_layout::col> &b, // in col-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout with transposed A
|
||||
* with fp16 inputs and fp32 accumulators.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<half_2, col_layout> matrix.
|
||||
* @param[in] b The second input rt_base<half_2, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AtB_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<half, ducks::rt_layout::col> &a,
|
||||
const rt_base<half, ducks::rt_layout::col> &b, // in col-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout with transposed A.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<fp8e4m3x4, col_layout> matrix.
|
||||
* @param[in] b The second input rt_base<fp8e4m3x4, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AtB_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<fp8e4m3, ducks::rt_layout::col> &a,
|
||||
const rt_base<fp8e4m3, ducks::rt_layout::col> &b, // in col-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout with transposed A and B.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<bf16_2, col_layout> matrix.
|
||||
* @param[in] b The second input rt_base<bf16_2, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AtBt_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<bf16, ducks::rt_layout::col> &a,
|
||||
const rt_base<bf16, ducks::rt_layout::row> &b, // in col-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout with transposed A and B
|
||||
* with fp16 inputs and fp32 accumulators.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<half_2, col_layout> matrix.
|
||||
* @param[in] b The second input rt_base<half_2, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AtBt_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<half, ducks::rt_layout::col> &a,
|
||||
const rt_base<half, ducks::rt_layout::row> &b, // in row-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief Base matrix multiply-accumulate operation for row layout with transposed A and B.
|
||||
*
|
||||
* This function performs the base matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function for matrices in row layout.
|
||||
*
|
||||
* @param[out] d The output rt_base<float2, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_base<fp8e4m3x4, col_layout> matrix.
|
||||
* @param[in] b The second input rt_base<fp8e4m3x4, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_base<float2, row_layout> accumulator matrix.
|
||||
*/
|
||||
__device__ static inline void mma_AtBt_base(rt_base<float, ducks::rt_layout::row> &d,
|
||||
const rt_base<fp8e4m3, ducks::rt_layout::col> &a,
|
||||
const rt_base<fp8e4m3, ducks::rt_layout::row> &b, // in col-major mode
|
||||
const rt_base<float, ducks::rt_layout::row> &c) {
|
||||
hmma16816(
|
||||
d.data[0], d.data[1],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[0], b.data[2],
|
||||
c.data[0], c.data[1]
|
||||
);
|
||||
hmma16816(
|
||||
d.data[2], d.data[3],
|
||||
a.data[0], a.data[1], a.data[2], a.data[3],
|
||||
b.data[1], b.data[3],
|
||||
c.data[2], c.data[3]
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Matrix multiply-accumulate operation.
|
||||
*
|
||||
* This function performs the matrix multiply-accumulate operation
|
||||
* using the `hmma16816` function.
|
||||
*
|
||||
* @tparam N The number of row tiles.
|
||||
* @tparam K The number of column tiles for the A matrix and row tiles for the B matrix.
|
||||
* @tparam M The number of column tiles for the B matrix.
|
||||
* @param[out] d The output rt_hf<N, M, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_hf<N, K, row_layout> matrix.
|
||||
* @param[in] b The second input rt_hf<K, M, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_hf<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt::row_layout D, ducks::rt::row_layout A, ducks::rt::col_layout B, ducks::rt::row_layout C>
|
||||
__device__ static inline void mma_AB(D &d,
|
||||
const A &a,
|
||||
const B &b,
|
||||
const C &c) {
|
||||
KITTENS_CHECK_WARP
|
||||
static_assert(D::rows == A::rows && D::cols == B::cols); // Check D matches A, B
|
||||
static_assert(A::cols == B::rows); // Check reduction dim is same
|
||||
static_assert(D::rows == C::rows && D::cols == C::cols); // Check D matches C
|
||||
#ifdef KITTENS_HOPPER
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
|
||||
);
|
||||
#else
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>)
|
||||
);
|
||||
#endif
|
||||
#pragma unroll
|
||||
for(int n = 0; n < D::height; n++) {
|
||||
#pragma unroll
|
||||
for(int m = 0; m < D::width; m++) {
|
||||
mma_AB_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[n][0],
|
||||
b.tiles[0][m],
|
||||
c.tiles[n][m]
|
||||
);
|
||||
#pragma unroll
|
||||
for(int k = 1; k < A::width; k++) {
|
||||
mma_AB_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[n][k],
|
||||
b.tiles[k][m],
|
||||
d.tiles[n][m]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Dot product operation for row layout.
|
||||
*
|
||||
* This function performs the dot product operation
|
||||
* using the `hmma16816` function.
|
||||
*
|
||||
* @tparam N The number of row tiles.
|
||||
* @tparam K The number of column tiles for the A matrix and row tiles for the B matrix.
|
||||
* @tparam M The number of column tiles for the B matrix.
|
||||
* @param[out] d The output rt_fl<N, M, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_bf<N, K, row_layout> matrix.
|
||||
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in row-major mode.
|
||||
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt::row_layout D, ducks::rt::row_layout A, ducks::rt::row_layout B, ducks::rt::row_layout C>
|
||||
__device__ static inline void mma_ABt(D &d,
|
||||
const A &a,
|
||||
const B &b, // notice row and (M, K) instead of col and (K, M)
|
||||
const C &c) {
|
||||
KITTENS_CHECK_WARP
|
||||
static_assert(D::rows == A::rows && D::cols == B::rows); // Check D matches A, B
|
||||
static_assert(A::cols == B::cols); // Check reduction dim is same
|
||||
static_assert(D::rows == C::rows && D::cols == C::cols); // Check D matches C
|
||||
#ifdef KITTENS_HOPPER
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
|
||||
);
|
||||
#else
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>)
|
||||
);
|
||||
#endif
|
||||
#pragma unroll
|
||||
for(int n = 0; n < D::height; n++) {
|
||||
#pragma unroll
|
||||
for(int m = 0; m < D::width; m++) {
|
||||
mma_ABt_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[n][0],
|
||||
b.tiles[m][0],
|
||||
c.tiles[n][m]
|
||||
);
|
||||
#pragma unroll
|
||||
for(int k = 1; k < A::width; k++) {
|
||||
mma_ABt_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[n][k],
|
||||
b.tiles[m][k],
|
||||
d.tiles[n][m]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Matrix multiply-accumulate operation with transposed A.
|
||||
*
|
||||
* This function performs the matrix multiply-accumulate operation
|
||||
* using the `hmma16816` instruction.
|
||||
*
|
||||
* @tparam N The number of row tiles.
|
||||
* @tparam K The number of column tiles for the A matrix and row tiles for the B matrix.
|
||||
* @tparam M The number of column tiles for the B matrix.
|
||||
* @param[out] d The output rt_fl<N, M, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_bf<K, N, row_layout> matrix.
|
||||
* @param[in] b The second input rt_bf<K, M, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt::row_layout D, ducks::rt::col_layout A, ducks::rt::col_layout B, ducks::rt::row_layout C>
|
||||
__device__ static inline void mma_AtB(D &d,
|
||||
const A &a,
|
||||
const B &b,
|
||||
const C &c) {
|
||||
KITTENS_CHECK_WARP
|
||||
static_assert(D::rows == A::cols && D::cols == B::cols); // Check D matches A, B
|
||||
static_assert(A::rows == B::rows); // Check reduction dim is same
|
||||
static_assert(D::rows == C::rows && D::cols == C::cols); // Check D matches C
|
||||
#ifdef KITTENS_HOPPER
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
|
||||
);
|
||||
#else
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>)
|
||||
);
|
||||
#endif
|
||||
#pragma unroll
|
||||
for(int n = 0; n < D::height; n++) {
|
||||
#pragma unroll
|
||||
for(int m = 0; m < D::width; m++) {
|
||||
mma_AtB_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[0][n],
|
||||
b.tiles[0][m],
|
||||
c.tiles[n][m]
|
||||
);
|
||||
#pragma unroll
|
||||
for(int k = 1; k < A::height; k++) {
|
||||
mma_AtB_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[k][n],
|
||||
b.tiles[k][m],
|
||||
d.tiles[n][m]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Matrix multiply-accumulate operation with transposed A and B.
|
||||
*
|
||||
* This function performs the matrix multiply-accumulate operation
|
||||
* using the `hmma16816` instruction.
|
||||
*
|
||||
* @tparam N The number of row tiles.
|
||||
* @tparam K The number of column tiles for the A matrix and row tiles for the B matrix.
|
||||
* @tparam M The number of column tiles for the B matrix.
|
||||
* @param[out] d The output rt_fl<N, M, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_bf<K, N, col_layout> matrix.
|
||||
* @param[in] b The second input rt_bf<M, K, row_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_fl<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<ducks::rt::row_layout D, ducks::rt::col_layout A, ducks::rt::row_layout B, ducks::rt::row_layout C>
|
||||
__device__ static inline void mma_AtBt(D &d,
|
||||
const A &a,
|
||||
const B &b,
|
||||
const C &c) {
|
||||
KITTENS_CHECK_WARP
|
||||
static_assert(D::rows == A::cols && D::cols == B::rows); // Check D matches A, B
|
||||
static_assert(A::rows == B::cols); // Check reduction dim is same
|
||||
static_assert(D::rows == C::rows && D::cols == C::cols); // Check D matches C
|
||||
#ifdef KITTENS_HOPPER
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, fp8e4m3> &&
|
||||
std::is_same_v<typename B::T, fp8e4m3> && std::is_same_v<typename C::T, float>)
|
||||
);
|
||||
#else
|
||||
static_assert(
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, bf16> &&
|
||||
std::is_same_v<typename B::T, bf16> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, float> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, float>) ||
|
||||
(std::is_same_v<typename D::T, half> && std::is_same_v<typename A::T, half> &&
|
||||
std::is_same_v<typename B::T, half> && std::is_same_v<typename C::T, half>)
|
||||
);
|
||||
#endif
|
||||
#pragma unroll
|
||||
for(int n = 0; n < D::height; n++) {
|
||||
#pragma unroll
|
||||
for(int m = 0; m < D::width; m++) {
|
||||
mma_AtBt_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[0][n],
|
||||
b.tiles[m][0],
|
||||
c.tiles[n][m]
|
||||
);
|
||||
#pragma unroll
|
||||
for(int k = 1; k < A::height; k++) {
|
||||
mma_AtBt_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[k][n],
|
||||
b.tiles[m][k],
|
||||
d.tiles[n][m]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int trans_A, int trans_B, ducks::rt::all D, ducks::rt::all A, ducks::rt::all B, ducks::rt::all C>
|
||||
__device__ static inline void mma(D &d,
|
||||
const A &a,
|
||||
const B &b,
|
||||
const C &c) {
|
||||
KITTENS_CHECK_WARP
|
||||
if constexpr(trans_A == transpose::T) {
|
||||
if constexpr(trans_B == transpose::T) {
|
||||
mma_AtBt(d, a, b, c);
|
||||
} else {
|
||||
mma_AtB(d, a, b, c);
|
||||
}
|
||||
} else {
|
||||
if constexpr(trans_B == transpose::T) {
|
||||
mma_ABt(d, a, b, c);
|
||||
} else {
|
||||
mma_AB(d, a, b, c);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<int trans_A, int trans_B, ducks::rt::all A, ducks::rt::all B, ducks::rt::all C>
|
||||
__device__ static inline C mma(const A &a,
|
||||
const B &b,
|
||||
const C &c) {
|
||||
KITTENS_CHECK_WARP
|
||||
C d;
|
||||
if constexpr(trans_A == transpose::T) {
|
||||
if constexpr(trans_B == transpose::T) {
|
||||
mma_AtBt(d, a, b, c);
|
||||
} else {
|
||||
mma_AtB(d, a, b, c);
|
||||
}
|
||||
} else {
|
||||
if constexpr(trans_B == transpose::T) {
|
||||
mma_ABt(d, a, b, c);
|
||||
} else {
|
||||
mma_AB(d, a, b, c);
|
||||
}
|
||||
}
|
||||
return d;
|
||||
}
|
||||
|
||||
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
// -------------------------------------------------- COMPLEX INPUTS --------------------------------------------------
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
// --------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @brief Matrix multiply-accumulate operation for complex tiles
|
||||
*
|
||||
* This function calls mma_AB with hf arguments
|
||||
*
|
||||
* @tparam N The number of row tiles.
|
||||
* @tparam K The number of column tiles for the A matrix and row tiles for the B matrix.
|
||||
* @tparam M The number of column tiles for the B matrix.
|
||||
* @param[out] d The output rt_cmplx_hf<N, M, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_cmplx_hf<N, K, row_layout> matrix.
|
||||
* @param[in] b The second input rt_cmplx_hf<K, M, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_cmplx_hf<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
template<int N, int K, int M>
|
||||
__device__ static inline void mma_AB(crt_hf<N, M, ducks::rt_layout::row> &d,
|
||||
const crt_hf<N, K, ducks::rt_layout::row> &a,
|
||||
const crt_hf<K, M, ducks::rt_layout::col> &b,
|
||||
const crt_hf<N, M, ducks::rt_layout::row> &c) {
|
||||
KITTENS_CHECK_WARP
|
||||
|
||||
// Copy data from input accumulate register into output
|
||||
::kittens::group<1>::copy(d.real, c.real);
|
||||
::kittens::group<1>::copy(d.imag, c.imag);
|
||||
|
||||
// Negative on B matrix so we can use single accum register
|
||||
rt_hf<N, K, ducks::rt_layout::row> tmp;
|
||||
// Hex value for -1 in float16
|
||||
constexpr half factor = std::bit_cast<__half>(uint16_t(0xFB80));
|
||||
::kittens::group<1>::mul(tmp, a.imag, factor);
|
||||
mma_AB(d.real, a.real, b.real, d.real);
|
||||
mma_AB(d.real, tmp, b.imag, d.real);
|
||||
|
||||
mma_AB(d.imag, a.real, b.imag, d.imag);
|
||||
mma_AB(d.imag, a.imag, b.real, d.imag);
|
||||
}
|
||||
/**
|
||||
* @brief Matrix multiply-accumulate operation for complex tiles
|
||||
*
|
||||
* This function calls mma_AB with bf16 arguments
|
||||
*
|
||||
* @tparam N The number of row tiles.
|
||||
* @tparam K The number of column tiles for the A matrix and row tiles for the B matrix.
|
||||
* @tparam M The number of column tiles for the B matrix.
|
||||
* @param[out] d The output rt_cmplx_fl<N, M, row_layout> accumulator.
|
||||
* @param[in] a The first input rt_cmplx_bf<N, K, row_layout> matrix.
|
||||
* @param[in] b The second input rt_cmplx_bf<K, M, col_layout> matrix in column-major mode.
|
||||
* @param[in] c The input rt_cmplx_fl<N, M, row_layout> accumulator matrix.
|
||||
*/
|
||||
|
||||
template<int N, int K, int M>
|
||||
__device__ static inline void mma_AB(crt_fl<N, M, ducks::rt_layout::row> &d,
|
||||
const crt_bf<N, K, ducks::rt_layout::row> &a,
|
||||
const crt_bf<K, M, ducks::rt_layout::col> &b,
|
||||
const crt_fl<N, M, ducks::rt_layout::row> &c) {
|
||||
KITTENS_CHECK_WARP
|
||||
|
||||
// Copy data from input accumulate register into output
|
||||
::kittens::group<1>::copy(d.real, c.real);
|
||||
::kittens::group<1>::copy(d.imag, c.imag);
|
||||
|
||||
// Negative on B matrix so we can use single accum register
|
||||
kittens::rt_bf<N, K, ducks::rt_layout::row> tmp;
|
||||
// Hex value for -1 in bf16
|
||||
constexpr bf16 factor = std::bit_cast<__nv_bfloat16>(uint16_t(0xBF80));
|
||||
::kittens::group<1>::mul(tmp, a.imag, factor);
|
||||
mma_AB(d.real, a.real, b.real, d.real);
|
||||
mma_AB(d.real, tmp, b.imag, d.real);
|
||||
|
||||
mma_AB(d.imag, a.real, b.imag, d.imag);
|
||||
mma_AB(d.imag, a.imag, b.real, d.imag);
|
||||
}
|
||||
@@ -0,0 +1,334 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 112, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 112, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %61, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \
|
||||
"{%56, %57, %58, %59}, " \
|
||||
"%60, " \
|
||||
"p, 1, %63, %62;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %61, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \
|
||||
"{%56, %57, %58, %59}, " \
|
||||
"%60, " \
|
||||
"p, 1, %63, %62;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %33, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27}, " \
|
||||
"{%28, %29, %30, %31}, " \
|
||||
"%32, " \
|
||||
"p, 1, %35, %34;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 112, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %58, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \
|
||||
"%56, " \
|
||||
"%57, " \
|
||||
"p, 1, %61, %59, %60;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %58, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \
|
||||
"%56, " \
|
||||
"%57, " \
|
||||
"p, 1, %61, %59, %60;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %30, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27}, " \
|
||||
"%28, " \
|
||||
"%29, " \
|
||||
"p, 1, %33, %31, %32;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,813 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 128, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 128, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %69, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \
|
||||
"{%64, %65, %66, %67}, " \
|
||||
"%68, " \
|
||||
"p, 1, %71, %70;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y),
|
||||
"+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y),
|
||||
"+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y),
|
||||
"+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y),
|
||||
"+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %69, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \
|
||||
"{%64, %65, %66, %67}, " \
|
||||
"%68, " \
|
||||
"p, 1, %71, %70;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y),
|
||||
"+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y),
|
||||
"+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y),
|
||||
"+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y),
|
||||
"+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %37, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"{%32, %33, %34, %35}, " \
|
||||
"%36, " \
|
||||
"p, 1, %39, %38;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %69, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \
|
||||
"{%64, %65, %66, %67}, " \
|
||||
"%68, " \
|
||||
"p, 1, %70;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y),
|
||||
"+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y),
|
||||
"+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y),
|
||||
"+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y),
|
||||
"+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %69, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \
|
||||
"{%64, %65, %66, %67}, " \
|
||||
"%68, " \
|
||||
"p, 1, %70;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y),
|
||||
"+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y),
|
||||
"+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y),
|
||||
"+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y),
|
||||
"+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %37, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"{%32, %33, %34, %35}, " \
|
||||
"%36, " \
|
||||
"p, 1, %38;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %37, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"{%32, %33, %34, %35}, " \
|
||||
"%36, " \
|
||||
"p, 1, %38;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 128, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %66, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \
|
||||
"%64, " \
|
||||
"%65, " \
|
||||
"p, 1, %69, %67, %68;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y),
|
||||
"+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y),
|
||||
"+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y),
|
||||
"+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y),
|
||||
"+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %66, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \
|
||||
"%64, " \
|
||||
"%65, " \
|
||||
"p, 1, %69, %67, %68;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y),
|
||||
"+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y),
|
||||
"+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y),
|
||||
"+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y),
|
||||
"+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %34, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"%32, " \
|
||||
"%33, " \
|
||||
"p, 1, %37, %35, %36;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %66, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \
|
||||
"%64, " \
|
||||
"%65, " \
|
||||
"p, 1, %67;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y),
|
||||
"+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y),
|
||||
"+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y),
|
||||
"+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y),
|
||||
"+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %66, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63}, " \
|
||||
"%64, " \
|
||||
"%65, " \
|
||||
"p, 1, %67;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y),
|
||||
"+f"(dst.tiles[0][6].data[0].x), "+f"(dst.tiles[0][6].data[0].y),
|
||||
"+f"(dst.tiles[0][6].data[1].x), "+f"(dst.tiles[0][6].data[1].y),
|
||||
"+f"(dst.tiles[0][6].data[2].x), "+f"(dst.tiles[0][6].data[2].y),
|
||||
"+f"(dst.tiles[0][6].data[3].x), "+f"(dst.tiles[0][6].data[3].y),
|
||||
"+f"(dst.tiles[0][7].data[0].x), "+f"(dst.tiles[0][7].data[0].y),
|
||||
"+f"(dst.tiles[0][7].data[1].x), "+f"(dst.tiles[0][7].data[1].y),
|
||||
"+f"(dst.tiles[0][7].data[2].x), "+f"(dst.tiles[0][7].data[2].y),
|
||||
"+f"(dst.tiles[0][7].data[3].x), "+f"(dst.tiles[0][7].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %34, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"%32, " \
|
||||
"%33, " \
|
||||
"p, 1, %35;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %34, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"%32, " \
|
||||
"%33, " \
|
||||
"p, 1, %35;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][7].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,382 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 144, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 144, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %77, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71}, " \
|
||||
"{%72, %73, %74, %75}, " \
|
||||
"%76, " \
|
||||
"p, 1, %79, %78;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %77, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71}, " \
|
||||
"{%72, %73, %74, %75}, " \
|
||||
"%76, " \
|
||||
"p, 1, %79, %78;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %41, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35}, " \
|
||||
"{%36, %37, %38, %39}, " \
|
||||
"%40, " \
|
||||
"p, 1, %43, %42;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 144, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %74, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71}, " \
|
||||
"%72, " \
|
||||
"%73, " \
|
||||
"p, 1, %77, %75, %76;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %74, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71}, " \
|
||||
"%72, " \
|
||||
"%73, " \
|
||||
"p, 1, %77, %75, %76;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %38, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35}, " \
|
||||
"%36, " \
|
||||
"%37, " \
|
||||
"p, 1, %41, %39, %40;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,190 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 16, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 16, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %13, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"{%8, %9, %10, %11}, " \
|
||||
"%12, " \
|
||||
"p, 1, %15, %14;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %13, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"{%8, %9, %10, %11}, " \
|
||||
"%12, " \
|
||||
"p, 1, %15, %14;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %9, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3}, " \
|
||||
"{%4, %5, %6, %7}, " \
|
||||
"%8, " \
|
||||
"p, 1, %11, %10;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 16, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %10, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"%8, " \
|
||||
"%9, " \
|
||||
"p, 1, %13, %11, %12;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %10, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"%8, " \
|
||||
"%9, " \
|
||||
"p, 1, %13, %11, %12;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %6, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3}, " \
|
||||
"%4, " \
|
||||
"%5, " \
|
||||
"p, 1, %9, %7, %8;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,666 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 160, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 160, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %85, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \
|
||||
"{%80, %81, %82, %83}, " \
|
||||
"%84, " \
|
||||
"p, 1, %87, %86;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %85, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \
|
||||
"{%80, %81, %82, %83}, " \
|
||||
"%84, " \
|
||||
"p, 1, %87, %86;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %45, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \
|
||||
"{%40, %41, %42, %43}, " \
|
||||
"%44, " \
|
||||
"p, 1, %47, %46;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %85, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \
|
||||
"{%80, %81, %82, %83}, " \
|
||||
"%84, " \
|
||||
"p, 1, %86;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %85, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \
|
||||
"{%80, %81, %82, %83}, " \
|
||||
"%84, " \
|
||||
"p, 1, %86;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 160, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %82, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \
|
||||
"%80, " \
|
||||
"%81, " \
|
||||
"p, 1, %85, %83, %84;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %82, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \
|
||||
"%80, " \
|
||||
"%81, " \
|
||||
"p, 1, %85, %83, %84;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %42, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \
|
||||
"%40, " \
|
||||
"%41, " \
|
||||
"p, 1, %45, %43, %44;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %82, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \
|
||||
"%80, " \
|
||||
"%81, " \
|
||||
"p, 1, %83;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %82, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79}, " \
|
||||
"%80, " \
|
||||
"%81, " \
|
||||
"p, 1, %83;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,430 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 176, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 176, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %93, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87}, " \
|
||||
"{%88, %89, %90, %91}, " \
|
||||
"%92, " \
|
||||
"p, 1, %95, %94;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %93, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87}, " \
|
||||
"{%88, %89, %90, %91}, " \
|
||||
"%92, " \
|
||||
"p, 1, %95, %94;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %49, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43}, " \
|
||||
"{%44, %45, %46, %47}, " \
|
||||
"%48, " \
|
||||
"p, 1, %51, %50;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 176, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %90, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87}, " \
|
||||
"%88, " \
|
||||
"%89, " \
|
||||
"p, 1, %93, %91, %92;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %90, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87}, " \
|
||||
"%88, " \
|
||||
"%89, " \
|
||||
"p, 1, %93, %91, %92;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %46, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43}, " \
|
||||
"%44, " \
|
||||
"%45, " \
|
||||
"p, 1, %49, %47, %48;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,674 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 192, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 192, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %101, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \
|
||||
"{%96, %97, %98, %99}, " \
|
||||
"%100, " \
|
||||
"p, 1, %103, %102;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %101, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \
|
||||
"{%96, %97, %98, %99}, " \
|
||||
"%100, " \
|
||||
"p, 1, %103, %102;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %53, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"{%48, %49, %50, %51}, " \
|
||||
"%52, " \
|
||||
"p, 1, %55, %54;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %101, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \
|
||||
"{%96, %97, %98, %99}, " \
|
||||
"%100, " \
|
||||
"p, 1, %102;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %101, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \
|
||||
"{%96, %97, %98, %99}, " \
|
||||
"%100, " \
|
||||
"p, 1, %102;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 192, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %98, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \
|
||||
"%96, " \
|
||||
"%97, " \
|
||||
"p, 1, %101, %99, %100;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %98, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \
|
||||
"%96, " \
|
||||
"%97, " \
|
||||
"p, 1, %101, %99, %100;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %50, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"%48, " \
|
||||
"%49, " \
|
||||
"p, 1, %53, %51, %52;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %98, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95}, " \
|
||||
"%96, " \
|
||||
"%97, " \
|
||||
"p, 1, %99;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,478 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 208, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 208, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %109, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103}, " \
|
||||
"{%104, %105, %106, %107}, " \
|
||||
"%108, " \
|
||||
"p, 1, %111, %110;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %109, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103}, " \
|
||||
"{%104, %105, %106, %107}, " \
|
||||
"%108, " \
|
||||
"p, 1, %111, %110;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %57, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51}, " \
|
||||
"{%52, %53, %54, %55}, " \
|
||||
"%56, " \
|
||||
"p, 1, %59, %58;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 208, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %106, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103}, " \
|
||||
"%104, " \
|
||||
"%105, " \
|
||||
"p, 1, %109, %107, %108;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %106, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103}, " \
|
||||
"%104, " \
|
||||
"%105, " \
|
||||
"p, 1, %109, %107, %108;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %54, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51}, " \
|
||||
"%52, " \
|
||||
"%53, " \
|
||||
"p, 1, %57, %55, %56;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,826 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 224, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 224, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %117, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \
|
||||
"{%112, %113, %114, %115}, " \
|
||||
"%116, " \
|
||||
"p, 1, %119, %118;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %117, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \
|
||||
"{%112, %113, %114, %115}, " \
|
||||
"%116, " \
|
||||
"p, 1, %119, %118;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %61, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \
|
||||
"{%56, %57, %58, %59}, " \
|
||||
"%60, " \
|
||||
"p, 1, %63, %62;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %117, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \
|
||||
"{%112, %113, %114, %115}, " \
|
||||
"%116, " \
|
||||
"p, 1, %118;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %117, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \
|
||||
"{%112, %113, %114, %115}, " \
|
||||
"%116, " \
|
||||
"p, 1, %118;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 224, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %114, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \
|
||||
"%112, " \
|
||||
"%113, " \
|
||||
"p, 1, %117, %115, %116;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %114, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \
|
||||
"%112, " \
|
||||
"%113, " \
|
||||
"p, 1, %117, %115, %116;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %58, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55}, " \
|
||||
"%56, " \
|
||||
"%57, " \
|
||||
"p, 1, %61, %59, %60;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %114, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \
|
||||
"%112, " \
|
||||
"%113, " \
|
||||
"p, 1, %117, %115, %116;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %114, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111}, " \
|
||||
"%112, " \
|
||||
"%113, " \
|
||||
"p, 1, %117, %115, %116;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,526 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 240, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 240, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %125, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119}, " \
|
||||
"{%120, %121, %122, %123}, " \
|
||||
"%124, " \
|
||||
"p, 1, %127, %126;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y),
|
||||
"+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y),
|
||||
"+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y),
|
||||
"+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y),
|
||||
"+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %125, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119}, " \
|
||||
"{%120, %121, %122, %123}, " \
|
||||
"%124, " \
|
||||
"p, 1, %127, %126;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y),
|
||||
"+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y),
|
||||
"+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y),
|
||||
"+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y),
|
||||
"+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %65, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59}, " \
|
||||
"{%60, %61, %62, %63}, " \
|
||||
"%64, " \
|
||||
"p, 1, %67, %66;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][14].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][14].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][14].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][14].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 240, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %122, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119}, " \
|
||||
"%120, " \
|
||||
"%121, " \
|
||||
"p, 1, %125, %123, %124;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y),
|
||||
"+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y),
|
||||
"+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y),
|
||||
"+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y),
|
||||
"+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %122, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59, %60, %61, %62, %63, %64, %65, %66, %67, %68, %69, %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92, %93, %94, %95, %96, %97, %98, %99, %100, %101, %102, %103, %104, %105, %106, %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119}, " \
|
||||
"%120, " \
|
||||
"%121, " \
|
||||
"p, 1, %125, %123, %124;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][ 0].data[0].x), "+f"(dst.tiles[0][ 0].data[0].y),
|
||||
"+f"(dst.tiles[0][ 0].data[1].x), "+f"(dst.tiles[0][ 0].data[1].y),
|
||||
"+f"(dst.tiles[0][ 0].data[2].x), "+f"(dst.tiles[0][ 0].data[2].y),
|
||||
"+f"(dst.tiles[0][ 0].data[3].x), "+f"(dst.tiles[0][ 0].data[3].y),
|
||||
"+f"(dst.tiles[0][ 1].data[0].x), "+f"(dst.tiles[0][ 1].data[0].y),
|
||||
"+f"(dst.tiles[0][ 1].data[1].x), "+f"(dst.tiles[0][ 1].data[1].y),
|
||||
"+f"(dst.tiles[0][ 1].data[2].x), "+f"(dst.tiles[0][ 1].data[2].y),
|
||||
"+f"(dst.tiles[0][ 1].data[3].x), "+f"(dst.tiles[0][ 1].data[3].y),
|
||||
"+f"(dst.tiles[0][ 2].data[0].x), "+f"(dst.tiles[0][ 2].data[0].y),
|
||||
"+f"(dst.tiles[0][ 2].data[1].x), "+f"(dst.tiles[0][ 2].data[1].y),
|
||||
"+f"(dst.tiles[0][ 2].data[2].x), "+f"(dst.tiles[0][ 2].data[2].y),
|
||||
"+f"(dst.tiles[0][ 2].data[3].x), "+f"(dst.tiles[0][ 2].data[3].y),
|
||||
"+f"(dst.tiles[0][ 3].data[0].x), "+f"(dst.tiles[0][ 3].data[0].y),
|
||||
"+f"(dst.tiles[0][ 3].data[1].x), "+f"(dst.tiles[0][ 3].data[1].y),
|
||||
"+f"(dst.tiles[0][ 3].data[2].x), "+f"(dst.tiles[0][ 3].data[2].y),
|
||||
"+f"(dst.tiles[0][ 3].data[3].x), "+f"(dst.tiles[0][ 3].data[3].y),
|
||||
"+f"(dst.tiles[0][ 4].data[0].x), "+f"(dst.tiles[0][ 4].data[0].y),
|
||||
"+f"(dst.tiles[0][ 4].data[1].x), "+f"(dst.tiles[0][ 4].data[1].y),
|
||||
"+f"(dst.tiles[0][ 4].data[2].x), "+f"(dst.tiles[0][ 4].data[2].y),
|
||||
"+f"(dst.tiles[0][ 4].data[3].x), "+f"(dst.tiles[0][ 4].data[3].y),
|
||||
"+f"(dst.tiles[0][ 5].data[0].x), "+f"(dst.tiles[0][ 5].data[0].y),
|
||||
"+f"(dst.tiles[0][ 5].data[1].x), "+f"(dst.tiles[0][ 5].data[1].y),
|
||||
"+f"(dst.tiles[0][ 5].data[2].x), "+f"(dst.tiles[0][ 5].data[2].y),
|
||||
"+f"(dst.tiles[0][ 5].data[3].x), "+f"(dst.tiles[0][ 5].data[3].y),
|
||||
"+f"(dst.tiles[0][ 6].data[0].x), "+f"(dst.tiles[0][ 6].data[0].y),
|
||||
"+f"(dst.tiles[0][ 6].data[1].x), "+f"(dst.tiles[0][ 6].data[1].y),
|
||||
"+f"(dst.tiles[0][ 6].data[2].x), "+f"(dst.tiles[0][ 6].data[2].y),
|
||||
"+f"(dst.tiles[0][ 6].data[3].x), "+f"(dst.tiles[0][ 6].data[3].y),
|
||||
"+f"(dst.tiles[0][ 7].data[0].x), "+f"(dst.tiles[0][ 7].data[0].y),
|
||||
"+f"(dst.tiles[0][ 7].data[1].x), "+f"(dst.tiles[0][ 7].data[1].y),
|
||||
"+f"(dst.tiles[0][ 7].data[2].x), "+f"(dst.tiles[0][ 7].data[2].y),
|
||||
"+f"(dst.tiles[0][ 7].data[3].x), "+f"(dst.tiles[0][ 7].data[3].y),
|
||||
"+f"(dst.tiles[0][ 8].data[0].x), "+f"(dst.tiles[0][ 8].data[0].y),
|
||||
"+f"(dst.tiles[0][ 8].data[1].x), "+f"(dst.tiles[0][ 8].data[1].y),
|
||||
"+f"(dst.tiles[0][ 8].data[2].x), "+f"(dst.tiles[0][ 8].data[2].y),
|
||||
"+f"(dst.tiles[0][ 8].data[3].x), "+f"(dst.tiles[0][ 8].data[3].y),
|
||||
"+f"(dst.tiles[0][ 9].data[0].x), "+f"(dst.tiles[0][ 9].data[0].y),
|
||||
"+f"(dst.tiles[0][ 9].data[1].x), "+f"(dst.tiles[0][ 9].data[1].y),
|
||||
"+f"(dst.tiles[0][ 9].data[2].x), "+f"(dst.tiles[0][ 9].data[2].y),
|
||||
"+f"(dst.tiles[0][ 9].data[3].x), "+f"(dst.tiles[0][ 9].data[3].y),
|
||||
"+f"(dst.tiles[0][10].data[0].x), "+f"(dst.tiles[0][10].data[0].y),
|
||||
"+f"(dst.tiles[0][10].data[1].x), "+f"(dst.tiles[0][10].data[1].y),
|
||||
"+f"(dst.tiles[0][10].data[2].x), "+f"(dst.tiles[0][10].data[2].y),
|
||||
"+f"(dst.tiles[0][10].data[3].x), "+f"(dst.tiles[0][10].data[3].y),
|
||||
"+f"(dst.tiles[0][11].data[0].x), "+f"(dst.tiles[0][11].data[0].y),
|
||||
"+f"(dst.tiles[0][11].data[1].x), "+f"(dst.tiles[0][11].data[1].y),
|
||||
"+f"(dst.tiles[0][11].data[2].x), "+f"(dst.tiles[0][11].data[2].y),
|
||||
"+f"(dst.tiles[0][11].data[3].x), "+f"(dst.tiles[0][11].data[3].y),
|
||||
"+f"(dst.tiles[0][12].data[0].x), "+f"(dst.tiles[0][12].data[0].y),
|
||||
"+f"(dst.tiles[0][12].data[1].x), "+f"(dst.tiles[0][12].data[1].y),
|
||||
"+f"(dst.tiles[0][12].data[2].x), "+f"(dst.tiles[0][12].data[2].y),
|
||||
"+f"(dst.tiles[0][12].data[3].x), "+f"(dst.tiles[0][12].data[3].y),
|
||||
"+f"(dst.tiles[0][13].data[0].x), "+f"(dst.tiles[0][13].data[0].y),
|
||||
"+f"(dst.tiles[0][13].data[1].x), "+f"(dst.tiles[0][13].data[1].y),
|
||||
"+f"(dst.tiles[0][13].data[2].x), "+f"(dst.tiles[0][13].data[2].y),
|
||||
"+f"(dst.tiles[0][13].data[3].x), "+f"(dst.tiles[0][13].data[3].y),
|
||||
"+f"(dst.tiles[0][14].data[0].x), "+f"(dst.tiles[0][14].data[0].y),
|
||||
"+f"(dst.tiles[0][14].data[1].x), "+f"(dst.tiles[0][14].data[1].y),
|
||||
"+f"(dst.tiles[0][14].data[2].x), "+f"(dst.tiles[0][14].data[2].y),
|
||||
"+f"(dst.tiles[0][14].data[3].x), "+f"(dst.tiles[0][14].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %62, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47, %48, %49, %50, %51, %52, %53, %54, %55, %56, %57, %58, %59}, " \
|
||||
"%60, " \
|
||||
"%61, " \
|
||||
"p, 1, %65, %63, %64;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][ 0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 5].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 6].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 7].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 8].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][ 9].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][10].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][11].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][12].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][13].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][14].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][14].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][14].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][14].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
1260
extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x256.impl
Normal file
1260
extra/thunder/cuda/include/ops/group/mma/warpgroup/base/64x256.impl
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,446 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 32, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 32, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %21, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"{%16, %17, %18, %19}, " \
|
||||
"%20, " \
|
||||
"p, 1, %23, %22;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %21, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"{%16, %17, %18, %19}, " \
|
||||
"%20, " \
|
||||
"p, 1, %23, %22;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %13, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"{%8, %9, %10, %11}, " \
|
||||
"%12, " \
|
||||
"p, 1, %15, %14;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %21, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"{%16, %17, %18, %19}, " \
|
||||
"%20, " \
|
||||
"p, 1, %22;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %21, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"{%16, %17, %18, %19}, " \
|
||||
"%20, " \
|
||||
"p, 1, %22;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %13, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"{%8, %9, %10, %11}, " \
|
||||
"%12, " \
|
||||
"p, 1, %14;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 32, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %18, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"%16, " \
|
||||
"%17, " \
|
||||
"p, 1, %21, %19, %20;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %18, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"%16, " \
|
||||
"%17, " \
|
||||
"p, 1, %21, %19, %20;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %10, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"%8, " \
|
||||
"%9, " \
|
||||
"p, 1, %13, %11, %12;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %18, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"%16, " \
|
||||
"%17, " \
|
||||
"p, 1, %19;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %18, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"%16, " \
|
||||
"%17, " \
|
||||
"p, 1, %19;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %10, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"%8, " \
|
||||
"%9, " \
|
||||
"p, 1, %11;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %10, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7}, " \
|
||||
"%8, " \
|
||||
"%9, " \
|
||||
"p, 1, %11;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,238 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 48, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 48, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %29, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"{%24, %25, %26, %27}, " \
|
||||
"%28, " \
|
||||
"p, 1, %31, %30;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %29, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"{%24, %25, %26, %27}, " \
|
||||
"%28, " \
|
||||
"p, 1, %31, %30;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %17, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11}, " \
|
||||
"{%12, %13, %14, %15}, " \
|
||||
"%16, " \
|
||||
"p, 1, %19, %18;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 48, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %26, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"%24, " \
|
||||
"%25, " \
|
||||
"p, 1, %29, %27, %28;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %26, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"%24, " \
|
||||
"%25, " \
|
||||
"p, 1, %29, %27, %28;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %14, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11}, " \
|
||||
"%12, " \
|
||||
"%13, " \
|
||||
"p, 1, %17, %15, %16;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,587 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 64, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 64, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %37, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"{%32, %33, %34, %35}, " \
|
||||
"%36, " \
|
||||
"p, 1, %39, %38;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %37, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"{%32, %33, %34, %35}, " \
|
||||
"%36, " \
|
||||
"p, 1, %39, %38;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %21, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"{%16, %17, %18, %19}, " \
|
||||
"%20, " \
|
||||
"p, 1, %23, %22;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %37, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"{%32, %33, %34, %35}, " \
|
||||
"%36, " \
|
||||
"p, 1, %38;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %37, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"{%32, %33, %34, %35}, " \
|
||||
"%36, " \
|
||||
"p, 1, %38;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %21, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"{%16, %17, %18, %19}, " \
|
||||
"%20, " \
|
||||
"p, 1, %22;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %21, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"{%16, %17, %18, %19}, " \
|
||||
"%20, " \
|
||||
"p, 1, %22;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 64, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %34, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"%32, " \
|
||||
"%33, " \
|
||||
"p, 1, %37, %35, %36;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %34, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"%32, " \
|
||||
"%33, " \
|
||||
"p, 1, %37, %35, %36;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %18, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"%16, " \
|
||||
"%17, " \
|
||||
"p, 1, %21, %19, %20;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %34, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"%32, " \
|
||||
"%33, " \
|
||||
"p, 1, %35;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b), // transpose is not supported for FP8
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %34, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31}, " \
|
||||
"%32, " \
|
||||
"%33, " \
|
||||
"p, 1, %35;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b), // transpose is not supported for FP8
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %18, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"%16, " \
|
||||
"%17, " \
|
||||
"p, 1, %19;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %18, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15}, " \
|
||||
"%16, " \
|
||||
"%17, " \
|
||||
"p, 1, %19;\n" \
|
||||
"}\n"
|
||||
// a_mat descriptor, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a, imm-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,286 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 80, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 80, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %45, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \
|
||||
"{%40, %41, %42, %43}, " \
|
||||
"%44, " \
|
||||
"p, 1, %47, %46;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %45, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \
|
||||
"{%40, %41, %42, %43}, " \
|
||||
"%44, " \
|
||||
"p, 1, %47, %46;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %25, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19}, " \
|
||||
"{%20, %21, %22, %23}, " \
|
||||
"%24, " \
|
||||
"p, 1, %27, %26;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 80, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %42, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \
|
||||
"%40, " \
|
||||
"%41, " \
|
||||
"p, 1, %45, %43, %44;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %42, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39}, " \
|
||||
"%40, " \
|
||||
"%41, " \
|
||||
"p, 1, %45, %43, %44;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %22, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19}, " \
|
||||
"%20, " \
|
||||
"%21, " \
|
||||
"p, 1, %25, %23, %24;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,703 @@
|
||||
template<typename T_D, typename T_AB, int trans_a, int trans_b>
|
||||
struct base<T_D, T_AB, 96, trans_a, trans_b> {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, 96, ducks::rt_layout::row> &dst,
|
||||
const rt_base<T_AB, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %53, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"{%48, %49, %50, %51}, " \
|
||||
"%52, " \
|
||||
"p, 1, %55, %54;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %53, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"{%48, %49, %50, %51}, " \
|
||||
"%52, " \
|
||||
"p, 1, %55, %54;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %29, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"{%24, %25, %26, %27}, " \
|
||||
"%28, " \
|
||||
"p, 1, %31, %30;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc), "r"(scale_d), "n"(trans_b), "n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %53, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"{%48, %49, %50, %51}, " \
|
||||
"%52, " \
|
||||
"p, 1, %54;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %53, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"{%48, %49, %50, %51}, " \
|
||||
"%52, " \
|
||||
"p, 1, %54;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y)
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %29, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"{%24, %25, %26, %27}, " \
|
||||
"%28, " \
|
||||
"p, 1, %30;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %29, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"{%24, %25, %26, %27}, " \
|
||||
"%28, " \
|
||||
"p, 1, %30;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3])
|
||||
|
||||
: "r"(*(uint32_t*)&a_rt.data[0]), "r"(*(uint32_t*)&a_rt.data[1]),
|
||||
"r"(*(uint32_t*)&a_rt.data[2]), "r"(*(uint32_t*)&a_rt.data[3]),
|
||||
|
||||
"l"(b_st_desc),
|
||||
"r"(scale_d),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, 96, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
) {
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Invalid type combination for WGMMA."
|
||||
);
|
||||
static_assert(scale_b==1 || scale_b==-1, "Invalid scale B (invert) option");
|
||||
// ----- BF16,BF16 -> FP32 ----- //
|
||||
if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, bf16>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %50, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"%48, " \
|
||||
"%49, " \
|
||||
"p, 1, %53, %51, %52;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %50, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"%48, " \
|
||||
"%49, " \
|
||||
"p, 1, %53, %51, %52;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP16,FP16 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, half>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %26, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"%24, " \
|
||||
"%25, " \
|
||||
"p, 1, %29, %27, %28;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
"n"(trans_a),
|
||||
"n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %50, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"%48, " \
|
||||
"%49, " \
|
||||
"p, 1, %51;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP32 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, float> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %50, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45, %46, %47}, " \
|
||||
"%48, " \
|
||||
"%49, " \
|
||||
"p, 1, %51;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+f"(dst.tiles[0][0].data[0].x), "+f"(dst.tiles[0][0].data[0].y),
|
||||
"+f"(dst.tiles[0][0].data[1].x), "+f"(dst.tiles[0][0].data[1].y),
|
||||
"+f"(dst.tiles[0][0].data[2].x), "+f"(dst.tiles[0][0].data[2].y),
|
||||
"+f"(dst.tiles[0][0].data[3].x), "+f"(dst.tiles[0][0].data[3].y),
|
||||
"+f"(dst.tiles[0][1].data[0].x), "+f"(dst.tiles[0][1].data[0].y),
|
||||
"+f"(dst.tiles[0][1].data[1].x), "+f"(dst.tiles[0][1].data[1].y),
|
||||
"+f"(dst.tiles[0][1].data[2].x), "+f"(dst.tiles[0][1].data[2].y),
|
||||
"+f"(dst.tiles[0][1].data[3].x), "+f"(dst.tiles[0][1].data[3].y),
|
||||
"+f"(dst.tiles[0][2].data[0].x), "+f"(dst.tiles[0][2].data[0].y),
|
||||
"+f"(dst.tiles[0][2].data[1].x), "+f"(dst.tiles[0][2].data[1].y),
|
||||
"+f"(dst.tiles[0][2].data[2].x), "+f"(dst.tiles[0][2].data[2].y),
|
||||
"+f"(dst.tiles[0][2].data[3].x), "+f"(dst.tiles[0][2].data[3].y),
|
||||
"+f"(dst.tiles[0][3].data[0].x), "+f"(dst.tiles[0][3].data[0].y),
|
||||
"+f"(dst.tiles[0][3].data[1].x), "+f"(dst.tiles[0][3].data[1].y),
|
||||
"+f"(dst.tiles[0][3].data[2].x), "+f"(dst.tiles[0][3].data[2].y),
|
||||
"+f"(dst.tiles[0][3].data[3].x), "+f"(dst.tiles[0][3].data[3].y),
|
||||
"+f"(dst.tiles[0][4].data[0].x), "+f"(dst.tiles[0][4].data[0].y),
|
||||
"+f"(dst.tiles[0][4].data[1].x), "+f"(dst.tiles[0][4].data[1].y),
|
||||
"+f"(dst.tiles[0][4].data[2].x), "+f"(dst.tiles[0][4].data[2].y),
|
||||
"+f"(dst.tiles[0][4].data[3].x), "+f"(dst.tiles[0][4].data[3].y),
|
||||
"+f"(dst.tiles[0][5].data[0].x), "+f"(dst.tiles[0][5].data[0].y),
|
||||
"+f"(dst.tiles[0][5].data[1].x), "+f"(dst.tiles[0][5].data[1].y),
|
||||
"+f"(dst.tiles[0][5].data[2].x), "+f"(dst.tiles[0][5].data[2].y),
|
||||
"+f"(dst.tiles[0][5].data[3].x), "+f"(dst.tiles[0][5].data[3].y)
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e4m3>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %26, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"%24, " \
|
||||
"%25, " \
|
||||
"p, 1, %27;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
// ----- FP8,FP8 -> FP16 ----- //
|
||||
else if constexpr (std::is_same_v<T_D, half> && std::is_same_v<T_AB, fp8e5m2>) {
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred p;\n" \
|
||||
"setp.ne.b32 p, %26, 0;\n" \
|
||||
"wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " \
|
||||
"{%0, %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12, %13, %14, %15, %16, %17, %18, %19, %20, %21, %22, %23}, " \
|
||||
"%24, " \
|
||||
"%25, " \
|
||||
"p, 1, %27;\n" \
|
||||
"}\n"
|
||||
// a_regs, b_mat descriptor, scale-d, imm-scale-a, imm-scale-b, im-trans-a im-trans-b
|
||||
|
||||
: "+r"(*(uint32_t*)&dst.tiles[0][0].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][0].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][1].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][2].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][3].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][4].data[3]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[0]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[1]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[2]),
|
||||
"+r"(*(uint32_t*)&dst.tiles[0][5].data[3])
|
||||
|
||||
: "l"(a_st_desc),
|
||||
"l"(b_st_desc),
|
||||
|
||||
"r"(scale_d),
|
||||
// "n"(trans_a),
|
||||
// "n"(trans_b),
|
||||
"n"(scale_b)
|
||||
);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../../../../common/common.cuh"
|
||||
#include "../../../../../types/types.cuh"
|
||||
|
||||
namespace kittens {
|
||||
namespace detail {
|
||||
namespace wgmma {
|
||||
|
||||
// templated wrapper for PTX
|
||||
template<typename T_D, typename T_AB, int cols, int trans_a, int trans_b, int inv=1>
|
||||
struct base {
|
||||
template<int scale_b=1> __device__ static inline void rt_st(
|
||||
rt<T_D, 16, cols, ducks::rt_layout::row> &dst,
|
||||
const rt<T_AB, 16, cols, ducks::rt_layout::row> & a_rt,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
);
|
||||
template<int scale_b=1> __device__ static inline void st_st(
|
||||
rt<T_D, 16, cols, ducks::rt_layout::row> &dst,
|
||||
const uint64_t a_st_desc,
|
||||
const uint64_t b_st_desc,
|
||||
int scale_d = 1
|
||||
);
|
||||
};
|
||||
|
||||
// all the ptx's
|
||||
#include "64x16.impl"
|
||||
#include "64x32.impl"
|
||||
#include "64x48.impl"
|
||||
#include "64x64.impl"
|
||||
#include "64x80.impl"
|
||||
#include "64x96.impl"
|
||||
#include "64x112.impl"
|
||||
#include "64x128.impl"
|
||||
#include "64x144.impl"
|
||||
#include "64x160.impl"
|
||||
#include "64x176.impl"
|
||||
#include "64x192.impl"
|
||||
#include "64x208.impl"
|
||||
#include "64x224.impl"
|
||||
#include "64x240.impl"
|
||||
#include "64x256.impl"
|
||||
|
||||
} // namespace wgmma
|
||||
} // namespace detail
|
||||
} // namespace kittens
|
||||
1170
extra/thunder/cuda/include/ops/group/mma/warpgroup/warpgroup.cuh
Normal file
1170
extra/thunder/cuda/include/ops/group/mma/warpgroup/warpgroup.cuh
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,7 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header for warp operations on data stored in registers.
|
||||
*/
|
||||
|
||||
#include "tile/tile.cuh"
|
||||
#include "vec/vec.cuh"
|
||||
@@ -0,0 +1,98 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Conversions between data layouts and types for complex register tiles.
|
||||
*/
|
||||
|
||||
/* ---------- LAYOUT SWAPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Swaps the layout of a complex register tile.
|
||||
*
|
||||
* This function swaps the layout of a complex register tile by
|
||||
* swapping the real and imaginary component tiles' layouts
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height of the register tile.
|
||||
* @tparam _width The width of the register tile.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param dst[out] Reference to the destination register tile where the result will be stored.
|
||||
* @param src[in] Reference to the source register tile to be swapped.
|
||||
*/
|
||||
template<typename T2, int _height, int _width, ducks::rt_layout::all layout>
|
||||
__device__ static inline void swap_layout(crt<T2, _height, _width, typename ducks::rt_layout::transpose<layout>::type> &dst, const crt<T2, _height, _width, layout> &src) {
|
||||
swap_layout(dst.real, src.real);
|
||||
swap_layout(dst.real, src.real);
|
||||
}
|
||||
/**
|
||||
* @brief Swaps the layout of a complex register tile in place.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height of the register tile.
|
||||
* @tparam _width The width of the register tile.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param tile[in,out] Reference to the register tile to be swapped in place.
|
||||
* @return A reference to the swapped register tile.
|
||||
*/
|
||||
template<typename T2, int _height, int _width, ducks::rt_layout::all layout>
|
||||
__device__ static inline crt<T2, _height, _width, typename ducks::rt_layout::transpose<layout>::type>& swap_layout_inplace(crt<T2, _height, _width, layout> &tile) {
|
||||
tile.real = swap_layout_inplace(tile.real);
|
||||
tile.imag = swap_layout_inplace(tile.imag);
|
||||
return tile;
|
||||
}
|
||||
|
||||
/* ---------- TRANSPOSE ---------- */
|
||||
|
||||
/**
|
||||
* @brief Transposes a complex register tile.
|
||||
*
|
||||
* This function is marked "sep", which means that the registers underlying dst MUST be separate
|
||||
* from the registers underlying src.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height of the src register tile, and the width of the dst tile.
|
||||
* @tparam _width The width of the src register tile, and the height of the dst tile.
|
||||
* @tparam layout The layout of the register tile.
|
||||
* @param dst[out] Reference to the register tile in which to store the transposed src.
|
||||
* @param src[in] Reference to the register tile to be transposed.
|
||||
*/
|
||||
template<typename T2, int _height, int _width, ducks::rt_layout::all layout>
|
||||
__device__ static inline void transpose_sep(crt<T2, _width, _height, layout> &dst, const crt<T2, _height, _width, layout> &src) {
|
||||
transpose_sep(dst.real, src.real);
|
||||
transpose_sep(dst.imag, src.imag);
|
||||
}
|
||||
/**
|
||||
* @brief Transposes a square complex register tile in-place.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height (in units of 16) of the src register tile, and the width of the dst tile. (Must be the same as _width.)
|
||||
* @tparam _width The width (in units of 16) of the src register tile, and the height of the dst tile. (Must be the same as _height.)
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param src[in] Reference to the register tile to be transposed.
|
||||
* @return A reference to the transposed register tile.
|
||||
*/
|
||||
template<typename T2, int _height, int _width, ducks::rt_layout::all layout>
|
||||
__device__ static inline crt<T2, _height, _width, layout>& transpose_inplace(crt<T2, _height, _width, layout> &tile) {
|
||||
tile.real = transpose_inplace(tile.real);
|
||||
tile.imag = transpose_inplace(tile.imag);
|
||||
|
||||
return tile;
|
||||
}
|
||||
|
||||
/* ---------- TYPE SWAPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Copies a complex register tile, converting the underlying type if necessary.
|
||||
*
|
||||
* @tparam T2 The data type of the destination register elements.
|
||||
* @tparam U2 The data type of the source register elements.
|
||||
* @tparam _height The height (in units of 16) of the register tiles.
|
||||
* @tparam _width The width (in units of 16) of the register tiles.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param[out] dst A reference to the destination register tile.
|
||||
* @param[in] src A reference to the source register tile.
|
||||
*/
|
||||
template<typename T2, typename U2, int _height, int _width, ducks::rt_layout::all layout>
|
||||
__device__ static inline void copy(crt<T2, _height, _width, layout> &dst, const crt<U2, _height, _width, layout> &src) {
|
||||
copy(dst.real, src.real);
|
||||
copy(dst.imag, src.imag);
|
||||
}
|
||||
@@ -0,0 +1,137 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Map operations between complex tiles.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a complex tile to zero.
|
||||
*
|
||||
* @tparam T Complex tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<ducks::crt::all T>
|
||||
__device__ static inline void zero(T &dst) {
|
||||
zero(dst.real);
|
||||
zero(dst.imag);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of a complex tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the exponential function on.
|
||||
*/
|
||||
template<ducks::crt::all T>
|
||||
__device__ static inline void exp(T &dst, const T &src) {
|
||||
using dtype = T::dtype;
|
||||
dtype tmp;
|
||||
// out of place storage
|
||||
dtype rdst;
|
||||
dtype idst;
|
||||
|
||||
// exp(a)
|
||||
exp(rdst, src.real);
|
||||
copy(idst, rdst);
|
||||
// exp(a)cos(b) + exp(a)sin(b)i
|
||||
cos(tmp, src.imag);
|
||||
mul(rdst, rdst, tmp);
|
||||
sin(tmp, src.imag);
|
||||
mul(idst, idst, tmp);
|
||||
|
||||
copy(dst.real, rdst);
|
||||
copy(dst.imag, idst);
|
||||
}
|
||||
/**
|
||||
* @brief Adds two complex tiles element-wise.
|
||||
*
|
||||
* @tparam T Complex Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the addition.
|
||||
* @param rhs[in] Right-hand side source tile for the addition.
|
||||
*/
|
||||
template<ducks::crt::all T>
|
||||
__device__ static inline void add(T &dst, const T &lhs, const T &rhs) {
|
||||
add(dst.real, lhs.real, rhs.real);
|
||||
add(dst.imag, lhs.imag, rhs.imag);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts two tiles element-wise.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the subtraction.
|
||||
* @param rhs[in] Right-hand side source tile for the subtraction.
|
||||
*/
|
||||
template<ducks::crt::all T>
|
||||
__device__ static inline void sub(T &dst, const T &lhs, const T &rhs) {
|
||||
sub(dst.real, lhs.real, rhs.real);
|
||||
sub(dst.imag, lhs.imag, rhs.imag);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies two tiles element-wise.
|
||||
*
|
||||
* @tparam T Complex tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the multiplication.
|
||||
* @param rhs[in] Right-hand side source tile for the multiplication.
|
||||
*/
|
||||
template<ducks::crt::all T>
|
||||
__device__ static inline void mul(T &dst, const T &lhs, const T &rhs) {
|
||||
using dtype = T::component;
|
||||
dtype tmp;
|
||||
// out of place storage regs
|
||||
dtype rdst;
|
||||
dtype idst;
|
||||
|
||||
// (a + bi) * (c + di) --> (ac - bd) + (ad + bc)i
|
||||
// Real component
|
||||
mul(rdst, lhs.real, rhs.real);
|
||||
mul(tmp, lhs.imag, rhs.imag);
|
||||
sub(rdst, rdst, tmp);
|
||||
|
||||
// Imag component
|
||||
mul(idst, lhs.imag, rhs.real);
|
||||
mul(tmp, lhs.real, rhs.imag);
|
||||
add(idst, idst, tmp);
|
||||
|
||||
copy(dst.real, rdst);
|
||||
copy(dst.imag, idst);
|
||||
}
|
||||
/**
|
||||
* @brief Divides two tiles element-wise.
|
||||
*
|
||||
* @tparam T Complex tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the division.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the division.
|
||||
*/
|
||||
template<ducks::crt::all T>
|
||||
__device__ static inline void div(T &dst, const T &lhs, const T &rhs) {
|
||||
using dtype = T::dtype;
|
||||
dtype tmp;
|
||||
dtype denom;
|
||||
// out of place storage regs
|
||||
dtype rdst;
|
||||
dtype idst;
|
||||
|
||||
// Calculate denom - square of b terms
|
||||
mul(tmp, rhs.real, rhs.real);
|
||||
mul(denom, rhs.imag, rhs.imag);
|
||||
add(denom, tmp, denom);
|
||||
// Real component
|
||||
mul(rdst, lhs.real, rhs.real);
|
||||
mul(tmp, lhs.imag, rhs.imag);
|
||||
add(rdst, rdst, tmp);
|
||||
// Imag component
|
||||
mul(dst.imag, lhs.imag, rhs.real);
|
||||
mul(tmp, lhs.real, rhs.imag);
|
||||
sub(idst, idst, tmp);
|
||||
// Divide components by denom
|
||||
div(rdst, rdst, denom);
|
||||
div(idst, idst, denom);
|
||||
copy(dst.real, rdst);
|
||||
copy(dst.imag, idst);
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,415 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Conversions between data layouts and types for register tiles.
|
||||
*/
|
||||
|
||||
/* ---------- LAYOUT SWAPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Perform a matrix transpose on a block of 8 bf16_2 elements using inline assembly.
|
||||
*
|
||||
* This low-level operation is utilized by higher-level layout swap functions to transpose
|
||||
* the layout of bf16_2 elements within a register tile. The function leverages inline PTX
|
||||
* assembly to efficiently swap the layout of the given block.
|
||||
*
|
||||
* @param[out] dst A reference to the destination bf16_2 element where the transposed result is stored.
|
||||
* @param[in] src A reference to the source bf16_2 element to be transposed.
|
||||
*/
|
||||
__device__ static inline void swap_layout_8(bf16_2 &dst, const bf16_2 &src) {
|
||||
KITTENS_CHECK_WARP
|
||||
asm volatile (
|
||||
"movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;\n"
|
||||
: "+r"(*(uint32_t*)(&dst))
|
||||
: "r"(*(uint32_t*)(&src))
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Swaps the layout of a register base tile.
|
||||
*
|
||||
* This function swaps the layout of a register base tile by performing a series of layout swaps
|
||||
* on its constituent bf16_2 elements. It is used to change the data layout within a register tile.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param dst[out] Reference to the destination register base tile where the result will be stored.
|
||||
* @param src[in] Reference to the source register base tile to be swapped.
|
||||
*/
|
||||
template<typename T, ducks::rt_layout::all layout>
|
||||
__device__ static inline void swap_layout(rt_base<T, typename ducks::rt_layout::transpose<layout>::type> &dst, const rt_base<T, layout> &src) {
|
||||
swap_layout_8(dst.data[0], src.data[0]);
|
||||
// technically this swap can be eliminated if we simply reinterpret the layout of the registers
|
||||
// everywhere else in the code, but that feels... very likely to cause bugs and not worth it.
|
||||
typename rt_base<T, layout>::T2 data1_cache = src.data[1]; // important for swap!
|
||||
swap_layout_8(dst.data[1], src.data[2]);
|
||||
swap_layout_8(dst.data[2], data1_cache);
|
||||
swap_layout_8(dst.data[3], src.data[3]);
|
||||
}
|
||||
/**
|
||||
* @brief Swaps the layout of a register tile.
|
||||
*
|
||||
* This function swaps the layout of a register tile by iterating over its height and width
|
||||
* and performing layout swaps on each of its base elements.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height of the register tile.
|
||||
* @tparam _width The width of the register tile.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param dst[out] Reference to the destination register tile where the result will be stored.
|
||||
* @param src[in] Reference to the source register tile to be swapped.
|
||||
*/
|
||||
template<typename T2, int _height, int _width, ducks::rt_layout::all layout>
|
||||
__device__ static inline void swap_layout(rt<T2, _height, _width, typename ducks::rt_layout::transpose<layout>::type> &dst, const rt<T2, _height, _width, layout> &src) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
swap_layout(dst.tiles[i][j], src.tiles[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Swaps the layout of a register base tile in place.
|
||||
*
|
||||
* This function swaps the layout of a register base tile in place by casting it to the
|
||||
* transposed layout type and then performing the layout swap.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param src[in] Reference to the register base tile to be swapped in place.
|
||||
* @return A reference to the swapped register base tile.
|
||||
*/
|
||||
template<typename T2, ducks::rt_layout::all layout>
|
||||
__device__ static inline rt_base<T2, typename ducks::rt_layout::transpose<layout>::type>& swap_layout_inplace(const rt_base<T2, layout> &src) {
|
||||
rt_base<T2, typename ducks::rt_layout::transpose<layout>::type> &dst = *(rt_base<T2, typename ducks::rt_layout::transpose<layout>::type>*)(&src);
|
||||
swap_layout(dst, src);
|
||||
return dst;
|
||||
}
|
||||
/**
|
||||
* @brief Swaps the layout of a register tile in place.
|
||||
*
|
||||
* This function swaps the layout of a register tile in place by iterating over its height and width
|
||||
* and performing in-place layout swaps on each of its base elements.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height of the register tile.
|
||||
* @tparam _width The width of the register tile.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param tile[in,out] Reference to the register tile to be swapped in place.
|
||||
* @return A reference to the swapped register tile.
|
||||
*/
|
||||
template<typename T2, int _rows, int _cols, ducks::rt_layout::all layout>
|
||||
__device__ static inline rt<T2, _rows, _cols, typename ducks::rt_layout::transpose<layout>::type>& swap_layout_inplace(rt<T2, _rows, _cols, layout> &tile) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < tile.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < tile.width; j++) {
|
||||
swap_layout_inplace(tile.tiles[i][j]);
|
||||
}
|
||||
}
|
||||
return *(rt<T2, _rows, _cols, typename ducks::rt_layout::transpose<layout>::type>*)(&tile);
|
||||
}
|
||||
|
||||
/* ---------- TRANSPOSE ---------- */
|
||||
|
||||
/**
|
||||
* @brief Transposes a register base tile.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param dst[out] Reference to the register tile in which to store the transposed src.
|
||||
* @param src[in] Reference to the register base tile to be transposed.
|
||||
*/
|
||||
template<typename T, ducks::rt_layout::all layout>
|
||||
__device__ static inline void transpose(rt_base<T, layout> &dst, const rt_base<T, layout> &src) {
|
||||
swap_layout_8(dst.data[0], src.data[0]);
|
||||
// technically this swap can be eliminated if we simply reinterpret the layout of the registers
|
||||
// everywhere else in the code, but that feels... very likely to cause bugs and not worth it.
|
||||
typename rt_base<T, layout>::T2 data1_cache = src.data[1]; // important for swap!
|
||||
swap_layout_8(dst.data[1], src.data[2]);
|
||||
swap_layout_8(dst.data[2], data1_cache);
|
||||
swap_layout_8(dst.data[3], src.data[3]);
|
||||
}
|
||||
/**
|
||||
* @brief Transposes a register tile.
|
||||
*
|
||||
* This function is marked "sep", which means that the registers underlying dst MUST be separate
|
||||
* from the registers underlying src.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height of the src register tile, and the width of the dst tile.
|
||||
* @tparam _width The width of the src register tile, and the height of the dst tile.
|
||||
* @tparam layout The layout of the register tile.
|
||||
* @param dst[out] Reference to the register tile in which to store the transposed src.
|
||||
* @param src[in] Reference to the register tile to be transposed.
|
||||
*/
|
||||
template<ducks::rt::all RT>
|
||||
__device__ static inline void transpose_sep(RT &dst, const rt<typename RT::T, RT::cols, RT::rows, typename RT::layout> &src) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RT::height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < RT::width; j++) {
|
||||
transpose(dst.tiles[i][j], src.tiles[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Transposes a register base tile in-place.
|
||||
*
|
||||
* @tparam T2 The data type of the register base tile elements.
|
||||
* @tparam layout The current layout of the register base tile.
|
||||
* @param src[in] Reference to the register tile to be transposed.
|
||||
* @return A reference to the transposed register base tile.
|
||||
*/
|
||||
template<typename T2, ducks::rt_layout::all layout>
|
||||
__device__ static inline rt_base<T2, layout>& transpose_inplace(rt_base<T2, layout> &src) {
|
||||
transpose(src, src);
|
||||
return src;
|
||||
}
|
||||
/**
|
||||
* @brief Transposes a square register tile in-place.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height (in units of 16) of the src register tile, and the width of the dst tile. (Must be the same as _width.)
|
||||
* @tparam _width The width (in units of 16) of the src register tile, and the height of the dst tile. (Must be the same as _height.)
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param src[in] Reference to the register tile to be transposed.
|
||||
* @return A reference to the transposed register tile.
|
||||
*/
|
||||
template<typename T2, int _rows, int _cols, ducks::rt_layout::all layout>
|
||||
__device__ static inline rt<T2, _rows, _cols, layout>& transpose_inplace(rt<T2, _rows, _cols, layout> &tile) {
|
||||
static_assert(_cols == _rows, "in-place register tile transpose is only allowed for square tiles.");
|
||||
#pragma unroll
|
||||
for(int i = 0; i < tile.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < i; j++) {
|
||||
rt_base<T2, layout> tmp;
|
||||
copy(tmp, tile.tiles[i][j]);
|
||||
transpose(tile.tiles[i][j], tile.tiles[j][i]);
|
||||
transpose(tile.tiles[j][i], tmp);
|
||||
}
|
||||
transpose_inplace(tile.tiles[i][i]);
|
||||
}
|
||||
return tile;
|
||||
}
|
||||
|
||||
/* ---------- TYPE SWAPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Copies a register base tile, converting the underlying type if necessary.
|
||||
*
|
||||
* @tparam T2 The data type of the destination register elements.
|
||||
* @tparam U2 The data type of the source register elements.
|
||||
* @tparam layout The current layout of the register base tile.
|
||||
* @param[out] dst A reference to the destination register base tile.
|
||||
* @param[in] src A reference to the source register base tile.
|
||||
*/
|
||||
template<typename T, typename U, ducks::rt_layout::all layout>
|
||||
__device__ static inline void copy(rt_base<T, layout> &dst, const rt_base<U, layout> &src) {
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_thread; k++) {
|
||||
dst.data[k] = base_types::convertor<T2, U2>::convert(src.data[k]);
|
||||
}
|
||||
}
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief Copies a register tile, converting the underlying type if necessary.
|
||||
*
|
||||
* @tparam T2 The data type of the destination register elements.
|
||||
* @tparam U2 The data type of the source register elements.
|
||||
* @tparam _height The height (in units of 16) of the register tiles.
|
||||
* @tparam _width The width (in units of 16) of the register tiles.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param[out] dst A reference to the destination register tile.
|
||||
* @param[in] src A reference to the source register tile.
|
||||
*/
|
||||
template<typename T2, typename U2, int _height, int _width, ducks::rt_layout::all layout>
|
||||
__device__ static inline void copy(rt<T2, _height, _width, layout> &dst, const rt<U2, _height, _width, layout> &src) {
|
||||
|
||||
if constexpr (
|
||||
(std::is_same_v<U2, float> && std::is_same_v<T2, fp8e4m3>) ||
|
||||
(std::is_same_v<U2, float> && std::is_same_v<T2, fp8e5m2>) ||
|
||||
(std::is_same_v<U2, kittens::bf16> && std::is_same_v<T2, fp8e4m3>) ||
|
||||
(std::is_same_v<U2, kittens::bf16> && std::is_same_v<T2, fp8e5m2>) ||
|
||||
(std::is_same_v<U2, half> && std::is_same_v<T2, fp8e4m3>) ||
|
||||
(std::is_same_v<U2, half> && std::is_same_v<T2, fp8e5m2>)
|
||||
) {
|
||||
// FLOAT (SRC -- 1H x 2W) to FP8 (DST -- 1H x 1W)
|
||||
int laneid = threadIdx.x % 32;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.tiles[0][0].packed_per_thread; k++) {
|
||||
|
||||
// check for half, float, bf16
|
||||
using src_t = std::conditional_t<std::is_same_v<U2, float>, float2, std::conditional_t<std::is_same_v<U2, kittens::bf16>, bf16_2, half2>>;
|
||||
src_t val1, val2;
|
||||
|
||||
// Put something up for adoption
|
||||
if (laneid % 2 == 0) {
|
||||
// put up src left core matrix first as 0, 2
|
||||
val1 = src.tiles[i][2*j + k/2].data[(k%2)+0];
|
||||
val2 = src.tiles[i][2*j + k/2].data[(k%2)+2];
|
||||
} else {
|
||||
// put up src right core matrix first as 1, 3
|
||||
val1 = src.tiles[i][2*j + k/2].data[(k%2)+2];
|
||||
val2 = src.tiles[i][2*j + k/2].data[(k%2)+0];
|
||||
}
|
||||
|
||||
// Shuffle first 4 floats
|
||||
int row_mask = 4 * ( laneid / 4 );
|
||||
int row_offset = row_mask + ( (laneid-row_mask) / 2 ) + ( laneid % 2 );
|
||||
int src_offset = (laneid % 2 == 0 ) ? row_offset + 0 : ( row_offset + 1 );
|
||||
src_t val01 = packed_shfl_sync(MASK_ALL, val1, src_offset); // Get from even thread
|
||||
|
||||
int src_offset2 = (laneid % 4 < 2 ) ? src_offset + 1 : (src_offset - 1);
|
||||
src_t val23 = packed_shfl_sync(MASK_ALL, val2, src_offset2); // Get from odd thread
|
||||
|
||||
// Convert to fp8e4m3_4
|
||||
float4 f4;
|
||||
using fp8_4_t = std::conditional_t<std::is_same_v<T2, fp8e4m3>, fp8e4m3_4, fp8e5m2_4>;
|
||||
fp8_4_t f4_fp8;
|
||||
if ( laneid % 4 < 2 ) {
|
||||
f4.x = val01.x; // Thread 2N's first value
|
||||
f4.y = val01.y; // Thread 2N's second value
|
||||
f4.z = val23.x; // Thread 2N+1's first value
|
||||
f4.w = val23.y; // Thread 2N+1's second value
|
||||
f4_fp8 = base_types::convertor<fp8_4_t, float4>::convert(f4);
|
||||
dst.tiles[i][j].data[k] = f4_fp8;
|
||||
} else {
|
||||
f4.x = val23.x; // Thread 2N+1's first value
|
||||
f4.y = val23.y; // Thread 2N+1's second value
|
||||
f4.z = val01.x; // Thread 2N's first value
|
||||
f4.w = val01.y; // Thread 2N's second value
|
||||
f4_fp8 = base_types::convertor<fp8_4_t, float4>::convert(f4);
|
||||
dst.tiles[i][j].data[k] = f4_fp8;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (
|
||||
(std::is_same_v<U2, fp8e4m3> && std::is_same_v<T2, float>) ||
|
||||
(std::is_same_v<U2, fp8e5m2> && std::is_same_v<T2, float>) ||
|
||||
(std::is_same_v<U2, fp8e4m3> && std::is_same_v<T2, kittens::bf16>) ||
|
||||
(std::is_same_v<U2, fp8e5m2> && std::is_same_v<T2, kittens::bf16>) ||
|
||||
(std::is_same_v<U2, fp8e4m3> && std::is_same_v<T2, half>) ||
|
||||
(std::is_same_v<U2, fp8e5m2> && std::is_same_v<T2, half>)
|
||||
) {
|
||||
// FP8 (SRC -- 1H x 1W) to FLOAT (DST -- 1H x 2W)
|
||||
int laneid = threadIdx.x % 32;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < src.tiles[0][0].packed_per_thread; k++) {
|
||||
int dst_j = 2*j + k/2;
|
||||
|
||||
// Put something up for adoption
|
||||
using fp8_4_t = std::conditional_t<std::is_same_v<U2, fp8e4m3>, fp8e4m3_4, fp8e5m2_4>;
|
||||
fp8_4_t val = src.tiles[i][j].data[k];
|
||||
float4 f4 = base_types::convertor<float4, fp8_4_t>::convert(val);
|
||||
float2 f2_0, f2_1;
|
||||
if ( laneid % 4 < 2 ) { // src 0 and 1 should put up .x and .y first
|
||||
f2_0 = make_float2(f4.x, f4.y);
|
||||
f2_1 = make_float2(f4.z, f4.w);
|
||||
}
|
||||
else { // src 2 and 3 should put up .z and .w first
|
||||
f2_0 = make_float2(f4.z, f4.w);
|
||||
f2_1 = make_float2(f4.x, f4.y);
|
||||
}
|
||||
|
||||
int row_offset = 4 * (laneid/4) + (laneid%2) * 2 + (laneid%4) / 2;
|
||||
float2 f2_0_shfl = packed_shfl_sync(MASK_ALL, f2_0, row_offset);
|
||||
float2 f2_1_shfl = packed_shfl_sync(MASK_ALL, f2_1, row_offset^2);
|
||||
|
||||
// convert to dst type if needed
|
||||
using dst_t = std::conditional_t<std::is_same_v<T2, float>, float2, std::conditional_t<std::is_same_v<T2, kittens::bf16>, bf16_2, half2>>;
|
||||
if constexpr (!(std::is_same_v<T2, float>)) {
|
||||
dst_t f2_0_shfl_t = base_types::convertor<dst_t, float2>::convert(f2_0_shfl);
|
||||
dst_t f2_1_shfl_t = base_types::convertor<dst_t, float2>::convert(f2_1_shfl);
|
||||
if (laneid % 2 == 0) {
|
||||
dst.tiles[i][dst_j].data[(k%2)+0] = f2_0_shfl_t;
|
||||
dst.tiles[i][dst_j].data[(k%2)+2] = f2_1_shfl_t;
|
||||
} else {
|
||||
dst.tiles[i][dst_j].data[(k%2)+0] = f2_1_shfl_t;
|
||||
dst.tiles[i][dst_j].data[(k%2)+2] = f2_0_shfl_t;
|
||||
}
|
||||
} else {
|
||||
if (laneid % 2 == 0) {
|
||||
dst.tiles[i][dst_j].data[(k%2)+0] = f2_0_shfl;
|
||||
dst.tiles[i][dst_j].data[(k%2)+2] = f2_1_shfl;
|
||||
} else {
|
||||
dst.tiles[i][dst_j].data[(k%2)+0] = f2_1_shfl;
|
||||
dst.tiles[i][dst_j].data[(k%2)+2] = f2_0_shfl;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// default case where the layouts map 1:1 in thread ownership logic
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
copy(dst.tiles[i][j], src.tiles[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
/**
|
||||
* @brief Copies a register tile, converting the underlying type if necessary.
|
||||
*
|
||||
* @tparam T2 The data type of the destination register elements.
|
||||
* @tparam U2 The data type of the source register elements.
|
||||
* @tparam _height The height (in units of 16) of the register tiles.
|
||||
* @tparam _width The width (in units of 16) of the register tiles.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param[out] dst A reference to the destination register tile.
|
||||
* @param[in] src A reference to the source register tile.
|
||||
*/
|
||||
template<typename T2, typename U2, int _height, int _width, ducks::rt_layout::all layout>
|
||||
__device__ static inline void copy(rt<T2, _height, _width, layout> &dst, const rt<U2, _height, _width, layout> &src) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
copy(dst.tiles[i][j], src.tiles[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
/* ---------- SUBTILE ---------- */
|
||||
|
||||
/**
|
||||
* @brief Returns a reference to a subtile of the given tile.
|
||||
*
|
||||
* @tparam subtile_height The height of the subtile.
|
||||
* @tparam RT The type of the input tile, which must satisfy the ducks::rt::all concept.
|
||||
* @param src The input tile.
|
||||
* @param idx The coord of the subtile.
|
||||
* @return A reference to the subtile.
|
||||
*
|
||||
* @note The subtile height must evenly divide the tile height.
|
||||
*/
|
||||
template<int subtile_rows, ducks::rt::all RT>
|
||||
__device__ static inline rt<typename RT::T, subtile_rows, RT::cols, typename RT::layout> &subtile_inplace(RT & src, int idx) {
|
||||
KITTENS_CHECK_WARP
|
||||
using T = typename RT::T;
|
||||
static_assert(RT::height % (subtile_rows / TILE_ROW_DIM<T>) == 0, "subtile height should evenly divide tile height.");
|
||||
return reinterpret_cast<rt<typename RT::T, subtile_rows, RT::cols, typename RT::layout>&>(
|
||||
src.tiles[idx*(subtile_rows / TILE_ROW_DIM<T>)]
|
||||
);
|
||||
}
|
||||
836
extra/thunder/cuda/include/ops/group/register/tile/maps.cuh
Normal file
836
extra/thunder/cuda/include/ops/group/register/tile/maps.cuh
Normal file
@@ -0,0 +1,836 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Map operations: between tiles, and those which apply vectors to tiles.
|
||||
*/
|
||||
|
||||
/* ---------- Uniform tile maps (independent of layout) ---------- */
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to each element of a tile.
|
||||
*
|
||||
* @tparam op Unary operation to apply.
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
*/
|
||||
template<typename op, ducks::rt::all T>
|
||||
__device__ static inline void unary_map(T &dst, const T &src) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k++) {
|
||||
dst.tiles[i][j].data[k] = op::template op<typename T::dtype>(src.tiles[i][j].data[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies a binary operation to each element of a tile with a scalar parameter.
|
||||
*
|
||||
* @tparam op Binary operation to apply.
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param param[in] Scalar parameter for the binary operation.
|
||||
*/
|
||||
template<typename op, ducks::rt::all T>
|
||||
__device__ static inline void bin_map(T &dst, const T &src, const typename T::dtype ¶m) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k++) {
|
||||
dst.tiles[i][j].data[k] = op::template op<typename T::dtype>(src.tiles[i][j].data[k], param);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Applies a binary operation to each element of a tile with an unpacked scalar parameter.
|
||||
*
|
||||
* @tparam op Binary operation to apply.
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param param[in] Unpacked scalar parameter for the binary operation.
|
||||
*/
|
||||
template<typename op, ducks::rt::all T>
|
||||
__device__ static inline void bin_map(T &dst, const T &src, const typename base_types::packing<typename T::dtype>::unpacked_type ¶m) {
|
||||
// The optimizing compiler should eliminate this pack in the 32-bit case but not in the 16-bit case
|
||||
bin_map<op, T>(dst, src, base_types::packing<typename T::dtype>::pack(param));
|
||||
}
|
||||
/**
|
||||
* @brief Applies a binary operation element-wise between two tiles.
|
||||
*
|
||||
* @tparam op Binary operation to apply.
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the operation.
|
||||
* @param rhs[in] Right-hand side source tile for the operation.
|
||||
*/
|
||||
template<typename op, ducks::rt::all T>
|
||||
__device__ static inline void bin_map(T &dst, const T &lhs, const T &rhs) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k++) {
|
||||
dst.tiles[i][j].data[k] = op::template op<typename T::dtype>(lhs.tiles[i][j].data[k], rhs.tiles[i][j].data[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<ducks::rt::all RT, typename Lambda>
|
||||
__device__ static inline void apply(RT &dst, const RT &src, Lambda &&lambda) {
|
||||
int row_offset = 0;
|
||||
if constexpr(GROUP_WARPS > 1) {
|
||||
row_offset = warpid()*RT::height;
|
||||
}
|
||||
static_assert(sizeof(RT::T) != 1, "Cannot apply lambda to 8-bit types");
|
||||
if constexpr (ducks::rt::row_layout<RT>) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k++) {
|
||||
int row = row_offset + i*TILE_ROW_DIM<typename RT::T> + (k%2) * (TILE_ROW_DIM<typename RT::T>/2) + ::kittens::laneid()/4;
|
||||
int col = j*TILE_COL_DIM<typename RT::T> + (k/2) * (TILE_COL_DIM<typename RT::T>/2) + (::kittens::laneid()%4)*2;
|
||||
dst.tiles[i][j].data[k].x = lambda(row, col+0, src.tiles[i][j].data[k].x);
|
||||
dst.tiles[i][j].data[k].y = lambda(row, col+1, src.tiles[i][j].data[k].y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k++) {
|
||||
int row = row_offset + i*TILE_ROW_DIM<typename RT::T> + (k/2) * (TILE_ROW_DIM<typename RT::T>/2) + (::kittens::laneid()%4)*2;
|
||||
int col = j*TILE_COL_DIM<typename RT::T> + (k%2) * (TILE_COL_DIM<typename RT::T>/2) + ::kittens::laneid()/4;
|
||||
dst.tiles[i][j].data[k].x = lambda(row+0, col, src.tiles[i][j].data[k].x);
|
||||
dst.tiles[i][j].data[k].y = lambda(row+1, col, src.tiles[i][j].data[k].y);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template<ducks::rt::all RT, typename Lambda>
|
||||
__device__ static inline RT apply(const RT &src, Lambda &&lambda) {
|
||||
RT dst;
|
||||
apply<RT, Lambda>(dst, src, std::forward<Lambda>(lambda));
|
||||
return dst;
|
||||
}
|
||||
|
||||
/* ---------- Row tile maps ----------*/
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the rows of a tile in a row-major layout.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param row_values[in] Column vector containing values to apply across each row.
|
||||
*/
|
||||
template<typename op, ducks::rt::row_layout T, ducks::rv::all V>
|
||||
__device__ static inline void row_map(T &dst, const T &src, const V &row_values) {
|
||||
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::col_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::height); // compatible size
|
||||
|
||||
using dtype = T::dtype;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
dtype packed_top_row = base_types::packing<dtype>::pack(row_values[i][0].x); // first value in eager mode
|
||||
dtype packed_bottom_row = base_types::packing<dtype>::pack(row_values[i][0].y); // second value in eager mode
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k+=2) {
|
||||
dst.tiles[i][j].data[k+0] = op::template op<dtype>(src.tiles[i][j].data[k+0], packed_top_row);
|
||||
dst.tiles[i][j].data[k+1] = op::template op<dtype>(src.tiles[i][j].data[k+1], packed_bottom_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Applies an operation across the rows of a tile in a column-major layout.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with column-major layout.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param row_values[in] Column vector containing values to apply across each row.
|
||||
*/
|
||||
template<typename op, ducks::rt::col_layout T, ducks::rv::all V>
|
||||
__device__ static inline void row_map(T &dst, const T &src, const V &row_values) {
|
||||
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::col_vec_layout>); // compatible layout
|
||||
static_assert(V::outer_dim == T::height); // compatible size
|
||||
|
||||
using dtype = T::dtype;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile/2; k++) {
|
||||
dst.tiles[i][j].data[k+0] = op::template op<dtype>(src.tiles[i][j].data[k+0], row_values[i][0]);
|
||||
dst.tiles[i][j].data[k+2] = op::template op<dtype>(src.tiles[i][j].data[k+2], row_values[i][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// Three-operand row map. Mostly useful for FMA instructions.
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the rows of two tiles in a row-major layout, using a third operand.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param a[in] First source tile to apply the operation on.
|
||||
* @param b[in] Second source tile to apply the operation on.
|
||||
* @param row_values[in] Column vector containing values to apply across each row.
|
||||
*/
|
||||
template<typename op, ducks::rt::row_layout T, ducks::rv::all V>
|
||||
__device__ static inline void row_map(T &dst, const T &a, const T &b, const V &row_values) {
|
||||
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::col_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::height); // compatible size
|
||||
|
||||
using dtype = T::dtype;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
dtype packed_top_row = base_types::packing<dtype>::pack(row_values[i][0].x); // first value in eager mode
|
||||
dtype packed_bottom_row = base_types::packing<dtype>::pack(row_values[i][0].y); // second value in eager mode
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k+=2) {
|
||||
dst.tiles[i][j].data[k+0] = op::template op<dtype>(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], packed_top_row);
|
||||
dst.tiles[i][j].data[k+1] = op::template op<dtype>(a.tiles[i][j].data[k+1], b.tiles[i][j].data[k+1], packed_bottom_row);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Applies an operation across the rows of two tiles in a column-major layout, using a third operand.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with column-major layout.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param a[in] First source tile to apply the operation on.
|
||||
* @param b[in] Second source tile to apply the operation on.
|
||||
* @param row_values[in] Column vector containing values to apply across each row.
|
||||
*/
|
||||
template<typename op, ducks::rt::col_layout T, ducks::rv::all V>
|
||||
__device__ static inline void row_map(T &dst, const T &a, const T &b, const V &row_values) {
|
||||
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::col_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::height); // compatible size
|
||||
|
||||
using dtype = T::dtype;
|
||||
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile/2; k++) {
|
||||
dst.tiles[i][j].data[k+0] = op::template op<dtype>(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], row_values[i][0]);
|
||||
dst.tiles[i][j].data[k+2] = op::template op<dtype>(a.tiles[i][j].data[k+2], b.tiles[i][j].data[k+2], row_values[i][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- Col major tile maps ----------*/
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the columns of a tile in a row-major layout.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param col_values[in] Row vector containing values to apply across each column.
|
||||
*/
|
||||
template<typename op, ducks::rt::row_layout T, ducks::rv::all V>
|
||||
__device__ static inline void col_map(T &dst, const T &src, const V &col_values) {
|
||||
KITTENS_CHECK_WARP
|
||||
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::row_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::width); // compatible size
|
||||
|
||||
using dtype = T::dtype;
|
||||
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile/2; k++) {
|
||||
dst.tiles[i][j].data[k+0] = op::template op<dtype>(src.tiles[i][j].data[k+0], col_values[j][0]);
|
||||
dst.tiles[i][j].data[k+2] = op::template op<dtype>(src.tiles[i][j].data[k+2], col_values[j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Applies an operation across the columns of a tile in a column-major layout.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with column-major layout.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param col_values[in] Row vector containing values to apply across each column.
|
||||
*/
|
||||
template<typename op, ducks::rt::col_layout T, ducks::rv::all V>
|
||||
__device__ static inline void col_map(T &dst, const T &src, const V &col_values) {
|
||||
KITTENS_CHECK_WARP
|
||||
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::row_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::width); // compatible size
|
||||
|
||||
using dtype = T::dtype;
|
||||
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
dtype packed_left_col = base_types::packing<dtype>::pack(col_values[j][0].x); // first value in eager mode
|
||||
dtype packed_right_col = base_types::packing<dtype>::pack(col_values[j][0].y); // second value in eager mode
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k+=2) {
|
||||
dst.tiles[i][j].data[k+0] = op::template op<dtype>(src.tiles[i][j].data[k+0], packed_left_col);
|
||||
dst.tiles[i][j].data[k+1] = op::template op<dtype>(src.tiles[i][j].data[k+1], packed_right_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Three-operand col map
|
||||
/**
|
||||
* @brief Applies an operation across the columns of two tiles in a row-major layout, using a third operand.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param a[in] First source tile to apply the operation on.
|
||||
* @param b[in] Second source tile to apply the operation on.
|
||||
* @param col_values[in] Row vector containing values to apply across each column.
|
||||
*/
|
||||
template<typename op, ducks::rt::row_layout T, ducks::rv::all V>
|
||||
__device__ static inline void col_map(T &dst, const T &a, const T &b, const V &col_values) {
|
||||
KITTENS_CHECK_WARP
|
||||
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::row_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::width); // compatible size
|
||||
|
||||
using dtype = T::dtype;
|
||||
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile/2; k++) {
|
||||
dst.tiles[i][j].data[k+0] = op::template op<dtype>(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], col_values[j][0]);
|
||||
dst.tiles[i][j].data[k+2] = op::template op<dtype>(a.tiles[i][j].data[k+2], b.tiles[i][j].data[k+2], col_values[j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Applies an operation across the columns of two tiles in a column-major layout, using a third operand.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with column-major layout.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param a[in] First source tile to apply the operation on.
|
||||
* @param b[in] Second source tile to apply the operation on.
|
||||
* @param col_values[in] Row vector containing values to apply across each column.
|
||||
*/
|
||||
template<typename op, ducks::rt::col_layout T, ducks::rv::all V>
|
||||
__device__ static inline void col_map(T &dst, const T &a, const T &b, const V &col_values) {
|
||||
KITTENS_CHECK_WARP
|
||||
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::row_vec_layout>); // compatible layout
|
||||
static_assert(V::outer_dim == T::width); // compatible size
|
||||
|
||||
using dtype = T::dtype;
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
dtype packed_left_col = base_types::packing<dtype>::pack(col_values[j][0].x); // first value in eager mode
|
||||
dtype packed_right_col = base_types::packing<dtype>::pack(col_values[j][0].y); // second value in eager mode
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < dst.packed_per_tile; k+=2) {
|
||||
dst.tiles[i][j].data[k+0] = op::template op<dtype>(a.tiles[i][j].data[k+0], b.tiles[i][j].data[k+0], packed_left_col);
|
||||
dst.tiles[i][j].data[k+1] = op::template op<dtype>(a.tiles[i][j].data[k+1], b.tiles[i][j].data[k+1], packed_right_col);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// All of the annoying qualifiers *should* be automatically inferred during compile-time.
|
||||
// So, syntax should just be kittens::add_row(tile, colvec);
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a tile to zero.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void zero(T &dst) {
|
||||
unary_map<base_ops::zero, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a tile to one.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void one(T &dst) {
|
||||
unary_map<base_ops::one, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a tile to positive infinity.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void pos_infty(T &dst) {
|
||||
unary_map<base_ops::pos_infty, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a tile to negative infinity.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void neg_infty(T &dst) {
|
||||
unary_map<base_ops::neg_infty, T>(dst, dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the exponential function on.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void exp(T &dst, const T &src) {
|
||||
unary_map<base_ops::exp, T>(dst, src);
|
||||
}
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline T exp(const T &src) {
|
||||
T dst;
|
||||
exp(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of a tile, in base 2.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the exponential function on.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void exp2(T &dst, const T &src) {
|
||||
unary_map<base_ops::exp2, T>(dst, src);
|
||||
}
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline T exp2(const T &src) {
|
||||
T dst;
|
||||
exp2(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the natural logarithm function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the natural logarithm function on.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void log(T &dst, const T &src) {
|
||||
unary_map<base_ops::log, T>(dst, src);
|
||||
}
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline T log(const T &src) {
|
||||
T dst;
|
||||
log(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the logarithm base 2 function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the logarithm base 2 function on.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void log2(T &dst, const T &src) {
|
||||
unary_map<base_ops::log2, T>(dst, src);
|
||||
}
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline T log2(const T &src) {
|
||||
T dst;
|
||||
log2(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the absolute value function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the absolute value function on.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void abs(T &dst, const T &src) {
|
||||
unary_map<base_ops::abs, T>(dst, src);
|
||||
}
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline T abs(const T &src) {
|
||||
T dst;
|
||||
abs(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the rectified linear unit (ReLU) function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the ReLU function on.
|
||||
*/
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline void relu(T &dst, const T &src) {
|
||||
unary_map<base_ops::relu, T>(dst, src);
|
||||
}
|
||||
template<ducks::rt::all T>
|
||||
__device__ static inline T relu(const T &src) {
|
||||
T dst;
|
||||
relu(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copies the elements from one tile to another.
|
||||
*
|
||||
* @tparam T Destination tile type.
|
||||
* @tparam U Source tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to copy from.
|
||||
*/
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void copy(T &dst, const U &src) {
|
||||
bin_map<base_ops::copy2, T>(dst, src);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the max operation element-wise between two tiles or a tile and a scalar.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the operation.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the operation.
|
||||
*/
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void max(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::max, T>(dst, lhs, rhs);
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline T max(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
max(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the min operation element-wise between two tiles or a tile and a scalar.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the operation.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the operation.
|
||||
*/
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void min(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::min, T>(dst, lhs, rhs);
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline T min(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
min(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds two tiles element-wise or adds a scalar to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the addition.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the addition.
|
||||
*/
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void add(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::sum, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Subtracts two tiles element-wise or subtracts a scalar from each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the subtraction.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the subtraction.
|
||||
*/
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::sub, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies two tiles element-wise or multiplies each element of a tile by a scalar.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the multiplication.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the multiplication.
|
||||
*/
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::mul, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Divides two tiles element-wise or divides each element of a tile by a scalar.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the division.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the division.
|
||||
*/
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void div(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::div, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds row values to each row of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the addition on.
|
||||
* @param row_values[in] Column vector containing values to add to each row.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void add_row(T &dst, const T &src, const V &row_values) {
|
||||
row_map<base_ops::sum, T, V>(dst, src, row_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Subtracts row values from each row of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the subtraction on.
|
||||
* @param row_values[in] Column vector containing values to subtract from each row.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void sub_row(T &dst, const T &src, const V &row_values) {
|
||||
row_map<base_ops::sub, T, V>(dst, src, row_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Multiplies each row of a tile by row values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the multiplication on.
|
||||
* @param row_values[in] Column vector containing values to multiply each row by.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void mul_row(T &dst, const T &src, const V &row_values) {
|
||||
row_map<base_ops::mul, T, V>(dst, src, row_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Divides each row of a tile by row values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the division on.
|
||||
* @param row_values[in] Column vector containing values to divide each row by.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void div_row(T &dst, const T &src, const V &row_values) {
|
||||
row_map<base_ops::div, T, V>(dst, src, row_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Broadcast a vector into into a tile's rows.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param row_values[in] Column vector containing values to broadcast into rows.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void broadcast_row(T &dst, const V &row_values) {
|
||||
row_map<base_ops::copy2, T, V>(dst, dst, row_values);
|
||||
}
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline T broadcast_row(const V &row_values) {
|
||||
T dst;
|
||||
broadcast_row(dst, row_values);
|
||||
return dst;
|
||||
}
|
||||
|
||||
|
||||
// col maps
|
||||
/**
|
||||
* @brief Adds column values to each column of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the addition on.
|
||||
* @param col_values[in] Row vector containing values to add to each column.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void add_col(T &dst, const T &src, const V &col_values) {
|
||||
col_map<base_ops::sum, T, V>(dst, src, col_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Subtracts column values from each column of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the subtraction on.
|
||||
* @param col_values[in] Row vector containing values to subtract from each column.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void sub_col(T &dst, const T &src, const V &col_values) {
|
||||
col_map<base_ops::sub, T, V>(dst, src, col_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Multiplies each column of a tile by column values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the multiplication on.
|
||||
* @param col_values[in] Row vector containing values to multiply each column by.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void mul_col(T &dst, const T &src, const V &col_values) {
|
||||
col_map<base_ops::mul, T, V>(dst, src, col_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Divides each column of a tile by column values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the division on.
|
||||
* @param col_values[in] Row vector containing values to divide each column by.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void div_col(T &dst, const T &src, const V &col_values) {
|
||||
col_map<base_ops::div, T, V>(dst, src, col_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Broadcast a vector into into a tile's columns.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param row_values[in] Row vector containing values to broadcast into cols.
|
||||
*/
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline void broadcast_col(T &dst, const V &col_values) {
|
||||
col_map<base_ops::copy2, T, V>(dst, dst, col_values);
|
||||
}
|
||||
template<ducks::rt::all T, ducks::rv::all V>
|
||||
__device__ static inline T broadcast_col(const V &col_values) {
|
||||
T dst;
|
||||
broadcast_col(dst, col_values);
|
||||
return dst;
|
||||
}
|
||||
|
||||
// Triangular masks
|
||||
template<ducks::rt::all RT>
|
||||
__device__ static inline void tril(RT &dst, const RT &src, int diagonal=0, const typename base_types::packing<typename RT::dtype>::unpacked_type &val=0) {
|
||||
apply(dst, src, [val, diagonal]__device__(int row, int col, auto &src_val) {
|
||||
return col <= row + diagonal ? src_val : val;
|
||||
});
|
||||
}
|
||||
template<ducks::rt::all RT>
|
||||
__device__ static inline void triu(RT &dst, const RT &src, int diagonal=0, const typename base_types::packing<typename RT::dtype>::unpacked_type &val=0) {
|
||||
apply(dst, src, [val, diagonal]__device__(int row, int col, auto &src_val) {
|
||||
return col >= row + diagonal ? src_val : val;
|
||||
});
|
||||
}
|
||||
@@ -0,0 +1,554 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Reduction operations mapping tiles to vectors.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Perform a row-wise reduction on a matrix in row-major layout.
|
||||
*
|
||||
* This function template performs a parallel reduction across the rows of a matrix using a specified operation.
|
||||
* It leverages warp shuffle functions for efficient intra-warp communication.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type with row layout.
|
||||
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
|
||||
*/
|
||||
template<typename op, ducks::rv::all V, ducks::rt::row_layout T, bool reset>
|
||||
__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) {
|
||||
// I actually like these static asserts because they give more verbose errors when things go wrong.
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::col_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::height); // compatible size
|
||||
|
||||
using dtype = V::dtype;
|
||||
|
||||
const int leader = threadIdx.x & 0x1C; // 11100 in binary
|
||||
#pragma unroll
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
dtype accum_top_row = op::template op<dtype>(src.tiles[i][0].data[0], src.tiles[i][0].data[2]);
|
||||
dtype accum_bottom_row = op::template op<dtype>(src.tiles[i][0].data[1], src.tiles[i][0].data[3]);
|
||||
#pragma unroll
|
||||
for(int j = 1; j < src.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < src.packed_per_tile; k+=2) {
|
||||
accum_top_row = op::template op<dtype>(accum_top_row, src.tiles[i][j].data[k+0]);
|
||||
accum_bottom_row = op::template op<dtype>(accum_bottom_row, src.tiles[i][j].data[k+1]);
|
||||
}
|
||||
}
|
||||
dtype accum_packed;
|
||||
accum_packed.x = op::template op<typename base_types::packing<dtype>::unpacked_type>(accum_top_row.x, accum_top_row.y);
|
||||
accum_packed.y = op::template op<typename base_types::packing<dtype>::unpacked_type>(accum_bottom_row.x, accum_bottom_row.y);
|
||||
|
||||
// Now we need to do a lil shuffle to make everyone happy.
|
||||
|
||||
accum_packed = op::template op<dtype>(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 2));
|
||||
accum_packed = op::template op<dtype>(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 1));
|
||||
|
||||
accum_packed = packed_shfl_sync(MASK_ALL, accum_packed, leader);
|
||||
|
||||
if(reset) {
|
||||
row_accum[i][0] = accum_packed;
|
||||
}
|
||||
else {
|
||||
row_accum[i][0] = op::template op<dtype>(src_accum[i][0], accum_packed);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a row-wise reduction on a matrix in column-major layout.
|
||||
*
|
||||
* This function template performs a parallel reduction across the rows of a matrix using a specified operation.
|
||||
* It leverages warp shuffle functions for efficient intra-warp communication and is optimized for column-major matrices.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type with column layout.
|
||||
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
|
||||
*/
|
||||
template<typename op, ducks::rv::all V, ducks::rt::col_layout T, bool reset>
|
||||
__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) {
|
||||
// I actually like these static asserts because they give more verbose errors when things go wrong.
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::col_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::height); // compatible size
|
||||
|
||||
using dtype = V::dtype;
|
||||
|
||||
const int leader = threadIdx.x & 0x3; // 00011 in binary
|
||||
#pragma unroll
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
dtype accum_top_rows = op::template op<dtype>(src.tiles[i][0].data[0], src.tiles[i][0].data[1]);
|
||||
dtype accum_bottom_rows = op::template op<dtype>(src.tiles[i][0].data[2], src.tiles[i][0].data[3]);
|
||||
#pragma unroll
|
||||
for(int j = 1; j < src.width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < src.packed_per_tile/2; k++) {
|
||||
accum_top_rows = op::template op<dtype>(accum_top_rows, src.tiles[i][j].data[k+0]);
|
||||
accum_bottom_rows = op::template op<dtype>(accum_bottom_rows, src.tiles[i][j].data[k+2]);
|
||||
}
|
||||
}
|
||||
|
||||
// Now we need to do a lil shuffle to make everyone happy.
|
||||
|
||||
accum_top_rows = op::template op<dtype>(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 16));
|
||||
accum_top_rows = op::template op<dtype>(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 8));
|
||||
accum_top_rows = op::template op<dtype>(accum_top_rows, packed_shfl_down_sync(MASK_ALL, accum_top_rows, 4));
|
||||
|
||||
accum_bottom_rows = op::template op<dtype>(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 16));
|
||||
accum_bottom_rows = op::template op<dtype>(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 8));
|
||||
accum_bottom_rows = op::template op<dtype>(accum_bottom_rows, packed_shfl_down_sync(MASK_ALL, accum_bottom_rows, 4));
|
||||
|
||||
accum_top_rows = packed_shfl_sync(MASK_ALL, accum_top_rows, leader);
|
||||
accum_bottom_rows = packed_shfl_sync(MASK_ALL, accum_bottom_rows, leader);
|
||||
|
||||
if(reset) {
|
||||
row_accum[i][0] = accum_top_rows;
|
||||
row_accum[i][1] = accum_bottom_rows;
|
||||
}
|
||||
else {
|
||||
row_accum[i][0] = op::template op<dtype>(src_accum[i][0], accum_top_rows);
|
||||
row_accum[i][1] = op::template op<dtype>(src_accum[i][1], accum_bottom_rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Col reduction.
|
||||
/**
|
||||
* @brief Perform a column-wise reduction on a matrix in row-major layout.
|
||||
*
|
||||
* This function template performs a parallel reduction across the columns of a matrix using a specified operation.
|
||||
* It leverages warp shuffle functions for efficient intra-warp communication and is optimized for row-major matrices.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The vector type for the column accumulator.
|
||||
* @tparam T The matrix type with row layout.
|
||||
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
|
||||
*/
|
||||
template<typename op, ducks::rv::all V, ducks::rt::row_layout T, bool reset>
|
||||
__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) {
|
||||
// I actually like these static asserts because they give more verbose errors when things go wrong.
|
||||
KITTENS_CHECK_WARP
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::row_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::width); // compatible size
|
||||
|
||||
using dtype = V::dtype;
|
||||
|
||||
const int leader = threadIdx.x & 0x3; // 00011 in binary
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
dtype accum_left_cols = op::template op<dtype>(src.tiles[0][j].data[0], src.tiles[0][j].data[1]);
|
||||
dtype accum_right_cols = op::template op<dtype>(src.tiles[0][j].data[2], src.tiles[0][j].data[3]);
|
||||
#pragma unroll
|
||||
for(int i = 1; i < src.height; i++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < src.packed_per_tile/2; k++) {
|
||||
accum_left_cols = op::template op<dtype>(accum_left_cols, src.tiles[i][j].data[k+0]);
|
||||
accum_right_cols = op::template op<dtype>(accum_right_cols, src.tiles[i][j].data[k+2]);
|
||||
}
|
||||
}
|
||||
|
||||
// Now we need to do a lil shuffle to make everyone happy.
|
||||
|
||||
accum_left_cols = op::template op<dtype>(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 16));
|
||||
accum_left_cols = op::template op<dtype>(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 8));
|
||||
accum_left_cols = op::template op<dtype>(accum_left_cols, packed_shfl_down_sync(MASK_ALL, accum_left_cols, 4));
|
||||
|
||||
accum_right_cols = op::template op<dtype>(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 16));
|
||||
accum_right_cols = op::template op<dtype>(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 8));
|
||||
accum_right_cols = op::template op<dtype>(accum_right_cols, packed_shfl_down_sync(MASK_ALL, accum_right_cols, 4));
|
||||
|
||||
accum_left_cols = packed_shfl_sync(MASK_ALL, accum_left_cols, leader);
|
||||
accum_right_cols = packed_shfl_sync(MASK_ALL, accum_right_cols, leader);
|
||||
|
||||
if(reset) {
|
||||
col_accum[j][0] = accum_left_cols;
|
||||
col_accum[j][1] = accum_right_cols;
|
||||
}
|
||||
else {
|
||||
col_accum[j][0] = op::template op<dtype>(src_accum[j][0], accum_left_cols);
|
||||
col_accum[j][1] = op::template op<dtype>(src_accum[j][1], accum_right_cols);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a column-wise reduction on a matrix in column-major layout.
|
||||
*
|
||||
* This function template performs a parallel reduction across the columns of a matrix using a specified operation.
|
||||
* It leverages warp shuffle functions for efficient intra-warp communication and is optimized for column-major matrices.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The vector type for the column accumulator.
|
||||
* @tparam T The matrix type with column layout.
|
||||
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
|
||||
*/
|
||||
template<typename op, ducks::rv::all V, ducks::rt::col_layout T, bool reset>
|
||||
__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) {
|
||||
// I actually like these static asserts because they give more verbose errors when things go wrong.
|
||||
KITTENS_CHECK_WARP
|
||||
static_assert(std::is_same_v<typename V::layout, typename rt_base<typename T::T, typename T::layout>::row_vec_layout>); // compatible layout
|
||||
static_assert(std::is_same_v<typename V::dtype, typename T::dtype>); // compatible type
|
||||
static_assert(V::outer_dim == T::width); // compatible size
|
||||
|
||||
using dtype = V::dtype;
|
||||
const int leader = threadIdx.x & 0x1C; // 11100 in binary
|
||||
#pragma unroll
|
||||
for(int j = 0; j < src.width; j++) { // note now width is the outer loop
|
||||
dtype accum_left_col = op::template op<dtype>(src.tiles[0][j].data[0], src.tiles[0][j].data[2]);
|
||||
dtype accum_right_col = op::template op<dtype>(src.tiles[0][j].data[1], src.tiles[0][j].data[3]);
|
||||
#pragma unroll
|
||||
for(int i = 1; i < src.height; i++) { // and height is the inner loop
|
||||
#pragma unroll
|
||||
for(int k = 0; k < src.packed_per_tile; k+=2) {
|
||||
accum_left_col = op::template op<dtype>(accum_left_col, src.tiles[i][j].data[k+0]);
|
||||
accum_right_col = op::template op<dtype>(accum_right_col, src.tiles[i][j].data[k+1]);
|
||||
}
|
||||
}
|
||||
dtype accum_packed;
|
||||
accum_packed.x = op::template op<typename base_types::packing<dtype>::unpacked_type>(accum_left_col.x, accum_left_col.y);
|
||||
accum_packed.y = op::template op<typename base_types::packing<dtype>::unpacked_type>(accum_right_col.x, accum_right_col.y);
|
||||
|
||||
// Now we need to do a lil shuffle to make everyone happy.
|
||||
|
||||
accum_packed = op::template op<dtype>(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 2));
|
||||
accum_packed = op::template op<dtype>(accum_packed, packed_shfl_down_sync(MASK_ALL, accum_packed, 1));
|
||||
|
||||
accum_packed = packed_shfl_sync(MASK_ALL, accum_packed, leader);
|
||||
|
||||
if(reset) {
|
||||
col_accum[j][0] = accum_packed;
|
||||
}
|
||||
else {
|
||||
col_accum[j][0] = op::template op<dtype>(src_accum[j][0], accum_packed);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// two-operand row reductions. (Accumulate and REPLACE.)
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src register tile in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void row_max(V &row_accum, const T &src) {
|
||||
row_reduce<base_ops::max, V, T, true>(row_accum, src, row_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src register tile in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void row_min(V &row_accum, const T &src) {
|
||||
row_reduce<base_ops::min, V, T, true>(row_accum, src, row_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src register tile in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void row_sum(V &row_accum, const T &src) {
|
||||
row_reduce<base_ops::sum, V, T, true>(row_accum, src, row_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src register tile in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void row_prod(V &row_accum, const T &src) {
|
||||
row_reduce<base_ops::mul, V, T, true>(row_accum, src, row_accum);
|
||||
}
|
||||
// three-operand row reductions. (Accumulate ONTO.)
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void row_max(V &row_accum, const T &src, const V &src_accum) {
|
||||
row_reduce<base_ops::max, V, T, false>(row_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void row_min(V &row_accum, const T &src, const V &src_accum) {
|
||||
row_reduce<base_ops::min, V, T, false>(row_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void row_sum(V &row_accum, const T &src, const V &src_accum) {
|
||||
row_reduce<base_ops::sum, V, T, false>(row_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void row_prod(V &row_accum, const T &src, const V &src_accum) {
|
||||
row_reduce<base_ops::mul, V, T, false>(row_accum, src, src_accum);
|
||||
}
|
||||
|
||||
// two-operand col reductions. (Accumulate and REPLACE.)
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src register tile in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void col_max(V &col_accum, const T &src) {
|
||||
col_reduce<base_ops::max, V, T, true>(col_accum, src, col_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src register tile in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void col_min(V &col_accum, const T &src) {
|
||||
col_reduce<base_ops::min, V, T, true>(col_accum, src, col_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src register tile in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void col_sum(V &col_accum, const T &src) {
|
||||
col_reduce<base_ops::sum, V, T, true>(col_accum, src, col_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src register tile in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void col_prod(V &col_accum, const T &src) {
|
||||
col_reduce<base_ops::mul, V, T, true>(col_accum, src, col_accum);
|
||||
}
|
||||
// three-operand col reductions. (Accumulate ONTO.)
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void col_max(V &col_accum, const T &src, const V &src_accum) {
|
||||
col_reduce<base_ops::max, V, T, false>(col_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void col_min(V &col_accum, const T &src, const V &src_accum) {
|
||||
col_reduce<base_ops::min, V, T, false>(col_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void col_sum(V &col_accum, const T &src, const V &src_accum) {
|
||||
col_reduce<base_ops::sum, V, T, false>(col_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::rv::all V, ducks::rt::all T>
|
||||
__device__ static inline void col_prod(V &col_accum, const T &src, const V &src_accum) {
|
||||
col_reduce<base_ops::mul, V, T, false>(col_accum, src, src_accum);
|
||||
}
|
||||
|
||||
// templated versions of each
|
||||
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline void max(RV &dst, const T &src, const RV &src_accum) {
|
||||
if constexpr (ax == axis::COL) row_max(dst, src, src_accum);
|
||||
else col_max(dst, src, src_accum);
|
||||
}
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline auto max(const T &src, const RV &src_accum) {
|
||||
RV dst;
|
||||
if constexpr (ax == axis::COL) row_max(dst, src, src_accum);
|
||||
else col_max(dst, src, src_accum);
|
||||
return dst;
|
||||
}
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline void max(RV &dst, const T &src) {
|
||||
if constexpr (ax == axis::COL) row_max(dst, src);
|
||||
else col_max(dst, src);
|
||||
}
|
||||
template<int ax, ducks::rt::all T>
|
||||
__device__ static inline auto max(const T &src) {
|
||||
using RV = std::conditional_t<ax==axis::COL, typename T::col_vec, typename T::row_vec>;
|
||||
RV dst;
|
||||
if constexpr (ax == axis::COL) row_max(dst, src);
|
||||
else col_max(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline void min(RV &dst, const T &src, const RV &src_accum) {
|
||||
if constexpr (ax == axis::COL) row_min(dst, src, src_accum);
|
||||
else col_min(dst, src, src_accum);
|
||||
}
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline auto min(const T &src, const RV &src_accum) {
|
||||
RV dst;
|
||||
if constexpr (ax == axis::COL) row_min(dst, src, src_accum);
|
||||
else col_min(dst, src, src_accum);
|
||||
return dst;
|
||||
}
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline void min(RV &dst, const T &src) {
|
||||
if constexpr (ax == axis::COL) row_min(dst, src);
|
||||
else col_min(dst, src);
|
||||
}
|
||||
template<int ax, ducks::rt::all T>
|
||||
__device__ static inline auto min(const T &src) {
|
||||
using RV = std::conditional_t<ax==axis::COL, typename T::col_vec, typename T::row_vec>;
|
||||
RV dst;
|
||||
if constexpr (ax == axis::COL) row_min(dst, src);
|
||||
else col_min(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline void sum(RV &dst, const T &src, const RV &src_accum) {
|
||||
if constexpr (ax == axis::COL) row_sum(dst, src, src_accum);
|
||||
else col_sum(dst, src, src_accum);
|
||||
}
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline auto sum(const T &src, const RV &src_accum) {
|
||||
RV dst;
|
||||
if constexpr (ax == axis::COL) row_sum(dst, src, src_accum);
|
||||
else col_sum(dst, src, src_accum);
|
||||
return dst;
|
||||
}
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline void sum(RV &dst, const T &src) {
|
||||
if constexpr (ax == axis::COL) row_sum(dst, src);
|
||||
else col_sum(dst, src);
|
||||
}
|
||||
template<int ax, ducks::rt::all T>
|
||||
__device__ static inline auto sum(const T &src) {
|
||||
using RV = std::conditional_t<ax==axis::COL, typename T::col_vec, typename T::row_vec>;
|
||||
RV dst;
|
||||
if constexpr (ax == axis::COL) row_sum(dst, src);
|
||||
else col_sum(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline void prod(RV &dst, const T &src, const RV &src_accum) {
|
||||
if constexpr (ax == axis::COL) row_prod(dst, src, src_accum);
|
||||
else col_prod(dst, src, src_accum);
|
||||
}
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline auto prod(const T &src, const RV &src_accum) {
|
||||
RV dst;
|
||||
if constexpr (ax == axis::COL) row_prod(dst, src, src_accum);
|
||||
else col_prod(dst, src, src_accum);
|
||||
return dst;
|
||||
}
|
||||
template<int ax, ducks::rv::all RV, ducks::rt::all T>
|
||||
__device__ static inline void prod(RV &dst, const T &src) {
|
||||
if constexpr (ax == axis::COL) row_prod(dst, src);
|
||||
else col_prod(dst, src);
|
||||
}
|
||||
template<int ax, ducks::rt::all T>
|
||||
__device__ static inline auto prod(const T &src) {
|
||||
using RV = std::conditional_t<ax==axis::COL, typename T::col_vec, typename T::row_vec>;
|
||||
RV dst;
|
||||
if constexpr (ax == axis::COL) row_prod(dst, src);
|
||||
else col_prod(dst, src);
|
||||
return dst;
|
||||
}
|
||||
47
extra/thunder/cuda/include/ops/group/register/tile/tile.cuh
Normal file
47
extra/thunder/cuda/include/ops/group/register/tile/tile.cuh
Normal file
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header for warp operations on register tiles.
|
||||
*/
|
||||
|
||||
#include "conversions.cuh"
|
||||
#include "maps.cuh"
|
||||
#include "reductions.cuh"
|
||||
|
||||
template<ducks::rt::all RT>
|
||||
__device__ static inline bool hasnan(const RT &src) {
|
||||
KITTENS_CHECK_WARP
|
||||
bool nan_detected = false;
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RT::height; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < RT::width; j++) {
|
||||
#pragma unroll
|
||||
for(int k = 0; k < RT::packed_per_tile; k++) {
|
||||
if constexpr (std::is_same_v<typename RT::T, float>) {
|
||||
if(isnan(src.tiles[i][j].data[k].x) || isnan(src.tiles[i][j].data[k].y)) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RT::T, bf16>) {
|
||||
if(isnan(__bfloat162float(src.tiles[i][j].data[k].x)) || isnan(__bfloat162float(src.tiles[i][j].data[k].y))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RT::T, half>) {
|
||||
if(isnan(__half2float(src.tiles[i][j].data[k].x)) || isnan(__half2float(src.tiles[i][j].data[k].y))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(typename RT::T) == 999, "Unsupported dtype");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ballot across the warp to see if any lane detected a nan
|
||||
return (__ballot_sync(0xffffffff, nan_detected) != 0);
|
||||
}
|
||||
|
||||
#include "complex/complex_conversions.cuh"
|
||||
#include "complex/complex_maps.cuh"
|
||||
|
||||
@@ -0,0 +1,153 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Conversions on vectors stored in registers.
|
||||
*/
|
||||
|
||||
struct vec_conversion_detail {
|
||||
|
||||
// i am not smart enough to figure out these indices without these helpers :/
|
||||
// again, blame nvidia for these stupid, stupid layouts
|
||||
__device__ static inline int row_from_indices_dim2(int laneid, int inner_dim, int x_or_y) {
|
||||
return 8*inner_dim + (laneid%4)*2 + x_or_y;
|
||||
}
|
||||
__device__ static inline int row_from_indices_dim1(int laneid, int x_or_y) {
|
||||
return 8*x_or_y + (laneid/4);
|
||||
}
|
||||
__device__ static inline int canonical_src_lane_dim2(int row) {
|
||||
return (row/2)%4 + 4*(row%2); // draw even rows from 0...3 and odds from 4...7
|
||||
}
|
||||
__device__ static inline int canonical_src_lane_dim1(int row) {
|
||||
return (row*4)%32;
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Copies data from one register vector to another.
|
||||
*
|
||||
* @tparam RV1 The type of the destination register vector.
|
||||
* @tparam RV2 The type of the source register vector.
|
||||
* @param dst[out] The destination register vector.
|
||||
* @param src[in] The source register vector to copy from.
|
||||
*/
|
||||
template<ducks::rv::all RV1, ducks::rv::all RV2>
|
||||
__device__ static inline void copy(RV1 &dst, const RV2 &src) {
|
||||
KITTENS_CHECK_WARP
|
||||
static_assert(RV1::length == RV2::length, "Register vectors must be the same length.");
|
||||
using D1 = RV1::dtype;
|
||||
using D2 = RV2::dtype;
|
||||
if constexpr (std::is_same_v<typename RV1::layout, typename RV2::layout>) { // just a simple copy / typecast
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < RV1::inner_dim; j++) {
|
||||
dst[i][j] = base_types::convertor<D1, D2>::convert(src[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
else { // Inner dimensions are not the same, this is really a layout conversion.
|
||||
int laneid = ::kittens::laneid();
|
||||
if constexpr (std::is_same_v<typename RV1::layout, ortho_l> && std::is_same_v<typename RV2::layout, align_l>) { // align -> ortho layout
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
dst[i][0].x = packed_shfl_sync(
|
||||
kittens::MASK_ALL,
|
||||
laneid < 4 ? src[i][0].x : src[i][0].y, // mirrors canonical_src_lane_dim2
|
||||
vec_conversion_detail::canonical_src_lane_dim2(vec_conversion_detail::row_from_indices_dim1(laneid, 0))
|
||||
);
|
||||
dst[i][0].y = packed_shfl_sync(
|
||||
kittens::MASK_ALL,
|
||||
laneid < 4 ? src[i][1].x : src[i][1].y, // mirrors canonical_src_lane_dim2
|
||||
vec_conversion_detail::canonical_src_lane_dim2(vec_conversion_detail::row_from_indices_dim1(laneid, 1))
|
||||
);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV1::layout, align_l> && std::is_same_v<typename RV2::layout, ortho_l>) { // ortho -> align layout
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
dst[i][0].x = packed_shfl_sync(
|
||||
kittens::MASK_ALL,
|
||||
src[i][0].x, // first 8 rows
|
||||
vec_conversion_detail::canonical_src_lane_dim1(vec_conversion_detail::row_from_indices_dim2(laneid, 0, 0))
|
||||
);
|
||||
dst[i][0].y = packed_shfl_sync(
|
||||
kittens::MASK_ALL,
|
||||
src[i][0].x, // first 8 rows
|
||||
vec_conversion_detail::canonical_src_lane_dim1(vec_conversion_detail::row_from_indices_dim2(laneid, 0, 1))
|
||||
);
|
||||
dst[i][1].x = packed_shfl_sync(
|
||||
kittens::MASK_ALL,
|
||||
src[i][0].y, // last 8 rows
|
||||
vec_conversion_detail::canonical_src_lane_dim1(vec_conversion_detail::row_from_indices_dim2(laneid, 1, 0))
|
||||
);
|
||||
dst[i][1].y = packed_shfl_sync(
|
||||
kittens::MASK_ALL,
|
||||
src[i][0].y, // last 8 rows
|
||||
vec_conversion_detail::canonical_src_lane_dim1(vec_conversion_detail::row_from_indices_dim2(laneid, 1, 1))
|
||||
);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV1::layout, ortho_l> && std::is_same_v<typename RV2::layout, naive_l>) { // naive -> ortho layout
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
dst[i][0].x = packed_shfl_sync(
|
||||
kittens::MASK_ALL, src[i/2][0],
|
||||
16*(i%2) + 0 + (laneid/4)
|
||||
);
|
||||
dst[i][0].y = packed_shfl_sync(
|
||||
kittens::MASK_ALL, src[i/2][0],
|
||||
16*(i%2) + 8 + (laneid/4)
|
||||
);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV1::layout, naive_l> && std::is_same_v<typename RV2::layout, ortho_l>) { // ortho -> naive layout
|
||||
int lane_replication = laneid%4; // 0...3
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
D1 tmp = 0;
|
||||
if(RV1::length%32==0 || i < RV1::outer_dim-1 || lane_replication<2) {
|
||||
tmp = lane_replication%2 ? src[2*i + (lane_replication>=2)][0].y : src[2*i + (lane_replication>=2)][0].x;
|
||||
}
|
||||
dst[i][0] = packed_shfl_sync(
|
||||
kittens::MASK_ALL, tmp,
|
||||
(laneid%8)*4 + (laneid/8)
|
||||
);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV1::layout, align_l> && std::is_same_v<typename RV2::layout, naive_l>) { // naive -> align layout
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
dst[i][0].x = packed_shfl_sync(
|
||||
kittens::MASK_ALL, src[i/2][0],
|
||||
16*(i%2) + 0 + 2*(laneid%4) + 0
|
||||
);
|
||||
dst[i][0].y = packed_shfl_sync(
|
||||
kittens::MASK_ALL, src[i/2][0],
|
||||
16*(i%2) + 0 + 2*(laneid%4) + 1
|
||||
);
|
||||
dst[i][1].x = packed_shfl_sync(
|
||||
kittens::MASK_ALL, src[i/2][0],
|
||||
16*(i%2) + 8 + 2*(laneid%4) + 0
|
||||
);
|
||||
dst[i][1].y = packed_shfl_sync(
|
||||
kittens::MASK_ALL, src[i/2][0],
|
||||
16*(i%2) + 8 + 2*(laneid%4) + 1
|
||||
);
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV1::layout, naive_l> && std::is_same_v<typename RV2::layout, align_l>) { // align -> naive layout
|
||||
int lane_replication = laneid/8; // 0...3
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
D1 tmp = 0;
|
||||
if(RV1::length%32==0 || i < RV1::outer_dim-1 || laneid<16) {
|
||||
tmp = (laneid%8)<4 ? src[2*i + (lane_replication>=2)][lane_replication%2].x : src[2*i + (lane_replication>=2)][lane_replication%2].y;
|
||||
}
|
||||
dst[i][0] = packed_shfl_sync(
|
||||
kittens::MASK_ALL, tmp,
|
||||
4*(laneid%2) + (laneid%8)/2 + (laneid&0b11000)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
374
extra/thunder/cuda/include/ops/group/register/vec/maps.cuh
Normal file
374
extra/thunder/cuda/include/ops/group/register/vec/maps.cuh
Normal file
@@ -0,0 +1,374 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Maps on vectors stored in registers.
|
||||
*/
|
||||
|
||||
/* ---------- Vector Maps ---------- */
|
||||
|
||||
/**
|
||||
* @brief Perform a unary operation on a vector.
|
||||
*
|
||||
* @tparam op The unary operation to perform.
|
||||
* @tparam T The type of the vector.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param src[in] The source vector to perform the operation on.
|
||||
*/
|
||||
template<typename op, ducks::rv::all T>
|
||||
__device__ static inline void unary_op(T &dst, const T &src) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.inner_dim; j++) {
|
||||
dst[i][j] = op::template op<typename T::dtype>(src[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on two vectors.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vectors.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param lhs[in] The left-hand side vector for the operation.
|
||||
* @param rhs[in] The right-hand side vector for the operation.
|
||||
*/
|
||||
template<typename op, ducks::rv::all T>
|
||||
__device__ static inline void bin_op(T &dst, const T &lhs, const T &rhs) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.inner_dim; j++) {
|
||||
dst[i][j] = op::template op<typename T::dtype>(lhs[i][j], rhs[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on a vector and a scalar.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vector.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param src[in] The source vector for the operation.
|
||||
* @param param[in] The scalar parameter for the operation.
|
||||
*/
|
||||
template<typename op, ducks::rv::all T>
|
||||
__device__ static inline void bin_op(T &dst, const T &src, const typename T::dtype ¶m) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < dst.inner_dim; j++) {
|
||||
dst[i][j] = op::template op<typename T::dtype>(src[i][j], param);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on a vector and an unpacked scalar.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vector.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param src[in] The source vector for the operation.
|
||||
* @param param[in] The unpacked scalar parameter for the operation.
|
||||
*/
|
||||
template<typename op, ducks::rv::tile_layout T>
|
||||
__device__ static inline void bin_op(T &dst, const T &src, const typename base_types::packing<typename T::dtype>::unpacked_type ¶m) {
|
||||
bin_op<op, T>(dst, src, base_types::packing<typename T::dtype>::pack(param));
|
||||
}
|
||||
|
||||
|
||||
template<ducks::rv::all RV, typename Lambda>
|
||||
__device__ static inline void apply(RV &dst, const RV &src, Lambda &&lambda) {
|
||||
int group_offset = 0;
|
||||
if constexpr(GROUP_WARPS > 1) {
|
||||
group_offset = warpid()*RV::length;
|
||||
}
|
||||
static_assert(sizeof(RV::T) != 1, "Cannot apply lambda to 8-bit types");
|
||||
if constexpr (ducks::rv::ortho_layout<RV>) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
int base_idx = group_offset + i*16 + ::kittens::laneid()/4;
|
||||
dst[i][0].x = lambda(base_idx+0, src[i][0].x);
|
||||
dst[i][0].y = lambda(base_idx+8, src[i][0].y);
|
||||
}
|
||||
}
|
||||
else if constexpr (ducks::rv::align_layout<RV>) {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
int base_idx = group_offset + i*16 + 2*(::kittens::laneid()%4);
|
||||
dst[i][0].x = lambda(base_idx+0, src[i][0].x);
|
||||
dst[i][0].y = lambda(base_idx+1, src[i][0].y);
|
||||
dst[i][1].x = lambda(base_idx+8, src[i][1].x);
|
||||
dst[i][1].y = lambda(base_idx+9, src[i][1].y);
|
||||
}
|
||||
}
|
||||
else {
|
||||
#pragma unroll
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
int base_idx = group_offset + i*32 + ::kittens::laneid();
|
||||
if (i < dst.outer_dim-1 || dst.length%32 == 0 || ::kittens::laneid()<16) {
|
||||
dst[i][0] = lambda(base_idx, src[i][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
template<ducks::rv::all RV, typename Lambda>
|
||||
__device__ static inline RV apply(const RV &src, Lambda &&lambda) {
|
||||
RV dst;
|
||||
apply<RV, Lambda>(dst, src, std::forward<Lambda>(lambda));
|
||||
return dst;
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// ---- const ops ----
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a register vector to zero.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector to be set to zero.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void zero(T &dst) {
|
||||
unary_op<base_ops::zero, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a register vector to one.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector to be set to one.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void one(T &dst) {
|
||||
unary_op<base_ops::one, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a register vector to positive infinity.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector to be set to positive infinity.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void pos_infty(T &dst) {
|
||||
unary_op<base_ops::pos_infty, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a register vector to negative infinity.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector to be set to negative infinity.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void neg_infty(T &dst) {
|
||||
unary_op<base_ops::neg_infty, T>(dst, dst);
|
||||
}
|
||||
|
||||
// ---- unary ops ----
|
||||
|
||||
/**
|
||||
* @brief Copies the elements from one register vector to another.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the source vector.
|
||||
* @param dst[out] Destination vector where the elements will be copied to.
|
||||
* @param src[in] Source vector to copy the elements from.
|
||||
*/
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void copy(T &dst, const U &src) {
|
||||
bin_op<base_ops::copy2, T>(dst, dst, src); // the second arg is ignored here.
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void exp(T &dst, const T &src) {
|
||||
unary_op<base_ops::exp, T>(dst, src);
|
||||
}
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline T exp(const T &src) {
|
||||
T dst;
|
||||
exp(dst, src);
|
||||
return dst;
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a register vector, in base 2.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void exp2(T &dst, const T &src) {
|
||||
unary_op<base_ops::exp2, T>(dst, src);
|
||||
}
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline T exp2(const T &src) {
|
||||
T dst;
|
||||
exp2(dst, src);
|
||||
return dst;
|
||||
}
|
||||
/**
|
||||
* @brief Applies the natural logarithm function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void log(T &dst, const T &src) {
|
||||
unary_op<base_ops::log, T>(dst, src);
|
||||
}
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline T log(const T &src) {
|
||||
T dst;
|
||||
log(dst, src);
|
||||
return dst;
|
||||
}
|
||||
/**
|
||||
* @brief Applies the logarithm base 2 function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the logarithm base 2 function to.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void log2(T &dst, const T &src) {
|
||||
unary_op<base_ops::log2, T>(dst, src);
|
||||
}
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline T log2(const T &src) {
|
||||
T dst;
|
||||
log2(dst, src);
|
||||
return dst;
|
||||
}
|
||||
/**
|
||||
* @brief Applies the absolute value function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the absolute values will be stored.
|
||||
* @param src[in] Source vector to apply the absolute value function to.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void abs(T &dst, const T &src) {
|
||||
unary_op<base_ops::abs, T>(dst, src);
|
||||
}
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline T abs(const T &src) {
|
||||
T dst;
|
||||
abs(dst, src);
|
||||
return dst;
|
||||
}
|
||||
/**
|
||||
* @brief Applies the rectified linear unit (ReLU) function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the ReLU values will be stored.
|
||||
* @param src[in] Source vector to apply the ReLU function to.
|
||||
*/
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline void relu(T &dst, const T &src) {
|
||||
unary_op<base_ops::relu, T>(dst, src);
|
||||
}
|
||||
template<ducks::rv::all T>
|
||||
__device__ static inline T relu(const T &src) {
|
||||
T dst;
|
||||
relu(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
// ---- binary ops ----
|
||||
|
||||
/**
|
||||
* @brief Computes the element-wise maximum of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the maximum values will be stored.
|
||||
* @param lhs[in] First vector for the maximum operation.
|
||||
* @param rhs[in] Second vector for the maximum operation.
|
||||
*/
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void max(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::max, T>(dst, lhs, rhs);
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline T max(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
max(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise minimum of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the minimum values will be stored.
|
||||
* @param lhs[in] First vector for the minimum operation.
|
||||
* @param rhs[in] Second vector for the minimum operation.
|
||||
*/
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void min(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::min, T>(dst, lhs, rhs);
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline T min(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
min(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise sum of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the sum values will be stored.
|
||||
* @param lhs[in] First vector for the sum operation.
|
||||
* @param rhs[in] Second vector for the sum operation.
|
||||
*/
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void add(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::sum, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise difference of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the difference values will be stored.
|
||||
* @param lhs[in] First vector for the difference operation.
|
||||
* @param rhs[in] Second vector for the difference operation.
|
||||
*/
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::sub, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise product of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the product values will be stored.
|
||||
* @param lhs[in] First vector for the product operation.
|
||||
* @param rhs[in] Second vector for the product operation.
|
||||
*/
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::mul, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise division of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the division values will be stored.
|
||||
* @param lhs[in] First vector for the division operation.
|
||||
* @param rhs[in] Second vector for the division operation.
|
||||
*/
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void div(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::div, T>(dst, lhs, rhs);
|
||||
}
|
||||
233
extra/thunder/cuda/include/ops/group/register/vec/reductions.cuh
Normal file
233
extra/thunder/cuda/include/ops/group/register/vec/reductions.cuh
Normal file
@@ -0,0 +1,233 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Reductions on vectors stored in registers.
|
||||
*/
|
||||
|
||||
/* ---------- Vector Reductions ---------- */
|
||||
|
||||
/**
|
||||
* @brief Performs a reduction operation on elements of a register vector within a warp.
|
||||
*
|
||||
* This function applies a specified operation to reduce the elements of a register vector `src` to a single value.
|
||||
* The result is stored in `accum`. If the `reset` parameter is true, the reduction includes an initial value `src_accum`.
|
||||
* The reduction operation is performed in a warp-wide context, ensuring synchronization between threads in the warp.
|
||||
*
|
||||
* @tparam op The operation to perform on the elements. Must provide a static `op` method.
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @tparam reset A boolean flag indicating whether to include an initial value in the reduction.
|
||||
* @param[out] accum The result of the reduction operation.
|
||||
* @param[in] src The register vector to reduce.
|
||||
* @param[in] src_accum The initial value to include in the reduction if `reset` is false.
|
||||
*/
|
||||
template<typename op, ducks::rv::all RV, bool reset>
|
||||
__device__ static inline void reduce(
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type &dst_accum,
|
||||
const RV &src,
|
||||
const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
KITTENS_CHECK_WARP
|
||||
using T = base_types::packing<typename RV::dtype>::unpacked_type;
|
||||
int laneid = kittens::laneid();
|
||||
if constexpr (std::is_same_v<typename RV::layout, ortho_l>) {
|
||||
T accum = op::template op<T>(src[0][0].x, src[0][0].y);
|
||||
#pragma unroll
|
||||
for(int i = 1; i < src.outer_dim; i++) {
|
||||
accum = op::template op<T>(accum, src[i][0].x);
|
||||
accum = op::template op<T>(accum, src[i][0].y);
|
||||
}
|
||||
// we've now reduced everything into 8 distinct values, replicated across lanes x, x+1, x+2, x+3 for x≡0(mod4)
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 16));
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 8));
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 4));
|
||||
// we've now reduced everything into 1 distinct value, replicated across lanes 0, 1, 2, 3
|
||||
if constexpr (!reset) accum = op::template op<T>(accum, src_accum);
|
||||
// final result has now been achieved (incorporating src_accum if necessary), finally broadcast back to all threads.
|
||||
dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0);
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, align_l>) {
|
||||
T accum = op::template op<T>(src[0][0].x, src[0][0].y);
|
||||
accum = op::template op<T>(accum, src[0][1].x);
|
||||
accum = op::template op<T>(accum, src[0][1].y);
|
||||
#pragma unroll
|
||||
for(int i = 1; i < src.outer_dim; i++) {
|
||||
// it is possible that shfl_sync's would be faster but I doubt it, replication is likely better. Certainly simpler.
|
||||
accum = op::template op<T>(accum, src[i][0].x);
|
||||
accum = op::template op<T>(accum, src[i][0].y);
|
||||
accum = op::template op<T>(accum, src[i][1].x);
|
||||
accum = op::template op<T>(accum, src[i][1].y);
|
||||
}
|
||||
// we've now reduced everything into 4 distinct values, replicated across lanes x, x+4, x+8, ..., x+28 for x<4
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 2));
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 1));
|
||||
// we've now reduced everything into 1 distinct value, replicated across lanes 0, 4, 8, 12, ..., 28
|
||||
if constexpr (!reset) accum = op::template op<T>(accum, src_accum);
|
||||
// final result has now been achieved (incorporating src_accum if necessary), finally broadcast back to all threads from lane 0
|
||||
dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0);
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::layout, naive_l>) {
|
||||
T accum = src[0][0];
|
||||
#pragma unroll
|
||||
for(int i = 1; i < src.outer_dim; i++) {
|
||||
if (i < src.outer_dim-1 || i*kittens::TILE_ROW_DIM<T>*2 + laneid < src.length) {
|
||||
accum = op::template op<T>(accum, src[i][0]);
|
||||
}
|
||||
}
|
||||
if(src.length > 16) accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 16));
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 8));
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 4));
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 2));
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 1));
|
||||
if constexpr (!reset) accum = op::template op<T>(accum, src_accum);
|
||||
dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Finds the maximum element in a register vector.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] max_val The maximum value found in the vector.
|
||||
* @param[in] src The register vector to find the maximum in.
|
||||
*/
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline void max(typename base_types::packing<typename RV::dtype>::unpacked_type &max_val, const RV &src) {
|
||||
reduce<base_ops::max, RV, true>(max_val, src, max_val);
|
||||
}
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline typename base_types::packing<typename RV::dtype>::unpacked_type max(const RV &src) {
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type max_val;
|
||||
reduce<base_ops::max, RV, true>(max_val, src, max_val);
|
||||
return max_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum element in a register vector.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] min_val The minimum value found in the vector.
|
||||
* @param[in] src The register vector to find the minimum in.
|
||||
*/
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline void min(typename base_types::packing<typename RV::dtype>::unpacked_type &min_val, const RV &src) {
|
||||
reduce<base_ops::min, RV, true>(min_val, src, min_val);
|
||||
}
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline typename base_types::packing<typename RV::dtype>::unpacked_type min(const RV &src) {
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type min_val;
|
||||
reduce<base_ops::min, RV, true>(min_val, src, min_val);
|
||||
return min_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the sum of elements in a register vector.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] sum_val The sum of the values in the vector.
|
||||
* @param[in] src The register vector to sum.
|
||||
*/
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline void sum(typename base_types::packing<typename RV::dtype>::unpacked_type &sum_val, const RV &src) {
|
||||
reduce<base_ops::sum, RV, true>(sum_val, src, sum_val);
|
||||
}
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline typename base_types::packing<typename RV::dtype>::unpacked_type sum(const RV &src) {
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type sum_val;
|
||||
reduce<base_ops::sum, RV, true>(sum_val, src, sum_val);
|
||||
return sum_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the product of elements in a register vector.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] prod_val The product of the values in the vector.
|
||||
* @param[in] src The register vector to multiply.
|
||||
*/
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline void prod(typename base_types::packing<typename RV::dtype>::unpacked_type &prod_val, const RV &src) {
|
||||
reduce<base_ops::mul, RV, true>(prod_val, src, prod_val);
|
||||
}
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline typename base_types::packing<typename RV::dtype>::unpacked_type prod(const RV &src) {
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type prod_val;
|
||||
reduce<base_ops::mul, RV, true>(prod_val, src, prod_val);
|
||||
return prod_val;
|
||||
}
|
||||
|
||||
// Three operand versions.
|
||||
|
||||
/**
|
||||
* @brief Finds the maximum element in a register vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] max_val The maximum value found in the vector, accumulated with src_accum.
|
||||
* @param[in] src The register vector to find the maximum in.
|
||||
* @param[in] src_accum The initial value to accumulate with the maximum value found.
|
||||
*/
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline void max(typename base_types::packing<typename RV::dtype>::unpacked_type &max_val, const RV &src, const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
reduce<base_ops::max, RV, false>(max_val, src, src_accum);
|
||||
}
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline typename base_types::packing<typename RV::dtype>::unpacked_type max(const RV &src, const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type max_val;
|
||||
reduce<base_ops::max, RV, false>(max_val, src, src_accum);
|
||||
return max_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum element in a register vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] min_val The minimum value found in the vector, accumulated with src_accum.
|
||||
* @param[in] src The register vector to find the minimum in.
|
||||
* @param[in] src_accum The initial value to accumulate with the minimum value found.
|
||||
*/
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline void min(typename base_types::packing<typename RV::dtype>::unpacked_type &min_val, const RV &src, const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
reduce<base_ops::min, RV, false>(min_val, src, src_accum);
|
||||
}
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline typename base_types::packing<typename RV::dtype>::unpacked_type min(const RV &src, const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type min_val;
|
||||
reduce<base_ops::min, RV, false>(min_val, src, src_accum);
|
||||
return min_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the sum of elements in a register vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] sum_val The sum of the values in the vector, accumulated with src_accum.
|
||||
* @param[in] src The register vector to sum.
|
||||
* @param[in] src_accum The initial value to accumulate with the sum of the vector.
|
||||
*/
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline void sum(typename base_types::packing<typename RV::dtype>::unpacked_type &sum_val, const RV &src, const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
reduce<base_ops::sum, RV, false>(sum_val, src, src_accum);
|
||||
}
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline typename base_types::packing<typename RV::dtype>::unpacked_type sum(const RV &src, const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type sum_val;
|
||||
reduce<base_ops::sum, RV, false>(sum_val, src, src_accum);
|
||||
return sum_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the product of elements in a register vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] prod_val The product of the values in the vector, accumulated with src_accum.
|
||||
* @param[in] src The register vector to multiply.
|
||||
* @param[in] src_accum The initial value to accumulate with the product of the vector.
|
||||
*/
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline void prod(typename base_types::packing<typename RV::dtype>::unpacked_type &prod_val, const RV &src, const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
reduce<base_ops::mul, RV, false>(prod_val, src, src_accum);
|
||||
}
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline typename base_types::packing<typename RV::dtype>::unpacked_type prod(const RV &src, const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum) {
|
||||
typename base_types::packing<typename RV::dtype>::unpacked_type prod_val;
|
||||
reduce<base_ops::mul, RV, false>(prod_val, src, src_accum);
|
||||
return prod_val;
|
||||
}
|
||||
59
extra/thunder/cuda/include/ops/group/register/vec/vec.cuh
Normal file
59
extra/thunder/cuda/include/ops/group/register/vec/vec.cuh
Normal file
@@ -0,0 +1,59 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header for warp operations on register vectors.
|
||||
*/
|
||||
|
||||
#include "conversions.cuh"
|
||||
#include "maps.cuh"
|
||||
#include "reductions.cuh"
|
||||
|
||||
template<ducks::rv::all RV>
|
||||
__device__ static inline bool hasnan(const RV &src) {
|
||||
KITTENS_CHECK_WARP
|
||||
bool nan_detected = false;
|
||||
#pragma unroll
|
||||
for(int i = 0; i < RV::outer_dim; i++) {
|
||||
#pragma unroll
|
||||
for(int j = 0; j < RV::inner_dim; j++) {
|
||||
if constexpr (std::is_same_v<typename RV::dtype, typename RV::T2>) {
|
||||
if constexpr (std::is_same_v<typename RV::dtype, float2>) {
|
||||
if(isnan(src[i][j].x) || isnan(src[i][j].y)) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::dtype, bf16_2>) {
|
||||
if(isnan(__bfloat162float(src[i][j].x)) || isnan(__bfloat162float(src[i][j].y))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::dtype, half_2>) {
|
||||
if(isnan(__half2float(src[i][j].x)) || isnan(__half2float(src[i][j].y))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::dtype, typename RV::T>) {
|
||||
if constexpr (std::is_same_v<typename RV::dtype, float>) {
|
||||
if(isnan(src[i][j])) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::dtype, bf16>) {
|
||||
if(isnan(__bfloat162float(src[i][j]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename RV::dtype, half>) {
|
||||
if(isnan(__half2float(src[i][j]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(typename RV::dtype) == 999, "Unsupported dtype");
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ballot across the warp to see if any lane detected a nan
|
||||
return (__ballot_sync(0xffffffff, nan_detected) != 0);
|
||||
}
|
||||
7
extra/thunder/cuda/include/ops/group/shared/shared.cuh
Normal file
7
extra/thunder/cuda/include/ops/group/shared/shared.cuh
Normal file
@@ -0,0 +1,7 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of group operations on data in shared memory
|
||||
*/
|
||||
|
||||
#include "tile/tile.cuh"
|
||||
#include "vec/vec.cuh"
|
||||
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group conversions between different shared memory tile types.
|
||||
*/
|
||||
|
||||
/* ---------- COPIES ---------- */
|
||||
|
||||
template<ducks::st::all ST1, ducks::st::all ST2>
|
||||
__device__ static inline void copy(ST1 &dst, const ST2 &src) {
|
||||
static_assert(ST1::height == ST2::height && ST1::width == ST2::width, "Tiles must have the same height and width");
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < dst.num_elements; i+=GROUP_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst[{row, col}] = base_types::convertor<typename ST1::dtype, typename ST2::dtype>::convert(src[{row, col}]);
|
||||
}
|
||||
}
|
||||
236
extra/thunder/cuda/include/ops/group/shared/tile/maps.cuh
Normal file
236
extra/thunder/cuda/include/ops/group/shared/tile/maps.cuh
Normal file
@@ -0,0 +1,236 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group maps on shared tiles.
|
||||
*/
|
||||
|
||||
|
||||
template<typename op, ducks::st::all T> // T2, w, h can be inferred from dst as long as op is specialized
|
||||
__device__ static inline void unary_map(T &dst, const T &src) {
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
dst.data[i] = op::template op<typename T::dtype>(src.data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename op, ducks::st::all T>
|
||||
__device__ static inline void bin_map(T &dst, const T &src, const typename T::dtype ¶m) {
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
dst.data[i] = op::template op<typename T::dtype>(src.data[i], param);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename op, ducks::st::all T>
|
||||
__device__ static inline void bin_map(T &dst, const T &lhs, const T &rhs) {
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
dst.data[i] = op::template op<typename T::dtype>(lhs.data[i], rhs.data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename op, ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void row_map(T &dst, const T &src, const V &vec) {
|
||||
static_assert(std::is_same<typename T::dtype, typename V::dtype>::value, "Tile and vector must have the same data type");
|
||||
static_assert(V::length == T::rows, "Vector length must match the number of rows in the tile");
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst[{row, col}] = op::template op<typename T::dtype>(src[{row, col}], vec[row]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename op, ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void col_map(T &dst, const T &src, const V &vec) {
|
||||
static_assert(std::is_same<typename T::dtype, typename V::dtype>::value, "Tile and vector must have the same data type");
|
||||
static_assert(V::length == T::cols, "Vector length must match the number of columns in the tile");
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst[{row, col}] = op::template op<typename T::dtype>(src[{row, col}], vec[col]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// All of the annoying qualifiers *should* be automatically inferred during compile-time.
|
||||
// So, syntax should just be kittens::add_row(tile, colvec);
|
||||
|
||||
// const maps
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void zero(T &dst) {
|
||||
unary_map<base_ops::zero, T>(dst, dst);
|
||||
}
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void one(T &dst) {
|
||||
unary_map<base_ops::one, T>(dst, dst);
|
||||
}
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void pos_infty(T &dst) {
|
||||
unary_map<base_ops::pos_infty, T>(dst, dst);
|
||||
}
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void neg_infty(T &dst) {
|
||||
unary_map<base_ops::neg_infty, T>(dst, dst);
|
||||
}
|
||||
|
||||
// unary maps
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void exp(T &dst, const T &src) {
|
||||
unary_map<base_ops::exp, T>(dst, src);
|
||||
}
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void exp2(T &dst, const T &src) {
|
||||
unary_map<base_ops::exp2, T>(dst, src);
|
||||
}
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void log(T &dst, const T &src) {
|
||||
unary_map<base_ops::log, T>(dst, src);
|
||||
}
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void log2(T &dst, const T &src) {
|
||||
unary_map<base_ops::log2, T>(dst, src);
|
||||
}
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void abs(T &dst, const T &src) {
|
||||
unary_map<base_ops::abs, T>(dst, src);
|
||||
}
|
||||
|
||||
template<ducks::st::all T>
|
||||
__device__ static inline void relu(T &dst, const T &src) {
|
||||
unary_map<base_ops::relu, T>(dst, src);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, typename U>
|
||||
__device__ static inline void copy(T &dst, const U &src) {
|
||||
bin_map<base_ops::copy, T>(dst, src);
|
||||
}
|
||||
|
||||
// uniform binary maps
|
||||
|
||||
template<ducks::st::all T, typename U>
|
||||
__device__ static inline void max(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::max, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, typename U>
|
||||
__device__ static inline void min(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::min, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, typename U>
|
||||
__device__ static inline void add(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::sum, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, typename U>
|
||||
__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::sub, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, typename U>
|
||||
__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::mul, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, typename U>
|
||||
__device__ static inline void div(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_map<base_ops::div, T>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
// Row and col maps
|
||||
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void add_row(T &dst, const T &src, const V &row_values) {
|
||||
row_map<base_ops::sum, T, V>(dst, src, row_values);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void sub_row(T &dst, const T &src, const V &row_values) {
|
||||
row_map<base_ops::sub, T, V>(dst, src, row_values);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void mul_row(T &dst, const T &src, const V &row_values) {
|
||||
row_map<base_ops::mul, T, V>(dst, src, row_values);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void div_row(T &dst, const T &src, const V &row_values) {
|
||||
row_map<base_ops::div, T, V>(dst, src, row_values);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void broadcast_row(T &dst, const V &row_values) {
|
||||
row_map<base_ops::copy2, T, V>(dst, dst, row_values);
|
||||
}
|
||||
|
||||
|
||||
// col maps
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void add_col(T &dst, const T &src, const V &col_values) {
|
||||
col_map<base_ops::sum, T, V>(dst, src, col_values);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void sub_col(T &dst, const T &src, const V &col_values) {
|
||||
col_map<base_ops::sub, T, V>(dst, src, col_values);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void mul_col(T &dst, const T &src, const V &col_values) {
|
||||
col_map<base_ops::mul, T, V>(dst, src, col_values);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void div_col(T &dst, const T &src, const V &col_values) {
|
||||
col_map<base_ops::div, T, V>(dst, src, col_values);
|
||||
}
|
||||
|
||||
template<ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void broadcast_col(T &dst, const V &col_values) {
|
||||
col_map<base_ops::copy2, T, V>(dst, dst, col_values);
|
||||
}
|
||||
|
||||
// Templated versions of each
|
||||
|
||||
template<int axis, ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void add(T &dst, const T &src, const V &col_values) {
|
||||
if constexpr (axis == axis::COL) add_col(dst, src, col_values);
|
||||
else add_row(dst, src, col_values);
|
||||
}
|
||||
|
||||
template<int axis, ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void sub(T &dst, const T &src, const V &col_values) {
|
||||
if constexpr (axis == axis::COL) sub_col(dst, src, col_values);
|
||||
else sub_row(dst, src, col_values);
|
||||
}
|
||||
|
||||
template<int axis, ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void mul(T &dst, const T &src, const V &col_values) {
|
||||
if constexpr (axis == axis::COL) mul_col(dst, src, col_values);
|
||||
else mul_row(dst, src, col_values);
|
||||
}
|
||||
|
||||
template<int axis, ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void div(T &dst, const T &src, const V &col_values) {
|
||||
if constexpr (axis == axis::COL) div_col(dst, src, col_values);
|
||||
else div_row(dst, src, col_values);
|
||||
}
|
||||
|
||||
template<int axis, ducks::st::all T, ducks::sv::all V>
|
||||
__device__ static inline void broadcast(T &dst, const V &col_values) {
|
||||
if constexpr (axis == axis::COL) broadcast_col(dst, col_values);
|
||||
else broadcast_row(dst, col_values);
|
||||
}
|
||||
372
extra/thunder/cuda/include/ops/group/shared/tile/reductions.cuh
Normal file
372
extra/thunder/cuda/include/ops/group/shared/tile/reductions.cuh
Normal file
@@ -0,0 +1,372 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group reductions on shared tiles.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Performs row-wise reduction on a matrix using a specified operation.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type with row layout.
|
||||
* @param row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param src The source matrix on which to perform the reduction.
|
||||
* @param src_accum The initial value of the accumulator, used when reset is false.
|
||||
* @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
*/
|
||||
template<typename op, ducks::sv::all V, ducks::st::all T, bool reset>
|
||||
__device__ static inline void row_reduce(V &row_accum, const T &src, const V &src_accum) {
|
||||
using dtype = typename V::dtype;
|
||||
for (int row = laneid(); row < src.rows; row += GROUP_THREADS) {
|
||||
dtype accum = src[{row, 0}];
|
||||
#pragma unroll
|
||||
for (int col = 1; col < src.cols; col++) {
|
||||
accum = op::template op<dtype>(accum, src[{row, col}]);
|
||||
}
|
||||
if (reset) {
|
||||
row_accum[row] = accum;
|
||||
} else {
|
||||
row_accum[row] = op::template op<dtype>(src_accum[row], accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs column-wise reduction on a matrix using a specified operation.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The shared vector type for the column accumulator.
|
||||
* @tparam T The shared matrix type with column layout.
|
||||
* @param col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param src The source matrix on which to perform the reduction.
|
||||
* @param src_accum The initial value of the accumulator, used when reset is false.
|
||||
* @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
*/
|
||||
template<typename op, ducks::sv::all V, ducks::st::all T, bool reset>
|
||||
__device__ static inline void col_reduce(V &col_accum, const T &src, const V &src_accum) {
|
||||
using dtype = typename V::dtype;
|
||||
for (int col = laneid(); col < src.cols; col += GROUP_THREADS) {
|
||||
dtype accum = src[{0, col}];
|
||||
#pragma unroll
|
||||
for (int row = 1; row < src.rows; row++) {
|
||||
accum = op::template op<dtype>(accum, src[{row, col}]);
|
||||
}
|
||||
if (reset) {
|
||||
col_accum[col] = accum;
|
||||
} else {
|
||||
col_accum[col] = op::template op<dtype>(src_accum[col], accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void row_max(V &row_accum, const T &src) {
|
||||
row_reduce<base_ops::max, V, T, true>(row_accum, src, row_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void row_min(V &row_accum, const T &src) {
|
||||
row_reduce<base_ops::min, V, T, true>(row_accum, src, row_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void row_sum(V &row_accum, const T &src) {
|
||||
row_reduce<base_ops::sum, V, T, true>(row_accum, src, row_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void row_prod(V &row_accum, const T &src) {
|
||||
row_reduce<base_ops::mul, V, T, true>(row_accum, src, row_accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void row_max(V &row_accum, const T &src, const V &src_accum) {
|
||||
row_reduce<base_ops::max, V, T, false>(row_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void row_min(V &row_accum, const T &src, const V &src_accum) {
|
||||
row_reduce<base_ops::min, V, T, false>(row_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void row_sum(V &row_accum, const T &src, const V &src_accum) {
|
||||
row_reduce<base_ops::sum, V, T, false>(row_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void row_prod(V &row_accum, const T &src, const V &src_accum) {
|
||||
row_reduce<base_ops::mul, V, T, false>(row_accum, src, src_accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void col_max(V &col_accum, const T &src) {
|
||||
col_reduce<base_ops::max, V, T, true>(col_accum, src, col_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void col_min(V &col_accum, const T &src) {
|
||||
col_reduce<base_ops::min, V, T, true>(col_accum, src, col_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void col_sum(V &col_accum, const T &src) {
|
||||
col_reduce<base_ops::sum, V, T, true>(col_accum, src, col_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void col_prod(V &col_accum, const T &src) {
|
||||
col_reduce<base_ops::mul, V, T, true>(col_accum, src, col_accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void col_max(V &col_accum, const T &src, const V &src_accum) {
|
||||
col_reduce<base_ops::max, V, T, false>(col_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void col_min(V &col_accum, const T &src, const V &src_accum) {
|
||||
col_reduce<base_ops::min, V, T, false>(col_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void col_sum(V &col_accum, const T &src, const V &src_accum) {
|
||||
col_reduce<base_ops::sum, V, T, false>(col_accum, src, src_accum);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void col_prod(V &col_accum, const T &src, const V &src_accum) {
|
||||
col_reduce<base_ops::mul, V, T, false>(col_accum, src, src_accum);
|
||||
}
|
||||
|
||||
// templated versions of each
|
||||
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void max(V &dst, const T &src, const V &src_accum) {
|
||||
if constexpr (ax == axis::COL) row_max(dst, src, src_accum);
|
||||
else col_max(dst, src, src_accum);
|
||||
}
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline auto max(const T &src, const V &src_accum) {
|
||||
V dst;
|
||||
if constexpr (ax == axis::COL) row_max(dst, src, src_accum);
|
||||
else col_max(dst, src, src_accum);
|
||||
return dst;
|
||||
}
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void max(V &dst, const T &src) {
|
||||
if constexpr (ax == axis::COL) row_max(dst, src);
|
||||
else col_max(dst, src);
|
||||
}
|
||||
template<int ax, ducks::st::all T>
|
||||
__device__ static inline auto max(const T &src) {
|
||||
using V = std::conditional_t<ax==axis::COL, typename T::col_vec, typename T::row_vec>;
|
||||
V dst;
|
||||
if constexpr (ax == axis::COL) row_max(dst, src);
|
||||
else col_max(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void min(V &dst, const T &src, const V &src_accum) {
|
||||
if constexpr (ax == axis::COL) row_min(dst, src, src_accum);
|
||||
else col_min(dst, src, src_accum);
|
||||
}
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline auto min(const T &src, const V &src_accum) {
|
||||
V dst;
|
||||
if constexpr (ax == axis::COL) row_min(dst, src, src_accum);
|
||||
else col_min(dst, src, src_accum);
|
||||
return dst;
|
||||
}
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void min(V &dst, const T &src) {
|
||||
if constexpr (ax == axis::COL) row_min(dst, src);
|
||||
else col_min(dst, src);
|
||||
}
|
||||
template<int ax, ducks::st::all T>
|
||||
__device__ static inline auto min(const T &src) {
|
||||
using V = std::conditional_t<ax==axis::COL, typename T::col_vec, typename T::row_vec>;
|
||||
V dst;
|
||||
if constexpr (ax == axis::COL) row_min(dst, src);
|
||||
else col_min(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void sum(V &dst, const T &src, const V &src_accum) {
|
||||
if constexpr (ax == axis::COL) row_sum(dst, src, src_accum);
|
||||
else col_sum(dst, src, src_accum);
|
||||
}
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline auto sum(const T &src, const V &src_accum) {
|
||||
V dst;
|
||||
if constexpr (ax == axis::COL) row_sum(dst, src, src_accum);
|
||||
else col_sum(dst, src, src_accum);
|
||||
return dst;
|
||||
}
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void sum(V &dst, const T &src) {
|
||||
if constexpr (ax == axis::COL) row_sum(dst, src);
|
||||
else col_sum(dst, src);
|
||||
}
|
||||
template<int ax, ducks::st::all T>
|
||||
__device__ static inline auto sum(const T &src) {
|
||||
using V = std::conditional_t<ax==axis::COL, typename T::col_vec, typename T::row_vec>;
|
||||
V dst;
|
||||
if constexpr (ax == axis::COL) row_sum(dst, src);
|
||||
else col_sum(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void prod(V &dst, const T &src, const V &src_accum) {
|
||||
if constexpr (ax == axis::COL) row_prod(dst, src, src_accum);
|
||||
else col_prod(dst, src, src_accum);
|
||||
}
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline auto prod(const T &src, const V &src_accum) {
|
||||
V dst;
|
||||
if constexpr (ax == axis::COL) row_prod(dst, src, src_accum);
|
||||
else col_prod(dst, src, src_accum);
|
||||
return dst;
|
||||
}
|
||||
template<int ax, ducks::sv::all V, ducks::st::all T>
|
||||
__device__ static inline void prod(V &dst, const T &src) {
|
||||
if constexpr (ax == axis::COL) row_prod(dst, src);
|
||||
else col_prod(dst, src);
|
||||
}
|
||||
template<int ax, ducks::st::all T>
|
||||
__device__ static inline auto prod(const T &src) {
|
||||
using V = std::conditional_t<ax==axis::COL, typename T::col_vec, typename T::row_vec>;
|
||||
V dst;
|
||||
if constexpr (ax == axis::COL) row_prod(dst, src);
|
||||
else col_prod(dst, src);
|
||||
return dst;
|
||||
}
|
||||
37
extra/thunder/cuda/include/ops/group/shared/tile/tile.cuh
Normal file
37
extra/thunder/cuda/include/ops/group/shared/tile/tile.cuh
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header for group operations on shared tiles.
|
||||
*/
|
||||
|
||||
#include "conversions.cuh"
|
||||
#include "maps.cuh"
|
||||
#include "reductions.cuh"
|
||||
|
||||
template<ducks::st::all ST>
|
||||
__device__ static inline bool hasnan(const ST &src) {
|
||||
KITTENS_CHECK_WARP
|
||||
bool nan_detected = false;
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < ST::num_elements; i+=GROUP_THREADS) {
|
||||
if constexpr (std::is_same_v<typename ST::T, float>) {
|
||||
if(isnan(src[i])) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename ST::T, bf16>) {
|
||||
if(isnan(__bfloat162float(src[i]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename ST::T, half>) {
|
||||
if(isnan(__half2float(src[i]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(typename ST::T) == 999, "Unsupported dtype");
|
||||
}
|
||||
}
|
||||
// Ballot across the warp to see if any lane detected a nan
|
||||
return (__ballot_sync(0xffffffff, nan_detected) != 0);
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group conversions on shared vectors.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Copies data from one shared vector to another, converting data types if necessary.
|
||||
*
|
||||
* This function copies data from the source shared vector `src` to the destination shared vector `dst`.
|
||||
* If the data types of `src` and `dst` are the same, it performs a direct memory copy. Otherwise, it
|
||||
* converts each element from the source data type to the destination data type using the appropriate
|
||||
* converter before copying.
|
||||
*
|
||||
* @tparam SV1 The type of the destination shared vector, must satisfy the ducks::sv::all concept.
|
||||
* @tparam SV2 The type of the source shared vector, must satisfy the ducks::sv::all concept.
|
||||
* @param[out] dst The destination shared vector.
|
||||
* @param[in] src The source shared vector.
|
||||
* @note The lengths of `src` and `dst` must be equal. This is enforced at compile time.
|
||||
*/
|
||||
template<ducks::sv::all SV1, ducks::sv::all SV2>
|
||||
__device__ static inline void copy(SV1 &dst, const SV2 &src) {
|
||||
static_assert(SV1::length == SV2::length, "Source and destination vectors must have the same length.");
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < dst.length; i+=GROUP_THREADS) {
|
||||
dst[i] = base_types::convertor<typename SV1::dtype, typename SV2::dtype>::convert(src[i]);
|
||||
}
|
||||
}
|
||||
259
extra/thunder/cuda/include/ops/group/shared/vec/maps.cuh
Normal file
259
extra/thunder/cuda/include/ops/group/shared/vec/maps.cuh
Normal file
@@ -0,0 +1,259 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group maps on shared vectors.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to each element of a shared memory vector.
|
||||
*
|
||||
* @tparam op Unary operation type.
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector in which to store the result.
|
||||
* @param src[in] Source vector to apply the unary operation.
|
||||
*/
|
||||
template<typename op, ducks::sv::all T>
|
||||
__device__ static inline void unary_op(T &dst, const T &src) {
|
||||
#pragma unroll
|
||||
for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) {
|
||||
dst[cur] = op::template op<typename T::dtype>(src[cur]);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on two shared vectors.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vectors.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param lhs[in] The left-hand side vector for the operation.
|
||||
* @param rhs[in] The right-hand side vector for the operation.
|
||||
*/
|
||||
template<typename op, ducks::sv::all T>
|
||||
__device__ static inline void bin_op(T &dst, const T &lhs, const T &rhs) {
|
||||
#pragma unroll
|
||||
for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) {
|
||||
dst[cur] = op::template op<typename T::dtype>(lhs[cur], rhs[cur]);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on a shared vector and a scalar.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vector.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param src[in] The source vector for the operation.
|
||||
* @param param[in] The scalar parameter for the operation.
|
||||
*/
|
||||
template<typename op, ducks::sv::all T>
|
||||
__device__ static inline void bin_op(T &dst, const T &src, const typename T::dtype ¶m) {
|
||||
#pragma unroll
|
||||
for(auto cur = laneid(); cur < T::length; cur+=GROUP_THREADS) {
|
||||
dst[cur] = op::template op<typename T::dtype>(src[cur], param);
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// ---- const ops ----
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to zero.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to zero.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void zero(T &dst) {
|
||||
unary_op<base_ops::zero, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to one.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to one.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void one(T &dst) {
|
||||
unary_op<base_ops::one, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to positive infinity.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to positive infinity.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void pos_infty(T &dst) {
|
||||
unary_op<base_ops::pos_infty, T>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to negative infinity.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to negative infinity.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void neg_infty(T &dst) {
|
||||
unary_op<base_ops::neg_infty, T>(dst, dst);
|
||||
}
|
||||
|
||||
// ---- unary ops ----
|
||||
|
||||
/**
|
||||
* @brief Copies the elements from one shared vector to another.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the source vector.
|
||||
* @param dst[out] Destination vector where the elements will be copied to.
|
||||
* @param src[in] Source vector to copy the elements from.
|
||||
*/
|
||||
template<ducks::sv::all T, typename U>
|
||||
__device__ static inline void copy(T &dst, const U &src) {
|
||||
bin_op<base_ops::copy2, T>(dst, dst, src); // the second arg is ignored here.
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void exp(T &dst, const T &src) {
|
||||
unary_op<base_ops::exp, T>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a shared vector, in base 2.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void exp2(T &dst, const T &src) {
|
||||
unary_op<base_ops::exp2, T>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the natural logarithm function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the logarithm values will be stored.
|
||||
* @param src[in] Source vector to apply the logarithm function to.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void log(T &dst, const T &src) {
|
||||
unary_op<base_ops::log, T>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the logarithm base 2 function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the logarithm base 2 values will be stored.
|
||||
* @param src[in] Source vector to apply the logarithm base 2 function to.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void log2(T &dst, const T &src) {
|
||||
unary_op<base_ops::log2, T>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the absolute value function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the absolute values will be stored.
|
||||
* @param src[in] Source vector to apply the absolute value function to.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void abs(T &dst, const T &src) {
|
||||
unary_op<base_ops::abs, T>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the rectified linear unit (ReLU) function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the ReLU values will be stored.
|
||||
* @param src[in] Source vector to apply the ReLU function to.
|
||||
*/
|
||||
template<ducks::sv::all T>
|
||||
__device__ static inline void relu(T &dst, const T &src) {
|
||||
unary_op<base_ops::relu, T>(dst, src);
|
||||
}
|
||||
|
||||
// ---- binary ops ----
|
||||
|
||||
/**
|
||||
* @brief Computes the element-wise maximum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the maximum values will be stored.
|
||||
* @param lhs[in] First vector for the maximum operation.
|
||||
* @param rhs[in] Second vector for the maximum operation.
|
||||
*/
|
||||
template<ducks::sv::all T, typename U>
|
||||
__device__ static inline void max(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::max, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise minimum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the minimum values will be stored.
|
||||
* @param lhs[in] First vector for the minimum operation.
|
||||
* @param rhs[in] Second vector for the minimum operation.
|
||||
*/
|
||||
template<ducks::sv::all T, typename U>
|
||||
__device__ static inline void min(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::min, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise sum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the sum values will be stored.
|
||||
* @param lhs[in] First vector for the sum operation.
|
||||
* @param rhs[in] Second vector for the sum operation.
|
||||
*/
|
||||
template<ducks::sv::all T, typename U>
|
||||
__device__ static inline void add(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::sum, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise difference of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the difference values will be stored.
|
||||
* @param lhs[in] First vector for the difference operation.
|
||||
* @param rhs[in] Second vector for the difference operation.
|
||||
*/
|
||||
template<ducks::sv::all T, typename U>
|
||||
__device__ static inline void sub(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::sub, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise product of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the product values will be stored.
|
||||
* @param lhs[in] First vector for the product operation.
|
||||
* @param rhs[in] Second vector for the product operation.
|
||||
*/
|
||||
template<ducks::sv::all T, typename U>
|
||||
__device__ static inline void mul(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::mul, T>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise division of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the division values will be stored.
|
||||
* @param lhs[in] First vector for the division operation.
|
||||
* @param rhs[in] Second vector for the division operation.
|
||||
*/
|
||||
template<ducks::sv::all T, typename U>
|
||||
__device__ static inline void div(T &dst, const T &lhs, const U &rhs) {
|
||||
bin_op<base_ops::div, T>(dst, lhs, rhs);
|
||||
}
|
||||
193
extra/thunder/cuda/include/ops/group/shared/vec/reductions.cuh
Normal file
193
extra/thunder/cuda/include/ops/group/shared/vec/reductions.cuh
Normal file
@@ -0,0 +1,193 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group reductions on shared vectors.
|
||||
*/
|
||||
|
||||
// The fastest way to do this, under most circumstances, is actually to just have each warp replicate it.
|
||||
// This is not true for enormous shared vectors, but doing that efficiently actually requires some extra scratch shared memory.
|
||||
// So, this is sufficient for the time being.
|
||||
template<typename op, ducks::sv::all SV, bool reset>
|
||||
__device__ static inline void reduce(typename SV::dtype &dst_accum, const SV &src, const typename SV::dtype &src_accum) {
|
||||
if constexpr (GROUP_WARPS == 1) {
|
||||
using T = SV::dtype;
|
||||
int lane = laneid();
|
||||
T accum;
|
||||
if(lane < src.length) accum = src[lane]; // initialize a register accumulator
|
||||
__syncwarp();
|
||||
for(int i = lane+kittens::WARP_THREADS; i < src.length; i+=kittens::WARP_THREADS) {
|
||||
accum = op::template op<T>(accum, src[i]);
|
||||
}
|
||||
__syncwarp();
|
||||
// We can now reduce within the warp.
|
||||
if constexpr (src.length > 16) {
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 16));
|
||||
__syncwarp();
|
||||
}
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 8));
|
||||
__syncwarp();
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 4));
|
||||
__syncwarp();
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 2));
|
||||
__syncwarp();
|
||||
accum = op::template op<T>(accum, packed_shfl_down_sync(kittens::MASK_ALL, accum, 1));
|
||||
__syncwarp();
|
||||
if constexpr (!reset) accum = op::template op<T>(accum, src_accum);
|
||||
// broadcast to all threads in the warp.
|
||||
dst_accum = packed_shfl_sync(kittens::MASK_ALL, accum, 0); // everyone takes from warp leader
|
||||
}
|
||||
else {
|
||||
::kittens::group<1>::reduce<op, SV, reset>(dst_accum, src, src_accum);
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Finds the maximum element in a shared memory vector.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] max_val The maximum value found in the vector.
|
||||
* @param[in] src The shared memory vector to find the maximum in.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline void max(typename SV::dtype &max_val, const SV &src) {
|
||||
reduce<base_ops::max, SV, true>(max_val, src, max_val);
|
||||
}
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline typename SV::dtype max(const SV &src) {
|
||||
typename SV::dtype max_val;
|
||||
reduce<base_ops::max, SV, true>(max_val, src, max_val);
|
||||
return max_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum element in a shared memory vector.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] min_val The minimum value found in the vector.
|
||||
* @param[in] src The shared memory vector to find the minimum in.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline void min(typename SV::dtype &min_val, const SV &src) {
|
||||
reduce<base_ops::min, SV, true>(min_val, src, min_val);
|
||||
}
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline typename SV::dtype min(const SV &src) {
|
||||
typename SV::dtype min_val;
|
||||
reduce<base_ops::min, SV, true>(min_val, src, min_val);
|
||||
return min_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the sum of elements in a shared memory vector.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] sum_val The sum of the values in the vector.
|
||||
* @param[in] src The shared memory vector to sum.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline void sum(typename SV::dtype &sum_val, const SV &src) {
|
||||
reduce<base_ops::sum, SV, true>(sum_val, src, sum_val);
|
||||
}
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline typename SV::dtype sum(const SV &src) {
|
||||
typename SV::dtype sum_val;
|
||||
reduce<base_ops::sum, SV, true>(sum_val, src, sum_val);
|
||||
return sum_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the product of elements in a shared memory vector.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] prod_val The product of the values in the vector.
|
||||
* @param[in] src The shared memory vector to multiply.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline void prod(typename SV::dtype &prod_val, const SV &src) {
|
||||
reduce<base_ops::mul, SV, true>(prod_val, src, prod_val);
|
||||
}
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline typename SV::dtype prod(const SV &src) {
|
||||
typename SV::dtype prod_val;
|
||||
reduce<base_ops::mul, SV, true>(prod_val, src, prod_val);
|
||||
return prod_val;
|
||||
}
|
||||
|
||||
// Three operand versions.
|
||||
|
||||
/**
|
||||
* @brief Finds the maximum element in a shared memory vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] max_val The maximum value found in the vector, accumulated with src_accum.
|
||||
* @param[in] src The shared memory vector to find the maximum in.
|
||||
* @param[in] src_accum The initial value to accumulate with the maximum value found.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline void max(typename SV::dtype &max_val, const SV &src, const typename SV::dtype &src_accum) {
|
||||
reduce<base_ops::max, SV, false>(max_val, src, src_accum);
|
||||
}
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline typename SV::dtype max(const SV &src, const typename SV::dtype &src_accum) {
|
||||
typename SV::dtype max_val;
|
||||
reduce<base_ops::max, SV, false>(max_val, src, src_accum);
|
||||
return max_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum element in a shared memory vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] min_val The minimum value found in the vector, accumulated with src_accum.
|
||||
* @param[in] src The shared memory vector to find the minimum in.
|
||||
* @param[in] src_accum The initial value to accumulate with the minimum value found.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline void min(typename SV::dtype &min_val, const SV &src, const typename SV::dtype &src_accum) {
|
||||
reduce<base_ops::min, SV, false>(min_val, src, src_accum);
|
||||
}
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline typename SV::dtype min(const SV &src, const typename SV::dtype &src_accum) {
|
||||
typename SV::dtype min_val;
|
||||
reduce<base_ops::min, SV, false>(min_val, src, src_accum);
|
||||
return min_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the sum of elements in a shared memory vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] sum_val The sum of the values in the vector, accumulated with src_accum.
|
||||
* @param[in] src The shared memory vector to sum.
|
||||
* @param[in] src_accum The initial value to accumulate with the sum of the vector.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline void sum(typename SV::dtype &sum_val, const SV &src, const typename SV::dtype &src_accum) {
|
||||
reduce<base_ops::sum, SV, false>(sum_val, src, src_accum);
|
||||
}
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline typename SV::dtype sum(const SV &src, const typename SV::dtype &src_accum) {
|
||||
typename SV::dtype sum_val;
|
||||
reduce<base_ops::sum, SV, false>(sum_val, src, src_accum);
|
||||
return sum_val;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the product of elements in a shared memory vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] prod_val The product of the values in the vector, accumulated with src_accum.
|
||||
* @param[in] src The shared memory vector to multiply.
|
||||
* @param[in] src_accum The initial value to accumulate with the product of the vector.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline void prod(typename SV::dtype &prod_val, const SV &src, const typename SV::dtype &src_accum) {
|
||||
reduce<base_ops::mul, SV, false>(prod_val, src, src_accum);
|
||||
}
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline typename SV::dtype prod(const SV &src, const typename SV::dtype &src_accum) {
|
||||
typename SV::dtype prod_val;
|
||||
reduce<base_ops::mul, SV, false>(prod_val, src, src_accum);
|
||||
return prod_val;
|
||||
}
|
||||
38
extra/thunder/cuda/include/ops/group/shared/vec/vec.cuh
Normal file
38
extra/thunder/cuda/include/ops/group/shared/vec/vec.cuh
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header for group operations on shared vectors.
|
||||
*/
|
||||
|
||||
#include "conversions.cuh"
|
||||
#include "maps.cuh"
|
||||
// no group vector reductions as they would require additional shared memory and synchronization, and those side effects just aren't worth it.
|
||||
// warp vector reductions should be plenty fast in 99.9% of situations.
|
||||
|
||||
template<ducks::sv::all SV>
|
||||
__device__ static inline bool hasnan(const SV &src) {
|
||||
KITTENS_CHECK_WARP
|
||||
bool nan_detected = false;
|
||||
#pragma unroll
|
||||
for(int i = laneid(); i < SV::length; i+=GROUP_THREADS) {
|
||||
if constexpr (std::is_same_v<typename SV::T, float>) {
|
||||
if(isnan(src[i])) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename SV::T, bf16>) {
|
||||
if(isnan(__bfloat162float(src[i]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename SV::T, half>) {
|
||||
if(isnan(__half2float(src[i]))) {
|
||||
nan_detected = true;
|
||||
}
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(typename SV::T) == 999, "Unsupported dtype");
|
||||
}
|
||||
}
|
||||
// Ballot across the warp to see if any lane detected a nan
|
||||
return (__ballot_sync(0xffffffff, nan_detected) != 0);
|
||||
}
|
||||
262
extra/thunder/cuda/include/ops/ops.cuh
Normal file
262
extra/thunder/cuda/include/ops/ops.cuh
Normal file
@@ -0,0 +1,262 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief A collection of all of the operations that ThunderKittens defines.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "thread/thread.cuh"
|
||||
#include "group/group.cuh"
|
||||
#include "device/device.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
// Operator overloading, which defaults to warp scope.
|
||||
|
||||
// Tile operators
|
||||
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline T operator+(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
warp::add(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void operator+=(T &lhs, const U &rhs) {
|
||||
warp::add(lhs, lhs, rhs);
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline T operator-(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
warp::sub(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void operator-=(T &lhs, const U &rhs) {
|
||||
warp::sub(lhs, lhs, rhs);
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline T operator*(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
warp::mul(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void operator*=(T &lhs, const U &rhs) {
|
||||
warp::mul(lhs, lhs, rhs);
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline T operator/(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
warp::div(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::all T, typename U>
|
||||
__device__ static inline void operator/=(T &lhs, const U &rhs) {
|
||||
warp::div(lhs, lhs, rhs);
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline T operator+(const T &src, const V &row_values) {
|
||||
T dst;
|
||||
warp::add_row(dst, src, row_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline T operator+(const T &src, const V &row_values) {
|
||||
T dst;
|
||||
warp::add_row(dst, src, row_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline void operator+=(T &lhs, const V &row_values) {
|
||||
warp::add_row(lhs, lhs, row_values);
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline void operator+=(T &lhs, const V &row_values) {
|
||||
warp::add_row(lhs, lhs, row_values);
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline T operator-(const T &src, const V &row_values) {
|
||||
T dst;
|
||||
warp::sub_row(dst, src, row_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline T operator-(const T &src, const V &row_values) {
|
||||
T dst;
|
||||
warp::sub_row(dst, src, row_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline void operator-=(T &lhs, const V &row_values) {
|
||||
warp::sub_row(lhs, lhs, row_values);
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline void operator-=(T &lhs, const V &row_values) {
|
||||
warp::sub_row(lhs, lhs, row_values);
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline T operator*(const T &src, const V &row_values) {
|
||||
T dst;
|
||||
warp::mul_row(dst, src, row_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline T operator*(const T &src, const V &row_values) {
|
||||
T dst;
|
||||
warp::mul_row(dst, src, row_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline void operator*=(T &lhs, const V &row_values) {
|
||||
warp::mul_row(lhs, lhs, row_values);
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline void operator*=(T &lhs, const V &row_values) {
|
||||
warp::mul_row(lhs, lhs, row_values);
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline T operator/(const T &src, const V &row_values) {
|
||||
T dst;
|
||||
warp::div_row(dst, src, row_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline T operator/(const T &src, const V &row_values) {
|
||||
T dst;
|
||||
warp::div_row(dst, src, row_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline void operator/=(T &lhs, const V &row_values) {
|
||||
warp::div_row(lhs, lhs, row_values);
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline void operator/=(T &lhs, const V &row_values) {
|
||||
warp::div_row(lhs, lhs, row_values);
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline T operator+(const T &src, const V &col_values) {
|
||||
T dst;
|
||||
warp::add_col(dst, src, col_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline T operator+(const T &src, const V &col_values) {
|
||||
T dst;
|
||||
warp::add_col(dst, src, col_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline void operator+=(T &lhs, const V &col_values) {
|
||||
warp::add_col(lhs, lhs, col_values);
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline void operator+=(T &lhs, const V &col_values) {
|
||||
warp::add_col(lhs, lhs, col_values);
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline T operator-(const T &src, const V &col_values) {
|
||||
T dst;
|
||||
warp::sub_col(dst, src, col_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline T operator-(const T &src, const V &col_values) {
|
||||
T dst;
|
||||
warp::sub_col(dst, src, col_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline void operator-=(T &lhs, const V &col_values) {
|
||||
warp::sub_col(lhs, lhs, col_values);
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline void operator-=(T &lhs, const V &col_values) {
|
||||
warp::sub_col(lhs, lhs, col_values);
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline T operator*(const T &src, const V &col_values) {
|
||||
T dst;
|
||||
warp::mul_col(dst, src, col_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline T operator*(const T &src, const V &col_values) {
|
||||
T dst;
|
||||
warp::mul_col(dst, src, col_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline void operator*=(T &lhs, const V &col_values) {
|
||||
warp::mul_col(lhs, lhs, col_values);
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline void operator*=(T &lhs, const V &col_values) {
|
||||
warp::mul_col(lhs, lhs, col_values);
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline T operator/(const T &src, const V &col_values) {
|
||||
T dst;
|
||||
warp::div_col(dst, src, col_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline T operator/(const T &src, const V &col_values) {
|
||||
T dst;
|
||||
warp::div_col(dst, src, col_values);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rt::row_layout T, ducks::rv::align_layout V>
|
||||
__device__ static inline void operator/=(T &lhs, const V &col_values) {
|
||||
warp::div_col(lhs, lhs, col_values);
|
||||
}
|
||||
template<ducks::rt::col_layout T, ducks::rv::ortho_layout V>
|
||||
__device__ static inline void operator/=(T &lhs, const V &col_values) {
|
||||
warp::div_col(lhs, lhs, col_values);
|
||||
}
|
||||
|
||||
// Vector operators
|
||||
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline T operator+(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
warp::add(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void operator+=(T &lhs, const U &rhs) {
|
||||
warp::add(lhs, lhs, rhs);
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline T operator-(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
warp::sub(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void operator-=(T &lhs, const U &rhs) {
|
||||
warp::sub(lhs, lhs, rhs);
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline T operator*(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
warp::mul(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void operator*=(T &lhs, const U &rhs) {
|
||||
warp::mul(lhs, lhs, rhs);
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline T operator/(const T &lhs, const U &rhs) {
|
||||
T dst;
|
||||
warp::div(dst, lhs, rhs);
|
||||
return dst;
|
||||
}
|
||||
template<ducks::rv::all T, typename U>
|
||||
__device__ static inline void operator/=(T &lhs, const U &rhs) {
|
||||
warp::div(lhs, lhs, rhs);
|
||||
}
|
||||
|
||||
}
|
||||
10
extra/thunder/cuda/include/ops/thread/memory/memory.cuh
Normal file
10
extra/thunder/cuda/include/ops/thread/memory/memory.cuh
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of warp memory operations, where a single warp loads or stores data on its own.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "util/util.cuh"
|
||||
#include "tile/tile.cuh"
|
||||
#include "vec/vec.cuh"
|
||||
10
extra/thunder/cuda/include/ops/thread/memory/tile/tile.cuh
Normal file
10
extra/thunder/cuda/include/ops/thread/memory/tile/tile.cuh
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of warp memory operations on tiles, where a single warp loads or stores data on its own.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include "tma.cuh"
|
||||
#endif
|
||||
564
extra/thunder/cuda/include/ops/thread/memory/tile/tma.cuh
Normal file
564
extra/thunder/cuda/include/ops/thread/memory/tile/tma.cuh
Normal file
@@ -0,0 +1,564 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.cuh"
|
||||
#include "../../../../types/types.cuh"
|
||||
#include "../util/util.cuh"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace kittens {
|
||||
namespace tma {
|
||||
|
||||
namespace detail {
|
||||
template<kittens::ducks::st::all ST, int axis> __device__ inline int4 tma_coords(const coord<ducks::default_type> &unit_coord) {
|
||||
constexpr int swizzle_elements = ST::swizzle_bytes / sizeof(typename ST::dtype);
|
||||
if constexpr (axis == 2) return {unit_coord.r, unit_coord.c / swizzle_elements, unit_coord.d, unit_coord.b};
|
||||
else if constexpr (axis == 1) return {unit_coord.d, unit_coord.c / swizzle_elements, unit_coord.r, unit_coord.b};
|
||||
else if constexpr (axis == 0) return {unit_coord.b, unit_coord.c / swizzle_elements, unit_coord.r, unit_coord.d};
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- Prefetch Tensor Map ---------- */
|
||||
|
||||
/**
|
||||
* @brief Prefetches data from global memory into a shared memory tile, along with the tensormap.
|
||||
*
|
||||
* @tparam ST A shared tile type with a TMA-compatible layout
|
||||
* @param[out] dst The destination shared memory tile.
|
||||
* @param[in] src_tma_map The source tensormap address in global memory
|
||||
* @param[in] tile_row_idx The row coord of the requested tile. This is in units of complete tiles.
|
||||
* @param[in] tile_col_idx The column coord of the requested tile. This is in units of complete tiles.
|
||||
*/
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void prefetch(ST &dst, const GL &src, const COORD &idx) {
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<ST, axis>());
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.prefetch.tensor.5d.L2.global.tile"
|
||||
" [%0, {%1, %2, %3, %4, %5}];"
|
||||
:
|
||||
: "l"(tma_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.prefetch.tensor.5d.L2.global.tile.L2::cache_hint"
|
||||
" [%0, {%1, %2, %3, %4, %5}], %6;"
|
||||
:
|
||||
: "l"(tma_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void prefetch(ST &dst, const GL &src, const COORD &idx) {
|
||||
prefetch<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
|
||||
/* ---------- Async load and store data from gmem/smem ---------- */
|
||||
|
||||
/**
|
||||
* @brief Asynchronously stores data into global memory from a shared memory tile.
|
||||
*
|
||||
* This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam ST A shared tile type with a TMA-compatible layout
|
||||
* @param[out] dst The destination tensormap address in global memory
|
||||
* @param[in] src_tma_map The source shared memory tile.
|
||||
* @param[in] tile_row_idx The row coord of the tile destination. This is in units of complete tiles.
|
||||
* @param[in] tile_col_idx The column coord of the tile destination. This is in units of complete tiles.
|
||||
*/
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<ST, axis>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
store_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<ST, axis>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.global.shared::cta.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
template<ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
store_async<dim::ROW, cache_policy::NORMAL>(dst, src, idx);
|
||||
}
|
||||
|
||||
/* ---------- Async reduction + store data from gmem/smem ---------- */
|
||||
|
||||
/**
|
||||
* @brief Asynchronously performs an add reduction and stores the result into global memory from a shared memory tile.
|
||||
*
|
||||
* This function performs an asynchronous add reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam ST A shared tile type with a TMA-compatible layout
|
||||
* @param[out] dst The destination tensormap address in global memory
|
||||
* @param[in] src_tma_map The source shared memory tile.
|
||||
* @param[in] tile_row_idx The row coord of the tile destination. This is in units of complete tiles.
|
||||
* @param[in] tile_col_idx The column coord of the tile destination. This is in units of complete tiles.
|
||||
*/
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_add_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
|
||||
static_assert(!(std::is_same_v<typename ST::dtype, fp8e4m3> ||
|
||||
std::is_same_v<typename ST::dtype, fp8e5m2>),
|
||||
"TMA does not support async add reductions for fp8 types.");
|
||||
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<ST, axis>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_add_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
store_add_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_add_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
|
||||
static_assert(!(std::is_same_v<typename ST::dtype, fp8e4m3> ||
|
||||
std::is_same_v<typename ST::dtype, fp8e5m2>),
|
||||
"TMA does not support async add reductions for fp8 types.");
|
||||
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<ST, axis>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
template<ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_add_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
store_add_async<dim::ROW, cache_policy::NORMAL>(dst, src, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Asynchronously performs an min reduction and stores the result into global memory from a shared memory tile.
|
||||
*
|
||||
* This function performs an asynchronous min reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam ST A shared tile type with a TMA-compatible layout
|
||||
* @param[out] dst The destination tensormap address in global memory
|
||||
* @param[in] src_tma_map The source shared memory tile.
|
||||
* @param[in] tile_row_idx The row coord of the tile destination. This is in units of complete tiles.
|
||||
* @param[in] tile_col_idx The column coord of the tile destination. This is in units of complete tiles.
|
||||
*/
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_min_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename ST::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
|
||||
static_assert(!(std::is_same_v<typename ST::dtype, fp8e4m3> ||
|
||||
std::is_same_v<typename ST::dtype, fp8e5m2>),
|
||||
"TMA does not support async add reductions for fp8 types.");
|
||||
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<ST, axis>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.min.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.min.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_min_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
store_min_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_min_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename ST::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
|
||||
static_assert(!(std::is_same_v<typename ST::dtype, fp8e4m3> ||
|
||||
std::is_same_v<typename ST::dtype, fp8e5m2>),
|
||||
"TMA does not support async add reductions for fp8 types.");
|
||||
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<ST, axis>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.min.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.min.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
template<ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_min_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
store_min_async<dim::ROW, cache_policy::NORMAL>(dst, src, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Asynchronously performs an max reduction and stores the result into global memory from a shared memory tile.
|
||||
*
|
||||
* This function performs an asynchronous max reduction and copy operation using CUDA's cp.reduce.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam ST A shared tile type with a TMA-compatible layout
|
||||
* @param[out] dst The destination tensormap address in global memory
|
||||
* @param[in] src_tma_map The source shared memory tile.
|
||||
* @param[in] tile_row_idx The row coord of the tile destination. This is in units of complete tiles.
|
||||
* @param[in] tile_col_idx The column coord of the tile destination. This is in units of complete tiles.
|
||||
*/
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_max_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename ST::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
|
||||
static_assert(!(std::is_same_v<typename ST::dtype, fp8e4m3> ||
|
||||
std::is_same_v<typename ST::dtype, fp8e5m2>),
|
||||
"TMA does not support async add reductions for fp8 types.");
|
||||
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<ST, axis>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.max.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.max.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_max_async(const GL &dst, const ST &src, const COORD &idx) {
|
||||
store_max_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx);
|
||||
}
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_max_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename ST::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
|
||||
static_assert(!(std::is_same_v<typename ST::dtype, fp8e4m3> ||
|
||||
std::is_same_v<typename ST::dtype, fp8e5m2>),
|
||||
"TMA does not support async add reductions for fp8 types.");
|
||||
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<ST, axis>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.max.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.5d.global.shared::cta.max.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5, %6}], [%1], %7;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
store_commit_group();
|
||||
}
|
||||
template<ducks::st::all ST, ducks::pgl::all PGL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void store_max_async(const PGL &dst, const ST &src, const COORD &idx) {
|
||||
store_max_async<dim::ROW, cache_policy::NORMAL>(dst, src, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Asynchronously loads data from global memory into a shared memory tile.
|
||||
*
|
||||
* This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam ST A shared tile type with a TMA-compatible layout
|
||||
* @param[out] dst The destination shared memory tile.
|
||||
* @param[in] src_tma_map The source tensormap address in global memory
|
||||
* @param[in,out] bar The semaphore used for synchronization of the asynchronous copy.
|
||||
* @param[in] tile_row_idx The row coord of the requested tile. This is in units of complete tiles.
|
||||
* @param[in] tile_col_idx The column coord of the requested tile. This is in units of complete tiles.
|
||||
*/
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar) {
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<ST, axis>());
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&bar));
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2];"
|
||||
:
|
||||
: "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile(
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
|
||||
:
|
||||
: "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar) {
|
||||
load_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx, bar);
|
||||
}
|
||||
|
||||
namespace cluster {
|
||||
|
||||
/**
|
||||
* @brief Asynchronously loads data from global memory into a shared memory tile, across a threadblock cluster
|
||||
*
|
||||
* This function performs an asynchronous copy operation using CUDA's cp.async.bulk.tensor instruction.
|
||||
*
|
||||
* @tparam ST A shared tile type with a TMA-compatible layout
|
||||
* @param[out] dst The destination shared memory tile.
|
||||
* @param[in] src_tma_map The source tensormap address in global memory
|
||||
* @param[in,out] bar The semaphore used for synchronization of the asynchronous copy.
|
||||
* @param[in] tile_row_idx The row coord of the requested tile. This is in units of complete tiles.
|
||||
* @param[in] tile_col_idx The column coord of the requested tile. This is in units of complete tiles.
|
||||
* @param[in] cluster_mask The mask of the clusters to broadcast to.
|
||||
*/
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1)
|
||||
#else
|
||||
template<int axis, cache_policy policy, ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask)
|
||||
#endif
|
||||
{
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<ST, axis>());
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&bar));
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst));
|
||||
coord<ducks::default_type> unit_coord = idx.template unit_coord<axis, 3>(); // convert to unit coordinates
|
||||
int4 tma_coords = detail::tma_coords<ST, axis>(unit_coord);
|
||||
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
if(dst_mbar_cta != -1) {
|
||||
uint32_t neighbor_mbar_ptr;
|
||||
asm volatile (
|
||||
"mapa.shared::cluster.u32 %0, %1, %2;\n"
|
||||
: "=r"(neighbor_mbar_ptr)
|
||||
: "r"(mbar_ptr), "r"(dst_mbar_cta)
|
||||
);
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.cta_group::2.multicast::cluster"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
|
||||
:
|
||||
: "r"(dst_ptr), "l"(tma_ptr), "r"(neighbor_mbar_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "h"(cluster_mask)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.cta_group::2.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8, %9;"
|
||||
:
|
||||
: "r"(dst_ptr), "l"(tma_ptr), "r"(neighbor_mbar_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "h"(cluster_mask), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;"
|
||||
:
|
||||
: "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "h"(cluster_mask)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.5d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8, %9;"
|
||||
:
|
||||
: "r"(dst_ptr), "l"(tma_ptr), "r"(mbar_ptr),
|
||||
"n"(0), "r"(tma_coords.x), "r"(tma_coords.y), "r"(tma_coords.z), "r"(tma_coords.w), "h"(cluster_mask), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) {
|
||||
load_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx, bar, cluster_mask, dst_mbar_cta);
|
||||
}
|
||||
#else
|
||||
template<ducks::st::all ST, ducks::gl::all GL, ducks::coord::tile COORD=coord<ST>>
|
||||
__device__ static inline void load_async(ST &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask) {
|
||||
load_async<dim::ROW, cache_policy::NORMAL, ST, GL, COORD>(dst, src, idx, bar, cluster_mask);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace cluster
|
||||
} // namespace tma
|
||||
|
||||
} // namespace kittens
|
||||
405
extra/thunder/cuda/include/ops/thread/memory/util/multimem.cuh
Normal file
405
extra/thunder/cuda/include/ops/thread/memory/util/multimem.cuh
Normal file
@@ -0,0 +1,405 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Wrappers for multimem operations
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace kittens {
|
||||
|
||||
enum class reduce_op {
|
||||
ADD = 0,
|
||||
MIN = 1,
|
||||
MAX = 2
|
||||
};
|
||||
|
||||
enum class memory_model {
|
||||
WEAK = 0,
|
||||
STRONG = 1
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct multimem;
|
||||
|
||||
template <>
|
||||
struct multimem<int> {
|
||||
template <reduce_op Op, memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void ld_reduce(int &dst, const int *src) {
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.add.s32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.add.s32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MIN) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.min.s32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.min.s32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MAX) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.max.s32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.max.s32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void st(int *dst, const int &src) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.st.weak.global.s32 [%0], %1;"
|
||||
:: "l"(dst), "r"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.st.release.sys.global.s32 [%0], %1;"
|
||||
:: "l"(dst), "r"(src) : "memory");
|
||||
}
|
||||
}
|
||||
template <reduce_op Op>
|
||||
__device__ static inline void red(int *dst, const int &src) {
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
asm volatile("multimem.red.release.sys.global.add.s32 [%0], %1;"
|
||||
: : "l"(dst), "r"(src) : "memory");
|
||||
} else if constexpr (Op == reduce_op::MIN) {
|
||||
asm volatile("multimem.red.release.sys.global.min.s32 [%0], %1;"
|
||||
: : "l"(dst), "r"(src) : "memory");
|
||||
} else if constexpr (Op == reduce_op::MAX) {
|
||||
asm volatile("multimem.red.release.sys.global.max.s32 [%0], %1;"
|
||||
: : "l"(dst), "r"(src) : "memory");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct multimem<uint> {
|
||||
template <reduce_op Op, memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void ld_reduce(uint &dst, const uint *src) {
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.add.u32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.add.u32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MIN) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.min.u32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.min.u32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MAX) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.max.u32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.max.u32 %0, [%1];"
|
||||
: "=r"(dst) : "l"(src) : "memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void st(uint *dst, const uint &src) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.st.weak.global.u32 [%0], %1;"
|
||||
:: "l"(dst), "r"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.st.release.sys.global.u32 [%0], %1;"
|
||||
:: "l"(dst), "r"(src) : "memory");
|
||||
}
|
||||
}
|
||||
template <reduce_op Op>
|
||||
__device__ static inline void red(uint *dst, const uint &src) {
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
asm volatile("multimem.red.release.sys.global.add.u32 [%0], %1;"
|
||||
: : "l"(dst), "r"(src) : "memory");
|
||||
} else if constexpr (Op == reduce_op::MIN) {
|
||||
asm volatile("multimem.red.release.sys.global.min.u32 [%0], %1;"
|
||||
: : "l"(dst), "r"(src) : "memory");
|
||||
} else if constexpr (Op == reduce_op::MAX) {
|
||||
asm volatile("multimem.red.release.sys.global.max.u32 [%0], %1;"
|
||||
: : "l"(dst), "r"(src) : "memory");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct multimem<float> {
|
||||
template <reduce_op Op, memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void ld_reduce(float &dst, const float *src) {
|
||||
static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f32 ld_reduce operations");
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.add.f32 %0, [%1];"
|
||||
: "=f"(dst) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.add.f32 %0, [%1];"
|
||||
: "=f"(dst) : "l"(src) : "memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void st(float *dst, const float &src) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.st.weak.global.f32 [%0], %1;"
|
||||
:: "l"(dst), "f"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.st.release.sys.global.f32 [%0], %1;"
|
||||
:: "l"(dst), "f"(src) : "memory");
|
||||
}
|
||||
}
|
||||
template <reduce_op Op>
|
||||
__device__ static inline void red(float *dst, const float &src) {
|
||||
static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f32 red operations");
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
asm volatile("multimem.red.release.sys.global.add.f32 [%0], %1;"
|
||||
: : "l"(dst), "f"(src) : "memory");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
struct multimem<float2> {
|
||||
template <reduce_op Op, memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void ld_reduce(float2 &dst, const float2 *src) {
|
||||
static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f32 ld_reduce operations");
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.add.v2.f32 {%0, %1}, [%2];"
|
||||
: "=f"(dst.x), "=f"(dst.y) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.add.v2.f32 {%0, %1}, [%2];"
|
||||
: "=f"(dst.x), "=f"(dst.y) : "l"(src) : "memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void st(float2 *dst, const float2 &src) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.st.weak.global.v2.f32 [%0], {%1, %2};"
|
||||
:: "l"(dst), "f"(src.x), "f"(src.y) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.st.release.sys.global.v2.f32 [%0], {%1, %2};"
|
||||
:: "l"(dst), "f"(src.x), "f"(src.y) : "memory");
|
||||
}
|
||||
}
|
||||
template <reduce_op Op>
|
||||
__device__ static inline void red(float2 *dst, const float2 &src) {
|
||||
static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f32 red operations");
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
asm volatile("multimem.red.release.sys.global.add.v2.f32 [%0], {%1, %2};"
|
||||
: : "l"(dst), "f"(src.x), "f"(src.y) : "memory");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct multimem<bf16> {
|
||||
template <reduce_op Op, memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void ld_reduce(bf16 &dst, const bf16 *src) {
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.add.acc::f32.bf16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.add.acc::f32.bf16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MIN) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.min.bf16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.min.bf16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MAX) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.max.bf16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.max.bf16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void st(bf16 *dst, const bf16 &src) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.st.weak.global.bf16 [%0], %1;"
|
||||
:: "l"(dst), "h"(*reinterpret_cast<const uint16_t *>(&src)) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.st.release.sys.global.bf16 [%0], %1;"
|
||||
:: "l"(dst), "h"(*reinterpret_cast<const uint16_t *>(&src)) : "memory");
|
||||
}
|
||||
}
|
||||
template <reduce_op Op>
|
||||
__device__ static inline void red(bf16 *dst, const bf16 &src) {
|
||||
static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for bf16 red operations");
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
asm volatile("multimem.red.release.sys.global.add.bf16 [%0], %1;"
|
||||
: : "l"(dst), "h"(*reinterpret_cast<const uint16_t *>(&src)) : "memory");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct multimem<bf16_2> {
|
||||
template <reduce_op Op, memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void ld_reduce(bf16_2 &dst, const bf16_2 *src) {
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.add.acc::f32.bf16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.add.acc::f32.bf16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MIN) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.min.bf16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.min.bf16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MAX) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.max.bf16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.max.bf16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void st(bf16_2 *dst, const bf16_2 &src) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.st.weak.global.bf16x2 [%0], %1;"
|
||||
:: "l"(dst), "r"(*reinterpret_cast<const uint32_t *>(&src)) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.st.release.sys.global.bf16x2 [%0], %1;"
|
||||
:: "l"(dst), "r"(*reinterpret_cast<const uint32_t *>(&src)) : "memory");
|
||||
}
|
||||
}
|
||||
template <reduce_op Op>
|
||||
__device__ static inline void red(bf16_2 *dst, const bf16_2 &src) {
|
||||
static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for bf16_2 red operations");
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
asm volatile("multimem.red.release.sys.global.add.bf16x2 [%0], %1;"
|
||||
: : "l"(dst), "r"(*reinterpret_cast<const uint32_t *>(&src)) : "memory");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct multimem<half> {
|
||||
template <reduce_op Op, memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void ld_reduce(half &dst, const half *src) {
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.add.acc::f32.f16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.add.acc::f32.f16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MIN) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.min.f16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.min.f16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MAX) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.max.f16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.max.f16 %0, [%1];"
|
||||
: "=h"(*reinterpret_cast<uint16_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void st(half *dst, const half &src) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.st.weak.global.f16 [%0], %1;"
|
||||
:: "l"(dst), "h"(*reinterpret_cast<const uint16_t *>(&src)) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.st.release.sys.global.f16 [%0], %1;"
|
||||
:: "l"(dst), "h"(*reinterpret_cast<const uint16_t *>(&src)) : "memory");
|
||||
}
|
||||
}
|
||||
template <reduce_op Op>
|
||||
__device__ static inline void red(half *dst, const half &src) {
|
||||
static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f16 red operations");
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
asm volatile("multimem.red.release.sys.global.add.f16 [%0], %1;"
|
||||
: : "l"(dst), "h"(*reinterpret_cast<const uint16_t *>(&src)) : "memory");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct multimem<half_2> {
|
||||
template <reduce_op Op, memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void ld_reduce(half_2 &dst, const half_2 *src) {
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.add.acc::f32.f16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.add.acc::f32.f16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MIN) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.min.f16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.min.f16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
} else if constexpr (Op == reduce_op::MAX) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.ld_reduce.weak.global.max.f16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.ld_reduce.acquire.sys.global.max.f16x2 %0, [%1];"
|
||||
: "=r"(*reinterpret_cast<uint32_t *>(&dst)) : "l"(src) : "memory");
|
||||
}
|
||||
}
|
||||
}
|
||||
template <memory_model M = memory_model::WEAK>
|
||||
__device__ static inline void st(half_2 *dst, const half_2 &src) {
|
||||
if constexpr (M == memory_model::WEAK) {
|
||||
asm volatile("multimem.st.weak.global.f16x2 [%0], %1;"
|
||||
:: "l"(dst), "r"(*reinterpret_cast<const uint32_t *>(&src)) : "memory");
|
||||
} else if constexpr (M == memory_model::STRONG) {
|
||||
asm volatile("multimem.st.release.sys.global.f16x2 [%0], %1;"
|
||||
:: "l"(dst), "r"(*reinterpret_cast<const uint32_t *>(&src)) : "memory");
|
||||
}
|
||||
}
|
||||
template <reduce_op Op>
|
||||
__device__ static inline void red(half_2 *dst, const half_2 &src) {
|
||||
static_assert(Op == reduce_op::ADD, "MIN/MAX are not supported for f16_2 red operations");
|
||||
if constexpr (Op == reduce_op::ADD) {
|
||||
asm volatile("multimem.red.release.sys.global.add.f16x2 [%0], %1;"
|
||||
: : "l"(dst), "r"(*reinterpret_cast<const uint32_t *>(&src)) : "memory");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace kittens
|
||||
30
extra/thunder/cuda/include/ops/thread/memory/util/tensor.cuh
Normal file
30
extra/thunder/cuda/include/ops/thread/memory/util/tensor.cuh
Normal file
@@ -0,0 +1,30 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between tensor memory and register memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "../../../../common/common.cuh"
|
||||
#include "../../../../types/types.cuh"
|
||||
#include "util.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
__device__ static inline void tensor_before_thread_sync() {
|
||||
asm volatile("tcgen05.fence::before_thread_sync;\n");
|
||||
}
|
||||
__device__ static inline void tensor_after_thread_sync() {
|
||||
asm volatile("tcgen05.fence::after_thread_sync;\n");
|
||||
}
|
||||
|
||||
__device__ inline static void tensor_load_wait() {
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
}
|
||||
__device__ inline static void tensor_store_wait() {
|
||||
asm volatile("tcgen05.wait::st.sync.aligned;");
|
||||
}
|
||||
|
||||
}
|
||||
249
extra/thunder/cuda/include/ops/thread/memory/util/tma.cuh
Normal file
249
extra/thunder/cuda/include/ops/thread/memory/util/tma.cuh
Normal file
@@ -0,0 +1,249 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.cuh"
|
||||
#include "../../../../types/types.cuh"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace kittens {
|
||||
/**
|
||||
* @brief A namespace for all of ThunderKittens' TMA functionality.
|
||||
*/
|
||||
namespace tma {
|
||||
|
||||
/* ---------- Barrier functions for async load ---------- */
|
||||
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore.
|
||||
*
|
||||
* This function sets the number of bytes expected at the semaphore for the first thread in the warp.
|
||||
* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly
|
||||
* instruction to set the expected number of bytes.
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param bytes The number of bytes expected at the semaphore.
|
||||
*/
|
||||
__device__ static inline void expect_bytes(semaphore& bar, uint32_t bytes) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t bar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n"
|
||||
:: "r"(bar_ptr), "r"(bytes));
|
||||
}
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore.
|
||||
*
|
||||
* This function sets the number of bytes expected at the mbarrier before the transaction arrives.
|
||||
*/
|
||||
template<typename T, typename... args>
|
||||
__device__ static inline void expect(semaphore& bar, const T& _1, const args&... _2) {
|
||||
expect_bytes(bar, size_bytes<T, args...>);
|
||||
}
|
||||
|
||||
/* ---------- Synchronization functions for async store ---------- */
|
||||
|
||||
/**
|
||||
* @brief Commits previous asynchronous TMA stores to a group and performs them.
|
||||
*/
|
||||
__device__ static inline void store_commit_group() {
|
||||
asm volatile("cp.async.bulk.commit_group;");
|
||||
}
|
||||
/**
|
||||
* @brief Waits for previous committed TMA store groups to complete.
|
||||
*
|
||||
* @tparam N The maximum number of remaining TMA store groups. Defaults to 0.
|
||||
*/
|
||||
template <int N=0>
|
||||
__device__ static inline void store_async_wait() {
|
||||
asm volatile (
|
||||
"cp.async.bulk.wait_group %0;"
|
||||
:
|
||||
: "n"(N)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Waits for previous committed TMA store groups to finish reading from shared memory.
|
||||
*
|
||||
* @tparam N The maximum number of remaining TMA store groups. Defaults to 0.
|
||||
*/
|
||||
template <int N=0>
|
||||
__device__ static inline void store_async_read_wait() {
|
||||
asm volatile (
|
||||
"cp.async.bulk.wait_group.read %0;"
|
||||
:
|
||||
: "n"(N)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
|
||||
/* ---------- Cluster-scope operations ---------- */
|
||||
|
||||
namespace cluster {
|
||||
|
||||
/**
|
||||
* @brief Waits for the requested semaphore phase, at cluster scope
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void wait(semaphore& bar, int kPhaseBit) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
}
|
||||
|
||||
__device__ static inline void careful_wait(semaphore& bar, int kPhaseBit) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .b64 start_clock, current_clock;\n"
|
||||
"mov.b64 start_clock, %clock64;\n"
|
||||
".reg .pred P_CLOCK;\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"mov.b64 current_clock, %clock64;\n"
|
||||
"sub.u64 current_clock, current_clock, start_clock;\n"
|
||||
"setp.ge.u64 P_CLOCK, current_clock, 1000000;\n"
|
||||
"@P_CLOCK trap;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore, assuming a multicast instruction.
|
||||
*
|
||||
* This function sets the number of bytes expected at the semaphore for the first thread in the warp.
|
||||
* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly
|
||||
* instruction to set the expected number of bytes.
|
||||
*
|
||||
* It's worth being aware that this function is particularly necessary for multicast loads, and
|
||||
* distributed shared memory can actually be done with a normal tma::expect followed by wait. See
|
||||
* the unit tests of dsmem for an example.
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param bytes The number of bytes expected at the semaphore.
|
||||
*/
|
||||
__device__ static inline void expect_bytes(semaphore& bar, uint32_t bytes, int dst_cta) {
|
||||
uint32_t mbar_addr = static_cast<uint32_t>(__cvta_generic_to_shared(&bar));
|
||||
uint32_t neighbor_mbar_addr;
|
||||
asm volatile (
|
||||
"mapa.shared::cluster.u32 %0, %1, %2;\n"
|
||||
: "=r"(neighbor_mbar_addr)
|
||||
: "r"(mbar_addr), "r"(dst_cta)
|
||||
);
|
||||
|
||||
asm volatile ("mbarrier.arrive.expect_tx.shared::cluster.b64 _, [%0], %1;\n"
|
||||
:: "r"(neighbor_mbar_addr), "r"(bytes));
|
||||
}
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore.
|
||||
*
|
||||
* This function sets the number of bytes expected at the semaphore for the first thread in the warp.
|
||||
* It converts the semaphore pointer to a generic shared memory pointer and uses an inline assembly
|
||||
* instruction to set the expected number of bytes.
|
||||
*
|
||||
* @tparam T The type of the data to be stored at the semaphore.
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
*/
|
||||
/**
|
||||
* @brief Sets the number of bytes expected at the semaphore.
|
||||
*
|
||||
* This function sets the number of bytes expected at the mbarrier before the transaction arrives.
|
||||
*/
|
||||
template<typename T, typename... args>
|
||||
__device__ static inline void expect(semaphore& bar, int dst_cta, const T& _1, const args&... _2) {
|
||||
expect_bytes(bar, size_bytes<T, args...>, dst_cta);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Arrives at a semaphore in cluster scope.
|
||||
*
|
||||
* Marks a thread arrival at an mbarrier
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void arrive(semaphore& bar, int dst_cta, uint32_t count=1) {
|
||||
uint32_t mbar_addr = static_cast<uint32_t>(__cvta_generic_to_shared(&bar));
|
||||
uint32_t neighbor_mbar_addr;
|
||||
asm volatile (
|
||||
"mapa.shared::cluster.u32 %0, %1, %2;\n"
|
||||
: "=r"(neighbor_mbar_addr)
|
||||
: "r"(mbar_addr), "r"(dst_cta)
|
||||
);
|
||||
asm volatile (
|
||||
"mbarrier.arrive.shared::cluster.b64 _, [%0], %1;\n"
|
||||
:
|
||||
: "r"(neighbor_mbar_addr), "r" (count)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
|
||||
// Generic transfer
|
||||
__device__ static inline void store_async(void *dst, void *src, int dst_cta, uint32_t size_bytes, semaphore& bar) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t mbarrier_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
// **************************************************
|
||||
// load from src to dst in different threadblocks
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(src));
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(dst));
|
||||
|
||||
// mapa instr = https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mapa
|
||||
// find dst addr in neighbor's cta
|
||||
uint32_t neighbor_addr_dst;
|
||||
asm volatile (
|
||||
"mapa.shared::cluster.u32 %0, %1, %2;\n"
|
||||
: "=r"(neighbor_addr_dst)
|
||||
: "r"(dst_ptr), "r"(dst_cta)
|
||||
);
|
||||
|
||||
uint32_t neighbor_addr_mbarrier = mbarrier_ptr;
|
||||
asm volatile (
|
||||
"mapa.shared::cluster.u32 %0, %1, %2;\n"
|
||||
: "=r"(neighbor_addr_mbarrier)
|
||||
: "r"(mbarrier_ptr), "r"(dst_cta)
|
||||
);
|
||||
|
||||
// cp.async instr = https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk
|
||||
// copy src into dst in neighbor's cta
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
asm volatile (
|
||||
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n"
|
||||
:
|
||||
: "r"(neighbor_addr_dst), "r"(src_ptr), "r"(size_bytes), "r"(neighbor_addr_mbarrier)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
|
||||
// Templated transfer for convenience
|
||||
template<typename T>
|
||||
__device__ static inline void store_async(T &dst_, T &src_, int dst_cta, semaphore& bar) {
|
||||
store_async((void*)&dst_, (void*)&src_, dst_cta, size_bytes<T>, bar);
|
||||
}
|
||||
|
||||
} // namespace cluster
|
||||
} // namespace tma
|
||||
} // namespace kittens
|
||||
443
extra/thunder/cuda/include/ops/thread/memory/util/util.cuh
Normal file
443
extra/thunder/cuda/include/ops/thread/memory/util/util.cuh
Normal file
@@ -0,0 +1,443 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief General memory utilities not specialized for either tiles or vectors.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/* ---------- To prevent generic addressing, PTX ---------- */
|
||||
|
||||
template<typename T> struct move {
|
||||
__device__ static inline void lds(T& dst, uint32_t src);
|
||||
__device__ static inline void sts(uint32_t dst, const T& src);
|
||||
__device__ static inline void ldg(T& dst, T* src);
|
||||
__device__ static inline void stg(T* dst, const T& src);
|
||||
};
|
||||
// unpacked types
|
||||
template<> struct move<bf16> {
|
||||
__device__ static inline void lds(bf16& dst, uint32_t src) {
|
||||
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(*(uint16_t*)&dst) : "r"(src));
|
||||
}
|
||||
__device__ static inline void sts(uint32_t dst, const bf16& src) {
|
||||
asm volatile("st.shared.b16 [%1], %0;\n" : : "h"(*(uint16_t*)&src), "r"(dst));
|
||||
}
|
||||
__device__ static inline void ldg(bf16& dst, bf16* src) {
|
||||
asm volatile("ld.global.b16 %0, [%1];\n" : "=h"(*(uint16_t*)&dst) : "l"(src));
|
||||
}
|
||||
__device__ static inline void stg(bf16* dst, const bf16& src) {
|
||||
asm volatile("st.global.b16 [%1], %0;\n" : : "h"(*(uint16_t*)&src), "l"(dst));
|
||||
}
|
||||
};
|
||||
template<> struct move<half> {
|
||||
__device__ static inline void lds(half& dst, uint32_t src) {
|
||||
asm volatile("ld.shared.b16 %0, [%1];\n" : "=h"(*(uint16_t*)&dst) : "r"(src));
|
||||
}
|
||||
__device__ static inline void sts(uint32_t dst, const half& src) {
|
||||
asm volatile("st.shared.b16 [%1], %0;\n" : : "h"(*(uint16_t*)&src), "r"(dst));
|
||||
}
|
||||
__device__ static inline void ldg(half& dst, half* src) {
|
||||
asm volatile("ld.global.b16 %0, [%1];\n" : "=h"(*(uint16_t*)&dst) : "l"(src));
|
||||
}
|
||||
__device__ static inline void stg(half* dst, const half& src) {
|
||||
asm volatile("st.global.b16 [%1], %0;\n" : : "h"(*(uint16_t*)&src), "l"(dst));
|
||||
}
|
||||
};
|
||||
template<> struct move<float> {
|
||||
__device__ static inline void lds(float& dst, uint32_t src) {
|
||||
asm volatile("ld.shared.f32 %0, [%1];\n" : "=f"(dst) : "r"(src));
|
||||
}
|
||||
__device__ static inline void sts(uint32_t dst, const float& src) {
|
||||
asm volatile("st.shared.f32 [%1], %0;\n" : : "f"(src), "r"(dst));
|
||||
}
|
||||
__device__ static inline void ldg(float& dst, float* src) {
|
||||
asm volatile("ld.global.f32 %0, [%1];\n" : "=f"(dst) : "l"(src));
|
||||
}
|
||||
__device__ static inline void stg(float* dst, const float& src) {
|
||||
asm volatile("st.global.f32 [%1], %0;\n" : : "f"(src), "l"(dst));
|
||||
}
|
||||
};
|
||||
template<> struct move<int> {
|
||||
__device__ static inline void lds(int& dst, uint32_t src) {
|
||||
asm volatile("ld.shared.u32 %0, [%1];\n" : "=r"(dst) : "r"(src));
|
||||
}
|
||||
__device__ static inline void sts(uint32_t dst, const int& src) {
|
||||
asm volatile("st.shared.u32 [%1], %0;\n" : : "r"(src), "r"(dst));
|
||||
}
|
||||
__device__ static inline void ldg(int& dst, int* src) {
|
||||
asm volatile("ld.global.u32 %0, [%1];\n" : "=r"(dst) : "l"(src));
|
||||
}
|
||||
__device__ static inline void stg(int* dst, const int& src) {
|
||||
asm volatile("st.global.u32 [%1], %0;\n" : : "r"(src), "l"(dst));
|
||||
}
|
||||
};
|
||||
// packed types
|
||||
template<> struct move<bf16_2> {
|
||||
__device__ static inline void lds(bf16_2& dst, uint32_t src) {
|
||||
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(*(uint32_t*)&dst) : "r"(src));
|
||||
}
|
||||
__device__ static inline void sts(uint32_t dst, const bf16_2& src) {
|
||||
asm volatile("st.shared.b32 [%1], %0;\n" : : "r"(*(uint32_t*)&src), "r"(dst));
|
||||
}
|
||||
__device__ static inline void ldg(bf16_2& dst, bf16_2* src) {
|
||||
asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(*(uint32_t*)&dst) : "l"(src));
|
||||
}
|
||||
__device__ static inline void stg(bf16_2* dst, const bf16_2& src) {
|
||||
asm volatile("st.global.b32 [%1], %0;\n" : : "r"(*(uint32_t*)&src), "l"(dst));
|
||||
}
|
||||
__device__ static inline void ldsm4(bf16_2& dst1, bf16_2& dst2, bf16_2& dst3, bf16_2& dst4, uint32_t src) {
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" :
|
||||
"=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src));
|
||||
}
|
||||
__device__ static inline void ldsm4t(bf16_2& dst1, bf16_2& dst2, bf16_2& dst3, bf16_2& dst4, uint32_t src) {
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" :
|
||||
"=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src));
|
||||
}
|
||||
__device__ static inline void stsm4(uint32_t dst, bf16_2& src1, bf16_2& src2, bf16_2& src3, bf16_2& src4) {
|
||||
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" ::
|
||||
"r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst));
|
||||
}
|
||||
__device__ static inline void stsm4t(uint32_t dst, bf16_2& src1, bf16_2& src2, bf16_2& src3, bf16_2& src4) {
|
||||
asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" ::
|
||||
"r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst));
|
||||
}
|
||||
};
|
||||
template<> struct move<half_2> {
|
||||
__device__ static inline void lds(half_2& dst, uint32_t src) {
|
||||
asm volatile("ld.shared.b32 %0, [%1];\n" : "=r"(*(uint32_t*)&dst) : "r"(src));
|
||||
}
|
||||
__device__ static inline void sts(uint32_t dst, const half_2& src) {
|
||||
asm volatile("st.shared.b32 [%1], %0;\n" : : "r"(*(uint32_t*)&src), "r"(dst));
|
||||
}
|
||||
__device__ static inline void ldg(half_2& dst, half_2* src) {
|
||||
asm volatile("ld.global.b32 %0, [%1];\n" : "=r"(*(uint32_t*)&dst) : "l"(src));
|
||||
}
|
||||
__device__ static inline void stg(half_2* dst, const half_2& src) {
|
||||
asm volatile("st.global.b32 [%1], %0;\n" : : "r"(*(uint32_t*)&src), "l"(dst));
|
||||
}
|
||||
__device__ static inline void ldsm4(half_2& dst1, half_2& dst2, half_2& dst3, half_2& dst4, uint32_t src) {
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" :
|
||||
"=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src));
|
||||
}
|
||||
__device__ static inline void ldsm4t(half_2& dst1, half_2& dst2, half_2& dst3, half_2& dst4, uint32_t src) {
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" :
|
||||
"=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src));
|
||||
}
|
||||
__device__ static inline void stsm4(uint32_t dst, half_2& src1, half_2& src2, half_2& src3, half_2& src4) {
|
||||
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" ::
|
||||
"r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst));
|
||||
}
|
||||
__device__ static inline void stsm4t(uint32_t dst, half_2& src1, half_2& src2, half_2& src3, half_2& src4) {
|
||||
asm volatile("stmatrix.sync.aligned.m8n8.x4.trans.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" ::
|
||||
"r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst));
|
||||
}
|
||||
};
|
||||
template<> struct move<float2> {
|
||||
__device__ static inline void lds(float2& dst, uint32_t src) {
|
||||
asm volatile("ld.shared.v2.f32 {%0, %1}, [%2];\n" : "=f"(dst.x), "=f"(dst.y) : "r"(src));
|
||||
}
|
||||
__device__ static inline void sts(uint32_t dst, const float2& src) {
|
||||
asm volatile("st.shared.v2.f32 [%2], {%0, %1};\n" : : "f"(src.x), "f"(src.y), "r"(dst));
|
||||
}
|
||||
__device__ static inline void ldg(float2& dst, float2* src) {
|
||||
asm volatile("ld.global.v2.f32 {%0, %1}, [%2];\n" : "=f"(dst.x), "=f"(dst.y) : "l"(src));
|
||||
}
|
||||
__device__ static inline void stg(float2* dst, const float2& src) {
|
||||
asm volatile("st.global.v2.f32 [%2], {%0, %1};\n" : : "f"(src.x), "f"(src.y), "l"(dst));
|
||||
}
|
||||
};
|
||||
template<> struct move<float4> {
|
||||
__device__ static inline void lds(float4& dst, uint32_t src) {
|
||||
asm volatile("ld.shared.v4.f32 {%0, %1, %2, %3}, [%4];\n" : "=f"(dst.x), "=f"(dst.y), "=f"(dst.z), "=f"(dst.w) : "r"(src));
|
||||
}
|
||||
__device__ static inline void sts(uint32_t dst, const float4& src) {
|
||||
asm volatile("st.shared.v4.f32 [%4], {%0, %1, %2, %3};\n" : : "f"(src.x), "f"(src.y), "f"(src.z), "f"(src.w), "r"(dst));
|
||||
}
|
||||
__device__ static inline void ldg(float4& dst, float4* src) {
|
||||
asm volatile("ld.global.v4.f32 {%0, %1, %2, %3}, [%4];\n" : "=f"(dst.x), "=f"(dst.y), "=f"(dst.z), "=f"(dst.w) : "l"(src));
|
||||
}
|
||||
__device__ static inline void stg(float4* dst, const float4& src) {
|
||||
asm volatile("st.global.v4.f32 [%4], {%0, %1, %2, %3};\n" : : "f"(src.x), "f"(src.y), "f"(src.z), "f"(src.w), "l"(dst));
|
||||
}
|
||||
};
|
||||
#ifdef KITTENS_HOPPER
|
||||
template<> struct move<fp8e4m3_4> {
|
||||
__device__ static inline void ldsm4(fp8e4m3_4& dst1, fp8e4m3_4& dst2, fp8e4m3_4& dst3, fp8e4m3_4& dst4, uint32_t src) {
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" :
|
||||
"=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src));
|
||||
}
|
||||
__device__ static inline void stsm4(uint32_t dst, fp8e4m3_4& src1, fp8e4m3_4& src2, fp8e4m3_4& src3, fp8e4m3_4& src4) {
|
||||
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" ::
|
||||
"r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst));
|
||||
}
|
||||
|
||||
};
|
||||
template<> struct move<fp8e5m2_4> {
|
||||
__device__ static inline void ldsm4(fp8e5m2_4& dst1, fp8e5m2_4& dst2, fp8e5m2_4& dst3, fp8e5m2_4& dst4, uint32_t src) {
|
||||
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared::cta.b16 {%0, %1, %2, %3}, [%4];\n" :
|
||||
"=r"(*(uint32_t*)&dst1), "=r"(*(uint32_t*)&dst2), "=r"(*(uint32_t*)&dst3), "=r"(*(uint32_t*)&dst4) : "r"(src));
|
||||
}
|
||||
__device__ static inline void stsm4(uint32_t dst, fp8e5m2_4& src1, fp8e5m2_4& src2, fp8e5m2_4& src3, fp8e5m2_4& src4) {
|
||||
asm volatile("stmatrix.sync.aligned.m8n8.x4.shared::cta.b16 [%4], {%0, %1, %2, %3};\n" ::
|
||||
"r"(*(uint32_t*)&src1), "r"(*(uint32_t*)&src2), "r"(*(uint32_t*)&src3), "r"(*(uint32_t*)&src4), "r"(dst));
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
/* ---------- Constants for Cache policies ---------- */
|
||||
|
||||
enum cache_policy {
|
||||
NORMAL = 0,
|
||||
EVICT_FIRST = 1,
|
||||
EVICT_LAST = 2
|
||||
};
|
||||
template<cache_policy policy> __device__ inline uint64_t make_cache_policy() {
|
||||
uint64_t cache_policy_val;
|
||||
constexpr float fraction = 1.0f;
|
||||
static_assert(policy == cache_policy::EVICT_FIRST || policy == cache_policy::EVICT_LAST, "Unexpected cache policy");
|
||||
if constexpr (policy == cache_policy::EVICT_FIRST) {
|
||||
asm volatile("createpolicy.fractional.L2::evict_first.b64 %0, %1;\n" : "=l"(cache_policy_val) : "f"(fraction));
|
||||
}
|
||||
else {
|
||||
asm volatile("createpolicy.fractional.L2::evict_last.b64 %0, %1;\n" : "=l"(cache_policy_val) : "f"(fraction));
|
||||
}
|
||||
return cache_policy_val;
|
||||
}
|
||||
/* ---------- Generic (non-Hopper specific) semaphore functions ---------- */
|
||||
|
||||
struct semaphore {
|
||||
private:
|
||||
uint64_t value;
|
||||
}; // note that this is an opaque type, so the value should not be accessed directly.
|
||||
template<int num_warps> struct barrier {
|
||||
int barrier_id;
|
||||
__device__ __forceinline__ barrier(int _id) : barrier_id(_id) {}
|
||||
__device__ __forceinline__ barrier operator[](int i) {
|
||||
return barrier(barrier_id + i);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Initializes a synchronization semaphore with a transaction count and sets the expected number of bytes.
|
||||
*
|
||||
* This function sets up a semaphore that is used to synchronize threads within a block during asynchronous operations.
|
||||
* It initializes the semaphore with a thread count semaphore.
|
||||
*
|
||||
* Additionally, if it is given a shared tile type, it will also call `set_bytes` to prepare for the memory transaction.
|
||||
*
|
||||
* @param[out] semaphore The semaphore variable to initialize.
|
||||
* @param[in] tc The thread counter for the semaphore.
|
||||
*/
|
||||
__device__ static inline void init_semaphore(semaphore& bar, int thread_count, int transaction_count=0) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t bar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
asm volatile (
|
||||
"mbarrier.init.shared::cta.b64 [%0], %1;\n"
|
||||
:: "r"(bar_ptr), "r"(thread_count+transaction_count)
|
||||
);
|
||||
}
|
||||
/**
|
||||
* @brief Invalidate an mbarrier
|
||||
*
|
||||
* @param[out] semaphore The semaphore variable to initialize.
|
||||
* @param[in] tc The thread counter for the semaphore.
|
||||
*/
|
||||
__device__ static inline void invalidate_semaphore(semaphore& bar) {
|
||||
void const* const ptr = &bar;
|
||||
uint32_t bar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
asm volatile (
|
||||
"mbarrier.inval.shared::cta.b64 [%0];\n"
|
||||
:: "r"(bar_ptr)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Arrives at a semaphore.
|
||||
*
|
||||
* Marks a warp arrival at an mbarrier
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void arrive(semaphore& sem) {
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&sem));
|
||||
asm volatile (
|
||||
"mbarrier.arrive.release.cta.shared::cta.b64 _, [%0];\n"
|
||||
:
|
||||
: "r"(mbar_ptr)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
template<int num_warps> __device__ static inline void arrive(barrier<num_warps> bar) {
|
||||
asm volatile("bar.arrive %0, %1;\n" :: "r"(bar.barrier_id), "n"(num_warps*WARP_THREADS) : "memory");
|
||||
}
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
/**
|
||||
* @brief Arrives at a semaphore.
|
||||
*
|
||||
* Marks a warp arrival at an mbarrier
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void arrive(semaphore& sem, uint32_t count) {
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&sem));
|
||||
asm volatile (
|
||||
"mbarrier.arrive.release.cta.shared::cta.b64 _, [%0], %1;\n"
|
||||
:
|
||||
: "r"(mbar_ptr), "r"(count)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
#endif
|
||||
|
||||
/**
|
||||
* @brief Waits for the requested semaphore phase.
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline void wait(semaphore& sem, int kPhaseBit) {
|
||||
void const* const ptr = &sem;
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures to save instruction issue slots
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ static inline void careful_wait(semaphore& sem, int kPhaseBit) {
|
||||
void const* const ptr = &sem;
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .b64 start_clock, current_clock;\n"
|
||||
"mov.b64 start_clock, %clock64;\n"
|
||||
".reg .pred P_CLOCK;\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"mov.b64 current_clock, %clock64;\n"
|
||||
"sub.u64 current_clock, current_clock, start_clock;\n"
|
||||
"setp.ge.u64 P_CLOCK, current_clock, 1000000;\n"
|
||||
"@P_CLOCK trap;\n"
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
#else
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"LAB_WAIT:\n"
|
||||
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%0], %1;\n"
|
||||
"@P1 bra.uni DONE;\n"
|
||||
"nanosleep.u32 5;\n" // wait a few nanoseconds on pre-Hopper architectures to save instruction issue slots
|
||||
"bra.uni LAB_WAIT;\n"
|
||||
"DONE:\n"
|
||||
"}\n"
|
||||
:: "r"(mbar_ptr),
|
||||
"r"(kPhaseBit)
|
||||
);
|
||||
#endif
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Checks if the requested semaphore phase is ready.
|
||||
*
|
||||
* @param semaphore Reference to the semaphore variable.
|
||||
* @param kPhaseBit The phase bit used for the semaphore.
|
||||
*/
|
||||
__device__ static inline int test_wait(semaphore& sem, int kPhaseBit) {
|
||||
void const* const ptr = &sem;
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(ptr));
|
||||
int result;
|
||||
asm volatile (
|
||||
"{\n"
|
||||
".reg .pred P1;\n"
|
||||
"mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2;\n"
|
||||
"selp.u32 %0,1,0,P1;"
|
||||
"}\n"
|
||||
: "=r"(result)
|
||||
: "r"(mbar_ptr), "r"(kPhaseBit)
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
||||
__device__ static inline void arrive_and_wait(semaphore& sem, int kPhaseBit) {
|
||||
arrive(sem);
|
||||
wait(sem, kPhaseBit);
|
||||
}
|
||||
template<int num_warps> __device__ static inline void arrive_and_wait(barrier<num_warps> bar) {
|
||||
asm volatile("bar.sync %0, %1;\n" :: "r"(bar.barrier_id), "n"(num_warps*WARP_THREADS) : "memory");
|
||||
}
|
||||
|
||||
template<int N=0> __device__ static inline void load_async_wait() { // for completing (non-TMA) async loads
|
||||
if constexpr (N == 0) {
|
||||
asm volatile("cp.async.wait_all;\n" ::);
|
||||
} else {
|
||||
asm volatile("cp.async.wait_group %0;\n" :: "n"(N));
|
||||
}
|
||||
__syncwarp();
|
||||
}
|
||||
|
||||
// meant to be used only with shared tiles and shared vectors
|
||||
namespace detail {
|
||||
template<typename T> struct size_info {
|
||||
static constexpr uint32_t bytes = sizeof(std::remove_reference_t<T>);
|
||||
};
|
||||
template<ducks::st::all ST> struct size_info<ST> {
|
||||
static constexpr uint32_t elements = ST::num_elements;
|
||||
static constexpr uint32_t bytes = ST::num_elements * sizeof(typename ST::dtype);
|
||||
};
|
||||
template<ducks::sv::all SV> struct size_info<SV> {
|
||||
static constexpr uint32_t elements = SV::length;
|
||||
static constexpr uint32_t bytes = SV::length * sizeof(typename SV::dtype);
|
||||
};
|
||||
}
|
||||
template<typename... Args> inline constexpr uint32_t size_bytes = 0; // base case
|
||||
template<typename T, typename... Args> inline constexpr uint32_t size_bytes<T, Args...> = detail::size_info<T>::bytes + size_bytes<Args...>; // recursive case
|
||||
|
||||
} // namespace kittens
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include "multimem.cuh"
|
||||
#include "tma.cuh"
|
||||
#endif
|
||||
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
#include "tensor.cuh"
|
||||
#endif
|
||||
416
extra/thunder/cuda/include/ops/thread/memory/vec/tma.cuh
Normal file
416
extra/thunder/cuda/include/ops/thread/memory/vec/tma.cuh
Normal file
@@ -0,0 +1,416 @@
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.cuh"
|
||||
#include "../../../../types/types.cuh"
|
||||
#include "../util/util.cuh"
|
||||
|
||||
#include <cuda.h>
|
||||
#include <iostream>
|
||||
|
||||
// This is a macro that helps us define default cache policy versions of each function.
|
||||
#define __KITTENS_TMA_DEFINE_DEFAULT_LOAD_CACHE_VEC__(function_name) \
|
||||
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>> \
|
||||
__device__ static inline void function_name(SV &dst, const GL &src, const COORD &idx) { \
|
||||
function_name<cache_policy::NORMAL>(dst, src, idx); \
|
||||
}
|
||||
#define __KITTENS_TMA_DEFINE_PGL_DEFAULT_LOAD_CACHE_VEC__(function_name) \
|
||||
template<ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>> \
|
||||
__device__ static inline void function_name(SV &dst, const PGL &src, const COORD &idx) { \
|
||||
function_name<cache_policy::NORMAL>(dst, src, idx); \
|
||||
}
|
||||
#define __KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(function_name) \
|
||||
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>> \
|
||||
__device__ static inline void function_name(const GL &dst, const SV &src, const COORD &idx) { \
|
||||
function_name<cache_policy::NORMAL>(dst, src, idx); \
|
||||
}
|
||||
#define __KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(function_name) \
|
||||
template<ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>> \
|
||||
__device__ static inline void function_name(const PGL &dst, const SV &src, const COORD &idx) { \
|
||||
function_name<cache_policy::NORMAL>(dst, src, idx); \
|
||||
}
|
||||
#define __KITTENS_TMA_DEFINE_SEMAPHORE_CACHE_VEC__(function_name) \
|
||||
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>> \
|
||||
__device__ static inline void function_name(SV &dst, const GL &src, const COORD &idx, semaphore& bar) { \
|
||||
function_name<cache_policy::NORMAL>(dst, src, idx, bar); \
|
||||
}
|
||||
#define __KITTENS_TMA_DEFINE_PGL_SEMAPHORE_CACHE_VEC__(function_name) \
|
||||
template<ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>> \
|
||||
__device__ static inline void function_name(SV &dst, const PGL &src, const COORD &idx, semaphore& bar) { \
|
||||
function_name<cache_policy::NORMAL>(dst, src, idx, bar); \
|
||||
}
|
||||
#define __KITTENS_TMA_DEFINE_CLUSTER_SEMAPHORE_CACHE_VEC__(function_name) \
|
||||
template<ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>> \
|
||||
__device__ static inline void function_name(SV &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { \
|
||||
function_name<cache_policy::NORMAL>(dst, src, idx, bar, cluster_mask, dst_mbar_cta); \
|
||||
}
|
||||
#define __KITTENS_TMA_DEFINE_PGL_CLUSTER_SEMAPHORE_CACHE_VEC__(function_name) \
|
||||
template<ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>> \
|
||||
__device__ static inline void function_name(SV &dst, const PGL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) { \
|
||||
function_name<cache_policy::NORMAL>(dst, src, idx, bar, cluster_mask, dst_mbar_cta); \
|
||||
}
|
||||
|
||||
|
||||
namespace kittens {
|
||||
|
||||
namespace detail {
|
||||
namespace tma {
|
||||
|
||||
template<cache_policy policy> __device__ static inline void vec_prefetch_tma_internal(uint64_t tma_ptr, coord<> tma_coord) {
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.prefetch.tensor.4d.L2.global.tile"
|
||||
" [%0, {%1, %2, %3, %4}];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.prefetch.tensor.4d.L2.global.tile.L2::cache_hint"
|
||||
" [%0, {%1, %2, %3, %4}], %5;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template<cache_policy policy> __device__ static inline void vec_store_async_tma_internal(uint64_t tma_ptr, uint32_t src_i_ptr, coord<> tma_coord) {
|
||||
asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.global.shared::cta.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5}], [%1], %6;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template<cache_policy policy> __device__ static inline void vec_store_add_async_tma_internal(uint64_t tma_ptr, uint32_t src_i_ptr, coord<> tma_coord) {
|
||||
asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5}], [%1], %6;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template<cache_policy policy> __device__ static inline void vec_store_min_async_tma_internal(uint64_t tma_ptr, uint32_t src_i_ptr, coord<> tma_coord) {
|
||||
asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.4d.global.shared::cta.min.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.4d.global.shared::cta.min.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5}], [%1], %6;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template<cache_policy policy> __device__ static inline void vec_store_max_async_tma_internal(uint64_t tma_ptr, uint32_t src_i_ptr, coord<> tma_coord) {
|
||||
asm volatile("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.4d.global.shared::cta.max.tile.bulk_group"
|
||||
" [%0, {%2, %3, %4, %5}], [%1];"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.reduce.async.bulk.tensor.4d.global.shared::cta.max.tile.bulk_group.L2::cache_hint"
|
||||
" [%0, {%2, %3, %4, %5}], [%1], %6;"
|
||||
:
|
||||
: "l"(tma_ptr), "r"(src_i_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
template<cache_policy policy> __device__ static inline void vec_load_async_tma_internal(uint64_t tma_ptr, uint32_t dst_i_ptr, uint32_t mbar_ptr, coord<> tma_coord) {
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2];"
|
||||
:
|
||||
: "r"(dst_i_ptr), "l"(tma_ptr), "r"(mbar_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
|
||||
:
|
||||
: "r"(dst_i_ptr), "l"(tma_ptr), "r"(mbar_ptr), "r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
namespace cluster {
|
||||
template<cache_policy policy> __device__ static inline void vec_load_async_tma_internal(uint64_t tma_ptr, uint32_t dst_i_ptr, uint32_t mbar_ptr, coord<> tma_coord, uint16_t cluster_mask, int dst_mbar_cta=-1) {
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
if(dst_mbar_cta != -1) {
|
||||
uint32_t neighbor_mbar_ptr;
|
||||
asm volatile (
|
||||
"mapa.shared::cluster.u32 %0, %1, %2;\n"
|
||||
: "=r"(neighbor_mbar_ptr)
|
||||
: "r"(mbar_ptr), "r"(dst_mbar_cta)
|
||||
);
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.cta_group::2.multicast::cluster"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
|
||||
:
|
||||
: "r"(dst_i_ptr), "l"(tma_ptr), "r"(neighbor_mbar_ptr),
|
||||
"r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "h"(cluster_mask)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.cta_group::2.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7, %8;"
|
||||
:
|
||||
: "r"(dst_i_ptr), "l"(tma_ptr), "r"(neighbor_mbar_ptr),
|
||||
"r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "h"(cluster_mask), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
} else
|
||||
#endif
|
||||
if constexpr (policy == cache_policy::NORMAL) {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7;"
|
||||
:
|
||||
: "r"(dst_i_ptr), "l"(tma_ptr), "r"(mbar_ptr),
|
||||
"r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "h"(cluster_mask)
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile (
|
||||
"cp.async.bulk.tensor.4d.shared::cluster.global.tile.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint"
|
||||
" [%0], [%1, {%3, %4, %5, %6}], [%2], %7, %8;"
|
||||
:
|
||||
: "r"(dst_i_ptr), "l"(tma_ptr), "r"(mbar_ptr),
|
||||
"r"(tma_coord.c), "r"(tma_coord.r), "r"(tma_coord.d), "r"(tma_coord.b), "h"(cluster_mask), "l"(make_cache_policy<policy>())
|
||||
: "memory"
|
||||
);
|
||||
}
|
||||
}
|
||||
} // namespace cluster
|
||||
|
||||
} // namespace tma
|
||||
} // namespace detail
|
||||
|
||||
namespace tma {
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void prefetch(SV &dst, const GL &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<SV, -1>());
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
::kittens::detail::tma::vec_prefetch_tma_internal<policy>(tma_ptr, tma_coord);
|
||||
}
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_LOAD_CACHE_VEC__(prefetch)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_async(const GL &dst, const SV &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
::kittens::tma::store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_async(const PGL &dst, const SV &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
::kittens::tma::store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_add_async(const GL &dst, const SV &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_add_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
::kittens::tma::store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_add_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_add_async(const PGL &dst, const SV &src, const COORD &idx) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_add_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
::kittens::tma::store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_add_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_min_async(const GL &dst, const SV &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename SV::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_min_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
::kittens::tma::store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_min_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_min_async(const PGL &dst, const SV &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename SV::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_min_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
::kittens::tma::store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_min_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_max_async(const GL &dst, const SV &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename SV::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_max_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
::kittens::tma::store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_DEFAULT_STORE_CACHE_VEC__(store_max_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::pgl::all PGL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void store_max_async(const PGL &dst, const SV &src, const COORD &idx) {
|
||||
static_assert(!std::is_same_v<typename SV::dtype, float>, "TMA does not support async min/max reductions for fp32 types.");
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(dst.template get_tma<SV, -1>());
|
||||
uint32_t src_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&src));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t src_i_ptr = src_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_store_max_async_tma_internal<policy>(tma_ptr, src_i_ptr, tma_coord);
|
||||
}
|
||||
::kittens::tma::store_commit_group();
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_PGL_DEFAULT_STORE_CACHE_VEC__(store_max_async)
|
||||
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx, semaphore& bar) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<SV, -1>());
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&bar));
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t dst_i_ptr = dst_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::vec_load_async_tma_internal<policy>(tma_ptr, dst_i_ptr, mbar_ptr, tma_coord);
|
||||
}
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_SEMAPHORE_CACHE_VEC__(load_async)
|
||||
|
||||
namespace cluster {
|
||||
template<cache_policy policy, ducks::sv::all SV, ducks::gl::all GL, ducks::coord::vec COORD=coord<SV>>
|
||||
__device__ static inline void load_async(SV &dst, const GL &src, const COORD &idx, semaphore& bar, uint16_t cluster_mask, int dst_mbar_cta=-1) {
|
||||
coord<> unit_coord = idx.template unit_coord<-1, 3>();
|
||||
uint64_t tma_ptr = reinterpret_cast<uint64_t>(src.template get_tma<SV, -1>());
|
||||
uint32_t mbar_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&bar));
|
||||
uint32_t dst_ptr = static_cast<uint32_t>(__cvta_generic_to_shared(&dst));
|
||||
for(int i = 0; i < ::kittens::detail::tma::sv_tma_dim2<SV>; i++) {
|
||||
coord<> tma_coord = unit_coord;
|
||||
tma_coord.c += i * ::kittens::detail::tma::sv_tma_dim1<SV>;
|
||||
uint32_t dst_i_ptr = dst_ptr + i*::kittens::detail::tma::sv_tma_dim1<SV>*sizeof(typename SV::dtype);
|
||||
::kittens::detail::tma::cluster::vec_load_async_tma_internal<policy>(tma_ptr, dst_i_ptr, mbar_ptr, tma_coord, cluster_mask, dst_mbar_cta);
|
||||
}
|
||||
}
|
||||
__KITTENS_TMA_DEFINE_CLUSTER_SEMAPHORE_CACHE_VEC__(load_async)
|
||||
} // namespace cluster
|
||||
} // namespace tma
|
||||
} // namespace kittens
|
||||
10
extra/thunder/cuda/include/ops/thread/memory/vec/vec.cuh
Normal file
10
extra/thunder/cuda/include/ops/thread/memory/vec/vec.cuh
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of warp memory operations on vectors, where a single warp loads or stores data on its own.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include "tma.cuh"
|
||||
#endif
|
||||
8
extra/thunder/cuda/include/ops/thread/mma/mma.cuh
Normal file
8
extra/thunder/cuda/include/ops/thread/mma/mma.cuh
Normal file
@@ -0,0 +1,8 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header for warp operations on data stored in tensor memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensor/tensor.cuh"
|
||||
523
extra/thunder/cuda/include/ops/thread/mma/tensor/tensor.cuh
Normal file
523
extra/thunder/cuda/include/ops/thread/mma/tensor/tensor.cuh
Normal file
@@ -0,0 +1,523 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Matrix multiply-accumulate operations for tiles stored in tensor memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.cuh"
|
||||
#include "../../../../types/types.cuh"
|
||||
|
||||
namespace kittens {
|
||||
namespace detail {
|
||||
namespace tcgen05 {
|
||||
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#instruction-descriptor
|
||||
template<typename D, typename AB, int M, int N, bool trans_a, bool trans_b, bool neg=false>
|
||||
__device__ static inline uint32_t instruction_descriptor() {
|
||||
uint32_t desc = 0;
|
||||
if constexpr (sizeof(AB) == 2) { // kind::f16
|
||||
// either accumulate to float, or the input is half and the output is half
|
||||
static_assert(std::is_same_v<D, float> || std::is_same_v<AB, half>);
|
||||
desc |= 0b00 << 0; // sparsity bits unneeded
|
||||
desc |= 0b0 << 2; // dense
|
||||
desc |= 0b0 << 3; // no saturate on fp types
|
||||
if constexpr (std::is_same_v<D, float>) {
|
||||
desc |= 0b01 << 4; // D matrix is FP32
|
||||
}
|
||||
else {
|
||||
desc |= 0b00 << 4; // D matrix is FP16
|
||||
}
|
||||
desc |= 0b0 << 6; // reserved
|
||||
if constexpr (std::is_same_v<AB, half>) {
|
||||
desc |= 0b000 << 7; // 16-bit A input type as FP16
|
||||
desc |= 0b000 << 10; // 16-bit B input type as FP16
|
||||
} else if constexpr (std::is_same_v<AB, bf16>) {
|
||||
desc |= 0b001 << 7; // 16-bit A input type as BF16
|
||||
desc |= 0b001 << 10; // 16-bit B input type as BF16
|
||||
} else if constexpr (std::is_same_v<AB, fp8e4m3>) {
|
||||
desc |= 0b000 << 7; // 8-bit A input type as FP8 e4m3
|
||||
desc |= 0b000 << 10; // 8-bit B input type as FP8 e4m3
|
||||
} else if constexpr (std::is_same_v<AB, fp8e5m2>) {
|
||||
desc |= 0b001 << 7; // 8-bit A input type as FP8 e5m2
|
||||
desc |= 0b001 << 10; // 8-bit B input type as FP8 e5m2
|
||||
}
|
||||
/* fp6 and fp4
|
||||
else if constexpr (std::is_same_v<AB, fp6e2m3>) {
|
||||
desc |= 0b011 << 7; // 6-bit A input type as FP6 e2m3
|
||||
desc |= 0b011 << 10; // 6-bit B input type as FP6 e2m3
|
||||
}
|
||||
else if constexpr (std::is_same_v<AB, fp4e2m3>) {
|
||||
desc |= 0b100 << 7; // 6-bit A input type as FP6 e3m2
|
||||
desc |= 0b100 << 10; // 6-bit B input type as FP6 e3m2
|
||||
}
|
||||
else if constexpr (std::is_same_v<AB, fp4e3m1>) {
|
||||
desc |= 0b101 << 7; // 4-bit A input type as FP4 e3m1
|
||||
desc |= 0b101 << 10; // 4-bit B input type as FP4 e3m1
|
||||
}
|
||||
*/
|
||||
if constexpr (neg) {
|
||||
desc |= 0b1 << 13; // Do negate A matrix
|
||||
}
|
||||
else {
|
||||
desc |= 0b0 << 13; // Don't negate A matrix
|
||||
}
|
||||
desc |= 0b0 << 14; // Don't negate B matrix (in all cases)
|
||||
if constexpr (trans_a) {
|
||||
desc |= 0b1 << 15; // Transpose A matrix
|
||||
}
|
||||
else {
|
||||
desc |= 0b0 << 15; // Don't transpose A matrix
|
||||
}
|
||||
if constexpr (trans_b) {
|
||||
desc |= 0b1 << 16; // Transpose B matrix
|
||||
}
|
||||
else {
|
||||
desc |= 0b0 << 16; // Don't transpose B matrix
|
||||
}
|
||||
desc |= (N >> 3) << 17; // B matrix has dimension N, encoded
|
||||
desc |= 0b0 << 23; // reserved
|
||||
desc |= (M >> 4) << 24; // A matrix has dimension M, encoded
|
||||
desc |= 0b0 << 29; // reserved
|
||||
desc |= 0b00 << 30; // no shift for B-matrix reuse
|
||||
} else if constexpr (sizeof(AB) == 1) { // kind::f8f6f4
|
||||
static_assert(std::is_same_v<D, float> || std::is_same_v<D, half>); // FP8/6/4 has to accumulate to float or half
|
||||
desc |= 0b00 << 0; // sparsity bits unneeded
|
||||
desc |= 0b0 << 2; // dense
|
||||
desc |= 0b0 << 3; // no saturate on fp types
|
||||
if constexpr (std::is_same_v<D, float>) {
|
||||
desc |= 0b01 << 4; // D matrix is FP32
|
||||
}
|
||||
else {
|
||||
desc |= 0b00 << 4; // D matrix is FP16
|
||||
}
|
||||
desc |= 0b0 << 6; // reserved
|
||||
if constexpr (std::is_same_v<AB, fp8e4m3>) {
|
||||
desc |= 0b000 << 7; // 8-bit A input type as FP8 e4m3
|
||||
desc |= 0b000 << 10; // 8-bit B input type as FP8 e4m3
|
||||
} else if constexpr (std::is_same_v<AB, fp8e5m2>) {
|
||||
desc |= 0b001 << 7; // 8-bit A input type as FP8 e5m2
|
||||
desc |= 0b001 << 10; // 8-bit B input type as FP8 e5m2
|
||||
}
|
||||
/* fp6 and fp4
|
||||
else if constexpr (std::is_same_v<AB, fp6e2m3>) {
|
||||
desc |= 0b011 << 7; // 6-bit A input type as FP6 e2m3
|
||||
desc |= 0b011 << 10; // 6-bit B input type as FP6 e2m3
|
||||
}
|
||||
else if constexpr (std::is_same_v<AB, fp4e2m3>) {
|
||||
desc |= 0b100 << 7; // 6-bit A input type as FP6 e3m2
|
||||
desc |= 0b100 << 10; // 6-bit B input type as FP6 e3m2
|
||||
}
|
||||
else if constexpr (std::is_same_v<AB, fp4e3m1>) {
|
||||
desc |= 0b101 << 7; // 4-bit A input type as FP4 e3m1
|
||||
desc |= 0b101 << 10; // 4-bit B input type as FP4 e3m1
|
||||
}
|
||||
*/
|
||||
if constexpr (neg) {
|
||||
desc |= 0b1 << 13; // Do negate A matrix
|
||||
}
|
||||
else {
|
||||
desc |= 0b0 << 13; // Don't negate A matrix
|
||||
}
|
||||
desc |= 0b0 << 14; // Don't negate B matrix (in all cases)
|
||||
if constexpr (trans_a) {
|
||||
desc |= 0b1 << 15; // Transpose A matrix
|
||||
}
|
||||
else {
|
||||
desc |= 0b0 << 15; // Don't transpose A matrix
|
||||
}
|
||||
if constexpr (trans_b) {
|
||||
desc |= 0b1 << 16; // Transpose B matrix
|
||||
}
|
||||
else {
|
||||
desc |= 0b0 << 16; // Don't transpose B matrix
|
||||
}
|
||||
desc |= (N >> 3) << 17; // B matrix has dimension N, encoded
|
||||
desc |= 0b0 << 23; // reserved
|
||||
desc |= (M >> 4) << 24; // A matrix has dimension M, encoded
|
||||
desc |= 0b0 << 29; // reserved
|
||||
desc |= 0b00 << 30; // no shift for B-matrix reuse
|
||||
}
|
||||
else {
|
||||
static_assert(sizeof(AB) == 999, "Invalid AB type size; not implemented yet.");
|
||||
}
|
||||
return desc;
|
||||
};
|
||||
|
||||
template<typename T_AB, int acc, int ncta=1>
|
||||
__device__ static inline void tt_st(uint32_t d_tt_addr, uint32_t a_tt_addr, uint64_t b_desc, uint32_t idesc) {
|
||||
if constexpr (std::is_same_v<T_AB, fp8e4m3> || std::is_same_v<T_AB, fp8e5m2>) {
|
||||
// TODO(danfu): is there a better way to do this with string manipulation that the compiler likes?
|
||||
if constexpr (ncta == 1) {
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n" \
|
||||
"setp.eq.u32 p, 1, %4;\n" \
|
||||
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], [%1], %2, %3, p;}\n"
|
||||
:: "r"(d_tt_addr), "r"(a_tt_addr), "l"(b_desc), "r"(idesc), "n"(acc)
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n" \
|
||||
"setp.eq.u32 p, 1, %4;\n" \
|
||||
"tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], [%1], %2, %3, p;}\n"
|
||||
:: "r"(d_tt_addr), "r"(a_tt_addr), "l"(b_desc), "r"(idesc), "n"(acc)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if constexpr (ncta == 1) {
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n" \
|
||||
"setp.eq.u32 p, 1, %4;\n" \
|
||||
"tcgen05.mma.cta_group::1.kind::f16 [%0], [%1], %2, %3, p;}\n"
|
||||
:: "r"(d_tt_addr), "r"(a_tt_addr), "l"(b_desc), "r"(idesc), "n"(acc)
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n" \
|
||||
"setp.eq.u32 p, 1, %4;\n" \
|
||||
"tcgen05.mma.cta_group::2.kind::f16 [%0], [%1], %2, %3, p;}\n"
|
||||
:: "r"(d_tt_addr), "r"(a_tt_addr), "l"(b_desc), "r"(idesc), "n"(acc)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T_AB, int acc, int ncta=1>
|
||||
__device__ static inline void st_st(uint32_t d_tt_addr, uint64_t a_desc, uint64_t b_desc, uint32_t idesc) {
|
||||
if constexpr (std::is_same_v<T_AB, fp8e4m3> || std::is_same_v<T_AB, fp8e5m2>) {
|
||||
// TODO(danfu): is there a better way to do this with string manipulation that the compiler likes?
|
||||
if constexpr (ncta == 1) {
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n" \
|
||||
"setp.eq.u32 p, 1, %4;\n" \
|
||||
"tcgen05.mma.cta_group::1.kind::f8f6f4 [%0], %1, %2, %3, p;}\n"
|
||||
:: "r"(d_tt_addr), "l"(a_desc), "l"(b_desc), "r"(idesc), "n"(acc)
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n" \
|
||||
"setp.eq.u32 p, 1, %4;\n" \
|
||||
"tcgen05.mma.cta_group::2.kind::f8f6f4 [%0], %1, %2, %3, p;}\n"
|
||||
:: "r"(d_tt_addr), "l"(a_desc), "l"(b_desc), "r"(idesc), "n"(acc)
|
||||
);
|
||||
}
|
||||
} else {
|
||||
if constexpr (ncta == 1) {
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n" \
|
||||
"setp.eq.u32 p, 1, %4;\n" \
|
||||
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, p;}\n"
|
||||
:: "r"(d_tt_addr), "l"(a_desc), "l"(b_desc), "r"(idesc), "n"(acc)
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile(
|
||||
"{.reg .pred p;\n" \
|
||||
"setp.eq.u32 p, 1, %4;\n" \
|
||||
"tcgen05.mma.cta_group::2.kind::f16 [%0], %1, %2, %3, p;}\n"
|
||||
:: "r"(d_tt_addr), "l"(a_desc), "l"(b_desc), "r"(idesc), "n"(acc)
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<int ncta=1> __device__ static inline void commit(kittens::semaphore &sem) {
|
||||
if constexpr (ncta == 1) {
|
||||
asm volatile(
|
||||
"tcgen05.commit.cta_group::1.mbarrier::arrive::one.b64 [%0];\n"
|
||||
:: "l"(&sem)
|
||||
);
|
||||
}
|
||||
else {
|
||||
asm volatile(
|
||||
"tcgen05.commit.cta_group::2.mbarrier::arrive::one.shared::cluster.multicast::cluster.b64 [%0], %1;\n"
|
||||
:: "l"(&sem), "h"((uint16_t)(0b11))
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace tcgen05
|
||||
} // namespace detail
|
||||
|
||||
template<typename T_AB> constexpr int reduction_dimension = sizeof(T_AB) == 2 ? 16 : sizeof(T_AB) == 4 ? 8 : 32; // haven't added fp4 yet.
|
||||
// RS matmul equivalent
|
||||
template<int trans_a, int n_trans_b, ducks::tt::all D, ducks::tt::all A, ducks::st_descriptor::input B, int acc=1, int ncta=1>
|
||||
__device__ static inline void mma(D &d, const A &a, const B &b) {
|
||||
constexpr int trans_b = 1 - n_trans_b;
|
||||
|
||||
// Do everything here.
|
||||
constexpr int M = (trans_a ? A::cols : A::rows) * ncta;
|
||||
static_assert(M == D::rows*ncta && ((ncta == 1 && (M == 64 || M == 128)) || (ncta == 2 && (M == 128 || M == 256)))); // output register is correctly sized
|
||||
|
||||
constexpr int N = (trans_b ? B::cols : B::rows) * ncta;
|
||||
static_assert(N == D::cols); // output register is correctly sized
|
||||
|
||||
constexpr int K = trans_a ? A::rows : A::cols;
|
||||
static_assert((trans_b ? B::rows : B::cols) == K); // K dimension must match
|
||||
static_assert(std::is_same_v<typename A::T, typename B::T>); // A and B must match type.
|
||||
|
||||
// Usings
|
||||
using T_AB = A::T; static_assert(std::is_same_v<T_AB, typename B::T>);
|
||||
using T_D = D::T;
|
||||
|
||||
constexpr int red_dim = reduction_dimension<T_AB>;
|
||||
static_assert(K%red_dim == 0, "K dimension must be divisible by red_dim.");
|
||||
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Currently unsupported type combination for matrix multiply."
|
||||
);
|
||||
uint32_t idesc = detail::tcgen05::instruction_descriptor<T_D, T_AB, M, N, trans_a, trans_b, false>();
|
||||
kittens::st_descriptor<ducks::st_descriptor::detail::get_st<B>, trans_b> b_desc(b);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
|
||||
detail::tcgen05::template tt_st<T_AB, acc, ncta>(
|
||||
d.addr,
|
||||
a.template chunk_addr<trans_a>(0),
|
||||
b_desc.chunk_descriptor(0),
|
||||
idesc
|
||||
);
|
||||
#pragma unroll
|
||||
for(int i = 1; i < K/red_dim; i++) {
|
||||
detail::tcgen05::template tt_st<T_AB, 1, ncta>(
|
||||
d.addr,
|
||||
a.template chunk_addr<trans_a>(i),
|
||||
b_desc.chunk_descriptor(i),
|
||||
idesc
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int trans_a, int n_trans_b, ducks::tt::all D, ducks::tt::all A, ducks::st_descriptor::input B, int acc=1, int ncta=1>
|
||||
__device__ static inline void mma(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<trans_a, n_trans_b, D, A, B, acc, ncta>(d, a, b);
|
||||
detail::tcgen05::commit<ncta>(sem);
|
||||
}
|
||||
// SS matmul equivalent
|
||||
template<int trans_a, int n_trans_b, ducks::tt::all D, ducks::st_descriptor::input A, ducks::st_descriptor::input B, int acc=1, int ncta=1>
|
||||
__device__ static inline void mma(D &d, const A &a, const B &b) {
|
||||
constexpr int trans_b = 1 - n_trans_b;
|
||||
|
||||
// Do everything here.
|
||||
constexpr int M = (trans_a ? A::cols : A::rows) * ncta;
|
||||
static_assert(M == D::rows*ncta && ((ncta == 1 && (M == 64 || M == 128)) || (ncta == 2 && (M == 128 || M == 256)))); // output register is correctly sized
|
||||
|
||||
constexpr int N = (trans_b ? B::cols : B::rows) * ncta;
|
||||
static_assert(N == D::cols); // output register is correctly sized
|
||||
|
||||
constexpr int K = trans_a ? A::rows : A::cols;
|
||||
static_assert((trans_b ? B::rows : B::cols) == K); // K dimension must match
|
||||
static_assert(std::is_same_v<typename A::T, typename B::T>); // A and B must match type.
|
||||
|
||||
// Usings
|
||||
using T_AB = A::T; static_assert(std::is_same_v<T_AB, typename B::T>);
|
||||
using T_D = D::T;
|
||||
|
||||
constexpr int red_dim = reduction_dimension<T_AB>;
|
||||
static_assert(K%red_dim == 0, "K dimension must be divisible by red_dim.");
|
||||
|
||||
static_assert(
|
||||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, half> && !std::is_same_v<T_AB, fp8e5m2>) ||
|
||||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, bf16>) ||
|
||||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, half>) ||
|
||||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e4m3>) ||
|
||||
(std::is_same_v<T_D, float> && !std::is_same_v<T_AB, fp8e5m2>),
|
||||
"Currently unsupported type combination for matrix multiply."
|
||||
);
|
||||
uint32_t idesc = detail::tcgen05::instruction_descriptor<T_D, T_AB, M, N, trans_a, trans_b, false>();
|
||||
kittens::st_descriptor<ducks::st_descriptor::detail::get_st<A>, trans_a> a_desc(a);
|
||||
kittens::st_descriptor<ducks::st_descriptor::detail::get_st<B>, trans_b> b_desc(b);
|
||||
|
||||
asm volatile ("fence.proxy.async.shared::cta;\n" ::: "memory");
|
||||
|
||||
detail::tcgen05::template st_st<T_AB, acc, ncta>(
|
||||
d.addr,
|
||||
a_desc.chunk_descriptor(0),
|
||||
b_desc.chunk_descriptor(0),
|
||||
idesc
|
||||
);
|
||||
#pragma unroll
|
||||
for(int i = 1; i < K/red_dim; i++) {
|
||||
detail::tcgen05::template st_st<T_AB, 1, ncta>(
|
||||
d.addr,
|
||||
a_desc.chunk_descriptor(i),
|
||||
b_desc.chunk_descriptor(i),
|
||||
idesc
|
||||
);
|
||||
}
|
||||
}
|
||||
template<int trans_a, int n_trans_b, ducks::tt::all D, ducks::st_descriptor::input A, ducks::st_descriptor::input B, int acc=1, int ncta=1>
|
||||
__device__ static inline void mma(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<trans_a, n_trans_b, D, A, B, acc, ncta>(d, a, b);
|
||||
detail::tcgen05::commit<ncta>(sem);
|
||||
}
|
||||
// Accumulator / numcta wrappers
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B, int acc=1>
|
||||
__device__ static inline void mma2(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<trans_a, trans_b, D, A, B, acc, 2>(d, a, b, sem);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B, int acc=1>
|
||||
__device__ static inline void mma2(D &d, const A &a, const B &b) {
|
||||
mma<trans_a, trans_b, D, A, B, acc, 2>(d, a, b);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<trans_a, trans_b, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm(D &d, const A &a, const B &b) {
|
||||
mma<trans_a, trans_b, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<trans_a, trans_b, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<int trans_a, int trans_b, ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2(D &d, const A &a, const B &b) {
|
||||
mma2<trans_a, trans_b, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
|
||||
// Transpose wrappers
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::N, transpose::N, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AB(D &d, const A &a, const B &b) {
|
||||
mma<transpose::N, transpose::N, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::N, transpose::N, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AB(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::N, transpose::N, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_ABt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::N, transpose::T, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_ABt(D &d, const A &a, const B &b) {
|
||||
mma<transpose::N, transpose::T, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_ABt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::N, transpose::T, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_ABt(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::N, transpose::T, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AtB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::T, transpose::N, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AtB(D &d, const A &a, const B &b) {
|
||||
mma<transpose::T, transpose::N, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AtB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::T, transpose::N, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AtB(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::T, transpose::N, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AtBt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::T, transpose::T, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma_AtBt(D &d, const A &a, const B &b) {
|
||||
mma<transpose::T, transpose::T, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AtBt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::T, transpose::T, D, A, B, 1>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mma2_AtBt(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::T, transpose::T, D, A, B, 1>(d, a, b);
|
||||
}
|
||||
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::N, transpose::N, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AB(D &d, const A &a, const B &b) {
|
||||
mma<transpose::N, transpose::N, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::N, transpose::N, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AB(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::N, transpose::N, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_ABt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::N, transpose::T, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_ABt(D &d, const A &a, const B &b) {
|
||||
mma<transpose::N, transpose::T, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_ABt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::N, transpose::T, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_ABt(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::N, transpose::T, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AtB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::T, transpose::N, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AtB(D &d, const A &a, const B &b) {
|
||||
mma<transpose::T, transpose::N, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AtB(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::T, transpose::N, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AtB(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::T, transpose::N, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AtBt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma<transpose::T, transpose::T, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm_AtBt(D &d, const A &a, const B &b) {
|
||||
mma<transpose::T, transpose::T, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AtBt(D &d, const A &a, const B &b, semaphore &sem) {
|
||||
mma2<transpose::T, transpose::T, D, A, B, 0>(d, a, b, sem);
|
||||
}
|
||||
template<ducks::tt::all D, typename A, ducks::st_descriptor::input B>
|
||||
__device__ static inline void mm2_AtBt(D &d, const A &a, const B &b) {
|
||||
mma2<transpose::T, transpose::T, D, A, B, 0>(d, a, b);
|
||||
}
|
||||
|
||||
|
||||
} // namespace kittens
|
||||
|
||||
13
extra/thunder/cuda/include/ops/thread/thread.cuh
Normal file
13
extra/thunder/cuda/include/ops/thread/thread.cuh
Normal file
@@ -0,0 +1,13 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of all warp (worker) operations defined by ThunderKittens
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// no namespace wrapper needed here
|
||||
|
||||
#include "memory/memory.cuh"
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
#include "mma/mma.cuh"
|
||||
#endif
|
||||
551
extra/thunder/cuda/include/pyutils/broker.cuh
Normal file
551
extra/thunder/cuda/include/pyutils/broker.cuh
Normal file
@@ -0,0 +1,551 @@
|
||||
/**
|
||||
* @file broker.cuh
|
||||
* @brief Utility for multiprocess data exchange and synchronization.
|
||||
*
|
||||
* This file provides the KittensBroker class, which enables efficient inter-process
|
||||
* communication and synchronization using POSIX shared memory, semaphores, and sockets.
|
||||
* The broker is designed to work in multi-GPU environments where processes need to
|
||||
* exchange data and synchronize execution across different local ranks.
|
||||
*
|
||||
* @note This implementation relies on POSIX IPC mechanisms and is intended for
|
||||
* Unix-like systems. All processes must be running on the same node.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cerrno>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include <fcntl.h>
|
||||
#include <semaphore.h>
|
||||
#include <stdexcept>
|
||||
#include <sys/mman.h>
|
||||
#include <sys/socket.h>
|
||||
#include <sys/stat.h>
|
||||
#include <sys/types.h>
|
||||
#include <sys/un.h>
|
||||
#include <sys/uio.h>
|
||||
#include <unistd.h>
|
||||
#include <vector>
|
||||
|
||||
#if defined(WIN32) || defined(_WIN32) || defined(WIN64) || defined(_WIN64)
|
||||
#error "KittensBroker is not supported on Windows"
|
||||
#endif
|
||||
|
||||
namespace kittens {
|
||||
|
||||
namespace detail {
|
||||
namespace broker {
|
||||
|
||||
static constexpr int MAX_LOCAL_WORLD_SIZE = 72;
|
||||
static constexpr int VAULT_SIZE_PER_RANK = 64; // sizeof(cudaIpcMemHandle_t)
|
||||
|
||||
struct KittensVault {
|
||||
static constexpr int INIT_CODE = 0x43617473; // "Cats"
|
||||
int init;
|
||||
int barrier;
|
||||
int sense;
|
||||
uint8_t data[MAX_LOCAL_WORLD_SIZE * VAULT_SIZE_PER_RANK];
|
||||
};
|
||||
|
||||
static constexpr int SHM_SIZE = (sizeof(KittensVault) + 4095) / 4096 * 4096;
|
||||
|
||||
__host__ inline static void init_sync(
|
||||
int local_rank,
|
||||
volatile KittensVault *vault
|
||||
) {
|
||||
if (local_rank == 0) {
|
||||
// initialize barrier resources
|
||||
vault->barrier = 0;
|
||||
vault->sense = 0;
|
||||
__sync_synchronize(); // make previous writes visible
|
||||
vault->init = KittensVault::INIT_CODE;
|
||||
} else {
|
||||
while (vault->init != KittensVault::INIT_CODE) usleep(1);
|
||||
__sync_synchronize(); // see leader's previous writes
|
||||
}
|
||||
}
|
||||
|
||||
__host__ inline static void sync(
|
||||
int local_world_size,
|
||||
volatile KittensVault *vault
|
||||
) {
|
||||
if (vault->init != KittensVault::INIT_CODE)
|
||||
throw std::runtime_error("KittensBroker: KittensVault not initialized");
|
||||
|
||||
// Phase 1
|
||||
int arrived = __sync_add_and_fetch(&vault->barrier, 1);
|
||||
if (arrived == local_world_size) vault->sense = 1;
|
||||
while (!vault->sense) usleep(1);
|
||||
|
||||
// Make previous writes visible
|
||||
__sync_synchronize();
|
||||
|
||||
// Phase 2
|
||||
arrived = __sync_add_and_fetch(&vault->barrier, -1);
|
||||
if (arrived == 0) vault->sense = 0;
|
||||
while (vault->sense) usleep(1);
|
||||
}
|
||||
|
||||
__host__ inline void *create_shm(const char *key, size_t size) {
|
||||
int shm_fd;
|
||||
shm_fd = shm_open(key, O_RDWR | O_CREAT | O_EXCL | O_CLOEXEC, 0600);
|
||||
|
||||
if (shm_fd < 0) {
|
||||
if (errno == EEXIST)
|
||||
throw std::runtime_error("KittensBroker: Named shared memory already exists");
|
||||
throw std::runtime_error("KittensBroker: Failed to create shared memory");
|
||||
}
|
||||
|
||||
if (ftruncate(shm_fd, size) != 0) {
|
||||
shm_unlink(key);
|
||||
close(shm_fd);
|
||||
throw std::runtime_error("KittensBroker: Failed to truncate shared memory");
|
||||
}
|
||||
|
||||
void *addr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
|
||||
close(shm_fd);
|
||||
if (addr == MAP_FAILED) {
|
||||
shm_unlink(key);
|
||||
throw std::runtime_error("KittensBroker: Failed to map to shared memory");
|
||||
}
|
||||
|
||||
return addr;
|
||||
}
|
||||
|
||||
__host__ inline void *open_shm(const char *key, size_t size) {
|
||||
int shm_fd;
|
||||
while (true) {
|
||||
shm_fd = shm_open(key, O_RDWR | O_CLOEXEC, 0);
|
||||
if (shm_fd >= 0)
|
||||
break;
|
||||
if (errno != ENOENT)
|
||||
throw std::runtime_error("KittensBroker: Failed to open shared memory");
|
||||
usleep(1);
|
||||
}
|
||||
|
||||
struct stat shm_st;
|
||||
do {
|
||||
if (fstat(shm_fd, &shm_st) != 0) {
|
||||
shm_unlink(key);
|
||||
close(shm_fd);
|
||||
throw std::runtime_error("KittensBroker: Failed to open shared memory stats");
|
||||
}
|
||||
usleep(1);
|
||||
} while ((size_t)shm_st.st_size < size);
|
||||
|
||||
void *addr = mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, 0);
|
||||
close(shm_fd);
|
||||
if (addr == MAP_FAILED) {
|
||||
shm_unlink(key);
|
||||
throw std::runtime_error("KittensBroker: Failed to map to shared memory");
|
||||
}
|
||||
|
||||
return addr;
|
||||
}
|
||||
|
||||
__host__ inline void unlink_shm(const char *key) {
|
||||
shm_unlink(key);
|
||||
}
|
||||
|
||||
__host__ inline void unmap_shm(void *addr, size_t size) {
|
||||
munmap(addr, size);
|
||||
}
|
||||
|
||||
__host__ inline int create_socket(const char *key, int local_rank) {
|
||||
int sock_fd;
|
||||
if ((sock_fd = socket(AF_UNIX, SOCK_DGRAM | SOCK_CLOEXEC, 0)) < 0)
|
||||
throw std::runtime_error("KittensBroker: Socket creation error");
|
||||
|
||||
struct sockaddr_un addr;
|
||||
memset(&addr, 0, sizeof(addr));
|
||||
addr.sun_family = AF_UNIX;
|
||||
|
||||
char unique_key[64];
|
||||
int n = snprintf(unique_key, sizeof(unique_key), "%s%d", key, local_rank);
|
||||
if (n < 0 || n >= (int)sizeof(unique_key)) {
|
||||
close(sock_fd);
|
||||
throw std::runtime_error("KittensBroker: Socket name too long");
|
||||
}
|
||||
|
||||
size_t len = strnlen(unique_key, sizeof(addr.sun_path));
|
||||
if (len > (sizeof(addr.sun_path) - 1)) {
|
||||
close(sock_fd);
|
||||
throw std::runtime_error("KittensBroker: Socket name too long");
|
||||
}
|
||||
strcpy(addr.sun_path, unique_key);
|
||||
unlink(unique_key);
|
||||
|
||||
if (bind(sock_fd, (struct sockaddr *)&addr, SUN_LEN(&addr)) < 0) {
|
||||
close(sock_fd);
|
||||
throw std::runtime_error("KittensBroker: Failed to bind socket");
|
||||
}
|
||||
|
||||
return sock_fd;
|
||||
}
|
||||
|
||||
__host__ inline void send_fd(
|
||||
int sock_fd,
|
||||
int data_fd,
|
||||
const char *dst_key,
|
||||
int dst_local_rank,
|
||||
int src_local_rank
|
||||
) {
|
||||
union {
|
||||
struct cmsghdr cm;
|
||||
char* control;
|
||||
} control_un;
|
||||
|
||||
size_t sizeof_control = CMSG_SPACE(sizeof(int));
|
||||
control_un.control = reinterpret_cast<char *>(malloc(sizeof_control));
|
||||
if (!control_un.control) {
|
||||
close(sock_fd);
|
||||
close(data_fd);
|
||||
throw std::runtime_error("KittensBroker: Failed to allocate a control buffer");
|
||||
}
|
||||
|
||||
struct msghdr msg {};
|
||||
msg.msg_control = control_un.control;
|
||||
msg.msg_controllen = sizeof_control;
|
||||
|
||||
struct cmsghdr *cmptr = CMSG_FIRSTHDR(&msg);
|
||||
cmptr->cmsg_len = CMSG_LEN(sizeof(int));
|
||||
cmptr->cmsg_level = SOL_SOCKET;
|
||||
cmptr->cmsg_type = SCM_RIGHTS;
|
||||
memmove(CMSG_DATA(cmptr), &data_fd, sizeof(data_fd));
|
||||
|
||||
struct sockaddr_un addr {};
|
||||
addr.sun_family = AF_UNIX;
|
||||
char dst_unique_key[64];
|
||||
int n = snprintf(dst_unique_key, sizeof(dst_unique_key), "%s%d", dst_key, dst_local_rank);
|
||||
if (n < 0 || n >= (int)sizeof(dst_unique_key)) {
|
||||
free(control_un.control);
|
||||
close(sock_fd);
|
||||
close(data_fd);
|
||||
throw std::runtime_error("KittensBroker: dst path too long");
|
||||
}
|
||||
strcpy(addr.sun_path, dst_unique_key);
|
||||
msg.msg_name = (void *)&addr;
|
||||
msg.msg_namelen = sizeof(struct sockaddr_un);
|
||||
|
||||
int payload = src_local_rank;
|
||||
struct iovec iov[1];
|
||||
iov[0].iov_base = &payload;
|
||||
iov[0].iov_len = sizeof(payload);
|
||||
msg.msg_iov = iov;
|
||||
msg.msg_iovlen = 1;
|
||||
|
||||
while (true) {
|
||||
ssize_t sent = sendmsg(sock_fd, &msg, 0);
|
||||
if (sent <= 0) {
|
||||
if (errno == EINTR) continue;
|
||||
close(sock_fd);
|
||||
close(data_fd);
|
||||
free(control_un.control);
|
||||
throw std::runtime_error("KittensBroker: Failed to send FD over socket");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
free(control_un.control);
|
||||
}
|
||||
|
||||
__host__ inline void recv_fd(int sock_fd, int *data_fd, int *src_local_rank) {
|
||||
union {
|
||||
struct cmsghdr cm;
|
||||
char* control;
|
||||
} control_un;
|
||||
|
||||
size_t sizeof_control = CMSG_SPACE(sizeof(int));
|
||||
control_un.control = reinterpret_cast<char *>(malloc(sizeof_control));
|
||||
if (!control_un.control) {
|
||||
close(sock_fd);
|
||||
throw std::runtime_error("KittensBroker: Failed to allocate a control buffer");
|
||||
}
|
||||
|
||||
struct msghdr msg {};
|
||||
msg.msg_control = control_un.control;
|
||||
msg.msg_controllen = sizeof_control;
|
||||
|
||||
int payload = -1;
|
||||
struct iovec iov[1];
|
||||
iov[0].iov_base = &payload;
|
||||
iov[0].iov_len = sizeof(payload);
|
||||
msg.msg_iov = iov;
|
||||
msg.msg_iovlen = 1;
|
||||
|
||||
while (true) {
|
||||
ssize_t received = recvmsg(sock_fd, &msg, 0);
|
||||
if (received < 0 && errno == EINTR) {
|
||||
msg.msg_controllen = sizeof_control;
|
||||
msg.msg_iovlen = 1;
|
||||
continue;
|
||||
}
|
||||
if (received < static_cast<ssize_t>(sizeof(*data_fd))) {
|
||||
free(control_un.control);
|
||||
close(sock_fd);
|
||||
throw std::runtime_error("KittensBroker: Failed to receive data over socket");
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (msg.msg_flags & MSG_CTRUNC) {
|
||||
free(control_un.control);
|
||||
close(sock_fd);
|
||||
throw std::runtime_error("KittensBroker: Control data truncated");
|
||||
}
|
||||
|
||||
struct cmsghdr *cmptr = CMSG_FIRSTHDR(&msg);
|
||||
if (!cmptr ||
|
||||
cmptr->cmsg_len != CMSG_LEN(sizeof(int)) ||
|
||||
cmptr->cmsg_level != SOL_SOCKET ||
|
||||
cmptr->cmsg_type != SCM_RIGHTS) {
|
||||
free(control_un.control);
|
||||
close(sock_fd);
|
||||
throw std::runtime_error("KittensBroker: Failed to receive data over socket");
|
||||
}
|
||||
|
||||
memmove(data_fd, CMSG_DATA(cmptr), sizeof(*data_fd));
|
||||
free(control_un.control);
|
||||
*src_local_rank = payload;
|
||||
}
|
||||
|
||||
__host__ inline void unlink_socket(const char *key, int local_rank) {
|
||||
char unique_key[64];
|
||||
int n = snprintf(unique_key, sizeof(unique_key), "%s%d", key, local_rank);
|
||||
if (n < 0 || n >= (int)sizeof(unique_key))
|
||||
throw std::runtime_error("KittensBroker: Socket name too long");
|
||||
unlink(unique_key);
|
||||
}
|
||||
|
||||
__host__ inline void close_socket(int sock_fd) {
|
||||
close(sock_fd);
|
||||
}
|
||||
|
||||
} // namespace broker
|
||||
} // namespace detail
|
||||
|
||||
/**
|
||||
@brief KittensBroker utility for multiprocess data exchange.
|
||||
|
||||
Note that the code relies on POSIX sockets/shared memory/semaphores for
|
||||
inter-process communication and synchronization.
|
||||
|
||||
The main functions meant to be used by the user are:
|
||||
|
||||
KittensBroker broker(local_rank, local_world_size);
|
||||
broker.exchange_data(dst, src, size); // exchange data between all processes
|
||||
broker.exchange_fds(dst, src_fd); // exchange file descriptors between all processes
|
||||
broker.broadcast_fd(dst, src_fd, src_rank); // broadcast file descriptor from src_rank to all processes
|
||||
broker.sync(); // wait until all processes reach here
|
||||
*/
|
||||
struct KittensBroker {
|
||||
// TODO: make unique per process group
|
||||
static inline constexpr const char *SHM_KEY_ = "/kittens_broker_shm";
|
||||
static inline constexpr const char *SOCK_KEY_ = "/tmp/kittens_broker.sock";
|
||||
|
||||
int local_rank_;
|
||||
int local_world_size_;
|
||||
|
||||
void *shm_raw_;
|
||||
volatile detail::broker::KittensVault *shm_;
|
||||
int sock_;
|
||||
|
||||
__host__ inline KittensBroker(int local_rank, int local_world_size)
|
||||
: local_rank_(local_rank),
|
||||
local_world_size_(local_world_size),
|
||||
shm_raw_(nullptr),
|
||||
shm_(nullptr),
|
||||
sock_(-1) {
|
||||
if (local_rank_ < 0)
|
||||
throw std::runtime_error("KittensBroker: Local rank must be non-negative");
|
||||
if (local_rank_ >= local_world_size_)
|
||||
throw std::runtime_error("KittensBroker: Local rank is greater than local world size");
|
||||
if (local_world_size_ > detail::broker::MAX_LOCAL_WORLD_SIZE)
|
||||
throw std::runtime_error("KittensBroker: Local world size is greater than MAX_LOCAL_WORLD_SIZE");
|
||||
|
||||
if (local_rank_ == 0) {
|
||||
shm_raw_ = detail::broker::create_shm(SHM_KEY_, sizeof(detail::broker::KittensVault));
|
||||
shm_ = reinterpret_cast<volatile detail::broker::KittensVault *>(shm_raw_);
|
||||
memset(shm_raw_, 0, sizeof(detail::broker::KittensVault));
|
||||
} else {
|
||||
shm_raw_ = detail::broker::open_shm(SHM_KEY_, sizeof(detail::broker::KittensVault));
|
||||
shm_ = reinterpret_cast<volatile detail::broker::KittensVault *>(shm_raw_);
|
||||
}
|
||||
detail::broker::init_sync(local_rank_, shm_);
|
||||
detail::broker::sync(local_world_size_, shm_);
|
||||
|
||||
if (local_rank_ ==0)
|
||||
detail::broker::unlink_shm(SHM_KEY_);
|
||||
detail::broker::sync(local_world_size_, shm_);
|
||||
|
||||
sock_ = detail::broker::create_socket(SOCK_KEY_, local_rank_);
|
||||
detail::broker::sync(local_world_size_, shm_);
|
||||
}
|
||||
|
||||
KittensBroker(const KittensBroker&) = delete;
|
||||
KittensBroker& operator=(const KittensBroker&) = delete;
|
||||
|
||||
__host__ inline KittensBroker(KittensBroker&& other) noexcept
|
||||
: local_rank_(other.local_rank_),
|
||||
local_world_size_(other.local_world_size_),
|
||||
shm_raw_(other.shm_raw_),
|
||||
shm_(other.shm_),
|
||||
sock_(other.sock_) {
|
||||
other.local_rank_ = -1;
|
||||
other.local_world_size_ = -1;
|
||||
other.shm_raw_ = nullptr;
|
||||
other.shm_ = nullptr;
|
||||
other.sock_ = -1;
|
||||
}
|
||||
|
||||
__host__ inline void destroy() {
|
||||
if (shm_raw_) {
|
||||
detail::broker::unmap_shm(shm_raw_, sizeof(detail::broker::KittensVault));
|
||||
shm_raw_ = nullptr;
|
||||
shm_ = nullptr;
|
||||
}
|
||||
if (sock_ >= 0) {
|
||||
detail::broker::unlink_socket(SOCK_KEY_, local_rank_);
|
||||
detail::broker::close_socket(sock_);
|
||||
sock_ = -1;
|
||||
}
|
||||
local_rank_ = -1;
|
||||
local_world_size_ = -1;
|
||||
}
|
||||
|
||||
__host__ inline KittensBroker& operator=(KittensBroker&& other) noexcept {
|
||||
if (this != &other) {
|
||||
destroy();
|
||||
local_rank_ = other.local_rank_;
|
||||
local_world_size_ = other.local_world_size_;
|
||||
shm_raw_ = other.shm_raw_;
|
||||
shm_ = other.shm_;
|
||||
sock_ = other.sock_;
|
||||
other.local_rank_ = -1;
|
||||
other.local_world_size_ = -1;
|
||||
other.shm_raw_ = nullptr;
|
||||
other.shm_ = nullptr;
|
||||
other.sock_ = -1;
|
||||
}
|
||||
return *this;
|
||||
}
|
||||
|
||||
__host__ inline ~KittensBroker() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
__host__ inline void sync(int num_ranks = -1) {
|
||||
if (num_ranks == -1)
|
||||
num_ranks = local_world_size_;
|
||||
else if (num_ranks < 0 || num_ranks > local_world_size_)
|
||||
throw std::runtime_error("KittensBroker: Invalid number of ranks");
|
||||
|
||||
detail::broker::sync(num_ranks, shm_);
|
||||
}
|
||||
|
||||
__host__ inline void exchange_data(void *dst_, const void *src_, size_t size) {
|
||||
if (size > detail::broker::VAULT_SIZE_PER_RANK)
|
||||
throw std::runtime_error("KittensBroker: Size is greater than VAULT_SIZE_PER_RANK");
|
||||
|
||||
uint8_t *dst = reinterpret_cast<uint8_t *>(dst_);
|
||||
const uint8_t *src = reinterpret_cast<const uint8_t *>(src_);
|
||||
|
||||
// Exchange data
|
||||
sync(); // ensure all processes enter together
|
||||
memcpy(const_cast<uint8_t *>(shm_->data) + local_rank_ * detail::broker::VAULT_SIZE_PER_RANK, src, size);
|
||||
sync(); // ensure all processes exit together
|
||||
|
||||
// Pack and copy back to destination
|
||||
for (int i = 0; i < local_world_size_; i++)
|
||||
memcpy(dst + i * size, const_cast<uint8_t *>(shm_->data) + i * detail::broker::VAULT_SIZE_PER_RANK, size);
|
||||
}
|
||||
|
||||
__host__ inline void exchange_fds(int *dst, const int data_fd) {
|
||||
if (dst == nullptr)
|
||||
throw std::runtime_error("KittensBroker: dst is null");
|
||||
if (data_fd < 0)
|
||||
throw std::runtime_error("KittensBroker: source fd is negative");
|
||||
|
||||
// Initialize dst buffer
|
||||
for (int i = 0; i < local_world_size_; ++i)
|
||||
dst[i] = -1;
|
||||
|
||||
// Ensure all processes enter together
|
||||
sync();
|
||||
|
||||
if (local_rank_ == 0) {
|
||||
// Rank 0 receives all FDs from and distributes them to other ranks
|
||||
dst[0] = data_fd;
|
||||
for (int i = 0; i < local_world_size_ - 1; i++) {
|
||||
int received_fd;
|
||||
int src_local_rank;
|
||||
detail::broker::recv_fd(sock_, &received_fd, &src_local_rank);
|
||||
if (received_fd < 0)
|
||||
throw std::runtime_error("KittensBroker: Failed to receive FD over socket");
|
||||
if (src_local_rank == local_rank_)
|
||||
throw std::runtime_error("KittensBroker: Invalid source rank");
|
||||
dst[src_local_rank] = received_fd;
|
||||
}
|
||||
for (int dst_local_rank = 1; dst_local_rank < local_world_size_; dst_local_rank++) {
|
||||
for (int src_local_rank = 0; src_local_rank < local_world_size_; src_local_rank++) {
|
||||
if (dst_local_rank == src_local_rank)
|
||||
continue;
|
||||
detail::broker::send_fd(sock_, dst[src_local_rank], SOCK_KEY_, dst_local_rank, src_local_rank);
|
||||
}
|
||||
}
|
||||
close(dst[0]); // no longer needed
|
||||
dst[0] = -1;
|
||||
} else {
|
||||
// The rest sends its FD to and receives the other FDs from rank 0
|
||||
detail::broker::send_fd(sock_, data_fd, SOCK_KEY_, 0, local_rank_);
|
||||
close(data_fd); // no longer needed
|
||||
for (int i = 0; i < local_world_size_ - 1; i++) {
|
||||
int received_fd;
|
||||
int src_local_rank;
|
||||
detail::broker::recv_fd(sock_, &received_fd, &src_local_rank);
|
||||
if (received_fd < 0)
|
||||
throw std::runtime_error("KittensBroker: Failed to receive FD over socket");
|
||||
if (src_local_rank == local_rank_)
|
||||
throw std::runtime_error("KittensBroker: Invalid source rank");
|
||||
dst[src_local_rank] = received_fd;
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure all processes exit together
|
||||
sync();
|
||||
}
|
||||
|
||||
__host__ inline void broadcast_fd(int *dst, const int data_fd, const int src_local_rank) {
|
||||
if (src_local_rank < 0 || src_local_rank >= local_world_size_)
|
||||
throw std::runtime_error("KittensBroker: Invalid source rank");
|
||||
|
||||
// Ensure all processes enter together
|
||||
sync();
|
||||
|
||||
if (local_rank_ == src_local_rank) {
|
||||
if (data_fd < 0)
|
||||
throw std::runtime_error("KittensBroker: Source rank has invalid FD");
|
||||
for (int dst_local_rank = 0; dst_local_rank < local_world_size_; dst_local_rank++) {
|
||||
if (dst_local_rank == src_local_rank)
|
||||
continue;
|
||||
detail::broker::send_fd(sock_, data_fd, SOCK_KEY_, dst_local_rank, src_local_rank);
|
||||
}
|
||||
close(data_fd); // no longer needed
|
||||
} else {
|
||||
if (!dst)
|
||||
throw std::runtime_error("KittensBroker: Destination rank has invalid buffer");
|
||||
int _src_local_rank;
|
||||
detail::broker::recv_fd(sock_, dst, &_src_local_rank);
|
||||
if (*dst < 0)
|
||||
throw std::runtime_error("KittensBroker: Failed to receive valid FD over socket");
|
||||
if (_src_local_rank != src_local_rank)
|
||||
throw std::runtime_error("KittensBroker: Invalid source rank");
|
||||
}
|
||||
|
||||
// Ensure all processes exit together
|
||||
sync();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace kittens
|
||||
122
extra/thunder/cuda/include/pyutils/club.cuh
Normal file
122
extra/thunder/cuda/include/pyutils/club.cuh
Normal file
@@ -0,0 +1,122 @@
|
||||
#include <condition_variable>
|
||||
#include <functional>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
/*
|
||||
CUDA-specific ThreadPool
|
||||
|
||||
Example usage
|
||||
|
||||
// Construction
|
||||
KittensClub club(device_ids, NUM_DEVICES);
|
||||
|
||||
// Dispatch work to all threads (no need to set device)
|
||||
club.execute([&](int dev_idx) {
|
||||
int dev;
|
||||
CUDACHECK(cudaGetDevice(&dev));
|
||||
if (dev != dev_idx) {
|
||||
fprintf(stderr, "Device mismatch: expected %d, got %d\n", dev_idx, dev);
|
||||
exit(1);
|
||||
}
|
||||
});
|
||||
*/
|
||||
class KittensClub {
|
||||
public:
|
||||
__host__ inline KittensClub(const int *device_ids, const int num_devices);
|
||||
__host__ inline KittensClub(const int *device_ids, const cudaStream_t *streams, const int num_devices);
|
||||
__host__ inline ~KittensClub();
|
||||
|
||||
// Dispatches `task` to all threads, and waits for all threads to finish (using cv)
|
||||
__host__ inline void execute(std::function<void(int, cudaStream_t)> task);
|
||||
|
||||
private:
|
||||
// Condition indicators
|
||||
bool stop;
|
||||
std::vector<bool> task_available;
|
||||
int n_task_done;
|
||||
|
||||
// Threadpool
|
||||
std::vector<std::thread> workers;
|
||||
|
||||
// Streams for each device
|
||||
std::vector<cudaStream_t> streams;
|
||||
|
||||
// Main entry point for each thread
|
||||
__host__ inline void worker(int worker_id, int device_id);
|
||||
|
||||
// Used to dispatch work to all threads
|
||||
std::function<void(int, cudaStream_t)> current_task;
|
||||
|
||||
// Synchronization
|
||||
std::mutex mutex;
|
||||
std::condition_variable cond_task_available;
|
||||
std::condition_variable cond_task_done;
|
||||
};
|
||||
|
||||
__host__ inline KittensClub::KittensClub(const int *device_ids, const int num_devices) : stop(false), n_task_done(0) {
|
||||
for (size_t dev_idx = 0; dev_idx < num_devices; ++dev_idx) {
|
||||
task_available.push_back(false);
|
||||
streams.push_back(0); // Use default stream (null stream)
|
||||
workers.emplace_back([this, dev_idx, device_ids] { worker(dev_idx, device_ids[dev_idx]); });
|
||||
}
|
||||
}
|
||||
|
||||
__host__ inline KittensClub::KittensClub(const int *device_ids, const cudaStream_t *streams_in, const int num_devices) : stop(false), n_task_done(0) {
|
||||
for (size_t dev_idx = 0; dev_idx < num_devices; ++dev_idx) {
|
||||
task_available.push_back(false);
|
||||
streams.push_back(streams_in[dev_idx]);
|
||||
workers.emplace_back([this, dev_idx, device_ids] { worker(dev_idx, device_ids[dev_idx]); });
|
||||
}
|
||||
}
|
||||
|
||||
__host__ inline KittensClub::~KittensClub() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
stop = true;
|
||||
}
|
||||
cond_task_available.notify_all();
|
||||
for (std::thread &worker : workers) {
|
||||
worker.join();
|
||||
}
|
||||
}
|
||||
|
||||
__host__ inline void KittensClub::execute(std::function<void(int, cudaStream_t)> task) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
current_task = task;
|
||||
for (size_t i = 0; i < task_available.size(); ++i)
|
||||
task_available[i] = true;
|
||||
}
|
||||
cond_task_available.notify_all();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
cond_task_done.wait(lock, [this] { return n_task_done == workers.size(); });
|
||||
n_task_done = 0;
|
||||
}
|
||||
}
|
||||
|
||||
__host__ inline void KittensClub::worker(int worker_id, int device_id) {
|
||||
cudaSetDevice(device_id); // done once and never again! This saves a LOT of time
|
||||
while (true) {
|
||||
std::function<void(int, cudaStream_t)> task;
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
cond_task_available.wait(lock, [this, worker_id] { return stop || task_available[worker_id]; });
|
||||
|
||||
if (stop)
|
||||
return;
|
||||
|
||||
task = current_task;
|
||||
task_available[worker_id] = false;
|
||||
}
|
||||
task(worker_id, streams[worker_id]);
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex); // adds about 10 microseconds overhead
|
||||
++n_task_done;
|
||||
if (n_task_done == workers.size())
|
||||
cond_task_done.notify_one();
|
||||
}
|
||||
}
|
||||
}
|
||||
336
extra/thunder/cuda/include/pyutils/parallel_tensor.cuh
Normal file
336
extra/thunder/cuda/include/pyutils/parallel_tensor.cuh
Normal file
@@ -0,0 +1,336 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/ops/from_blob.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
#include "../types/device/vmm.cuh"
|
||||
#include "../types/device/ipc.cuh"
|
||||
#include "broker.cuh"
|
||||
|
||||
namespace kittens {
|
||||
namespace py {
|
||||
|
||||
/**
|
||||
* @brief Distributed tensor wrapper for multi-GPU IPC sharing and multicast.
|
||||
* Can be later used for easy PGL creation right before a kernel call.
|
||||
* Meant to be used as a single object per thread/process.
|
||||
*/
|
||||
struct TKParallelTensor {
|
||||
inline static std::map<std::pair<int, int>, KittensBroker> brokers_; // lazily initialized
|
||||
|
||||
at::Tensor data_; // for direct access from PyTorch
|
||||
std::vector<int64_t> shape_;
|
||||
at::ScalarType dtype_;
|
||||
|
||||
std::vector<void *> raw_ptrs_;
|
||||
size_t allocated_size_;
|
||||
|
||||
int local_rank_; // identical to device index
|
||||
int local_world_size_;
|
||||
|
||||
bool multicast_;
|
||||
void *multicast_ptr_;
|
||||
size_t multicast_allocated_size_;
|
||||
|
||||
detail::ipc::flavor ipc_flavor_;
|
||||
|
||||
__host__ inline TKParallelTensor(
|
||||
const at::Tensor &tensor,
|
||||
int local_rank,
|
||||
int local_world_size,
|
||||
bool multicast
|
||||
) : data_(tensor),
|
||||
shape_(tensor.sizes().vec()),
|
||||
dtype_(tensor.scalar_type()),
|
||||
raw_ptrs_(local_world_size, nullptr),
|
||||
allocated_size_(tensor.nbytes()),
|
||||
local_rank_(local_rank),
|
||||
local_world_size_(local_world_size),
|
||||
multicast_(multicast),
|
||||
multicast_ptr_(nullptr),
|
||||
multicast_allocated_size_(0),
|
||||
ipc_flavor_(detail::ipc::flavor::LEGACY) {
|
||||
|
||||
TORCH_CHECK(tensor.is_cuda(), "Tensor must be on CUDA device");
|
||||
TORCH_CHECK(tensor.is_contiguous(), "Tensor must be contiguous");
|
||||
TORCH_CHECK(tensor.dim() <= 4, "Only tensors with dim <= 4 are supported for TKParallelTensor");
|
||||
TORCH_CHECK(tensor.device().index() == local_rank_, "Tensor device index must match local_rank");
|
||||
TORCH_CHECK(local_rank_ >= 0, "local_rank must be non-negative");
|
||||
TORCH_CHECK(local_rank_ < local_world_size_, "local_rank must be less than local_world_size");
|
||||
TORCH_CHECK(!multicast, "Multicast is not supported for pre-allocated tensors");
|
||||
|
||||
brokers_.try_emplace(
|
||||
{local_rank_, local_world_size_},
|
||||
local_rank_, local_world_size_
|
||||
);
|
||||
|
||||
if (brokers_.size() > 1)
|
||||
std::cerr << "WARNING: 2 KittensBroker instances created in the same process. This is not safe." << std::endl;
|
||||
|
||||
c10::cuda::CUDAGuard device_guard(local_rank_);
|
||||
exchange_ipc_handles<detail::ipc::flavor::LEGACY>();
|
||||
}
|
||||
|
||||
__host__ inline TKParallelTensor(
|
||||
const std::vector<int64_t> &shape,
|
||||
const at::ScalarType dtype,
|
||||
int local_rank,
|
||||
int local_world_size,
|
||||
bool multicast
|
||||
) : shape_(shape),
|
||||
dtype_(dtype),
|
||||
raw_ptrs_(local_world_size, nullptr),
|
||||
allocated_size_(0),
|
||||
local_rank_(local_rank),
|
||||
local_world_size_(local_world_size),
|
||||
multicast_(multicast),
|
||||
multicast_ptr_(nullptr),
|
||||
multicast_allocated_size_(0),
|
||||
ipc_flavor_(detail::ipc::flavor::VMM) {
|
||||
|
||||
TORCH_CHECK(local_rank_ >= 0, "local_rank must be non-negative");
|
||||
TORCH_CHECK(local_rank_ < local_world_size_, "local_rank must be less than local_world_size");
|
||||
|
||||
brokers_.try_emplace(
|
||||
{local_rank_, local_world_size_},
|
||||
local_rank_, local_world_size_
|
||||
);
|
||||
|
||||
if (brokers_.size() > 1)
|
||||
std::cerr << "WARNING: 2 KittensBroker instances created in the same process. This is not safe." << std::endl;
|
||||
|
||||
c10::cuda::CUDAGuard device_guard(local_rank_);
|
||||
create_shareable_cuda_tensor();
|
||||
exchange_ipc_handles<detail::ipc::flavor::VMM>();
|
||||
|
||||
if (multicast_)
|
||||
initialize_multicast();
|
||||
}
|
||||
|
||||
TKParallelTensor(const TKParallelTensor&) = delete;
|
||||
TKParallelTensor& operator=(const TKParallelTensor&) = delete;
|
||||
TKParallelTensor& operator=(TKParallelTensor&& other) = delete;
|
||||
|
||||
__host__ inline TKParallelTensor(TKParallelTensor&& other) :
|
||||
data_(std::move(other.data_)),
|
||||
shape_(std::move(other.shape_)),
|
||||
dtype_(std::move(other.dtype_)),
|
||||
raw_ptrs_(std::move(other.raw_ptrs_)),
|
||||
allocated_size_(other.allocated_size_),
|
||||
local_rank_(other.local_rank_),
|
||||
local_world_size_(other.local_world_size_),
|
||||
multicast_(other.multicast_),
|
||||
multicast_ptr_(other.multicast_ptr_),
|
||||
multicast_allocated_size_(other.multicast_allocated_size_),
|
||||
ipc_flavor_(other.ipc_flavor_) {
|
||||
other.data_ = at::Tensor();
|
||||
other.shape_.clear();
|
||||
other.dtype_ = at::ScalarType::Undefined;
|
||||
other.raw_ptrs_.clear();
|
||||
other.allocated_size_ = 0;
|
||||
other.local_rank_ = -1;
|
||||
other.local_world_size_ = -1;
|
||||
other.multicast_ = false;
|
||||
other.multicast_ptr_ = nullptr;
|
||||
other.multicast_allocated_size_ = 0;
|
||||
}
|
||||
|
||||
__host__ inline ~TKParallelTensor() {
|
||||
destroy();
|
||||
}
|
||||
|
||||
__host__ inline at::Tensor data() const {
|
||||
return data_;
|
||||
}
|
||||
|
||||
__host__ inline void create_shareable_cuda_tensor() {
|
||||
c10::cuda::CUDAGuard device_guard(local_rank_);
|
||||
|
||||
TORCH_CHECK(!shape_.empty(), "Shape must be non-empty");
|
||||
TORCH_CHECK(shape_.size() <= 4, "Shape must have at most 4 dimensions for TKParallelTensor");
|
||||
size_t size = c10::elementSize(dtype_);
|
||||
for (auto dim : shape_) {
|
||||
TORCH_CHECK(dim > 0, "Size dimensions must be positive");
|
||||
size *= static_cast<size_t>(dim);
|
||||
}
|
||||
|
||||
void *raw_ptr;
|
||||
detail::vmm::vm_alloc_map_set_access(
|
||||
&raw_ptr, &allocated_size_, size, local_rank_, local_world_size_);
|
||||
|
||||
// Create local copies for capture
|
||||
int local_rank = local_rank_;
|
||||
size_t allocated_size = allocated_size_;
|
||||
|
||||
auto deleter = [local_rank, raw_ptr, allocated_size](void* p) mutable {
|
||||
if (!p) return;
|
||||
c10::cuda::CUDAGuard device_guard(local_rank);
|
||||
auto stream = c10::cuda::getCurrentCUDAStream().stream();
|
||||
CUDACHECK(cudaStreamSynchronize(stream));
|
||||
detail::vmm::vm_unmap(raw_ptr, allocated_size);
|
||||
};
|
||||
|
||||
at::TensorOptions options = at::TensorOptions()
|
||||
.dtype(dtype_)
|
||||
.device(at::kCUDA, local_rank_);
|
||||
|
||||
data_ = at::from_blob(raw_ptr, shape_, std::move(deleter), options);
|
||||
}
|
||||
|
||||
template <detail::ipc::flavor IPC_FLAVOR>
|
||||
__host__ inline void exchange_ipc_handles() {
|
||||
using handle_t = detail::ipc::handle<IPC_FLAVOR>;
|
||||
|
||||
// Get IPC handle
|
||||
detail::ipc::check_support(local_rank_);
|
||||
void *raw_ptr = reinterpret_cast<void *>(data_.data_ptr());
|
||||
handle_t ipc_handle;
|
||||
detail::ipc::export_handle(&ipc_handle, raw_ptr);
|
||||
|
||||
// Exchange IPC handles
|
||||
std::vector<handle_t> all_ipc_handles(local_world_size_);
|
||||
if constexpr (IPC_FLAVOR == detail::ipc::flavor::LEGACY) {
|
||||
brokers_.at({local_rank_, local_world_size_}).exchange_data(
|
||||
reinterpret_cast<void *>(all_ipc_handles.data()),
|
||||
reinterpret_cast<void *>(&ipc_handle),
|
||||
sizeof(handle_t)
|
||||
);
|
||||
} else if constexpr (IPC_FLAVOR == detail::ipc::flavor::VMM) {
|
||||
brokers_.at({local_rank_, local_world_size_}).exchange_fds(
|
||||
reinterpret_cast<int *>(all_ipc_handles.data()),
|
||||
ipc_handle.handle_
|
||||
);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid IPC flavor");
|
||||
}
|
||||
|
||||
// Import IPC handles
|
||||
for (int i = 0; i < local_world_size_; i++) {
|
||||
if (i == local_rank_)
|
||||
raw_ptrs_[i] = raw_ptr;
|
||||
else
|
||||
detail::ipc::import_handle(&raw_ptrs_[i], all_ipc_handles[i], allocated_size_, local_world_size_);
|
||||
}
|
||||
}
|
||||
|
||||
__host__ inline void initialize_multicast() {
|
||||
using handle_t = detail::ipc::handle<detail::ipc::flavor::VMM>;
|
||||
|
||||
detail::vmm::multicast_check(local_rank_);
|
||||
detail::ipc::check_support(local_rank_);
|
||||
detail::vmm::handle multicast_handle;
|
||||
|
||||
if (local_rank_ == 0) {
|
||||
// Create multicast handle; only a single rank should create MC handle
|
||||
detail::vmm::multicast_create_handle(
|
||||
&multicast_handle,
|
||||
&multicast_allocated_size_,
|
||||
allocated_size_,
|
||||
local_world_size_
|
||||
);
|
||||
|
||||
// Currently, non-rank-0 path assumes allocated_size_ == multicast_allocated_size_
|
||||
if (allocated_size_ != multicast_allocated_size_)
|
||||
throw std::runtime_error("Multicast allocated size does not match memory allocated size");
|
||||
|
||||
// Get IPC handle
|
||||
handle_t ipc_handle;
|
||||
detail::ipc::export_handle(&ipc_handle, multicast_handle);
|
||||
|
||||
// Broadcast the IPC multicast handle
|
||||
brokers_.at({local_rank_, local_world_size_}).broadcast_fd(nullptr, ipc_handle.handle_, 0);
|
||||
} else {
|
||||
// Receive the IPC multicast handle from rank 0
|
||||
handle_t ipc_handle;
|
||||
brokers_.at({local_rank_, local_world_size_}).broadcast_fd(&ipc_handle.handle_, -1, 0);
|
||||
multicast_allocated_size_ = allocated_size_;
|
||||
detail::ipc::import_handle(&multicast_handle, ipc_handle, multicast_allocated_size_, local_world_size_);
|
||||
}
|
||||
|
||||
// Add all devices to the MC handle. Must sync
|
||||
detail::vmm::multicast_bind_device(multicast_handle, local_rank_);
|
||||
brokers_.at({local_rank_, local_world_size_}).sync(); // must ensure all devices are added
|
||||
|
||||
// Bind all memory to the MC handle and map to a virtual address; must be done after adding all devices
|
||||
detail::vmm::handle memory_handle;
|
||||
detail::vmm::vm_retrieve_handle(&memory_handle, raw_ptrs_[local_rank_]);
|
||||
detail::vmm::multicast_bind_memory(multicast_handle, memory_handle, allocated_size_);
|
||||
brokers_.at({local_rank_, local_world_size_}).sync();
|
||||
|
||||
// Map virtual address to multicast handle and set access; must be done after adding all devices
|
||||
detail::vmm::vm_map(&multicast_ptr_, multicast_handle, multicast_allocated_size_);
|
||||
detail::vmm::vm_set_access(multicast_ptr_, multicast_allocated_size_, local_world_size_);
|
||||
|
||||
// Free the handles immediately
|
||||
detail::vmm::vm_free(multicast_handle);
|
||||
detail::vmm::vm_free(memory_handle);
|
||||
}
|
||||
|
||||
__host__ inline void destroy() {
|
||||
// 1. Multicast cleanup
|
||||
if (multicast_ && multicast_ptr_) {
|
||||
brokers_.at({local_rank_, local_world_size_}).sync();
|
||||
detail::vmm::handle multicast_handle;
|
||||
detail::vmm::vm_retrieve_handle(&multicast_handle, multicast_ptr_);
|
||||
detail::vmm::vm_unmap(multicast_ptr_, multicast_allocated_size_);
|
||||
detail::vmm::multicast_unbind_device(multicast_handle, multicast_allocated_size_, local_rank_);
|
||||
brokers_.at({local_rank_, local_world_size_}).sync();
|
||||
detail::vmm::vm_free(multicast_handle);
|
||||
}
|
||||
|
||||
// 2. Imported handle cleanup
|
||||
for (int i = 0; i < local_world_size_; i++) {
|
||||
if (i != local_rank_ && i < raw_ptrs_.size()) {
|
||||
if (ipc_flavor_ == detail::ipc::flavor::LEGACY) {
|
||||
detail::ipc::free_handle<detail::ipc::flavor::LEGACY>(raw_ptrs_[i], allocated_size_);
|
||||
} else if (ipc_flavor_ == detail::ipc::flavor::VMM) {
|
||||
detail::ipc::free_handle<detail::ipc::flavor::VMM>(raw_ptrs_[i], allocated_size_);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid IPC flavor");
|
||||
}
|
||||
}
|
||||
}
|
||||
brokers_.at({local_rank_, local_world_size_}).sync(); // must sync before destroying the tensor
|
||||
|
||||
// 3. Tensor cleanup
|
||||
if (data_.defined())
|
||||
data_.reset(); // properly decreases the ref count
|
||||
|
||||
// 4. Member variables cleanup
|
||||
shape_.clear();
|
||||
dtype_ = at::ScalarType::Undefined;
|
||||
raw_ptrs_.clear();
|
||||
allocated_size_ = 0;
|
||||
local_rank_ = -1;
|
||||
local_world_size_ = -1;
|
||||
multicast_ = false;
|
||||
multicast_ptr_ = nullptr;
|
||||
multicast_allocated_size_ = 0;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace py
|
||||
} // namespace kittens
|
||||
|
||||
#define BIND_TK_PARALLEL_TENSOR(m) \
|
||||
pybind11::class_<kittens::py::TKParallelTensor>(m, "TKParallelTensor") \
|
||||
.def(pybind11::init<const at::Tensor&, int, int, bool>(), \
|
||||
pybind11::arg("tensor"), \
|
||||
pybind11::arg("local_rank"), \
|
||||
pybind11::arg("local_world_size"), \
|
||||
pybind11::arg("multicast") = false) \
|
||||
.def(pybind11::init<const std::vector<int64_t>&, const at::ScalarType&, int, int, bool>(), \
|
||||
pybind11::arg("shape"), \
|
||||
pybind11::arg("dtype"), \
|
||||
pybind11::arg("local_rank"), \
|
||||
pybind11::arg("local_world_size"), \
|
||||
pybind11::arg("multicast") = false) \
|
||||
.def("data", &kittens::py::TKParallelTensor::data) \
|
||||
.def_readonly("data_", &kittens::py::TKParallelTensor::data_) \
|
||||
.def_readonly("local_rank_", &kittens::py::TKParallelTensor::local_rank_) \
|
||||
.def_readonly("local_world_size_", &kittens::py::TKParallelTensor::local_world_size_)
|
||||
235
extra/thunder/cuda/include/pyutils/pyutils.cuh
Normal file
235
extra/thunder/cuda/include/pyutils/pyutils.cuh
Normal file
@@ -0,0 +1,235 @@
|
||||
#pragma once
|
||||
|
||||
#include "util.cuh"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h> // for automatic Python list -> std::vector conversion
|
||||
|
||||
namespace kittens {
|
||||
namespace py {
|
||||
|
||||
template<typename T> struct from_object {
|
||||
static T make(pybind11::object obj) {
|
||||
return obj.cast<T>();
|
||||
}
|
||||
static T unwrap(pybind11::object obj, int dev_idx) {
|
||||
return make(obj); // Scalars should be passed in as a scalar
|
||||
}
|
||||
};
|
||||
template<ducks::gl::all GL> struct from_object<GL> {
|
||||
static GL make(pybind11::object obj) {
|
||||
// Check if argument is a torch.Tensor
|
||||
if (pybind11::hasattr(obj, "__class__") &&
|
||||
obj.attr("__class__").attr("__name__").cast<std::string>() == "Tensor") {
|
||||
|
||||
// Check if tensor is contiguous
|
||||
if (!obj.attr("is_contiguous")().cast<bool>()) {
|
||||
throw std::runtime_error("Tensor must be contiguous");
|
||||
}
|
||||
if (obj.attr("device").attr("type").cast<std::string>() == "cpu") {
|
||||
throw std::runtime_error("Tensor must be on CUDA device");
|
||||
}
|
||||
|
||||
// Get shape, pad with 1s if needed
|
||||
std::array<int, 4> shape = {1, 1, 1, 1};
|
||||
auto py_shape = obj.attr("shape").cast<pybind11::tuple>();
|
||||
size_t dims = py_shape.size();
|
||||
if (dims > 4) {
|
||||
throw std::runtime_error("Expected Tensor.ndim <= 4");
|
||||
}
|
||||
for (size_t i = 0; i < dims; ++i) {
|
||||
shape[4 - dims + i] = pybind11::cast<int>(py_shape[i]);
|
||||
}
|
||||
|
||||
// Get data pointer using data_ptr()
|
||||
uint64_t data_ptr = obj.attr("data_ptr")().cast<uint64_t>();
|
||||
|
||||
// Create GL object using make_gl
|
||||
return make_gl<GL>(data_ptr, shape[0], shape[1], shape[2], shape[3]);
|
||||
}
|
||||
throw std::runtime_error("Expected a torch.Tensor");
|
||||
}
|
||||
static GL unwrap(pybind11::object obj, int dev_idx) {
|
||||
if (!pybind11::isinstance<pybind11::list>(obj))
|
||||
throw std::runtime_error("GL unwrap expected a Python list.");
|
||||
pybind11::list lst = pybind11::cast<pybind11::list>(obj);
|
||||
if (dev_idx >= lst.size())
|
||||
throw std::runtime_error("Device index out of bounds.");
|
||||
return *lst[dev_idx].cast<std::shared_ptr<GL>>();
|
||||
}
|
||||
};
|
||||
template<ducks::pgl::all PGL> struct from_object<PGL> {
|
||||
static PGL make(pybind11::object obj) {
|
||||
static_assert(!PGL::MULTICAST, "Multicast not yet supported on pyutils. Please initialize the multicast pointer manually.");
|
||||
if (!pybind11::isinstance<pybind11::list>(obj))
|
||||
throw std::runtime_error("PGL from_object expected a Python list.");
|
||||
pybind11::list tensors = pybind11::cast<pybind11::list>(obj);
|
||||
if (tensors.size() != PGL::num_devices)
|
||||
throw std::runtime_error("Expected a list of " + std::to_string(PGL::num_devices) + " tensors");
|
||||
std::array<int, 4> shape = {1, 1, 1, 1};
|
||||
uint64_t data_ptrs[PGL::num_devices];
|
||||
for (int i = 0; i < PGL::num_devices; i++) {
|
||||
auto tensor = tensors[i];
|
||||
if (!pybind11::hasattr(tensor, "__class__") ||
|
||||
tensor.attr("__class__").attr("__name__").cast<std::string>() != "Tensor")
|
||||
throw std::runtime_error("Expected a list of torch.Tensor");
|
||||
if (!tensor.attr("is_contiguous")().cast<bool>())
|
||||
throw std::runtime_error("Tensor must be contiguous");
|
||||
if (tensor.attr("device").attr("type").cast<std::string>() == "cpu")
|
||||
throw std::runtime_error("Tensor must be on CUDA device");
|
||||
auto py_shape = tensor.attr("shape").cast<pybind11::tuple>();
|
||||
size_t dims = py_shape.size();
|
||||
if (dims > 4)
|
||||
throw std::runtime_error("Expected Tensor.ndim <= 4");
|
||||
for (size_t j = 0; j < dims; ++j) {
|
||||
if (i == 0)
|
||||
shape[4 - dims + j] = pybind11::cast<int>(py_shape[j]);
|
||||
else if (shape[4 - dims + j] != pybind11::cast<int>(py_shape[j]))
|
||||
throw std::runtime_error("All tensors must have the same shape");
|
||||
}
|
||||
data_ptrs[i] = tensor.attr("data_ptr")().cast<uint64_t>();
|
||||
}
|
||||
return make_pgl<PGL>(data_ptrs, shape[0], shape[1], shape[2], shape[3]);
|
||||
}
|
||||
static PGL unwrap(pybind11::object obj, int dev_idx) {
|
||||
return *obj.cast<std::shared_ptr<PGL>>();
|
||||
}
|
||||
};
|
||||
|
||||
static std::unordered_set<std::string> registered;
|
||||
template<typename T> static void register_pyclass(pybind11::module &m) {
|
||||
if constexpr (ducks::gl::all<T> || ducks::pgl::all<T>) {
|
||||
std::string _typename = typeid(T).name();
|
||||
if (registered.find(_typename) == registered.end()) {
|
||||
pybind11::class_<T, std::shared_ptr<T>>(m, _typename.c_str());
|
||||
registered.insert(_typename);
|
||||
}
|
||||
}
|
||||
}
|
||||
template<typename T> static pybind11::object multigpu_make(pybind11::object obj) {
|
||||
if constexpr (ducks::gl::all<T>) {
|
||||
if (!pybind11::isinstance<pybind11::list>(obj))
|
||||
throw std::runtime_error("multigpu_make [GL] expected a Python list.");
|
||||
pybind11::list lst = pybind11::cast<pybind11::list>(obj);
|
||||
std::vector<std::shared_ptr<T>> gls;
|
||||
for (int i = 0; i < lst.size(); i++)
|
||||
gls.push_back(std::make_shared<T>(from_object<T>::make(lst[i])));
|
||||
return pybind11::cast(gls);
|
||||
} else if constexpr (ducks::pgl::all<T>) {
|
||||
return pybind11::cast(std::make_shared<T>(from_object<T>::make(obj)));
|
||||
} else {
|
||||
return pybind11::cast(from_object<T>::make(obj));
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T> concept has_dynamic_shared_memory = requires(T t) { { t.dynamic_shared_memory() } -> std::convertible_to<int>; };
|
||||
template<typename T> concept is_multigpu_globals = requires {
|
||||
{ T::num_devices } -> std::convertible_to<std::size_t>;
|
||||
{ T::dev_idx } -> std::convertible_to<std::size_t>;
|
||||
} && T::num_devices >= 1;
|
||||
|
||||
template<typename> struct trait;
|
||||
template<typename MT, typename T> struct trait<MT T::*> { using member_type = MT; using type = T; };
|
||||
template<typename> using object = pybind11::object;
|
||||
template<auto kernel, typename TGlobal> static void bind_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) {
|
||||
m.def(name, [](object<decltype(member_ptrs)>... args, pybind11::kwargs kwargs) {
|
||||
TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
|
||||
cudaStream_t raw_stream = nullptr;
|
||||
if (kwargs.contains("stream")) {
|
||||
// Extract stream pointer
|
||||
uintptr_t stream_ptr = kwargs["stream"].attr("cuda_stream").cast<uintptr_t>();
|
||||
raw_stream = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
}
|
||||
if constexpr (has_dynamic_shared_memory<TGlobal>) {
|
||||
int __dynamic_shared_memory__ = (int)__g__.dynamic_shared_memory();
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__);
|
||||
kernel<<<__g__.grid(), __g__.block(), __dynamic_shared_memory__, raw_stream>>>(__g__);
|
||||
} else {
|
||||
kernel<<<__g__.grid(), __g__.block(), 0, raw_stream>>>(__g__);
|
||||
}
|
||||
});
|
||||
}
|
||||
template<auto function, typename TGlobal> static void bind_function(auto m, auto name, auto TGlobal::*... member_ptrs) {
|
||||
m.def(name, [](object<decltype(member_ptrs)>... args) {
|
||||
TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
|
||||
function(__g__);
|
||||
});
|
||||
}
|
||||
static void bind_multigpu_boilerplate(auto m) {
|
||||
m.def("enable_all_p2p_access", [](const std::vector<int>& device_ids) {
|
||||
int device_count;
|
||||
CUDACHECK(cudaGetDeviceCount(&device_count));
|
||||
if (device_count < device_ids.size())
|
||||
throw std::runtime_error("Not enough CUDA devices available");
|
||||
for (int i = 0; i < device_ids.size(); i++) {
|
||||
CUDACHECK(cudaSetDevice(device_ids[i]));
|
||||
for (int j = 0; j < device_ids.size(); j++) {
|
||||
if (i == j) continue;
|
||||
int can_access = 0;
|
||||
CUDACHECK(cudaDeviceCanAccessPeer(&can_access, device_ids[i], device_ids[j]));
|
||||
if (!can_access)
|
||||
throw std::runtime_error("Device " + std::to_string(device_ids[i]) + " cannot access device " + std::to_string(device_ids[j]));
|
||||
cudaError_t res = cudaDeviceEnablePeerAccess(device_ids[j], 0);
|
||||
if (res != cudaSuccess && res != cudaErrorPeerAccessAlreadyEnabled) {
|
||||
CUDACHECK(res);
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
pybind11::class_<KittensClub, std::shared_ptr<KittensClub>>(m, "KittensClub")
|
||||
.def(pybind11::init([](const std::vector<int>& device_ids) {
|
||||
int device_count;
|
||||
CUDACHECK(cudaGetDeviceCount(&device_count));
|
||||
if (device_count < device_ids.size())
|
||||
throw std::runtime_error("Not enough CUDA devices available");
|
||||
auto club = std::make_shared<KittensClub>(device_ids.data(), device_ids.size());
|
||||
club->execute([&](int dev_idx, cudaStream_t stream) {}); // warmup
|
||||
return club;
|
||||
}), pybind11::arg("device_ids"))
|
||||
.def(pybind11::init([](const std::vector<int>& device_ids, const std::vector<pybind11::object>& streams) {
|
||||
int device_count;
|
||||
CUDACHECK(cudaGetDeviceCount(&device_count));
|
||||
if (device_count < device_ids.size())
|
||||
throw std::runtime_error("Not enough CUDA devices available");
|
||||
if (streams.size() != device_ids.size())
|
||||
throw std::runtime_error("Number of streams must match number of devices");
|
||||
|
||||
std::vector<cudaStream_t> raw_streams(streams.size());
|
||||
for (size_t i = 0; i < streams.size(); ++i) {
|
||||
uintptr_t stream_ptr = streams[i].attr("cuda_stream").cast<uintptr_t>();
|
||||
raw_streams[i] = reinterpret_cast<cudaStream_t>(stream_ptr);
|
||||
}
|
||||
|
||||
auto club = std::make_shared<KittensClub>(device_ids.data(), raw_streams.data(), device_ids.size());
|
||||
club->execute([&](int dev_idx, cudaStream_t stream) {}); // warmup
|
||||
return club;
|
||||
}), pybind11::arg("device_ids"), pybind11::arg("streams"));
|
||||
}
|
||||
template<auto kernel, typename TGlobal> static void bind_multigpu_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) {
|
||||
static_assert(is_multigpu_globals<TGlobal>, "Multigpu globals must have a member num_devices >= 1 and dev_idx");
|
||||
(register_pyclass<typename trait<decltype(member_ptrs)>::member_type>(m), ...);
|
||||
m.def((std::string("make_globals_")+name).c_str(), [](object<decltype(member_ptrs)>... args) -> std::vector<pybind11::object> {
|
||||
return {multigpu_make<typename trait<decltype(member_ptrs)>::member_type>(args)...};
|
||||
});
|
||||
m.def(name, [](std::shared_ptr<KittensClub> club, object<decltype(member_ptrs)>... args) {
|
||||
std::vector<TGlobal> __g__;
|
||||
for (int i = 0; i < TGlobal::num_devices; i++) {
|
||||
__g__.emplace_back(from_object<typename trait<decltype(member_ptrs)>::member_type>::unwrap(args, i)...);
|
||||
__g__.back().dev_idx = i;
|
||||
}
|
||||
if constexpr (has_dynamic_shared_memory<TGlobal>) {
|
||||
club->execute([&](int dev_idx, cudaStream_t stream) {
|
||||
int __dynamic_shared_memory__ = (int)__g__[dev_idx].dynamic_shared_memory();
|
||||
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__);
|
||||
kernel<<<__g__[dev_idx].grid(), __g__[dev_idx].block(), __dynamic_shared_memory__, stream>>>(__g__[dev_idx]);
|
||||
});
|
||||
} else {
|
||||
club->execute([&](int dev_idx, cudaStream_t stream) {
|
||||
kernel<<<__g__[dev_idx].grid(), __g__[dev_idx].block(), 0, stream>>>(__g__[dev_idx]);
|
||||
});
|
||||
}
|
||||
});
|
||||
// TODO: PGL destructor binding
|
||||
}
|
||||
|
||||
} // namespace py
|
||||
} // namespace kittens
|
||||
7
extra/thunder/cuda/include/pyutils/torch_helpers.cuh
Normal file
7
extra/thunder/cuda/include/pyutils/torch_helpers.cuh
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/extension.h>
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
180
extra/thunder/cuda/include/pyutils/torchutils.cuh
Normal file
180
extra/thunder/cuda/include/pyutils/torchutils.cuh
Normal file
@@ -0,0 +1,180 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
#include "kittens.cuh"
|
||||
#include "parallel_tensor.cuh"
|
||||
|
||||
namespace kittens {
|
||||
namespace py {
|
||||
|
||||
template <typename Config>
|
||||
concept has_min_blocks_per_sm = requires { std::integral_constant<int, int(Config::MIN_BLOCKS_PER_SM)>{}; };
|
||||
|
||||
template <typename Config>
|
||||
consteval int min_blocks_per_sm() {
|
||||
if constexpr(has_min_blocks_per_sm<Config>)
|
||||
return Config::MIN_BLOCKS_PER_SM;
|
||||
else
|
||||
return 1;
|
||||
}
|
||||
|
||||
template <typename Config, typename Globals, auto Kernel>
|
||||
__global__
|
||||
__launch_bounds__(Config::NUM_THREADS, min_blocks_per_sm<Config>())
|
||||
void global_kernel_unclustered(const __grid_constant__ Globals G) {
|
||||
Kernel(G);
|
||||
}
|
||||
|
||||
template <typename Config, typename Globals, auto Kernel>
|
||||
__global__
|
||||
__launch_bounds__(Config::NUM_THREADS, min_blocks_per_sm<Config>())
|
||||
__cluster_dims__(Config::CLUSTER_SIZE)
|
||||
void global_kernel_clustered(const __grid_constant__ Globals G) {
|
||||
Kernel(G);
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static inline void tensor_check(const at::Tensor &t) {
|
||||
TORCH_CHECK(t.is_cuda(), "Tensor must be on CUDA device")
|
||||
TORCH_CHECK(t.is_contiguous(), "Tensor must be contiguous")
|
||||
TORCH_CHECK(t.dim() <= 4, "Expected Tensor.dim() <= 4");
|
||||
|
||||
if constexpr (std::is_same_v<typename Layout::dtype, char>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Char, "Tensor has invalid dtype (expected int8)");
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, short>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Short, "Tensor has invalid dtype (expected int16)");
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, int>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Int, "Tensor has invalid dtype (expected int32)");
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, long>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Long, "Tensor has invalid dtype (expected int64)");
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, ::kittens::fp8e4m3>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Float8_e4m3fn, "Tensor has invalid dtype (expected fp8e4m3)");
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, ::kittens::fp8e5m2>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Float8_e5m2, "Tensor has invalid dtype (expected fp8e5m2)");
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, ::kittens::fp8e8m0>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Byte, "Tensor has invalid dtype (expected fp8e8m0 represented as uint8)");
|
||||
#endif
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, ::kittens::bf16>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::BFloat16, "Tensor has invalid dtype (expected bfloat16)");
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, ::kittens::half>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Half, "Tensor has invalid dtype (expected float16)");
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, float>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Float, "Tensor has invalid dtype (expected float32)");
|
||||
} else if constexpr (std::is_same_v<typename Layout::dtype, double>) {
|
||||
TORCH_CHECK(t.dtype() == at::ScalarType::Double, "Tensor has invalid dtype (expected float64)");
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported dtype");
|
||||
}
|
||||
}
|
||||
|
||||
template <kittens::ducks::pgl::all PGL>
|
||||
static inline void parallel_tensor_check(const TKParallelTensor& t) {
|
||||
tensor_check<PGL>(t.data_);
|
||||
TORCH_CHECK(t.data_.sizes().vec() == t.shape_, "Shape mismatch between TKParallelTensor and the underlying tensor");
|
||||
TORCH_CHECK(t.data_.dtype() == t.dtype_, "Dtype mismatch between TKParallelTensor and the underlying tensor");
|
||||
TORCH_CHECK(t.raw_ptrs_.size() == PGL::num_devices, "Number of devices mismatch between PGL and TKParallelTensor");
|
||||
TORCH_CHECK(t.local_rank_ == t.data_.device().index(), "Current tensor device index mismatch within TKParallelTensor");
|
||||
TORCH_CHECK(t.local_world_size_ == PGL::num_devices, "Number of devices mismatch between PGL and TKParallelTensor");
|
||||
TORCH_CHECK(t.multicast_ == PGL::multicast, "Multicast mismatch between PGL and TKParallelTensor");
|
||||
TORCH_CHECK(t.raw_ptrs_[t.local_rank_] == reinterpret_cast<void *>(t.data_.data_ptr()), "Current tensor data pointer not found in TKParallelTensor's raw_ptrs_");
|
||||
}
|
||||
|
||||
template <kittens::ducks::gl::all GL>
|
||||
static inline GL tensor_to_gl(const at::Tensor &t) {
|
||||
tensor_check<GL>(t);
|
||||
|
||||
std::array<int, 4> shape = {1, 1, 1, 1};
|
||||
for (int i = 0; i < static_cast<int>(t.dim()); ++i)
|
||||
shape[4 - t.dim() + i] = static_cast<int>(t.size(i));
|
||||
|
||||
uint64_t data_ptr = reinterpret_cast<uint64_t>(t.data_ptr());
|
||||
|
||||
return ::kittens::make_gl<GL>(data_ptr, shape[0], shape[1], shape[2], shape[3]);
|
||||
}
|
||||
|
||||
template <kittens::ducks::pgl::all PGL>
|
||||
static inline PGL parallel_tensor_to_pgl(TKParallelTensor &t) {
|
||||
parallel_tensor_check<PGL>(t);
|
||||
|
||||
std::array<int, 4> shape = {1, 1, 1, 1};
|
||||
for (int i = 0; i < static_cast<int>(t.data_.dim()); ++i) {
|
||||
shape[4 - t.data_.dim() + i] = static_cast<int>(t.data_.size(i));
|
||||
}
|
||||
|
||||
if constexpr (PGL::multicast)
|
||||
return ::kittens::make_pgl<PGL>(
|
||||
reinterpret_cast<uint64_t>(t.multicast_ptr_), reinterpret_cast<uint64_t *>(t.raw_ptrs_.data()), shape[0], shape[1], shape[2], shape[3]);
|
||||
else
|
||||
return ::kittens::make_pgl<PGL>(
|
||||
reinterpret_cast<uint64_t *>(t.raw_ptrs_.data()), shape[0], shape[1], shape[2], shape[3]);
|
||||
}
|
||||
|
||||
template <kittens::ducks::gl::all GL>
|
||||
static inline GL make_fake_gl(const int batch, const int depth, const int rows, const int cols) {
|
||||
return ::kittens::make_gl<GL>(reinterpret_cast<uint64_t>(nullptr), batch, depth, rows, cols);
|
||||
}
|
||||
|
||||
static inline void _device_check(const at::Tensor& first, const at::Tensor& second) {
|
||||
TORCH_CHECK(first.device() == second.device(), "All tensors must be on the same device");
|
||||
}
|
||||
|
||||
template <typename T1, typename... Ts>
|
||||
static inline void device_check(const T1& first, const Ts&... rest) {
|
||||
(_device_check(first, rest), ...);
|
||||
}
|
||||
|
||||
static inline void _parallel_tensor_check(const TKParallelTensor& first, const TKParallelTensor& second) {
|
||||
TORCH_CHECK(first.local_rank_ == second.local_rank_, "All parallel tensors must have the same local_rank");
|
||||
TORCH_CHECK(first.local_world_size_ == second.local_world_size_, "All parallel tensors must have the same local_world_size");
|
||||
}
|
||||
|
||||
template <typename T1, typename... Ts>
|
||||
static inline void parallel_tensor_check(const T1& first, const Ts&... rest) {
|
||||
(_parallel_tensor_check(first, rest), ...);
|
||||
}
|
||||
|
||||
template <typename Config>
|
||||
concept static_grid = requires { Config::NUM_BLOCKS; };
|
||||
|
||||
template <typename Config>
|
||||
concept static_block = requires { Config::NUM_THREADS; };
|
||||
|
||||
template <typename Config>
|
||||
concept static_dynamic_shared_memory = requires { Config::DYNAMIC_SHARED_MEMORY; };
|
||||
|
||||
template <typename Config, typename Globals, auto Kernel>
|
||||
static inline void launch_kernel(const Globals &G) {
|
||||
dim3 grid;
|
||||
if constexpr (static_grid<Config>)
|
||||
grid = dim3{Config::NUM_BLOCKS, 1, 1};
|
||||
else
|
||||
grid = G.grid();
|
||||
|
||||
dim3 block;
|
||||
if constexpr (static_block<Config>)
|
||||
block = dim3{Config::NUM_THREADS, 1, 1};
|
||||
else
|
||||
block = G.block();
|
||||
|
||||
int dynamic_shared_memory;
|
||||
if constexpr (static_dynamic_shared_memory<Config>)
|
||||
dynamic_shared_memory = static_cast<int>(Config::DYNAMIC_SHARED_MEMORY);
|
||||
else
|
||||
dynamic_shared_memory = G.dynamic_shared_memory();
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if constexpr (Config::CLUSTER_SIZE <= 1) {
|
||||
CUDACHECK(cudaFuncSetAttribute(global_kernel_unclustered<Config, Globals, Kernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_shared_memory));
|
||||
global_kernel_unclustered<Config, Globals, Kernel><<<grid, block, dynamic_shared_memory, stream>>>(G);
|
||||
} else {
|
||||
CUDACHECK(cudaFuncSetAttribute(global_kernel_clustered<Config, Globals, Kernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_shared_memory));
|
||||
global_kernel_clustered<Config, Globals, Kernel><<<grid, block, dynamic_shared_memory, stream>>>(G);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace py
|
||||
} // namespace kittens
|
||||
19
extra/thunder/cuda/include/pyutils/util.cuh
Normal file
19
extra/thunder/cuda/include/pyutils/util.cuh
Normal file
@@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include "../ops/ops.cuh"
|
||||
#include "club.cuh"
|
||||
#include <iostream>
|
||||
|
||||
#define CHECK_CUDA_ERROR(val) check((val), #val, __FILE__, __LINE__)
|
||||
template <typename T>
|
||||
void check(T err, char const* const func, char const* const file,
|
||||
int const line)
|
||||
{
|
||||
if (err != cudaSuccess)
|
||||
{
|
||||
std::cerr << "CUDA Runtime Error at: " << file << ":" << line
|
||||
<< std::endl;
|
||||
std::cerr << cudaGetErrorString(err) << " " << func << std::endl;
|
||||
//std::exit(EXIT_FAILURE);
|
||||
}
|
||||
}
|
||||
12
extra/thunder/cuda/include/types/device/device.cuh
Normal file
12
extra/thunder/cuda/include/types/device/device.cuh
Normal file
@@ -0,0 +1,12 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header file for all the device types defined by ThunderKittens.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if defined(KITTENS_HOPPER) || defined(KITTENS_BLACKWELL)
|
||||
#include "ipc.cuh"
|
||||
#include "pgl.cuh"
|
||||
#include "vmm.cuh"
|
||||
#endif
|
||||
195
extra/thunder/cuda/include/types/device/ipc.cuh
Normal file
195
extra/thunder/cuda/include/types/device/ipc.cuh
Normal file
@@ -0,0 +1,195 @@
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
#include <type_traits>
|
||||
|
||||
#include "../../common/common.cuh"
|
||||
#include "vmm.cuh"
|
||||
|
||||
namespace kittens {
|
||||
namespace ducks {
|
||||
namespace ipc {
|
||||
namespace handle {
|
||||
|
||||
struct identifier {};
|
||||
|
||||
template<typename T> concept all = requires {
|
||||
typename T::identifier;
|
||||
} && std::is_same_v<typename T::identifier, identifier>;
|
||||
|
||||
} // namespace handle
|
||||
} // namespace ipc
|
||||
} // namespace ducks
|
||||
|
||||
namespace detail {
|
||||
namespace ipc {
|
||||
|
||||
enum flavor {
|
||||
LEGACY = 0,
|
||||
VMM = 1
|
||||
};
|
||||
|
||||
template<flavor _flavor>
|
||||
struct handle;
|
||||
|
||||
template<>
|
||||
struct handle<flavor::LEGACY> {
|
||||
using identifier = ducks::ipc::handle::identifier;
|
||||
static constexpr flavor flavor_ = flavor::LEGACY;
|
||||
cudaIpcMemHandle_t handle_ {};
|
||||
};
|
||||
|
||||
template<>
|
||||
struct handle<flavor::VMM> {
|
||||
using identifier = ducks::ipc::handle::identifier;
|
||||
static constexpr flavor flavor_ = flavor::VMM;
|
||||
int handle_;
|
||||
};
|
||||
|
||||
__host__ inline static void check_support(const int device_id) {
|
||||
CUdevice device;
|
||||
CUCHECK(cuDeviceGet(&device, device_id));
|
||||
|
||||
int ipc_supported = 0;
|
||||
CUDACHECK(cudaDeviceGetAttribute(&ipc_supported, cudaDevAttrIpcEventSupport, device_id));
|
||||
int ipc_handle_supported = 0;
|
||||
CUCHECK(cuDeviceGetAttribute(&ipc_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED, device));
|
||||
|
||||
if (!ipc_supported || !ipc_handle_supported)
|
||||
throw std::runtime_error("CUDA IPC is not supported on this device");
|
||||
}
|
||||
|
||||
template<ducks::ipc::handle::all IPC_HANDLE>
|
||||
__host__ inline static void export_handle(
|
||||
IPC_HANDLE *ipc_handle,
|
||||
void *ptr
|
||||
) {
|
||||
if constexpr (IPC_HANDLE::flavor_ == flavor::LEGACY) {
|
||||
CUDACHECK(cudaIpcGetMemHandle(&ipc_handle->handle_, ptr));
|
||||
} else if constexpr (IPC_HANDLE::flavor_ == flavor::VMM) {
|
||||
CUmemGenericAllocationHandle memory_handle;
|
||||
detail::vmm::vm_retrieve_handle(&memory_handle, ptr);
|
||||
// ** Important: this handle (FD) must be manually closed by the user **
|
||||
CUCHECK(cuMemExportToShareableHandle(&ipc_handle->handle_, memory_handle, detail::vmm::HANDLE_TYPE, 0));
|
||||
detail::vmm::vm_free(memory_handle);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid IPC handle type");
|
||||
}
|
||||
}
|
||||
|
||||
template<ducks::ipc::handle::all IPC_HANDLE>
|
||||
__host__ inline static void export_handle(
|
||||
IPC_HANDLE *ipc_handle,
|
||||
CUmemGenericAllocationHandle &memory_handle
|
||||
) {
|
||||
if constexpr (IPC_HANDLE::flavor_ == flavor::VMM) {
|
||||
CUCHECK(cuMemExportToShareableHandle(&ipc_handle->handle_, memory_handle, detail::vmm::HANDLE_TYPE, 0));
|
||||
} else {
|
||||
throw std::runtime_error("Invalid IPC handle type");
|
||||
}
|
||||
}
|
||||
|
||||
template<ducks::ipc::handle::all IPC_HANDLE>
|
||||
__host__ inline static void import_handle (
|
||||
void **ptr,
|
||||
IPC_HANDLE &ipc_handle,
|
||||
const size_t size,
|
||||
int local_world_size
|
||||
) {
|
||||
if constexpr (IPC_HANDLE::flavor_ == flavor::LEGACY) {
|
||||
CUDACHECK(cudaIpcOpenMemHandle(ptr, ipc_handle.handle_, cudaIpcMemLazyEnablePeerAccess)); // this is the only flag supported
|
||||
} else if constexpr (IPC_HANDLE::flavor_ == flavor::VMM) {
|
||||
CUmemGenericAllocationHandle memory_handle;
|
||||
CUCHECK(cuMemImportFromShareableHandle(&memory_handle, reinterpret_cast<void *>(static_cast<uintptr_t>(ipc_handle.handle_)), detail::vmm::HANDLE_TYPE));
|
||||
detail::vmm::vm_map(ptr, memory_handle, size);
|
||||
detail::vmm::vm_set_access(*ptr, size, local_world_size);
|
||||
detail::vmm::vm_free(memory_handle);
|
||||
close(ipc_handle.handle_); // close fd immediately
|
||||
ipc_handle.handle_ = -1;
|
||||
} else {
|
||||
throw std::runtime_error("Invalid IPC handle type");
|
||||
}
|
||||
}
|
||||
|
||||
template<ducks::ipc::handle::all IPC_HANDLE>
|
||||
__host__ inline static void import_handle (
|
||||
CUmemGenericAllocationHandle *memory_handle,
|
||||
IPC_HANDLE &ipc_handle,
|
||||
const size_t size,
|
||||
int local_world_size
|
||||
) {
|
||||
if constexpr (IPC_HANDLE::flavor_ == flavor::VMM) {
|
||||
CUCHECK(cuMemImportFromShareableHandle(memory_handle, reinterpret_cast<void *>(static_cast<uintptr_t>(ipc_handle.handle_)), detail::vmm::HANDLE_TYPE));
|
||||
close(ipc_handle.handle_); // close fd immediately
|
||||
ipc_handle.handle_ = -1;
|
||||
} else {
|
||||
throw std::runtime_error("Invalid IPC handle type");
|
||||
}
|
||||
}
|
||||
|
||||
template<flavor _flavor>
|
||||
__host__ inline static void free_handle(
|
||||
void *ptr,
|
||||
const size_t size
|
||||
) {
|
||||
if constexpr (_flavor == flavor::LEGACY) {
|
||||
CUDACHECK(cudaIpcCloseMemHandle(ptr));
|
||||
} else if constexpr (_flavor == flavor::VMM) {
|
||||
detail::vmm::vm_unmap(ptr, size);
|
||||
} else {
|
||||
throw std::runtime_error("Invalid IPC handle type");
|
||||
}
|
||||
}
|
||||
|
||||
__host__ inline static void enable_all_peer_access(int num_devices) {
|
||||
int num_available_devices;
|
||||
CUCHECK(cuDeviceGetCount(&num_available_devices));
|
||||
if (num_available_devices < num_devices)
|
||||
throw std::runtime_error("Not enough GPUs available");
|
||||
|
||||
std::vector<CUdevice> devices(num_devices);
|
||||
std::vector<CUcontext> contexts(num_devices);
|
||||
|
||||
for (int i = 0; i < num_devices; i++) {
|
||||
CUCHECK(cuDeviceGet(&devices[i], i));
|
||||
CUCHECK(cuCtxCreate(&contexts[i], 0, devices[i]));
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_devices; i++) {
|
||||
int device_compute_mode;
|
||||
CUCHECK(cuDeviceGetAttribute(&device_compute_mode, CU_DEVICE_ATTRIBUTE_COMPUTE_MODE, devices[i]));
|
||||
if (device_compute_mode != CU_COMPUTEMODE_DEFAULT)
|
||||
throw std::runtime_error("Device is in an unsupported compute mode");
|
||||
|
||||
int vmm_supported = 0;
|
||||
CUCHECK(cuDeviceGetAttribute(&vmm_supported, CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED, devices[i]));
|
||||
if (!vmm_supported)
|
||||
throw std::runtime_error("Device does not support CUDA VMM");
|
||||
|
||||
int ipc_handle_supported;
|
||||
CUCHECK(cuDeviceGetAttribute(&ipc_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR_SUPPORTED, devices[i]));
|
||||
if (!ipc_handle_supported)
|
||||
throw std::runtime_error("Device does not support IPC handles");
|
||||
|
||||
for (int j = 0; j < num_devices; j++) {
|
||||
if (i == j) continue;
|
||||
int can_access_peer;
|
||||
CUCHECK(cuDeviceCanAccessPeer(&can_access_peer, devices[i], devices[j]));
|
||||
if (!can_access_peer)
|
||||
throw std::runtime_error("Device cannot access peer device");
|
||||
CUCHECK(cuCtxSetCurrent(contexts[i]));
|
||||
CUCHECK(cuCtxEnablePeerAccess(contexts[j], 0));
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < contexts.size(); ++i)
|
||||
CUCHECK(cuCtxDestroy(contexts[i]));
|
||||
}
|
||||
|
||||
} // namespace ipc
|
||||
} // namespace detail
|
||||
} // namespace kittens
|
||||
173
extra/thunder/cuda/include/types/device/pgl.cuh
Normal file
173
extra/thunder/cuda/include/types/device/pgl.cuh
Normal file
@@ -0,0 +1,173 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Templated layouts for parallel global memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../common/common.cuh"
|
||||
#include "../shared/shared.cuh"
|
||||
#include "../global/global.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/* ---------- Parallel global layout descriptor ---------- */
|
||||
|
||||
namespace ducks {
|
||||
namespace pgl {
|
||||
|
||||
struct identifier {};
|
||||
|
||||
/**
|
||||
* @brief Concept for all parallel global layouts.
|
||||
* @tparam T The type to check against the concept requirements.
|
||||
*
|
||||
* Requires:
|
||||
* - T has a nested type identifier that is the same as ducks::pgl::identifier.
|
||||
*/
|
||||
template<typename T> concept all = requires {
|
||||
typename T::identifier;
|
||||
} && std::is_same_v<typename T::identifier, identifier>;
|
||||
|
||||
} // namespace pgl
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @brief Parallel global layout. Represents a region of data spread across multiple devices.
|
||||
* @tparam GL The underlying global layout on each device.
|
||||
* @tparam NUM_DEVICES The number of GPU devices.
|
||||
* @tparam MULTICAST Whether the multicast object should be initialized by the caller.
|
||||
* @tparam TMA_Types The types of TMA descriptors to use for the multicast locations.
|
||||
Only valid if MULTICAST is true.
|
||||
*/
|
||||
template<kittens::ducks::gl::all _GL, int NUM_DEVICES = 8, bool MULTICAST = true, typename... TMA_Types>
|
||||
struct pgl {
|
||||
using identifier = ducks::pgl::identifier;
|
||||
using GL = _GL;
|
||||
using T = GL::dtype;
|
||||
using dtype = T;
|
||||
|
||||
static constexpr int num_devices = NUM_DEVICES;
|
||||
static constexpr bool multicast = MULTICAST;
|
||||
|
||||
T *mc_ptr; // multicast pointer; nullptr if MULTICAST is false
|
||||
GL gls[NUM_DEVICES];
|
||||
|
||||
detail::descriptor_dict<TMA_Types...> tma_descs;
|
||||
|
||||
__host__ __device__ const GL &operator[](int idx) const { return gls[idx]; }
|
||||
__device__ inline T* mc_ptr_at(const coord<ducks::default_type> &idx) const {
|
||||
static_assert(MULTICAST, "Multicast is not enabled for this PGL.");
|
||||
const GL &gl = gls[0]; // all gls have the same shape
|
||||
return &mc_ptr[((idx.b * gl.depth() + idx.d) * gl.rows() + idx.r) * gl.cols() + idx.c];
|
||||
}
|
||||
|
||||
__host__ inline pgl(T **_data, // an array of NUM_DEVICES pointers to the data on each device
|
||||
ducks::gl::make_arg_t<GL::__b__> _batch,
|
||||
ducks::gl::make_arg_t<GL::__d__> _depth,
|
||||
ducks::gl::make_arg_t<GL::__r__> _rows,
|
||||
ducks::gl::make_arg_t<GL::__c__> _cols) :
|
||||
pgl(std::make_index_sequence<NUM_DEVICES>{}, _data, _batch, _depth, _rows, _cols) { }
|
||||
|
||||
__host__ inline pgl(T *_mc_ptr, // multicast pointer, initialized by the caller
|
||||
T **_data, // an array of NUM_DEVICES pointers to the data on each device
|
||||
ducks::gl::make_arg_t<GL::__b__> _batch,
|
||||
ducks::gl::make_arg_t<GL::__d__> _depth,
|
||||
ducks::gl::make_arg_t<GL::__r__> _rows,
|
||||
ducks::gl::make_arg_t<GL::__c__> _cols) :
|
||||
pgl(std::make_index_sequence<NUM_DEVICES>{}, _mc_ptr, _data, _batch, _depth, _rows, _cols) { }
|
||||
|
||||
template<size_t... I>
|
||||
__host__ inline pgl(std::index_sequence<I...>,
|
||||
T **_data,
|
||||
ducks::gl::make_arg_t<GL::__b__> _batch,
|
||||
ducks::gl::make_arg_t<GL::__d__> _depth,
|
||||
ducks::gl::make_arg_t<GL::__r__> _rows,
|
||||
ducks::gl::make_arg_t<GL::__c__> _cols) :
|
||||
mc_ptr(nullptr), gls{GL(_data[I], _batch, _depth, _rows, _cols)...} {
|
||||
static_assert(!MULTICAST, "Multicast pointer not passed to multicast-enabled PGL.");
|
||||
}
|
||||
|
||||
template<size_t... I>
|
||||
__host__ inline pgl(std::index_sequence<I...>,
|
||||
T *_mc_ptr,
|
||||
T **_data,
|
||||
ducks::gl::make_arg_t<GL::__b__> _batch,
|
||||
ducks::gl::make_arg_t<GL::__d__> _depth,
|
||||
ducks::gl::make_arg_t<GL::__r__> _rows,
|
||||
ducks::gl::make_arg_t<GL::__c__> _cols) :
|
||||
mc_ptr(_mc_ptr), gls{GL(_data[I], _batch, _depth, _rows, _cols)...} {
|
||||
static_assert(MULTICAST, "Multicast pointer passed to multicast-disabled PGL.");
|
||||
tma_descs = detail::descriptor_dict<TMA_Types...>(
|
||||
mc_ptr, gls[0].batch_internal, gls[0].depth_internal, gls[0].rows_internal, gls[0].cols_internal);
|
||||
}
|
||||
|
||||
template<typename U, int axis>
|
||||
__device__ inline const CUtensorMap* get_tma() const {
|
||||
return tma_descs.template get<U, axis>();
|
||||
}
|
||||
|
||||
__host__ __device__ inline auto batch() const { return gls[0].batch(); }
|
||||
__host__ __device__ inline auto depth() const { return gls[0].depth(); }
|
||||
__host__ __device__ inline auto rows() const { return gls[0].rows(); }
|
||||
__host__ __device__ inline auto cols() const { return gls[0].cols(); }
|
||||
__host__ __device__ inline size_t numel() const { return static_cast<size_t>(batch()) * depth() * rows() * cols(); }
|
||||
|
||||
template<int axis> __device__ inline size_t shape() const { return gls[0].template shape<axis>(); }
|
||||
template<int axis> __device__ inline size_t stride() const { return gls[0].template stride<axis>(); }
|
||||
};
|
||||
|
||||
template<ducks::pgl::all PGL, bool safe=true> __host__ inline PGL make_pgl(
|
||||
uint64_t *data, int b, int d, int r, int c
|
||||
) {
|
||||
if constexpr (safe) {
|
||||
if (PGL::GL::__b__ > 0 && b != PGL::GL::__b__) {
|
||||
throw std::runtime_error("Batch dimension mismatch. Expected: " + std::to_string(PGL::GL::__b__) + ", Got: " + std::to_string(b));
|
||||
}
|
||||
if (PGL::GL::__d__ > 0 && d != PGL::GL::__d__) {
|
||||
throw std::runtime_error("Depth dimension mismatch. Expected: " + std::to_string(PGL::GL::__d__) + ", Got: " + std::to_string(d));
|
||||
}
|
||||
if (PGL::GL::__r__ > 0 && r != PGL::GL::__r__) {
|
||||
throw std::runtime_error("Row dimension mismatch. Expected: " + std::to_string(PGL::GL::__r__) + ", Got: " + std::to_string(r));
|
||||
}
|
||||
if (PGL::GL::__c__ > 0 && c != PGL::GL::__c__) {
|
||||
throw std::runtime_error("Column dimension mismatch. Expected: " + std::to_string(PGL::GL::__c__) + ", Got: " + std::to_string(c));
|
||||
}
|
||||
}
|
||||
return PGL(
|
||||
reinterpret_cast<typename PGL::dtype**>(data),
|
||||
make_unsafe_gl_arg<PGL::GL::__b__>(b),
|
||||
make_unsafe_gl_arg<PGL::GL::__d__>(d),
|
||||
make_unsafe_gl_arg<PGL::GL::__r__>(r),
|
||||
make_unsafe_gl_arg<PGL::GL::__c__>(c)
|
||||
);
|
||||
}
|
||||
|
||||
template<ducks::pgl::all PGL, bool safe=true> __host__ inline PGL make_pgl(
|
||||
uint64_t mc_ptr, uint64_t *data, int b, int d, int r, int c
|
||||
) {
|
||||
if constexpr (safe) {
|
||||
if (PGL::GL::__b__ > 0 && b != PGL::GL::__b__) {
|
||||
throw std::runtime_error("Batch dimension mismatch. Expected: " + std::to_string(PGL::GL::__b__) + ", Got: " + std::to_string(b));
|
||||
}
|
||||
if (PGL::GL::__d__ > 0 && d != PGL::GL::__d__) {
|
||||
throw std::runtime_error("Depth dimension mismatch. Expected: " + std::to_string(PGL::GL::__d__) + ", Got: " + std::to_string(d));
|
||||
}
|
||||
if (PGL::GL::__r__ > 0 && r != PGL::GL::__r__) {
|
||||
throw std::runtime_error("Row dimension mismatch. Expected: " + std::to_string(PGL::GL::__r__) + ", Got: " + std::to_string(r));
|
||||
}
|
||||
if (PGL::GL::__c__ > 0 && c != PGL::GL::__c__) {
|
||||
throw std::runtime_error("Column dimension mismatch. Expected: " + std::to_string(PGL::GL::__c__) + ", Got: " + std::to_string(c));
|
||||
}
|
||||
}
|
||||
return PGL(
|
||||
reinterpret_cast<typename PGL::dtype*>(mc_ptr),
|
||||
reinterpret_cast<typename PGL::dtype**>(data),
|
||||
make_unsafe_gl_arg<PGL::GL::__b__>(b),
|
||||
make_unsafe_gl_arg<PGL::GL::__d__>(d),
|
||||
make_unsafe_gl_arg<PGL::GL::__r__>(r),
|
||||
make_unsafe_gl_arg<PGL::GL::__c__>(c)
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace kittens
|
||||
180
extra/thunder/cuda/include/types/device/vmm.cuh
Normal file
180
extra/thunder/cuda/include/types/device/vmm.cuh
Normal file
@@ -0,0 +1,180 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <vector>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "../../common/common.cuh"
|
||||
|
||||
namespace kittens {
|
||||
namespace detail {
|
||||
namespace vmm {
|
||||
|
||||
// Intra-node shareable handle type
|
||||
// This makes the handle shareable with cuMemExportToShareableHandle/cuMemImportFromShareableHandle
|
||||
static constexpr CUmemAllocationHandleType HANDLE_TYPE = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR;
|
||||
|
||||
typedef CUmemGenericAllocationHandle handle;
|
||||
|
||||
__host__ inline static void vm_alloc(
|
||||
CUmemGenericAllocationHandle *handle,
|
||||
size_t *allocated_size,
|
||||
const size_t size,
|
||||
const int device_id
|
||||
) {
|
||||
CUmemAllocationProp prop = {};
|
||||
prop.location.id = device_id;
|
||||
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.requestedHandleTypes = HANDLE_TYPE;
|
||||
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
|
||||
size_t granularity;
|
||||
CUCHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
|
||||
*allocated_size = (size + granularity - 1) / granularity * granularity; // round-up
|
||||
|
||||
CUCHECK(cuMemCreate(handle, *allocated_size, &prop, 0));
|
||||
}
|
||||
|
||||
__host__ inline static void vm_map(
|
||||
void **ptr,
|
||||
const CUmemGenericAllocationHandle &handle,
|
||||
const size_t size
|
||||
) {
|
||||
CUdeviceptr device_ptr;
|
||||
CUCHECK(cuMemAddressReserve(&device_ptr, size, 0, 0, 0));
|
||||
CUCHECK(cuMemMap(device_ptr, size, 0, handle, 0));
|
||||
*ptr = (void *)device_ptr;
|
||||
}
|
||||
|
||||
__host__ inline static void vm_set_access(
|
||||
void *ptr,
|
||||
const size_t size,
|
||||
const int num_devices
|
||||
) {
|
||||
std::vector<CUmemAccessDesc> descs(num_devices);
|
||||
for (int i = 0; i < num_devices; i++) {
|
||||
descs[i].location.id = i;
|
||||
descs[i].location.type = CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
descs[i].flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
|
||||
}
|
||||
CUCHECK(cuMemSetAccess(reinterpret_cast<CUdeviceptr>(ptr), size, descs.data(), num_devices));
|
||||
}
|
||||
|
||||
__host__ inline static void vm_retrieve_handle(
|
||||
CUmemGenericAllocationHandle *handle,
|
||||
void *ptr
|
||||
) {
|
||||
// Every call to this requires a corresponding call to cuMemRelease
|
||||
CUCHECK(cuMemRetainAllocationHandle(handle, ptr));
|
||||
}
|
||||
|
||||
__host__ inline static void vm_unmap(
|
||||
void *ptr,
|
||||
const size_t size
|
||||
) {
|
||||
CUCHECK(cuMemUnmap(reinterpret_cast<CUdeviceptr>(ptr), size));
|
||||
CUCHECK(cuMemAddressFree(reinterpret_cast<CUdeviceptr>(ptr), size));
|
||||
}
|
||||
|
||||
__host__ inline static void vm_free(CUmemGenericAllocationHandle &handle) {
|
||||
// It is recommended to free the handle ASAP; the backing memory will
|
||||
// only be freed when all handles AND address mappings are released
|
||||
CUCHECK(cuMemRelease(handle));
|
||||
}
|
||||
|
||||
__host__ inline static void vm_alloc_map_set_access(
|
||||
void **ptr,
|
||||
size_t *allocated_size,
|
||||
const size_t size,
|
||||
const int device_id,
|
||||
const int num_devices
|
||||
) {
|
||||
CUmemGenericAllocationHandle handle;
|
||||
vm_alloc(&handle, allocated_size, size, device_id);
|
||||
vm_map(ptr, handle, *allocated_size);
|
||||
vm_set_access(*ptr, *allocated_size, num_devices);
|
||||
vm_free(handle); // release the handle ASAP
|
||||
}
|
||||
|
||||
__host__ inline static void multicast_check(const int device_id) {
|
||||
CUdevice device;
|
||||
CUCHECK(cuDeviceGet(&device, device_id));
|
||||
|
||||
int multicast_supported;
|
||||
CUresult result = cuDeviceGetAttribute(
|
||||
&multicast_supported,
|
||||
CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED,
|
||||
device
|
||||
);
|
||||
|
||||
if (!multicast_supported)
|
||||
throw std::runtime_error("Device does not support multicast");
|
||||
}
|
||||
|
||||
__host__ inline static void multicast_create_handle(
|
||||
CUmemGenericAllocationHandle *handle,
|
||||
size_t *allocated_size,
|
||||
const size_t size,
|
||||
const int num_devices
|
||||
) {
|
||||
if (num_devices <= 1)
|
||||
throw std::runtime_error("Multicast requires at least 2 devices");
|
||||
|
||||
CUmulticastObjectProp prop = {};
|
||||
prop.numDevices = num_devices;
|
||||
prop.handleTypes = HANDLE_TYPE;
|
||||
|
||||
size_t granularity;
|
||||
CUCHECK(cuMulticastGetGranularity(&granularity, &prop, CU_MULTICAST_GRANULARITY_RECOMMENDED));
|
||||
*allocated_size = (size + granularity - 1) / granularity * granularity;
|
||||
prop.size = *allocated_size;
|
||||
|
||||
// After this, the handle must be shared with all processes through MPI, KittensBroker, etc.
|
||||
cuMulticastCreate(handle, &prop);
|
||||
}
|
||||
|
||||
__host__ inline static void multicast_bind_device(
|
||||
const CUmemGenericAllocationHandle &handle,
|
||||
const int device_id
|
||||
) {
|
||||
// All processes must sync after this, before binding any memory
|
||||
CUdevice device;
|
||||
CUCHECK(cuDeviceGet(&device, device_id));
|
||||
CUCHECK(cuMulticastAddDevice(handle, device));
|
||||
}
|
||||
|
||||
__host__ inline static void multicast_bind_memory(
|
||||
const CUmemGenericAllocationHandle &multicast_handle,
|
||||
const CUmemGenericAllocationHandle &memory_handle,
|
||||
const size_t size
|
||||
) {
|
||||
// All processes should finish adding device before calling this function
|
||||
CUCHECK(cuMulticastBindMem(multicast_handle, 0, memory_handle, 0, size, 0));
|
||||
}
|
||||
|
||||
__host__ inline static void multicast_bind_address(
|
||||
const CUmemGenericAllocationHandle &multicast_handle,
|
||||
void *ptr,
|
||||
const size_t size
|
||||
) {
|
||||
// All processes should finish adding device before calling this function
|
||||
CUmemGenericAllocationHandle memory_handle;
|
||||
vm_retrieve_handle(&memory_handle, ptr);
|
||||
multicast_bind_memory(multicast_handle, memory_handle, size);
|
||||
vm_free(memory_handle);
|
||||
}
|
||||
|
||||
__host__ inline static void multicast_unbind_device(
|
||||
const CUmemGenericAllocationHandle &handle,
|
||||
const size_t size,
|
||||
const int device_id
|
||||
) {
|
||||
// Unbinding memory is not needed
|
||||
CUdevice device;
|
||||
CUCHECK(cuDeviceGet(&device, device_id));
|
||||
CUCHECK(cuMulticastUnbind(handle, device, 0, size));
|
||||
}
|
||||
|
||||
} // namespace vmm
|
||||
} // namespace detail
|
||||
} // namespace kittens
|
||||
56
extra/thunder/cuda/include/types/global/cgl.cuh
Normal file
56
extra/thunder/cuda/include/types/global/cgl.cuh
Normal file
@@ -0,0 +1,56 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Templated layouts for complex global memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../common/common.cuh"
|
||||
#include "../shared/cst.cuh"
|
||||
#include "gl.cuh"
|
||||
#include "util.cuh"
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include "tma.cuh"
|
||||
#endif
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/* ---------- Global layout descriptor ---------- */
|
||||
|
||||
namespace ducks {
|
||||
namespace cgl {
|
||||
struct identifier {};
|
||||
}
|
||||
}
|
||||
|
||||
// namespace detail {
|
||||
// template<typename T> concept tile = ducks::cst::all<T> || ducks::crt::all<T>;
|
||||
// template<typename T> concept vec = ducks::csv::all<T> || ducks::crv::all<T>;
|
||||
// }
|
||||
|
||||
template<kittens::ducks::gl::all _GL>
|
||||
struct cgl {
|
||||
using identifier = ducks::cgl::identifier;
|
||||
using component = _GL;
|
||||
using T = component::T;
|
||||
using T2 = component::T2;
|
||||
using dtype = component::dtype;
|
||||
component real, imag;
|
||||
};
|
||||
|
||||
namespace ducks {
|
||||
namespace cgl {
|
||||
/**
|
||||
* @brief Concept for all complex global layouts.
|
||||
* @tparam T The type to check against the concept requirements.
|
||||
*
|
||||
* Requires:
|
||||
* - T has a nested type identifier that is the same as ducks::cgl::identifier.
|
||||
*/
|
||||
template<typename T> concept all = requires {
|
||||
typename T::identifier; // Checks if T::identifier exists
|
||||
} && std::is_same_v<typename T::identifier, identifier>; // Checks if T::identifier is ducks::cgl::identifier
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
225
extra/thunder/cuda/include/types/global/gl.cuh
Normal file
225
extra/thunder/cuda/include/types/global/gl.cuh
Normal file
@@ -0,0 +1,225 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Templated layouts for global memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../common/common.cuh"
|
||||
#include "../shared/shared.cuh"
|
||||
#include "util.cuh"
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include <utility>
|
||||
#include "tma.cuh"
|
||||
#endif
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/* ---------- Global layout axes ---------- */
|
||||
|
||||
struct dim {
|
||||
static constexpr int BATCH = 0;
|
||||
static constexpr int DEPTH = 1;
|
||||
static constexpr int ROW = 2;
|
||||
static constexpr int COL = 3;
|
||||
};
|
||||
|
||||
/* ---------- Associative dictionary for global layouts ---------- */
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
namespace ducks {
|
||||
namespace tma {
|
||||
namespace descriptor {
|
||||
struct identifier {};
|
||||
template<typename T> concept all = requires {
|
||||
typename T::identifier;
|
||||
} && std::is_same_v<typename T::identifier, identifier>;
|
||||
} // namespace descriptor
|
||||
} // namespace tma
|
||||
} // namespace ducks
|
||||
namespace detail {
|
||||
namespace tma {
|
||||
template<typename T> struct descriptor_copy_helper {};
|
||||
template<kittens::ducks::tma::descriptor::all _T> struct descriptor_copy_helper<_T> { static constexpr int value = _T::axis; using T = _T::T; static constexpr bool swizzle_flag = _T::swizzle_flag; };
|
||||
template<kittens::ducks::st::all _T> struct descriptor_copy_helper<_T> { static constexpr int value = 2; using T = _T; static constexpr bool swizzle_flag = true; };
|
||||
template<kittens::ducks::sv::all _T> struct descriptor_copy_helper<_T> { static constexpr int value = -1; using T = _T; static constexpr bool swizzle_flag = true; };
|
||||
template<typename T> using descriptor_copy_helper_t = descriptor_copy_helper<T>::T;
|
||||
template<typename T> static constexpr int descriptor_copy_helper_v = descriptor_copy_helper<T>::value;
|
||||
template<typename T> static constexpr bool descriptor_copy_helper_swizzle_flag = descriptor_copy_helper<T>::swizzle_flag;
|
||||
} // namespace tma
|
||||
} // namespace detail
|
||||
namespace tma {
|
||||
template<typename _T, int _axis=-9999, bool _swizzle_flag=true> struct descriptor {
|
||||
using identifier = ducks::tma::descriptor::identifier;
|
||||
using T = detail::tma::descriptor_copy_helper_t<_T>;
|
||||
static_assert(ducks::st::all<T> || ducks::sv::all<T> || ducks::tma::descriptor::all<T>, "Must be a shared TK type to generate a TMA descriptor.");
|
||||
static constexpr int axis = (
|
||||
ducks::tma::descriptor::all<_T> ? detail::tma::descriptor_copy_helper_v<_T> : // if a copy, inherit the axis from the original descriptor.
|
||||
(_axis != -9999) ? _axis : detail::tma::descriptor_copy_helper_v<_T>); // if a default value was provided, use it.
|
||||
static_assert((kittens::ducks::st::all<T> && axis >= 0 && axis <= 2) || (kittens::ducks::sv::all<T> && axis == -1), "Internal template error detected.");
|
||||
static constexpr bool swizzle_flag = ducks::tma::descriptor::all<_T> ? detail::tma::descriptor_copy_helper_swizzle_flag<_T> : _swizzle_flag;
|
||||
};
|
||||
} // namespace tma
|
||||
#endif
|
||||
|
||||
namespace detail {
|
||||
template<typename... Args>
|
||||
struct descriptor_dict {
|
||||
__host__ descriptor_dict() {}
|
||||
template<typename T> __host__ descriptor_dict(T _, int b, int d, int r, int c) {}
|
||||
__host__ __device__ descriptor_dict(const descriptor_dict &other) {}
|
||||
#ifdef KITTENS_HOPPER
|
||||
template<typename T, int U> __device__ const CUtensorMap* get() const {
|
||||
static_assert(
|
||||
std::is_same_v<T, std::true_type> && std::is_same_v<T, std::false_type>,
|
||||
"SKILL ISSUE: Requested a TMA descriptor for a type not initialized in the global layout."
|
||||
);
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
template<typename _T, typename... Args>
|
||||
struct descriptor_dict<_T, Args...> {
|
||||
static_assert(ducks::sv::all<_T> || ducks::st::all<_T> || ducks::tma::descriptor::all<_T>, "Must be a shared TK type to generate a TMA descriptor.");
|
||||
using DESC = kittens::tma::descriptor<_T>; // copy or initialize with a default value
|
||||
CUtensorMap tma_desc;
|
||||
descriptor_dict<Args...> other_descs;
|
||||
__host__ descriptor_dict() {}
|
||||
__host__ descriptor_dict(typename DESC::T::dtype *data, int b, int d, int r, int c): other_descs(data, b, d, r, c) {
|
||||
kittens::detail::tma::create_tensor_map<typename DESC::T, DESC::axis, DESC::swizzle_flag>(&tma_desc, data, b, d, r, c);
|
||||
}
|
||||
__host__ __device__ inline descriptor_dict(const descriptor_dict &other) :
|
||||
tma_desc(other.tma_desc), other_descs(other.other_descs) {}
|
||||
template<typename U, int axis> __device__ inline const CUtensorMap* get() const {
|
||||
if constexpr (std::is_same_v<typename DESC::T, U> && DESC::axis == axis) { return &tma_desc; }
|
||||
else { return other_descs.template get<U, axis>(); }
|
||||
}
|
||||
};
|
||||
#endif
|
||||
}
|
||||
|
||||
/* ---------- Global layout descriptor ---------- */
|
||||
|
||||
namespace ducks {
|
||||
namespace gl {
|
||||
struct identifier {};
|
||||
}
|
||||
}
|
||||
|
||||
template<typename _T, int b, int d, int r, int c, typename... TMA_Types>
|
||||
struct gl {
|
||||
using identifier = ducks::gl::identifier;
|
||||
|
||||
using T = base_types::packing<_T>::unpacked_type;
|
||||
using T2 = base_types::packing<_T>::packed_type;
|
||||
using dtype = T;
|
||||
|
||||
T* raw_ptr;
|
||||
|
||||
static constexpr int __b__ = b, __d__ = d, __r__ = r, __c__ = c; // Not to be touched by the user.
|
||||
|
||||
ducks::gl::make_dim_t<b> batch_internal;
|
||||
ducks::gl::make_dim_t<d> depth_internal;
|
||||
ducks::gl::make_dim_t<r> rows_internal;
|
||||
ducks::gl::make_dim_t<c> cols_internal;
|
||||
|
||||
template <int B=__b__> __device__ __host__ static constexpr std::enable_if_t<(B > 0), int> batch() { return B; }
|
||||
template <int B=__b__> __device__ __host__ std::enable_if_t<(B == -1), int> batch() const { return batch_internal; }
|
||||
template <int D=__d__> __device__ __host__ static constexpr std::enable_if_t<(D > 0), int> depth() { return D; }
|
||||
template <int D=__d__> __device__ __host__ std::enable_if_t<(D == -1), int> depth() const { return depth_internal; }
|
||||
template <int R=__r__> __device__ __host__ static constexpr std::enable_if_t<(R > 0), int> rows() { return R; }
|
||||
template <int R=__r__> __device__ __host__ std::enable_if_t<(R == -1), int> rows() const { return rows_internal; }
|
||||
template <int C=__c__> __device__ __host__ static constexpr std::enable_if_t<(C > 0), int> cols() { return C; }
|
||||
template <int C=__c__> __device__ __host__ std::enable_if_t<(C == -1), int> cols() const { return cols_internal; }
|
||||
|
||||
detail::descriptor_dict<TMA_Types...> tma_descs;
|
||||
|
||||
__host__ inline gl(T *_data,
|
||||
ducks::gl::make_arg_t<b> _batch,
|
||||
ducks::gl::make_arg_t<d> _depth,
|
||||
ducks::gl::make_arg_t<r> _rows,
|
||||
ducks::gl::make_arg_t<c> _cols) :
|
||||
raw_ptr(_data), batch_internal(_batch), depth_internal(_depth), rows_internal(_rows), cols_internal(_cols) {
|
||||
tma_descs = detail::descriptor_dict<TMA_Types...>(raw_ptr, batch_internal, depth_internal, rows_internal, cols_internal);
|
||||
}
|
||||
__host__ __device__ inline gl(const gl &other) :
|
||||
raw_ptr(other.raw_ptr), batch_internal(other.batch_internal), depth_internal(other.depth_internal), rows_internal(other.rows_internal), cols_internal(other.cols_internal), tma_descs(other.tma_descs) {}
|
||||
#ifdef KITTENS_HOPPER
|
||||
template<typename U, int axis> __device__ inline const CUtensorMap* get_tma() const {
|
||||
return tma_descs.template get<U, axis>();
|
||||
}
|
||||
#endif
|
||||
__device__ inline T& operator[](const coord<ducks::default_type> &idx) const { // yes I am abusing the const qualifier here a bit.
|
||||
return raw_ptr[((idx.b*depth() + idx.d)*rows() + idx.r)*cols() + idx.c];
|
||||
}
|
||||
template<int axis> __device__ inline size_t shape() const {
|
||||
static_assert(axis==0 || axis==1 || axis==2 || axis==3, "Axis must be 0, 1, 2, or 3.");
|
||||
if constexpr (axis==0) { return size_t(batch()); }
|
||||
else if constexpr (axis==1) { return size_t(depth()); }
|
||||
else if constexpr (axis==2) { return size_t(rows()); }
|
||||
else if constexpr (axis==3) { return size_t(cols()); }
|
||||
}
|
||||
template<int axis> __device__ inline size_t stride() const {
|
||||
static_assert(axis==0 || axis==1 || axis==2 || axis==3, "Axis must be 0, 1, 2, or 3.");
|
||||
if constexpr (axis==0) { return depth()*rows()*cols(); }
|
||||
else if constexpr (axis==1) { return rows()*cols(); }
|
||||
else if constexpr (axis==2) { return cols(); }
|
||||
else if constexpr (axis==3) { return 1; }
|
||||
}
|
||||
};
|
||||
|
||||
template<typename _T, int d, int r, int c, typename... TMA_Types> using gl3 = gl<_T, 1, d, r, c, TMA_Types...>;
|
||||
template<typename _T, int r, int c, typename... TMA_Types> using gl2 = gl<_T, 1, 1, r, c, TMA_Types...>;
|
||||
template<typename _T, int c, typename... TMA_Types> using gl1 = gl<_T, 1, 1, 1, c, TMA_Types...>;
|
||||
|
||||
namespace ducks {
|
||||
namespace gl {
|
||||
/**
|
||||
* @brief Concept for all global layouts.
|
||||
* @tparam T The type to check against the concept requirements.
|
||||
*
|
||||
* Requires:
|
||||
* - T has a nested type identifier that is the same as ducks::gl::identifier.
|
||||
*/
|
||||
template<typename T> concept all = requires {
|
||||
typename T::identifier; // Checks if T::identifier exists
|
||||
} && std::is_same_v<typename T::identifier, identifier>; // Checks if T::identifier is ducks::gl::identifier
|
||||
}
|
||||
}
|
||||
|
||||
// Structs for initializing global layouts automatically.
|
||||
// struct unsafe_gl {
|
||||
// uint64_t data;
|
||||
// int b, d, r, c;
|
||||
// unsafe_gl(uint64_t data, int b, int d, int r, int c) : data(data), b(b), d(d), r(r), c(c) {}
|
||||
// };
|
||||
template<int N> auto make_unsafe_gl_arg(int param) { // typename std::conditional_t<(N < 0), std::nullptr_t, int>
|
||||
if constexpr (N > 0) { return nullptr; }
|
||||
else { return param; }
|
||||
}
|
||||
template<ducks::gl::all GL, bool safe=true> __host__ inline GL make_gl(uint64_t data, int b, int d, int r, int c) {
|
||||
if constexpr (safe) {
|
||||
if(GL::__b__ > 0 && b != GL::__b__) {
|
||||
throw std::runtime_error("Batch dimension mismatch. Expected: " + std::to_string(GL::__b__) + ", Got: " + std::to_string(b));
|
||||
}
|
||||
if(GL::__d__ > 0 && d != GL::__d__) {
|
||||
throw std::runtime_error("Depth dimension mismatch. Expected: " + std::to_string(GL::__d__) + ", Got: " + std::to_string(d));
|
||||
}
|
||||
if(GL::__r__ > 0 && r != GL::__r__) {
|
||||
throw std::runtime_error("Row dimension mismatch. Expected: " + std::to_string(GL::__r__) + ", Got: " + std::to_string(r));
|
||||
}
|
||||
if(GL::__c__ > 0 && c != GL::__c__) {
|
||||
throw std::runtime_error("Column dimension mismatch. Expected: " + std::to_string(GL::__c__) + ", Got: " + std::to_string(c));
|
||||
}
|
||||
}
|
||||
return GL(
|
||||
reinterpret_cast<typename GL::dtype*>(data),
|
||||
make_unsafe_gl_arg<GL::__b__>(b),
|
||||
make_unsafe_gl_arg<GL::__d__>(d),
|
||||
make_unsafe_gl_arg<GL::__r__>(r),
|
||||
make_unsafe_gl_arg<GL::__c__>(c)
|
||||
);
|
||||
}
|
||||
|
||||
} // namespace kittens
|
||||
13
extra/thunder/cuda/include/types/global/global.cuh
Normal file
13
extra/thunder/cuda/include/types/global/global.cuh
Normal file
@@ -0,0 +1,13 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header file for all the global types defined by ThunderKittens.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifdef KITTENS_HOPPER
|
||||
#include "tma.cuh"
|
||||
#endif
|
||||
#include "util.cuh"
|
||||
#include "gl.cuh"
|
||||
#include "cgl.cuh"
|
||||
428
extra/thunder/cuda/include/types/global/tma.cuh
Normal file
428
extra/thunder/cuda/include/types/global/tma.cuh
Normal file
@@ -0,0 +1,428 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda.h>
|
||||
#include <iostream>
|
||||
#include <assert.h>
|
||||
#include <functional> // for std::hash
|
||||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
#include "../../common/common.cuh"
|
||||
#include "../shared/shared.cuh"
|
||||
|
||||
namespace kittens {
|
||||
namespace detail {
|
||||
namespace tma {
|
||||
|
||||
__host__ static inline std::string format_tma_error(
|
||||
const char* error_type,
|
||||
const char* error_string,
|
||||
int batch, int depth, int rows, int cols,
|
||||
CUtensorMap* tma_map,
|
||||
CUtensorMapDataType tma_format,
|
||||
uint32_t tma_dim,
|
||||
void* global_addr,
|
||||
const uint64_t* gmem_shape,
|
||||
const uint64_t* gmem_stride,
|
||||
const uint32_t* smem_shape,
|
||||
const uint32_t* smem_stride,
|
||||
size_t gmem_shape_size,
|
||||
size_t gmem_stride_size,
|
||||
size_t smem_shape_size,
|
||||
size_t smem_stride_size,
|
||||
CUtensorMapInterleave tma_interleave,
|
||||
CUtensorMapSwizzle tma_swizzle,
|
||||
CUtensorMapL2promotion tma_l2Promotion,
|
||||
CUtensorMapFloatOOBfill tma_oobFill,
|
||||
const std::string& extra_info = ""
|
||||
) {
|
||||
std::ostringstream oss;
|
||||
oss << "Error in " << error_type << " TMA descriptor creation: ";
|
||||
oss << (error_string ? error_string : "Unknown CUDA error");
|
||||
oss << "\nParameters:";
|
||||
oss << "\n batch: " << batch;
|
||||
oss << "\n depth: " << depth;
|
||||
oss << "\n rows: " << rows;
|
||||
oss << "\n cols: " << cols;
|
||||
if (!extra_info.empty())
|
||||
oss << "\n " << extra_info;
|
||||
|
||||
oss << "\ncuTensorMapEncodeTiled arguments:";
|
||||
oss << "\n tma_map: " << reinterpret_cast<uintptr_t>(tma_map);
|
||||
oss << "\n tma_format: " << tma_format;
|
||||
oss << "\n tma_dim: " << tma_dim;
|
||||
oss << "\n global_addr: " << reinterpret_cast<uintptr_t>(global_addr);
|
||||
|
||||
// Check if global_addr is valid device memory
|
||||
cudaPointerAttributes attributes;
|
||||
cudaError_t err = cudaPointerGetAttributes(&attributes, global_addr);
|
||||
if (err == cudaSuccess) {
|
||||
oss << "\n global_addr memory type: ";
|
||||
if (attributes.type == cudaMemoryTypeDevice) {
|
||||
oss << "valid device memory";
|
||||
} else if (attributes.type == cudaMemoryTypeHost) {
|
||||
oss << "host memory (invalid for TMA)";
|
||||
} else if (attributes.type == cudaMemoryTypeManaged) {
|
||||
oss << "managed memory";
|
||||
} else {
|
||||
oss << "unknown memory type";
|
||||
}
|
||||
} else {
|
||||
oss << "\n global_addr memory type: unable to determine (error: " << cudaGetErrorString(err) << ")";
|
||||
}
|
||||
|
||||
oss << "\n gmem_shape: " << reinterpret_cast<uintptr_t>(gmem_shape) << " [";
|
||||
for (size_t i = 0; i < gmem_shape_size; ++i)
|
||||
oss << gmem_shape[i] << (i < gmem_shape_size - 1 ? ", " : "");
|
||||
oss << "]";
|
||||
|
||||
oss << "\n gmem_stride: " << reinterpret_cast<uintptr_t>(gmem_stride) << " [";
|
||||
for (size_t i = 0; i < gmem_stride_size; ++i)
|
||||
oss << gmem_stride[i] << (i < gmem_stride_size - 1 ? ", " : "");
|
||||
oss << "]";
|
||||
|
||||
oss << "\n smem_shape: " << reinterpret_cast<uintptr_t>(smem_shape) << " [";
|
||||
for (size_t i = 0; i < smem_shape_size; ++i)
|
||||
oss << smem_shape[i] << (i < smem_shape_size - 1 ? ", " : "");
|
||||
oss << "]";
|
||||
|
||||
oss << "\n smem_stride: " << reinterpret_cast<uintptr_t>(smem_stride) << " [";
|
||||
for (size_t i = 0; i < smem_stride_size; ++i)
|
||||
oss << smem_stride[i] << (i < smem_stride_size - 1 ? ", " : "");
|
||||
oss << "]";
|
||||
|
||||
oss << "\n tma_interleave: " << tma_interleave;
|
||||
oss << "\n tma_swizzle: " << tma_swizzle;
|
||||
oss << "\n tma_l2Promotion: " << tma_l2Promotion;
|
||||
oss << "\n tma_oobFill: " << tma_oobFill;
|
||||
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
/* ---------- Create tile tensor map descriptor (HOST) ---------- */
|
||||
|
||||
/**
|
||||
* @brief Creates a tensor map for the given source tensor.
|
||||
*
|
||||
* This function creates a tensor map (CUtensorMap) for the specified source shared tile type. The tensor map
|
||||
* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor
|
||||
* map based on the provided source tensor pointer and the layout specified by the ST template parameter.
|
||||
*
|
||||
* @tparam ST The source tensor type, which must be TMA-compatible.
|
||||
* @tparam blocks_height The number of tiles present on the height axis in global memory.
|
||||
* @tparam blocks_width The number of tiles present on the width axis in global memory. Defaults to 1.
|
||||
* @param tma_map Pointer to the CUtensorMap object to be initialized.
|
||||
* @param src Pointer to the source tensor data in global memory.
|
||||
*/
|
||||
template<ducks::st::all ST, int axis, bool enable_swizzle = true>
|
||||
__host__ static inline void create_tensor_map(CUtensorMap *tma_map, const typename ST::dtype *src, int batch, int depth, int rows, int cols) {
|
||||
using dtype = typename ST::dtype;
|
||||
static_assert(axis==0 || axis==1 || axis==2, "axis must be 0, 1, or 2");
|
||||
|
||||
constexpr uint32_t tma_dim = enable_swizzle ? 5 : 4;
|
||||
void *global_addr = (void*)(src);
|
||||
|
||||
constexpr CUtensorMapDataType tma_format = (
|
||||
std::is_same_v<dtype, bf16> ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 :
|
||||
std::is_same_v<dtype, half> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT16 :
|
||||
std::is_same_v<dtype, float> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT32 :
|
||||
std::is_same_v<dtype, fp8e4m3> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
|
||||
std::is_same_v<dtype, fp8e5m2> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
std::is_same_v<dtype, fp8e8m0> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
|
||||
#endif
|
||||
CUtensorMapDataType(-1)
|
||||
);
|
||||
constexpr CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE;
|
||||
constexpr CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE;
|
||||
constexpr CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
|
||||
constexpr CUtensorMapSwizzle tma_swizzle = enable_swizzle ? (
|
||||
ST::swizzle_bytes == 32 ? CU_TENSOR_MAP_SWIZZLE_32B :
|
||||
ST::swizzle_bytes == 64 ? CU_TENSOR_MAP_SWIZZLE_64B :
|
||||
ST::swizzle_bytes == 128 ? CU_TENSOR_MAP_SWIZZLE_128B :
|
||||
CU_TENSOR_MAP_SWIZZLE_NONE
|
||||
) : CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
|
||||
// Works for tma_dim = 4 too
|
||||
uint64_t gmem_shape [5] = {0, 0, 0, 0, 0};
|
||||
uint64_t gmem_stride[4] = {0, 0, 0, 0};
|
||||
uint32_t smem_shape [5] = {0, 0, 0, 0, 0};
|
||||
uint32_t smem_stride[5] = {1, 1, 1, 1, 1};
|
||||
|
||||
constexpr uint64_t shared_tile_height = ST::rows;
|
||||
constexpr uint64_t shared_tile_width = ST::cols;
|
||||
|
||||
constexpr int swizzle_elements = ST::swizzle_bytes / sizeof(dtype);
|
||||
|
||||
if constexpr (enable_swizzle) {
|
||||
if constexpr (axis == 2) {
|
||||
gmem_shape[0] = swizzle_elements;
|
||||
gmem_shape[1] = (uint64_t)rows;
|
||||
gmem_shape[2] = (uint64_t)(cols+swizzle_elements-1) / swizzle_elements; // round up, note this can potentially screw up out of bounds access handling :/
|
||||
gmem_shape[3] = (uint64_t)depth;
|
||||
gmem_shape[4] = (uint64_t)batch;
|
||||
|
||||
gmem_stride[0] = (uint64_t)cols * sizeof(dtype);
|
||||
gmem_stride[1] = ST::swizzle_bytes;
|
||||
gmem_stride[2] = (uint64_t)rows * cols * sizeof(dtype);
|
||||
gmem_stride[3] = (uint64_t)depth * rows * cols * sizeof(dtype);
|
||||
}
|
||||
else if constexpr (axis == 1) {
|
||||
gmem_shape[0] = swizzle_elements;
|
||||
gmem_shape[1] = (uint64_t)depth;
|
||||
gmem_shape[2] = (uint64_t)(cols+swizzle_elements-1) / swizzle_elements; // round up, note this can potentially screw up out of bounds access handling :/
|
||||
gmem_shape[3] = (uint64_t)rows;
|
||||
gmem_shape[4] = (uint64_t)batch;
|
||||
|
||||
gmem_stride[0] = (uint64_t)rows * cols * sizeof(dtype);
|
||||
gmem_stride[1] = ST::swizzle_bytes;
|
||||
gmem_stride[2] = (uint64_t)cols * sizeof(dtype);
|
||||
gmem_stride[3] = (uint64_t)depth * rows * cols * sizeof(dtype);
|
||||
|
||||
}
|
||||
else {
|
||||
gmem_shape[0] = swizzle_elements;
|
||||
gmem_shape[1] = (uint64_t)batch;
|
||||
gmem_shape[2] = (uint64_t)(cols+swizzle_elements-1) / swizzle_elements; // round up, note this can potentially screw up out of bounds access handling :/
|
||||
gmem_shape[3] = (uint64_t)rows;
|
||||
gmem_shape[4] = (uint64_t)depth;
|
||||
|
||||
gmem_stride[0] = (uint64_t)depth * rows * cols * sizeof(dtype);
|
||||
gmem_stride[1] = ST::swizzle_bytes;
|
||||
gmem_stride[2] = (uint64_t)cols * sizeof(dtype);
|
||||
gmem_stride[3] = (uint64_t)rows * cols * sizeof(dtype);
|
||||
}
|
||||
smem_shape[0] = swizzle_elements;
|
||||
smem_shape[1] = shared_tile_height;
|
||||
smem_shape[2] = shared_tile_width / swizzle_elements;
|
||||
smem_shape[3] = 1;
|
||||
smem_shape[4] = 1;
|
||||
} else {
|
||||
gmem_shape[0] = (uint64_t)cols;
|
||||
gmem_shape[1] = (uint64_t)rows;
|
||||
gmem_shape[2] = (uint64_t)depth;
|
||||
gmem_shape[3] = (uint64_t)batch;
|
||||
|
||||
gmem_stride[0] = (uint64_t)cols * sizeof(dtype);
|
||||
gmem_stride[1] = (uint64_t)rows * cols * sizeof(dtype);
|
||||
gmem_stride[2] = (uint64_t)depth * rows * cols * sizeof(dtype);
|
||||
|
||||
smem_shape[0] = shared_tile_width;
|
||||
smem_shape[1] = shared_tile_height;
|
||||
smem_shape[2] = 1;
|
||||
smem_shape[3] = 1;
|
||||
}
|
||||
|
||||
// ensure that the global address is always 16-byte aligned
|
||||
assert((reinterpret_cast<uint64_t>(global_addr) & 0b1111) == 0);
|
||||
|
||||
assert(gmem_stride[0] % 16 == 0); // gmem_stride[0] elements must be a multiple of 16B
|
||||
assert(gmem_stride[1] % 16 == 0); // gmem_stride[1] elements must be a multiple of 16B
|
||||
assert(gmem_stride[2] % 16 == 0); // gmem_stride[2] elements must be a multiple of 16B
|
||||
assert(gmem_stride[3] % 16 == 0); // gmem_stride[2] elements must be a multiple of 16B
|
||||
|
||||
assert(smem_shape[0] <= 256); // smem_shape[0] elements must be <= 256
|
||||
assert(smem_shape[1] <= 256); // smem_shape[1] elements must be <= 256
|
||||
assert(smem_shape[2] <= 256); // smem_shape[2] elements must be <= 256
|
||||
|
||||
assert((smem_shape[0]*sizeof(dtype)) % 16 == 0); // if wgmma_interleave is none, then smem_shape[0] * sizeof(dtype) must be a multiple of 16B
|
||||
|
||||
assert(smem_stride[0] <= 8); // smem_stride[0] must be less <= 8
|
||||
assert(smem_stride[1] <= 8); // smem_stride[1] must be less <= 8
|
||||
assert(smem_stride[2] <= 8); // smem_stride[2] must be less <= 8
|
||||
assert(smem_stride[3] <= 8); // smem_stride[3] must be less <= 8
|
||||
assert(smem_stride[4] <= 8); // smem_stride[3] must be less <= 8
|
||||
|
||||
assert(smem_stride[0] == 1); // smem_stride[0] is ignored when wgmma_interleave is none
|
||||
|
||||
if constexpr (tma_interleave == CU_TENSOR_MAP_INTERLEAVE_NONE && tma_swizzle != CU_TENSOR_MAP_SWIZZLE_NONE) {
|
||||
assert(smem_shape[0] * sizeof(dtype) <= ST::swizzle_bytes);
|
||||
}
|
||||
|
||||
const uint64_t *gmem_shape_ptr = &gmem_shape[0];
|
||||
const uint64_t *gmem_stride_ptr = &gmem_stride[0];
|
||||
const uint32_t *smem_shape_ptr = &smem_shape[0];
|
||||
const uint32_t *smem_stride_ptr = &smem_stride[0];
|
||||
|
||||
CUresult result = cuTensorMapEncodeTiled(
|
||||
tma_map,
|
||||
tma_format,
|
||||
tma_dim,
|
||||
global_addr,
|
||||
gmem_shape_ptr,
|
||||
gmem_stride_ptr,
|
||||
smem_shape_ptr,
|
||||
smem_stride_ptr,
|
||||
tma_interleave,
|
||||
tma_swizzle,
|
||||
tma_l2Promotion,
|
||||
tma_oobFill);
|
||||
|
||||
const char *error_string;
|
||||
CUresult res = cuGetErrorString(result, &error_string);
|
||||
if (result != CUDA_SUCCESS) {
|
||||
std::string error_msg = format_tma_error(
|
||||
"tile", error_string,
|
||||
batch, depth, rows, cols,
|
||||
tma_map, tma_format, tma_dim, global_addr,
|
||||
gmem_shape_ptr, gmem_stride_ptr,
|
||||
smem_shape_ptr, smem_stride_ptr,
|
||||
5, 4, 5, 5,
|
||||
tma_interleave, tma_swizzle, tma_l2Promotion, tma_oobFill,
|
||||
"ST::rows: " + std::to_string(ST::rows) + "\n ST::cols: " + std::to_string(ST::cols)
|
||||
);
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Allocates on the GPU and initializes a tensor map for the given source tensor.
|
||||
*
|
||||
* This function creates a tensor map (CUtensorMap) for the specified source shared tile type. The tensor map
|
||||
* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor
|
||||
* map based on the provided source tensor pointer and the layout specified by the ST template parameter.
|
||||
*
|
||||
* @tparam ST The source tensor type, which must be TMA-compatible.
|
||||
* @tparam blocks_height The number of tiles present on the height axis in global memory.
|
||||
* @tparam blocks_width The number of tiles present on the width axis in global memory. Defaults to 1.
|
||||
* @param src Pointer to the source tensor data in global memory.
|
||||
* @returns Pointer to the CUtensorMap object to be initialized.
|
||||
*/
|
||||
template<ducks::st::all ST>
|
||||
__host__ static inline CUtensorMap* allocate_and_create_tensor_map(const typename ST::dtype *src, int batch, int depth, int rows, int cols) {
|
||||
CUtensorMap *tma_map_d;
|
||||
cudaMalloc(&tma_map_d, sizeof(CUtensorMap));
|
||||
CUtensorMap tma_map_host; // put it on the stack, why not.
|
||||
create_tensor_map<ST>(&tma_map_host, src, batch, depth, rows, cols);
|
||||
cudaMemcpy(tma_map_d, &tma_map_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
return tma_map_d;
|
||||
}
|
||||
|
||||
/* ---------- Create vector tensor map descriptor (HOST) ---------- */
|
||||
|
||||
// First, we need a template system to determine how to divide up a long shared vector into multiple subvectors.
|
||||
// We have to do this because the first dimension for TMA is limited to 256 elements.
|
||||
// Our goal is to find the largest multiple of 16 that is <= 256 and divides the vector length evenly.
|
||||
|
||||
template<typename SV, int D=16> struct find_vector_divider {
|
||||
static constexpr int value = (SV::length % (16*D) == 0 && (SV::length < 256 || ((16*D)*sizeof(typename SV::dtype)) % 128 == 0)) ?
|
||||
16*D : find_vector_divider<SV, D-1>::value;
|
||||
};
|
||||
template<typename SV> struct find_vector_divider<SV, 1> { static constexpr int value = 16; }; // base case
|
||||
template<typename SV> constexpr int sv_tma_dim1 = find_vector_divider<SV>::value; // inner dim
|
||||
template<typename SV> constexpr int sv_tma_dim2 = (SV::length / sv_tma_dim1<SV>);
|
||||
|
||||
/**
|
||||
* @brief Creates a tensor map for the given source vector.
|
||||
*
|
||||
* This function creates a tensor map (CUtensorMap) for the specified source shared vector type. The tensor map
|
||||
* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor
|
||||
* map based on the provided source tensor pointer and the layout specified by the SV template parameter.
|
||||
*
|
||||
* @tparam SV The source tensor type, which must be TMA-compatible.
|
||||
* @tparam num_vectors The number of vectors present in global memory.
|
||||
* @param tma_map Pointer to the CUtensorMap object to be initialized.
|
||||
* @param src Pointer to the source tensor data in global memory.
|
||||
*/
|
||||
template<ducks::sv::all SV, int axis, bool disable_swizzle = true>
|
||||
__host__ static inline void create_tensor_map(CUtensorMap *tma_map, const typename SV::dtype *src, int batch, int depth, int rows, int cols) {
|
||||
using dtype = typename SV::dtype;
|
||||
static_assert(axis == -1, "for vector TMA, row axis must be -1 as it's unused");
|
||||
static_assert(SV::length <= 256 || (SV::length*sizeof(dtype)) % 128 == 0);
|
||||
// There is technically a way around ^ that involves instantiating two separate TMA descriptors, one of size 256
|
||||
// and the other of size %256, but this is a fairly mild restriction and the other approach is a real PITA and incurs other costs.
|
||||
static_assert(disable_swizzle, "for vector TMA, swizzle should be disabled");
|
||||
|
||||
constexpr uint32_t tma_dim = 4;
|
||||
void *global_addr = (void*)(src);
|
||||
|
||||
constexpr CUtensorMapDataType tma_format = (
|
||||
std::is_same_v<dtype, bf16> ? CU_TENSOR_MAP_DATA_TYPE_BFLOAT16 :
|
||||
std::is_same_v<dtype, half> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT16 :
|
||||
std::is_same_v<dtype, float> ? CU_TENSOR_MAP_DATA_TYPE_FLOAT32 :
|
||||
std::is_same_v<dtype, fp8e4m3> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
|
||||
std::is_same_v<dtype, fp8e5m2> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
|
||||
#ifdef KITTENS_BLACKWELL
|
||||
std::is_same_v<dtype, fp8e8m0> ? CU_TENSOR_MAP_DATA_TYPE_UINT8 :
|
||||
#endif
|
||||
CUtensorMapDataType(-1)
|
||||
);
|
||||
constexpr CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE;
|
||||
constexpr CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_NONE;
|
||||
constexpr CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
|
||||
constexpr CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE;
|
||||
|
||||
constexpr uint64_t dim1 = sv_tma_dim1<SV>; // inner dim
|
||||
// constexpr uint64_t dim2 = sv_tma_dim2<SV>; outer dim, not used here.
|
||||
|
||||
uint64_t gmem_shape [4] = {(uint64_t)cols, (uint64_t)rows, (uint64_t)depth, (uint64_t)batch};
|
||||
uint64_t gmem_stride[3] = {(uint64_t)cols*sizeof(dtype), (uint64_t)cols*rows*sizeof(dtype), (uint64_t)cols*rows*depth*sizeof(dtype)};
|
||||
uint32_t smem_shape [4] = {(uint32_t)dim1, 1, 1, 1};
|
||||
uint32_t smem_stride[4] = {1, 1, 1, 1};
|
||||
|
||||
// ensure that the global address is always 16-byte aligned
|
||||
assert((reinterpret_cast<uint64_t>(global_addr) & 0b1111) == 0);
|
||||
|
||||
assert(smem_shape[0] <= 256); // smem_shape[0] elements must be <= 256.
|
||||
|
||||
const uint64_t *gmem_shape_ptr = &gmem_shape[0];
|
||||
const uint64_t *gmem_stride_ptr = &gmem_stride[0];
|
||||
const uint32_t *smem_shape_ptr = &smem_shape[0];
|
||||
const uint32_t *smem_stride_ptr = &smem_stride[0];
|
||||
|
||||
CUresult result = cuTensorMapEncodeTiled(
|
||||
tma_map,
|
||||
tma_format,
|
||||
tma_dim,
|
||||
global_addr,
|
||||
gmem_shape_ptr,
|
||||
gmem_stride_ptr,
|
||||
smem_shape_ptr,
|
||||
smem_stride_ptr,
|
||||
tma_interleave,
|
||||
swizzle,
|
||||
tma_l2Promotion,
|
||||
tma_oobFill
|
||||
);
|
||||
|
||||
const char *error_string;
|
||||
CUresult res = cuGetErrorString(result, &error_string);
|
||||
if (result != CUDA_SUCCESS) {
|
||||
std::string error_msg = format_tma_error(
|
||||
"vector", error_string,
|
||||
batch, depth, rows, cols,
|
||||
tma_map, tma_format, tma_dim, global_addr,
|
||||
gmem_shape_ptr, gmem_stride_ptr,
|
||||
smem_shape_ptr, smem_stride_ptr,
|
||||
4, 3, 4, 4,
|
||||
tma_interleave, swizzle, tma_l2Promotion, tma_oobFill,
|
||||
"SV::length: " + std::to_string(SV::length)
|
||||
);
|
||||
throw std::runtime_error(error_msg);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Allocates on the GPU and initializes a tensor map for the given source tensor.
|
||||
*
|
||||
* This function creates a tensor map (CUtensorMap) for the specified source shared vector type. The tensor map
|
||||
* is used to describe the shape and layout of the tensor in memory. The function sets up the tensor
|
||||
* map based on the provided source tensor pointer and the layout specified by the SV template parameter.
|
||||
*
|
||||
* @tparam SV The source tensor type, which must be TMA-compatible.
|
||||
* @tparam num_vectors The number of vectors present in global memory.
|
||||
* @param src Pointer to the source tensor data in global memory.
|
||||
* @returns Pointer to the CUtensorMap object to be initialized.
|
||||
*/
|
||||
template<ducks::sv::all SV>
|
||||
__host__ static inline CUtensorMap* allocate_and_create_tensor_map(const typename SV::dtype *src, int batch, int depth, int rows, int cols) {
|
||||
CUtensorMap *tma_map_d;
|
||||
cudaMalloc(&tma_map_d, sizeof(CUtensorMap));
|
||||
CUtensorMap tma_map_host; // put it on the stack, why not.
|
||||
create_tensor_map<SV>(&tma_map_host, src, batch, depth, rows, cols);
|
||||
cudaMemcpy(tma_map_d, &tma_map_host, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
||||
return tma_map_d;
|
||||
}
|
||||
|
||||
} // namespace tma
|
||||
} // namespace detail
|
||||
} // namespace kittens
|
||||
99
extra/thunder/cuda/include/types/global/util.cuh
Normal file
99
extra/thunder/cuda/include/types/global/util.cuh
Normal file
@@ -0,0 +1,99 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <cstddef>
|
||||
#include "../register/register.cuh"
|
||||
|
||||
namespace kittens {
|
||||
namespace ducks {
|
||||
namespace gl {
|
||||
|
||||
template<int d> concept cdim = (d > 0); // represents a compile-time dimension
|
||||
template<int d> concept rdim = (d == -1); // represents a runtime dimension
|
||||
template<int _v> struct compiled_dim {
|
||||
static_assert(cdim<_v>, "Invalid compile-time dimension value");
|
||||
static constexpr size_t v = _v;
|
||||
__host__ __device__ inline compiled_dim(const std::nullptr_t &_) {}
|
||||
__host__ __device__ inline constexpr operator size_t() const { return v; }
|
||||
};
|
||||
struct runtime_dim {
|
||||
size_t v;
|
||||
__host__ __device__ inline runtime_dim(const size_t &_v) : v(_v) {}
|
||||
__host__ __device__ inline operator size_t() const { return v; }
|
||||
};
|
||||
template<int d> using make_dim_t = std::conditional_t<rdim<d>, runtime_dim, compiled_dim<d>>;
|
||||
template<int d> using make_arg_t = std::conditional_t<rdim<d>, size_t, std::nullptr_t>; // we pass runtime dims as size_t, comptime dims as nullptr_t
|
||||
}
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
template<typename T> concept tile = ducks::st::all<T> || ducks::rt::all<T> || ducks::cst::all<T> || ducks::crt::all<T>;
|
||||
template<typename T> concept vec = ducks::sv::all<T> || ducks::rv::all<T> || ducks::csv::all<T> || ducks::crv::all<T>;
|
||||
}
|
||||
|
||||
namespace ducks {
|
||||
namespace coord {
|
||||
struct identifier {};
|
||||
}
|
||||
}
|
||||
template<typename _T=ducks::default_type> struct coord { // essentially a named int4 for tensor coordinates.
|
||||
using identifier = ducks::coord::identifier;
|
||||
using BASE = _T; // in units of what type?
|
||||
static_assert(std::is_same_v<BASE, ducks::default_type> || detail::tile<BASE> || detail::vec<BASE>); // ensure BASE is a valid type
|
||||
int b, d, r, c;
|
||||
__device__ inline coord(int _b, int _d, int _r, int _c) : b(_b), d(_d), r(_r), c(_c) {}
|
||||
__device__ inline coord( int _d, int _r, int _c) : b( 0), d(_d), r(_r), c(_c) {}
|
||||
__device__ inline coord( int _r, int _c) : b( 0), d( 0), r(_r), c(_c) {}
|
||||
__device__ inline coord( int _c) : b( 0), d( 0), r( 0), c(_c) {}
|
||||
__device__ inline coord( ) : b( 0), d( 0), r( 0), c( 0) {}
|
||||
template<typename U> __device__ inline coord(const coord<U> &other) : b(other.b), d(other.d), r(other.r), c(other.c) {}
|
||||
__device__ inline coord(const int4 &other) : b(other.x), d(other.y), r(other.z), c(other.w) {}
|
||||
__device__ inline operator int4() const { return int4(b, d, r, c); }
|
||||
template<int row_axis, int col_axis> __device__ inline coord<ducks::default_type> unit_coord() const {
|
||||
if constexpr (detail::tile<BASE>) {
|
||||
static_assert(row_axis != col_axis, "row and column axes must be different");
|
||||
static_assert(row_axis >= 0 && row_axis <= 3, "row axis must be between 0 and 3");
|
||||
static_assert(col_axis >= 0 && col_axis <= 3, "column axis must be between 0 and 3");
|
||||
static_assert(col_axis == 3, "for now, column axis must be 3");
|
||||
return coord<ducks::default_type>(
|
||||
row_axis == 0 ? b*BASE::rows : b,
|
||||
row_axis == 1 ? d*BASE::rows : d,
|
||||
row_axis == 2 ? r*BASE::rows : r,
|
||||
c*BASE::cols
|
||||
);
|
||||
}
|
||||
else if constexpr (detail::vec<BASE>) {
|
||||
static_assert(row_axis == -1, "row axis must be be -1 for a vector coordinate to be converted to a unit coordinate");
|
||||
static_assert(col_axis >= 0 && col_axis <= 3, "column axis must be between 0 and 3");
|
||||
static_assert(col_axis == 3, "for now, column axis must be 3");
|
||||
return coord<ducks::default_type>(b, d, r, c*BASE::length);
|
||||
}
|
||||
else {
|
||||
return coord<ducks::default_type>(*this);
|
||||
}
|
||||
}
|
||||
template<int axis> __device__ inline int dim() const {
|
||||
static_assert(axis >= 0 && axis <= 3, "axis must be between 0 and 3");
|
||||
if constexpr (axis == 0) { return b; }
|
||||
else if constexpr (axis == 1) { return d; }
|
||||
else if constexpr (axis == 2) { return r; }
|
||||
else { return c; }
|
||||
}
|
||||
};
|
||||
namespace ducks {
|
||||
namespace coord {
|
||||
/**
|
||||
* @brief Concept for all coordinate types.
|
||||
* @tparam T The type to check against the concept requirements.
|
||||
*
|
||||
* Requires:
|
||||
* - T has a nested type identifier that is the same as ducks::coord::identifier.
|
||||
*/
|
||||
template<typename T> concept all = requires {
|
||||
typename T::identifier; // Checks if T::identifier exists
|
||||
} && std::is_same_v<typename T::identifier, identifier>; // Checks if T::identifier is ducks::coord::identifier
|
||||
template<typename T> concept tile = all<T> && (std::is_same_v<typename T::BASE, ducks::default_type> || detail::tile<typename T::BASE>);
|
||||
template<typename T> concept vec = all<T> && (std::is_same_v<typename T::BASE, ducks::default_type> || detail::vec<typename T::BASE>);
|
||||
}
|
||||
}
|
||||
}
|
||||
95
extra/thunder/cuda/include/types/register/crt.cuh
Normal file
95
extra/thunder/cuda/include/types/register/crt.cuh
Normal file
@@ -0,0 +1,95 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Abstraction for a complex register tile composed of real and imaginary tiles
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "rt.cuh"
|
||||
#include "crv.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
namespace ducks {
|
||||
namespace crt {
|
||||
/**
|
||||
* @brief A dummy type used to identify complex register tiles.
|
||||
*
|
||||
* For a type to quack like an rt_cmplx, it should define its identifier as ducks::rt::cmplx_identifier.
|
||||
* If a type quacks like ducks::rt::cmplx_identifier, it will be treated as an rt_cmplx by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
/**
|
||||
* @brief Concept for register tiles that are complex.
|
||||
* @tparam T The type to check against the concept requirements.
|
||||
*
|
||||
* Requires:
|
||||
* - T is a register tile.
|
||||
* - T has a complex tile identifier.
|
||||
*/
|
||||
template <typename T> concept all = requires {
|
||||
typename T::identifier;
|
||||
} && std::is_same_v<typename T::identifier, identifier> && ducks::rt::all<typename T::component>;
|
||||
|
||||
/*
|
||||
* Requires:
|
||||
* - T is a register tile.
|
||||
* - T has an internal type layout that is ducks::rt_layout::row.
|
||||
*/
|
||||
template<typename T>
|
||||
concept row_layout = all<T> && std::is_same_v<typename T::layout, ducks::rt_layout::row>;
|
||||
/**
|
||||
* @brief Concept for register tiles with col layout.
|
||||
* @tparam T The type to check against the concept requirements.
|
||||
*
|
||||
* Requires:
|
||||
* - T is a register tile.
|
||||
* - T has an internal type layout that is ducks::rt_layout::col.
|
||||
*/
|
||||
template<typename T>
|
||||
concept col_layout = all<T> && std::is_same_v<typename T::layout, ducks::rt_layout::col>;
|
||||
} // namespace rt
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @brief Complex tile structure
|
||||
*
|
||||
* @tparam T2 The packed data type used for the matrix elements.
|
||||
* @tparam _rows The height of the tile in terms of the number of subtiles.
|
||||
* @tparam _cols The width of the tile in terms of the number of subtiles.
|
||||
* @tparam _layout The layout of the internal register tiles, either row-major or column-major.
|
||||
*
|
||||
* This structure is designed to abstract complex number operations internally to the real and imaginary
|
||||
* register tiles, respectively
|
||||
*
|
||||
* In general, you probably want a row-major tile, unless you specifically want to call mma
|
||||
*/
|
||||
template<typename _T, int _rows, int _cols, ducks::rt_layout::all _layout=ducks::rt_layout::row>
|
||||
struct crt {
|
||||
using identifier = ducks::crt::identifier;
|
||||
using component = rt<_T, _rows, _cols, _layout>; /// Data type of each internal tile.
|
||||
using layout = component::layout; ///< Layout of the matrix tile, ensures compatibility with the rt concepts
|
||||
using T = component::T;
|
||||
using T2 = component::T2;
|
||||
using dtype = component::dtype; ///< Data type of the elements in the tile.
|
||||
|
||||
static constexpr int rows = component::rows;
|
||||
static constexpr int cols = component::cols;
|
||||
static constexpr int height = component::height;
|
||||
static constexpr int width = component::width;
|
||||
|
||||
// Real/imag tiles have same internal layout and size
|
||||
component real;
|
||||
component imag;
|
||||
|
||||
using row_vec = crv<T, cols, typename rt_base<T, layout>::row_vec_layout>; ///< A type representing a column vector for this tile.
|
||||
using col_vec = crv<T, rows, typename rt_base<T, layout>::col_vec_layout>; ///< A type representing a column vector for this tile.
|
||||
};
|
||||
|
||||
template<int _rows, int _cols, ducks::rt_layout::all layout=ducks::rt_layout::row> using crt_fl = crt<float, _rows, _cols, layout>;
|
||||
template<int _rows, int _cols, ducks::rt_layout::all layout=ducks::rt_layout::row> using crt_bf = crt<bf16, _rows, _cols, layout>;
|
||||
template<int _rows, int _cols, ducks::rt_layout::all layout=ducks::rt_layout::row> using crt_hf = crt<half, _rows, _cols, layout>;
|
||||
|
||||
|
||||
|
||||
}
|
||||
88
extra/thunder/cuda/include/types/register/crv.cuh
Normal file
88
extra/thunder/cuda/include/types/register/crv.cuh
Normal file
@@ -0,0 +1,88 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Register vectors for computations on axes.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "../../common/common.cuh"
|
||||
#include "rv_layout.cuh"
|
||||
|
||||
namespace kittens {
|
||||
|
||||
/* ---------- MAIN VECTOR STRUCT ---------- */
|
||||
|
||||
// helper struct for type inference
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace rt
|
||||
*
|
||||
* @brief The namespace where concepts and abstract types for register vectors live.
|
||||
*/
|
||||
namespace crv {
|
||||
/**
|
||||
* @brief A dummy type used to identify register vectors.
|
||||
*
|
||||
* For a type to quack like an rv, it should define its identifier as ducks::rv::identifier.
|
||||
* If a type quacks like ducks::rv::identifier, it will be treated as an rv by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
/**
|
||||
* @brief Concept for all register vectors.
|
||||
* @tparam T The type to check against the concept requirements.
|
||||
*
|
||||
* Requires:
|
||||
* - T has a nested type identifier that is the same as rv::identifier.
|
||||
*/
|
||||
template<typename T>
|
||||
concept all = requires {
|
||||
typename T::identifier; // Checks if T::identifier exists
|
||||
} && std::is_same_v<typename T::identifier, identifier>; // Checks if T::identifier is ducks::rv::identifier.
|
||||
|
||||
template<typename T> concept naive_layout = all<T> && std::is_same_v<typename T::layout, ducks::rv_layout::naive>;
|
||||
template<typename T> concept align_layout = all<T> && std::is_same_v<typename T::layout, ducks::rv_layout::align>;
|
||||
template<typename T> concept ortho_layout = all<T> && std::is_same_v<typename T::layout, ducks::rv_layout::ortho>;
|
||||
template<typename T> concept tile_layout = align_layout<T> || ortho_layout<T>; // vector layouts for interacting with tiles.
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Register vector structure.
|
||||
*
|
||||
* @tparam _T The packed data type used for the vector elements.
|
||||
* @tparam _outer_dim The size of the tile, in units of TILE_DIM (16).
|
||||
* @tparam _inner_dim This controls the layout of the tile in terms of which axis it maps on the register tile layout.
|
||||
*
|
||||
* Register vectors are used to accumulate and map values across tiles. You can do computation
|
||||
* on them directly if you want, but they're not designed to be maximally efficient vectors
|
||||
* as they have substantial duplication and strange layouts to help them work efficiently with
|
||||
* the register layouts used by the tensor cores. ThunderKittens wants you working with tiles
|
||||
* where possible!
|
||||
*/
|
||||
|
||||
template<typename _T, size_t _length, ducks::rv_layout::all _layout=ducks::rv_layout::naive>
|
||||
struct crv {
|
||||
using identifier = ducks::crv::identifier;
|
||||
using component = rv<_T, _length, _layout>; /// Data type of each internal tile.
|
||||
using layout = component::layout; ///< Layout of the matrix tile, ensures compatibility with the rv concepts
|
||||
|
||||
using T = component::T;
|
||||
using T2 = component::T2;
|
||||
using dtype = component::dtype; ///< Data type of the elements in the tile.
|
||||
|
||||
static constexpr int length = component::length;
|
||||
static constexpr int tiles = component::tiles;
|
||||
|
||||
// Real/imag tiles have same internal layout and size
|
||||
component real;
|
||||
component imag;
|
||||
};
|
||||
|
||||
|
||||
template<int _l, ducks::rv_layout::all layout=ducks::rv_layout::naive> using crv_fl = crv<float, _l, layout>;
|
||||
template<int _l, ducks::rv_layout::all layout=ducks::rv_layout::naive> using crv_bf = crv<bf16, _l, layout>;
|
||||
template<int _l, ducks::rv_layout::all layout=ducks::rv_layout::naive> using crv_hf = crv<half, _l, layout>;
|
||||
|
||||
} // namespace kittens
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user