mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
71
extra/thunder/gemm.py
Normal file
71
extra/thunder/gemm.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# include directory copied from https://github.com/HazyResearch/ThunderMittens
|
||||
|
||||
gemm = """
|
||||
#include <metal_stdlib>
|
||||
#include "include/tk.metal"
|
||||
using namespace mittens;
|
||||
|
||||
#define GEMM_PARAMS_DEF(T) \
|
||||
device T* D [[buffer(0)]], \
|
||||
device T* A [[buffer(1)]], \
|
||||
device T* B [[buffer(2)]], \
|
||||
const constant int &N [[buffer(3)]], \
|
||||
const constant int &K [[buffer(4)]], \
|
||||
const constant int &M [[buffer(5)]], \
|
||||
uint3 tg_id [[threadgroup_position_in_grid]], \
|
||||
uint simd_lane_id [[thread_index_in_simdgroup]]
|
||||
|
||||
template<typename T, unsigned N_BLOCK, unsigned K_BLOCK, unsigned M_BLOCK>
|
||||
kernel void matmul_naive(GEMM_PARAMS_DEF(T)) {
|
||||
using global_layout = gl<T, 1, 1, -1, -1>;
|
||||
global_layout gl_a(A, nullptr, nullptr, N, K);
|
||||
global_layout gl_b(B, nullptr, nullptr, K, M);
|
||||
global_layout gl_d(D, nullptr, nullptr, N, M);
|
||||
rt<T, N_BLOCK * TILE_DIM, K_BLOCK * TILE_DIM> a_reg;
|
||||
rt<T, K_BLOCK * TILE_DIM, M_BLOCK * TILE_DIM> b_reg;
|
||||
rt<float, N_BLOCK * TILE_DIM, M_BLOCK * TILE_DIM> d_reg;
|
||||
zero(d_reg);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 0; k < K / (K_BLOCK * TILE_DIM); k++) {
|
||||
load(a_reg, gl_a, {0, 0, (int)tg_id.y, k}, simd_lane_id);
|
||||
load(b_reg, gl_b, {0, 0, k, (int)tg_id.x}, simd_lane_id);
|
||||
mma_AB(d_reg, a_reg, b_reg, d_reg);
|
||||
}
|
||||
store(gl_d, d_reg, {0, 0, (int)tg_id.y, (int)tg_id.x}, simd_lane_id);
|
||||
}
|
||||
|
||||
#define instantiate_matmul_custom(type_name, T) \
|
||||
template [[host_name("matmul_custom_" #type_name)]] [[kernel]] \
|
||||
void matmul_naive<T, 4, 2, 4>(GEMM_PARAMS_DEF(T)); \
|
||||
|
||||
instantiate_matmul_custom(float32, float);
|
||||
"""
|
||||
|
||||
from tinygrad import Device, Tensor
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: why isn't this type inferred?
|
||||
device = Device["METAL"]
|
||||
lib = device.compiler.compile(gemm)
|
||||
prg = device.runtime("matmul_custom_float32", lib)
|
||||
|
||||
N = 4096
|
||||
a = Tensor.randn(N, N)
|
||||
b = Tensor.randn(N, N)
|
||||
c = Tensor.empty(N, N)
|
||||
Tensor.realize(a, b, c)
|
||||
|
||||
TILE_DIM = 8
|
||||
N_BLOCK = 4
|
||||
M_BLOCK = 4
|
||||
|
||||
gsz = (N // (M_BLOCK * TILE_DIM), N // (N_BLOCK * TILE_DIM), 1)
|
||||
for _ in range(5):
|
||||
et = prg(c.uop.buffer.ensure_allocated()._buf, a.uop.buffer._buf, b.uop.buffer._buf,
|
||||
global_size=gsz, local_size=(32,1,1), vals=(N, N, N), wait=True)
|
||||
print(f"{N*N*N*2/(et*1e9):2f} GFLOPS")
|
||||
|
||||
val = ((a@b).contiguous()-c).mean()
|
||||
print(val.item())
|
||||
|
||||
|
||||
392
extra/thunder/include/common/base_ops.metal
Normal file
392
extra/thunder/include/common/base_ops.metal
Normal file
@@ -0,0 +1,392 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Basic operations on generic types.
|
||||
*/
|
||||
#pragma once
|
||||
#include "base_types.metal"
|
||||
#include <metal_math>
|
||||
|
||||
namespace mittens {
|
||||
/**
|
||||
* @namespace base_ops
|
||||
*
|
||||
* @brief A namespace for operations on basic data types.
|
||||
*/
|
||||
namespace base_ops {
|
||||
#define TEMPLATE_OPS_SINGLE(func_contents) \
|
||||
template<typename T> static METAL_FUNC T op(device const T &x) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &x) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &x) { func_contents }
|
||||
|
||||
#define TEMPLATE_OPS_OVERRIDE_SINGLE(T, op_name, func_contents) \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &x) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &x) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &x) { func_contents }
|
||||
|
||||
#define TEMPLATE_OPS_DOUBLE(func_contents) \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, device const T &b) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, threadgroup const T &b) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, thread const T &b) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, device const T &b) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, thread const T &b) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, device const T &b) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, threadgroup const T &b) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, thread const T &b) { func_contents }
|
||||
|
||||
#define TEMPLATE_OPS_OVERRIDE_DOUBLE(T, op_name, func_contents) \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, device const T &b) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, threadgroup const T &b) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, thread const T &b) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, device const T &b) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, threadgroup const T &b) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, thread const T &b) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, device const T &b) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, threadgroup const T &b) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, thread const T &b) { func_contents }
|
||||
|
||||
#define TEMPLATE_OPS_TRIPLE(func_contents) \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, device const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, device const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, device const T &b, thread const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, threadgroup const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, thread const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(device const T &a, thread const T &b, thread const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, device const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, device const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, device const T &b, thread const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, thread const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &a, thread const T &b, thread const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, device const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, device const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, device const T &b, thread const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, threadgroup const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, thread const T &b, device const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &a, thread const T &b, thread const T &c) { func_contents }
|
||||
|
||||
#define TEMPLATE_OPS_OVERRIDE_TRIPLE(T, op_name, func_contents) \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, device const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, device const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, device const T &b, thread const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, threadgroup const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, thread const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(device const T &a, thread const T &b, thread const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, device const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, device const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, device const T &b, thread const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, threadgroup const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, thread const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, thread const T &b, thread const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, device const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, device const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, device const T &b, thread const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, threadgroup const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, thread const T &b, device const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
|
||||
template<> METAL_FUNC T op_name::op<T>(thread const T &a, thread const T &b, thread const T &c) { func_contents }
|
||||
|
||||
|
||||
|
||||
/* ---------- CONST OPS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Represents the zero constant operation.
|
||||
*
|
||||
* This operation returns the zero value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the zero value.
|
||||
* @return The zero value of type T.
|
||||
*/
|
||||
struct zero {
|
||||
template<typename T, typename... args> static METAL_FUNC constexpr T op(args... _) { return base_types::constants<T>::zero(); }
|
||||
};
|
||||
/**
|
||||
* @brief Represents the one constant operation.
|
||||
*
|
||||
* This operation returns the one value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the one value.
|
||||
* @return The one value of type T.
|
||||
*/
|
||||
struct one {
|
||||
template<typename T, typename... args> static METAL_FUNC constexpr T op(args... _) { return base_types::constants<T>::one(); }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Represents the positive infinity constant operation.
|
||||
*
|
||||
* This operation returns the positive infinity value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the positive infinity value.
|
||||
* @return The positive infinity value of type T.
|
||||
*/
|
||||
struct pos_infty {
|
||||
template<typename T, typename... args> static METAL_FUNC constexpr T op(args... _) { return base_types::constants<T>::pos_infty(); }
|
||||
};
|
||||
/**
|
||||
* @brief Represents the negative infinity constant operation.
|
||||
*
|
||||
* This operation returns the negative infinity value of the specified type.
|
||||
*
|
||||
* @tparam T The data type for which to return the negative infinity value.
|
||||
* @return The negative infinity value of type T.
|
||||
*/
|
||||
struct neg_infty {
|
||||
template<typename T, typename... args> static METAL_FUNC constexpr T op(args... _) { return base_types::constants<T>::neg_infty(); }
|
||||
};
|
||||
|
||||
|
||||
/* ---------- UNARY OPS ---------- */
|
||||
/**
|
||||
* @brief Exponential function operation.
|
||||
*
|
||||
* This operation calculates the exponential of the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The exponential of the input value.
|
||||
*/
|
||||
struct exp {
|
||||
TEMPLATE_OPS_SINGLE(return metal::exp(x);)
|
||||
};
|
||||
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, exp, return bf16(metal::exp((float)x));)
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, exp, return bf16_2(metal::exp(float2(x)));)
|
||||
|
||||
/**
|
||||
* @brief Exponential function operation, in base 2
|
||||
*
|
||||
* This operation calculates the exponential of the input value, in base 2.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The exponential of the input value.
|
||||
*/
|
||||
struct exp2 {
|
||||
template<typename T> static METAL_FUNC T op(device const T &x) { return metal::exp2(x); } \
|
||||
template<typename T> static METAL_FUNC T op(threadgroup const T &x) { return metal::exp2(x); } \
|
||||
template<typename T> static METAL_FUNC T op(thread const T &x) { return metal::exp2(x); }
|
||||
};
|
||||
|
||||
//template<> METAL_FUNC bf16 exp2::op<bf16>(device const bf16 &x) { return bf16(metal::exp2(x)); } \
|
||||
//template<> METAL_FUNC bf16 exp2::op<bf16>(threadgroup const bf16 &x) { return bf16(metal::exp2(x)); } \
|
||||
//template<> METAL_FUNC bf16 exp2::op<bf16>(thread const bf16 &x) { return bf16(metal::exp2(x)); }
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, exp2, return bf16(metal::exp2(x));)
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, exp2, return bf16_2(metal::exp2((float2)x));)
|
||||
|
||||
/**
|
||||
* @brief Natural log function operation.
|
||||
*
|
||||
* This operation calculates the natural logarithm of the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The natural logarithm of the input value.
|
||||
*/
|
||||
struct log {
|
||||
TEMPLATE_OPS_SINGLE(return metal::log(x);)
|
||||
};
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, log, return bf16(metal::log(x));)
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, log, return bf16_2(metal::log((float2)x));)
|
||||
|
||||
/**
|
||||
* @brief Absolute value operation.
|
||||
*
|
||||
* This operation calculates the absolute value of the input.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The absolute value of the input.
|
||||
*/
|
||||
struct abs {
|
||||
TEMPLATE_OPS_SINGLE(return metal::abs(x);)
|
||||
};
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16 , abs, return bf16(metal::abs((float)x));)
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, abs, return bf16_2(metal::abs((float2)x));)
|
||||
/**
|
||||
* @brief Rectified Linear Unit (ReLU) operation.
|
||||
*
|
||||
* This operation applies the ReLU function to the input, which is the
|
||||
* maximum of zero and the input value.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param x[in] The input value.
|
||||
* @return The result of ReLU function applied to the input.
|
||||
*/
|
||||
struct relu {
|
||||
TEMPLATE_OPS_SINGLE(return max(x, base_types::constants<T>::zero());)
|
||||
};
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16 , relu, return bf16(metal::max((float)x, base_types::constants<float>::zero()));)
|
||||
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, relu, return bf16_2(metal::max((float2)x, base_types::constants<float2>::zero()));)
|
||||
/**
|
||||
* @brief Copy operation.
|
||||
*
|
||||
* This operation returns the input value unchanged.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The input value.
|
||||
* @return The same value as the input.
|
||||
*/
|
||||
struct copy { // for non-compile-time setters.
|
||||
TEMPLATE_OPS_SINGLE(return x;)
|
||||
};
|
||||
|
||||
/* ---------- BINARY OPS ---------- */
|
||||
|
||||
|
||||
/**
|
||||
* @brief Copy2 operation.
|
||||
*
|
||||
* This operation returns the second input value unchanged.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value (ignored).
|
||||
* @param b[in] The second input value.
|
||||
* @return The same value as the second input.
|
||||
*/
|
||||
struct copy2 { // this turns out to be a slightly hacky op that makes some code cleaner :/
|
||||
TEMPLATE_OPS_DOUBLE(return b;)
|
||||
};
|
||||
/**
|
||||
* @brief Sum operation.
|
||||
*
|
||||
* This operation calculates the sum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The sum of the input values.
|
||||
*/
|
||||
struct sum {
|
||||
TEMPLATE_OPS_DOUBLE(return a+b;)
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Subtraction operation.
|
||||
*
|
||||
* This operation calculates the difference between two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The difference between the input values.
|
||||
*/
|
||||
struct sub {
|
||||
TEMPLATE_OPS_DOUBLE(return a-b;)
|
||||
};
|
||||
/**
|
||||
* @brief Multiplication operation.
|
||||
*
|
||||
* This operation calculates the product of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The product of the input values.
|
||||
*/
|
||||
struct mul {
|
||||
TEMPLATE_OPS_DOUBLE(return a*b;)
|
||||
};
|
||||
/**
|
||||
* @brief Division operation.
|
||||
*
|
||||
* This operation calculates the quotient of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The quotient of the input values.
|
||||
*/
|
||||
struct div {
|
||||
TEMPLATE_OPS_DOUBLE(return a/b;)
|
||||
};
|
||||
/**
|
||||
* @brief Maximum operation.
|
||||
*
|
||||
* This operation calculates the maximum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The maximum of the input values.
|
||||
*/
|
||||
struct max {
|
||||
TEMPLATE_OPS_DOUBLE(return metal::max(a,b);)
|
||||
};
|
||||
TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16 , max, return (bf16)metal::max((float)a, (float)b);)
|
||||
TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16_2, max, return (bf16_2)metal::max((float2)a, (float2)b);)
|
||||
/**
|
||||
* @brief Minimum operation.
|
||||
*
|
||||
* This operation calculates the minimum of two input values.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @return The minimum of the input values.
|
||||
*/
|
||||
struct min {
|
||||
TEMPLATE_OPS_DOUBLE(return metal::min(a,b);)
|
||||
};
|
||||
TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16 , min, return (bf16)metal::min((float)a, (float)b);)
|
||||
TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16_2, min, return (bf16_2)metal::min((float2)a, (float2)b);)
|
||||
|
||||
|
||||
/* ---------- TERNARY OPS ---------- */
|
||||
/**
|
||||
* @brief Fused multiply-add operation A * B + C.
|
||||
*
|
||||
* This operation performs a fused multiply-add, computing (A * B) + C with only one rounding.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The second input value.
|
||||
* @param c[in] The third input value to be added.
|
||||
* @return The result of the fused multiply-add operation.
|
||||
*/
|
||||
struct fma_AxBtC {
|
||||
TEMPLATE_OPS_TRIPLE(return sum::op<T>(mul::op<T>(a, b), c);)
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Fused multiply-add operation A * C + B.
|
||||
*
|
||||
* This operation performs a fused multiply-add, computing (A * C) + B with only one rounding.
|
||||
* This is particularly useful for attention mechanisms in neural networks.
|
||||
*
|
||||
* @tparam T The data type of the input and output values.
|
||||
* @param a[in] The first input value.
|
||||
* @param b[in] The third input value to be added.
|
||||
* @param c[in] The second input value.
|
||||
* @return The result of the fused multiply-add operation.
|
||||
*/
|
||||
struct fma_AxCtB { // this is the one needed for attention
|
||||
TEMPLATE_OPS_TRIPLE(return sum::op<T>(mul::op<T>(a, c), b);)
|
||||
};
|
||||
|
||||
#undef TEMPLATE_OPS_SINGLE
|
||||
#undef TEMPLATE_OPS_OVERRIDE_SINGLE
|
||||
#undef TEMPLATE_OPS_DOUBLE
|
||||
#undef TEMPLATE_OPS_OVERRIDE_DOUBLE
|
||||
#undef TEMPLATE_OPS_TRIPLE
|
||||
#undef TEMPLATE_OPS_OVERRIDE_TRIPLE
|
||||
} // base_ops
|
||||
} // mittens
|
||||
321
extra/thunder/include/common/base_types.metal
Normal file
321
extra/thunder/include/common/base_types.metal
Normal file
@@ -0,0 +1,321 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace mittens {
|
||||
|
||||
using bf16 = bfloat;
|
||||
using bf16_2 = bfloat2;
|
||||
using bf16_4 = bfloat4;
|
||||
//using half_2 = half2;
|
||||
|
||||
namespace ducks {
|
||||
namespace base_types {
|
||||
template <typename T>
|
||||
static METAL_FUNC constexpr const bool isT1() {
|
||||
return metal::is_same<typename T::dtype, float>::value ||
|
||||
metal::is_same<typename T::dtype, bf16 >::value ||
|
||||
metal::is_same<typename T::dtype, half>::value;
|
||||
}
|
||||
template <typename T>
|
||||
static METAL_FUNC constexpr const bool isT2() {
|
||||
return metal::is_same<typename T::dtype, float2>::value ||
|
||||
metal::is_same<typename T::dtype, bf16_2>::value ||
|
||||
metal::is_same<typename T::dtype, half2>::value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static METAL_FUNC constexpr const bool isT1Type() {
|
||||
return metal::is_same<T, float>::value ||
|
||||
metal::is_same<T, bf16 >::value ||
|
||||
metal::is_same<T, half>::value;
|
||||
}
|
||||
template <typename T>
|
||||
static METAL_FUNC constexpr const bool isT2Type() {
|
||||
return metal::is_same<T, float2>::value ||
|
||||
metal::is_same<T, bf16_2>::value ||
|
||||
metal::is_same<T, half2>::value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static METAL_FUNC constexpr const bool isT1Ptr() {
|
||||
return metal::is_same<T, device float*>::value ||
|
||||
metal::is_same<T, threadgroup float*>::value ||
|
||||
metal::is_same<T, thread float*>::value ||
|
||||
metal::is_same<T, device bf16*>::value ||
|
||||
metal::is_same<T, threadgroup bf16*>::value ||
|
||||
metal::is_same<T, thread bf16*>::value ||
|
||||
metal::is_same<T, device half*>::value ||
|
||||
metal::is_same<T, threadgroup half*>::value ||
|
||||
metal::is_same<T, thread half*>::value;
|
||||
}
|
||||
template <typename T>
|
||||
static METAL_FUNC constexpr const bool isT2Ptr() {
|
||||
return metal::is_same<T, device float2*>::value ||
|
||||
metal::is_same<T, threadgroup float2*>::value ||
|
||||
metal::is_same<T, thread float2*>::value ||
|
||||
metal::is_same<T, device bf16_2*>::value ||
|
||||
metal::is_same<T, threadgroup bf16_2*>::value ||
|
||||
metal::is_same<T, thread bf16_2*>::value ||
|
||||
metal::is_same<T, device half2*>::value ||
|
||||
metal::is_same<T, threadgroup half2*>::value ||
|
||||
metal::is_same<T, thread half2*>::value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static METAL_FUNC constexpr const bool isTKType() { // good enough
|
||||
return !isT1Type<T>() && !isT2Type<T>() && !isT1Ptr<T>() && !isT2Ptr<T>();
|
||||
}
|
||||
|
||||
} // namespace base_types
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @namespace base_types
|
||||
*
|
||||
* @brief A namespace for Thundermittens basic data types.
|
||||
*/
|
||||
namespace base_types {
|
||||
/**
|
||||
* @brief Provides compile-time constants for different types.
|
||||
*
|
||||
* @tparam T The type for which to provide constants.
|
||||
*/
|
||||
template<typename T> struct constants {
|
||||
/**
|
||||
* @brief Zero
|
||||
* @return Constexpr zero with type T
|
||||
*/
|
||||
static METAL_FUNC constexpr T zero() { return T{0}; }
|
||||
/**
|
||||
* @brief One
|
||||
* @return Constexpr one with type T
|
||||
*/
|
||||
static METAL_FUNC constexpr T one() { return T{1}; }
|
||||
/**
|
||||
* @brief Positive infinity. Particularly useful for initializing before a min op.
|
||||
* @return Constexpr positive infinity with type T
|
||||
*/
|
||||
static METAL_FUNC constexpr T pos_infty() { return T{INFINITY}; } // I'll find a better way at some point but this appears to work.
|
||||
/**
|
||||
* @brief Negative infinity. Particularly useful for initializing before a max op.
|
||||
* @return Constexpr negative infinity with type T
|
||||
*/
|
||||
static METAL_FUNC constexpr T neg_infty() { return T{-INFINITY}; }
|
||||
};
|
||||
template<> struct constants<float> {
|
||||
static METAL_FUNC constexpr float zero() { return 0.f; }
|
||||
static METAL_FUNC constexpr float one() { return 1.f; }
|
||||
static METAL_FUNC constexpr float pos_infty() { return INFINITY; }
|
||||
static METAL_FUNC constexpr float neg_infty() { return -INFINITY; }
|
||||
};
|
||||
template<> struct constants<float2> {
|
||||
static METAL_FUNC constexpr float2 zero() { return float2(0.f, 0.f); }
|
||||
static METAL_FUNC constexpr float2 one() { return float2(1.f, 1.f); }
|
||||
static METAL_FUNC constexpr float2 pos_infty() { return float2(constants<float>::pos_infty(), constants<float>::pos_infty()); }
|
||||
static METAL_FUNC constexpr float2 neg_infty() { return float2(constants<float>::neg_infty(), constants<float>::neg_infty()); }
|
||||
};
|
||||
template<> struct constants<bf16> {
|
||||
static METAL_FUNC constexpr bf16 zero() { return 0.bf; }
|
||||
static METAL_FUNC constexpr bf16 one() { return 1.bf; }
|
||||
static METAL_FUNC constexpr bf16 pos_infty() { return HUGE_VALBF; }
|
||||
static METAL_FUNC constexpr bf16 neg_infty() { return -HUGE_VALBF; }
|
||||
};
|
||||
template<> struct constants<bf16_2> {
|
||||
static METAL_FUNC constexpr bf16_2 zero() { return bf16_2(constants<bf16>::zero(), constants<bf16>::zero()); }
|
||||
static METAL_FUNC constexpr bf16_2 one() { return bf16_2(constants<bf16>::one(), constants<bf16>::one()); }
|
||||
static METAL_FUNC constexpr bf16_2 pos_infty() { return bf16_2(constants<bf16>::pos_infty(), constants<bf16>::pos_infty()); }
|
||||
static METAL_FUNC constexpr bf16_2 neg_infty() { return bf16_2(constants<bf16>::neg_infty(), constants<bf16>::neg_infty()); }
|
||||
};
|
||||
template<> struct constants<half> {
|
||||
static METAL_FUNC constexpr half zero() { return half(0.h); }
|
||||
static METAL_FUNC constexpr half one() { return half(1.h); }
|
||||
static METAL_FUNC constexpr half pos_infty() { return HUGE_VALH; }
|
||||
static METAL_FUNC constexpr half neg_infty() { return -HUGE_VALH; }
|
||||
};
|
||||
|
||||
template<> struct constants<half2> {
|
||||
static METAL_FUNC constexpr half2 zero() { return half2(constants<half>::zero(), constants<half>::zero()); }
|
||||
static METAL_FUNC constexpr half2 one() { return half2(constants<half>::one(), constants<half>::one()); }
|
||||
static METAL_FUNC constexpr half2 pos_infty() { return half2(constants<half>::pos_infty(), constants<half>::pos_infty()); }
|
||||
static METAL_FUNC constexpr half2 neg_infty() { return half2(constants<half>::neg_infty(), constants<half>::neg_infty()); }
|
||||
};
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @brief Provides information about packing of elements for a given type.
|
||||
*
|
||||
* @tparam T The type for which to provide packing information.
|
||||
*/
|
||||
template<typename T> struct packing {
|
||||
// /**
|
||||
// * @brief The number of elements packed together.
|
||||
// *
|
||||
// * @return constexpr int representing number of elements within the type.
|
||||
// */
|
||||
// static METAL_FUNC constexpr int num() { return 1; }
|
||||
// /**
|
||||
// * @brief Packs a single T element twice (replicated) into its packed type.
|
||||
// *
|
||||
// * @param i[in] The element to pack.
|
||||
// * @return The packed type.
|
||||
// */
|
||||
// static METAL_FUNC constexpr T pack(device const bf16 &i);
|
||||
// static METAL_FUNC constexpr T pack(threadgroup const bf16 &i);
|
||||
// static METAL_FUNC constexpr T pack(thread const bf16 &i);
|
||||
};
|
||||
|
||||
#define PACK_FUNCTIONS(T1, T2) \
|
||||
static METAL_FUNC constexpr T2 pack(device const T1 &i) { return T2{i, i}; } \
|
||||
static METAL_FUNC constexpr T2 pack(threadgroup const T1 &i) { return T2{i, i}; } \
|
||||
static METAL_FUNC constexpr T2 pack(thread const T1 &i) { return T2{i, i}; }
|
||||
|
||||
template<> struct packing<bf16> {
|
||||
static METAL_FUNC constexpr int num() { return 1; }
|
||||
using unpacked_type = bf16;
|
||||
using packed_type = bf16_2;
|
||||
using packed_four = bf16_4;
|
||||
PACK_FUNCTIONS(unpacked_type, packed_type)
|
||||
};
|
||||
template<> struct packing<half> {
|
||||
static METAL_FUNC constexpr int num() { return 1; }
|
||||
using unpacked_type = half;
|
||||
using packed_type = half2;
|
||||
using packed_four = half4;
|
||||
PACK_FUNCTIONS(unpacked_type, packed_type)
|
||||
};
|
||||
template<> struct packing<float> {
|
||||
static METAL_FUNC constexpr int num() { return 1; }
|
||||
using unpacked_type = float;
|
||||
using packed_type = float2;
|
||||
using packed_four = float4;
|
||||
|
||||
PACK_FUNCTIONS(unpacked_type, packed_type)
|
||||
};
|
||||
template<> struct packing<bf16_2> {
|
||||
static METAL_FUNC constexpr int num() { return 2; }
|
||||
using unpacked_type = bf16;
|
||||
using packed_type = bf16_2;
|
||||
using packed_four = bf16_4;
|
||||
PACK_FUNCTIONS(unpacked_type, packed_type)
|
||||
};
|
||||
template<> struct packing<half2> {
|
||||
static METAL_FUNC constexpr int num() { return 2; }
|
||||
using unpacked_type = half;
|
||||
using packed_type = half2;
|
||||
using packed_four = half4;
|
||||
PACK_FUNCTIONS(unpacked_type, packed_type)
|
||||
};
|
||||
template<> struct packing<float2> {
|
||||
static METAL_FUNC constexpr int num() { return 2; }
|
||||
using unpacked_type = float;
|
||||
using packed_type = float2;
|
||||
using packed_four = float4;
|
||||
PACK_FUNCTIONS(unpacked_type, packed_type)
|
||||
};
|
||||
template<> struct packing<int2> {
|
||||
static METAL_FUNC constexpr int num() { return 2; }
|
||||
};
|
||||
template<> struct packing<float4> {
|
||||
static METAL_FUNC constexpr int num() { return 4; }
|
||||
};
|
||||
template<> struct packing<int4> {
|
||||
static METAL_FUNC constexpr int num() { return 4; }
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* @brief Provides templated functionality to convert between different types.
|
||||
*
|
||||
* @tparam T The target type for conversion.
|
||||
* @tparam U The source type for conversion.
|
||||
*/
|
||||
template<typename T, typename U> struct convertor {
|
||||
/**
|
||||
* @brief Converts a value of type U to type T.
|
||||
*
|
||||
* @param u[in] The value of type U to convert.
|
||||
* @return T The converted value of type T.
|
||||
*/
|
||||
static METAL_FUNC T convert(device const U & u) { return (T)u; }
|
||||
static METAL_FUNC T convert(threadgroup const U & u) { return (T)u; }
|
||||
static METAL_FUNC T convert(thread const U & u) { return (T)u; }
|
||||
};
|
||||
|
||||
template<> struct convertor<float, bf16> {
|
||||
// fptrunc float %_ to bfloat
|
||||
static METAL_FUNC float convert(device const bf16 & u) { return float(u);}
|
||||
static METAL_FUNC float convert(threadgroup const bf16 & u) { return float(u);}
|
||||
static METAL_FUNC float convert(thread const bf16 & u) { return float(u);}
|
||||
};
|
||||
template<> struct convertor<bf16, float> {
|
||||
// fpext bfloat %_ to float
|
||||
static METAL_FUNC bf16 convert(device const float & u) { return bf16(u); }
|
||||
static METAL_FUNC bf16 convert(threadgroup const float & u) { return bf16(u); }
|
||||
static METAL_FUNC bf16 convert(thread const float & u) { return bf16(u); }
|
||||
};
|
||||
template<> struct convertor<float2, bf16_2> {
|
||||
// tail call fast <2 x float> @air.convert.f.v2f32.f.v2bf16(<2 x bfloat> %_)
|
||||
static METAL_FUNC float2 convert(device const bf16_2 & u) { return float2(u); }
|
||||
static METAL_FUNC float2 convert(threadgroup const bf16_2 & u) { return float2(u); }
|
||||
static METAL_FUNC float2 convert(thread const bf16_2 & u) { return float2(u); }
|
||||
};
|
||||
template<> struct convertor<bf16_2, float2> {
|
||||
// tail call fast <2 x bfloat> @air.convert.f.v2bf16.f.v2f32(<2 x float> %_)
|
||||
static METAL_FUNC bf16_2 convert(device const float2 & u) { return bf16_2(u); }
|
||||
static METAL_FUNC bf16_2 convert(threadgroup const float2 & u) { return bf16_2(u); }
|
||||
static METAL_FUNC bf16_2 convert(thread const float2 & u) { return bf16_2(u); }
|
||||
};
|
||||
|
||||
template<> struct convertor<float, half> {
|
||||
// fptrunc float %_ to half
|
||||
static METAL_FUNC float convert(device const half & u) { return float(u); }
|
||||
static METAL_FUNC float convert(threadgroup const half & u) { return float(u); }
|
||||
static METAL_FUNC float convert(thread const half & u) { return float(u); }
|
||||
};
|
||||
template<> struct convertor<half, float> {
|
||||
//fpext half %_ to float
|
||||
static METAL_FUNC half convert(device const float & u) { return half(u); }
|
||||
static METAL_FUNC half convert(threadgroup const float & u) { return half(u); }
|
||||
static METAL_FUNC half convert(thread const float & u) { return half(u); }
|
||||
};
|
||||
template<> struct convertor<float2, half2> {
|
||||
// tail call fast <2 x float> @air.convert.f.v2f32.f.v2f16(<2 x half> %_)
|
||||
static METAL_FUNC float2 convert(device const half2 & u) { return float2(u); }
|
||||
static METAL_FUNC float2 convert(threadgroup const half2 & u) { return float2(u); }
|
||||
static METAL_FUNC float2 convert(thread const half2 & u) { return float2(u); }
|
||||
};
|
||||
template<> struct convertor<half2, float2> {
|
||||
// tail call fast <2 x half> @air.convert.f.v2f16.f.v2f32(<2 x float> %_)
|
||||
static METAL_FUNC half2 convert(device const float2 & u) { return half2(u); }
|
||||
static METAL_FUNC half2 convert(threadgroup const float2 & u) { return half2(u); }
|
||||
static METAL_FUNC half2 convert(thread const float2 & u) { return half2(u); }
|
||||
};
|
||||
template<> struct convertor<bf16, half> {
|
||||
static METAL_FUNC bf16 convert(device const half & u) { return bf16(u); }
|
||||
static METAL_FUNC bf16 convert(threadgroup const half & u) { return bf16(u); }
|
||||
static METAL_FUNC bf16 convert(thread const half & u) { return bf16(u); }
|
||||
};
|
||||
template<> struct convertor<half, bf16> {
|
||||
static METAL_FUNC half convert(device const bf16 & u) { return half(u); }
|
||||
static METAL_FUNC half convert(threadgroup const bf16 & u) { return half(u); }
|
||||
static METAL_FUNC half convert(thread const bf16 & u) { return half(u); }
|
||||
};
|
||||
template<> struct convertor<bf16_2, half2> {
|
||||
// tail call fast <2 x bfloat> @air.convert.f.v2bf16.f.v2f16(<2 x half> %_)
|
||||
static METAL_FUNC bf16_2 convert(device const half2 & u) { return bf16_2(u); }
|
||||
static METAL_FUNC bf16_2 convert(threadgroup const half2 & u) { return bf16_2(u); }
|
||||
static METAL_FUNC bf16_2 convert(thread const half2 & u) { return bf16_2(u); }
|
||||
};
|
||||
template<> struct convertor<half2, bf16_2> {
|
||||
// tail call fast <2 x half> @air.convert.f.v2f16.f.v2bf16(<2 x bfloat> %_)
|
||||
static METAL_FUNC half2 convert(device const bf16_2 & u) { return half2(u); }
|
||||
static METAL_FUNC half2 convert(threadgroup const bf16_2 & u) { return half2(u); }
|
||||
static METAL_FUNC half2 convert(thread const bf16_2 & u) { return half2(u); }
|
||||
};
|
||||
|
||||
|
||||
|
||||
} // base_types
|
||||
|
||||
} // mittens
|
||||
10
extra/thunder/include/common/common.metal
Normal file
10
extra/thunder/include/common/common.metal
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief A collection of common resources on which Thundermittens depends.
|
||||
*/
|
||||
|
||||
|
||||
#pragma once
|
||||
#include "base_types.metal"
|
||||
#include "base_ops.metal"
|
||||
#include "utils.metal"
|
||||
225
extra/thunder/include/common/utils.metal
Normal file
225
extra/thunder/include/common/utils.metal
Normal file
@@ -0,0 +1,225 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief General utilities for Thundermittens.
|
||||
*/
|
||||
#pragma once // not done
|
||||
/*
|
||||
TODO:
|
||||
shared allocator
|
||||
max shared mem for other hardware
|
||||
*/
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "base_types.metal"
|
||||
/**
|
||||
* @namespace mittens
|
||||
*
|
||||
* @brief The main namespace of Thundermittens.
|
||||
*/
|
||||
namespace mittens {
|
||||
/**
|
||||
* @namespace ore
|
||||
*
|
||||
* @brief The main namespace of Thundermittens Metal.
|
||||
*/
|
||||
|
||||
/* ---------- GENERAL CONSTANTS FOR mittens ---------- */
|
||||
|
||||
/**
|
||||
* @brief Tile dimension constant.
|
||||
*/
|
||||
constant constexpr const int TILE_DIM{8};
|
||||
constant constexpr const int TILE_ELEMENTS{TILE_DIM*TILE_DIM};
|
||||
constant constexpr const int SIMD_THREADS{32};
|
||||
|
||||
|
||||
#ifdef M2_PRO
|
||||
constant constexpr int MAX_SHARED_MEMORY = 32768;
|
||||
#else
|
||||
constant constexpr int MAX_SHARED_MEMORY = 32768;
|
||||
#endif
|
||||
/* ---------- TYPE HELPERS ---------- */
|
||||
/**
|
||||
* @namespace ducks
|
||||
*
|
||||
* @brief Thundermittens' namespace for template metaprogramming..
|
||||
*
|
||||
* This includes primarily dummy types and concept wrappers, along
|
||||
* with a few additional utilities.
|
||||
*/
|
||||
namespace ducks {
|
||||
|
||||
/**
|
||||
* @brief A type representing an empty default for a template.
|
||||
*/
|
||||
struct default_type {};
|
||||
|
||||
// This macro can't be done as a template, so it doesn't really have a location in mittens.
|
||||
#define typeof(A) typename std::remove_const<typename std::remove_reference<decltype(A)>::type>::type
|
||||
|
||||
|
||||
}
|
||||
|
||||
/* ---------- SHUFFLE UTILS ---------- */
|
||||
/**
|
||||
* @brief Mask constant for all active threads in a warp.
|
||||
*/
|
||||
constant static constexpr uint32_t MASK_ALL = 0xFFFFFFFF;
|
||||
|
||||
template<typename T>
|
||||
static METAL_FUNC T shfl_sync(thread const T &f, const ushort laneid) {
|
||||
return metal::simd_shuffle(f, laneid);
|
||||
}
|
||||
|
||||
template<>
|
||||
METAL_FUNC bfloat shfl_sync<bfloat>(thread const bf16 &f, const ushort laneid) {
|
||||
// return as_type<bf16>(metal::simd_shuffle(*(thread half*)(&f), laneid));
|
||||
float f_val = (float)f;
|
||||
float shfl_val = metal::simd_shuffle(f_val, laneid);
|
||||
return (bf16)shfl_val;
|
||||
}
|
||||
|
||||
template<>
|
||||
METAL_FUNC bfloat2 shfl_sync<bfloat2>(thread const bf16_2 &f, const ushort laneid) {
|
||||
// return as_type<bf16_2>(metal::simd_shuffle(*(thread half2*)(&f), laneid));
|
||||
float2 f_val = (float2)f;
|
||||
float2 shfl_val = metal::simd_shuffle(f_val, laneid);
|
||||
return (bf16_2)shfl_val;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static METAL_FUNC T shfl_down_fill_sync(thread const T &f, thread const T& fill_data, const ushort laneid) {
|
||||
return metal::simd_shuffle_and_fill_down(f, laneid, fill_data);
|
||||
}
|
||||
|
||||
template<>
|
||||
METAL_FUNC bfloat shfl_down_fill_sync<bfloat>(thread const bfloat &f, thread const bfloat &fill_data, const ushort laneid) {
|
||||
// return as_type<bf16>(metal::simd_shuffle_and_fill_down(*(thread half*)(&f), *(thread half*)(&fill_data), laneid));
|
||||
float f_val = (float)f;
|
||||
float fill_data_f = (float)fill_data;
|
||||
float shfl_val = metal::simd_shuffle_and_fill_down(f_val, fill_data_f, laneid);
|
||||
return (bf16)shfl_val;
|
||||
}
|
||||
template<>
|
||||
METAL_FUNC bfloat2 shfl_down_fill_sync<bfloat2>(thread const bfloat2 &f, thread const bfloat2 &fill_data, const ushort laneid) {
|
||||
// return as_type<bf16_2>(metal::simd_shuffle_and_fill_down(*(thread half2*)(&f), *(thread half2*)(&fill_data), laneid));
|
||||
float2 f_val = (float2)f;
|
||||
float2 fill_data_f = (float2)fill_data;
|
||||
float2 shfl_val = metal::simd_shuffle_and_fill_down(f_val, fill_data_f, laneid);
|
||||
return (bf16_2)shfl_val;
|
||||
}
|
||||
/**
|
||||
* @brief Perform a shuffle down operation on a packed type synchronously across a warp.
|
||||
* @tparam T The type of the value to be shuffled.
|
||||
* @param mask[in] The mask of active threads.
|
||||
* @param f[in] The value to be shuffled.
|
||||
* @param delta[in] The number of positions to shuffle down.
|
||||
* @return The result of the shuffle operation.
|
||||
*/
|
||||
template<typename T>
|
||||
static METAL_FUNC T shfl_down_sync(thread const T &f, int delta) {
|
||||
return metal::simd_shuffle_rotate_down(f, delta);
|
||||
}
|
||||
|
||||
template<>
|
||||
METAL_FUNC bfloat shfl_down_sync<bfloat>(thread const bf16 &f, int delta) {
|
||||
// return base_types::convertor<bf16, float>::convert(metal::simd_shuffle_rotate_down(base_types::convertor<float, bf16>::convert(f), delta));
|
||||
// return as_type<bf16>(metal::simd_shuffle_rotate_down(*(thread half*)(&f), delta));
|
||||
float f_val = (float)f;
|
||||
float shfl_val = metal::simd_shuffle_rotate_down(f_val, delta);
|
||||
return (bf16)shfl_val;
|
||||
}
|
||||
|
||||
template<>
|
||||
METAL_FUNC bfloat2 shfl_down_sync<bfloat2>(thread const bf16_2 &f, int delta) {
|
||||
// return as_type<bf16_2>(metal::simd_shuffle_rotate_down(*(thread const half2*)(&f), delta));
|
||||
// return base_types::convertor<bf16_2, float2>::convert(metal::simd_shuffle_rotate_down(base_types::convertor<float2, bf16_2>::convert(f), delta));
|
||||
|
||||
float2 f_val = (float2)f;
|
||||
float2 shfl_val = metal::simd_shuffle_rotate_down(f_val, delta);
|
||||
return (bf16_2)shfl_val;
|
||||
// return as_type<bf16_2>(metal::simd_shuffle_rotate_down(*(thread half2*)(&f), delta));
|
||||
}
|
||||
|
||||
|
||||
/* ---------- LOOP UNROLLING UTILS ---------- */
|
||||
|
||||
namespace meta {
|
||||
template <int Start, int End, int Stride, bool = (Start < End)>
|
||||
struct unroll_i_in_range {
|
||||
template<class F, typename... Args>
|
||||
static METAL_FUNC void run(F f, Args... args) {
|
||||
f(Start, args...);
|
||||
unroll_i_in_range<Start + Stride, End, Stride>::run(f, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <int Start, int End, int Stride>
|
||||
struct unroll_i_in_range<Start, End, Stride, false> {
|
||||
template<class F, typename... Args>
|
||||
static METAL_FUNC void run(F, Args...) {
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <int Start, int End, int Stride, bool = (Start < End)>
|
||||
struct unroll_i_j_in_range_inner {
|
||||
template<class F, typename... Args>
|
||||
static METAL_FUNC void run(F f, int outerIndex, Args... args) {
|
||||
f(outerIndex, Start, args...);
|
||||
unroll_i_j_in_range_inner<Start + Stride, End, Stride>::run(f, outerIndex, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <int Start, int End, int Stride>
|
||||
struct unroll_i_j_in_range_inner<Start, End, Stride, false> {
|
||||
template<class F, typename... Args>
|
||||
static METAL_FUNC void run(F, int, Args...) {
|
||||
}
|
||||
};
|
||||
|
||||
template <int StartOuter, int EndOuter, int StrideOuter,
|
||||
int StartInner, int EndInner, int StrideInner,
|
||||
bool = (StartOuter < EndOuter)>
|
||||
struct unroll_i_j_in_range {
|
||||
template<class F, typename... Args>
|
||||
static METAL_FUNC void run(F f, Args... args) {
|
||||
unroll_i_j_in_range_inner<StartInner, EndInner, StrideInner>::run(
|
||||
f, StartOuter, args...
|
||||
);
|
||||
unroll_i_j_in_range<
|
||||
StartOuter + StrideOuter, EndOuter, StrideOuter,
|
||||
StartInner, EndInner, StrideInner
|
||||
>::run(f, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <int StartOuter, int EndOuter, int StrideOuter,
|
||||
int StartInner, int EndInner, int StrideInner>
|
||||
struct unroll_i_j_in_range<StartOuter, EndOuter, StrideOuter,
|
||||
StartInner, EndInner, StrideInner, false> {
|
||||
template<class F, typename... Args>
|
||||
static METAL_FUNC void run(F, Args...) {
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
|
||||
template <int N>
|
||||
struct ReadVector {
|
||||
float _[N];
|
||||
};
|
||||
|
||||
/* ---------- SHARED MEMORY UTILS ---------- */
|
||||
|
||||
#define mittens_ALIGN_AS(n) alignas(n)
|
||||
#define mittens_DEFAULT_ALIGN mittens_ALIGN_AS(16)
|
||||
|
||||
/**
|
||||
* @brief Dummy structure for alignment purposes. Needed for WGMMA and TMA calls.
|
||||
*/
|
||||
struct mittens_DEFAULT_ALIGN alignment_dummy { int dummy; };
|
||||
}
|
||||
|
||||
|
||||
24
extra/thunder/include/ops/group/group.metal
Normal file
24
extra/thunder/include/ops/group/group.metal
Normal file
@@ -0,0 +1,24 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of all group (multi-warp) operations defined by Thundermittens
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "../../common/common.metal"
|
||||
#include "../../types/types.metal"
|
||||
#include "../warp/warp.metal" // several group memory ops rely on underlying warp-scope ops
|
||||
namespace mittens {
|
||||
template<int N_WARPS>
|
||||
struct group {
|
||||
constant static constexpr int GROUP_WARPS = N_WARPS; // This alias produces nice parallelism.
|
||||
constant static constexpr int GROUP_THREADS = N_WARPS * mittens::SIMD_THREADS; // This alias produces nice parallelism.
|
||||
static METAL_FUNC int simd_laneid(const unsigned threadIdx) { return threadIdx % mittens::SIMD_THREADS; }
|
||||
static METAL_FUNC int laneid (const unsigned threadIdx) { return threadIdx % GROUP_THREADS; }
|
||||
static METAL_FUNC int warpid (const unsigned threadIdx) { return laneid(threadIdx) / mittens::SIMD_THREADS; }
|
||||
static METAL_FUNC int groupid (const unsigned threadIdx) { return threadIdx / GROUP_THREADS; }
|
||||
#include "memory/memory.metal"
|
||||
#include "shared/shared.metal"
|
||||
};
|
||||
|
||||
|
||||
}
|
||||
2
extra/thunder/include/ops/group/memory/memory.metal
Normal file
2
extra/thunder/include/ops/group/memory/memory.metal
Normal file
@@ -0,0 +1,2 @@
|
||||
#include "tile/tile.metal"
|
||||
#include "vec/vec.metal"
|
||||
@@ -0,0 +1,132 @@
|
||||
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group to collaboratively transfer data directly between global memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively loads data from a source array into row-major layout tiles.
|
||||
*
|
||||
* @tparam RT The row-major layout tile type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param dst[out] The destination tile to load data into.
|
||||
* @param src[in] The source array to load data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the source array.
|
||||
*/
|
||||
template<typename RT, typename GL>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_global_layout<GL>(), void>::type
|
||||
load(thread RT &dst, thread const GL &_src, thread const coord &idx, const int threadIdx) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
const device U *src = (device U*)&_src.template get<RT>(idx);
|
||||
const int row_stride = _src.row_stride();
|
||||
|
||||
int warp_laneid = threadIdx % 32;
|
||||
const int row_offset = dst.rows * warpid(threadIdx);
|
||||
const short qid = warp_laneid / 4;
|
||||
const short simd_y = row_offset + (qid & 4) + (warp_laneid / 2) % 4;
|
||||
const short simd_x = (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
int row = simd_y + i * RT::tile_size;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = simd_x + j * RT::tile_size;
|
||||
T2 src2 = base_types::convertor<T2, U2>::convert(*((device U2*)(&src[row * row_stride + col])));
|
||||
dst.tiles[i][j].data.thread_elements()[0] = src2[0];
|
||||
dst.tiles[i][j].data.thread_elements()[1] = src2[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename RT, typename GL>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_global_layout<GL>(), void>::type
|
||||
load(thread RT &dst, thread const GL &_src, thread const coord &idx, const int threadIdx) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
const device U *src = (device U*)&_src.template get<RT>(idx);
|
||||
const int row_stride = _src.row_stride();
|
||||
|
||||
int warp_laneid = threadIdx % 32;
|
||||
const int row_offset = dst.rows * warpid(threadIdx);
|
||||
const short qid = warp_laneid / 4;
|
||||
const short simd_y = row_offset + (qid & 2) * 2 + (warp_laneid % 2) * 2;;
|
||||
const short simd_x = (qid & 4) + (warp_laneid / 2) % 4;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
int row = simd_y + i * RT::tile_size;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = simd_x + j * RT::tile_size;
|
||||
T2 src2 = base_types::convertor<T2, U2>::convert(*((device U2*)(&src[row * row_stride + col])));
|
||||
dst.tiles[i][j].data.thread_elements()[0] = base_types::convertor<T, U>::convert(src[row * row_stride + col]);
|
||||
dst.tiles[i][j].data.thread_elements()[1] = base_types::convertor<T, U>::convert(src[(row + 1) * row_stride + col]);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Collaboratively stores data from register tiles to a destination array in global memory with a row-major layout.
|
||||
*
|
||||
* @tparam RT The register tile type with a row-major layout.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register tile to store data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the destination array.
|
||||
*/
|
||||
template<typename RT, typename GL>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>(), void>::type
|
||||
store(thread GL &_dst, thread const RT &src, thread const coord &idx, const int threadIdx) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
device U *dst = (device U*)&(_dst.template get<RT>(idx));
|
||||
const int row_stride = _dst.row_stride();
|
||||
int warp_laneid = simd_laneid(threadIdx);
|
||||
const int row_offset = src.rows * warpid(threadIdx);
|
||||
const short qid = warp_laneid / 4;
|
||||
const short simd_y = row_offset + (qid & 4) + (warp_laneid / 2) % 4;
|
||||
const short simd_x = (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
int row = simd_y + i * RT::tile_size;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
int col = simd_x + j * RT::tile_size;
|
||||
U2 src2 = base_types::convertor<U2, T2>::convert(T2(src.tiles[i][j].data.thread_elements()[0], src.tiles[i][j].data.thread_elements()[1]));
|
||||
*(device U2*)(&dst[row*row_stride + col]) = src2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename RT, typename GL>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_col_register_tile<RT>(), void>::type
|
||||
store(thread GL &_dst, thread const RT &src, thread const coord &idx, const int threadIdx) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
device U *dst = (device U*)&(_dst.template get<RT>(idx));
|
||||
const int row_stride = _dst.row_stride();
|
||||
int warp_laneid = simd_laneid(threadIdx);
|
||||
const int row_offset = src.rows * warpid(threadIdx);
|
||||
const short qid = warp_laneid / 4;
|
||||
// const short simd_y = row_offset + (qid & 4) + (warp_laneid / 2) % 4;
|
||||
// const short simd_x = (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
const short simd_y = row_offset + (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
const short simd_x = (qid & 4) + (warp_laneid / 2) % 4;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
int row = simd_y + i * RT::tile_size;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
int col = simd_x + j * RT::tile_size;
|
||||
dst[row*row_stride + col] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[0]);
|
||||
dst[(row + 1) * row_stride + col] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group (collaborative warp) ops for loading shared tiles from and storing to global memory.
|
||||
*/
|
||||
|
||||
|
||||
//template<typename ST, typename U>
|
||||
//static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
//load(int i,
|
||||
// threadgroup ST *dst, device U* src,
|
||||
// thread const int& group_laneid,
|
||||
// thread const int& memcpy_per_row,
|
||||
// thread const int& elem_per_memcpy,
|
||||
// thread const int& row_stride)
|
||||
//{
|
||||
// int idx = i * GROUP_THREADS + group_laneid;
|
||||
// int row = idx / memcpy_per_row;
|
||||
// int col = (idx*elem_per_memcpy) % ST::cols;
|
||||
// if (row < ST::rows) {
|
||||
// *(threadgroup float4*)(&(*dst)[{row, col}]) = *(device float4*)(&src[row*row_stride + col]);
|
||||
// }
|
||||
//}
|
||||
|
||||
|
||||
template<typename ST, typename GL>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_global_layout<GL>(), void>::type
|
||||
load(threadgroup ST &dst, thread const GL &_src, thread const coord &idx, const int threadIdx) {
|
||||
int group_laneid = threadIdx % GROUP_THREADS;
|
||||
using T = typename ST::T;
|
||||
using U = typename GL::dtype;
|
||||
device U *src = (device U*)&_src.template get<ST>(idx);
|
||||
const int row_stride = _src.row_stride();
|
||||
using read_vector = ReadVector<1>;
|
||||
// we can handle this many rows each time we run a memcpy_async
|
||||
constexpr const int elem_per_memcpy = sizeof(read_vector)/sizeof(typename ST::dtype);
|
||||
constexpr const int memcpy_per_row = ST::cols / elem_per_memcpy;
|
||||
int total_calls = ((ST::height * ST::width + (N_WARPS-1))) * TILE_DIM*TILE_DIM / (N_WARPS*SIMD_THREADS*elem_per_memcpy); // round up
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < total_calls; i++) {
|
||||
|
||||
int idx = i * GROUP_THREADS + group_laneid;
|
||||
int row = idx / memcpy_per_row;
|
||||
int col = (idx*elem_per_memcpy) % dst.cols;
|
||||
if (row<dst.rows && col < dst.cols) {
|
||||
*(threadgroup read_vector*)(&dst[{row, col}]) = *(device read_vector*)(&src[row*row_stride + col]);
|
||||
// *(threadgroup float*)(&dst[{row, col}]) = 1.0f;
|
||||
}
|
||||
}
|
||||
// dst[{0, 0}] = base_types::convertor<T, float>::convert(1.f);
|
||||
// dst[{0, 0}] = total_calls;
|
||||
// meta::unroll_i_in_range<0, total_calls, 1>::run(load<ST, typename GL::dtype>, &dst, src, group_laneid, memcpy_per_row, elem_per_memcpy, row_stride);
|
||||
}
|
||||
|
||||
|
||||
//template<typename ST, typename GL>
|
||||
//static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_global_layout<GL>(), void>::type
|
||||
//load(threadgroup ST &dst, thread const GL &_src, thread const coord &idx, const int threadIdx) {
|
||||
// int group_laneid = threadIdx % GROUP_THREADS;
|
||||
// int groupid = threadIdx / GROUP_THREADS;
|
||||
// int laneid = threadIdx % SIMD_THREADS;
|
||||
//
|
||||
// using U = typename GL::dtype;
|
||||
// device U *src = (device U*)&_src.template get<ST>(idx);
|
||||
// const int row_stride = _src.row_stride();
|
||||
//
|
||||
// int elem_per_memcpy = sizeof(float)/sizeof(typename ST::dtype);
|
||||
// int memcpy_per_row = ST::cols / elem_per_memcpy;
|
||||
// int total_calls = ((ST::height * ST::width + (N_WARPS-1))) * TILE_DIM*TILE_DIM / (N_WARPS*SIMD_THREADS*elem_per_memcpy); // round up
|
||||
// /*
|
||||
// 1x16 or 8 x 128
|
||||
// */
|
||||
// int offset = ST::num_elements / (GROUP_WARPS);
|
||||
//// int offset = group_laneid
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < total_calls; i++) {
|
||||
// int idx = i * SIMD_THREADS + laneid;
|
||||
//// int idx = i * () + group_laneid;
|
||||
// int row = idx / memcpy_per_row;
|
||||
// int col = (idx*elem_per_memcpy) % dst.cols;
|
||||
// if (row<dst.rows) {
|
||||
// *(threadgroup float*)(&dst[{row, col}]) = *(device float*)(&src[row*row_stride + col]);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//template<typename ST, typename GL>
|
||||
//static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_global_layout<GL>(), void>::type
|
||||
//load(threadgroup ST &dst, thread const GL &_src, thread const coord &idx, const int threadIdx) {
|
||||
// int warp_id = threadIdx / SIMD_THREADS;
|
||||
// int lane_id = threadIdx % SIMD_THREADS;
|
||||
//// int N_WARPS = /* number of warps in your group */;
|
||||
//
|
||||
// using U = typename GL::dtype;
|
||||
// device U *src = (device U*)&_src.template get<ST>(idx);
|
||||
// const int row_stride = _src.row_stride();
|
||||
//
|
||||
// int elem_per_memcpy = sizeof(float)/sizeof(typename ST::dtype);
|
||||
// int memcpy_per_row = ST::cols / elem_per_memcpy;
|
||||
// int total_memcpy_elems = (ST::height * ST::cols) / elem_per_memcpy;
|
||||
// int elems_per_warp = (total_memcpy_elems + N_WARPS - 1) / N_WARPS; // Ceiling division
|
||||
//
|
||||
// int start_idx = warp_id * elems_per_warp;
|
||||
// int end_idx = metal::min(start_idx + elems_per_warp, total_memcpy_elems);
|
||||
//
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int idx = start_idx + lane_id; idx < end_idx; idx += SIMD_THREADS) {
|
||||
// int row = idx / memcpy_per_row;
|
||||
// int col = (idx % memcpy_per_row) * elem_per_memcpy;
|
||||
// if (row < ST::height) {
|
||||
// *(threadgroup float*)(&dst[{row, col}]) = *(device float*)(&src[row * row_stride + col]);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
template<typename ST, typename GL>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_global_layout<GL>(), void>::type
|
||||
store(thread const GL &_dst, threadgroup const ST &src, thread const coord &idx, const int threadIdx) {
|
||||
int group_laneid = threadIdx % GROUP_THREADS;
|
||||
using U = typename GL::dtype;
|
||||
device U *dst = (device U*)&_dst.template get<ST>(idx);
|
||||
const int row_stride = _dst.row_stride();
|
||||
using read_vector = ReadVector<1>;
|
||||
// we can handle this many rows each time we run a memcpy_async
|
||||
int elem_per_memcpy = sizeof(read_vector)/sizeof(typename ST::dtype); // float/float -> 1
|
||||
int memcpy_per_row = ST::cols / elem_per_memcpy; // 240 memcpy per row
|
||||
int total_calls = ((src.height * src.width + (N_WARPS-1))) * TILE_DIM*TILE_DIM / (N_WARPS*SIMD_THREADS*elem_per_memcpy); // round up
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < total_calls; i++) {
|
||||
|
||||
int idx = i * GROUP_THREADS + group_laneid;
|
||||
|
||||
int row = idx / memcpy_per_row;
|
||||
int col = (idx*elem_per_memcpy) % src.cols;
|
||||
if (row<src.rows && col < src.cols) {
|
||||
*(device read_vector*)(&dst[row*row_stride + col]) = *(threadgroup read_vector*)(&src[{row, col}]);
|
||||
// *(device float*)(&dst[row*row_stride + col]) = 1.f;
|
||||
}
|
||||
}
|
||||
// dst[0] = src[{0,0}];
|
||||
// dst[0] = total_calls;
|
||||
// dst[0] = base_types::convertor<U, float>::convert(1);
|
||||
}
|
||||
|
||||
@@ -0,0 +1,152 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a warpgroup to collaboratively transfer data directly between shared memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively load data from a shared tile into register tiles split across a warpgroup.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination register tile.
|
||||
* @param src[in] The source shared tile.
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
load(thread RT &dst, threadgroup const ST &src, const int threadIdx) {
|
||||
constexpr int height = ST::height;
|
||||
constexpr int warp_height = RT::height;
|
||||
static_assert(height%N_WARPS == 0, "Group load / store requires tile height to be a multiple of N_WARPS.");
|
||||
static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height.");
|
||||
static_assert(warp_height * N_WARPS == height, "RT height * N_WARPS must = ST height");
|
||||
static_assert(ST::width==RT::width, "Group load / store requires tile widths to match.");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
|
||||
int warp_laneid = simd_laneid(threadIdx);
|
||||
const int row_offset = RT::rows * warpid(threadIdx);
|
||||
const short qid = warp_laneid / 4;
|
||||
const short simd_y = row_offset + (qid & 4) + (warp_laneid / 2) % 4;
|
||||
const short simd_x = (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
int row = simd_y + i * mittens::TILE_DIM;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int col = simd_x + j * mittens::TILE_DIM;
|
||||
T2 src2 = base_types::convertor<T2, U2>::convert(*((threadgroup U2*)(&src[{row, col}])));
|
||||
dst.tiles[i][j].data.thread_elements()[0] = src2[0];
|
||||
dst.tiles[i][j].data.thread_elements()[1] = src2[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
load(thread RT &dst, threadgroup const ST &src, const int threadIdx) {
|
||||
constexpr int height = ST::height;
|
||||
constexpr int warp_height = RT::height;
|
||||
static_assert(height%N_WARPS == 0, "Group load / store requires tile height to be a multiple of N_WARPS.");
|
||||
static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height.");
|
||||
static_assert(warp_height * N_WARPS == height, "RT height * N_WARPS must = ST height");
|
||||
static_assert(ST::width==RT::width, "Group load / store requires tile widths to match.");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
|
||||
int warp_laneid = simd_laneid(threadIdx);
|
||||
const int row_offset = RT::rows * warpid(threadIdx);
|
||||
const short qid = warp_laneid / 4;
|
||||
const short simd_y = row_offset + (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
const short simd_x = (qid & 4) + (warp_laneid / 2) % 4;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
int row = simd_y + i * mittens::TILE_DIM;
|
||||
int col = simd_x + j * mittens::TILE_DIM;
|
||||
dst.tiles[i][j].data.thread_elements()[0] = base_types::convertor<T, U>::convert(src[{row + 0, col}]);
|
||||
dst.tiles[i][j].data.thread_elements()[1] = base_types::convertor<T, U>::convert(src[{row + 1, col}]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Collaboratively store data into a shared tile from register tiles split across a warpgroup.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination shared tile.
|
||||
* @param src[in] The source register tile.
|
||||
*/
|
||||
template<typename ST, typename RT>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
store(threadgroup ST &dst, thread const RT &src, const int threadIdx) {
|
||||
constexpr int height = ST::height;
|
||||
constexpr int warp_height = RT::height;
|
||||
static_assert(height%N_WARPS == 0, "Group load / store requires tile height to be a multiple of N_WARPS.");
|
||||
static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height.");
|
||||
static_assert(warp_height * N_WARPS == height, "RT height * N_WARPS must = ST height");
|
||||
static_assert(ST::width==RT::width, "Group load / store requires tile widths to match.");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int warp_laneid = simd_laneid(threadIdx);
|
||||
const int row_offset = RT::rows * warpid(threadIdx);
|
||||
const short qid = warp_laneid / 4;
|
||||
const short simd_y = row_offset + (qid & 4) + (warp_laneid / 2) % 4;
|
||||
const short simd_x = (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < RT::height; i++) {
|
||||
int row = simd_y + i * mittens::TILE_DIM;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < RT::width; j++) {
|
||||
int col = simd_x + j * mittens::TILE_DIM;
|
||||
U2 src2 = base_types::convertor<U2, T2>::convert(T2(src.tiles[i][j].data.thread_elements()[0],
|
||||
src.tiles[i][j].data.thread_elements()[1]));
|
||||
*(threadgroup U2*)(&dst[{row, col}]) = src2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename ST, typename RT>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
store(threadgroup ST &dst, thread const RT &src, const int threadIdx) {
|
||||
constexpr int height = ST::height;
|
||||
constexpr int warp_height = RT::height;
|
||||
static_assert(height%N_WARPS == 0, "Group load / store requires tile height to be a multiple of N_WARPS.");
|
||||
static_assert(height%warp_height == 0, "Group load / store requires tile height to be a multiple of the RT height.");
|
||||
static_assert(warp_height * N_WARPS == height, "RT height * N_WARPS must = ST height");
|
||||
static_assert(ST::width==RT::width, "Group load / store requires tile widths to match.");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int warp_laneid = simd_laneid(threadIdx);
|
||||
const int row_offset = RT::rows * warpid(threadIdx);
|
||||
const short qid = warp_laneid / 4;
|
||||
// const short simd_y = row_offset + (qid & 4) + (warp_laneid / 2) % 4;
|
||||
// const short simd_x = (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
const short simd_y = row_offset + (qid & 2) * 2 + (warp_laneid % 2) * 2;
|
||||
const short simd_x = (qid & 4) + (warp_laneid / 2) % 4;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < RT::height; i++) {
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < RT::width; j++) {
|
||||
int row = simd_y + i * mittens::TILE_DIM;
|
||||
int col = simd_x + j * mittens::TILE_DIM;
|
||||
// U2 src2 = base_types::convertor<U2, T2>::convert(T2(src.tiles[i][j].data.thread_elements()[0],
|
||||
// src.tiles[i][j].data.thread_elements()[1]));
|
||||
// *(threadgroup U2*)(&dst[{row, col}]) = src2;
|
||||
|
||||
dst[{row + 0, col}] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[0]);
|
||||
dst[{row + 1, col}] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
8
extra/thunder/include/ops/group/memory/tile/tile.metal
Normal file
8
extra/thunder/include/ops/group/memory/tile/tile.metal
Normal file
@@ -0,0 +1,8 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of group memory operations on tiles.
|
||||
*/
|
||||
|
||||
#include "shared_to_register.metal"
|
||||
#include "global_to_register.metal"
|
||||
#include "global_to_shared.metal"
|
||||
@@ -0,0 +1,47 @@
|
||||
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a warpgroup to collaboratively transfer data directly between global memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively loads data into register vectors from a source array in global memory.
|
||||
*
|
||||
* @tparam RV The register vector type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param[out] dst The destination register vector to load data into.
|
||||
* @param[in] src The source array in global memory to load data from.
|
||||
*/
|
||||
template<typename RV, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
load(thread RV &dst, thread const GL &_src, thread coord idx, const int threadIdx) {
|
||||
using T = typename RV::dtype;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
|
||||
idx.c += warpid(threadIdx);
|
||||
// Call warp level store
|
||||
::mittens::load(dst, _src, idx, simd_laneid(threadIdx));
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Collaboratively stores data from register vectors to a destination array in global memory.
|
||||
*
|
||||
* @tparam RV The register vector type.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register vector to store data from.
|
||||
*/
|
||||
template<typename RV, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
store(thread GL &_dst, thread const RV &src, thread coord idx, const int threadIdx) {
|
||||
using T = typename RV::dtype;
|
||||
// using U2 = typename base_types::packing<U>::packed_type;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
|
||||
idx.c += warpid(threadIdx);
|
||||
|
||||
// Call warp level store
|
||||
::mittens::store(_dst, src, idx, simd_laneid(threadIdx));
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group (collaborative warp) ops for loading shared vectors from and storing to global memory.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Loads data from global memory into shared memory vector.
|
||||
*
|
||||
* This function loads data from a global memory location pointed to by `src` into a shared memory vector `dst`.
|
||||
* It calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`.
|
||||
* The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp.
|
||||
*
|
||||
* @tparam SV Shared vector type, must satisfy ducks::sv::all concept.
|
||||
* @param dst Reference to the shared vector where the data will be loaded.
|
||||
* @param src Pointer to the global memory location from where the data will be loaded.
|
||||
*/
|
||||
template<typename SV, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
load(threadgroup SV &dst, thread const GL &_src, thread const coord &idx, const int threadIdx) {
|
||||
using U = typename GL::dtype;
|
||||
using read_vector = ReadVector<1>;
|
||||
constexpr int elem_per_transfer = sizeof(read_vector) / sizeof(typename SV::dtype);
|
||||
constexpr int total_calls = SV::length / elem_per_transfer; // guaranteed to divide
|
||||
device U *src = (device U*)&_src.template get<SV>(idx);
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < total_calls; i+=GROUP_THREADS) {
|
||||
if(i * elem_per_transfer < dst.length)
|
||||
*(threadgroup read_vector*)&dst[i*elem_per_transfer] = *(device read_vector*)&src[i*elem_per_transfer];
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stores data from a shared memory vector to global memory.
|
||||
*
|
||||
* This function stores data from a shared memory vector `src` to a global memory location pointed to by `dst`.
|
||||
* Similar to the load function, it calculates the number of elements that can be transferred in one operation based on the size ratio of `float4` to the data type of `SV`.
|
||||
* The function ensures coalesced memory access and efficient use of bandwidth by dividing the work among threads in a warp.
|
||||
*
|
||||
* @tparam SV Shared vector type, must satisfy ducks::sv::all concept.
|
||||
* @param dst Pointer to the global memory location where the data will be stored.
|
||||
* @param src Reference to the shared vector from where the data will be stored.
|
||||
*/
|
||||
template<typename SV, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
store(thread const GL &_dst, threadgroup const SV &src, thread const coord &idx, const int threadIdx) {
|
||||
using read_vector = ReadVector<1>;
|
||||
using U = typename GL::dtype;
|
||||
constexpr int elem_per_transfer = sizeof(read_vector) / sizeof(typename SV::dtype);
|
||||
constexpr int total_calls = SV::length / elem_per_transfer; // guaranteed to divide
|
||||
device U *dst = (device U*)&_dst.template get<SV>(idx);
|
||||
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < total_calls; i+= GROUP_THREADS) {
|
||||
if(i * elem_per_transfer < src.length)
|
||||
*(device read_vector*)&dst[i*elem_per_transfer] = *(threadgroup read_vector*)&src[i*elem_per_transfer]; // lmao it's identical
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for a group to collaboratively transfer data directly between shared memory and registers and back.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Collaboratively load data from a shared vector into register vectors split across a warpgroup.
|
||||
*
|
||||
* @tparam RV The register vector type
|
||||
* @tparam SV The shared vector type
|
||||
* @param dst[out] The destination register vector.
|
||||
* @param src[in] The source shared vector.
|
||||
*/
|
||||
template<typename RV, typename SV>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
load(thread RV &dst, threadgroup const SV &_src, const int threadIdx) {
|
||||
using T = typename RV::dtype;
|
||||
using U = typename SV::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
|
||||
static_assert(SV::length == RV::length*N_WARPS, "rv and sv dimensions do not match");// confirm size correct
|
||||
// threadgroup typename SV::template subvec<typename SV::dtype, RV::outer_dim> &src = subvec_inplace<RV::outer_dim, SV>(_src, warpid(threadIdx));
|
||||
// threadgroup subvec &src = subvec_inplace<RV::outer_dim, SV>(_src, warpid(threadIdx));
|
||||
unsigned warpId = warpid(threadIdx);
|
||||
using subvec = typename SV::template subvec<RV::length>;
|
||||
|
||||
threadgroup subvec& src = *(threadgroup subvec*)(&_src[warpId *RV::length]);
|
||||
|
||||
::mittens::load<RV, subvec>(dst, src, simd_laneid(threadIdx)); // warp-level
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Collaboratively store data into a shared vector from register vectors split across a warpgroup.
|
||||
*
|
||||
* @tparam RV The register vector type
|
||||
* @tparam SV The shared vector type
|
||||
* @param dst[out] The destination shared vector.
|
||||
* @param src[in] The source register vector.
|
||||
*/
|
||||
template<typename SV, typename RV>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
store(threadgroup SV &_dst, thread const RV &src, const int threadIdx) {
|
||||
using T = typename RV::dtype;
|
||||
using U = typename SV::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
|
||||
|
||||
static_assert(SV::length == RV::length*N_WARPS, "rv and sv dimensions do not match");// confirm size correct
|
||||
|
||||
// threadgroup typename SV::template subvec<typename SV::dtype, RV::outer_dim> &dst = subvec_inplace<RV::outer_dim, SV>(_dst, warpid(threadIdx));
|
||||
// ::mittens::store<threadgroup typename SV::template subvec<typename SV::dtype, RV::outer_dim>, RV>(dst, src, simd_laneid(threadIdx)); // warp-level
|
||||
|
||||
unsigned warpId = warpid(threadIdx);
|
||||
using subvec = typename SV::template subvec<RV::length>;
|
||||
threadgroup subvec& dst = *(threadgroup subvec*)(&_dst[warpId * RV::length]);
|
||||
|
||||
::mittens::store(dst, src, simd_laneid(threadIdx)); // warp-level
|
||||
}
|
||||
8
extra/thunder/include/ops/group/memory/vec/vec.metal
Normal file
8
extra/thunder/include/ops/group/memory/vec/vec.metal
Normal file
@@ -0,0 +1,8 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header of group memory operations on vectors.
|
||||
*/
|
||||
|
||||
#include "shared_to_register.metal"
|
||||
#include "global_to_register.metal"
|
||||
#include "global_to_shared.metal"
|
||||
3
extra/thunder/include/ops/group/shared/shared.metal
Normal file
3
extra/thunder/include/ops/group/shared/shared.metal
Normal file
@@ -0,0 +1,3 @@
|
||||
|
||||
#include "tile/tile.metal"
|
||||
#include "vec/vec.metal"
|
||||
@@ -0,0 +1,27 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group conversions between different shared memory tile types.
|
||||
*/
|
||||
|
||||
/* ---------- COPIES ---------- */
|
||||
|
||||
/**
|
||||
* @brief Copies data from one shared memory tile to another, potentially with different data types and layouts.
|
||||
*
|
||||
* @tparam T The data type of the destination tile.
|
||||
* @tparam U The data type of the source tile.
|
||||
* @tparam _height The height of the tile.
|
||||
* @tparam _width The width of the tile.
|
||||
* @tparam L1 The layout of the destination tile.
|
||||
* @tparam L2 The layout of the source tile.
|
||||
* @param[out] dst The destination tile.
|
||||
* @param[in] src The source tile.
|
||||
*/
|
||||
template<typename T, typename U, int _height, int _width>
|
||||
static METAL_FUNC void copy(threadgroup st<T, _height, _width> &dst, threadgroup const st<U, _height, _width> &src, const int threadIdx) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < dst.num_elements; i+=GROUP_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst[{row, col}] = base_types::convertor<T, U>::convert(src[{row, col}]);
|
||||
}
|
||||
}
|
||||
475
extra/thunder/include/ops/group/shared/tile/maps.metal
Normal file
475
extra/thunder/include/ops/group/shared/tile/maps.metal
Normal file
@@ -0,0 +1,475 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group maps on shared tiles.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Performs a uniform unary operation on a tile.
|
||||
*
|
||||
* This function applies a given unary operation to each element of the source tile and stores the result in the destination tile.
|
||||
* The operation is applied independently to each element, without considering its position or the values of neighboring elements.
|
||||
*
|
||||
* @tparam op The unary operation to be applied. Must be specialized to support operation on the data type of T.
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the unary operation is applied.
|
||||
*/
|
||||
template<typename op, typename ST> // T2, w, h can be inferred from dst as long as op is specialized
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
unary_map(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
dst.data[i] = op::template op<typename ST::dtype>(src.data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs a uniform binary operation on a tile with a scalar parameter.
|
||||
*
|
||||
* This function applies a given binary operation to each element of the source tile and a scalar parameter, then stores the result in the destination tile.
|
||||
* The operation is applied independently to each element, treating the scalar parameter as the second operand for each operation.
|
||||
*
|
||||
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the scalar parameter.
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the binary operation is applied.
|
||||
* @param[in] param The scalar parameter to be used as the second operand in the binary operation.
|
||||
*/
|
||||
template<typename op, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
bin_map(threadgroup ST &dst, threadgroup const ST &src, thread const typename ST::dtype ¶m, const int threadIdx) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
dst.data[i] = op::template op<typename ST::dtype>(src.data[i], param);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs a uniform binary operation on two tiles.
|
||||
*
|
||||
* This function applies a given binary operation to corresponding elements of two source tiles and stores the result in the destination tile.
|
||||
* The operation is applied independently to each pair of elements, without considering their positions or the values of neighboring elements.
|
||||
*
|
||||
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T.
|
||||
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile to which the binary operation is applied.
|
||||
* @param[in] rhs The second source tile to which the binary operation is applied.
|
||||
*/
|
||||
template<typename op, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
bin_map(threadgroup ST &dst, threadgroup const ST &lhs, threadgroup const ST &rhs, const int threadIdx) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
dst.data[i] = op::template op<typename ST::dtype>(lhs.data[i], rhs.data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs a row-wise binary operation on a tile with a vector.
|
||||
*
|
||||
* This function applies a given binary operation to each row of the source tile and the corresponding element of the source vector,
|
||||
* then stores the result in the destination tile. The operation is applied independently to each row, using the vector element as
|
||||
* the second operand for each element in the row.
|
||||
*
|
||||
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements.
|
||||
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam V The type of the vector. Must have the same data type as T.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the binary operation is applied.
|
||||
* @param[in] vec The source vector containing the second operand for each row operation.
|
||||
*/
|
||||
template<typename op, typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_map(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &vec, const int threadIdx) {
|
||||
static_assert(metal::is_same<typename ST::dtype, typename SV::dtype>::value, "Tile and vector must have the same data type");
|
||||
static_assert(SV::length == ST::rows, "Vector length must match the number of rows in the tile");
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst[{row, col}] = op::template op<typename ST::dtype>(src[{row, col}], vec[row]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs a column-wise binary operation on a tile with a vector.
|
||||
*
|
||||
* This function applies a given binary operation to each column of the source tile and the corresponding element of the source vector,
|
||||
* then stores the result in the destination tile. The operation is applied independently to each column, using the vector element as
|
||||
* the second operand for each element in the column.
|
||||
*
|
||||
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements.
|
||||
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam V The type of the vector. Must have the same data type as T.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the binary operation is applied.
|
||||
* @param[in] vec The source vector containing the second operand for each column operation.
|
||||
*/
|
||||
template<typename op, typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_map(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &vec, const int threadIdx) {
|
||||
static_assert(metal::is_same<typename ST::dtype, typename SV::dtype>::value, "Tile and vector must have the same data type");
|
||||
static_assert(SV::length == ST::cols, "Vector length must match the number of columns in the tile");
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst[{row, col}] = op::template op<typename ST::dtype>(src[{row, col}], vec[col]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// All of the annoying qualifiers *should* be automatically inferred during compile-time.
|
||||
// So, syntax should just be mittens::add_row(tile, colvec);
|
||||
|
||||
// const maps
|
||||
/**
|
||||
* @brief Sets all elements of the destination tile to zero.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
zero(threadgroup ST &dst, const int threadIdx) {
|
||||
unary_map<base_ops::zero, ST>(dst, dst, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of the destination tile to one.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
one(threadgroup ST &dst, const int threadIdx) {
|
||||
unary_map<base_ops::one, ST>(dst, dst, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of the destination tile to positive infinity.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
pos_infty(threadgroup ST &dst, const int threadIdx) {
|
||||
unary_map<base_ops::pos_infty, ST>(dst, dst, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of the destination tile to negative infinity.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
neg_infty(threadgroup ST &dst, const int threadIdx) {
|
||||
unary_map<base_ops::neg_infty, ST>(dst, dst, threadIdx);
|
||||
}
|
||||
|
||||
// unary maps
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of the source tile and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the exponential function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
exp(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
||||
unary_map<base_ops::exp, ST>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of the source tile and stores the result in the destination tile, in base 2.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the exponential function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
exp2(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
||||
unary_map<base_ops::exp2, ST>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the natural logarithm function to each element of the source tile and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the natural logarithm function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
log(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
||||
unary_map<base_ops::log, ST>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the absolute function to each element of the source tile and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the absolute function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
abs(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
||||
unary_map<base_ops::abs, ST>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the rectified linear unit function to each element of the source tile and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the rectified linear unit function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
relu(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
||||
unary_map<base_ops::relu, ST>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Copies the elements of the source tile to the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source data to be copied.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
copy(threadgroup ST &dst, thread const U &src, const int threadIdx) {
|
||||
bin_map<base_ops::copy, ST>(dst, src, threadIdx);
|
||||
}
|
||||
|
||||
// uniform binary maps
|
||||
/**
|
||||
* @brief Finds the maximum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
max(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_map<base_ops::max, ST>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Finds the minimum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
min(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_map<base_ops::min, ST>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Adds each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
add(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_map<base_ops::sum, ST>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
sub(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_map<base_ops::sub, ST>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
mul(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_map<base_ops::mul, ST>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Divides each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
div(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_map<base_ops::div, ST>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
|
||||
// Row and col maps
|
||||
|
||||
/**
|
||||
* @brief Adds row values to each row of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the addition on.
|
||||
* @param row_values[in] Column vector containing values to add to each row.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
add_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const int threadIdx) {
|
||||
row_map<base_ops::sum, ST, SV>(dst, src, row_values, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts row values from each row of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the subtraction on.
|
||||
* @param row_values[in] Column vector containing values to subtract from each row.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
sub_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const int threadIdx) {
|
||||
row_map<base_ops::sub, ST, SV>(dst, src, row_values, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies each row of a tile by row values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the multiplication on.
|
||||
* @param row_values[in] Column vector containing values to multiply each row by.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
mul_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const int threadIdx) {
|
||||
row_map<base_ops::mul, ST, SV>(dst, src, row_values, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Divides each row of a tile by row values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the division on.
|
||||
* @param row_values[in] Column vector containing values to divide each row by.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
div_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const int threadIdx) {
|
||||
row_map<base_ops::div, ST, SV>(dst, src, row_values, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Broadcast a vector into into a tile's rows.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param row_values[in] Column vector containing values to broadcast into rows.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
broadcast_row(threadgroup ST &dst, threadgroup const SV &row_values, const int threadIdx) {
|
||||
row_map<base_ops::copy2, ST, SV>(dst, dst, row_values, threadIdx);
|
||||
}
|
||||
|
||||
|
||||
// col maps
|
||||
/**
|
||||
* @brief Adds column values to each column of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the addition on.
|
||||
* @param col_values[in] Row vector containing values to add to each column.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
add_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const int threadIdx) {
|
||||
col_map<base_ops::sum, ST, SV>(dst, src, col_values, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts column values from each column of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the subtraction on.
|
||||
* @param col_values[in] Row vector containing values to subtract from each column.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
sub_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const int threadIdx) {
|
||||
col_map<base_ops::sub, ST, SV>(dst, src, col_values, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies each column of a tile by column values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the multiplication on.
|
||||
* @param col_values[in] Row vector containing values to multiply each column by.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
mul_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const int threadIdx) {
|
||||
col_map<base_ops::mul, ST, SV>(dst, src, col_values, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Divides each column of a tile by column values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the division on.
|
||||
* @param col_values[in] Row vector containing values to divide each column by.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
div_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const int threadIdx) {
|
||||
col_map<base_ops::div, ST, SV>(dst, src, col_values, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Broadcast a vector into into a tile's columns.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param row_values[in] Row vector containing values to broadcast into cols.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
broadcast_col(threadgroup ST &dst, threadgroup const SV &col_values, const int threadIdx) {
|
||||
col_map<base_ops::copy2, ST, SV>(dst, dst, col_values, threadIdx);
|
||||
}
|
||||
284
extra/thunder/include/ops/group/shared/tile/reductions.metal
Normal file
284
extra/thunder/include/ops/group/shared/tile/reductions.metal
Normal file
@@ -0,0 +1,284 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group reductions on shared tiles.
|
||||
*/
|
||||
|
||||
/**
|
||||
* Performs row-wise reduction on a matrix using a specified operation.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type with row layout.
|
||||
* @param row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param src The source matrix on which to perform the reduction.
|
||||
* @param src_accum The initial value of the accumulator, used when reset is false.
|
||||
* @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
*/
|
||||
template<typename op, typename SV, typename ST, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_reduce(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
using dtype = typename SV::dtype;
|
||||
for (int row = laneid(threadIdx); row < src.rows; row += GROUP_THREADS) {
|
||||
dtype accum = src[{row, 0}];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int col = 1; col < src.cols; col++) {
|
||||
accum = op::template op<dtype>(accum, src[{row, col}]);
|
||||
}
|
||||
if (reset) {
|
||||
row_accum[row] = accum;
|
||||
} else {
|
||||
row_accum[row] = op::template op<dtype>(src_accum[row], accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs column-wise reduction on a matrix using a specified operation.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The shared vector type for the column accumulator.
|
||||
* @tparam T The shared matrix type with column layout.
|
||||
* @param col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param src The source matrix on which to perform the reduction.
|
||||
* @param src_accum The initial value of the accumulator, used when reset is false.
|
||||
* @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
*/
|
||||
template<typename op, typename SV, typename ST, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_reduce(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
using dtype = typename SV::dtype;
|
||||
for (int col = laneid(threadIdx); col < src.cols; col += GROUP_THREADS) {
|
||||
dtype accum = src[{0, col}];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int row = 1; row < src.rows; row++) {
|
||||
accum = op::template op<dtype>(accum, src[{row, col}]);
|
||||
}
|
||||
if (reset) {
|
||||
col_accum[col] = accum;
|
||||
} else {
|
||||
col_accum[col] = op::template op<dtype>(src_accum[col], accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_max(threadgroup SV &row_accum, threadgroup const ST &src, const int threadIdx) {
|
||||
row_reduce<base_ops::max, SV, ST, true>(row_accum, src, row_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_min(threadgroup SV &row_accum, threadgroup const ST &src, const int threadIdx) {
|
||||
row_reduce<base_ops::min, SV, ST, true>(row_accum, src, row_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_sum(threadgroup SV &row_accum, threadgroup const ST &src, const int threadIdx) {
|
||||
row_reduce<base_ops::sum, SV, ST, true>(row_accum, src, row_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_prod(threadgroup SV &row_accum, threadgroup const ST &src, const int threadIdx) {
|
||||
row_reduce<base_ops::mul, SV, ST, true>(row_accum, src, row_accum, threadIdx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_max(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
row_reduce<base_ops::max, SV, ST, false>(row_accum, src, src_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_min(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
row_reduce<base_ops::min, SV, ST, false>(row_accum, src, src_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_sum(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
row_reduce<base_ops::sum, SV, ST, false>(row_accum, src, src_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_prod(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
row_reduce<base_ops::mul, SV, ST, false>(row_accum, src, src_accum, threadIdx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_max(threadgroup SV &col_accum, threadgroup const ST &src, const int threadIdx) {
|
||||
col_reduce<base_ops::max, SV, ST, true>(col_accum, src, col_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_min(threadgroup SV &col_accum, threadgroup const ST &src, const int threadIdx) {
|
||||
col_reduce<base_ops::min, threadgroup SV, threadgroup ST, true>(col_accum, src, col_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_sum(threadgroup SV &col_accum, threadgroup const ST &src, const int threadIdx) {
|
||||
col_reduce<base_ops::sum, SV, ST, true>(col_accum, src, col_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_prod(threadgroup SV &col_accum, threadgroup const ST &src, const int threadIdx) {
|
||||
col_reduce<base_ops::mul, SV, ST, true>(col_accum, src, col_accum, threadIdx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_max(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
col_reduce<base_ops::max, SV, ST, false>(col_accum, src, src_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_min(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
col_reduce<base_ops::min, SV, ST, false>(col_accum, src, src_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_sum(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
col_reduce<base_ops::sum, SV, ST, false>(col_accum, src, src_accum, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_prod(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const int threadIdx) {
|
||||
col_reduce<base_ops::mul, SV, ST, false>(col_accum, src, src_accum, threadIdx);
|
||||
}
|
||||
3
extra/thunder/include/ops/group/shared/tile/tile.metal
Normal file
3
extra/thunder/include/ops/group/shared/tile/tile.metal
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "conversions.metal"
|
||||
#include "maps.metal"
|
||||
#include "reductions.metal"
|
||||
29
extra/thunder/include/ops/group/shared/vec/conversions.metal
Normal file
29
extra/thunder/include/ops/group/shared/vec/conversions.metal
Normal file
@@ -0,0 +1,29 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group conversions on shared vectors.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Copies data from one shared vector to another, converting data types if necessary.
|
||||
*
|
||||
* This function copies data from the source shared vector `src` to the destination shared vector `dst`.
|
||||
* If the data types of `src` and `dst` are the same, it performs a direct memory copy. Otherwise, it
|
||||
* converts each element from the source data type to the destination data type using the appropriate
|
||||
* converter before copying.
|
||||
*
|
||||
* @tparam SV1 The type of the destination shared vector, must satisfy the ducks::sv::all concept.
|
||||
* @tparam SV2 The type of the source shared vector, must satisfy the ducks::sv::all concept.
|
||||
* @param[out] dst The destination shared vector.
|
||||
* @param[in] src The source shared vector.
|
||||
* @note The lengths of `src` and `dst` must be equal. This is enforced at compile time.
|
||||
*/
|
||||
template<typename SV1, typename SV2>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV1>() && ducks::is_shared_vector<SV2>(), void>::type
|
||||
copy(threadgroup SV1 &dst, threadgroup const SV2 &src, const int threadIdx) {
|
||||
static_assert(SV1::length == SV2::length, "Source and destination vectors must have the same length.");
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid(threadIdx); i < dst.length; i+=GROUP_THREADS) {
|
||||
dst[i] = base_types::convertor<typename SV1::dtype, typename SV2::dtype>::convert(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
267
extra/thunder/include/ops/group/shared/vec/maps.metal
Normal file
267
extra/thunder/include/ops/group/shared/vec/maps.metal
Normal file
@@ -0,0 +1,267 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Group maps on shared vectors.
|
||||
*/
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to each element of a shared memory vector.
|
||||
*
|
||||
* @tparam op Unary operation type.
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector in which to store the result.
|
||||
* @param src[in] Source vector to apply the unary operation.
|
||||
*/
|
||||
template<typename op, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
unary_op(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(auto cur = laneid(threadIdx); cur < SV::length; cur+=GROUP_THREADS) {
|
||||
dst[cur] = op::template op<typename SV::dtype>(src[cur]);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on two shared vectors.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vectors.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param lhs[in] The left-hand side vector for the operation.
|
||||
* @param rhs[in] The right-hand side vector for the operation.
|
||||
*/
|
||||
template<typename op, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
bin_op(threadgroup SV &dst, threadgroup const SV &lhs, threadgroup const SV &rhs, const int threadIdx) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(auto cur = laneid(threadIdx); cur < SV::length; cur+=GROUP_THREADS) {
|
||||
dst[cur] = op::template op<typename SV::dtype>(lhs[cur], rhs[cur]);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on a shared vector and a scalar.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vector.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param src[in] The source vector for the operation.
|
||||
* @param param[in] The scalar parameter for the operation.
|
||||
*/
|
||||
template<typename op, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
bin_op(threadgroup SV &dst, threadgroup const SV &src, thread const typename SV::dtype ¶m, const int threadIdx) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(auto cur = laneid(threadIdx); cur < SV::length; cur+=GROUP_THREADS) {
|
||||
dst[cur] = op::template op<typename SV::dtype>(src[cur], param);
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// ---- const ops ----
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to zero.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to zero.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
zero(threadgroup SV &dst, const int threadIdx) {
|
||||
unary_op<base_ops::zero, SV>(dst, dst, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to one.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to one.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
one(threadgroup SV &dst, const int threadIdx) {
|
||||
unary_op<base_ops::one, SV>(dst, dst, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to positive infinity.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to positive infinity.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
pos_infty(threadgroup SV &dst, const int threadIdx) {
|
||||
unary_op<base_ops::pos_infty, SV>(dst, dst, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to negative infinity.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to negative infinity.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
neg_infty(threadgroup SV &dst, const int threadIdx) {
|
||||
unary_op<base_ops::neg_infty, SV>(dst, dst, threadIdx);
|
||||
}
|
||||
|
||||
// ---- unary ops ----
|
||||
|
||||
/**
|
||||
* @brief Copies the elements from one shared vector to another.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the source vector.
|
||||
* @param dst[out] Destination vector where the elements will be copied to.
|
||||
* @param src[in] Source vector to copy the elements from.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
copy(threadgroup SV &dst, thread const U &src, const int threadIdx) {
|
||||
bin_op<base_ops::copy2, SV>(dst, dst, src, threadIdx); // the second arg is ignored here.
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
exp(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) {
|
||||
unary_op<base_ops::exp, SV>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a shared vector, in base 2.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
exp2(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) {
|
||||
unary_op<base_ops::exp2, SV>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the natural logarithm function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the logarithm function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
log(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) {
|
||||
unary_op<base_ops::log, SV>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the absolute value function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the absolute values will be stored.
|
||||
* @param src[in] Source vector to apply the absolute value function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
abs(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) {
|
||||
unary_op<base_ops::abs, SV>(dst, src, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the rectified linear unit (ReLU) function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the ReLU values will be stored.
|
||||
* @param src[in] Source vector to apply the ReLU function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
relu(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) {
|
||||
unary_op<base_ops::relu, SV>(dst, src, threadIdx);
|
||||
}
|
||||
|
||||
// ---- binary ops ----
|
||||
|
||||
/**
|
||||
* @brief Computes the element-wise maximum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the maximum values will be stored.
|
||||
* @param lhs[in] First vector for the maximum operation.
|
||||
* @param rhs[in] Second vector for the maximum operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
max(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_op<base_ops::max, SV>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise minimum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the minimum values will be stored.
|
||||
* @param lhs[in] First vector for the minimum operation.
|
||||
* @param rhs[in] Second vector for the minimum operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
min(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_op<base_ops::min, SV>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise sum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the sum values will be stored.
|
||||
* @param lhs[in] First vector for the sum operation.
|
||||
* @param rhs[in] Second vector for the sum operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
add(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_op<base_ops::sum, SV>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise difference of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the difference values will be stored.
|
||||
* @param lhs[in] First vector for the difference operation.
|
||||
* @param rhs[in] Second vector for the difference operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
sub(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_op<base_ops::sub, SV>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise product of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the product values will be stored.
|
||||
* @param lhs[in] First vector for the product operation.
|
||||
* @param rhs[in] Second vector for the product operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
mul(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_op<base_ops::mul, SV>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise division of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the division values will be stored.
|
||||
* @param lhs[in] First vector for the division operation.
|
||||
* @param rhs[in] Second vector for the division operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
div(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) {
|
||||
bin_op<base_ops::div, SV>(dst, lhs, rhs, threadIdx);
|
||||
}
|
||||
3
extra/thunder/include/ops/group/shared/vec/vec.metal
Normal file
3
extra/thunder/include/ops/group/shared/vec/vec.metal
Normal file
@@ -0,0 +1,3 @@
|
||||
#include "conversions.metal"
|
||||
#include "maps.metal"
|
||||
|
||||
3
extra/thunder/include/ops/ops.metal
Normal file
3
extra/thunder/include/ops/ops.metal
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
#include "group/group.metal"
|
||||
#include "warp/warp.metal"
|
||||
4
extra/thunder/include/ops/warp/memory/memory.metal
Normal file
4
extra/thunder/include/ops/warp/memory/memory.metal
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "tile/tile.metal"
|
||||
#include "util/util.metal"
|
||||
#include "vec/vec.metal"
|
||||
@@ -0,0 +1,51 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between global memory and registers and back.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../../common/common.metal"
|
||||
#include "../../../../../types/types.metal"
|
||||
|
||||
#include "../global_to_register.metal"
|
||||
|
||||
namespace mittens {
|
||||
/**
|
||||
* @brief Load data from source arrays into a complex-type tile.
|
||||
*
|
||||
* @tparam CRT The complex tile type.
|
||||
* @tparam U The data type of the source arrays.
|
||||
* @param dst[out] The destination tile to load data into.
|
||||
* @param resrc[in] The source array to load the real component data from.
|
||||
* @param imsrc[in] The source array to load the imaginary component data from.
|
||||
* @param re_row_stride[in] The stride in elements between rows in the real component source array.
|
||||
* @param im_row_stride[in] The stride in elements between rows in the imaginary component source array.
|
||||
*/
|
||||
template<typename CRT, typename CGL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_complex_register_tile<CRT>() && ducks::is_complex_global_layout<CGL>(), void>::type
|
||||
load(thread CRT &dst, thread const CGL &src, thread const coord &idx, const short laneid) {
|
||||
// Internally will use the correct load() method for row and column types
|
||||
load(dst.real, src.real, idx);
|
||||
load(dst.imag, src.imag, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data from a complex register tile to destination arrays in global memory.
|
||||
*
|
||||
* @tparam CRT The complex tile type.
|
||||
* @tparam U The data type of the destination arrays.
|
||||
* @param redst[out] The destination array in global memory to store the real component data into.
|
||||
* @param imdst[out] The destination array in global memory to store the imaginary component data into.
|
||||
* @param src[in] The source register tile to store data from.
|
||||
* @param re_row_stride[in] The stride in elements between rows in the real component destination array.
|
||||
* @param im_row_stride[in] The stride in elements between rows in the imaginary component destination array.
|
||||
*/
|
||||
template<typename CRT, typename CGL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_complex_register_tile<CRT>() && ducks::is_complex_global_layout<CGL>(), void>::type
|
||||
store(thread CGL &dst, thread const CRT &src, thread const coord &idx) {
|
||||
// Internally will use the correct load() method for row and column types
|
||||
store(dst.real, src.real, idx);
|
||||
store(dst.imag, src.imag, idx);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between global and shared memory and back.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../../common/common.metal"
|
||||
#include "../../../../../types/types.metal"
|
||||
|
||||
#include "../global_to_shared.metal"
|
||||
|
||||
namespace mittens {
|
||||
/**
|
||||
* @brief Loads data from global memory into a complex shared memory tile with a row layout.
|
||||
*
|
||||
* @tparam CST The type of the complex shared tile.
|
||||
* @param[out] dst The destination complex shared memory tile.
|
||||
* @param[in] resrc The source global memory array for the real component.
|
||||
* @param[in] imsrc The source global memory array for the imaginary component.
|
||||
* @param re_row_stride[in] The stride between rows in the source real component array.
|
||||
* @param im_row_stride[in] The stride between rows in the source imaginary component array.
|
||||
*/
|
||||
template<typename CST, typename CGL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_complex_shared_tile<CST>() && ducks::is_global_layout<CGL>(), void>::type
|
||||
load(threadgroup CST &dst, thread const CGL &src, thread const coord &idx) {
|
||||
load(dst.real, src.real, idx);
|
||||
load(dst.imag, src.imag, idx);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Stores bf16 data from a complex shared memory tile with a row layout into global memory.
|
||||
*
|
||||
* @tparam CST The type of the complex shared tile.
|
||||
* @param[out] redst The destination global memory array for the real component.
|
||||
* @param[out] imdst The destination global memory array for the imaginary component.
|
||||
* @param[in] src The source complex shared memory tile.
|
||||
* @param re_row_stride[in] The stride between rows in the destination real component array.
|
||||
* @param im_row_stride[in] The stride between rows in the destination imaginary component array.
|
||||
*/
|
||||
template<typename CST, typename CGL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_complex_shared_tile<CST>() && ducks::is_complex_global_layout<CGL>(), void>::type
|
||||
store(thread const CGL &dst, threadgroup CST &src, thread const coord &idx) {
|
||||
store(dst.real, src.real, idx);
|
||||
store(dst.imag, src.imag, idx);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,47 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between shared memory and registers and back.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
#include "../../../../../common/common.metal"
|
||||
#include "../../../../../types/types.metal"
|
||||
|
||||
#include "../shared_to_register.metal"
|
||||
|
||||
namespace mittens {
|
||||
/**
|
||||
* @brief Load data from a complex shared tile into a complex register tile.
|
||||
*
|
||||
* @tparam CRT The complex register tile type
|
||||
* @tparam CST The complex shared tile type
|
||||
* @param dst[out] The destination complex register tile.
|
||||
* @param src[in] The source complex shared tile.
|
||||
*/
|
||||
template<typename CRT, typename CST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_complex_shared_tile<CST>() && ducks::is_complex_register_tile<CRT>(), void>::type
|
||||
load(thread CRT &dst, threadgroup const CST &src) {
|
||||
load(dst.real, src.real);
|
||||
load(dst.imag, src.imag);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data into a complex shared tile from a complex register tile.
|
||||
*
|
||||
* @tparam RT The complex register tile type
|
||||
* @tparam ST The complex shared tile type
|
||||
* @param dst[out] The destination complex shared tile.
|
||||
* @param src[in] The source complex register tile.
|
||||
*/
|
||||
template<typename CRT, typename CST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_complex_shared_tile<CST>() && ducks::is_complex_register_tile<CRT>(), void>::type
|
||||
store(threadgroup CST &dst, thread const CRT &src) {
|
||||
store(dst.real, src.real);
|
||||
store(dst.imag, src.imag);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,217 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between global memory and registers and back.
|
||||
*/
|
||||
|
||||
#pragma once // done!
|
||||
#include "../../../../types/types.metal"
|
||||
#include "../../../../common/common.metal"
|
||||
#include <metal_stdlib>
|
||||
namespace mittens{
|
||||
|
||||
namespace meta {
|
||||
template<typename RT, typename U>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>(), void>::type
|
||||
load(int i, int j, thread RT *dst, const device U *src_ptr, const short simd_y, const short simd_x, const int row_stride) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename RT::T2;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using layout = typename RT::layout;
|
||||
unsigned offset = (simd_y + i * rt_base<T, layout>::tile_size) * row_stride + (simd_x + j * rt_base<T, layout>::tile_size);
|
||||
T2 src2 = base_types::convertor<T2, U2>::convert(*((device U2*)(&src_ptr[offset])));
|
||||
dst->tiles[i][j].data.thread_elements()[0] = src2[0];
|
||||
dst->tiles[i][j].data.thread_elements()[1] = src2[1];
|
||||
}
|
||||
|
||||
template<typename RT, typename U>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>(), void>::type
|
||||
load(int i, int j, thread RT *dst, const device U *src_ptr, const short simd_y, const short simd_x, const int row_stride) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename RT::T2;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using layout = typename RT::layout;
|
||||
unsigned offset = (simd_y + i * rt_base<T, layout>::tile_size) * row_stride + (simd_x + j * rt_base<T, layout>::tile_size);
|
||||
dst->tiles[i][j].data.thread_elements()[0] = base_types::convertor<T, U>::convert(src_ptr[offset]);
|
||||
offset += row_stride;
|
||||
dst->tiles[i][j].data.thread_elements()[1] = base_types::convertor<T, U>::convert(src_ptr[offset]);
|
||||
}
|
||||
|
||||
template<typename RT, typename U>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>(), void>::type
|
||||
store(int i, int j, device U *dst_ptr, const thread RT *src, const short simd_y, const short simd_x, const int row_stride) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename RT::T2;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using layout = typename RT::layout;
|
||||
unsigned offset = (simd_y + i * TILE_DIM) * row_stride + (simd_x + j * TILE_DIM);
|
||||
U2 src2 = base_types::convertor<U2, T2>::convert(
|
||||
T2(src->tiles[i][j].data.thread_elements()[0],
|
||||
src->tiles[i][j].data.thread_elements()[1])
|
||||
);
|
||||
*((device U2*)&dst_ptr[offset]) = src2;
|
||||
}
|
||||
|
||||
template<typename RT, typename U>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>(), void>::type
|
||||
store(int i, int j, device U *dst_ptr, const thread RT *src, const short simd_y, const short simd_x, const int row_stride) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename RT::T2;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using layout = typename RT::layout;
|
||||
unsigned offset = (simd_y + i * rt_base<T, layout>::tile_size) * row_stride + (simd_x + j * rt_base<T, layout>::tile_size);
|
||||
dst_ptr[offset] = base_types::convertor<U, T>::convert(src->tiles[i][j].data.thread_elements()[0]);
|
||||
offset += row_stride;
|
||||
dst_ptr[offset] = base_types::convertor<U, T>::convert(src->tiles[i][j].data.thread_elements()[1]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load data from a source array into a row-major layout tile.
|
||||
*
|
||||
* @tparam RT The row-major layout tile type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param dst[out] The destination tile to load data into.
|
||||
* @param src[in] The source array to load data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the source array.
|
||||
*/
|
||||
template<typename RT, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_global_layout<GL>(), void>::type
|
||||
load(thread RT &dst, thread const GL &src, thread const coord &idx, const short laneid) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename RT::T2;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using layout = typename RT::layout;
|
||||
const device U *src_ptr = (device U*)&src.template get<RT>(idx);
|
||||
const int row_stride = src.row_stride();
|
||||
|
||||
const short qid = laneid / 4;
|
||||
const short simd_y = (qid & 4) + (laneid / 2) % 4;
|
||||
const short simd_x = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int i = 0; i < RT::height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int j = 0; j < RT::width; j++) {
|
||||
// unsigned offset = (simd_y + i * rt_base<T, layout>::tile_size) * row_stride + (simd_x + j * rt_base<T, layout>::tile_size);
|
||||
// T2 src2 = base_types::convertor<T2, U2>::convert(*((device U2*)(&src_ptr[offset])));
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = src2[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = src2[1];
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::load<RT, U>, &dst, src_ptr, simd_y, simd_x, row_stride);
|
||||
}
|
||||
/**
|
||||
* @brief Load data from a source array into a col-major layout tile.
|
||||
*
|
||||
* @tparam RT The row-major layout tile type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param dst[out] The destination tile to load data into.
|
||||
* @param src[in] The source array to load data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the source array.
|
||||
*/
|
||||
template<typename RT, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_global_layout<GL>(), void>::type
|
||||
load(thread RT &dst, thread const GL &src, thread const coord &idx, const short laneid) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename RT::T2;
|
||||
using U = typename GL::dtype;
|
||||
using layout = typename RT::layout;
|
||||
const device U *src_ptr = (device U*)&(src.template get<RT>(idx));
|
||||
const int row_stride = src.row_stride();
|
||||
|
||||
const short qid = laneid / 4;
|
||||
const short simd_x = (qid & 4) + (laneid / 2) % 4;
|
||||
const short simd_y = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int i = 0; i < RT::height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int j = 0; j < RT::width; j++) {
|
||||
// unsigned offset = (simd_y + i * rt_base<T, layout>::tile_size) * row_stride + (simd_x + j * rt_base<T, layout>::tile_size);
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = base_types::convertor<T, U>::convert(src_ptr[offset]);
|
||||
// offset += row_stride;
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = base_types::convertor<T, U>::convert(src_ptr[offset]);
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::load<RT, U>, &dst, src_ptr, simd_y, simd_x, row_stride);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data from a register tile to a destination array in global memory with a row-major layout.
|
||||
*
|
||||
* @tparam RT The register tile type with a row-major layout.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register tile to store data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the destination array.
|
||||
*/
|
||||
template<typename RT, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_global_layout<GL>(), void>::type
|
||||
store(thread GL &dst, thread const RT &src, thread const coord &idx, const short laneid) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename RT::T2;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using layout = typename RT::layout;
|
||||
device U *dst_ptr = (device U*)&(dst.template get<RT>(idx));
|
||||
// device U* dst_ptr = dst.raw_ptr;
|
||||
const int row_stride = dst.row_stride();
|
||||
const short qid = laneid / 4;
|
||||
const short simd_y = (qid & 4) + (laneid / 2) % 4;
|
||||
const short simd_x = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int i = 0; i < RT::height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int j = 0; j < RT::width; j++) {
|
||||
// unsigned offset = (simd_y + i * TILE_DIM) * row_stride + (simd_x + j * TILE_DIM);
|
||||
// U2 src2 = base_types::convertor<U2, T2>::convert(
|
||||
// T2(src.tiles[i][j].data.thread_elements()[0],
|
||||
// src.tiles[i][j].data.thread_elements()[1])
|
||||
// );
|
||||
// *((device U2*)&dst_ptr[offset]) = src2;
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::store<RT, U>, dst_ptr, &src, simd_y, simd_x, row_stride);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data from a register tile to a destination array in global memory with a col-major layout.
|
||||
*
|
||||
* @tparam RT The register tile type with a row-major layout.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register tile to store data from.
|
||||
* @param row_stride[in] The stride in elements between rows in the destination array.
|
||||
*/
|
||||
template<typename RT, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_global_layout<GL>(), void>::type
|
||||
store(thread GL &dst, thread const RT &src, thread const coord &idx, const short laneid) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename RT::T2;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
using layout = typename RT::layout;
|
||||
device U *dst_ptr = (device U*)&(dst.template get<RT>(idx));
|
||||
const int row_stride = dst.row_stride();
|
||||
const short qid = laneid / 4;
|
||||
const short simd_x = (qid & 4) + (laneid / 2) % 4;
|
||||
const short simd_y = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int i = 0; i < RT::height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for (int j = 0; j < RT::width; j++) {
|
||||
// unsigned offset = (simd_y + i * rt_base<T, layout>::tile_size) * row_stride + (simd_x + j * rt_base<T, layout>::tile_size);
|
||||
// dst_ptr[offset] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[0]);
|
||||
// offset += row_stride;
|
||||
// dst_ptr[offset] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[1]);
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::store<RT, U>, dst_ptr, &src, simd_y, simd_x, row_stride);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,192 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between global and shared memory and back.
|
||||
*/
|
||||
|
||||
#pragma once // not done!
|
||||
#include "../../../../types/types.metal"
|
||||
#include "../../../../common/common.metal"
|
||||
#include <metal_stdlib>
|
||||
namespace mittens {
|
||||
|
||||
//
|
||||
namespace meta {
|
||||
template<typename ST, int memcpy_per_row, int elem_per_memcpy, int READ_FLOATS>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
load(int i, threadgroup ST *dst, device const typename ST::dtype *src, thread const int& row_stride, thread const short& laneid) {
|
||||
{
|
||||
unsigned idx = i + laneid;
|
||||
unsigned row = idx / memcpy_per_row;
|
||||
unsigned col = (idx*elem_per_memcpy) % ST::cols;
|
||||
*(threadgroup ReadVector<READ_FLOATS>*)(&(*dst)[int2(row, col)]) = *(device ReadVector<READ_FLOATS>*)(&src[row*row_stride + col]);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename ST, int memcpy_per_row, int elem_per_memcpy, int READ_FLOATS>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
store(int i, device typename ST::dtype *dst, threadgroup const ST *src, thread const int& row_stride, thread const short& laneid) {
|
||||
{
|
||||
unsigned idx = i + laneid;
|
||||
unsigned row = idx / memcpy_per_row;
|
||||
unsigned col = (idx*elem_per_memcpy) % ST::cols;
|
||||
*(device ReadVector<READ_FLOATS>*)(&dst[row*row_stride + col]) = *(threadgroup ReadVector<READ_FLOATS>*)(&(*src)[int2(row, col)]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace meta
|
||||
|
||||
//
|
||||
///**
|
||||
// * @brief Loads data from global memory into a shared memory tile with a row layout.
|
||||
// *
|
||||
// * @tparam ST The type of the shared tile.
|
||||
// * @param[out] dst The destination shared memory tile.
|
||||
// * @param[in] src The source global memory array.
|
||||
// * @param row_stride[in] The stride between rows in the source array.
|
||||
// * @param laneid[in] Thread's index in SIMD group
|
||||
// */
|
||||
//template<typename ST>
|
||||
//static METAL_FUNC void load(threadgroup ST &dst, device const typename ST::dtype *src, const int row_stride, short laneid) {
|
||||
// using read_type = float;
|
||||
// ducks::assert_shared_tile<ST>();
|
||||
// constexpr const unsigned elem_per_memcpy = sizeof(read_type)/sizeof(typename ST::dtype); // 2
|
||||
// constexpr const unsigned memcpy_per_row = ST::cols / elem_per_memcpy; // 32/2=16 not power of 2
|
||||
// constexpr const unsigned total_calls = ST::num_elements / (SIMD_THREADS*elem_per_memcpy); // 1024/(32*2)=16
|
||||
//// #pragma clang loop unroll_count(1)
|
||||
//// #pragma clang loop unroll(disable)
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(unsigned i = 0; i < total_calls; i++) {
|
||||
// unsigned idx = i * 32 + laneid;
|
||||
// unsigned row = idx / memcpy_per_row;
|
||||
// unsigned col = (idx*elem_per_memcpy) % ST::cols;
|
||||
// *(threadgroup read_type*)(&dst[int2(row, col)]) = *(device read_type*)(&src[row*row_stride + col]);
|
||||
// }
|
||||
//
|
||||
//// ducks::assert_shared_tile<ST>();
|
||||
//// const constexpr int read_size = 1;
|
||||
//// using read_type = ReadVector<read_size>;
|
||||
//// constexpr const unsigned elem_per_memcpy = sizeof(read_type)/sizeof(typename ST::dtype); // 2
|
||||
//// constexpr const unsigned memcpy_per_row = ST::cols / elem_per_memcpy; // 32/2=16 not power of 2
|
||||
//// constexpr const unsigned total_calls = ST::num_elements / (SIMD_THREADS*elem_per_memcpy); // 1024/(32*2)=16
|
||||
////
|
||||
////
|
||||
//// meta::unroll_i_in_range<0, total_calls * SIMD_THREADS, SIMD_THREADS>::run(meta::load<ST, memcpy_per_row, elem_per_memcpy, read_size>, &dst, src, row_stride, laneid);
|
||||
//}
|
||||
//
|
||||
//
|
||||
///**
|
||||
// * @brief Stores data from a shared memory tile with a row layout into global memory.
|
||||
// *
|
||||
// * @tparam ST The type of the shared tile.
|
||||
// * @param[out] dst The destination global memory array.
|
||||
// * @param[in] src The source shared memory tile.
|
||||
// * @param row_stride[in] The stride between rows in the destination array.
|
||||
// * @param laneid[in] Thread's index in SIMD group
|
||||
// */
|
||||
//template<typename ST>
|
||||
//static METAL_FUNC void store(device typename ST::dtype *dst, threadgroup const ST &src, const int row_stride, short laneid) {
|
||||
// using read_type = float4;
|
||||
// ducks::assert_shared_tile<ST>();
|
||||
// constexpr const unsigned elem_per_memcpy = sizeof(read_type)/sizeof(typename ST::dtype);
|
||||
// constexpr const unsigned memcpy_per_row = ST::cols / elem_per_memcpy;
|
||||
// constexpr const unsigned total_calls = ST::num_elements / (SIMD_THREADS*elem_per_memcpy);
|
||||
//// #pragma clang loop unroll_count(READ_SIZE)
|
||||
////#pragma clang loop unroll(disable)
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(unsigned i = 0; i < total_calls; i++) {
|
||||
// unsigned idx = i * 32 + laneid;
|
||||
// unsigned row = idx / memcpy_per_row;
|
||||
// unsigned col = (idx*elem_per_memcpy) % src.cols;
|
||||
// *(device read_type*)(&dst[row*row_stride + col]) = *(threadgroup read_type*)(&src[int2(row, col)]);
|
||||
// }
|
||||
//
|
||||
////
|
||||
//// ducks::assert_shared_tile<ST>();
|
||||
//// const constexpr int read_size = 1;
|
||||
//// using read_type = ReadVector<read_size>;
|
||||
////
|
||||
//// constexpr const unsigned elem_per_memcpy = sizeof(read_type)/sizeof(typename ST::dtype);
|
||||
//// constexpr const unsigned memcpy_per_row = ST::cols / elem_per_memcpy;
|
||||
//// constexpr const unsigned total_calls = ST::num_elements / (SIMD_THREADS*elem_per_memcpy);
|
||||
////
|
||||
////
|
||||
//// meta::unroll_i_in_range<0, total_calls * SIMD_THREADS, SIMD_THREADS>::run(meta::store<ST, memcpy_per_row, elem_per_memcpy, read_size>, dst, &src, row_stride, laneid);
|
||||
//}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* @brief Loads data from global memory into a shared memory tile with a row layout.
|
||||
*
|
||||
* @tparam ST The type of the shared tile.
|
||||
* @param[out] dst The destination shared memory tile.
|
||||
* @param[in] src The source global memory array.
|
||||
* @param row_stride[in] The stride between rows in the source array.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename ST, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_global_layout<GL>(), void>::type
|
||||
load(threadgroup ST &dst, thread const GL &src, thread const coord &idx, short laneid) {
|
||||
using U = typename GL::dtype;
|
||||
constexpr const int read_size = 1;
|
||||
using read_type = ReadVector<read_size>;
|
||||
device U *src_ptr = (device U*)&src.template get<ST>(idx);
|
||||
const int row_stride = src.row_stride();
|
||||
constexpr const unsigned elem_per_memcpy = sizeof(read_type)/sizeof(typename ST::dtype); // 2
|
||||
constexpr const unsigned memcpy_per_row = ST::cols / elem_per_memcpy; // 32/2=16 not power of 2
|
||||
constexpr const unsigned total_calls = ST::num_elements / (SIMD_THREADS*elem_per_memcpy); // 1024/(32*2)=16
|
||||
// #pragma clang loop unroll_count(1)
|
||||
// #pragma clang loop unroll(disable)
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(unsigned i = 0; i < total_calls; i++) {
|
||||
// unsigned idx = i * 32 + laneid;
|
||||
// unsigned row = idx / memcpy_per_row;
|
||||
// unsigned col = (idx*elem_per_memcpy) % ST::cols;
|
||||
// *(threadgroup read_type*)(&dst[int2(row, col)]) = *(device read_type*)(&src_ptr[row*row_stride + col]);
|
||||
// }
|
||||
meta::unroll_i_in_range<0, total_calls * SIMD_THREADS, SIMD_THREADS>::run(meta::load<ST, memcpy_per_row, elem_per_memcpy, read_size>, &dst, src_ptr, row_stride, laneid);
|
||||
}
|
||||
/*
|
||||
|
||||
*/
|
||||
|
||||
|
||||
/**
|
||||
* @brief Stores data from a shared memory tile with a row layout into global memory.
|
||||
*
|
||||
* @tparam ST The type of the shared tile.
|
||||
* @param[out] dst The destination global memory array.
|
||||
* @param[in] src The source shared memory tile.
|
||||
* @param row_stride[in] The stride between rows in the destination array.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename ST, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_global_layout<GL>(), void>::type
|
||||
store(thread GL &dst, threadgroup const ST &src, thread const coord &idx, short laneid) {
|
||||
using U = typename GL::dtype;
|
||||
constexpr const int read_size = 1;
|
||||
using read_type = ReadVector<read_size>;
|
||||
device U *dst_ptr = (device U*)&dst.template get<ST>(idx);
|
||||
const int row_stride = dst.row_stride();
|
||||
|
||||
constexpr const unsigned elem_per_memcpy = sizeof(read_type)/sizeof(typename ST::dtype);
|
||||
constexpr const unsigned memcpy_per_row = ST::cols / elem_per_memcpy;
|
||||
constexpr const unsigned total_calls = ST::num_elements / (SIMD_THREADS*elem_per_memcpy);
|
||||
// #pragma clang loop unroll_count(READ_SIZE)
|
||||
//#pragma clang loop unroll(disable)
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(unsigned i = 0; i < total_calls; i++) {
|
||||
// unsigned idx = i * 32 + laneid;
|
||||
// unsigned row = idx / memcpy_per_row;
|
||||
// unsigned col = (idx*elem_per_memcpy) % src.cols;
|
||||
// *(device read_type*)(&dst_ptr[row*row_stride + col]) = *(threadgroup read_type*)(&src[int2(row, col)]);
|
||||
// }
|
||||
|
||||
meta::unroll_i_in_range<0, total_calls * SIMD_THREADS, SIMD_THREADS>::run(meta::store<ST, memcpy_per_row, elem_per_memcpy, read_size>, dst_ptr, &src, row_stride, laneid);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,461 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between shared memory and registers and back.
|
||||
*/
|
||||
#pragma once // done!
|
||||
|
||||
#include "../../../../types/types.metal"
|
||||
#include "../../../../common/common.metal"
|
||||
#include <metal_stdlib>
|
||||
namespace mittens {
|
||||
|
||||
// These probably need to be redone to reduce bank conflicts.
|
||||
// They currently work fine with xor layout but it should be
|
||||
// possible to reduce their bank conflicts with other layouts too.
|
||||
//
|
||||
namespace meta {
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
loadStR(int i, int j, thread RT *dst, threadgroup const ST *src, short laneid, int offsetY, int offsetX) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int y = offsetY + i * mittens::TILE_DIM;
|
||||
int x = offsetX + j * mittens::TILE_DIM;
|
||||
T2 values = base_types::convertor<T2, U2>::convert(*((threadgroup U2*)(&(*src)[int2(y, x)])));
|
||||
dst->tiles[i][j].data.thread_elements()[0] = values[0];
|
||||
dst->tiles[i][j].data.thread_elements()[1] = values[1];
|
||||
//
|
||||
// simdgroup_load(dst->tiles[i][j].data,
|
||||
// (threadgroup T*)(src->data),
|
||||
// src->cols,
|
||||
// {i * mittens::TILE_DIM, j * mittens::TILE_DIM},
|
||||
//
|
||||
}
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
storeStR(int i, int j, threadgroup ST *dst, thread const RT *src, short laneid, int offsetY, int offsetX) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int y = offsetY + i * mittens::TILE_DIM;
|
||||
int x = offsetX + j * mittens::TILE_DIM;
|
||||
U2 values = base_types::convertor<U2, T2>::convert({src->tiles[i][j].data.thread_elements()[0], src->tiles[i][j].data.thread_elements()[1]});
|
||||
*((threadgroup U2*)(&(*dst)[int2(y, x)])) = values;
|
||||
}
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
loadStR(int i, int j, thread RT *dst, threadgroup const ST *src, short laneid, int offsetY, int offsetX) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int y = offsetY + i * mittens::TILE_DIM;
|
||||
int x = offsetX + j * mittens::TILE_DIM;
|
||||
// dst->tiles[i][j].data.thread_elements()[0] = base_types::convertor<T, U>::convert((*src)[int2(y , x)]);
|
||||
// dst->tiles[i][j].data.thread_elements()[1] = base_types::convertor<T, U>::convert((*src)[int2(y+1, x)]);
|
||||
T2 vals = base_types::convertor<T2, U2>::convert({(*src)[int2(y , x)], (*src)[int2(y+1, x)]});
|
||||
dst->tiles[i][j].data.thread_elements()[0] = vals[0];
|
||||
dst->tiles[i][j].data.thread_elements()[1] = vals[1];
|
||||
}
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
storeStR(int i, int j, threadgroup ST *dst, thread const RT *src, short laneid, int offsetY, int offsetX) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int y = offsetY + i * mittens::TILE_DIM;
|
||||
int x = offsetX + j * mittens::TILE_DIM;
|
||||
// (*dst)[int2(y , x)] = base_types::convertor<U, T>::convert(src->tiles[i][j].data.thread_elements()[0]);
|
||||
// (*dst)[int2(y+1, x)] = base_types::convertor<U, T>::convert(src->tiles[i][j].data.thread_elements()[1]);
|
||||
|
||||
U2 vals = base_types::convertor<U2, T2>::convert({src->tiles[i][j].data.thread_elements()[0], src->tiles[i][j].data.thread_elements()[1]});
|
||||
(*dst)[int2(y , x)] = vals[0];
|
||||
(*dst)[int2(y+1, x)] = vals[1];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load data from a shared tile into a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination register tile.
|
||||
* @param src[in] The source shared tile.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
load(thread RT &dst, threadgroup const ST &src, short laneid) {
|
||||
static_assert(RT::height == ST::height, "register tile and shared tile must match height");
|
||||
static_assert(RT::width == ST::width, "register tile and shared tile must match width");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
const short qid = laneid / 4;
|
||||
int offsetY = (qid & 4) + (laneid / 2) % 4;
|
||||
int offsetX = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < dst.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < dst.width; j++) {
|
||||
// int y = offsetY + i * mittens::TILE_DIM;
|
||||
// int x = offsetX + j * mittens::TILE_DIM;
|
||||
// T2 values = base_types::convertor<T2, U2>::convert(*((threadgroup U2*)(&src[int2(y, x)])));
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = values[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = values[1];
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::loadStR<RT, ST>, &dst, &src, laneid, offsetY, offsetX);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load data from a shared tile into a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination register tile.
|
||||
* @param src[in] The source shared tile.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
load(thread RT &dst, threadgroup const ST &src, short laneid) {
|
||||
static_assert(RT::height == ST::height, "register tile and shared tile must match height");
|
||||
static_assert(RT::width == ST::width, "register tile and shared tile must match width");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
const short qid = laneid / 4;
|
||||
// int offsetY = (qid & 4) + (laneid / 2) % 4;
|
||||
// int offsetX = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
int offsetX = (qid & 4) + (laneid / 2) % 4;
|
||||
int offsetY = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < dst.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < dst.width; j++) {
|
||||
// int y = offsetY + i * mittens::TILE_DIM;
|
||||
// int x = offsetX + j * mittens::TILE_DIM;
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = base_types::convertor<T, U>::convert(src[int2(y , x)]);
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = base_types::convertor<T, U>::convert(src[int2(y+1, x)]);
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::loadStR<RT, ST>, &dst, &src, laneid, offsetY, offsetX);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data into a shared tile from a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination shared tile.
|
||||
* @param src[in] The source register tile.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
store(threadgroup ST &dst, thread const RT &src, short laneid) {
|
||||
ducks::assert_register_tile<RT>();
|
||||
ducks::assert_shared_tile<ST>();
|
||||
static_assert(RT::height == ST::height, "register tile and shared tile must match height");
|
||||
static_assert(RT::width == ST::width, "register tile and shared tile must match width");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
|
||||
const short qid = laneid / 4;
|
||||
int offsetY = (qid & 4) + (laneid / 2) % 4;
|
||||
int offsetX = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < src.width; j++) {
|
||||
// int y = offsetY + i * mittens::TILE_DIM;
|
||||
// int x = offsetX + j * mittens::TILE_DIM;
|
||||
// U2 values = base_types::convertor<U2, T2>::convert({src.tiles[i][j].data.thread_elements()[0], src.tiles[i][j].data.thread_elements()[1]});
|
||||
// *((threadgroup U2*)(&dst[int2(y, x)])) = values;
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::storeStR<RT, ST>, &dst, &src, laneid, offsetY, offsetX);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data into a shared tile from a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination shared tile.
|
||||
* @param src[in] The source register tile.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
store(threadgroup ST &dst, thread const RT &src, short laneid) {
|
||||
ducks::assert_register_tile<RT>();
|
||||
ducks::assert_shared_tile<ST>();
|
||||
static_assert(RT::height == ST::height, "register tile and shared tile must match height");
|
||||
static_assert(RT::width == ST::width, "register tile and shared tile must match width");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
|
||||
const short qid = laneid / 4;
|
||||
// int offsetY = (qid & 4) + (laneid / 2) % 4;
|
||||
// int offsetX = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
int offsetX = (qid & 4) + (laneid / 2) % 4;
|
||||
int offsetY = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < src.width; j++) {
|
||||
// int y = offsetY + i * mittens::TILE_DIM;
|
||||
// int x = offsetX + j * mittens::TILE_DIM;
|
||||
// dst[int2(y , x)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[0]);
|
||||
// dst[int2(y+1, x)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[1]);
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::storeStR<RT, ST>, &dst, &src, laneid, offsetY, offsetX);
|
||||
}
|
||||
|
||||
/*---------------------------------------------------------------------------------*/
|
||||
// These probably need to be redone to reduce bank conflicts.
|
||||
// They currently work fine with xor layout but it should be
|
||||
// possible to reduce their bank conflicts with other layouts too.
|
||||
//
|
||||
namespace meta {
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
loadStR_r(int i, int j, thread RT *dst, thread const ST *src, short laneid, int offsetY, int offsetX) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int y = offsetY + i * mittens::TILE_DIM;
|
||||
int x = offsetX + j * mittens::TILE_DIM;
|
||||
T2 values = base_types::convertor<T2, U2>::convert(*((threadgroup U2*)(&(*src)[int2(y, x)])));
|
||||
dst->tiles[i][j].data.thread_elements()[0] = values[0];
|
||||
dst->tiles[i][j].data.thread_elements()[1] = values[1];
|
||||
//
|
||||
// simdgroup_load(dst->tiles[i][j].data,
|
||||
// (threadgroup T*)(src->data),
|
||||
// src->cols,
|
||||
// {i * mittens::TILE_DIM, j * mittens::TILE_DIM},
|
||||
//
|
||||
}
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
storeStR_r(int i, int j, thread ST *dst, thread const RT *src, short laneid, int offsetY, int offsetX) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int y = offsetY + i * mittens::TILE_DIM;
|
||||
int x = offsetX + j * mittens::TILE_DIM;
|
||||
U2 values = base_types::convertor<U2, T2>::convert({src->tiles[i][j].data.thread_elements()[0], src->tiles[i][j].data.thread_elements()[1]});
|
||||
*((threadgroup U2*)(&(*dst)[int2(y, x)])) = values;
|
||||
}
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
loadStR_c(int i, int j, thread RT *dst, thread const ST *src, short laneid, int offsetY, int offsetX) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int y = offsetY + i * mittens::TILE_DIM;
|
||||
int x = offsetX + j * mittens::TILE_DIM;
|
||||
// dst->tiles[i][j].data.thread_elements()[0] = base_types::convertor<T, U>::convert((*src)[int2(y , x)]);
|
||||
// dst->tiles[i][j].data.thread_elements()[1] = base_types::convertor<T, U>::convert((*src)[int2(y+1, x)]);
|
||||
T2 vals = base_types::convertor<T2, U2>::convert({(*src)[int2(y , x)], (*src)[int2(y+1, x)]});
|
||||
dst->tiles[i][j].data.thread_elements()[0] = vals[0];
|
||||
dst->tiles[i][j].data.thread_elements()[1] = vals[1];
|
||||
}
|
||||
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
storeStR_c(int i, int j, thread ST *dst, thread const RT *src, short laneid, int offsetY, int offsetX) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
int y = offsetY + i * mittens::TILE_DIM;
|
||||
int x = offsetX + j * mittens::TILE_DIM;
|
||||
// (*dst)[int2(y , x)] = base_types::convertor<U, T>::convert(src->tiles[i][j].data.thread_elements()[0]);
|
||||
// (*dst)[int2(y+1, x)] = base_types::convertor<U, T>::convert(src->tiles[i][j].data.thread_elements()[1]);
|
||||
|
||||
U2 vals = base_types::convertor<U2, T2>::convert({src->tiles[i][j].data.thread_elements()[0], src->tiles[i][j].data.thread_elements()[1]});
|
||||
(*dst)[int2(y , x)] = vals[0];
|
||||
(*dst)[int2(y+1, x)] = vals[1];
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load data from a shared tile into a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination register tile.
|
||||
* @param src[in] The source shared tile.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
load(thread RT &dst, thread const ST &src, short laneid) {
|
||||
static_assert(RT::height == ST::height, "register tile and shared tile must match height");
|
||||
static_assert(RT::width == ST::width, "register tile and shared tile must match width");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
const short qid = laneid / 4;
|
||||
int offsetY = (qid & 4) + (laneid / 2) % 4;
|
||||
int offsetX = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < dst.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < dst.width; j++) {
|
||||
// int y = offsetY + i * mittens::TILE_DIM;
|
||||
// int x = offsetX + j * mittens::TILE_DIM;
|
||||
// T2 values = base_types::convertor<T2, U2>::convert(*((threadgroup U2*)(&src[int2(y, x)])));
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = values[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = values[1];
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::loadStR_r<RT, ST>, &dst, &src, laneid, offsetY, offsetX);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Load data from a shared tile into a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination register tile.
|
||||
* @param src[in] The source shared tile.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
load(thread RT &dst, thread const ST &src, short laneid) {
|
||||
static_assert(RT::height == ST::height, "register tile and shared tile must match height");
|
||||
static_assert(RT::width == ST::width, "register tile and shared tile must match width");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
const short qid = laneid / 4;
|
||||
// int offsetY = (qid & 4) + (laneid / 2) % 4;
|
||||
// int offsetX = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
int offsetX = (qid & 4) + (laneid / 2) % 4;
|
||||
int offsetY = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < dst.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < dst.width; j++) {
|
||||
// int y = offsetY + i * mittens::TILE_DIM;
|
||||
// int x = offsetX + j * mittens::TILE_DIM;
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = base_types::convertor<T, U>::convert(src[int2(y , x)]);
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = base_types::convertor<T, U>::convert(src[int2(y+1, x)]);
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::loadStR_c<RT, ST>, &dst, &src, laneid, offsetY, offsetX);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data into a shared tile from a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination shared tile.
|
||||
* @param src[in] The source register tile.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
store(thread ST &dst, thread const RT &src, short laneid) {
|
||||
ducks::assert_register_tile<RT>();
|
||||
ducks::assert_shared_tile<ST>();
|
||||
static_assert(RT::height == ST::height, "register tile and shared tile must match height");
|
||||
static_assert(RT::width == ST::width, "register tile and shared tile must match width");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
|
||||
const short qid = laneid / 4;
|
||||
int offsetY = (qid & 4) + (laneid / 2) % 4;
|
||||
int offsetX = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < src.width; j++) {
|
||||
// int y = offsetY + i * mittens::TILE_DIM;
|
||||
// int x = offsetX + j * mittens::TILE_DIM;
|
||||
// U2 values = base_types::convertor<U2, T2>::convert({src.tiles[i][j].data.thread_elements()[0], src.tiles[i][j].data.thread_elements()[1]});
|
||||
// *((threadgroup U2*)(&dst[int2(y, x)])) = values;
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::storeStR_r<RT, ST>, &dst, &src, laneid, offsetY, offsetX);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data into a shared tile from a register tile.
|
||||
*
|
||||
* @tparam RT The register tile type
|
||||
* @tparam ST The shared tile type
|
||||
* @param dst[out] The destination shared tile.
|
||||
* @param src[in] The source register tile.
|
||||
* @param laneid[in] Thread's index in SIMD group
|
||||
*/
|
||||
template<typename RT, typename ST>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_shared_tile<ST>(), void>::type
|
||||
store(thread ST &dst, thread const RT &src, short laneid) {
|
||||
ducks::assert_register_tile<RT>();
|
||||
ducks::assert_shared_tile<ST>();
|
||||
static_assert(RT::height == ST::height, "register tile and shared tile must match height");
|
||||
static_assert(RT::width == ST::width, "register tile and shared tile must match width");
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
using U = typename ST::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
|
||||
const short qid = laneid / 4;
|
||||
// int offsetY = (qid & 4) + (laneid / 2) % 4;
|
||||
// int offsetX = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
int offsetX = (qid & 4) + (laneid / 2) % 4;
|
||||
int offsetY = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < src.width; j++) {
|
||||
// int y = offsetY + i * mittens::TILE_DIM;
|
||||
// int x = offsetX + j * mittens::TILE_DIM;
|
||||
// dst[int2(y , x)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[0]);
|
||||
// dst[int2(y+1, x)] = base_types::convertor<U, T>::convert(src.tiles[i][j].data.thread_elements()[1]);
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::storeStR_c<RT, ST>, &dst, &src, laneid, offsetY, offsetX);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
7
extra/thunder/include/ops/warp/memory/tile/tile.metal
Normal file
7
extra/thunder/include/ops/warp/memory/tile/tile.metal
Normal file
@@ -0,0 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "global_to_register.metal"
|
||||
#include "global_to_shared.metal"
|
||||
#include "shared_to_register.metal"
|
||||
|
||||
|
||||
37
extra/thunder/include/ops/warp/memory/util/util.metal
Normal file
37
extra/thunder/include/ops/warp/memory/util/util.metal
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief General utilities not specialized for either tiles or vectors.
|
||||
*/
|
||||
#pragma once // done!
|
||||
#include "../tile/tile.metal"
|
||||
#include "../../../../types/shared/shared.metal"
|
||||
namespace mittens {
|
||||
|
||||
// sizeof() can be unreliable when working with references to objects
|
||||
// plus, template magic allows arrays of these objects to be copied, too.
|
||||
namespace detail {
|
||||
|
||||
template <typename T, uint32_t... dims>
|
||||
struct size_info;
|
||||
|
||||
template <typename T>
|
||||
struct size_info<T> {
|
||||
private:
|
||||
static_assert(ducks::is_shared_tile<T>() || ducks::is_shared_vector<T>(), "T must be a shared tile or shared vector");
|
||||
constant static constexpr uint32_t elements = ducks::is_shared_tile<T>() ? T::num_elements : T::length;
|
||||
constant static constexpr uint32_t bytes = elements * sizeof(typename T::dtype);
|
||||
};
|
||||
|
||||
template <typename T, uint32_t dim, uint32_t... rest_dims>
|
||||
struct size_info<T, dim, rest_dims...> {
|
||||
constant static constexpr uint32_t elements = dim * size_info<T, rest_dims...>::elements;
|
||||
constant static constexpr uint32_t bytes = dim * size_info<T, rest_dims...>::bytes;
|
||||
};
|
||||
}
|
||||
|
||||
template<typename T, uint32_t... dims> constant constexpr uint32_t size_elements = detail::size_info<T, dims...>::elements;
|
||||
template<typename T, uint32_t... dims> constant constexpr uint32_t size_bytes = detail::size_info<T, dims...>::bytes;
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between global memory and registers and back.
|
||||
*/
|
||||
#pragma once // not done
|
||||
/*
|
||||
TODO:
|
||||
change loads/stores, prevent unnecessary
|
||||
*/
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
/**
|
||||
* @brief Load data into a register vector from a source array in global memory.
|
||||
*
|
||||
* @tparam RV The register vector type.
|
||||
* @tparam U The data type of the source array.
|
||||
* @param[out] dst The destination register vector to load data into.
|
||||
* @param[in] src The source array in global memory to load data from.
|
||||
*/
|
||||
template<typename RV, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>() && ducks::is_global_layout<GL>(), void>::type
|
||||
load(thread RV &dst, thread const GL &src, thread const coord &idx, const short laneid) {
|
||||
using RV_T = typename RV::dtype;
|
||||
using RV_T2 = typename base_types::packing<RV_T>::packed_type;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
device U *src_ptr = (device U*)&src.template get<RV>(idx);
|
||||
if (ducks::is_align_layout<typename RV::layout>()) {
|
||||
constexpr const uint32_t MASK_1 = 0x00AA00AA; // kitty bit magic
|
||||
constexpr const uint32_t MASK_2 = 0x55005500;
|
||||
constexpr const uint32_t MASK_3 = 0xAA00AA00;
|
||||
unsigned offset = ((MASK_1 >> laneid) & 1u) * 2 + ((MASK_2 >> laneid) & 1u) * 4 + ((MASK_3 >> laneid) & 1u) * 6;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = 0; t < RV::outer_dim; offset+=8, t++) {
|
||||
RV_T2 src2 = base_types::convertor<RV_T2, U2>::convert(*(device U2*)(&src_ptr[offset]));
|
||||
dst.data[t][0] = src2[0];
|
||||
dst.data[t][1] = src2[1];
|
||||
}
|
||||
} else if (ducks::is_ortho_layout<typename RV::layout>()) { // RV::inner_dim == 1
|
||||
const short laneid_div2 = laneid / 2;
|
||||
unsigned offset = laneid_div2 % 4 + (laneid_div2 / 8) * 4;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = 0; t < RV::outer_dim; offset+=8, t++) {
|
||||
dst.data[t][0] = base_types::convertor<RV_T, U>::convert(src_ptr[offset]);
|
||||
}
|
||||
} else if (ducks::is_naive_layout<typename RV::layout>()) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(auto w = 0; w < RV::outer_dim; w++) {
|
||||
// if(w < dst.outer_dim-1 || dst.length%32 == 0 || laneid<16) {
|
||||
if (w * SIMD_THREADS + laneid < RV::length) {
|
||||
dst[w][0] = base_types::convertor<RV_T, U>::convert(src_ptr[w * SIMD_THREADS + laneid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store data from a register vector to a destination array in global memory.
|
||||
*
|
||||
* @tparam RV The register vector type.
|
||||
* @tparam U The data type of the destination array.
|
||||
* @param[out] dst The destination array in global memory to store data into.
|
||||
* @param[in] src The source register vector to store data from.
|
||||
*/
|
||||
template<typename RV, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>() && ducks::is_global_layout<GL>(), void>::type
|
||||
store(thread GL &dst, thread const RV &src, thread const coord &idx, const short laneid) {
|
||||
using RV_T = typename RV::dtype;
|
||||
using RV_T2 = typename base_types::packing<RV_T>::packed_type;
|
||||
using U = typename GL::dtype;
|
||||
using U2 = typename base_types::packing<U>::packed_type;
|
||||
device U *dst_ptr = (device U*)&(dst.template get<RV>(idx));
|
||||
if (ducks::is_align_layout<typename RV::layout>()) {
|
||||
constexpr const uint32_t MASK_1 = 0x00AA00AA; // kitty bit magic
|
||||
constexpr const uint32_t MASK_2 = 0x55005500;
|
||||
constexpr const uint32_t MASK_3 = 0xAA00AA00;
|
||||
unsigned offset = ((MASK_1 >> laneid) & 1u) * 2 + ((MASK_2 >> laneid) & 1u) * 4 + ((MASK_3 >> laneid) & 1u) * 6;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = 0; t < RV::outer_dim; offset+=8, t++) {
|
||||
U2 src2 = base_types::convertor<U2, RV_T2>::convert({src.data[t][0], src.data[t][1]});
|
||||
*(device U2*)(&dst_ptr[offset]) = src2;
|
||||
}
|
||||
} else if (ducks::is_ortho_layout<typename RV::layout>()){ // RV::inner_dim == 1
|
||||
const short laneid_div2 = laneid / 2;
|
||||
unsigned offset = laneid_div2 % 4 + (laneid_div2 / 8) * 4;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = 0; t < RV::outer_dim; offset+=8, t++) {
|
||||
dst_ptr[offset] = base_types::convertor<U, RV_T>::convert(src.data[t][0]);
|
||||
}
|
||||
} else {
|
||||
#pragma clang loop unroll(full)
|
||||
for(auto w = 0; w < RV::outer_dim; w++) {
|
||||
// if(w < dst.outer_dim-1 || dst.length%32 == 0 || laneid<16) {
|
||||
if (w * SIMD_THREADS + laneid < RV::length) {
|
||||
dst_ptr[w * SIMD_THREADS + laneid] = base_types::convertor<U, RV_T>::convert(src.data[w][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,44 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between global and shared memory and back.
|
||||
*/
|
||||
|
||||
#pragma once // done!
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
template<typename SV, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_shared_vector<SV>() && ducks::is_global_layout<GL>(), void>::type
|
||||
load(threadgroup SV &dst, thread const GL &src, thread const coord &idx, const unsigned laneid) {
|
||||
using read_type = float4;
|
||||
using U = typename GL::dtype;
|
||||
constexpr int elem_per_transfer = sizeof(read_type) / sizeof(typename SV::dtype);
|
||||
constexpr int total_calls = SV::length / elem_per_transfer; // guaranteed to divide
|
||||
device U *src_ptr = (device U*)&src.template get<SV>(idx);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = laneid; i < total_calls; i += mittens::SIMD_THREADS) {
|
||||
if(i * elem_per_transfer < dst.length) {
|
||||
*(threadgroup read_type*)&dst[i*elem_per_transfer] = *(device read_type*)&src_ptr[i*elem_per_transfer];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename SV, typename GL>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_shared_vector<SV>() && ducks::is_global_layout<GL>(), void>::type
|
||||
store(thread const GL &dst, threadgroup const SV &src, thread const coord &idx, const unsigned laneid) {
|
||||
using read_type = float4;
|
||||
using U = typename GL::dtype;
|
||||
constexpr int elem_per_transfer = sizeof(read_type) / sizeof(typename SV::dtype);
|
||||
constexpr int total_calls = SV::length / elem_per_transfer; // guaranteed to divide
|
||||
device U *dst_ptr = (device U*)&dst.template get<SV>(idx);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int i = laneid; i < total_calls; i += mittens::SIMD_THREADS) {
|
||||
if(i * elem_per_transfer < src.length) {
|
||||
*(device read_type*)&dst_ptr[i*elem_per_transfer] = *(threadgroup read_type*)&src[i*elem_per_transfer];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,208 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Functions for transferring data directly between shared memory and registers and back.
|
||||
*/
|
||||
|
||||
#pragma once // not done
|
||||
/*
|
||||
TODO:
|
||||
prevent unnecesary memory back forth
|
||||
|
||||
*/
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
/**
|
||||
* @brief Load data from a shared vector into a register vector.
|
||||
*
|
||||
* @tparam RV The register vector type
|
||||
* @tparam SV The shared vector type
|
||||
* @param dst[out] The destination register vector.
|
||||
* @param src[in] The source shared vector.
|
||||
*/
|
||||
|
||||
/*
|
||||
"For row-vectors:
|
||||
0,2,4,6,16,18,20,22 holds %8+0 & %8 +1
|
||||
1,3,5,7,17,19,21,23 holds %8+2 & %8+3
|
||||
00000000101010100000000010101010 = 0x00AA00AA
|
||||
8,10,12,14,24,26,28,30 holds %8+4 & %8+5
|
||||
01010101000000000101010100000000 = 0x55005500
|
||||
9,11,13,15,25,27,29,31 holds %8+6 & %8+7"
|
||||
10101010000000001010101000000000 = 0xAA00AA00
|
||||
|
||||
"For colum-vectors:
|
||||
0,1,8,9 holds %8+0
|
||||
2,3,10,11 holds %8+1
|
||||
4,5,12,13 holds %8+2
|
||||
6,7,14,15 holds %8+3
|
||||
16,17,24,25 holds %8+4
|
||||
18,19,26,27 holds %8+5
|
||||
20,21,28,29 holds %8+6
|
||||
22,23,30,31 holds %8+7
|
||||
|
||||
0,0,4,4 holds %8+0
|
||||
1,1,5,5 holds %8+1
|
||||
2,2,6,6 holds %8+2
|
||||
3,3,7,7 holds %8+3
|
||||
8,8,12,12 holds %8+4
|
||||
9,9,13,13 holds %8+5
|
||||
10,10,14,14 holds %8+6
|
||||
11,11,15,15 holds %8+7
|
||||
"
|
||||
|
||||
0 0 1 1 8 8 9 9
|
||||
2 2 3 3 10 10 11 11
|
||||
4 4 5 5 12 12 13 13
|
||||
6 6 7 7 14 14 15 15
|
||||
16 16 17 17 24 24 25 25
|
||||
18 18 19 19 26 26 27 27
|
||||
20 20 21 21 28 28 29 29
|
||||
22 22 23 23 30 30 31 31
|
||||
*/
|
||||
// optimize later
|
||||
template<typename RV, typename SV>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
load(thread RV &dst, threadgroup const SV &src, const short laneid) {
|
||||
using RV_T = typename RV::dtype;
|
||||
using RV_T2 = typename base_types::packing<RV_T>::packed_type;
|
||||
using SV_T = typename SV::dtype;
|
||||
using SV_T2 = typename base_types::packing<SV_T>::packed_type;
|
||||
|
||||
|
||||
static_assert(SV::tiles == RV::tiles, "RV and SV dimensions must match");
|
||||
|
||||
if (ducks::is_align_layout<typename RV::layout>()) {
|
||||
constexpr const uint32_t MASK_1 = 0x00AA00AA; // kitty bit magic
|
||||
constexpr const uint32_t MASK_2 = 0x55005500;
|
||||
constexpr const uint32_t MASK_3 = 0xAA00AA00;
|
||||
unsigned offset = ((MASK_1 >> laneid) & 1u) * 2 + ((MASK_2 >> laneid) & 1u) * 4 + ((MASK_3 >> laneid) & 1u) * 6;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = 0; t < SV::tiles; offset+=8, t++) {
|
||||
RV_T2 src2 = base_types::convertor<RV_T2, SV_T2>::convert(*(threadgroup SV_T2*)(&src.data[offset]));
|
||||
dst.data[t][0] = src2[0];
|
||||
dst.data[t][1] = src2[1];
|
||||
// dst.data[t][0] = 7.f;
|
||||
// dst.data[t][1] = 7.f;
|
||||
}
|
||||
} else if (ducks::is_ortho_layout<typename RV::layout>()) {
|
||||
const short laneid_div2 = laneid / 2;
|
||||
unsigned offset = laneid_div2 % 4 + (laneid_div2 / 8) * 4;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = 0; t < SV::tiles; offset+=8, t++) {
|
||||
dst.data[t][0] = base_types::convertor<RV_T, SV_T>::convert(src[offset]);
|
||||
}
|
||||
} else if (ducks::is_naive_layout<typename RV::layout>()) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(auto w = 0; w < RV::outer_dim; w++) {
|
||||
if (w * SIMD_THREADS + laneid < RV::length) {
|
||||
dst.data[w][0] = base_types::convertor<RV_T, SV_T>::convert(src[w * SIMD_THREADS + laneid]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Store data into a shared vector from a register vector.
|
||||
*
|
||||
* @tparam RV The register vector type
|
||||
* @tparam SV The shared vector type
|
||||
* @param dst[out] The destination shared vector.
|
||||
* @param src[in] The source register vector.
|
||||
*/
|
||||
// optimize later
|
||||
template<typename SV, typename RV>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
store(threadgroup SV &dst, thread const RV &src, const short laneid) {
|
||||
ducks::assert_shared_vector<SV>();
|
||||
ducks::assert_register_vector<RV>();
|
||||
using RV_T = typename RV::dtype;
|
||||
using RV_T2 = typename base_types::packing<RV_T>::packed_type;
|
||||
using SV_T = typename SV::dtype;
|
||||
using SV_T2 = typename base_types::packing<SV_T>::packed_type;
|
||||
|
||||
|
||||
static_assert(SV::tiles == RV::tiles, "RV and SV dimensions must match");
|
||||
|
||||
if (ducks::is_align_layout<typename RV::layout>()) {
|
||||
constexpr const uint32_t MASK_1 = 0x00AA00AA; // kitty bit magic
|
||||
constexpr const uint32_t MASK_2 = 0x55005500;
|
||||
constexpr const uint32_t MASK_3 = 0xAA00AA00;
|
||||
unsigned offset = ((MASK_1 >> laneid) & 1u) * 2 + ((MASK_2 >> laneid) & 1u) * 4 + ((MASK_3 >> laneid) & 1u) * 6;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = 0; t < SV::tiles; offset+=8, t++) {
|
||||
SV_T2 src2 = base_types::convertor<SV_T2, RV_T2>::convert({src.data[t][0], src.data[t][1]});
|
||||
*(threadgroup SV_T2*)(&dst.data[offset]) = src2;
|
||||
|
||||
// *(threadgroup SV_T2*)(&dst.data[offset]) = (SV_T2)1.f;
|
||||
}
|
||||
} else if (ducks::is_ortho_layout<typename RV::layout>()) {
|
||||
const short laneid_div2 = laneid / 2;
|
||||
unsigned offset = laneid_div2 % 4 + (laneid_div2 / 8) * 4;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = 0; t < SV::tiles; offset+=8, t++) {
|
||||
dst[offset] = base_types::convertor<SV_T, RV_T>::convert(src.data[t][0]);
|
||||
}
|
||||
} else if (ducks::is_naive_layout<typename RV::layout>()) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(auto w = 0; w < RV::outer_dim; w++) {
|
||||
if (w * SIMD_THREADS + laneid < RV::length) {
|
||||
dst[w * SIMD_THREADS + laneid] = base_types::convertor<SV_T, RV_T>::convert(src.data[w][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
///// TRASH CAN
|
||||
|
||||
/*
|
||||
template<typename RV, typename SV>
|
||||
METAL_FUNC static typename metal::enable_if<ducks::is_register_vector<RV>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
load(thread RV &dst, threadgroup const SV &src, const short laneid, const int start_tile, const int size_tile) {
|
||||
using RV_T = typename RV::dtype;
|
||||
using RV_T2 = typename base_types::packing<RV_T>::packed_type;
|
||||
using SV_T = typename SV::dtype;
|
||||
using SV_T2 = typename base_types::packing<SV_T>::packed_type;
|
||||
|
||||
|
||||
// static_assert(RV::tiles == size_tile , "RV and SV dimensions must match");
|
||||
|
||||
if (ducks::is_align_layout<typename RV::layout>()) {
|
||||
constexpr const uint32_t MASK_1 = 0x00AA00AA; // kitty bit magic
|
||||
constexpr const uint32_t MASK_2 = 0x55005500;
|
||||
constexpr const uint32_t MASK_3 = 0xAA00AA00;
|
||||
unsigned offset = ((MASK_1 >> laneid) & 1u) * 2 + ((MASK_2 >> laneid) & 1u) * 4 + ((MASK_3 >> laneid) & 1u) * 6
|
||||
+ 8 * start_tile;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = start_tile; t < start_tile + size_tile; offset+=8, t++) {
|
||||
// RV_T2 src2 = base_types::convertor<RV_T2, SV_T2>::convert(*(threadgroup SV_T2*)(&src.data[offset]));
|
||||
// dst.data[t][0] = src2[0];
|
||||
// dst.data[t][1] = src2[1];
|
||||
}
|
||||
} else if (ducks::is_ortho_layout<typename RV::layout>()) {
|
||||
const short laneid_div2 = laneid / 2;
|
||||
unsigned offset = laneid_div2 % 4 + (laneid_div2 / 8) * 4
|
||||
+ 8 * start_tile;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int t = start_tile; t < start_tile + size_tile; offset+=8, t++) {
|
||||
dst.data[t][0] = base_types::convertor<RV_T, SV_T>::convert(src[offset]);
|
||||
}
|
||||
}
|
||||
// else if (ducks::is_naive_layout<typename RV::layout>()) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(auto w = 0; w < RV::outer_dim; w++) {
|
||||
// if (w * SIMD_THREADS + laneid < RV::length) {
|
||||
// dst.data[w][0] = base_types::convertor<RV_T, SV_T>::convert(src[w * SIMD_THREADS + laneid]);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
*/
|
||||
4
extra/thunder/include/ops/warp/memory/vec/vec.metal
Normal file
4
extra/thunder/include/ops/warp/memory/vec/vec.metal
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "global_to_register.metal"
|
||||
#include "global_to_shared.metal"
|
||||
#include "shared_to_register.metal"
|
||||
3
extra/thunder/include/ops/warp/register/register.metal
Normal file
3
extra/thunder/include/ops/warp/register/register.metal
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
#include "tile/tile.metal"
|
||||
#include "vec/vec.metal"
|
||||
313
extra/thunder/include/ops/warp/register/tile/conversions.metal
Normal file
313
extra/thunder/include/ops/warp/register/tile/conversions.metal
Normal file
@@ -0,0 +1,313 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Conversions between data layouts and types for register tiles.
|
||||
*/
|
||||
|
||||
#pragma once // not done:
|
||||
/*
|
||||
swaping register layout doesn't exist. no layout to swap
|
||||
SUBTILE
|
||||
|
||||
*/
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
/* ---------- TRANSPOSE ---------- */
|
||||
METAL_FUNC int compute_laneid(ushort y, ushort x) {
|
||||
// Extract bits from simd_y
|
||||
ushort b1 = y & 1;
|
||||
ushort temp_y = y >> 1;
|
||||
ushort b2 = temp_y & 1;
|
||||
ushort b4 = temp_y >> 1;
|
||||
|
||||
// Extract bits from simd_x
|
||||
ushort b0 = (x >> 1) & 1;
|
||||
ushort b3 = x >> 2;
|
||||
|
||||
// Reconstruct laneid
|
||||
ushort laneid = (b4 << 4) | (b3 << 3) | (b2 << 2) | (b1 << 1) | b0;
|
||||
return laneid;
|
||||
}
|
||||
/**
|
||||
* @brief Transposes a register base tile.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param dst[out] Reference to the register tile in which to store the transposed src.
|
||||
* @param src[in] Reference to the register base tile to be transposed.
|
||||
*/
|
||||
template<typename T, typename layout>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_rt_layout<layout>(), void>::type
|
||||
swap_layout(thread rt_base<T, typename ducks::rt_layout::transpose<layout>::type> &dst,
|
||||
thread const rt_base<T, layout> &src,
|
||||
const ushort laneid) {
|
||||
const ushort qid = laneid / 4;
|
||||
const ushort simd_y = (qid & 4) + (laneid / 2) % 4;
|
||||
const ushort simd_x = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
const ushort src_laneid_start = compute_laneid(simd_x, simd_y);
|
||||
const ushort2 src_laneid = ushort2(src_laneid_start, src_laneid_start+(ushort)2);
|
||||
const ushort first_idx = (laneid / 2) % 2;
|
||||
|
||||
dst.data.thread_elements()[first_idx] = shfl_sync<T>(src.data.thread_elements()[first_idx], src_laneid[first_idx]);
|
||||
|
||||
dst.data.thread_elements()[1 - first_idx] = shfl_sync<T>(src.data.thread_elements()[1 - first_idx], src_laneid[1 - first_idx]);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Swaps the layout of a register tile.
|
||||
*
|
||||
* This function swaps the layout of a register tile by iterating over its height and width
|
||||
* and performing layout swaps on each of its base elements.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height of the register tile.
|
||||
* @tparam _width The width of the register tile.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param dst[out] Reference to the destination register tile where the result will be stored.
|
||||
* @param src[in] Reference to the source register tile to be swapped.
|
||||
*/
|
||||
template<typename T, int _height, int _width, typename layout>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_rt_layout<layout>(), void>::type
|
||||
swap_layout(thread rt<T, _height, _width, typename ducks::rt_layout::transpose<layout>::type> &dst, thread const rt<T, _height, _width, layout> &src, const short laneid) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
swap_layout(dst.tiles[i][j], src.tiles[i][j], laneid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Swaps the layout of a register base tile in place.
|
||||
*
|
||||
* This function swaps the layout of a register base tile in place by casting it to the
|
||||
* transposed layout type and then performing the layout swap.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param src[in] Reference to the register base tile to be swapped in place.
|
||||
* @return A reference to the swapped register base tile.
|
||||
*/
|
||||
template<typename T2, typename layout>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_rt_layout<layout>(), thread rt_base<T2, typename ducks::rt_layout::transpose<layout>::type>&>::type
|
||||
swap_layout_inplace(thread const rt_base<T2, layout> &src) {
|
||||
thread rt_base<T2, typename ducks::rt_layout::transpose<layout>::type> &dst = *(thread rt_base<T2, typename ducks::rt_layout::transpose<layout>::type>*)(&src);
|
||||
swap_layout(dst, src);
|
||||
return dst;
|
||||
}
|
||||
|
||||
/* ---------- TRANSPOSE ---------- */
|
||||
|
||||
/**
|
||||
* @brief Transposes a register base tile.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param dst[out] Reference to the register tile in which to store the transposed src.
|
||||
* @param src[in] Reference to the register base tile to be transposed.
|
||||
*/
|
||||
template<typename T, typename layout>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_rt_layout<layout>(), void>::type
|
||||
transpose(thread rt_base<T, layout> &dst, thread const rt_base<T, layout> &src, const ushort laneid) {
|
||||
const ushort qid = laneid / 4;
|
||||
const ushort simd_y = (qid & 4) + (laneid / 2) % 4;
|
||||
const ushort simd_x = (qid & 2) * 2 + (laneid % 2) * 2;
|
||||
|
||||
const ushort src_laneid_start = compute_laneid(simd_x, simd_y);
|
||||
const ushort2 src_laneid = ushort2(src_laneid_start, src_laneid_start+(ushort)2);
|
||||
const ushort first_idx = (laneid / 2) % 2;
|
||||
|
||||
dst.data.thread_elements()[first_idx] = shfl_sync<T>(src.data.thread_elements()[first_idx], src_laneid[first_idx]);
|
||||
|
||||
dst.data.thread_elements()[1 - first_idx] = shfl_sync<T>(src.data.thread_elements()[1 - first_idx], src_laneid[1 - first_idx]);
|
||||
}
|
||||
/**
|
||||
* @brief Transposes a register tile.
|
||||
*
|
||||
* This function is marked "sep", which means that the registers underlying dst MUST be separate
|
||||
* from the registers underlying src.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height of the src register tile, and the width of the dst tile.
|
||||
* @tparam _width The width of the src register tile, and the height of the dst tile.
|
||||
* @tparam layout The layout of the register tile.
|
||||
* @param dst[out] Reference to the register tile in which to store the transposed src.
|
||||
* @param src[in] Reference to the register tile to be transposed.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
transpose_sep(thread RT &dst, thread const rt<typename RT::T, RT::cols, RT::rows, typename RT::layout> &src,
|
||||
const int laneid) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < RT::height; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < RT::width; j++) {
|
||||
transpose(dst.tiles[i][j], src.tiles[j][i], laneid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Transposes a register base tile in-place.
|
||||
*
|
||||
* @tparam T2 The data type of the register base tile elements.
|
||||
* @tparam layout The current layout of the register base tile.
|
||||
* @param src[in] Reference to the register tile to be transposed.
|
||||
* @return A reference to the transposed register base tile.
|
||||
*/
|
||||
template<typename T2, typename layout>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_rt_layout<layout>(), thread rt_base<T2, layout>&>::type
|
||||
transpose_inplace(thread rt_base<T2, layout> &src, const ushort laneid) {
|
||||
transpose(src, src, laneid);
|
||||
return src;
|
||||
}
|
||||
|
||||
template<typename T, typename U, typename layout>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_rt_layout<layout>(), void>::type
|
||||
copy(thread rt_base<T, layout> &dst, thread const rt_base<U, layout> &src);
|
||||
|
||||
/**
|
||||
* @brief Transposes a square register tile in-place.
|
||||
*
|
||||
* @tparam T2 The data type of the register tile elements.
|
||||
* @tparam _height The height (in units of 16) of the src register tile, and the width of the dst tile. (Must be the same as _width.)
|
||||
* @tparam _width The width (in units of 16) of the src register tile, and the height of the dst tile. (Must be the same as _height.)
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param src[in] Reference to the register tile to be transposed.
|
||||
* @return A reference to the transposed register tile.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && RT::cols == RT::rows, thread RT&>::type
|
||||
transpose_inplace(thread RT &tile, const ushort laneid) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < tile.height; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < i; j++) {
|
||||
rt_base<typename RT::T, typename RT::layout> tmp;
|
||||
copy(tmp, tile.tiles[i][j]);
|
||||
transpose(tile.tiles[i][j], tile.tiles[j][i], laneid);
|
||||
transpose(tile.tiles[j][i], tmp, laneid);
|
||||
}
|
||||
transpose_inplace(tile.tiles[i][i], laneid);
|
||||
}
|
||||
return tile;
|
||||
}
|
||||
/* ---------- TYPE SWAPS ---------- */
|
||||
/**
|
||||
* @brief Copies a register base tile, converting the underlying type if necessary.
|
||||
*
|
||||
* @tparam T2 The data type of the destination register elements.
|
||||
* @tparam U2 The data type of the source register elements.
|
||||
* @tparam layout The current layout of the register base tile.
|
||||
* @param[out] dst A reference to the destination register base tile.
|
||||
* @param[in] src A reference to the source register base tile.
|
||||
*/
|
||||
template<typename T, typename U, typename layout>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_rt_layout<layout>(), void>::type
|
||||
copy(thread rt_base<T, layout> &dst, thread const rt_base<U, layout> &src) {
|
||||
using T1 = typename base_types::packing<T>::unpacked_type;
|
||||
using U1 = typename base_types::packing<U>::unpacked_type;
|
||||
dst.data.thread_elements()[0] = base_types::convertor<T1, U1>::convert(src.data.thread_elements()[0]);
|
||||
dst.data.thread_elements()[1] = base_types::convertor<T1, U1>::convert(src.data.thread_elements()[1]);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Copies a register tile, converting the underlying type if necessary.
|
||||
*
|
||||
* @tparam T2 The data type of the destination register elements.
|
||||
* @tparam U2 The data type of the source register elements.
|
||||
* @tparam _height The height (in units of 8) of the register tiles.
|
||||
* @tparam _width The width (in units of 8) of the register tiles.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param[out] dst A reference to the destination register tile.
|
||||
* @param[in] src A reference to the source register tile.
|
||||
*/
|
||||
template<typename T, typename U, int _height, int _width, typename layout>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_rt_layout<layout>(), void>::type
|
||||
copy(thread rt<T, _height, _width, layout> &dst, thread const rt<U, _height, _width, layout> &src) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
copy(dst.tiles[i][j], src.tiles[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- CAUSAL ---------- */
|
||||
|
||||
/**
|
||||
* @brief Makes a square register tile causal by zeroing elements above the main diagonal.
|
||||
*
|
||||
* This function modifies a square register tile in-place to make it causal. All elements
|
||||
* above the main diagonal are set to zero, while elements on or below the main diagonal
|
||||
* are left unchanged.
|
||||
*
|
||||
* @tparam T The data type of the register tile elements.
|
||||
* @tparam _size The size (height and width) of the square register tile.
|
||||
* @tparam layout The current layout of the register tile.
|
||||
* @param tile[in,out] Reference to the register tile to be made causal.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
make_causal(thread RT &dst, thread const RT &src, const unsigned laneid, thread const typename base_types::packing<typename RT::dtype>::unpacked_type &val=0) {
|
||||
ducks::assert_register_tile<RT>();
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
if(j < i) { // below the diagonal, copy
|
||||
dst.tiles[i][j].data.thread_elements()[0] = src.tiles[i][j].data.thread_elements()[0];
|
||||
dst.tiles[i][j].data.thread_elements()[1] = src.tiles[i][j].data.thread_elements()[1];
|
||||
}
|
||||
else if(j > i) { // above the diagonal, zero
|
||||
dst.tiles[i][j].data.thread_elements()[0] = val;
|
||||
dst.tiles[i][j].data.thread_elements()[1] = val;
|
||||
}
|
||||
else { // on the diagonal
|
||||
constexpr uint32_t MASK_0 = (ducks::is_row_register_tile<RT>()) ? 0x0A00FF0A : 0xD4FF00D4;
|
||||
constexpr uint32_t MASK_1 = (ducks::is_row_register_tile<RT>()) ? 0x2B00FF2B : 0x50FF0050;
|
||||
if((MASK_0 >> laneid) & 1) {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = val;
|
||||
}
|
||||
else {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = src.tiles[i][j].data.thread_elements()[0];
|
||||
}
|
||||
if((MASK_1 >> laneid) & 1) {
|
||||
dst.tiles[i][j].data.thread_elements()[1] = val;
|
||||
}
|
||||
else {
|
||||
dst.tiles[i][j].data.thread_elements()[1] = src.tiles[i][j].data.thread_elements()[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* ---------- SUBTILE ---------- */
|
||||
|
||||
/**
|
||||
* @brief Returns a reference to a subtile of the given tile.
|
||||
*
|
||||
* @tparam subtile_height The height of the subtile.
|
||||
* @tparam RT The type of the input tile, which must satisfy the ducks::rt::all concept.
|
||||
* @param src The input tile.
|
||||
* @param idx The index of the subtile.
|
||||
* @return A reference to the subtile.
|
||||
*
|
||||
* @note The subtile height must evenly divide the tile height.
|
||||
*/
|
||||
//template<int subtile_height, ducks::rt::all RT>
|
||||
//__device__ inline rt<typename RT::T, subtile_height, RT::width, typename RT::layout> &subtile_inplace(RT & src, int idx) {
|
||||
// static_assert(RT::height % subtile_height == 0, "subtile height should evenly divide tile height.");
|
||||
// return reinterpret_cast<rt<typename RT::T, subtile_height, RT::width, typename RT::layout>&>(
|
||||
// src.tiles[idx*subtile_height]
|
||||
// );
|
||||
//}
|
||||
|
||||
}
|
||||
878
extra/thunder/include/ops/warp/register/tile/maps.metal
Normal file
878
extra/thunder/include/ops/warp/register/tile/maps.metal
Normal file
@@ -0,0 +1,878 @@
|
||||
#pragma once // doneington but add register tile col
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
/* ---------- Uniform tile maps (independent of layout) ---------- */
|
||||
|
||||
namespace meta {
|
||||
template<typename op, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
unary_map_unroll(int i, int j, thread RT *dst, thread const RT *src) {
|
||||
using T2 = typename RT::T2;
|
||||
T2 vals = op::template op<T2>(T2{src->tiles[i][j].data.thread_elements()[0], src->tiles[i][j].data.thread_elements()[1]});
|
||||
dst->tiles[i][j].data.thread_elements()[0] = vals[0];
|
||||
dst->tiles[i][j].data.thread_elements()[1] = vals[1];
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Applies a unary operation to each element of a tile.
|
||||
*
|
||||
* @tparam op Unary operation to apply.
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
*/
|
||||
template<typename op, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
unary_map(thread RT &dst, thread const RT &src) {
|
||||
using T = typename RT::T;
|
||||
ducks::assert_register_tile<RT>();
|
||||
using T2 = typename RT::T2;
|
||||
using T4 = typename base_types::packing<typename RT::dtype>::packed_four;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < dst.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < dst.width; j++) {
|
||||
// T2 op2 = op::template op<T2>(T2{src.tiles[i][j].data.thread_elements()[0], src.tiles[i][j].data.thread_elements()[1]});
|
||||
//// dst.tiles[i][j].data.thread_elements()[0] = op::template op<typename RT::dtype>(src.tiles[i][j].data.thread_elements()[0]);
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = op::template op<typename RT::dtype>(src.tiles[i][j].data.thread_elements()[1]);
|
||||
//
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = op2[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = op2[1];
|
||||
//
|
||||
//// dst.tiles[i][j].data.thread_elements()[0] = base_ops::abs::template op<T>(src.tiles[i][j].data.thread_elements()[0]);
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = base_ops::abs::template op<T>(src.tiles[i][j].data.thread_elements()[1]);
|
||||
//// dst.tiles[i][j].data.thread_elements()[0] = (T)(metal::abs(-1.f));
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = (T)(metal::abs(-1.f));
|
||||
//
|
||||
//// ((T)(((float)src.tiles[i][j].data.thread_elements()[0])));
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = metal::abs((T)((float)src.tiles[i][j].data.thread_elements()[1]));
|
||||
//
|
||||
//// dst.tiles[i][j].data.thread_elements()[0] = base_types::constants<typename RT::dtype>::one();
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = base_types::constants<typename RT::dtype>::one();
|
||||
//// metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
//
|
||||
////// T2 val = op::template op<T2>(T2{src.tiles[i][j].data.thread_elements()[0],
|
||||
////// src.tiles[i][j].data.thread_elements()[1]});
|
||||
////// dst.tiles[i][j].data.thread_elements()[0] = val[0];
|
||||
////// dst.tiles[i][j].data.thread_elements()[1] = val[1];
|
||||
////////
|
||||
////// T4 val = op::template op<T4>(T4{src.tiles[i][j].data.thread_elements()[0],
|
||||
////// src.tiles[i][j].data.thread_elements()[1],
|
||||
////// src.tiles[i][j+1].data.thread_elements()[0],
|
||||
////// src.tiles[i][j+1].data.thread_elements()[1],});
|
||||
////// dst.tiles[i][j].data.thread_elements()[0] = val[0];
|
||||
////// dst.tiles[i][j].data.thread_elements()[1] = val[1];
|
||||
////// dst.tiles[i][j+1].data.thread_elements()[0] = val[2];
|
||||
////// dst.tiles[i][j+1].data.thread_elements()[1] = val[3];
|
||||
// }
|
||||
// }
|
||||
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::unary_map_unroll<op, RT>, &dst, &src);
|
||||
}
|
||||
|
||||
|
||||
namespace meta {
|
||||
template<typename op, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
bin_map_unroll(int i, int j, thread RT *dst, thread const RT *src, thread const typename RT::dtype *param) {
|
||||
using T = typename RT::T;
|
||||
using T2 = typename RT::T2;
|
||||
// T2 vals = op::template op<T2>({src->tiles[i][j].data.thread_elements()[0], src->tiles[i][j].data.thread_elements()[1]}, {*param, *param});
|
||||
// dst->tiles[i][j].data.thread_elements()[0] = vals[0];
|
||||
// dst->tiles[i][j].data.thread_elements()[1] = vals[1];
|
||||
dst->tiles[i][j].data.thread_elements()[0] = op::template op<T>(src->tiles[i][j].data.thread_elements()[0], *param);
|
||||
dst->tiles[i][j].data.thread_elements()[1] = op::template op<T>(src->tiles[i][j].data.thread_elements()[1], *param);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Applies a binary operation to each element of a tile with a scalar parameter.
|
||||
*
|
||||
* @tparam op Binary operation to apply.
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param param[in] Scalar parameter for the binary operation.
|
||||
*/
|
||||
template<typename op, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
bin_map(thread RT &dst, thread const RT &src, thread const typename RT::dtype ¶m) {
|
||||
// using T = typename RT::T;
|
||||
// using T2 = typename RT::T2;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < dst.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < dst.width; j++) {
|
||||
// T2 vals = op::template op<T2>({src.tiles[i][j].data.thread_elements()[0], src.tiles[i][j].data.thread_elements()[1]}, {param, param});
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = vals[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = vals[1];
|
||||
//// dst.tiles[i][j].data.thread_elements()[0] = op::template op<typename RT::dtype>(src.tiles[i][j].data.thread_elements()[0], param);
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = op::template op<typename RT::dtype>(src.tiles[i][j].data.thread_elements()[1], param);
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::bin_map_unroll<op, RT>, &dst, &src, ¶m);
|
||||
}
|
||||
|
||||
namespace meta {
|
||||
template<typename op, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
binary_map_unroll(int i, int j, thread RT *dst, thread const RT *lhs, thread const RT *rhs) {
|
||||
using T2 = typename RT::T2;
|
||||
using T4 = typename base_types::packing<typename RT::dtype>::packed_four;
|
||||
dst->tiles[i][j].data.thread_elements()[0] = op::template op<typename RT::dtype>(lhs->tiles[i][j].data.thread_elements()[0],
|
||||
rhs->tiles[i][j].data.thread_elements()[0]);
|
||||
dst->tiles[i][j].data.thread_elements()[1] = op::template op<typename RT::dtype>(lhs->tiles[i][j].data.thread_elements()[1],
|
||||
rhs->tiles[i][j].data.thread_elements()[1]);
|
||||
// T2 vals = op::template op<T2>({lhs->tiles[i][j].data.thread_elements()[0], lhs->tiles[i][j].data.thread_elements()[1]},
|
||||
// {rhs->tiles[i][j].data.thread_elements()[0], rhs->tiles[i][j].data.thread_elements()[1]});
|
||||
////
|
||||
// dst->tiles[i][j].data.thread_elements()[0] = vals[0];
|
||||
// dst->tiles[i][j].data.thread_elements()[1] = vals[1];
|
||||
|
||||
// dst->tiles[i][j].data.thread_elements()[0] = op::template op<typename RT::dtype>(lhs->tiles[i][j].data.thread_elements()[0],
|
||||
// rhs->tiles[i][j].data.thread_elements()[0]);
|
||||
// dst->tiles[i][j].data.thread_elements()[1] = op::template op<typename RT::dtype>(lhs->tiles[i][j].data.thread_elements()[1],
|
||||
// rhs->tiles[i][j].data.thread_elements()[1]);
|
||||
// T4 val = op::template op<T4>(T4{src->tiles[i][j].data.thread_elements()[0],
|
||||
// src->tiles[i][j].data.thread_elements()[1],
|
||||
// src->tiles[i][j+1].data.thread_elements()[0],
|
||||
// src->tiles[i][j+1].data.thread_elements()[1]});
|
||||
// dst->tiles[i][j].data.thread_elements()[0] = val[0];
|
||||
// dst->tiles[i][j].data.thread_elements()[1] = val[1];
|
||||
// dst->tiles[i][j+1].data.thread_elements()[0] = val[2];
|
||||
// dst->tiles[i][j+1].data.thread_elements()[1] = val[3];
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Applies a binary operation element-wise between two tiles.
|
||||
*
|
||||
* @tparam op Binary operation to apply.
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the operation.
|
||||
* @param rhs[in] Right-hand side source tile for the operation.
|
||||
*/
|
||||
template<typename op, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
bin_map(thread RT &dst, thread const RT &lhs, thread const RT &rhs) {
|
||||
using T = typename RT::dtype;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < dst.height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < dst.width; j++) {
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = op::template op<typename RT::dtype>(lhs.tiles[i][j].data.thread_elements()[0],
|
||||
// rhs.tiles[i][j].data.thread_elements()[0]);
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = op::template op<typename RT::dtype>(lhs.tiles[i][j].data.thread_elements()[1],
|
||||
// rhs.tiles[i][j].data.thread_elements()[1]);
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = lhs.tiles[i][j].data.thread_elements()[0] + rhs.tiles[i][j].data.thread_elements()[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = lhs.tiles[i][j].data.thread_elements()[1] + rhs.tiles[i][j].data.thread_elements()[1];
|
||||
////
|
||||
// T2 vals = op::template op<T2>(T2(lhs.tiles[i][j].data.thread_elements()[0], lhs.tiles[i][j].data.thread_elements()[1]),
|
||||
// T2(rhs.tiles[i][j].data.thread_elements()[0], rhs.tiles[i][j].data.thread_elements()[1]));
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = vals[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = vals[1];
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::binary_map_unroll<op, RT>, &dst, &lhs, &rhs);
|
||||
}
|
||||
|
||||
/* ---------- Row tile maps ----------*/
|
||||
|
||||
namespace meta {
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_map_unroll(int i, int j, thread RT *dst, thread const RT *src, thread const RV *row_values) {
|
||||
using T2 = typename RT::T2;
|
||||
T2 val = op::template op<T2>({src->tiles[i][j].data.thread_elements()[0], src->tiles[i][j].data.thread_elements()[1]}, {(*row_values)[i][0], (*row_values)[i][0]});
|
||||
dst->tiles[i][j].data.thread_elements()[0] = val[0];
|
||||
dst->tiles[i][j].data.thread_elements()[1] = val[1];
|
||||
}
|
||||
|
||||
}
|
||||
/**
|
||||
* @brief Applies an operation across the rows of a tile in a row-major layout.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param row_values[in] Column vector containing values to apply across each row.
|
||||
*/
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_map(thread RT &dst, thread const RT &src, thread const RV &row_values) {
|
||||
static_assert(ducks::is_ortho_layout<typename RV::layout>(), "RV must be otho layout (col vec for row rt)");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rt and rv must be of same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::height, "RV outer dim and RT height do not match"); // compatible size
|
||||
using T4 = typename base_types::packing<typename RT::dtype>::packed_four;
|
||||
using T2 = typename RT::T2;
|
||||
using T = typename RT::dtype;
|
||||
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < RT::height; i++) {
|
||||
// T row_val = row_values[i][0];
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < RT::width; j++) {
|
||||
// T2 val = op::template op<T2>({src.tiles[i][j].data.thread_elements()[0], src.tiles[i][j].data.thread_elements()[1]}, {row_val, row_val});
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = val[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = val[1];
|
||||
//// dst.tiles[i][j].data.thread_elements()[0] = op::template op<T>(src.tiles[i][j].data.thread_elements()[0], row_values[i][0]);
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = op::template op<T>(src.tiles[i][j].data.thread_elements()[1], row_values[i][0]);
|
||||
// }
|
||||
// }
|
||||
meta::unroll_i_j_in_range<0, RT::height, 1, 0, RT::width, 1>::run(meta::row_map_unroll<op, RT, RV>, &dst, &src, &row_values);
|
||||
|
||||
|
||||
// meta::unroll_i_j_in_range<0, RT::height, 1,
|
||||
// 0, (RT::width / 2) * 2, 2>::run(meta::row_map_unroll<op, RT, RV, 0, 1>, &dst, &src, &row_values);
|
||||
// meta::unroll_i_j_in_range<0, (RT::height / 2) * 2, 2,
|
||||
// (RT::width / 2) * 2, RT::width, 1>::run(meta::row_map_unroll<op, RT, RV, 1, 0>, &dst, &src, &row_values);
|
||||
//
|
||||
// meta::unroll_i_j_in_range<(RT::height / 2) * 2, RT::height, 1,
|
||||
// (RT::width / 2) * 2, RT::width, 1>::run(meta::row_map_unroll<op, RT, RV>, &dst, &src, &row_values);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the rows of a tile in a row-major layout.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param row_values[in] Column vector containing values to apply across each row.
|
||||
*/
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_map(thread RT &dst, thread const RT &src, thread const RV &row_values) {
|
||||
static_assert(ducks::is_align_layout<typename RV::layout>(), "RV must be align layout (col vec for col rt)");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rt and rv must be of same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::height, "RV outer dim and RT height do not match"); // compatible size
|
||||
using T4 = typename base_types::packing<typename RT::dtype>::packed_four;
|
||||
using T2 = typename RT::T2;
|
||||
using T = typename RT::dtype;
|
||||
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < RT::height; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < RT::width; j++) {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = op::template op<T>(src.tiles[i][j].data.thread_elements()[0], row_values[i][0]);
|
||||
dst.tiles[i][j].data.thread_elements()[1] = op::template op<T>(src.tiles[i][j].data.thread_elements()[1], row_values[i][1]);
|
||||
}
|
||||
}
|
||||
//
|
||||
// meta::unroll_i_j_in_range<0, RT::height, 1,
|
||||
// 0, (RT::width / 2) * 2, 2>::run(meta::row_map_unroll<op, RT, RV, 0, 1>, &dst, &src, &row_values);
|
||||
// meta::unroll_i_j_in_range<0, (RT::height / 2) * 2, 2,
|
||||
// (RT::width / 2) * 2, RT::width, 1>::run(meta::row_map_unroll<op, RT, RV, 1, 0>, &dst, &src, &row_values);
|
||||
//
|
||||
// meta::unroll_i_j_in_range<(RT::height / 2) * 2, RT::height, 1,
|
||||
// (RT::width / 2) * 2, RT::width, 1>::run(meta::row_map_unroll<op, RT, RV>, &dst, &src, &row_values);
|
||||
}
|
||||
|
||||
// Three-operand row map. Mostly useful for FMA instructions.
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the rows of two tiles in a row-major layout, using a third operand.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param a[in] First source tile to apply the operation on.
|
||||
* @param b[in] Second source tile to apply the operation on.
|
||||
* @param row_values[in] Column vector containing values to apply across each row.
|
||||
*/
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_map(thread RT &dst, thread const RT &a, thread const RT &b, thread const RV &row_values) {
|
||||
static_assert(ducks::is_ortho_layout<RV::layout>(), "rv must be ortho layout for row rt");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rt and rv must be same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::height, "rv and rt dimensions don't match"); // compatible size
|
||||
|
||||
|
||||
using dtype = typename RT::dtype;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
dtype vec_val = row_values[i][0];
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = op::template op<dtype>(a.tiles[i][j].data.thread_elements()[0], b.tiles[i][j].data.thread_elements()[0], vec_val);
|
||||
|
||||
dst.tiles[i][j].data.thread_elements()[1] = op::template op<dtype>(a.tiles[i][j].data.thread_elements()[1], b.tiles[i][j].data.thread_elements()[1], vec_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the rows of two tiles in a column-major layout, using a third operand.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with column-major layout.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param a[in] First source tile to apply the operation on.
|
||||
* @param b[in] Second source tile to apply the operation on.
|
||||
* @param row_values[in] Column vector containing values to apply across each row.
|
||||
*/
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_map(thread RT &dst, thread const RT &a, thread const RT &b, thread const RV &row_values) {
|
||||
static_assert(ducks::is_align_layout<RV::layout>(), "rv must be align layout for row rt");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rt and rv must be same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::height, "rv and rt dimensions don't match"); // compatible size
|
||||
|
||||
|
||||
using dtype = typename RT::dtype;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = op::template op<dtype>(a.tiles[i][j].data.thread_elements()[0], b.tiles[i][j].data.thread_elements()[0], row_values[i][0]);
|
||||
|
||||
dst.tiles[i][j].data.thread_elements()[1] = op::template op<dtype>(a.tiles[i][j].data.thread_elements()[1], b.tiles[i][j].data.thread_elements()[1], row_values[i][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- Col major tile maps ----------*/
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the columns of a tile in a row-major layout.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param col_values[in] Row vector containing values to apply across each column.
|
||||
*/
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_map(thread RT &dst, thread const RT &src, thread const RV &col_values) {
|
||||
static_assert(ducks::is_align_layout<typename RV::layout>(), "rv must be align layout for row rt"); // compatible type
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rv and rt must be of the same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::width, "rv and rt dimensions do not match"); // compatible size
|
||||
|
||||
using dtype = typename RT::dtype;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = op::template op<dtype>(src.tiles[i][j].data.thread_elements()[0], col_values[j][0]);
|
||||
dst.tiles[i][j].data.thread_elements()[1] = op::template op<dtype>(src.tiles[i][j].data.thread_elements()[1], col_values[j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the columns of a tile in a col-major layout.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the operation on.
|
||||
* @param col_values[in] Row vector containing values to apply across each column.
|
||||
*/
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_map(thread RT &dst, thread const RT &src, thread const RV &col_values) {
|
||||
static_assert(ducks::is_ortho_layout<typename RV::layout>(), "rv must be ortho layout for row rt"); // compatible type
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rv and rt must be of the same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::width, "rv and rt dimensions do not match"); // compatible size
|
||||
|
||||
using dtype = typename RT::dtype;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = op::template op<dtype>(src.tiles[i][j].data.thread_elements()[0], col_values[j][0]);
|
||||
dst.tiles[i][j].data.thread_elements()[1] = op::template op<dtype>(src.tiles[i][j].data.thread_elements()[1], col_values[j][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Three-operand col map
|
||||
/**
|
||||
* @brief Applies an operation across the columns of two tiles in a row-major layout, using a third operand.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param a[in] First source tile to apply the operation on.
|
||||
* @param b[in] Second source tile to apply the operation on.
|
||||
* @param col_values[in] Row vector containing values to apply across each column.
|
||||
*/
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_map(thread RT &dst, thread const RT &a, thread const RT &b, thread const RV &col_values) {
|
||||
static_assert(ducks::is_align_layout<RV::layout>(), "rv must be align layout");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rv and rt must be of the same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::width, "rv and rt dims don't match"); // compatible size
|
||||
|
||||
|
||||
using dtype = typename RT::dtype;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = op::template op<dtype>(a.tiles[i][j].data.thread_elements()[0], b.tiles[i][j].data.thread_elements()[0], col_values[j][0]);
|
||||
dst.tiles[i][j].data.thread_elements()[1] = op::template op<dtype>(a.tiles[i][j].data.thread_elements()[1], b.tiles[i][j].data.thread_elements()[1], col_values[j][1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies an operation across the columns of two tiles in a row-major layout, using a third operand.
|
||||
*
|
||||
* @tparam op Operation to apply.
|
||||
* @tparam T Tile type with row-major layout.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param a[in] First source tile to apply the operation on.
|
||||
* @param b[in] Second source tile to apply the operation on.
|
||||
* @param col_values[in] Row vector containing values to apply across each column.
|
||||
*/
|
||||
template<typename op, typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_map(thread RT &dst, thread const RT &a, thread const RT &b, thread const RV &col_values) {
|
||||
static_assert(ducks::is_ortho_layout<RV::layout>(), "rv must be ortho layout");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rv and rt must be of the same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::width, "rv and rt dims don't match"); // compatible size
|
||||
|
||||
|
||||
using dtype = typename RT::dtype;
|
||||
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.width; j++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.height; i++) {
|
||||
dst.tiles[i][j].data.thread_elements()[0] = op::template op<dtype>(a.tiles[i][j].data.thread_elements()[0], b.tiles[i][j].data.thread_elements()[0], col_values[j][0]);
|
||||
dst.tiles[i][j].data.thread_elements()[1] = op::template op<dtype>(a.tiles[i][j].data.thread_elements()[1], b.tiles[i][j].data.thread_elements()[1], col_values[j][0]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// All of the annoying qualifiers *should* be automatically inferred during compile-time.
|
||||
// So, syntax should just be mittens::add_row(tile, colvec);
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a tile to zero.
|
||||
*
|
||||
* @tparam RT Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
zero(thread RT &dst) {
|
||||
unary_map<base_ops::zero, RT>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a tile to one.
|
||||
*
|
||||
* @tparam RT Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
one(thread RT &dst) {
|
||||
unary_map<base_ops::one, RT>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a tile to positive infinity.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
pos_infty(thread RT &dst) {
|
||||
unary_map<base_ops::pos_infty, RT>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a tile to negative infinity.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
neg_infty(thread RT &dst) {
|
||||
unary_map<base_ops::neg_infty, RT>(dst, dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the exponential function on.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
exp(thread RT &dst, thread const RT &src) {
|
||||
unary_map<base_ops::exp, RT>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of a tile, in base 2.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the exponential function on.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
exp2(thread RT &dst, thread const RT &src) {
|
||||
unary_map<base_ops::exp2, RT>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the natural logarithm function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the natural logarithm function on.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
log(thread RT &dst, thread const RT &src) {
|
||||
unary_map<base_ops::log, RT>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the absolute value function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the absolute value function on.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
abs(thread RT &dst, thread const RT &src) {
|
||||
unary_map<base_ops::abs, RT>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the rectified linear unit (ReLU) function to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the ReLU function on.
|
||||
*/
|
||||
template<typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
relu(thread RT &dst, thread const RT &src) {
|
||||
unary_map<base_ops::relu, RT>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Copies the elements from one tile to another.
|
||||
*
|
||||
* @tparam T Destination tile type.
|
||||
* @tparam U Source tile type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to copy from.
|
||||
*/
|
||||
template<typename RT, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
copy(thread RT &dst, thread const U &src) {
|
||||
bin_map<base_ops::copy2, RT>(dst, dst, src);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Applies the max operation element-wise between two tiles or a tile and a scalar.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the operation.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the operation.
|
||||
*/
|
||||
template<typename RT, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
max(thread RT &dst, thread const RT &lhs, thread const U &rhs) {
|
||||
bin_map<base_ops::max, RT>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the min operation element-wise between two tiles or a tile and a scalar.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the operation.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the operation.
|
||||
*/
|
||||
template<typename RT, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
min(thread RT &dst, thread const RT &lhs, thread const U &rhs) {
|
||||
bin_map<base_ops::min, RT>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Adds two tiles element-wise or adds a scalar to each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the addition.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the addition.
|
||||
*/
|
||||
template<typename RT, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
add(thread RT &dst, thread const RT &lhs, thread const U &rhs) {
|
||||
bin_map<base_ops::sum, RT>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts two tiles element-wise or subtracts a scalar from each element of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the subtraction.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the subtraction.
|
||||
*/
|
||||
template<typename RT, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
sub(thread RT &dst, const thread RT &lhs, thread const U &rhs) {
|
||||
bin_map<base_ops::sub, RT>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies two tiles element-wise or multiplies each element of a tile by a scalar.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the multiplication.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the multiplication.
|
||||
*/
|
||||
template<typename RT, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
mul(thread RT &dst, thread const RT &lhs, thread const U &rhs) {
|
||||
bin_map<base_ops::mul, RT>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Divides two tiles element-wise or divides each element of a tile by a scalar.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam U Second operand type, which can be a tile or a scalar.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param lhs[in] Left-hand side source tile for the division.
|
||||
* @param rhs[in] Right-hand side source tile or scalar for the division.
|
||||
*/
|
||||
template<typename RT, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>(), void>::type
|
||||
div(thread RT &dst, thread const RT &lhs, thread const U &rhs) {
|
||||
bin_map<base_ops::div, RT>(dst, lhs, rhs);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Adds row values to each row of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the addition on.
|
||||
* @param row_values[in] Column vector containing values to add to each row.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
add_row(thread RT &dst, thread const RT &src, thread const RV &row_values) {
|
||||
row_map<base_ops::sum, RT, RV>(dst, src, row_values);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts row values from each row of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the subtraction on.
|
||||
* @param row_values[in] Column vector containing values to subtract from each row.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
sub_row(thread RT &dst, thread const RT &src, thread const RV &row_values) {
|
||||
row_map<base_ops::sub, RT, RV>(dst, src, row_values);
|
||||
// using T4 = typename base_types::packing<typename RT::dtype>::packed_four;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < RT::height; i++) {
|
||||
// // #pragma clang loop unroll(full)
|
||||
// // for(int j = 0; j < RT::width; j+=2) {
|
||||
// // T4 val = op::template op<T4>({src.tiles[i][j].data.thread_elements()[0],
|
||||
// // src.tiles[i][j].data.thread_elements()[1],
|
||||
// // src.tiles[i][j+1].data.thread_elements()[0],
|
||||
// // src.tiles[i][j+1].data.thread_elements()[1],},
|
||||
// // {row_values[i][0], row_values[i][0],row_values[i][0], row_values[i][0]});
|
||||
// //
|
||||
// // dst.tiles[i][j].data.thread_elements()[0] = val[0];
|
||||
// // dst.tiles[i][j].data.thread_elements()[1] = val[1];
|
||||
// // dst.tiles[i][j+1].data.thread_elements()[0] = val[2];
|
||||
// // dst.tiles[i][j+1].data.thread_elements()[1] = val[3];
|
||||
// // }
|
||||
//
|
||||
// // #pragma clang loop unroll(full)
|
||||
// // for(int j = 0; j < RT::width; j++) {
|
||||
// // T2 val = op::template op<T2>({src.tiles[i][j].data.thread_elements()[0],
|
||||
// // src.tiles[i][j].data.thread_elements()[1]},
|
||||
// // {row_values[i][0], row_values[i][0]});
|
||||
// //
|
||||
// // dst.tiles[i][j].data.thread_elements()[0] = val[0];
|
||||
// // dst.tiles[i][j].data.thread_elements()[1] = val[1];
|
||||
// // }
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < RT::width; j+=2) {
|
||||
// T4 val = T4(src.tiles[i][j].data.thread_elements()[0],
|
||||
// src.tiles[i][j].data.thread_elements()[1],
|
||||
// src.tiles[i][j+1].data.thread_elements()[0],
|
||||
// src.tiles[i][j+1].data.thread_elements()[1]) - T4(row_values[i][0], row_values[i][0], row_values[i][0], row_values[i][0]);
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = val[0];
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = val[1];
|
||||
// dst.tiles[i][j+1].data.thread_elements()[0] = val[2];
|
||||
// dst.tiles[i][j+1].data.thread_elements()[1] = val[3];
|
||||
// }
|
||||
// }
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies each row of a tile by row values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the multiplication on.
|
||||
* @param row_values[in] Column vector containing values to multiply each row by.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
mul_row(thread RT &dst, thread const RT &src, thread const RV &row_values) {
|
||||
// using T = typename RT::T;
|
||||
// using T2 = typename RT::T2;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < RT::height; i++) {
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 0; j < RT::width; j++) {
|
||||
//// T s1 = src.tiles[i][j].data.thread_elements()[0];
|
||||
//// T v1 = row_values[i][0];
|
||||
//// dst.tiles[i][j].data.thread_elements()[0] = s1 * v1;
|
||||
//// T s2 = src.tiles[i][j].data.thread_elements()[1];
|
||||
//// T v2 = row_values[i][1];
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = s2 * v2;
|
||||
//
|
||||
//
|
||||
//// dst.tiles[i][j].data.thread_elements()[0] = op::template op<T>(src.tiles[i][j].data.thread_elements()[0], row_values[i][0]);
|
||||
//// dst.tiles[i][j].data.thread_elements()[1] = op::template op<T>(src.tiles[i][j].data.thread_elements()[1], row_values[i][0]);
|
||||
// T2 val = op::template op<T2>({src.tiles[i][j].data.thread_elements()[0], row_values[i][0]);
|
||||
// dst.tiles[i][j].data.thread_elements()[0] = op::template op<T>(src.tiles[i][j].data.thread_elements()[0], row_values[i][0]);
|
||||
// dst.tiles[i][j].data.thread_elements()[1] = op::template op<T>(src.tiles[i][j].data.thread_elements()[1], row_values[i][0]);
|
||||
// }
|
||||
// }
|
||||
row_map<base_ops::mul, RT, RV>(dst, src, row_values);
|
||||
}
|
||||
/**
|
||||
* @brief Divides each row of a tile by row values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the division on.
|
||||
* @param row_values[in] Column vector containing values to divide each row by.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
div_row(thread RT &dst, thread const RT &src, thread const RV &row_values) {
|
||||
row_map<base_ops::div, RT, RV>(dst, src, row_values);
|
||||
}
|
||||
/**
|
||||
* @brief Broadcast a vector into into a tile's rows.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param row_values[in] Column vector containing values to broadcast into rows.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
broadcast_row(thread RT &dst, thread const RV &row_values) {
|
||||
row_map<base_ops::copy2, RT, RV>(dst, dst, row_values);
|
||||
}
|
||||
|
||||
|
||||
// col maps
|
||||
/**
|
||||
* @brief Adds column values to each column of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the addition on.
|
||||
* @param col_values[in] Row vector containing values to add to each column.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
add_col(thread RT &dst, thread const RT &src, thread const RV &col_values) {
|
||||
col_map<base_ops::sum, RT, RV>(dst, src, col_values);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts column values from each column of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the subtraction on.
|
||||
* @param col_values[in] Row vector containing values to subtract from each column.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
sub_col(thread RT &dst, thread const RT &src, thread const RV &col_values) {
|
||||
col_map<base_ops::sub, RT, RV>(dst, src, col_values);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies each column of a tile by column values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the multiplication on.
|
||||
* @param col_values[in] Row vector containing values to multiply each column by.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
mul_col(thread RT &dst, thread const RT &src, thread const RV &col_values) {
|
||||
col_map<base_ops::mul, RT, RV>(dst, src, col_values);
|
||||
}
|
||||
/**
|
||||
* @brief Divides each column of a tile by column values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the division on.
|
||||
* @param col_values[in] Row vector containing values to divide each column by.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC void div_col(thread RT &dst, thread const RT &src, thread const RV &col_values) {
|
||||
col_map<base_ops::div, RT, RV>(dst, src, col_values);
|
||||
}
|
||||
/**
|
||||
* @brief Broadcast a vector into into a tile's columns.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param row_values[in] Row vector containing values to broadcast into cols.
|
||||
*/
|
||||
template<typename RT, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
broadcast_col(thread RT &dst, thread const RV &col_values) {
|
||||
col_map<base_ops::copy2, RT, RV>(dst, dst, col_values);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
214
extra/thunder/include/ops/warp/register/tile/mma.metal
Normal file
214
extra/thunder/include/ops/warp/register/tile/mma.metal
Normal file
@@ -0,0 +1,214 @@
|
||||
#pragma once // doneington
|
||||
|
||||
#include <metal_stdlib>
|
||||
#include "../../../../types/types.metal"
|
||||
#include "../../../../common/common.metal"
|
||||
namespace mittens {
|
||||
|
||||
template <typename R, typename T, typename U, typename V,
|
||||
typename l1, typename l2, typename l3, typename l4>
|
||||
METAL_FUNC static void mma_base(thread rt_base<R, l1>& d,
|
||||
thread rt_base<T, l2>& a,
|
||||
thread rt_base<U, l3>& b,
|
||||
thread rt_base<V, l4>& c) {
|
||||
metal::simdgroup_multiply_accumulate(d.data, a.data, b.data, c.data);
|
||||
}
|
||||
|
||||
template <typename R, typename T, typename U,
|
||||
typename l1, typename l2, typename l3>
|
||||
METAL_FUNC static void mm_base(thread rt_base<R, l1>& d,
|
||||
thread rt_base<T, l2>& a,
|
||||
thread rt_base<U, l3>& b) {
|
||||
metal::simdgroup_multiply(d.data, a.data, b.data);
|
||||
}
|
||||
|
||||
namespace meta {
|
||||
template<typename R, typename T, typename U, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>(), void>::type
|
||||
mma_AB_unroll_inner(int k, int n, int m,
|
||||
thread rt<R, N, M, ducks::rt_layout::row>* d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>* a,
|
||||
thread rt<U, K, M, ducks::rt_layout::row>* b) {
|
||||
mma_base(
|
||||
d->tiles[n][m],
|
||||
a->tiles[n][k],
|
||||
b->tiles[k][m],
|
||||
d->tiles[n][m]
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
template<typename R, typename T, typename U, typename V, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>() && ducks::base_types::isT1Type<V>(), void>::type
|
||||
mma_AB_unroll(int n, int m,
|
||||
thread rt<R, N, M, ducks::rt_layout::row>* d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>* a,
|
||||
thread rt<U, K, M, ducks::rt_layout::row>* b,
|
||||
thread rt<V, N, M, ducks::rt_layout::row>* c) {
|
||||
mma_base(
|
||||
d->tiles[n][m],
|
||||
a->tiles[n][0],
|
||||
b->tiles[0][m],
|
||||
c->tiles[n][m]
|
||||
);
|
||||
meta::unroll_i_in_range<1, K/TILE_DIM, 1>::run(meta::mma_AB_unroll_inner<R, T, U, N, K, M>, n, m, d, a, b);
|
||||
}
|
||||
|
||||
template<typename R, typename T, typename U, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>(), void>::type
|
||||
mm_AB_unroll(int n, int m,
|
||||
thread rt<R, N, M, ducks::rt_layout::row>* d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>* a,
|
||||
thread rt<U, K, M, ducks::rt_layout::row>* b) {
|
||||
mm_base(
|
||||
d->tiles[n][m],
|
||||
a->tiles[n][0],
|
||||
b->tiles[0][m]
|
||||
);
|
||||
meta::unroll_i_in_range<1, K/TILE_DIM, 1>::run(meta::mma_AB_unroll_inner<R, T, U, N, K, M>, n, m, d, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename R, typename T, typename U, typename V, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>() && ducks::base_types::isT1Type<V>(), void>::type
|
||||
mma_AB(thread rt<R, N, M, ducks::rt_layout::row>& d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>& a,
|
||||
thread rt<U, K, M, ducks::rt_layout::row>& b,
|
||||
thread rt<V, N, M, ducks::rt_layout::row>& c) {
|
||||
meta::unroll_i_j_in_range<0, N/TILE_DIM, 1, 0, M/TILE_DIM, 1>::run(meta::mma_AB_unroll<R, T, U, V, N, K, M>, &d, &a, &b, &c);
|
||||
}
|
||||
|
||||
template<typename R, typename T, typename U, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>(), void>::type
|
||||
mm_AB(thread rt<R, N, M, ducks::rt_layout::row>& d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>& a,
|
||||
thread rt<U, K, M, ducks::rt_layout::row>& b) {
|
||||
meta::unroll_i_j_in_range<0, N/TILE_DIM, 1, 0, M/TILE_DIM, 1>::run(meta::mm_AB_unroll<R, T, U, N, K, M>, &d, &a, &b);
|
||||
}
|
||||
|
||||
namespace meta {
|
||||
template<typename R, typename T, typename U, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>(), void>::type
|
||||
mma_ABt_unroll_inner(int k, int n, int m,
|
||||
thread rt<R, N, M, ducks::rt_layout::row>* d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>* a,
|
||||
thread rt<U, M, K, ducks::rt_layout::col>* b) {
|
||||
mma_base(
|
||||
d->tiles[n][m],
|
||||
a->tiles[n][k],
|
||||
b->tiles[m][k],
|
||||
d->tiles[n][m]
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
template<typename R, typename T, typename U, typename V, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>() && ducks::base_types::isT1Type<V>(), void>::type
|
||||
mma_ABt_unroll(int n, int m,
|
||||
thread rt<R, N, M, ducks::rt_layout::row>* d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>* a,
|
||||
thread rt<U, M, K, ducks::rt_layout::col>* b,
|
||||
thread rt<V, N, M, ducks::rt_layout::row>* c) {
|
||||
mma_base(
|
||||
d->tiles[n][m],
|
||||
a->tiles[n][0],
|
||||
b->tiles[m][0],
|
||||
c->tiles[n][m]
|
||||
);
|
||||
meta::unroll_i_in_range<1, K/TILE_DIM, 1>::run(meta::mma_ABt_unroll_inner<R, T, U, N, K, M>, n, m, d, a, b);
|
||||
}
|
||||
|
||||
template<typename R, typename T, typename U, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>(), void>::type
|
||||
mm_ABt_unroll(int n, int m,
|
||||
thread rt<R, N, M, ducks::rt_layout::row>* d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>* a,
|
||||
thread rt<U, M, K, ducks::rt_layout::col>* b) {
|
||||
mm_base(
|
||||
d->tiles[n][m],
|
||||
a->tiles[n][0],
|
||||
b->tiles[m][0]
|
||||
);
|
||||
meta::unroll_i_in_range<1, K/TILE_DIM, 1>::run(meta::mma_ABt_unroll_inner<R, T, U, N, K, M>, n, m, d, a, b);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename R, typename T, typename U, typename V, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>() && ducks::base_types::isT1Type<V>(), void>::type
|
||||
mma_ABt(thread rt<R, N, M, ducks::rt_layout::row>& d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>& a,
|
||||
thread rt<U, M, K, ducks::rt_layout::col>& b,
|
||||
thread rt<V, N, M, ducks::rt_layout::row>& c) {
|
||||
meta::unroll_i_j_in_range<0, N/TILE_DIM, 1, 0, M/TILE_DIM, 1>::run(meta::mma_ABt_unroll<R, T, U, V, N, K, M>, &d, &a, &b, &c);
|
||||
}
|
||||
|
||||
template<typename R, typename T, typename U, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>(), void>::type
|
||||
mm_ABt(thread rt<R, N, M, ducks::rt_layout::row>& d,
|
||||
thread rt<T, N, K, ducks::rt_layout::row>& a,
|
||||
thread rt<U, M, K, ducks::rt_layout::col>& b) {
|
||||
meta::unroll_i_j_in_range<0, N/TILE_DIM, 1, 0, M/TILE_DIM, 1>::run(meta::mm_ABt_unroll<R, T, U, N, K, M>, &d, &a, &b);
|
||||
}
|
||||
|
||||
template<typename R, typename T, typename U, typename V, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>() && ducks::base_types::isT1Type<V>(), void>::type
|
||||
mma_AtB(thread rt<R, N, M, ducks::rt_layout::row>& d,
|
||||
thread rt<T, K, N, ducks::rt_layout::col>& a,
|
||||
thread rt<U, K, M, ducks::rt_layout::row>& b,
|
||||
thread rt<V, N, M, ducks::rt_layout::row>& c) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int n = 0; n < N / TILE_DIM; n++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int m = 0; m < M / TILE_DIM; m++) {
|
||||
mma_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[0][n],
|
||||
b.tiles[0][m],
|
||||
c.tiles[n][m]
|
||||
);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 1; k < K / TILE_DIM; k++) {
|
||||
mma_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[k][n],
|
||||
b.tiles[k][m],
|
||||
d.tiles[n][m]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename R, typename T, typename U, typename V, int N, int K, int M>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::base_types::isT1Type<R>() && ducks::base_types::isT1Type<T>() && ducks::base_types::isT1Type<U>() && ducks::base_types::isT1Type<V>(), void>::type
|
||||
mma_AtBt(thread rt<R, N, M, ducks::rt_layout::row>& d,
|
||||
thread rt<T, K, N, ducks::rt_layout::col>& a,
|
||||
thread rt<U, M, K, ducks::rt_layout::col>& b,
|
||||
thread rt<V, N, M, ducks::rt_layout::row>& c) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int n = 0; n < N / TILE_DIM; n++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for (int m = 0; m < M / TILE_DIM; m++) {
|
||||
mma_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[0][n],
|
||||
b.tiles[m][0],
|
||||
c.tiles[n][m]
|
||||
);
|
||||
#pragma clang loop unroll(full)
|
||||
for (int k = 1; k < K / TILE_DIM; k++) {
|
||||
mma_base(
|
||||
d.tiles[n][m],
|
||||
a.tiles[k][n],
|
||||
b.tiles[m][k],
|
||||
d.tiles[n][m]
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
636
extra/thunder/include/ops/warp/register/tile/reductions.metal
Normal file
636
extra/thunder/include/ops/warp/register/tile/reductions.metal
Normal file
@@ -0,0 +1,636 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Reduction operations mapping tiles to vectors.
|
||||
*/
|
||||
|
||||
#pragma once //doneington (but register col layotus)
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
namespace meta {
|
||||
|
||||
//template<typename op, typename RT>
|
||||
//static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>(), void>::type
|
||||
//row_reduce_unroll_inner(int i, thread const RT *src, thread typename RT::T& accum_thread) {
|
||||
// accum_thread = op::template op<typename RT::T>(accum_thread, src->tiles[i][0].data.thread_elements()[0]);
|
||||
// accum_thread = op::template op<typename RT::T>(accum_thread, src->tiles[i][0].data.thread_elements()[1]);
|
||||
//}
|
||||
//
|
||||
//template<typename op, typename RV, typename RT, bool reset>
|
||||
//static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
//row_reduce_unroll(int i, thread RV *row_accum, thread const RT *src, thread const RV *src_accum, const short leader) {
|
||||
// using T = typename RV::T;
|
||||
// T accum_thread = op::template op<T>(src->tiles[i][0].data.thread_elements()[0], src->tiles[i][0].data.thread_elements()[1]);
|
||||
//
|
||||
// meta::unroll_i_in_range<1, RT::width, 1>::run(meta::row_reduce_unroll_inner<op, RT>, src, accum_thread);
|
||||
// accum_thread = op::template op<T>(accum_thread, shfl_down_sync<T>(accum_thread, 1));
|
||||
// accum_thread = op::template op<T>(accum_thread, shfl_down_sync<T>(accum_thread, 8));
|
||||
//
|
||||
// accum_thread = shfl_sync<T>(accum_thread, leader);
|
||||
//
|
||||
// if(reset) { (*row_accum)[i][0] = accum_thread; }
|
||||
// else { (*row_accum)[i][0] = op::template op<T>((*src_accum)[i][0], accum_thread); }
|
||||
//}
|
||||
|
||||
//template<typename op, typename RT>
|
||||
//static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>(), void>::type
|
||||
//row_reduce_unroll_inner(int i, thread const RT *src, thread typename RT::T2& accum_thread) {
|
||||
// accum_thread = op::template op<typename RT::T2>(accum_thread, {src->tiles[i][0].data.thread_elements()[0], src->tiles[i][0].data.thread_elements()[1]});
|
||||
//}
|
||||
|
||||
/*
|
||||
pragma clang loop unroll(full)
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
T accum_thread = op::template op<T>(src.tiles[i][0].data.thread_elements()[0], src.tiles[i][0].data.thread_elements()[1]);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 1; j < src.width; j++) {
|
||||
accum_thread = op::template op<T>(accum_thread, src.tiles[i][j].data.thread_elements()[0]);
|
||||
accum_thread = op::template op<T>(accum_thread, src.tiles[i][j].data.thread_elements()[1]);
|
||||
}
|
||||
accum_thread = op::template op<T>(accum_thread, shfl_down_sync<T>(accum_thread, 1));
|
||||
accum_thread = op::template op<T>(accum_thread, shfl_down_sync<T>(accum_thread, 8));
|
||||
|
||||
accum_thread = shfl_sync<T>(accum_thread, leader);
|
||||
|
||||
if(reset) { row_accum[i][0] = accum_thread; }
|
||||
else { row_accum[i][0] = op::template op<T>(src_accum[i][0], accum_thread); }
|
||||
}
|
||||
*/
|
||||
|
||||
template<typename op, typename RV, typename RT, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_reduce_unroll(int i, thread RV *row_accum, thread const RT *src, thread const RV *src_accum, const short leader) {
|
||||
using T = typename RV::T;
|
||||
using T2 = typename RV::T2;
|
||||
T accum_thread = op::template op<T>(src->tiles[i][0].data.thread_elements()[0], src->tiles[i][0].data.thread_elements()[1]);
|
||||
for(int j = 1; j < src->width; j++) {
|
||||
accum_thread = op::template op<T>(accum_thread, src->tiles[i][j].data.thread_elements()[0]);
|
||||
accum_thread = op::template op<T>(accum_thread, src->tiles[i][j].data.thread_elements()[1]);
|
||||
}
|
||||
|
||||
T shfl_val = shfl_down_sync<T>(accum_thread, 1);
|
||||
accum_thread = op::template op<T>(accum_thread, shfl_val);
|
||||
shfl_val = shfl_down_sync<T>(accum_thread, 8);
|
||||
accum_thread = op::template op<T>(accum_thread, shfl_val);
|
||||
|
||||
accum_thread = shfl_sync<T>(accum_thread, leader);
|
||||
|
||||
if(reset) {
|
||||
(*row_accum)[i][0] = accum_thread;
|
||||
}
|
||||
else {
|
||||
(*row_accum)[i][0] = op::template op<T>((*src_accum)[i][0], accum_thread);;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
/**
|
||||
* @brief Perform a row-wise reduction on a matrix in row-major layout.
|
||||
*
|
||||
* This function template performs a parallel reduction across the rows of a matrix using a specified operation.
|
||||
* It leverages warp shuffle functions for efficient intra-warp communication.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type with row layout.
|
||||
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
|
||||
*/
|
||||
template<typename op, typename RV, typename RT, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_reduce(thread RV &row_accum, thread const RT &src, thread const RV &src_accum, const short laneid) {
|
||||
static_assert(ducks::is_ortho_layout<typename RV::layout>(), "rv must be ortho for row RT");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rv and rt must be the same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::height, "rv and rt dims don't match"); // compatible size
|
||||
using T = typename RV::T;
|
||||
using T2 = typename RV::T2;
|
||||
const short leader = (laneid / 16) * 16 + ((laneid / 2) % 4) * 2;
|
||||
|
||||
// constexpr const uint32_t COL_0 = 0x00550055;
|
||||
// constexpr const uint32_t COL_1 = 0x00AA00AA;
|
||||
// constexpr const uint32_t COL_2 = 0x55005500;
|
||||
// constexpr const uint32_t COL_3 = 0xAA00AA00;
|
||||
//
|
||||
// constexpr const uint32_t COL_0_2 = COL_0 | COL_2;
|
||||
// constexpr const uint32_t COL_0_1 = COL_0 | COL_1;
|
||||
// constexpr const uint32_t COL_2_3 = COL_2 | COL_3;
|
||||
// const ushort src_lane1 = laneid + ((COL_0_2 >> laneid) & 1) * 1 + ((COL_1 >> laneid) & 1) * 7 - ((COL_3 >> laneid) & 1) * 9;
|
||||
// const ushort src_lane2 = laneid + ((COL_0_1 >> laneid) & 1) * 8 - ((COL_2_3 >> laneid) & 1) * 8;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// T accum_thread = op::template op<T>(src.tiles[i][0].data.thread_elements()[0], src.tiles[i][0].data.thread_elements()[1]);
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 1; j < src.width; j++) {
|
||||
// accum_thread = op::template op<T>(accum_thread, src.tiles[i][j].data.thread_elements()[0]);
|
||||
// accum_thread = op::template op<T>(accum_thread, src.tiles[i][j].data.thread_elements()[1]);
|
||||
// }
|
||||
// accum_thread = op::template op<T>(accum_thread, shfl_sync<T>(accum_thread, src_lane1));
|
||||
// accum_thread = op::template op<T>(accum_thread, shfl_sync<T>(accum_thread, src_lane2));
|
||||
//
|
||||
//
|
||||
// if(reset) { row_accum[i][0] = accum_thread; }
|
||||
// else { row_accum[i][0] = op::template op<T>(src_accum[i][0], accum_thread); }
|
||||
// }
|
||||
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// T accum_thread = op::template op<T>(src.tiles[i][0].data.thread_elements()[0], src.tiles[i][0].data.thread_elements()[1]);
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 1; j < src.width; j++) {
|
||||
// accum_thread = op::template op<T>(accum_thread, src.tiles[i][j].data.thread_elements()[0]);
|
||||
// accum_thread = op::template op<T>(accum_thread, src.tiles[i][j].data.thread_elements()[1]);
|
||||
// }
|
||||
// accum_thread = op::template op<T>(accum_thread, shfl_down_sync<T>(accum_thread, 1));
|
||||
// accum_thread = op::template op<T>(accum_thread, shfl_down_sync<T>(accum_thread, 8));
|
||||
//
|
||||
// accum_thread = shfl_sync<T>(accum_thread, leader);
|
||||
//
|
||||
// if(reset) { row_accum[i][0] = accum_thread; }
|
||||
// else { row_accum[i][0] = op::template op<T>(src_accum[i][0], accum_thread); }
|
||||
// }
|
||||
|
||||
meta::unroll_i_in_range<0, RT::height, 1>::run(meta::row_reduce_unroll<op, RV, RT, reset>, &row_accum, &src, &src_accum, leader);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Perform a row-wise reduction on a matrix in row-major layout.
|
||||
*
|
||||
* This function template performs a parallel reduction across the rows of a matrix using a specified operation.
|
||||
* It leverages warp shuffle functions for efficient intra-warp communication.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type with row layout.
|
||||
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
|
||||
*/
|
||||
template<typename op, typename RV, typename RT, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_reduce(thread RV &row_accum, thread const RT &src, thread const RV &src_accum, const short laneid) {
|
||||
static_assert(ducks::is_align_layout<typename RV::layout>(), "rv must be align for row RT");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rv and rt must be the same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::height, "rv and rt dims don't match"); // compatible size
|
||||
|
||||
using T = typename RV::T;
|
||||
using T2 = typename RV::T2;
|
||||
|
||||
const int leader = (laneid % 2) + ((laneid / 8) % 2) * 8;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < src.height; i++) {
|
||||
T2 accum_thread = {src.tiles[i][0].data.thread_elements()[0], src.tiles[i][0].data.thread_elements()[1]};
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 1; j < src.width; j++) {
|
||||
accum_thread = op::template op<T2>(accum_thread, {src.tiles[i][j].data.thread_elements()[0], src.tiles[i][j].data.thread_elements()[1]});
|
||||
}
|
||||
// Now we need to do a lil shuffle to make everyone happy.
|
||||
|
||||
accum_thread = op::template op<T2>(accum_thread, shfl_down_sync<T2>(accum_thread, 2));
|
||||
accum_thread = op::template op<T2>(accum_thread, shfl_down_sync<T2>(accum_thread, 4));
|
||||
accum_thread = op::template op<T2>(accum_thread, shfl_down_sync<T2>(accum_thread, 16));
|
||||
|
||||
accum_thread = shfl_sync<T2>(accum_thread, leader);
|
||||
|
||||
if(reset) {
|
||||
row_accum[i][0] = accum_thread[0];
|
||||
row_accum[i][1] = accum_thread[1];
|
||||
}
|
||||
else {
|
||||
row_accum[i][0] = op::template op<T>(row_accum[i][0], accum_thread[0]);
|
||||
row_accum[i][1] = op::template op<T>(row_accum[i][1], accum_thread[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Perform a column-wise reduction on a matrix in row-major layout.
|
||||
*
|
||||
* This function template performs a parallel reduction across the columns of a matrix using a specified operation.
|
||||
* It leverages warp shuffle functions for efficient intra-warp communication and is optimized for row-major matrices.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The vector type for the column accumulator.
|
||||
* @tparam T The matrix type with row layout.
|
||||
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
|
||||
*/
|
||||
template<typename op, typename RV, typename RT, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_row_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_reduce(thread RV &col_accum, thread const RT &src, thread const RV &src_accum, const ushort laneid) {
|
||||
static_assert(ducks::is_align_layout<typename RV::layout>(), "rv must be align layout");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rt and rv must be same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::width, "rv and rt dims don't match"); // compatible size
|
||||
|
||||
using dtype = typename RV::dtype;
|
||||
using T2 = typename base_types::packing<dtype>::packed_type;
|
||||
|
||||
const int leader = (laneid % 2) + ((laneid / 8) % 2) * 8;
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < src.width; j++) {
|
||||
// dtype accum_left_cols = src.tiles[0][j].data.thread_elements()[0];
|
||||
// dtype accum_right_cols = src.tiles[0][j].data.thread_elements()[1];
|
||||
T2 accum_cols = {src.tiles[0][j].data.thread_elements()[0], src.tiles[0][j].data.thread_elements()[1]};
|
||||
// dtype accum_right_cols = src.tiles[0][j].data.thread_elements()[1];
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < src.height; i++) {
|
||||
// accum_left_cols = op::template op<dtype>(accum_left_cols , src.tiles[i][j].data.thread_elements()[0]);
|
||||
// accum_right_cols = op::template op<dtype>(accum_right_cols, src.tiles[i][j].data.thread_elements()[1]);
|
||||
accum_cols = op::template op<T2>(accum_cols, {src.tiles[i][j].data.thread_elements()[0], src.tiles[i][j].data.thread_elements()[1]});
|
||||
}
|
||||
|
||||
// accum_left_cols = op::template op<dtype>(accum_left_cols, shfl_down_sync<dtype>(accum_left_cols, 2));
|
||||
// accum_left_cols = op::template op<dtype>(accum_left_cols, shfl_down_sync<dtype>(accum_left_cols, 4));
|
||||
// accum_left_cols = op::template op<dtype>(accum_left_cols, shfl_down_sync<dtype>(accum_left_cols, 16));
|
||||
|
||||
// accum_right_cols = op::template op<dtype>(accum_right_cols, shfl_down_sync<dtype>(accum_right_cols, 2));
|
||||
// accum_right_cols = op::template op<dtype>(accum_right_cols, shfl_down_sync<dtype>(accum_right_cols, 4));
|
||||
// accum_right_cols = op::template op<dtype>(accum_right_cols, shfl_down_sync<dtype>(accum_right_cols, 16));
|
||||
accum_cols = op::template op<T2>(accum_cols, shfl_down_sync<T2>(accum_cols, 2));
|
||||
accum_cols = op::template op<T2>(accum_cols, shfl_down_sync<T2>(accum_cols, 4));
|
||||
accum_cols = op::template op<T2>(accum_cols, shfl_down_sync<T2>(accum_cols, 16));
|
||||
|
||||
// accum_left_cols = shfl_sync<dtype>(accum_left_cols, leader);
|
||||
// accum_right_cols = shfl_sync<dtype>(accum_right_cols, leader);
|
||||
accum_cols = shfl_sync<T2>(accum_cols, leader);
|
||||
|
||||
|
||||
if(reset) {
|
||||
// col_accum[j][0] = accum_left_cols;
|
||||
// col_accum[j][1] = accum_right_cols;
|
||||
col_accum[j][0] = accum_cols[0];
|
||||
col_accum[j][1] = accum_cols[1];
|
||||
}
|
||||
else {
|
||||
// col_accum[j][0] = op::template op<dtype>(src_accum[j][0], accum_left_cols);
|
||||
// col_accum[j][1] = op::template op<dtype>(src_accum[j][1], accum_right_cols);
|
||||
col_accum[j][0] = op::template op<dtype>(src_accum[j][0], accum_cols[0]);
|
||||
col_accum[j][1] = op::template op<dtype>(src_accum[j][1], accum_cols[1]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Perform a column-wise reduction on a matrix in row-major layout.
|
||||
*
|
||||
* This function template performs a parallel reduction across the columns of a matrix using a specified operation.
|
||||
* It leverages warp shuffle functions for efficient intra-warp communication and is optimized for row-major matrices.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The vector type for the column accumulator.
|
||||
* @tparam T The matrix type with row layout.
|
||||
* @tparam reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when reset is false.
|
||||
*/
|
||||
template<typename op, typename RV, typename RT, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_col_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_reduce(thread RV &col_accum, thread const RT &src, thread const RV &src_accum, const ushort laneid) {
|
||||
static_assert(ducks::is_ortho_layout<typename RV::layout>(), "rv must be ortho layout");
|
||||
static_assert(metal::is_same_v<typename RV::dtype, typename RT::dtype>, "rt and rv must be same type"); // compatible type
|
||||
static_assert(RV::outer_dim == RT::width, "rv and rt dims don't match"); // compatible size
|
||||
|
||||
using T = typename RV::T;
|
||||
using T2 = typename base_types::packing<T>::packed_type;
|
||||
|
||||
const int leader = (laneid / 16) * 16 + ((laneid / 2) % 4) * 2; // lololol
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < src.width; i++) {
|
||||
T accum_thread = op::template op<T>(src.tiles[0][i].data.thread_elements()[0], src.tiles[0][i].data.thread_elements()[1]);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 1; j < src.height; j++) {
|
||||
accum_thread = op::template op<T>(accum_thread, src.tiles[j][i].data.thread_elements()[0]);
|
||||
accum_thread = op::template op<T>(accum_thread, src.tiles[j][i].data.thread_elements()[1]);
|
||||
}
|
||||
// Now we need to do a lil shuffle to make everyone happy.
|
||||
|
||||
accum_thread = op::template op<T>(accum_thread, shfl_down_sync<T>(accum_thread, 1));
|
||||
accum_thread = op::template op<T>(accum_thread, shfl_down_sync<T>(accum_thread, 8));
|
||||
|
||||
accum_thread = shfl_sync<T>(accum_thread, leader);
|
||||
|
||||
if(reset) {
|
||||
col_accum[i][0] = accum_thread;
|
||||
}
|
||||
else {
|
||||
col_accum[i][0] = op::template op<T>(col_accum[i][0], accum_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
// two-operand row reductions. (Accumulate and REPLACE.)
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src register tile in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_max(thread RV &row_accum, thread const RT &src, const int laneid) {
|
||||
row_reduce<base_ops::max, RV, RT, true>(row_accum, src, row_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src register tile in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_min(thread RV &row_accum, thread const RT &src, const int laneid) {
|
||||
row_reduce<base_ops::min, RV, RT, true>(row_accum, src, row_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src register tile in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_sum(thread RV &row_accum, thread const RT &src, const int laneid) {
|
||||
row_reduce<base_ops::sum, RV, RT, true>(row_accum, src, row_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src register tile in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_prod(thread RV &row_accum, thread const RT &src, const int laneid) {
|
||||
row_reduce<base_ops::mul, RV, RT, true>(row_accum, src, row_accum, laneid);
|
||||
}
|
||||
|
||||
// three-operand row reductions. (Accumulate ONTO.)
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_max(thread RV &row_accum, thread const RT &src, thread const RV &src_accum, const int laneid) {
|
||||
// using T = typename RV::T;
|
||||
// using T2 = typename RV::T2;
|
||||
// const short leader = (laneid / 16) * 16 + ((laneid / 2) % 4) * 2;
|
||||
//
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// T accum_thread = metal::max(src.tiles[i][0].data.thread_elements()[0], src.tiles[i][0].data.thread_elements()[1]);
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 1; j < src.width; j++) {
|
||||
// accum_thread = metal::max(accum_thread, src.tiles[i][j].data.thread_elements()[0]);
|
||||
// accum_thread = metal::max(accum_thread, src.tiles[i][j].data.thread_elements()[1]);
|
||||
// }
|
||||
// accum_thread = metal::max(accum_thread, shfl_down_sync<T>(accum_thread, 1));
|
||||
// accum_thread = metal::max(accum_thread, shfl_down_sync<T>(accum_thread, 8));
|
||||
// accum_thread = shfl_sync<T>(accum_thread, leader);
|
||||
// if(false) { row_accum[i][0] = accum_thread; }
|
||||
// else { row_accum[i][0] = metal::max(src_accum[i][0], accum_thread); }
|
||||
// }
|
||||
|
||||
row_reduce<base_ops::max, RV, RT, false>(row_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_min(thread RV &row_accum, thread const RT &src, thread const RV &src_accum, const int laneid) {
|
||||
row_reduce<base_ops::min, RV, RT, false>(row_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_sum(thread RV &row_accum, thread const RT &src, thread const RV &src_accum, const int laneid) {
|
||||
// using T = typename RV::T;
|
||||
// using T2 = typename RV::T2;
|
||||
// const short leader = (laneid / 16) * 16 + ((laneid / 2) % 4) * 2;
|
||||
//
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// T accum_thread = (src.tiles[i][0].data.thread_elements()[0] + src.tiles[i][0].data.thread_elements()[1]);
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 1; j < src.width; j++) {
|
||||
// accum_thread = (accum_thread + src.tiles[i][j].data.thread_elements()[0]);
|
||||
// accum_thread = (accum_thread + src.tiles[i][j].data.thread_elements()[1]);
|
||||
// }
|
||||
// T shfl_val = shfl_down_sync<T>(accum_thread, 1);
|
||||
// accum_thread = (accum_thread + shfl_val);
|
||||
// shfl_val = shfl_down_sync<T>(accum_thread, 8);
|
||||
// accum_thread = (accum_thread + shfl_val);
|
||||
// accum_thread = shfl_sync<T>(accum_thread, leader);
|
||||
//// accum_thread = metal::simd_sum(accum_thread);
|
||||
// if(false) {
|
||||
// row_accum[i][0] = accum_thread;
|
||||
// }
|
||||
// else {
|
||||
// T src_val = src_accum[i][0];
|
||||
// row_accum[i][0] = (src_val + accum_thread);
|
||||
// }
|
||||
// }
|
||||
row_reduce<base_ops::sum, RV, RT, false>(row_accum, src, src_accum, laneid);
|
||||
}
|
||||
//template<typename RV, typename RT>
|
||||
//static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
//row_sum(thread RV &row_accum, thread const RT &src, thread const RV &src_accum, const int laneid, const int warpId, threadgroup typename RT::T* smem) {
|
||||
// using T = typename RV::T;
|
||||
// using T2 = typename RV::T2;
|
||||
// using T4 = typename base_types::packing<T>::packed_four;
|
||||
// const short leader = (laneid / 16) * 16 + ((laneid / 2) % 4) * 2;
|
||||
// const short qid = laneid / 4;
|
||||
// const int offsetX = (qid & 4) + (laneid / 2) % 4;
|
||||
// const int offsetY = (qid & 2) + laneid % 2;
|
||||
// const int smem_idx_row = 32 * warpId + offsetY * 4;
|
||||
// const int smem_idx = smem_idx_row + offsetX;
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int i = 0; i < src.height; i++) {
|
||||
// T accum_thread = src.tiles[i][0].data.thread_elements()[0] + src.tiles[i][0].data.thread_elements()[1];
|
||||
// #pragma clang loop unroll(full)
|
||||
// for(int j = 1; j < src.width; j++) {
|
||||
// accum_thread = accum_thread + src.tiles[i][0].data.thread_elements()[0];
|
||||
// accum_thread = accum_thread + src.tiles[i][0].data.thread_elements()[1];
|
||||
// }
|
||||
// {
|
||||
// metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
// smem[smem_idx] = accum_thread;
|
||||
// metal::simdgroup_barrier(metal::mem_flags::mem_threadgroup);
|
||||
// T4 vals = *(threadgroup T4*)(&smem[smem_idx_row]);
|
||||
// accum_thread = vals[0] + vals[1] + vals[2] + vals[3];
|
||||
// }
|
||||
// row_accum[i][0] = src_accum[i][0] + accum_thread;
|
||||
//
|
||||
// }
|
||||
//}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Store the product of each row of the src register tile, as well as the src_accum column vector, in the row_accum column vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
row_prod(thread RV &row_accum, thread const RT &src, thread const RV &src_accum, const int laneid) {
|
||||
row_reduce<base_ops::mul, RV, RT, false>(row_accum, src, src_accum, laneid);
|
||||
}
|
||||
// two-operand col reductions. (Accumulate and REPLACE.)
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src register tile in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_max(thread RV &col_accum, thread const RT &src, const int laneid) {
|
||||
col_reduce<base_ops::max, RV, RT, true>(col_accum, src, col_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src register tile in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_min(thread RV &col_accum, thread const RT &src, const int laneid) {
|
||||
col_reduce<base_ops::min, RV, RT, true>(col_accum, src, col_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src register tile in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_sum(thread RV &col_accum, thread const RT &src, const int laneid) {
|
||||
col_reduce<base_ops::sum, RV, RT, true>(col_accum, src, col_accum, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the product of each column of the src register tile in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_prod(thread RV &col_accum, thread const RT &src, const int laneid) {
|
||||
col_reduce<base_ops::mul, RV, RT, true>(col_accum, src, col_accum, laneid);
|
||||
}
|
||||
// three-operand col reductions. (Accumulate ONTO.)
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_max(thread RV &col_accum, thread const RT &src, thread const RV &src_accum, const int laneid) {
|
||||
col_reduce<base_ops::max, RV, RT, false>(col_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_min(thread RV &col_accum, thread const RT &src, thread const RV &src_accum, const int laneid) {
|
||||
col_reduce<base_ops::min, RV, RT, false>(col_accum, src, src_accum, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the sum of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_sum(thread RV &col_accum, thread const RT &src, thread const RV &src_accum, const int laneid) {
|
||||
col_reduce<base_ops::sum, RV, RT, false>(col_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src register tile, as well as the src_accum row vector, in the col_accum row vector.
|
||||
*
|
||||
* @tparam V The vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename RV, typename RT>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_tile<RT>() && ducks::is_register_vector<RV>(), void>::type
|
||||
col_prod(thread RV &col_accum, thread const RT &src, thread const RV &src_accum, const int laneid) {
|
||||
col_reduce<base_ops::mul, RV, RT, false>(col_accum, src, src_accum, laneid);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
11
extra/thunder/include/ops/warp/register/tile/tile.metal
Normal file
11
extra/thunder/include/ops/warp/register/tile/tile.metal
Normal file
@@ -0,0 +1,11 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header for warp operations on register tiles.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "conversions.metal"
|
||||
#include "maps.metal"
|
||||
#include "mma.metal"
|
||||
#include "reductions.metal"
|
||||
162
extra/thunder/include/ops/warp/register/vec/conversions.metal
Normal file
162
extra/thunder/include/ops/warp/register/vec/conversions.metal
Normal file
@@ -0,0 +1,162 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Conversions on vectors stored in registers.
|
||||
*/
|
||||
|
||||
#pragma once // done
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
namespace detail {
|
||||
static METAL_FUNC int colstart_from_laneid(const int laneid) { // rowvec
|
||||
return (laneid % 2) * 2 + ((laneid / 8) % 2) * 4;
|
||||
}
|
||||
// 0,1,2,3,4,5,6,7 -> 0,2,1,3,8,10,9,11
|
||||
static METAL_FUNC int leader_from_col(const int col) { // rowvec
|
||||
return (col / 4) * 8 + (col / 2) % 2 + (col % 2) * 2;
|
||||
}
|
||||
// 0,2,1,3,8,10,9,11 -> 0,1,0,1,0,1,0,1
|
||||
static METAL_FUNC int idx_from_colleader(const int laneid) { // rowvec
|
||||
return ((laneid % 8) / 2) % 2; // % 2 to protect against non-leaders
|
||||
}
|
||||
|
||||
static METAL_FUNC int row_from_laneid(const int laneid) { // rowvec
|
||||
return (laneid / 2) % 4 + (laneid / 16) * 4;
|
||||
}
|
||||
// 0,1,2,3,4,5,6,7 -> 0, 2, 4, 6, 16, 18, 20, 22
|
||||
static METAL_FUNC int leader_from_row(const int row) { // rowvec
|
||||
return (row/4) * 16 + (row % 4) * 2;
|
||||
}
|
||||
|
||||
|
||||
/* ----- ducks::is_align_register_vector<RV1>() && ducks::is_naive_register_vector<RV2>() -----*/
|
||||
static METAL_FUNC int col_leader_from_naive_laneid(const int laneid) { // rowvec
|
||||
int tile_col = laneid % 8;
|
||||
int base_leader = (tile_col / 4) * 8 + (tile_col / 2) % 2 + (tile_col % 2) * 16;
|
||||
return base_leader + 2 * (laneid / 8);
|
||||
}
|
||||
|
||||
static METAL_FUNC int local_send_idx_from_col(const int laneid) {
|
||||
return laneid >= 16;
|
||||
}
|
||||
|
||||
static METAL_FUNC int src_basetile_from_laneid(const int laneid) { // rowvec
|
||||
return (laneid/ 2) % 4;
|
||||
}
|
||||
|
||||
/* ----- ducks::is_ortho_register_vector<RV1>() && ducks::is_naive_register_vector<RV2>() -----*/
|
||||
static METAL_FUNC int row_leader_from_naive_laneid(const int laneid) { // rowvec
|
||||
int row = laneid % 8;
|
||||
int base_row = (row/4) * 16 + (row % 4) * 2;
|
||||
return base_row + (laneid / 8) % 2 + (laneid >= 16) * 8;
|
||||
}
|
||||
|
||||
static METAL_FUNC int ortho_send_tile_from_laneid(const int laneid) { // rowvec
|
||||
// uint32_t MASK_1 = 0b00000000010101010000000001010101;
|
||||
uint32_t MASK_2 = 0b00000000101010100000000010101010;
|
||||
uint32_t MASK_3 = 0b01010101000000000101010100000000;
|
||||
uint32_t MASK_4 = 0b10101010000000001010101000000000;
|
||||
return ((MASK_2 >> laneid) & 1) + ((MASK_3 >> laneid) & 1) * 2 + ((MASK_4 >> laneid) & 1) * 3;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
/**
|
||||
* @brief Copies data from one register vector to another.
|
||||
*
|
||||
* @tparam RV1 The type of the destination register vector.
|
||||
* @tparam RV2 The type of the source register vector.
|
||||
* @param dst[out] The destination register vector.
|
||||
* @param src[in] The source register vector to copy from.
|
||||
*/
|
||||
template<typename RV2, typename RV1>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV1>() && ducks::is_register_vector<RV2>(), void>::type
|
||||
copy(thread RV2 &dst, thread const RV1 &src, const ushort laneid) {
|
||||
static_assert(RV1::length == RV2::length, "Outer dimensions of the register vectors must be the same.");
|
||||
using D1 = typename RV1::dtype;
|
||||
using D2 = typename RV2::dtype;
|
||||
if (metal::is_same_v<typename RV1::layout, typename RV2::layout>) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < RV1::inner_dim; j++) {
|
||||
dst[i][j] = base_types::convertor<D1, D2>::convert(src[i][j]);
|
||||
}
|
||||
}
|
||||
} else if (ducks::is_align_register_vector<RV1>() && ducks::is_ortho_register_vector<RV2>()) { // align vector -> ortho vector
|
||||
const int row = detail::row_from_laneid(laneid);
|
||||
const int laneid_src = detail::leader_from_col(row);
|
||||
const int send_idx = detail::idx_from_colleader(laneid);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
dst[i][0] = base_types::convertor<D1,D2>::convert(shfl_sync<D2>(src[i][send_idx], laneid_src));
|
||||
// dst[i][0] = 1;
|
||||
}
|
||||
} else if (ducks::is_ortho_register_vector<RV1>() && ducks::is_align_register_vector<RV2>()) { // ortho vector -> align vector
|
||||
const int col1 = detail::colstart_from_laneid(laneid);
|
||||
const int col2 = col1 + 1;
|
||||
const int laneid_src1 = detail::leader_from_row(col1);
|
||||
const int laneid_src2 = detail::leader_from_row(col2);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < RV1::outer_dim; i++) {
|
||||
dst[i][0] = base_types::convertor<D2,D1>::convert(shfl_sync<D1>(src[i][0], laneid_src1));
|
||||
dst[i][1] = base_types::convertor<D2,D1>::convert(shfl_sync<D1>(src[i][0], laneid_src2));
|
||||
}
|
||||
} else if (ducks::is_align_register_vector<RV1>() && ducks::is_naive_register_vector<RV2>()) {
|
||||
const int src_laneid = detail::col_leader_from_naive_laneid(laneid);
|
||||
int align_send_tile = detail::src_basetile_from_laneid(laneid);
|
||||
int align_local_send_idx = detail::local_send_idx_from_col(laneid);
|
||||
int naive_tile_idx = 0;
|
||||
for (int l_idx = 0;
|
||||
l_idx < RV2::length;
|
||||
l_idx += 32, naive_tile_idx++, align_send_tile += 4)
|
||||
{
|
||||
D1 send_val = 0;
|
||||
if (align_send_tile < RV1::outer_dim) send_val = src[align_send_tile][align_local_send_idx];
|
||||
D1 recieve_val = shfl_sync<D1>(send_val, src_laneid);
|
||||
if (l_idx + laneid < RV2::length) dst[l_idx / 32][0] = base_types::convertor<D2,D1>::convert(recieve_val);
|
||||
}
|
||||
} else if (ducks::is_naive_register_vector<RV1>() && ducks::is_align_register_vector<RV2>()) {
|
||||
int col1 = detail::colstart_from_laneid(laneid);
|
||||
int col2 = col1 + 1;
|
||||
for (int i = 0; i < RV2::outer_dim; i++) {
|
||||
int src1 = (i%4) * 8 + col1;
|
||||
int src2 = (i%4) * 8 + col2;
|
||||
D1 send_val = src[i / 4][0];
|
||||
D1 recieve_val1 = shfl_sync<D1>(send_val, src1);
|
||||
D1 recieve_val2 = shfl_sync<D1>(send_val, src2);
|
||||
dst[i][0] = recieve_val1;
|
||||
dst[i][1] = recieve_val2;
|
||||
}
|
||||
} else if (ducks::is_ortho_register_vector<RV1>() && ducks::is_naive_register_vector<RV2>()) {
|
||||
const int src_laneid = detail::row_leader_from_naive_laneid(laneid);
|
||||
int ortho_send_tile = detail::ortho_send_tile_from_laneid(laneid);
|
||||
int naive_tile_idx = 0;
|
||||
for (int l_idx = 0; l_idx < RV2::length;
|
||||
l_idx += 32, naive_tile_idx++, ortho_send_tile += 4)
|
||||
{
|
||||
D1 send_val = 10;
|
||||
if (ortho_send_tile < RV1::outer_dim) send_val = src[ortho_send_tile][0];
|
||||
D1 recieve_val = shfl_sync<D1>(send_val, src_laneid);
|
||||
if (l_idx + laneid < RV2::length) dst[l_idx / 32][0] = base_types::convertor<D2,D1>::convert(recieve_val);
|
||||
}
|
||||
} else if (ducks::is_naive_register_vector<RV1>() && ducks::is_ortho_register_vector<RV2>()) {
|
||||
int row = detail::row_from_laneid(laneid);
|
||||
for (int i = 0; i < RV2::outer_dim; i++) {
|
||||
int src_laneid = (i%4) * 8 + row;
|
||||
D1 send_val = src[i / 4][0];
|
||||
D1 recieve_val = shfl_sync<D1>(send_val, src_laneid);
|
||||
dst[i][0] = recieve_val;
|
||||
}
|
||||
}
|
||||
else {
|
||||
// static_assert(RV1::inner_dim == RV2::inner_dim, "Something has gone deeply wrong with how register vectors were instantiated.");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
288
extra/thunder/include/ops/warp/register/vec/maps.metal
Normal file
288
extra/thunder/include/ops/warp/register/vec/maps.metal
Normal file
@@ -0,0 +1,288 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Maps on vectors stored in registers.
|
||||
*/
|
||||
|
||||
#pragma once // doneington
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
/* ---------- Vector Maps ---------- */
|
||||
|
||||
/**
|
||||
* @brief Perform a unary operation on a vector.
|
||||
*
|
||||
* @tparam op The unary operation to perform.
|
||||
* @tparam T The type of the vector.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param src[in] The source vector to perform the operation on.
|
||||
*/
|
||||
template<typename op, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
unary_op(thread RV &dst, thread const RV &src) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.inner_dim; j++) {
|
||||
dst[i][j] = op::template op<typename RV::dtype>(src[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on two vectors.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vectors.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param lhs[in] The left-hand side vector for the operation.
|
||||
* @param rhs[in] The right-hand side vector for the operation.
|
||||
*/
|
||||
template<typename op, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
bin_op(thread RV &dst, thread const RV &lhs, thread const RV &rhs) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.inner_dim; j++) {
|
||||
dst[i][j] = op::template op<typename RV::dtype>(lhs[i][j], rhs[i][j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on a vector and a scalar.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vector.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param src[in] The source vector for the operation.
|
||||
* @param param[in] The scalar parameter for the operation.
|
||||
*/
|
||||
template<typename op, typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
bin_op(thread RV &dst, thread const RV &src, thread const typename RV::dtype ¶m) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 0; i < dst.outer_dim; i++) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int j = 0; j < dst.inner_dim; j++) {
|
||||
dst[i][j] = op::template op<typename RV::dtype>(src[i][j], param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// ---- const ops ----
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a register vector to zero.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector to be set to zero.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
zero(thread RV &dst) {
|
||||
unary_op<base_ops::zero, RV>(dst, dst);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a register vector to one.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector to be set to one.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
one(thread RV &dst) {
|
||||
unary_op<base_ops::one, RV>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a register vector to positive infinity.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector to be set to positive infinity.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
pos_infty(thread RV &dst) {
|
||||
unary_op<base_ops::pos_infty, RV>(dst, dst);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a register vector to negative infinity.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector to be set to negative infinity.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
neg_infty(thread RV &dst) {
|
||||
unary_op<base_ops::neg_infty, RV>(dst, dst);
|
||||
}
|
||||
|
||||
// ---- unary ops ----
|
||||
|
||||
/**
|
||||
* @brief Copies the elements from one register vector to another.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the source vector.
|
||||
* @param dst[out] Destination vector where the elements will be copied to.
|
||||
* @param src[in] Source vector to copy the elements from.
|
||||
*/
|
||||
template<typename RV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>() && ducks::base_types::isT1Type<U>(), void>::type
|
||||
copy(thread RV &dst, thread const U &src) {
|
||||
bin_op<base_ops::copy2, RV>(dst, dst, src); // the second arg is ignored here.
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
exp(thread RV &dst, thread const RV &src) {
|
||||
unary_op<base_ops::exp, RV>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a register vector, in base 2.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
exp2(thread RV &dst, thread const RV &src) {
|
||||
unary_op<base_ops::exp2, RV>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the natural logarithm function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
log(thread RV &dst, thread const RV &src) {
|
||||
unary_op<base_ops::log, RV>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the absolute value function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the absolute values will be stored.
|
||||
* @param src[in] Source vector to apply the absolute value function to.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
abs(thread RV &dst, thread const RV &src) {
|
||||
unary_op<base_ops::abs, RV>(dst, src);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the rectified linear unit (ReLU) function element-wise to a register vector.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @param dst[out] Destination vector where the ReLU values will be stored.
|
||||
* @param src[in] Source vector to apply the ReLU function to.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
relu(thread RV &dst, thread const RV &src) {
|
||||
unary_op<base_ops::relu, RV>(dst, src);
|
||||
}
|
||||
|
||||
// ---- binary ops ----
|
||||
|
||||
/**
|
||||
* @brief Computes the element-wise maximum of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the maximum values will be stored.
|
||||
* @param lhs[in] First vector for the maximum operation.
|
||||
* @param rhs[in] Second vector for the maximum operation.
|
||||
*/
|
||||
template<typename RV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
max(thread RV &dst, thread const RV &lhs, thread const U &rhs) {
|
||||
bin_op<base_ops::max, RV>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise minimum of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the minimum values will be stored.
|
||||
* @param lhs[in] First vector for the minimum operation.
|
||||
* @param rhs[in] Second vector for the minimum operation.
|
||||
*/
|
||||
template<typename RV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
min(thread RV &dst, thread const RV &lhs, thread const U &rhs) {
|
||||
bin_op<base_ops::min, RV>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise sum of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the sum values will be stored.
|
||||
* @param lhs[in] First vector for the sum operation.
|
||||
* @param rhs[in] Second vector for the sum operation.
|
||||
*/
|
||||
template<typename RV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
add(thread RV &dst, thread const RV &lhs, thread const U &rhs) {
|
||||
bin_op<base_ops::sum, RV>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise difference of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the difference values will be stored.
|
||||
* @param lhs[in] First vector for the difference operation.
|
||||
* @param rhs[in] Second vector for the difference operation.
|
||||
*/
|
||||
template<typename RV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
sub(thread RV &dst, thread const RV &lhs, thread const U &rhs) {
|
||||
bin_op<base_ops::sub, RV>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise product of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the product values will be stored.
|
||||
* @param lhs[in] First vector for the product operation.
|
||||
* @param rhs[in] Second vector for the product operation.
|
||||
*/
|
||||
template<typename RV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
mul(thread RV &dst, thread const RV &lhs, thread const U &rhs) {
|
||||
bin_op<base_ops::mul, RV>(dst, lhs, rhs);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise division of two register vectors.
|
||||
*
|
||||
* @tparam T Register vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the division values will be stored.
|
||||
* @param lhs[in] First vector for the division operation.
|
||||
* @param rhs[in] Second vector for the division operation.
|
||||
*/
|
||||
template<typename RV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
div(thread RV &dst, thread const RV &lhs, thread const U &rhs) {
|
||||
bin_op<base_ops::div, RV>(dst, lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
236
extra/thunder/include/ops/warp/register/vec/reductions.metal
Normal file
236
extra/thunder/include/ops/warp/register/vec/reductions.metal
Normal file
@@ -0,0 +1,236 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Reductions on vectors stored in registers.
|
||||
*/
|
||||
|
||||
#pragma once // done
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
/* ---------- Vector Reductions ---------- */
|
||||
|
||||
/**
|
||||
* @brief Performs a reduction operation on elements of a register vector within a warp.
|
||||
*
|
||||
* This function applies a specified operation to reduce the elements of a register vector `src` to a single value.
|
||||
* The result is stored in `accum`. If the `reset` parameter is true, the reduction includes an initial value `src_accum`.
|
||||
* The reduction operation is performed in a warp-wide context, ensuring synchronization between threads in the warp.
|
||||
*
|
||||
* @tparam op The operation to perform on the elements. Must provide a static `op` method.
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @tparam reset A boolean flag indicating whether to include an initial value in the reduction.
|
||||
* @param[out] accum The result of the reduction operation.
|
||||
* @param[in] src The register vector to reduce.
|
||||
* @param[in] src_accum The initial value to include in the reduction if `reset` is false.
|
||||
*/
|
||||
template<typename op, typename RV, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
reduce(
|
||||
thread typename RV::T &dst_accum,
|
||||
thread const RV &src,
|
||||
thread const typename RV::T &src_accum,
|
||||
const ushort laneid) {
|
||||
using T = typename RV::T;
|
||||
if (ducks::is_ortho_register_vector<RV>()) { // col vector
|
||||
T accum = src[0][0];
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < src.outer_dim; i++) {
|
||||
accum = op::template op<T>(accum, src[i][0]);
|
||||
}
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 16));
|
||||
if (!reset) accum = op::template op<T>(accum, src_accum);
|
||||
dst_accum = shfl_sync(accum, 0);
|
||||
}
|
||||
else if (ducks::is_align_register_vector<RV>()) { // row vector
|
||||
T accum = op::template op<T>(src[0][0], src[0][1]);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < src.outer_dim; i++) {
|
||||
accum = op::template op<T>(accum, src[i][0]);
|
||||
accum = op::template op<T>(accum, src[i][1]);
|
||||
}
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 8));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
|
||||
accum = shfl_sync<T>(accum, 0);
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
if (!reset) accum = op::template op<T>(accum, src_accum);
|
||||
dst_accum = accum;
|
||||
}
|
||||
else if (ducks::is_naive_register_vector<RV>()) {
|
||||
// T accum = src[0][0];
|
||||
T accum;
|
||||
if (laneid < src.length) accum = src[0][0];
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = 1; i < src.outer_dim; i++) {
|
||||
if (i*SIMD_THREADS + laneid < src.length) {
|
||||
accum = op::template op<T>(accum, src[i][0]);
|
||||
}
|
||||
}
|
||||
if (src.length == 8) {
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
|
||||
} else if (src.length == 16) {
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 8));
|
||||
} else if (src.length == 24) {
|
||||
if (laneid < 24) {
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
|
||||
|
||||
T shfle_val = shfl_down_sync<T>(accum, 8);
|
||||
if (laneid < 16) {
|
||||
accum = op::template op<T>(accum, shfle_val);
|
||||
}
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 16));
|
||||
}
|
||||
|
||||
} else {
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 8));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 16));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
}
|
||||
|
||||
if (!reset) accum = op::template op<T>(accum, src_accum);
|
||||
dst_accum = shfl_sync(accum, 0);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the maximum element in a register vector.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] max_val The maximum value found in the vector.
|
||||
* @param[in] src The register vector to find the maximum in.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
max(thread typename base_types::packing<typename RV::dtype>::unpacked_type &max_val, thread const RV &src, const ushort laneid) {
|
||||
reduce<base_ops::max, RV, true>(max_val, src, max_val, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum element in a register vector.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] min_val The minimum value found in the vector.
|
||||
* @param[in] src The register vector to find the minimum in.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
min(thread typename base_types::packing<typename RV::dtype>::unpacked_type &min_val, thread const RV &src, const ushort laneid) {
|
||||
reduce<base_ops::min, RV, true>(min_val, src, min_val, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the sum of elements in a register vector.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] sum_val The sum of the values in the vector.
|
||||
* @param[in] src The register vector to sum.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
sum(thread typename base_types::packing<typename RV::dtype>::unpacked_type &sum_val, thread const RV &src, const ushort laneid) {
|
||||
reduce<base_ops::sum, RV, true>(sum_val, src, sum_val, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the product of elements in a register vector.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] prod_val The product of the values in the vector.
|
||||
* @param[in] src The register vector to multiply.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
prod(thread typename base_types::packing<typename RV::dtype>::unpacked_type &prod_val, thread const RV &src, const ushort laneid) {
|
||||
reduce<base_ops::mul, RV, true>(prod_val, src, prod_val, laneid);
|
||||
}
|
||||
|
||||
// Three operand versions.
|
||||
|
||||
/**
|
||||
* @brief Finds the maximum element in a register vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] max_val The maximum value found in the vector, accumulated with src_accum.
|
||||
* @param[in] src The register vector to find the maximum in.
|
||||
* @param[in] src_accum The initial value to accumulate with the maximum value found.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
max(thread typename base_types::packing<typename RV::dtype>::unpacked_type &max_val,
|
||||
thread const RV &src,
|
||||
thread const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum, const ushort laneid) {
|
||||
reduce<base_ops::max, RV, false>(max_val, src, src_accum, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum element in a register vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] min_val The minimum value found in the vector, accumulated with src_accum.
|
||||
* @param[in] src The register vector to find the minimum in.
|
||||
* @param[in] src_accum The initial value to accumulate with the minimum value found.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
min(thread typename base_types::packing<typename RV::dtype>::unpacked_type &min_val,
|
||||
thread const RV &src,
|
||||
thread const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum, const ushort laneid) {
|
||||
reduce<base_ops::min, RV, false>(min_val, src, src_accum, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the sum of elements in a register vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] sum_val The sum of the values in the vector, accumulated with src_accum.
|
||||
* @param[in] src The register vector to sum.
|
||||
* @param[in] src_accum The initial value to accumulate with the sum of the vector.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
sum(thread typename base_types::packing<typename RV::dtype>::unpacked_type &sum_val,
|
||||
thread const RV &src,
|
||||
thread const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum, const ushort laneid) {
|
||||
reduce<base_ops::sum, RV, false>(sum_val, src, src_accum, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the product of elements in a register vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam RV The type of the register vector. Must satisfy the `ducks::rv::all` concept.
|
||||
* @param[out] prod_val The product of the values in the vector, accumulated with src_accum.
|
||||
* @param[in] src The register vector to multiply.
|
||||
* @param[in] src_accum The initial value to accumulate with the product of the vector.
|
||||
*/
|
||||
template<typename RV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_register_vector<RV>(), void>::type
|
||||
prod(thread typename base_types::packing<typename RV::dtype>::unpacked_type &prod_val,
|
||||
thread const RV &src,
|
||||
thread const typename base_types::packing<typename RV::dtype>::unpacked_type &src_accum, const ushort laneid) {
|
||||
reduce<base_ops::mul, RV, false>(prod_val, src, src_accum, laneid);
|
||||
}
|
||||
|
||||
}
|
||||
4
extra/thunder/include/ops/warp/register/vec/vec.metal
Normal file
4
extra/thunder/include/ops/warp/register/vec/vec.metal
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "conversions.metal"
|
||||
#include "maps.metal"
|
||||
#include "reductions.metal"
|
||||
3
extra/thunder/include/ops/warp/shared/shared.metal
Normal file
3
extra/thunder/include/ops/warp/shared/shared.metal
Normal file
@@ -0,0 +1,3 @@
|
||||
#pragma once
|
||||
#include "tile/tile.metal"
|
||||
#include "vec/vec.metal"
|
||||
59
extra/thunder/include/ops/warp/shared/tile/conversions.metal
Normal file
59
extra/thunder/include/ops/warp/shared/tile/conversions.metal
Normal file
@@ -0,0 +1,59 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Conversions between shared tile types.
|
||||
*/
|
||||
|
||||
#pragma once // not done, add subtile
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
/* ---------- COPIES ---------- */
|
||||
/**
|
||||
* @brief Copies data from one shared memory tile to another, potentially with different data types and layouts.
|
||||
*
|
||||
* @tparam T The data type of the destination tile.
|
||||
* @tparam U The data type of the source tile.
|
||||
* @tparam _height The height of the tile.
|
||||
* @tparam _width The width of the tile.
|
||||
* @tparam L1 The layout of the destination tile.
|
||||
* @tparam L2 The layout of the source tile.
|
||||
* @param[out] dst The destination tile.
|
||||
* @param[in] src The source tile.
|
||||
*/
|
||||
template<typename T, typename U, int _height, int _width>
|
||||
static METAL_FUNC void copy(threadgroup st<T, _height, _width> &dst, threadgroup const st<U, _height, _width> &src, const ushort laneid) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid; i < dst.num_elements; i+=mittens::SIMD_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst[{row, col}] = base_types::convertor<T, U>::convert(src[{row, col}]);
|
||||
}
|
||||
}
|
||||
|
||||
///* ---------- SUBTILE ---------- */
|
||||
//
|
||||
///**
|
||||
//* @brief Returns a reference to a subtile of the given shared tile.
|
||||
//*
|
||||
//* @tparam subtile_height The height of the subtile.
|
||||
//* @tparam subtile_width The width of the subtile.
|
||||
//* @tparam ST The type of the input tile, which must satisfy the ducks::st::all concept.
|
||||
//* @param src The input tile.
|
||||
//* @param row_idx The row index of the subtile, in units of subtile_height*16 elements.
|
||||
//* @param col_idx The col index of the subtile, in units of subtile_width*16 elements.
|
||||
//* @return A reference to the subtile.
|
||||
//*
|
||||
//* @note The subtile {height, width} must evenly divide the tile {height, width}.
|
||||
//*/
|
||||
//template<int subtile_height, int subtile_width, ducks::st::all ST>
|
||||
//__device__ inline typename ST::subtile<subtile_height, subtile_width> subtile_inplace(ST &src, int row_idx, int col_idx) {
|
||||
// static_assert(ST::height % subtile_height == 0);
|
||||
// static_assert(ST::width % subtile_width == 0);
|
||||
// return typename ST::subtile<subtile_height, subtile_width>(
|
||||
// &src[0], subtile_height*16*row_idx, subtile_width*16*col_idx
|
||||
// );
|
||||
//}
|
||||
|
||||
}
|
||||
|
||||
485
extra/thunder/include/ops/warp/shared/tile/maps.metal
Normal file
485
extra/thunder/include/ops/warp/shared/tile/maps.metal
Normal file
@@ -0,0 +1,485 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Warp-scope maps on shared tiles.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
/* ---------- Uniform tile maps (independent of layout) ---------- */
|
||||
|
||||
/**
|
||||
* @brief Performs a uniform unary operation on a tile.
|
||||
*
|
||||
* This function applies a given unary operation to each element of the source tile and stores the result in the destination tile.
|
||||
* The operation is applied independently to each element, without considering its position or the values of neighboring elements.
|
||||
*
|
||||
* @tparam op The unary operation to be applied. Must be specialized to support operation on the data type of T.
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the unary operation is applied.
|
||||
*/
|
||||
template<typename op, typename ST> // T2, w, h can be inferred from dst as long as op is specialized
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
unary_map(threadgroup ST &dst, threadgroup const ST &src, const ushort laneid) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid; i < ST::num_elements; i += SIMD_THREADS) {
|
||||
dst.data[i] = op::template op<typename ST::dtype>(src.data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Performs a uniform binary operation on a tile with a scalar parameter.
|
||||
*
|
||||
* This function applies a given binary operation to each element of the source tile and a scalar parameter, then stores the result in the destination tile.
|
||||
* The operation is applied independently to each element, treating the scalar parameter as the second operand for each operation.
|
||||
*
|
||||
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the scalar parameter.
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the binary operation is applied.
|
||||
* @param[in] param The scalar parameter to be used as the second operand in the binary operation.
|
||||
*/
|
||||
template<typename op, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
bin_map(threadgroup ST &dst, threadgroup const ST &src, thread const typename ST::dtype ¶m, const short laneid) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid; i < dst.num_elements; i += SIMD_THREADS) {
|
||||
dst.data[i] = op::template op<typename ST::dtype>(src.data[i], param);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs a uniform binary operation on two tiles.
|
||||
*
|
||||
* This function applies a given binary operation to corresponding elements of two source tiles and stores the result in the destination tile.
|
||||
* The operation is applied independently to each pair of elements, without considering their positions or the values of neighboring elements.
|
||||
*
|
||||
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T.
|
||||
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile to which the binary operation is applied.
|
||||
* @param[in] rhs The second source tile to which the binary operation is applied.
|
||||
*/
|
||||
template<typename op, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
bin_map(threadgroup ST &dst, threadgroup const ST &lhs, threadgroup const ST &rhs, const ushort laneid) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid; i < dst.num_elements; i += SIMD_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst.data[i] = op::template op<typename ST::dtype>(lhs.data[i], rhs.data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs a row-wise binary operation on a tile with a vector.
|
||||
*
|
||||
* This function applies a given binary operation to each row of the source tile and the corresponding element of the source vector,
|
||||
* then stores the result in the destination tile. The operation is applied independently to each row, using the vector element as
|
||||
* the second operand for each element in the row.
|
||||
*
|
||||
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements.
|
||||
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam V The type of the vector. Must have the same data type as T.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the binary operation is applied.
|
||||
* @param[in] vec The source vector containing the second operand for each row operation.
|
||||
*/
|
||||
template<typename op, typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>, void>::type
|
||||
row_map(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &vec, const ushort laneid) {
|
||||
static_assert(metal::is_same<typename ST::dtype, typename SV::dtype>::value, "Tile and vector must have the same data type");
|
||||
static_assert(SV::length == ST::rows, "Vector length must match the number of rows in the tile");
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid; i < dst.num_elements; i += SIMD_THREADS) {
|
||||
int row = i/ST::cols, col = i%ST::cols;
|
||||
dst[{row, col}] = op::template op<typename ST::dtype>(src[{row, col}], vec[row]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Performs a column-wise binary operation on a tile with a vector.
|
||||
*
|
||||
* This function applies a given binary operation to each column of the source tile and the corresponding element of the source vector,
|
||||
* then stores the result in the destination tile. The operation is applied independently to each column, using the vector element as
|
||||
* the second operand for each element in the column.
|
||||
*
|
||||
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements.
|
||||
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam V The type of the vector. Must have the same data type as T.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the binary operation is applied.
|
||||
* @param[in] vec The source vector containing the second operand for each column operation.
|
||||
*/
|
||||
template<typename op, typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_map(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &vec, const ushort laneid) {
|
||||
static_assert(metal::is_same<typename ST::dtype, typename SV::dtype>::value, "Tile and vector must have the same data type");
|
||||
static_assert(SV::length == ST::cols, "Vector length must match the number of columns in the tile");
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid; i < dst.num_elements; i += SIMD_THREADS) {
|
||||
int row = i/dst.cols, col = i%dst.cols;
|
||||
dst[{row, col}] = op::template op<typename ST::dtype>(src[{row, col}], vec[col]);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// const maps
|
||||
/**
|
||||
* @brief Sets all elements of the destination tile to zero.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
zero(threadgroup ST &dst, const ushort laneid) {
|
||||
unary_map<base_ops::zero, ST>(dst, dst, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of the destination tile to one.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
one(threadgroup ST &dst, const ushort laneid) {
|
||||
unary_map<base_ops::one, ST>(dst, dst, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of the destination tile to positive infinity.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
pos_infty(threadgroup ST &dst, const ushort laneid) {
|
||||
unary_map<base_ops::pos_infty, ST>(dst, dst, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of the destination tile to negative infinity.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
neg_infty(threadgroup ST &dst, const ushort laneid) {
|
||||
unary_map<base_ops::neg_infty, ST>(dst, dst, laneid);
|
||||
}
|
||||
|
||||
// unary maps
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of the source tile and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the exponential function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
exp(threadgroup ST &dst, threadgroup const ST &src, const ushort laneid) {
|
||||
unary_map<base_ops::exp, ST>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function to each element of the source tile and stores the result in the destination tile, in base 2.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the exponential function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
exp2(threadgroup ST &dst, threadgroup const ST &src, const ushort laneid) {
|
||||
unary_map<base_ops::exp2, ST>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the natural logarithm function to each element of the source tile and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the natural logarithm function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
log(threadgroup ST &dst, threadgroup const ST &src, const ushort laneid) {
|
||||
unary_map<base_ops::log, ST>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the absolute function to each element of the source tile and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the absolute function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
abs(threadgroup ST &dst, threadgroup const ST &src, const ushort laneid) {
|
||||
unary_map<base_ops::abs, ST>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the rectified linear unit function to each element of the source tile and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source tile to which the rectified linear unit function is applied.
|
||||
*/
|
||||
template<typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
relu(threadgroup ST &dst, const threadgroup ST &src, const ushort laneid) {
|
||||
unary_map<base_ops::relu, ST>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Copies the elements of the source tile to the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] src The source data to be copied.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
copy(threadgroup ST &dst, thread const U &src, const ushort laneid) {
|
||||
bin_map<base_ops::copy2, ST>(dst, dst, src, laneid);
|
||||
}
|
||||
|
||||
// uniform binary maps
|
||||
/**
|
||||
* @brief Finds the maximum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
max(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_map<base_ops::max, ST>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Finds the minimum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
min(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_map<base_ops::min, ST>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Adds each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
add(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_map<base_ops::sum, ST>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
sub(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_map<base_ops::sub, ST>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
mul(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_map<base_ops::mul, ST>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Divides each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
||||
*
|
||||
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
||||
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
||||
* @param[out] dst The destination tile where the results are stored.
|
||||
* @param[in] lhs The first source tile.
|
||||
* @param[in] rhs The second source data.
|
||||
*/
|
||||
template<typename ST, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
div(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_map<base_ops::div, ST>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
|
||||
// Row and col maps
|
||||
|
||||
/**
|
||||
* @brief Adds row values to each row of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the addition on.
|
||||
* @param row_values[in] Column vector containing values to add to each row.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
add_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const ushort laneid) {
|
||||
row_map<base_ops::sum, ST, SV>(dst, src, row_values, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts row values from each row of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the subtraction on.
|
||||
* @param row_values[in] Column vector containing values to subtract from each row.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
sub_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const ushort laneid) {
|
||||
row_map<base_ops::sub, ST, SV>(dst, src, row_values, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies each row of a tile by row values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the multiplication on.
|
||||
* @param row_values[in] Column vector containing values to multiply each row by.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
mul_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const ushort laneid) {
|
||||
row_map<base_ops::mul, ST, SV>(dst, src, row_values, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Divides each row of a tile by row values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the division on.
|
||||
* @param row_values[in] Column vector containing values to divide each row by.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_tile<SV>(), void>::type
|
||||
div_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const ushort laneid) {
|
||||
row_map<base_ops::div, ST, SV>(dst, src, row_values, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Broadcast a vector into into a tile's rows.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Column vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param row_values[in] Column vector containing values to broadcast into rows.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
broadcast_row(threadgroup ST &dst, threadgroup const SV &row_values, const ushort laneid) {
|
||||
row_map<base_ops::copy2, ST, SV>(dst, dst, row_values, laneid);
|
||||
}
|
||||
|
||||
|
||||
// col maps
|
||||
/**
|
||||
* @brief Adds column values to each column of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the addition on.
|
||||
* @param col_values[in] Row vector containing values to add to each column.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
add_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const ushort laneid) {
|
||||
col_map<base_ops::sum, ST, SV>(dst, src, col_values, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Subtracts column values from each column of a tile.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the subtraction on.
|
||||
* @param col_values[in] Row vector containing values to subtract from each column.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
sub_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const ushort laneid) {
|
||||
col_map<base_ops::sub, ST, SV>(dst, src, col_values, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Multiplies each column of a tile by column values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the multiplication on.
|
||||
* @param col_values[in] Row vector containing values to multiply each column by.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
mul_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const ushort laneid) {
|
||||
col_map<base_ops::mul, ST, SV>(dst, src, col_values, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Divides each column of a tile by column values.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param src[in] Source tile to apply the division on.
|
||||
* @param col_values[in] Row vector containing values to divide each column by.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
||||
div_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const ushort laneid) {
|
||||
col_map<base_ops::div, ST, SV>(dst, src, col_values, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Broadcast a vector into into a tile's columns.
|
||||
*
|
||||
* @tparam T Tile type.
|
||||
* @tparam V Row vector type.
|
||||
* @param dst[out] Destination tile where the result is stored.
|
||||
* @param row_values[in] Row vector containing values to broadcast into cols.
|
||||
*/
|
||||
template<typename ST, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
broadcast_col(threadgroup ST &dst, threadgroup const SV &col_values, const ushort laneid) {
|
||||
col_map<base_ops::copy2, ST, SV>(dst, dst, col_values, laneid);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
295
extra/thunder/include/ops/warp/shared/tile/reductions.metal
Normal file
295
extra/thunder/include/ops/warp/shared/tile/reductions.metal
Normal file
@@ -0,0 +1,295 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Warp-scope reductions on shared tiles.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
/**
|
||||
* Performs row-wise reduction on a matrix using a specified operation.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type with row layout.
|
||||
* @param row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param src The source matrix on which to perform the reduction.
|
||||
* @param src_accum The initial value of the accumulator, used when reset is false.
|
||||
* @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
*/
|
||||
template<typename op, typename SV, typename ST, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_reduce(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
using dtype = typename SV::dtype;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int row = laneid; row < ST::rows; row += mittens::SIMD_THREADS) {
|
||||
dtype accum = src[{row, 0}];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int col = 1; col < src.cols; col++) {
|
||||
accum = op::template op<dtype>(accum, src[{row, col}]);
|
||||
}
|
||||
if (reset) {
|
||||
row_accum[row] = accum;
|
||||
} else {
|
||||
row_accum[row] = op::template op<dtype>(src_accum[row], accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Performs column-wise reduction on a matrix using a specified operation.
|
||||
*
|
||||
* @tparam op The operation to be applied for reduction.
|
||||
* @tparam V The shared vector type for the column accumulator.
|
||||
* @tparam T The shared matrix type with column layout.
|
||||
* @param col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param src The source matrix on which to perform the reduction.
|
||||
* @param src_accum The initial value of the accumulator, used when reset is false.
|
||||
* @param reset A boolean flag indicating whether to reset the accumulator (ignore src_accum) or not.
|
||||
*/
|
||||
template<typename op, typename SV, typename ST, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_reduce(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
using dtype = typename SV::dtype;
|
||||
#pragma clang loop unroll(full)
|
||||
for (int col = laneid; col < src.cols; col += mittens::SIMD_THREADS) {
|
||||
dtype accum = src[int2(0, col)];
|
||||
#pragma clang loop unroll(full)
|
||||
for (int row = 1; row < src.rows; row++) {
|
||||
accum = op::template op<dtype>(accum, src[int2(row, col)]);
|
||||
}
|
||||
if (reset) {
|
||||
col_accum[col] = accum;
|
||||
} else {
|
||||
col_accum[col] = op::template op<dtype>(src_accum[col], accum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_max(threadgroup SV &row_accum, threadgroup const ST &src, const ushort laneid) {
|
||||
row_reduce<base_ops::max, SV, ST, true>(row_accum, src, row_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_min(threadgroup SV &row_accum, threadgroup const ST &src, const ushort laneid) {
|
||||
row_reduce<base_ops::min, SV, ST, true>(row_accum, src, row_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_sum(threadgroup SV &row_accum, threadgroup const ST &src, const ushort laneid) {
|
||||
row_reduce<base_ops::sum, SV, ST, true>(row_accum, src, row_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src shared matrix in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_prod(threadgroup SV &row_accum, threadgroup const ST &src, const ushort laneid) {
|
||||
row_reduce<base_ops::mul, SV, ST, true>(row_accum, src, row_accum, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_max(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
row_reduce<base_ops::max, SV, ST, false>(row_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_min(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
row_reduce<base_ops::min, SV, ST, false>(row_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_sum(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
row_reduce<base_ops::sum, SV, ST, false>(row_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each row of the src shared matrix, as well as the src_accum shared vector, in the row_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] row_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
row_prod(threadgroup SV &row_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
row_reduce<base_ops::mul, SV, ST, false>(row_accum, src, src_accum, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_max(threadgroup SV &col_accum, threadgroup const ST &src, const ushort laneid) {
|
||||
col_reduce<base_ops::max, SV, ST, true>(col_accum, src, col_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_min(threadgroup SV &col_accum, threadgroup const ST &src, const ushort laneid) {
|
||||
col_reduce<base_ops::min, SV, ST, true>(col_accum, src, col_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_sum(threadgroup SV &col_accum, threadgroup const ST &src, const ushort laneid) {
|
||||
col_reduce<base_ops::sum, SV, ST, true>(col_accum, src, col_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src shared matrix in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_prod(threadgroup SV &col_accum, threadgroup const ST &src, const ushort laneid) {
|
||||
col_reduce<base_ops::mul, SV, ST, true>(col_accum, src, col_accum, laneid);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Store the maximum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_max(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
col_reduce<base_ops::max, SV, ST, false>(col_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the minimum of each column of the src shared matrix, as well as the src_accum shared vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_min(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
col_reduce<base_ops::min, SV, ST, false>(col_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the sum of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_sum(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
col_reduce<base_ops::sum, SV, ST, false>(col_accum, src, src_accum, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Store the product of each column of the src shared tile, as well as the src_accum row vector, in the col_accum shared vector.
|
||||
*
|
||||
* @tparam V The shared vector type for the row accumulator.
|
||||
* @tparam T The shared matrix type.
|
||||
* @param[out] col_accum The accumulator where the result of the reduction is stored.
|
||||
* @param[in] src The source matrix on which to perform the reduction.
|
||||
* @param[in] src_accum The initial value of the accumulator, used when accumulating onto an existing value.
|
||||
*/
|
||||
template<typename SV, typename ST>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
||||
col_prod(threadgroup SV &col_accum, threadgroup const ST &src, threadgroup const SV &src_accum, const ushort laneid) {
|
||||
col_reduce<base_ops::mul, SV, ST, false>(col_accum, src, src_accum, laneid);
|
||||
}
|
||||
|
||||
}
|
||||
4
extra/thunder/include/ops/warp/shared/tile/tile.metal
Normal file
4
extra/thunder/include/ops/warp/shared/tile/tile.metal
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "conversions.metal"
|
||||
#include "maps.metal"
|
||||
#include "reductions.metal"
|
||||
60
extra/thunder/include/ops/warp/shared/vec/conversions.metal
Normal file
60
extra/thunder/include/ops/warp/shared/vec/conversions.metal
Normal file
@@ -0,0 +1,60 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Warp-scope conversions on shared vectors.
|
||||
*/
|
||||
|
||||
#pragma once // done!
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
|
||||
/**
|
||||
* @brief Copies data from one shared vector to another, converting data types if necessary.
|
||||
*
|
||||
* This function copies data from the source shared vector `src` to the destination shared vector `dst`.
|
||||
* If the data types of `src` and `dst` are the same, it performs a direct memory copy. Otherwise, it
|
||||
* converts each element from the source data type to the destination data type using the appropriate
|
||||
* converter before copying.
|
||||
*
|
||||
* @tparam SV1 The type of the destination shared vector, must satisfy the ducks::sv::all concept.
|
||||
* @tparam SV2 The type of the source shared vector, must satisfy the ducks::sv::all concept.
|
||||
* @param[out] dst The destination shared vector.
|
||||
* @param[in] src The source shared vector.
|
||||
* @note The lengths of `src` and `dst` must be equal. This is enforced at compile time.
|
||||
*/
|
||||
template<typename SV1, typename SV2>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV1>() && ducks::is_shared_vector<SV2>(), void>::type
|
||||
copy(threadgroup SV1 &dst, threadgroup const SV2 &src, const ushort laneid) {
|
||||
static_assert(SV1::length == SV2::length, "Source and destination vectors must have the same length.");
|
||||
#pragma clang loop unroll(full)
|
||||
for(int i = laneid; i < dst.length; i+=SIMD_THREADS) {
|
||||
dst[i] = base_types::convertor<typename SV1::dtype, typename SV2::dtype>::convert(src[i]);
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- SUBVEC ---------- */
|
||||
|
||||
/**
|
||||
* @brief Returns a reference to a subvec of a given shared vector
|
||||
*
|
||||
* @tparam subvec_tiles The length, in subtiles, of the subvec.
|
||||
* @tparam SV The type of the input vector, which must satisfy the ducks::sv::all concept.
|
||||
* @param src The input tile.
|
||||
* @param vec_idx The index of the subtile, in units of subvec_tiles*16 elements.
|
||||
* @return A reference to the subvec.
|
||||
*
|
||||
* @note The subvec length must evenly divide the vector length.
|
||||
*/
|
||||
template<int subvec_tiles, typename SV>
|
||||
//using subvec = typename SV::template subvec<SV::length>;
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), threadgroup typename SV::template subvec<typename SV::dtype, subvec_tiles>&>::type
|
||||
subvec_inplace(threadgroup SV &src, int vec_idx) {
|
||||
return *(threadgroup typename SV::template subvec<typename SV::dtype, subvec_tiles>*)(&src[vec_idx*TILE_DIM*subvec_tiles]);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
278
extra/thunder/include/ops/warp/shared/vec/maps.metal
Normal file
278
extra/thunder/include/ops/warp/shared/vec/maps.metal
Normal file
@@ -0,0 +1,278 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Warp-scope maps on shared vectors.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
/**
|
||||
* @brief Applies a unary operation to each element of a shared memory vector.
|
||||
*
|
||||
* @tparam op Unary operation type.
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector in which to store the result.
|
||||
* @param src[in] Source vector to apply the unary operation.
|
||||
*/
|
||||
template<typename op, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
unary_op(threadgroup SV &dst, threadgroup const SV &src, const ushort laneid) {
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int cur = laneid; cur < SV::length; cur+=SIMD_THREADS) {
|
||||
dst[cur] = op::template op<typename SV::dtype>(src[cur]);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on two shared vectors.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vectors.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param lhs[in] The left-hand side vector for the operation.
|
||||
* @param rhs[in] The right-hand side vector for the operation.
|
||||
*/
|
||||
template<typename op, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
bin_op(threadgroup SV &dst, threadgroup const SV &lhs, threadgroup const SV &rhs, const ushort laneid) {
|
||||
#pragma clang loop unroll(full)
|
||||
for(int cur = laneid; cur < SV::length; cur+=SIMD_THREADS) {
|
||||
dst[cur] = op::template op<typename SV::dtype>(lhs[cur], rhs[cur]);
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Perform a binary operation on a shared vector and a scalar.
|
||||
*
|
||||
* @tparam op The binary operation to perform.
|
||||
* @tparam T The type of the vector.
|
||||
* @param dst[out] The destination vector where the result is stored.
|
||||
* @param src[in] The source vector for the operation.
|
||||
* @param param[in] The scalar parameter for the operation.
|
||||
*/
|
||||
template<typename op, typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
bin_op(threadgroup SV &dst, threadgroup const SV &src, thread const typename SV::T ¶m, const ushort laneid) {
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
#pragma clang loop unroll(full)
|
||||
for(int cur = laneid; cur < SV::length; cur+=SIMD_THREADS) {
|
||||
dst[cur] = op::template op<typename SV::dtype>(src[cur], param);
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// ---- const ops ----
|
||||
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to zero.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to zero.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
zero(threadgroup SV &dst, const ushort laneid) {
|
||||
unary_op<base_ops::zero, SV>(dst, dst, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to one.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to one.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
one(threadgroup SV &dst, const ushort laneid) {
|
||||
unary_op<base_ops::one, SV>(dst, dst, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to positive infinity.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to positive infinity.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
pos_infty(threadgroup SV &dst, const ushort laneid) {
|
||||
unary_op<base_ops::pos_infty, SV>(dst, dst, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Sets all elements of a shared memory vector to negative infinity.
|
||||
*
|
||||
* @tparam T Shared memory vector type.
|
||||
* @param dst[out] Destination vector to be set to negative infinity.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
neg_infty(threadgroup SV &dst, const ushort laneid) {
|
||||
unary_op<base_ops::neg_infty, SV>(dst, dst, laneid);
|
||||
}
|
||||
|
||||
// ---- unary ops ----
|
||||
|
||||
/**
|
||||
* @brief Copies the elements from one shared vector to another.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the source vector.
|
||||
* @param dst[out] Destination vector where the elements will be copied to.
|
||||
* @param src[in] Source vector to copy the elements from.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
copy(threadgroup SV &dst, thread const U &src, const ushort laneid) {
|
||||
bin_op<base_ops::copy2, SV>(dst, dst, src, laneid); // the second arg is ignored here.
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
exp(threadgroup SV &dst, threadgroup const SV &src, const ushort laneid) {
|
||||
unary_op<base_ops::exp, SV>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the exponential function element-wise to a shared vector, in base 2.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the exponential values will be stored.
|
||||
* @param src[in] Source vector to apply the exponential function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
exp2(threadgroup SV &dst, threadgroup const SV &src, const ushort laneid) {
|
||||
unary_op<base_ops::exp2, SV>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the natural logarithm function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the logarithm values will be stored.
|
||||
* @param src[in] Source vector to apply the logarithm function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
log(threadgroup SV &dst, threadgroup const SV &src, const ushort laneid) {
|
||||
unary_op<base_ops::log, SV>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the absolute value function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the absolute values will be stored.
|
||||
* @param src[in] Source vector to apply the absolute value function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
abs(threadgroup SV &dst, threadgroup const SV &src, const ushort laneid) {
|
||||
unary_op<base_ops::abs, SV>(dst, src, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Applies the rectified linear unit (ReLU) function element-wise to a shared vector.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @param dst[out] Destination vector where the ReLU values will be stored.
|
||||
* @param src[in] Source vector to apply the ReLU function to.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
relu(threadgroup SV &dst, threadgroup const SV &src, const ushort laneid) {
|
||||
unary_op<base_ops::relu, SV>(dst, src, laneid);
|
||||
}
|
||||
|
||||
// ---- binary ops ----
|
||||
|
||||
/**
|
||||
* @brief Computes the element-wise maximum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the maximum values will be stored.
|
||||
* @param lhs[in] First vector for the maximum operation.
|
||||
* @param rhs[in] Second vector for the maximum operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
max(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_op<base_ops::max, SV>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise minimum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the minimum values will be stored.
|
||||
* @param lhs[in] First vector for the minimum operation.
|
||||
* @param rhs[in] Second vector for the minimum operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
min(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_op<base_ops::min, SV>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise sum of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the sum values will be stored.
|
||||
* @param lhs[in] First vector for the sum operation.
|
||||
* @param rhs[in] Second vector for the sum operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
add(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_op<base_ops::sum, SV>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise difference of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the difference values will be stored.
|
||||
* @param lhs[in] First vector for the difference operation.
|
||||
* @param rhs[in] Second vector for the difference operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
sub(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_op<base_ops::sub, SV>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise product of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the product values will be stored.
|
||||
* @param lhs[in] First vector for the product operation.
|
||||
* @param rhs[in] Second vector for the product operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
mul(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_op<base_ops::mul, SV>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
/**
|
||||
* @brief Computes the element-wise division of two shared vectors.
|
||||
*
|
||||
* @tparam T Shared vector type.
|
||||
* @tparam U Type of the second vector.
|
||||
* @param dst[out] Destination vector where the division values will be stored.
|
||||
* @param lhs[in] First vector for the division operation.
|
||||
* @param rhs[in] Second vector for the division operation.
|
||||
*/
|
||||
template<typename SV, typename U>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
div(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const ushort laneid) {
|
||||
bin_op<base_ops::div, SV>(dst, lhs, rhs, laneid);
|
||||
}
|
||||
|
||||
}
|
||||
268
extra/thunder/include/ops/warp/shared/vec/reductions.metal
Normal file
268
extra/thunder/include/ops/warp/shared/vec/reductions.metal
Normal file
@@ -0,0 +1,268 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Warp-scope maps on shared vectors.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../../../common/common.metal"
|
||||
#include "../../../../types/types.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
/**
|
||||
* @brief Performs a reduction operation on elements of a shared memory vector within a warp.
|
||||
*
|
||||
* This function applies a specified operation to reduce the elements of a shared memory vector `src` to a single value.
|
||||
* The result is stored in `accum`. If the `reset` parameter is true, the reduction includes an initial value `src_accum`.
|
||||
* The reduction operation is performed in a warp-wide context, ensuring synchronization between threads in the warp.
|
||||
*
|
||||
* @tparam op The operation to perform on the elements. Must provide a static `op` method.
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @tparam reset A boolean flag indicating whether to include an initial value in the reduction.
|
||||
* @param[out] accum The result of the reduction operation.
|
||||
* @param[in] src The shared memory vector to reduce.
|
||||
* @param[in] src_accum The initial value to include in the reduction if `reset` is false.
|
||||
*/
|
||||
template<typename op, typename SV, bool reset>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
reduce(thread typename SV::dtype &dst_accum, threadgroup const SV &src, thread const typename SV::dtype &src_accum, const ushort laneid) {
|
||||
using T = typename SV::dtype;
|
||||
|
||||
{
|
||||
T accum = src[0];
|
||||
for (int i = 1; i < SV::length; i++) {
|
||||
accum = op::template op<T>(accum, src[i]);
|
||||
}
|
||||
dst_accum = shfl_sync(accum, 0);
|
||||
return;
|
||||
}
|
||||
|
||||
//
|
||||
T accum;
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = op::template op<T>(accum, src[i]);
|
||||
}
|
||||
if (src.length >= 32) {
|
||||
// accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
|
||||
accum = op::template op<T>(accum, (T)metal::simd_shuffle_rotate_down((float)accum, 1));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
// accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
|
||||
accum = op::template op<T>(accum, (T)metal::simd_shuffle_rotate_down((float)accum, 2));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
// accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
|
||||
accum = op::template op<T>(accum, (T)metal::simd_shuffle_rotate_down((float)accum, 4));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
// accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 8));
|
||||
accum = op::template op<T>(accum, (T)metal::simd_shuffle_rotate_down((float)accum, 8));
|
||||
metal::simdgroup_barrier(metal::mem_flags::mem_none);
|
||||
// accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 16));
|
||||
accum = op::template op<T>(accum, (T)metal::simd_shuffle_rotate_down((float)accum, 16));
|
||||
|
||||
} else if (src.length == 24) {
|
||||
T shfl_val = shfl_down_sync<T>(accum, 1);
|
||||
accum = op::template op<T>(accum, shfl_val);
|
||||
|
||||
shfl_val = shfl_down_sync<T>(accum, 2);
|
||||
accum = op::template op<T>(accum, shfl_val);
|
||||
|
||||
shfl_val = shfl_down_sync<T>(accum, 4);
|
||||
accum = op::template op<T>(accum, shfl_val);
|
||||
|
||||
shfl_val = shfl_down_sync<T>(accum, 8);
|
||||
if (laneid < 16) {
|
||||
accum = op::template op<T>(accum, shfl_val);
|
||||
}
|
||||
shfl_val = shfl_down_sync<T>(accum, 16);
|
||||
accum = op::template op<T>(accum, shfl_val);
|
||||
} else if (src.length == 16) {
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 8));
|
||||
} else if (src.length == 8) {
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 1));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 2));
|
||||
accum = op::template op<T>(accum, shfl_down_sync<T>(accum, 4));
|
||||
}
|
||||
if (!reset) accum = op::template op<T>(accum, src_accum);
|
||||
dst_accum = shfl_sync(accum, 0);
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
/**
|
||||
* @brief Finds the maximum element in a shared memory vector.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] max_val The maximum value found in the vector.
|
||||
* @param[in] src The shared memory vector to find the maximum in.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
max(thread typename SV::dtype &max_val, threadgroup const SV &src, const ushort laneid) {
|
||||
// reduce<base_ops::max, SV, true>(max_val, src, max_val, laneid);
|
||||
using T = typename SV::dtype;
|
||||
T accum = base_types::constants<T>::neg_infty();
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = base_ops::max::template op<T>(accum, src[i]);
|
||||
}
|
||||
max_val = (T)metal::simd_max((float)accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum element in a shared memory vector.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] min_val The minimum value found in the vector.
|
||||
* @param[in] src The shared memory vector to find the minimum in.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
min(thread typename SV::dtype &min_val, threadgroup const SV &src, const ushort laneid) {
|
||||
// reduce<base_ops::min, SV, true>(min_val, src, min_val);
|
||||
|
||||
using T = typename SV::dtype;
|
||||
T accum = base_types::constants<T>::pos_infty();
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = base_ops::min::template op<T>(accum, src[i]);
|
||||
}
|
||||
min_val = (T)metal::simd_min((float)accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the sum of elements in a shared memory vector.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] sum_val The sum of the values in the vector.
|
||||
* @param[in] src The shared memory vector to sum.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
sum(thread typename SV::dtype &sum_val, threadgroup const SV &src, const ushort laneid) {
|
||||
// reduce<base_ops::sum, SV, true>(sum_val, src, sum_val, laneid);
|
||||
using T = typename SV::dtype;
|
||||
T accum = base_types::constants<T>::zero();
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = base_ops::min::template op<T>(accum, src[i]);
|
||||
}
|
||||
sum_val = (T)metal::simd_sum((float)accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the product of elements in a shared memory vector.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] prod_val The product of the values in the vector.
|
||||
* @param[in] src The shared memory vector to multiply.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
prod(thread typename SV::dtype &prod_val, threadgroup const SV &src, const ushort laneid) {
|
||||
// reduce<base_ops::mul, SV, true>(prod_val, src, prod_val, laneid);
|
||||
using T = typename SV::dtype;
|
||||
T accum = base_types::constants<T>::one();
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = base_ops::min::template op<T>(accum, src[i]);
|
||||
}
|
||||
prod_val = (T)metal::simd_product((float)accum);
|
||||
}
|
||||
|
||||
// Three operand versions.
|
||||
|
||||
/**
|
||||
* @brief Finds the maximum element in a shared memory vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] max_val The maximum value found in the vector, accumulated with src_accum.
|
||||
* @param[in] src The shared memory vector to find the maximum in.
|
||||
* @param[in] src_accum The initial value to accumulate with the maximum value found.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
max(thread typename SV::dtype &max_val, threadgroup const SV &src, thread const typename SV::dtype &src_accum, const ushort laneid) {
|
||||
// reduce<base_ops::max, SV, false>(max_val, src, src_accum, laneid);
|
||||
using T = typename SV::dtype;
|
||||
T accum = base_types::constants<T>::neg_infty();
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = base_ops::max::template op<T>(accum, src[i]);
|
||||
}
|
||||
max_val = (T)metal::simd_max((float)accum);
|
||||
max_val = base_ops::max::template op<T>(max_val, src_accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Finds the minimum element in a shared memory vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] min_val The minimum value found in the vector, accumulated with src_accum.
|
||||
* @param[in] src The shared memory vector to find the minimum in.
|
||||
* @param[in] src_accum The initial value to accumulate with the minimum value found.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
min(thread typename SV::dtype &min_val, threadgroup const SV &src, thread const typename SV::dtype &src_accum, const ushort laneid) {
|
||||
// reduce<base_ops::min, SV, false>(min_val, src, src_accum, laneid);
|
||||
using T = typename SV::dtype;
|
||||
T accum = base_types::constants<T>::pos_infty();
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = base_ops::max::template op<T>(accum, src[i]);
|
||||
}
|
||||
min_val = (T)metal::simd_min((float)accum);
|
||||
min_val = base_ops::max::template op<T>(min_val, src_accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the sum of elements in a shared memory vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] sum_val The sum of the values in the vector, accumulated with src_accum.
|
||||
* @param[in] src The shared memory vector to sum.
|
||||
* @param[in] src_accum The initial value to accumulate with the sum of the vector.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
sum(thread typename SV::dtype &sum_val, threadgroup const SV &src, threadgroup const typename SV::dtype &src_accum, const ushort laneid) {
|
||||
// reduce<base_ops::sum, SV, false>(sum_val, src, src_accum, laneid);
|
||||
using T = typename SV::dtype;
|
||||
T accum = base_types::constants<T>::zero();
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = base_ops::max::template op<T>(accum, src[i]);
|
||||
}
|
||||
sum_val = (T)metal::simd_sum((float)accum);
|
||||
sum_val = base_ops::max::template op<T>(sum_val, src_accum);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Calculates the product of elements in a shared memory vector and accumulates it with src_accum.
|
||||
*
|
||||
* @tparam SV The type of the shared memory vector. Must satisfy the `ducks::sv::all` concept.
|
||||
* @param[out] prod_val The product of the values in the vector, accumulated with src_accum.
|
||||
* @param[in] src The shared memory vector to multiply.
|
||||
* @param[in] src_accum The initial value to accumulate with the product of the vector.
|
||||
*/
|
||||
template<typename SV>
|
||||
static METAL_FUNC typename metal::enable_if<ducks::is_shared_vector<SV>(), void>::type
|
||||
prod(thread typename SV::dtype &prod_val, threadgroup const SV &src, thread const typename SV::dtype &src_accum, const ushort laneid) {
|
||||
// reduce<base_ops::mul, SV, false>(prod_val, src, src_accum, laneid);
|
||||
using T = typename SV::dtype;
|
||||
T accum = base_types::constants<T>::one();
|
||||
if(laneid < SV::length) accum = src[laneid]; // initialize a register accumulator
|
||||
for(int i = laneid + 32; i < SV::length; i+=32) {
|
||||
accum = base_ops::max::template op<T>(accum, src[i]);
|
||||
}
|
||||
prod_val = (T)metal::simd_product((float)accum);
|
||||
prod_val = base_ops::max::template op<T>(prod_val, src_accum);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
4
extra/thunder/include/ops/warp/shared/vec/vec.metal
Normal file
4
extra/thunder/include/ops/warp/shared/vec/vec.metal
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "conversions.metal"
|
||||
#include "maps.metal"
|
||||
#include "reductions.metal"
|
||||
4
extra/thunder/include/ops/warp/warp.metal
Normal file
4
extra/thunder/include/ops/warp/warp.metal
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "memory/memory.metal"
|
||||
#include "register/register.metal"
|
||||
#include "shared/shared.metal"
|
||||
4
extra/thunder/include/tk.metal
Normal file
4
extra/thunder/include/tk.metal
Normal file
@@ -0,0 +1,4 @@
|
||||
#pragma once
|
||||
#include "common/common.metal"
|
||||
#include "ops/ops.metal"
|
||||
#include "types/types.metal"
|
||||
63
extra/thunder/include/types/global/cgl.metal
Normal file
63
extra/thunder/include/types/global/cgl.metal
Normal file
@@ -0,0 +1,63 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Templated layouts for complex global memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../common/common.metal"
|
||||
//#include "../shared/cst.metal"
|
||||
#include "gl.metal"
|
||||
#include "util.metal"
|
||||
#ifdef mittens_HOPPER
|
||||
#include "tma.metal"
|
||||
#endif
|
||||
|
||||
namespace mittens {
|
||||
/* ---------- Global layout descriptor ---------- */
|
||||
|
||||
namespace ducks {
|
||||
namespace cgl {
|
||||
struct identifier {};
|
||||
}
|
||||
}
|
||||
|
||||
template<typename GL>
|
||||
struct cgl {
|
||||
static_assert(ducks::is_global_layout<GL>, "GL must satisfy global layout requirements.");
|
||||
|
||||
using identifier = ducks::cgl::identifier;
|
||||
using T = typename GL::T;
|
||||
using T2 = typename GL::T2;
|
||||
using dtype = typename GL::dtype;
|
||||
|
||||
GL real, imag;
|
||||
};
|
||||
|
||||
namespace ducks {
|
||||
template <typename T>
|
||||
struct has_cgl_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
|
||||
//template <typename _T, int b, int d, int r, int c, typename... TMA_Types>
|
||||
//struct has_cgl_identifier<mittens::gl<_T, b, d, r, c, TMA_Types ...>> {
|
||||
// static constant constexpr bool value = true;
|
||||
//};
|
||||
template <typename _T, int b, int d, int r, int c>
|
||||
struct has_cgl_identifier<mittens::gl<_T, b, d, r, c>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename GL>
|
||||
static constexpr bool is_complex_global_layout() {
|
||||
return has_rt_identifier<GL>::value;
|
||||
}
|
||||
template <typename GL>
|
||||
static constexpr void assert_cgl() {
|
||||
static_assert(is_complex_global_layout<GL>(), "T must be a cgl");
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
213
extra/thunder/include/types/global/gl.metal
Normal file
213
extra/thunder/include/types/global/gl.metal
Normal file
@@ -0,0 +1,213 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Templated layouts for global memory.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../common/common.metal"
|
||||
#include "../shared/shared.metal"
|
||||
#include "../register/register.metal"
|
||||
#include "util.metal"
|
||||
|
||||
|
||||
namespace mittens {
|
||||
/* ---------- Associative dictionary for global layouts ---------- */
|
||||
|
||||
namespace detail {
|
||||
template<typename... Args>
|
||||
struct descriptor_dict {
|
||||
METAL_FUNC descriptor_dict() {}
|
||||
template<typename T> METAL_FUNC descriptor_dict(T _, int b, int d, int r, int c) {}
|
||||
METAL_FUNC descriptor_dict(thread const descriptor_dict &other) {}
|
||||
};
|
||||
}
|
||||
|
||||
/* ---------- Global layout descriptor ---------- */
|
||||
|
||||
namespace ducks {
|
||||
namespace gl {
|
||||
struct identifier {};
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_tile() {
|
||||
return mittens::ducks::is_shared_tile<T>() || mittens::ducks::is_register_tile<T>();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static constexpr bool is_vec() {
|
||||
return mittens::ducks::is_shared_vector<T>() || mittens::ducks::is_register_vector<T>();
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename _T, int b, int d, int r, int c>
|
||||
struct gl {
|
||||
using identifier = ducks::gl::identifier;
|
||||
|
||||
using T = typename base_types::packing<_T>::unpacked_type;
|
||||
using T2 = typename base_types::packing<_T>::packed_type;
|
||||
using dtype = T;
|
||||
|
||||
device T* raw_ptr;
|
||||
|
||||
ducks::g::make_dim_t<b> batch;
|
||||
ducks::g::make_dim_t<d> depth;
|
||||
ducks::g::make_dim_t<r> rows;
|
||||
ducks::g::make_dim_t<c> cols;
|
||||
// int batch;
|
||||
// int depth;
|
||||
// int rows;
|
||||
// int cols;
|
||||
|
||||
METAL_FUNC gl(device T *_data,
|
||||
ducks::g::make_arg_t<b> _batch,
|
||||
ducks::g::make_arg_t<d> _depth,
|
||||
ducks::g::make_arg_t<r> _rows,
|
||||
ducks::g::make_arg_t<c> _cols) :
|
||||
raw_ptr(_data), batch(_batch), depth(_depth), rows(_rows), cols(_cols) {
|
||||
}
|
||||
// METAL_FUNC gl(device T *_data,
|
||||
// int _batch,
|
||||
// int _depth,
|
||||
// int _rows,
|
||||
// int _cols) :
|
||||
// raw_ptr(_data), batch(_batch), depth(_depth), rows(_rows), cols(_cols) {
|
||||
// }
|
||||
//
|
||||
METAL_FUNC gl(thread const gl &other) :
|
||||
raw_ptr(other.raw_ptr), batch(other.batch), depth(other.depth), rows(other.rows), cols(other.cols) {}
|
||||
|
||||
METAL_FUNC gl(constant const gl &other) :
|
||||
raw_ptr(other.raw_ptr), batch(other.batch), depth(other.depth), rows(other.rows), cols(other.cols) {}
|
||||
|
||||
METAL_FUNC device T& operator[](const thread coord &idx) {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c];
|
||||
}
|
||||
METAL_FUNC device const T& operator[](const thread coord &idx) const {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c];
|
||||
}
|
||||
template<typename TILE>
|
||||
METAL_FUNC typename metal::enable_if<ducks::is_tile<TILE>(), device T&>::type
|
||||
get(const thread coord &idx) {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r*TILE::rows)*cols + idx.c*TILE::cols];
|
||||
}
|
||||
template<typename TILE>
|
||||
METAL_FUNC typename metal::enable_if<ducks::is_tile<TILE>(), device const T&>::type
|
||||
get(const thread coord &idx) const {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r*TILE::rows)*cols + idx.c*TILE::cols];
|
||||
}
|
||||
template<typename VEC>
|
||||
METAL_FUNC typename metal::enable_if<ducks::is_vec<VEC>(), device T&>::type
|
||||
get(const thread coord &idx) {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c*VEC::length];
|
||||
}
|
||||
template<typename VEC>
|
||||
METAL_FUNC typename metal::enable_if<ducks::is_vec<VEC>(), device const T&>::type
|
||||
get(const thread coord &idx) const {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c*VEC::length];
|
||||
}
|
||||
METAL_FUNC size_t row_stride() const { return cols; }
|
||||
};
|
||||
|
||||
namespace ducks {
|
||||
template <typename T>
|
||||
struct has_gl_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
|
||||
template <typename _T, int b, int d, int r, int c>
|
||||
struct has_gl_identifier<mittens::gl<_T, b, d, r, c>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename GL>
|
||||
static constexpr bool is_global_layout() {
|
||||
return has_gl_identifier<GL>::value;
|
||||
}
|
||||
template <typename GL>
|
||||
static constexpr void assert_gl() {
|
||||
static_assert(is_global_layout<GL>(), "T must be a gl");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template<typename _T, int b, int d, int r, int c>
|
||||
struct gl2 {
|
||||
using identifier = ducks::gl::identifier;
|
||||
|
||||
using T = typename base_types::packing<_T>::unpacked_type;
|
||||
using T2 = typename base_types::packing<_T>::packed_type;
|
||||
using dtype = T;
|
||||
|
||||
device T* raw_ptr;
|
||||
|
||||
// ducks::g::make_dim_t<b> batch;
|
||||
// ducks::g::make_dim_t<d> depth;
|
||||
// ducks::g::make_dim_t<r> rows;
|
||||
// ducks::g::make_dim_t<c> cols;
|
||||
//
|
||||
// METAL_FUNC gl2(device T *_data,
|
||||
// ducks::g::make_arg_t<b> _batch,
|
||||
// ducks::g::make_arg_t<d> _depth,
|
||||
// ducks::g::make_arg_t<r> _rows,
|
||||
// ducks::g::make_arg_t<c> _cols) :
|
||||
// raw_ptr(_data), batch(_batch), depth(_depth), rows(_rows), cols(_cols) {
|
||||
// }
|
||||
|
||||
int batch;
|
||||
int depth;
|
||||
int rows;
|
||||
int cols;
|
||||
|
||||
METAL_FUNC gl2(device T *_data,
|
||||
int _batch,
|
||||
int _depth,
|
||||
int _rows,
|
||||
int _cols) :
|
||||
raw_ptr(_data), batch(_batch), depth(_depth), rows(_rows), cols(_cols) {
|
||||
}
|
||||
|
||||
|
||||
// METAL_FUNC gl2(thread const gl2 &other) :
|
||||
// raw_ptr(other.raw_ptr), batch(other.batch), depth(other.depth), rows(other.rows), cols(other.cols) {}
|
||||
//
|
||||
// METAL_FUNC gl2(constant const gl2 &other) :
|
||||
// raw_ptr(other.raw_ptr), batch(other.batch), depth(other.depth), rows(other.rows), cols(other.cols) {}
|
||||
|
||||
METAL_FUNC device T& operator[](const thread coord &idx) {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c];
|
||||
}
|
||||
METAL_FUNC device const T& operator[](const thread coord &idx) const {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c];
|
||||
}
|
||||
template<typename TILE>
|
||||
METAL_FUNC typename metal::enable_if<ducks::is_tile<TILE>(), device T&>::type
|
||||
get(const thread coord &idx) {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r*TILE::rows)*cols + idx.c*TILE::cols];
|
||||
}
|
||||
template<typename TILE>
|
||||
METAL_FUNC typename metal::enable_if<ducks::is_tile<TILE>(), device const T&>::type
|
||||
get(const thread coord &idx) const {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r*TILE::rows)*cols + idx.c*TILE::cols];
|
||||
}
|
||||
template<typename VEC>
|
||||
METAL_FUNC typename metal::enable_if<ducks::is_vec<VEC>(), device T&>::type
|
||||
get(const thread coord &idx) {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c*VEC::length];
|
||||
}
|
||||
template<typename VEC>
|
||||
METAL_FUNC typename metal::enable_if<ducks::is_vec<VEC>(), device const T&>::type
|
||||
get(const thread coord &idx) const {
|
||||
return raw_ptr[((idx.b*depth + idx.d)*rows + idx.r)*cols + idx.c*VEC::length];
|
||||
}
|
||||
METAL_FUNC size_t row_stride() const { return cols; }
|
||||
};
|
||||
|
||||
}
|
||||
9
extra/thunder/include/types/global/global.metal
Normal file
9
extra/thunder/include/types/global/global.metal
Normal file
@@ -0,0 +1,9 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header file for all the global types defined by Thundermittens.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "util.metal"
|
||||
#include "gl.metal"
|
||||
#include "cgl.metal"
|
||||
44
extra/thunder/include/types/global/util.metal
Normal file
44
extra/thunder/include/types/global/util.metal
Normal file
@@ -0,0 +1,44 @@
|
||||
#pragma once
|
||||
|
||||
namespace mittens {
|
||||
namespace ducks {
|
||||
namespace g {
|
||||
|
||||
//template<int d> concept cdim = (d > 0); // represents a compile-time dimension
|
||||
//template<int d> concept rdim = (d == -1); // represents a runtime dimension
|
||||
|
||||
template<int d>
|
||||
struct compiled_dim {
|
||||
static_assert(d > 0, "Invalid compile-time dimension value"); // Replace `cdim` concept check
|
||||
static constant constexpr uint32_t v = d;
|
||||
|
||||
METAL_FUNC compiled_dim(thread const metal::nullptr_t &_) {}
|
||||
|
||||
METAL_FUNC constexpr operator uint32_t() const { return v; }
|
||||
};
|
||||
|
||||
struct runtime_dim {
|
||||
uint32_t v;
|
||||
METAL_FUNC runtime_dim(thread const uint32_t &_v) : v(_v) {}
|
||||
METAL_FUNC operator uint32_t() const { return v; }
|
||||
};
|
||||
|
||||
template<int d> using make_dim_t = metal::conditional_t<d == -1, runtime_dim, compiled_dim<d>>;
|
||||
template<int d> using make_arg_t = metal::conditional_t<d == -1, size_t, metal::nullptr_t>; // we pass runtime dims as size_t, comptime dims as nullptr_t
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
struct coord { // essentially a named int4 for tensor coordinates.
|
||||
int b, d, r, c;
|
||||
METAL_FUNC coord(int _b, int _d, int _r, int _c) : b(_b), d(_d), r(_r), c(_c) {}
|
||||
METAL_FUNC coord( int _d, int _r, int _c) : b( 0), d(_d), r(_r), c(_c) {}
|
||||
METAL_FUNC coord( int _r, int _c) : b( 0), d( 0), r(_r), c(_c) {}
|
||||
METAL_FUNC coord( int _c) : b( 0), d( 0), r( 0), c(_c) {}
|
||||
METAL_FUNC coord( ) : b( 0), d( 0), r( 0), c( 0) {}
|
||||
METAL_FUNC coord(thread const coord &other) : b(other.b), d(other.d), r(other.r), c(other.c) {}
|
||||
METAL_FUNC coord(thread const int4 &other) : b(other.x), d(other.y), r(other.z), c(other.w) {}
|
||||
METAL_FUNC operator int4() const { return int4(b, d, r, c); }
|
||||
};
|
||||
|
||||
}
|
||||
91
extra/thunder/include/types/register/crt.metal
Normal file
91
extra/thunder/include/types/register/crt.metal
Normal file
@@ -0,0 +1,91 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Abstraction for a complex register tile composed of real and imaginary tiles
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "rt.metal"
|
||||
#include "crv.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
namespace ducks {
|
||||
namespace crt {
|
||||
/**
|
||||
* @brief A dummy type used to identify complex register tiles.
|
||||
*
|
||||
* For a type to quack like an rt_cmplx, it should define its identifier as ducks::rt::cmplx_identifier.
|
||||
* If a type quacks like ducks::rt::cmplx_identifier, it will be treated as an rt_cmplx by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
} // namespace rt
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @brief Complex tile structure
|
||||
*
|
||||
* @tparam T2 The packed data type used for the matrix elements.
|
||||
* @tparam _rows The height of the tile in terms of the number of subtiles.
|
||||
* @tparam _cols The width of the tile in terms of the number of subtiles.
|
||||
* @tparam _layout The layout of the internal register tiles, either row-major or column-major.
|
||||
*
|
||||
* This structure is designed to abstract complex number operations internally to the real and imaginary
|
||||
* register tiles, respectively
|
||||
*
|
||||
* In general, you probably want a row-major tile, unless you specifically want to call mma
|
||||
*/
|
||||
template<typename _T, int _rows, int _cols, typename _layout>
|
||||
struct crt {
|
||||
using identifier = ducks::crt::identifier;
|
||||
static_assert(ducks::is_rt_layout<_layout>(), "crt was given invalid layout");
|
||||
using component = rt<_T, _rows, _cols, _layout>; /// Data type of each internal tile.
|
||||
using layout = typename component::layout; ///< Layout of the matrix tile, ensures compatibility with the rt concepts
|
||||
using T = typename component::T;
|
||||
using T2 = typename component::T2;
|
||||
using dtype = typename component::dtype; ///< Data type of the elements in the tile.
|
||||
|
||||
constant static constexpr int rows = component::rows;
|
||||
constant static constexpr int cols = component::cols;
|
||||
constant static constexpr int height = component::height;
|
||||
constant static constexpr int width = component::width;
|
||||
|
||||
// Real/imag tiles have same internal layout and size
|
||||
component real;
|
||||
component imag;
|
||||
|
||||
using row_vec = crv<T, cols, typename rt_base<T, layout>::row_vec_layout>; ///< A type representing a column vector for this tile.
|
||||
using col_vec = crv<T, rows, typename rt_base<T, layout>::col_vec_layout>; ///< A type representing a column vector for this tile.
|
||||
};
|
||||
|
||||
/* ---------- CONCEPTS ---------- */
|
||||
|
||||
namespace ducks {
|
||||
template <typename T>
|
||||
struct has_crt_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
|
||||
// Specialize for specific template instantiations of st
|
||||
template <typename _T, int _rows, int _cols, typename _layout>
|
||||
struct has_crt_identifier<mittens::crt<_T, _rows, _cols, _layout>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename CRT>
|
||||
static constexpr bool is_complex_register_tile() {
|
||||
return has_crt_identifier<CRT>::value;
|
||||
}
|
||||
template <typename CRT>
|
||||
static constexpr void assert_complex_register_tile() {
|
||||
static_assert(is_register_tile<CRT>(), "T must be a rt");
|
||||
}
|
||||
}
|
||||
|
||||
template<int _rows, int _cols, typename _layout=ducks::rt_layout::row> using crt_fl = crt<float, _rows, _cols, _layout>;
|
||||
template<int _rows, int _cols, typename _layout=ducks::rt_layout::row> using crt_bf = crt<bf16, _rows, _cols, _layout>;
|
||||
template<int _rows, int _cols, typename _layout=ducks::rt_layout::row> using crt_hf = crt<half, _rows, _cols, _layout>;
|
||||
|
||||
|
||||
}
|
||||
|
||||
97
extra/thunder/include/types/register/crv.metal
Normal file
97
extra/thunder/include/types/register/crv.metal
Normal file
@@ -0,0 +1,97 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Register vectors for computations on axes.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "../../common/common.metal"
|
||||
#include "rv_layout.metal"
|
||||
#include "rv.metal"
|
||||
|
||||
namespace mittens {
|
||||
|
||||
/* ---------- MAIN VECTOR STRUCT ---------- */
|
||||
|
||||
// helper struct for type inference
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace rt
|
||||
*
|
||||
* @brief The namespace where concepts and abstract types for register vectors live.
|
||||
*/
|
||||
namespace crv {
|
||||
/**
|
||||
* @brief A dummy type used to identify register vectors.
|
||||
*
|
||||
* For a type to quack like an rv, it should define its identifier as ducks::rv::identifier.
|
||||
* If a type quacks like ducks::rv::identifier, it will be treated as an rv by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
}
|
||||
}
|
||||
/**
|
||||
* @brief Register vector structure.
|
||||
*
|
||||
* @tparam _T The packed data type used for the vector elements.
|
||||
* @tparam _outer_dim The size of the tile, in units of TILE_DIM (16).
|
||||
* @tparam _inner_dim This controls the layout of the tile in terms of which axis it maps on the register tile layout.
|
||||
*
|
||||
* Register vectors are used to accumulate and map values across tiles. You can do computation
|
||||
* on them directly if you want, but they're not designed to be maximally efficient vectors
|
||||
* as they have substantial duplication and strange layouts to help them work efficiently with
|
||||
* the register layouts used by the tensor cores. Thundermittens wants you working with tiles
|
||||
* where possible!
|
||||
*/
|
||||
|
||||
template<typename _T, size_t _length, typename _layout=ducks::rv_layout::naive>
|
||||
struct crv {
|
||||
static_assert(ducks::is_rv_layout<_layout>(), "_layout must be a rv layout");
|
||||
static_assert(ducks::base_types::isT1Type<_T>(), "T must be float, bf16, or half");
|
||||
using identifier = ducks::crv::identifier;
|
||||
using component = rv<_T, _length, _layout>; /// Data type of each internal tile.
|
||||
using layout = typename component::layout; ///< Layout of the matrix tile, ensures compatibility with the rv concepts
|
||||
|
||||
using T = typename component::T;
|
||||
using T2 = typename component::T2;
|
||||
using dtype = typename component::dtype; ///< Data type of the elements in the tile.
|
||||
|
||||
constant static constexpr int length = component::length;
|
||||
constant static constexpr int tiles = component::tiles;
|
||||
|
||||
// Real/imag tiles have same internal layout and size
|
||||
component real;
|
||||
component imag;
|
||||
};
|
||||
|
||||
/* ---------- CONCEPTS ---------- */
|
||||
|
||||
namespace ducks {
|
||||
template <typename T>
|
||||
struct has_crv_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
|
||||
// Specialize for specific template instantiations of st
|
||||
template <typename _T, int _length, typename _layout>
|
||||
struct has_crv_identifier<mittens::crv<_T, _length, _layout>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename CRV>
|
||||
static constexpr bool is_complex_register_vector() {
|
||||
return has_crv_identifier<CRV>::value;
|
||||
}
|
||||
template <typename CRV>
|
||||
static constexpr void assert_complex_register_vector() {
|
||||
static_assert(is_complex_register_vector<CRV>(), "T must be a crv");
|
||||
}
|
||||
} // namespace ducks
|
||||
|
||||
template<int _l, typename layout=ducks::rv_layout::naive> using crv_fl = crv<float, _l, layout>;
|
||||
template<int _l, typename layout=ducks::rv_layout::naive> using crv_bf = crv<bf16, _l, layout>;
|
||||
template<int _l, typename layout=ducks::rv_layout::naive> using crv_hf = crv<half, _l, layout>;
|
||||
|
||||
|
||||
} // namespace mittens
|
||||
|
||||
15
extra/thunder/include/types/register/register.metal
Normal file
15
extra/thunder/include/types/register/register.metal
Normal file
@@ -0,0 +1,15 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header file for all the register types defined by Thundermittens.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "crv.metal"
|
||||
#include "rv.metal"
|
||||
#include "rv_layout.metal"
|
||||
#include "crt.metal"
|
||||
#include "rt.metal"
|
||||
#include "rt_layout.metal"
|
||||
#include "rt_base.metal"
|
||||
|
||||
|
||||
129
extra/thunder/include/types/register/rt.metal
Normal file
129
extra/thunder/include/types/register/rt.metal
Normal file
@@ -0,0 +1,129 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief The main Thundermittens register tile struct, where most computation happens.
|
||||
*/
|
||||
#pragma once // kinda done
|
||||
/*
|
||||
TODO:
|
||||
consider if column layout rly rly rly makes no sense and no implement needed, not me being lazy
|
||||
*/
|
||||
#include <metal_stdlib>
|
||||
#include "../../common/common.metal"
|
||||
#include "rt_base.metal"
|
||||
#include "rv.metal"
|
||||
|
||||
/* ---------- MAIN TILE STRUCT ---------- */
|
||||
|
||||
|
||||
namespace mittens {
|
||||
/* ---------- MAIN TILE STRUCT ---------- */
|
||||
// helper struct for type inference
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace rt
|
||||
*
|
||||
* @brief The namespace where concepts and abstract types for register tiles live.
|
||||
*/
|
||||
namespace rt {
|
||||
/**
|
||||
* @brief A dummy type used to identify register tiles.
|
||||
*
|
||||
* For a type to quack like an rt, it should define its identifier as ducks::rt::identifier.
|
||||
* If a type quacks like ducks::rt::identifier, it will be treated as an rt by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
|
||||
} // namespace rt
|
||||
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @brief Main tile structure for manipulating data in registers.
|
||||
*
|
||||
* @tparam _T The data type used for the matrix elements.
|
||||
* @tparam _height The height of the tile in terms of the number of subtiles.
|
||||
* @tparam _width The width of the tile in terms of the number of subtiles.
|
||||
*
|
||||
* This structure is designed to handle matrix tiles in a flexible manner, allowing
|
||||
* for operations on tiles that are composed of smaller subtiles.
|
||||
*/
|
||||
template<typename _T, int _rows, int _cols, typename _layout=ducks::rt_layout::row>
|
||||
struct rt {
|
||||
using identifier = ducks::rt::identifier; ///< Type identifier for the rt structure.
|
||||
using layout = _layout;
|
||||
using T = typename base_types::packing<_T>::unpacked_type;
|
||||
static_assert(ducks::base_types::isT1Type<T>(), "T must be float, bf16, or half");
|
||||
static_assert(ducks::is_rt_layout<_layout>(), "T must be float, bf16, or half");
|
||||
using T2 = typename base_types::packing<_T>::packed_type;
|
||||
using dtype = T; ///< Data type of the elements in the tile.
|
||||
constant static constexpr int rows = _rows; ///< Total number of rows.
|
||||
static_assert(rows % rt_base<T, _layout>::tile_size == 0, "Rows must be divisible by the tile size");
|
||||
constant static constexpr int cols = _cols; ///< Total number of columns.
|
||||
static_assert(cols % rt_base<T, _layout>::tile_size == 0, "Columns must be divisible by the tile size");
|
||||
constant static constexpr int height = rows / rt_base<T, _layout>::tile_size; ///< Height in subtiles.
|
||||
constant static constexpr int width = cols / rt_base<T, _layout>::tile_size; ///< Width in subtiles.
|
||||
constant static constexpr int tile_size = rt_base<T, _layout>::tile_size; ///< Size of the base tile.
|
||||
constant static constexpr int num_elements = rt_base<T, _layout>::num_elements * width * height; ///< Total number of elements.
|
||||
constant static constexpr int elements_per_thread = rt_base<T, _layout>::elements_per_thread * width * height; ///< Elements handled per thread.
|
||||
constant static constexpr int packed_per_thread = rt_base<T, _layout>::packed_per_thread * width * height; ///< Packed elements per thread.
|
||||
constant static constexpr int packed_per_tile = rt_base<T, _layout>::packed_per_thread; ///< Packed elements per tile.
|
||||
|
||||
rt_base<dtype, _layout> tiles[height][width]; ///< The actual storage for the matrix tile, organized in subtiles.
|
||||
|
||||
using row_vec = rv<T, cols, typename rt_base<T, _layout>::row_vec_layout>; ///< A type representing a column vector for this tile.
|
||||
using col_vec = rv<T, rows, typename rt_base<T, _layout>::col_vec_layout>; ///< A type representing a column vector for this tile.
|
||||
};
|
||||
|
||||
|
||||
|
||||
namespace ducks{
|
||||
template <typename T>
|
||||
struct has_rt_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
static constant constexpr bool is_row = false;
|
||||
static constant constexpr bool is_col = false;
|
||||
};
|
||||
|
||||
template <typename _T, int _rows, int _cols>
|
||||
struct has_rt_identifier<mittens::rt<_T, _rows, _cols, rt_layout::row>> {
|
||||
static constant constexpr bool value = true;
|
||||
static constant constexpr bool is_row = true; // Row-specific indicator
|
||||
static constant constexpr bool is_col = false;
|
||||
};
|
||||
|
||||
template <typename _T, int _rows, int _cols>
|
||||
struct has_rt_identifier<mittens::rt<_T, _rows, _cols, rt_layout::col>> {
|
||||
static constant constexpr bool value = true;
|
||||
static constant constexpr bool is_row = false;
|
||||
static constant constexpr bool is_col = true; // Col-specific indicator
|
||||
};
|
||||
|
||||
template <typename RT>
|
||||
static constexpr bool is_register_tile() {
|
||||
return has_rt_identifier<RT>::value;
|
||||
}
|
||||
|
||||
template <typename RT>
|
||||
static constexpr bool is_row_register_tile() {
|
||||
return has_rt_identifier<RT>::is_row;
|
||||
}
|
||||
|
||||
template <typename RT>
|
||||
static constexpr bool is_col_register_tile() {
|
||||
return has_rt_identifier<RT>::is_col;
|
||||
}
|
||||
|
||||
|
||||
template <typename RT>
|
||||
static constexpr void assert_register_tile() {
|
||||
static_assert(is_register_tile<RT>(), "T must be a rt");
|
||||
}
|
||||
}
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
// layout and type wrappers
|
||||
|
||||
template<int _r, int _c, typename layout=ducks::rt_layout::row> using rt_fl = rt<float, _r, _c, layout>;
|
||||
template<int _r, int _c, typename layout=ducks::rt_layout::row> using rt_bf = rt<bf16, _r, _c, layout>;
|
||||
template<int _r, int _c, typename layout=ducks::rt_layout::row> using rt_hf = rt<half, _r, _c, layout>;
|
||||
} // namespace mittens
|
||||
84
extra/thunder/include/types/register/rt_base.metal
Normal file
84
extra/thunder/include/types/register/rt_base.metal
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief The basic 8x8 register tile on which larger register tiles are built.
|
||||
*/
|
||||
#pragma once // todo: col/row layout if needed
|
||||
#include <metal_stdlib>
|
||||
|
||||
#include "../../common/common.metal"
|
||||
#include "rt_layout.metal"
|
||||
#include "rv_layout.metal"
|
||||
namespace mittens {
|
||||
/* ---------- BASE 8x8 SUBTILE STRUCT ---------- */
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace rt_base
|
||||
*
|
||||
* @brief The namespace where concepts and abstract types for register base (16x16) tiles live.
|
||||
*/
|
||||
namespace rt_base {
|
||||
/**
|
||||
* @brief A dummy type used to identify register base tiles.
|
||||
*
|
||||
* For a type to quack like an rt_base, it should define its identifier as ducks::rt_base::identifier.
|
||||
* If a type quacks like ducks::rt_base::identifier, it will be treated as an rt_base by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
}
|
||||
template <typename T>
|
||||
static constexpr bool is_register_tile_base() {
|
||||
return metal::is_same<typename T::identifier, ducks::rt_base::identifier>::value;
|
||||
}
|
||||
template <typename RT>
|
||||
static constexpr void assert_register_tile_base() {
|
||||
static_assert(is_register_tile_base<RT>(), "T must be a rt_base");
|
||||
}
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @brief Basic tile structure for computation in registers.
|
||||
*
|
||||
* @tparam T2 The packed data type used for the matrix elements.
|
||||
* @tparam _layout The layout of the base tile, either row-major or column-major.
|
||||
*
|
||||
* This type is a primarily utility for building larger inline templates
|
||||
* out of PTX primitives and managing layouts.
|
||||
*
|
||||
* In general, you probably want a row-major tile, unless you specifically want to call mma
|
||||
*/
|
||||
template <typename _T, typename _layout>
|
||||
struct rt_base {
|
||||
using identifier = ducks::rt_base::identifier; ///< Type identifier for the rt_base structure.
|
||||
using layout = _layout; ///< Layout of the matrix tile.
|
||||
static_assert(ducks::base_types::isT1Type<_T>(), "rt_base was provided an unsupported type");
|
||||
static_assert(ducks::is_rt_layout<layout>(), "rt_base was provided an unsupported layout");
|
||||
using T = typename base_types::packing<_T>::unpacked_type;
|
||||
using T2 = typename base_types::packing<_T>::packed_type;
|
||||
using dtype = T;
|
||||
|
||||
|
||||
|
||||
static constant constexpr const int tile_size = mittens::TILE_DIM;
|
||||
static constant constexpr const int rows = tile_size;
|
||||
static constant constexpr const int cols = tile_size;
|
||||
static constant constexpr const int num_elements = rows*cols;
|
||||
static constant constexpr const int elements_per_thread = num_elements / mittens::SIMD_THREADS;
|
||||
|
||||
static constant constexpr const int registers_per_thread = elements_per_thread;
|
||||
static constant constexpr const int packed_per_thread = elements_per_thread / base_types::packing<T2>::num();
|
||||
metal::simdgroup_matrix<dtype, mittens::TILE_DIM, mittens::TILE_DIM> data;
|
||||
|
||||
using row_vec_layout = metal::conditional_t<metal::is_same_v<layout, ducks::rt_layout::row>, ducks::rv_layout::align, ducks::rv_layout::ortho>; // for holding column reductions
|
||||
|
||||
using col_vec_layout = metal::conditional_t<metal::is_same_v<layout, ducks::rt_layout::row>, ducks::rv_layout::ortho, ducks::rv_layout::align>; // for holding row reductions
|
||||
};
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
template<typename L=ducks::rt_layout::row> using rt_base_fl = rt_base<float, L>;
|
||||
template<typename L=ducks::rt_layout::row> using rt_base_bf = rt_base<bf16, L>;
|
||||
template<typename L=ducks::rt_layout::row> using rt_base_hf = rt_base<half, L>;
|
||||
|
||||
|
||||
}
|
||||
|
||||
45
extra/thunder/include/types/register/rt_layout.metal
Normal file
45
extra/thunder/include/types/register/rt_layout.metal
Normal file
@@ -0,0 +1,45 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Layouts and their manipulations for register tiles.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
namespace mittens {
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace rt_layout
|
||||
*
|
||||
* @brief A namespace for template metaprogramming with register tile layouts.
|
||||
*/
|
||||
namespace rt_layout {
|
||||
|
||||
/**
|
||||
* @brief A dummy type used to identify a row-major layout for a register tile.
|
||||
*/
|
||||
struct row {}; // for most matrices
|
||||
/**
|
||||
* @brief A dummy type used to identify a col-major layout for a register tile.
|
||||
*/
|
||||
struct col {}; // for the B-matrix of MMA ops.
|
||||
|
||||
template<typename l> struct transpose { using type = rt_layout::col; };
|
||||
template<> struct transpose<rt_layout::col> { using type = rt_layout::row; };
|
||||
} // namespace rt_layout
|
||||
template <typename _layout>
|
||||
METAL_FUNC static constexpr bool is_row_layout() {
|
||||
return metal::is_same_v<_layout, rt_layout::row>;
|
||||
}
|
||||
template <typename _layout>
|
||||
METAL_FUNC static constexpr bool is_col_layout() {
|
||||
return metal::is_same_v<_layout, rt_layout::col>;
|
||||
}
|
||||
template <typename _layout>
|
||||
METAL_FUNC static constexpr bool is_rt_layout() {
|
||||
return is_row_layout<_layout>() || is_col_layout<_layout>();
|
||||
}
|
||||
|
||||
|
||||
} // namespace ducks
|
||||
} // namespace mittens
|
||||
125
extra/thunder/include/types/register/rv.metal
Normal file
125
extra/thunder/include/types/register/rv.metal
Normal file
@@ -0,0 +1,125 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Register vectors for computations on axes.
|
||||
*/
|
||||
#pragma once
|
||||
#include "../../common/common.metal"
|
||||
#include "rv_layout.metal"
|
||||
namespace mittens {
|
||||
/* ---------- MAIN VECTOR STRUCT ---------- */
|
||||
|
||||
// helper struct for type inference
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace rt
|
||||
*
|
||||
* @brief The namespace where concepts and abstract types for register vectors live.
|
||||
*/
|
||||
namespace rv {
|
||||
/**
|
||||
* @brief A dummy type used to identify register vectors.
|
||||
*
|
||||
* For a type to quack like an rv, it should define its identifier as ducks::rv::identifier.
|
||||
* If a type quacks like ducks::rv::identifier, it will be treated as an rv by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Register vector structure.
|
||||
*
|
||||
* @tparam _T The packed data type used for the vector elements.
|
||||
* @tparam _outer_dim The size of the tile, in units of TILE_DIM (8).
|
||||
* @tparam _inner_dim This controls the layout of the tile in terms of which axis it maps on the register tile layout.
|
||||
*
|
||||
* Register vectors are used to accumulate and map values across tiles. You can do computation
|
||||
* on them directly if you want, but they're not designed to be maximally efficient vectors
|
||||
* as they have substantial duplication and strange layouts to help them work efficiently with
|
||||
* the register layouts used by the tensor cores. Thundermittens wants you working with tiles
|
||||
* where possible!
|
||||
*/
|
||||
|
||||
template<typename _T, size_t _length, typename _layout>
|
||||
struct rv {
|
||||
using identifier = ducks::rv::identifier; ///< Type identifier for the rv structure.
|
||||
|
||||
static_assert(ducks::is_rv_layout<_layout>(), "_layout must be a rv layout");
|
||||
static_assert(ducks::base_types::isT1Type<_T>(), "T must be float, bf16, or half");
|
||||
using layout = _layout;
|
||||
constant static constexpr bool is_naive = ducks::is_naive_layout<layout>();
|
||||
using T = typename mittens::base_types::packing<_T>::unpacked_type;
|
||||
using T2 =typename mittens::base_types::packing<_T>::packed_type;
|
||||
using dtype = T; ///< Data type of the matrix elements
|
||||
|
||||
constant static constexpr int length = _length; ///< Length in elements.
|
||||
static_assert(length % mittens::TILE_DIM == 0, "Length must be divisible by the tile dimension");
|
||||
constant static constexpr int tiles = _length / mittens::TILE_DIM; ///< Length in subtiles, aliased for consistency with sv type
|
||||
constant static constexpr int inner_dim = layout::inner_dim; ///< Internal layout within a subtile. Either 1 or 2.
|
||||
constant static constexpr int outer_dim = is_naive ? (tiles+3)/4 : tiles; ///< Outer dim (also length in tiles)
|
||||
dtype data[outer_dim][inner_dim]; ///< The actual register vector data.
|
||||
|
||||
METAL_FUNC thread dtype* operator[](size_t idx) { return &data[idx][0]; } ///< A wrapper for indexing into vector data.
|
||||
METAL_FUNC thread const dtype* operator[](size_t idx) const { return &data[idx][0]; } ///< A wrapper for indexing into vector data.
|
||||
METAL_FUNC thread dtype& operator[](int2 outin) { return data[outin.x][outin.y]; } ///< A wrapper for indexing into vector data.
|
||||
METAL_FUNC thread const dtype& operator[](int2 outin) const { return data[outin.x][outin.y]; } ///< A wrapper for indexing into vector data.
|
||||
};
|
||||
|
||||
namespace ducks{
|
||||
template <typename T>
|
||||
struct has_rv_align_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
template <typename _T, int _length>
|
||||
struct has_rv_align_identifier<mittens::rv<_T, _length, ducks::rv_layout::align>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
template <typename RT>
|
||||
static constexpr bool is_align_register_vector() {
|
||||
return has_rv_align_identifier<RT>::value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct has_rv_ortho_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
template <typename _T, int _length>
|
||||
struct has_rv_ortho_identifier<mittens::rv<_T, _length, ducks::rv_layout::ortho>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename RT>
|
||||
static constexpr bool is_ortho_register_vector() {
|
||||
return has_rv_ortho_identifier<RT>::value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct has_rv_naive_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
template <typename _T, int _length>
|
||||
struct has_rv_naive_identifier<mittens::rv<_T, _length, ducks::rv_layout::naive>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
template <typename RT>
|
||||
static constexpr bool is_naive_register_vector() {
|
||||
return has_rv_naive_identifier<RT>::value;
|
||||
}
|
||||
|
||||
template <typename RT>
|
||||
static constexpr bool is_register_vector() {
|
||||
return is_align_register_vector<RT>() || is_ortho_register_vector<RT>() || is_naive_register_vector<RT>();
|
||||
}
|
||||
|
||||
template <typename RT>
|
||||
static constexpr void assert_register_vector() {
|
||||
static_assert(is_register_vector<RT>(), "T must be a rv");
|
||||
}
|
||||
}
|
||||
template<int _l, typename layout=ducks::rv_layout::naive> using rv_fl = rv<float, _l, layout>;
|
||||
template<int _l, typename layout=ducks::rv_layout::naive> using rv_bf = rv<bf16, _l, layout>;
|
||||
template<int _l, typename layout=ducks::rv_layout::naive> using rv_hf = rv<half, _l, layout>;
|
||||
|
||||
}
|
||||
|
||||
54
extra/thunder/include/types/register/rv_layout.metal
Normal file
54
extra/thunder/include/types/register/rv_layout.metal
Normal file
@@ -0,0 +1,54 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Layouts and their manipulations for register tiles.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
namespace mittens {
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace rv_layout
|
||||
*
|
||||
* @brief A namespace for template metaprogramming with register vector layouts.
|
||||
*/
|
||||
namespace rv_layout {
|
||||
|
||||
/**
|
||||
* @brief A dummy type used to identify an aligned (8x replicated) layout.
|
||||
*/
|
||||
struct align { constant constexpr static int inner_dim = 2; };
|
||||
/**
|
||||
* @brief A dummy type used to identify an orthogonal (4x replicated) layout.
|
||||
*/
|
||||
struct ortho { constant constexpr static int inner_dim = 1; };
|
||||
/**
|
||||
* @brief A dummy type used to identify an unreplicated layout, for better coalesced loads and vector operations like layernorm.
|
||||
*/
|
||||
struct naive { constant constexpr static int inner_dim = 1; };
|
||||
|
||||
|
||||
} // namespace rv_layout
|
||||
|
||||
template <typename _layout>
|
||||
METAL_FUNC static constexpr bool is_align_layout() {
|
||||
return metal::is_same_v<_layout, rv_layout::align>;
|
||||
}
|
||||
template <typename _layout>
|
||||
METAL_FUNC static constexpr bool is_ortho_layout() {
|
||||
return metal::is_same_v<_layout, rv_layout::ortho>;
|
||||
}
|
||||
template <typename _layout>
|
||||
METAL_FUNC static constexpr bool is_naive_layout() {
|
||||
return metal::is_same_v<_layout, rv_layout::naive>;
|
||||
}
|
||||
template <typename _layout>
|
||||
METAL_FUNC static constexpr bool is_rv_layout() {
|
||||
return is_align_layout<_layout>() || is_ortho_layout<_layout>() || is_naive_layout<_layout>();
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace ducks
|
||||
} // namespace mittens
|
||||
94
extra/thunder/include/types/shared/cst.metal
Normal file
94
extra/thunder/include/types/shared/cst.metal
Normal file
@@ -0,0 +1,94 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Abstraction for a complex register tile composed of real and imaginary tiles
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "st.metal"
|
||||
#include "csv.metal"
|
||||
namespace mittens {
|
||||
namespace ducks {
|
||||
namespace cst {
|
||||
/**
|
||||
* @brief A dummy type used to identify complex register tiles.
|
||||
*
|
||||
* For a type to quack like an st_cmplx, it should define its identifier as ducks::st::cmplx_identifier.
|
||||
* If a type quacks like ducks::st::cmplx_identifier, it will be treated as an st_cmplx by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
} // namespace st
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @brief Complex tile structure
|
||||
*
|
||||
* @tparam T2 The packed data type used for the matrix elements.
|
||||
* @tparam _rows The height of the tile in terms of the number of subtiles.
|
||||
* @tparam _cols The width of the tile in terms of the number of subtiles.
|
||||
* @tparam _layout The layout of the internal register tiles
|
||||
*
|
||||
* This structure is designed to abstract complex number operations internally to the real and imaginary
|
||||
* shared tiles, respectively
|
||||
*
|
||||
*
|
||||
*/
|
||||
template<typename _T, int _rows, int _cols>
|
||||
struct cst {
|
||||
using identifier = ducks::cst::identifier;
|
||||
using component = st<_T, _rows, _cols>; /// Data type of each internal tile.
|
||||
using T = typename component::T;
|
||||
using T2 = typename component::T2;
|
||||
using dtype = typename component::dtype; ///< Data type of the elements in the tile.
|
||||
|
||||
constant static constexpr int rows = component::rows;
|
||||
constant static constexpr int cols = component::cols;
|
||||
constant static constexpr int height = component::height;
|
||||
constant static constexpr int width = component::width;
|
||||
|
||||
// todo: fill in the rest for convenience, but they're all accessible via component so it's not urgent.
|
||||
|
||||
// Real/imag tiles have same internal layout and size
|
||||
component real;
|
||||
component imag;
|
||||
|
||||
// vector types
|
||||
using col_vec = csv<dtype, rows>;
|
||||
using row_vec = csv<dtype, cols>;
|
||||
};
|
||||
|
||||
/* ---------- CONCEPTS ---------- */
|
||||
|
||||
namespace ducks {
|
||||
template <typename T>
|
||||
struct has_cst_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
|
||||
// Specialize for specific template instantiations of st
|
||||
template <typename _T, int _height, int _width>
|
||||
struct has_cst_identifier<mittens::cst<_T, _height, _width>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename CST>
|
||||
static constexpr bool is_complex_shared_tile() {
|
||||
return has_cst_identifier<CST>::value;
|
||||
}
|
||||
template <typename CST>
|
||||
static constexpr void assert_complex_shared_tile() {
|
||||
static_assert(is_complex_shared_tile<CST>(), "T must be a cst");
|
||||
}
|
||||
|
||||
} // namespace ducks
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
template<int _rows, int _cols> using cst_bf = cst<bf16, _rows, _cols>;
|
||||
template<int _rows, int _cols> using cst_hf = cst<half, _rows, _cols>;
|
||||
template<int _rows, int _cols> using cst_fl = cst<float, _rows, _cols>;
|
||||
|
||||
|
||||
|
||||
}
|
||||
86
extra/thunder/include/types/shared/csv.metal
Normal file
86
extra/thunder/include/types/shared/csv.metal
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief Abstraction for a complex register tile composed of real and imaginary tiles
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "st.metal"
|
||||
|
||||
namespace mittens {
|
||||
namespace ducks {
|
||||
namespace csv {
|
||||
/**
|
||||
* @brief A dummy type used to identify complex register tiles.
|
||||
*
|
||||
* For a type to quack like an st_cmplx, it should define its identifier as ducks::st::cmplx_identifier.
|
||||
* If a type quacks like ducks::st::cmplx_identifier, it will be treated as an st_cmplx by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
} // namespace st
|
||||
} // namespace ducks
|
||||
|
||||
/**
|
||||
* @brief Complex tile structure
|
||||
*
|
||||
* @tparam T2 The packed data type used for the matrix elements.
|
||||
* @tparam _height The height of the tile in terms of the number of subtiles.
|
||||
* @tparam _width The width of the tile in terms of the number of subtiles.
|
||||
* @tparam _layout The layout of the internal register tiles
|
||||
*
|
||||
* This structure is designed to abstract complex number operations internally to the real and imaginary
|
||||
* shared tiles, respectively
|
||||
*
|
||||
*
|
||||
*/
|
||||
template<typename _T, int _length>
|
||||
struct csv {
|
||||
using identifier = ducks::csv::identifier;
|
||||
using component = sv<_T, _length>; /// Data type of each internal tile.
|
||||
using T = typename component::T;
|
||||
using T2 = typename component::T2;
|
||||
using dtype = typename component::dtype; ///< Data type of the elements in the tile.
|
||||
|
||||
constant static constexpr int length = component::length;
|
||||
constant static constexpr int tiles = component::tiles;
|
||||
|
||||
// todo: fill in the rest for convenience, but they're all accessible via component so it's not urgent.
|
||||
|
||||
// Real/imag tiles have same internal layout and size
|
||||
component real;
|
||||
component imag;
|
||||
};
|
||||
|
||||
/* ---------- CONCEPTS ---------- */
|
||||
|
||||
namespace ducks {
|
||||
template <typename T>
|
||||
struct has_csv_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
|
||||
// Specialize for specific template instantiations of st
|
||||
template <typename _T, int _length>
|
||||
struct has_csv_identifier<mittens::csv<_T, _length>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename CSV>
|
||||
static constexpr bool is_complex_shared_vector() {
|
||||
return has_csv_identifier<CSV>::value;
|
||||
}
|
||||
template <typename CSV>
|
||||
static constexpr void assert_complex_shared_vector() {
|
||||
static_assert(is_complex_shared_vector<CSV>(), "T must be a csv");
|
||||
}
|
||||
} // namespace ducks
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
template<int _length> using csv_bf = csv<bf16, _length>;
|
||||
template<int _length> using csv_hf = csv<half, _length>;
|
||||
template<int _length> using csv_fl = csv<float, _length>;
|
||||
|
||||
}
|
||||
|
||||
10
extra/thunder/include/types/shared/shared.metal
Normal file
10
extra/thunder/include/types/shared/shared.metal
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief An aggregate header file for all the shared types defined by Thundermittens.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "st.metal"
|
||||
#include "sv.metal"
|
||||
#include "cst.metal"
|
||||
#include "csv.metal"
|
||||
379
extra/thunder/include/types/shared/st.metal
Normal file
379
extra/thunder/include/types/shared/st.metal
Normal file
@@ -0,0 +1,379 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief The Thundermittens shared tile struct.
|
||||
*/
|
||||
|
||||
#pragma once // kinda done
|
||||
|
||||
/*
|
||||
add subtile, make it work
|
||||
*/
|
||||
#include <metal_stdlib>
|
||||
#include "../../common/common.metal"
|
||||
#include "sv.metal"
|
||||
/* ---------- MAIN TILE STRUCT ---------- */
|
||||
|
||||
// these are helper structs for type inference
|
||||
namespace mittens {
|
||||
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace st
|
||||
*
|
||||
* @brief The namespace where concepts and abstract types for shared tiles live.
|
||||
*/
|
||||
namespace st {
|
||||
/**
|
||||
* @brief A dummy type used to identify shared tiles.
|
||||
*
|
||||
* For a type to quack like an st, it should define its identifier as ducks::st::identifier.
|
||||
* If a type quacks like ducks::st::identifier, it will be treated as an st by compiler checks.
|
||||
* This is particularly useful for subtiles.
|
||||
*/
|
||||
struct identifier {};
|
||||
} // namespace st
|
||||
|
||||
}// namespace ducks
|
||||
|
||||
// Forward declaration of subtile
|
||||
template<
|
||||
typename ST,
|
||||
int _subtile_height,
|
||||
int _subtile_width
|
||||
>
|
||||
struct st_subtile;
|
||||
|
||||
/**
|
||||
* @brief Shared memory tile structure for various data types and layouts.
|
||||
*
|
||||
* @tparam T The data type of the elements in the tile. Not packed!
|
||||
* @tparam _height The height of the tile in units of 8-element subtiles.
|
||||
* @tparam _width The width of the tile in units of 8-element subtiles.
|
||||
*/
|
||||
template<typename _T, int _rows, int _cols>
|
||||
struct mittens_DEFAULT_ALIGN st {
|
||||
using identifier = ducks::st::identifier; ///< Type identifier for the rt structure.
|
||||
using T = typename base_types::packing<_T>::unpacked_type;
|
||||
using T2 = typename base_types::packing<_T>::packed_type;
|
||||
using dtype = T; ///< Data type of the elements in the tile.
|
||||
static_assert(base_types::packing<dtype>::num() == 1, "st type must be 1-packed (float, bf16, etc)"); // must be a 1-packed type (e.g. float, bf16, etc)
|
||||
// define underlying data as same as that projected, to make clear that this is *not* a subtile.
|
||||
static constant constexpr const int underlying_rows = _rows;
|
||||
static constant constexpr const int underlying_cols = _cols;
|
||||
static constant constexpr const int underlying_height = _rows / TILE_DIM;
|
||||
static constant constexpr const int underlying_width = _cols / TILE_DIM;
|
||||
static constant constexpr const int underlying_num_elements = underlying_rows * underlying_cols;
|
||||
|
||||
static constant constexpr const int rows = _rows; ///< Total number of rows in the tile.
|
||||
static_assert(rows % TILE_DIM == 0, "Rows must be divisible by the tile dimension");
|
||||
static constant constexpr const int cols = _cols; ///< Total number of cols in the tile.
|
||||
static_assert(cols % TILE_DIM == 0, "Rows must be divisible by the tile dimension");
|
||||
static constant constexpr const int height = _rows / TILE_DIM; ///< Height of the tile in terms of 16-element subtiles.
|
||||
static constant constexpr const int width = _cols / TILE_DIM; ///< Width of the tile in terms of 16-element subtiles.
|
||||
|
||||
static constant constexpr const int num_elements = rows * cols; ///< Total number of elements in the tile.
|
||||
// static constant constexpr const int row_incr = 32 / memcpy_per_row;
|
||||
|
||||
|
||||
|
||||
dtype data[rows*cols]; ///< Raw data storage for the tile.
|
||||
|
||||
|
||||
|
||||
/* ---------- static vars ---------- */
|
||||
// /* static METAL_FUNC threadgroup float* idx(threadgroup float *ptr, int r, int c)*/
|
||||
static constant constexpr const int swizzle_bytes = underlying_width % 4 == 0 ? 128 : underlying_width%2==0 ? 64 : 32;
|
||||
static constant constexpr const int swizzle_repeat = swizzle_bytes * 8;
|
||||
static constant constexpr const int subtile_cols = swizzle_bytes / sizeof(T);
|
||||
|
||||
static constant constexpr const int subtile_cols_log2 = (swizzle_bytes == 128) ? 5 : (swizzle_bytes == 64) ? 4 : 3;
|
||||
static constant constexpr const int subtile_cols_mask = subtile_cols - 1;
|
||||
static constant constexpr int swizzle_mask = swizzle_repeat - 1;
|
||||
static constant constexpr int swizzle_offset_shift = 7;
|
||||
static constant constexpr int swizzle_adjust_shift = 4;
|
||||
static constant constexpr int mask = (swizzle_repeat - 1) >> swizzle_offset_shift;
|
||||
|
||||
// static constant constexpr const int load_block_bytes = 8;
|
||||
static constant constexpr const int laod_block_words = 4;
|
||||
// static constant constexpr const int load_block_words = 2;
|
||||
static constant constexpr const int col_load_block_words = cols / laod_block_words;
|
||||
static constant constexpr const int load_block_words_mask = laod_block_words - 1;
|
||||
|
||||
|
||||
static METAL_FUNC threadgroup T* idx(threadgroup T * __restrict ptr, int2 coord) { // naive row-major index default
|
||||
int r = coord.x, c = coord.y;
|
||||
return ptr + r * underlying_cols + c;
|
||||
//
|
||||
// c = (c + ((r / 2) * 8)) % cols;
|
||||
// return ptr + r * underlying_cols + c;
|
||||
//// CORRECT 0.124 | 0.168
|
||||
// const int outer_idx = c/subtile_cols;
|
||||
// const uint64_t addr = (uint64_t)(&ptr[outer_idx*rows*subtile_cols + r*subtile_cols + c%subtile_cols]);
|
||||
// const int swizzle = ((addr % swizzle_repeat) >> 7) << 4;
|
||||
// return (threadgroup T*)(addr ^ swizzle);
|
||||
|
||||
// const int outer_idx = c/subtile_cols;
|
||||
// ptr = &ptr[outer_idx*rows*subtile_cols + r*subtile_cols + c%subtile_cols];
|
||||
// const int swizzle = (((uintptr_t)ptr % swizzle_repeat) >> 7) << 4;
|
||||
// return (threadgroup T*)((uintptr_t)ptr ^ swizzle);
|
||||
////
|
||||
//// CORRECT 0.097 | 0.120
|
||||
// int idx = (((c >> subtile_cols_log2) * rows + r) << subtile_cols_log2) + (c & subtile_cols_mask);
|
||||
// // Compute address in bytes (since ptr is float*, multiply idx by sizeof(float) = 4)
|
||||
// int addr_bytes = idx << 2; // Equivalent to idx * 4
|
||||
// // Compute swizzle without modulo operation
|
||||
// int swizzle = (((addr_bytes & swizzle_mask) >> 7) << 4);
|
||||
// // Compute final swizzled address
|
||||
// return (threadgroup T*)((threadgroup char*)ptr + (addr_bytes ^ swizzle));
|
||||
//
|
||||
//// CORRECT ____ | 0.169
|
||||
// int idx = (((c >> subtile_cols_log2) * rows + r) << subtile_cols_log2) + (c & subtile_cols_mask);
|
||||
//
|
||||
// // Compute address in bytes (since ptr is float*, multiply idx by sizeof(float) = 4)
|
||||
// uint64_t addr_bytes = ((uint64_t)ptr) + ((uint64_t)idx << 2); // Full address in bytes
|
||||
//
|
||||
// // Compute swizzle including the base address
|
||||
// int swizzle = ((addr_bytes % swizzle_repeat) >> 7) << 4;
|
||||
//
|
||||
// // Compute final swizzled address
|
||||
// addr_bytes ^= swizzle;
|
||||
//
|
||||
// // Return the swizzled address
|
||||
// return (threadgroup float*)(addr_bytes);
|
||||
//
|
||||
}
|
||||
static METAL_FUNC uint32_t idx(uint32_t ptr, int2 coord) { // naive row-major index
|
||||
int r = coord.x, c = coord.y; // alias
|
||||
return ptr + sizeof(T) * (r * underlying_cols + c);
|
||||
|
||||
// c = (c + ((r / 2) * 8)) % cols;
|
||||
// return ptr + r * underlying_cols + c;
|
||||
// return ptr + sizeof(T) * (r * underlying_cols + c);
|
||||
}
|
||||
/**
|
||||
* @brief Access a shared tile element using a row and column, as if the tile were row-major.
|
||||
*
|
||||
* This is the preferred way to access memory within a shared tile, which abstracts
|
||||
* indexing calculations for swizzled layouts.
|
||||
*/
|
||||
METAL_FUNC threadgroup T& operator[](thread const int2& rowcol) threadgroup {
|
||||
return *idx(data, rowcol);
|
||||
}
|
||||
METAL_FUNC const threadgroup T& operator[](thread const int2 &rowcol) const threadgroup {
|
||||
return *(const threadgroup T*)idx((threadgroup T*)data, rowcol);
|
||||
}
|
||||
|
||||
METAL_FUNC threadgroup T& operator[](int idx) threadgroup {
|
||||
return data[idx];
|
||||
}
|
||||
METAL_FUNC const threadgroup T& operator[](int idx) const threadgroup {
|
||||
return data[idx];
|
||||
}
|
||||
|
||||
using col_vec = sv<dtype, rows>; ///< Column vector type for this tile
|
||||
using row_vec = sv<dtype, cols>; ///< Row vector type for this tile
|
||||
template<int subtile_rows, int subtile_cols> using subtile = st_subtile<
|
||||
st<T, rows, cols>, subtile_rows, subtile_cols
|
||||
>; ///< A templated subtile type wrapper for this tile.
|
||||
};
|
||||
|
||||
|
||||
/**
|
||||
* @brief A reference into a chunk of shared tile memory.
|
||||
*
|
||||
* The st_subtile is a drop-in replacement for an st which internally
|
||||
* references the appropriate memory while performing minimal address
|
||||
* calculations. You should never create this directly, but instead
|
||||
* have subtile_inplace return it for you instead. (`auto` is nice.)
|
||||
*
|
||||
* You can generally just pretend this is an st. But not for wgmma's.
|
||||
*/
|
||||
template<
|
||||
typename _ST,
|
||||
int _subtile_rows,
|
||||
int _subtile_cols
|
||||
>
|
||||
struct st_subtile {
|
||||
using identifier = ducks::st::identifier; // i quack like an st, gcc will never know the difference
|
||||
using ST = _ST;
|
||||
using T = typename ST::T;
|
||||
using T2 = typename ST::T2;
|
||||
using dtype = T; ///< Data type of the elements in the tile.
|
||||
|
||||
|
||||
constant static constexpr int underlying_rows = ST::underlying_rows;
|
||||
static_assert(underlying_rows % TILE_DIM == 0, "Underlying rows must be divisible by the tile dimension");
|
||||
constant static constexpr int underlying_cols = ST::underlying_cols;
|
||||
static_assert(underlying_cols % TILE_DIM == 0, "Underlying cols must be divisible by the tile dimension");
|
||||
constant static constexpr int underlying_height = ST::underlying_height;
|
||||
constant static constexpr int underlying_width = ST::underlying_width;
|
||||
constant static constexpr int underlying_num_elements = ST::underlying_num_elements;
|
||||
|
||||
constant static constexpr int rows = _subtile_rows;
|
||||
static_assert(rows % TILE_DIM == 0, "Rows must be divisible by the tile dimension");
|
||||
constant static constexpr int cols = _subtile_cols;
|
||||
static_assert(cols % TILE_DIM == 0, "Cols must be divisible by the tile dimension");
|
||||
constant static constexpr int height = rows / TILE_DIM;
|
||||
constant static constexpr int width = cols / TILE_DIM;
|
||||
constant static constexpr int num_elements = rows * cols;
|
||||
|
||||
// constant static constexpr int swizzle_bytes = ST::swizzle_bytes;
|
||||
|
||||
// device dtype *data;
|
||||
threadgroup T* data;
|
||||
int row_offset, col_offset;
|
||||
|
||||
// METAL_FUNC st_subtile(threadgroup ST &src, int2 rowcol) {
|
||||
// data = reinterpret_cast<uint64_t>(&src.data[0]);
|
||||
// row_offset = rowcol.x * rows;
|
||||
// col_offset = rowcol.y * cols;
|
||||
// }
|
||||
// void METAL_FUNC init_subtile(threadgroup ST &src, int2 rowcol) {
|
||||
//// data = &(src.data[0]);
|
||||
// row_offset = rowcol.x * rows;
|
||||
// col_offset = rowcol.y * cols;
|
||||
// }
|
||||
template<typename SUBTILE, typename ST>
|
||||
static void METAL_FUNC init_subtile(threadgroup SUBTILE& sub_st, threadgroup ST& src, int2 rowcol) {
|
||||
sub_st.data = (threadgroup T*)src.data;
|
||||
sub_st.row_offset = rowcol.x * rows;
|
||||
sub_st.col_offset = rowcol.y * cols;
|
||||
}
|
||||
|
||||
template<typename SUBTILE, typename ST>
|
||||
static void METAL_FUNC init_subtile(thread SUBTILE& sub_st, threadgroup ST& src, int2 rowcol) {
|
||||
sub_st.data = (threadgroup T*)src.data;
|
||||
sub_st.row_offset = rowcol.x * rows;
|
||||
sub_st.col_offset = rowcol.y * cols;
|
||||
}
|
||||
|
||||
// METAL_FUNC threadgroup T* idx(threadgroup T *ptr, const int2 coord) { // naive row-major index default
|
||||
// int r = coord.x+row_offset, c = coord.y+col_offset; // alias
|
||||
// return ptr + r * underlying_cols + c;
|
||||
// }
|
||||
// // Add this const overload of idx
|
||||
// METAL_FUNC const threadgroup T* idx(const threadgroup T *ptr, const int2 coord) const {
|
||||
// int r = coord.x + row_offset, c = coord.y + col_offset;
|
||||
// return ptr + r * underlying_cols + c;
|
||||
// }
|
||||
//
|
||||
// METAL_FUNC uint32_t idx(uint32_t ptr, const int2 coord) const { // naive row-major index default
|
||||
// int r = coord.x+row_offset, c = coord.y+col_offset; // alias
|
||||
// return ptr + sizeof(T) * (r * underlying_cols + c);
|
||||
// }
|
||||
// METAL_FUNC threadgroup T& operator[](thread const int2 &rowcol) threadgroup {
|
||||
// return *idx(data, rowcol);
|
||||
// }
|
||||
// METAL_FUNC const threadgroup T& operator[](thread const int2 &rowcol) const threadgroup {
|
||||
// return *idx(data, rowcol);
|
||||
// }
|
||||
// Declare idx as a const member function
|
||||
// METAL_FUNC threadgroup T* idx(threadgroup T * __restrict ptr, const int2 coord) const {
|
||||
// int r = coord.x + row_offset, c = coord.y + col_offset;
|
||||
// return ptr + r * underlying_cols + c;
|
||||
// }
|
||||
//
|
||||
// // New idx function (const overload)
|
||||
// METAL_FUNC uint32_t idx(uint32_t ptr, int2 coord) {
|
||||
// int r = coord.x + row_offset, c = coord.y + col_offset;
|
||||
// return ptr + r * underlying_cols + c;
|
||||
// }
|
||||
//
|
||||
// // Non-const operator[]
|
||||
// METAL_FUNC threadgroup T& operator[](thread const int2& rowcol) threadgroup {
|
||||
// return *idx(data, rowcol);
|
||||
// }
|
||||
//
|
||||
// // Const operator[]
|
||||
// METAL_FUNC const threadgroup T& operator[](thread const int2 &rowcol) threadgroup const {
|
||||
// return *idx(data, rowcol);
|
||||
// }
|
||||
// idx function returning threadgroup T*
|
||||
METAL_FUNC threadgroup T* idx(threadgroup T * __restrict ptr, const int2 coord) threadgroup const {
|
||||
int r = coord.x + row_offset, c = coord.y + col_offset;
|
||||
return ptr + r * underlying_cols + c;
|
||||
}
|
||||
|
||||
// idx function returning uint32_t
|
||||
METAL_FUNC uint32_t idx(uint32_t ptr, int2 coord) threadgroup const {
|
||||
int r = coord.x + row_offset, c = coord.y + col_offset;
|
||||
return ptr + r * underlying_cols + c;
|
||||
}
|
||||
|
||||
// Non-const operator[]
|
||||
METAL_FUNC threadgroup T& operator[](thread const int2& rowcol) threadgroup {
|
||||
return *idx(data, rowcol);
|
||||
}
|
||||
|
||||
// Const operator[]
|
||||
METAL_FUNC const threadgroup T& operator[](thread const int2 &rowcol) threadgroup const {
|
||||
return *idx(data, rowcol);
|
||||
}
|
||||
|
||||
|
||||
METAL_FUNC threadgroup T* idx(threadgroup T * __restrict ptr, const int2 coord) thread const {
|
||||
int r = coord.x + row_offset, c = coord.y + col_offset;
|
||||
return ptr + r * underlying_cols + c;
|
||||
}
|
||||
|
||||
// idx function returning uint32_t
|
||||
METAL_FUNC uint32_t idx(uint32_t ptr, int2 coord) thread const {
|
||||
int r = coord.x + row_offset, c = coord.y + col_offset;
|
||||
return ptr + r * underlying_cols + c;
|
||||
}
|
||||
|
||||
// Non-const operator[]
|
||||
METAL_FUNC threadgroup T& operator[](thread const int2& rowcol) thread {
|
||||
return *idx(data, rowcol);
|
||||
}
|
||||
|
||||
// Const operator[]
|
||||
METAL_FUNC const threadgroup T& operator[](thread const int2 &rowcol) thread const {
|
||||
return *idx(data, rowcol);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
// single-index operator[] is left undefined as it would likely be an improper use of st_subtile type.
|
||||
// can of course be end-run by just accessing .data directly.
|
||||
|
||||
};
|
||||
|
||||
namespace ducks{
|
||||
template <typename T>
|
||||
struct has_st_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
|
||||
// Specialize for specific template instantiations of st
|
||||
template <typename _T, int _height, int _width>
|
||||
struct has_st_identifier<mittens::st<_T, _height, _width>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template<typename _T, int _subtile_rows, int _subtile_cols>
|
||||
struct has_st_identifier<mittens::st_subtile<_T, _subtile_rows, _subtile_cols>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename ST>
|
||||
static constexpr bool is_shared_tile() {
|
||||
return has_st_identifier<ST>::value;
|
||||
}
|
||||
template <typename ST>
|
||||
static constexpr void assert_shared_tile() {
|
||||
static_assert(is_shared_tile<ST>(), "T must be a st");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
|
||||
// layout and type wrappers
|
||||
template<int _height, int _width> using st_bf = st<bf16, _height, _width>;
|
||||
template<int _height, int _width> using st_hf = st<half, _height, _width>;
|
||||
template<int _height, int _width> using st_fl = st<float, _height, _width>;
|
||||
} // namespace mittens
|
||||
|
||||
86
extra/thunder/include/types/shared/sv.metal
Normal file
86
extra/thunder/include/types/shared/sv.metal
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* @file
|
||||
* @brief The Thundermittens shared vector struct.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "../../common/common.metal"
|
||||
#include <metal_stdlib>
|
||||
namespace mittens {
|
||||
namespace ducks {
|
||||
/**
|
||||
* @namespace sv
|
||||
*
|
||||
* @brief The namespace where concepts and abstract types for shared vectors live.
|
||||
*/
|
||||
namespace sv {
|
||||
/**
|
||||
* @brief A dummy type used to identify shared vectors.
|
||||
*
|
||||
* For a type to quack like an sv, it should define its identifier as ducks::sv::identifier.
|
||||
* If a type quacks like ducks::sv::identifier, it will be treated as an sv by compiler checks.
|
||||
*/
|
||||
struct identifier {};
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @brief Shared vector structure.
|
||||
*
|
||||
* @tparam _T The packed data type used for the vector elements.
|
||||
* @tparam _tiles The size of the tile, in units of TILE_DIM (16).
|
||||
*
|
||||
* Shared vectors are used to accumulate and map values across shared tiles.
|
||||
* Unlike every other structure present in Thundermittens, these have a simple
|
||||
* uniform layout which is just an array in memory. EZ!
|
||||
*/
|
||||
template<typename _T, size_t _length>
|
||||
struct mittens_DEFAULT_ALIGN sv {
|
||||
using identifier = ducks::sv::identifier;
|
||||
using T = typename base_types::packing<_T>::unpacked_type;
|
||||
using T2 = typename base_types::packing<_T>::packed_type;
|
||||
using dtype = T; ///< Data type of the elements in the tile.
|
||||
|
||||
constant static constexpr int length = _length; ///< Length in elements.
|
||||
static_assert(length % TILE_DIM == 0, "Length must be divisible by the tile dimension");
|
||||
constant static constexpr int tiles = length / TILE_DIM; ///< Length in subtiles.
|
||||
|
||||
dtype data[length]; ///< The actual shared vector data.
|
||||
|
||||
METAL_FUNC threadgroup dtype& operator[](size_t idx) threadgroup { return data[idx]; }
|
||||
METAL_FUNC const threadgroup dtype& operator[](size_t idx) const threadgroup { return data[idx]; }
|
||||
|
||||
template<size_t _len> using subvec = sv<dtype, _len>;
|
||||
};
|
||||
|
||||
|
||||
namespace ducks {
|
||||
template <typename T>
|
||||
struct has_sv_identifier {
|
||||
static constant constexpr bool value = false; // Default case
|
||||
};
|
||||
|
||||
// Specialize for specific template instantiations of st
|
||||
template<typename _T, size_t _length>
|
||||
struct has_sv_identifier<mittens::sv<_T, _length>> {
|
||||
static constant constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename ST>
|
||||
static constexpr bool is_shared_vector() {
|
||||
return has_sv_identifier<ST>::value;
|
||||
}
|
||||
template <typename ST>
|
||||
static constexpr void assert_shared_vector() {
|
||||
static_assert(is_shared_vector<ST>(), "T must be a sv");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<size_t _tiles> using sv_bf = sv<bfloat, _tiles>;
|
||||
template<size_t _tiles> using sv_hf = sv<half , _tiles>;
|
||||
template<size_t _tiles> using sv_fl = sv<float , _tiles>;
|
||||
}
|
||||
|
||||
|
||||
49
extra/thunder/include/types/types.metal
Normal file
49
extra/thunder/include/types/types.metal
Normal file
@@ -0,0 +1,49 @@
|
||||
#pragma once
|
||||
#include "global/global.metal"
|
||||
#include "register/register.metal"
|
||||
#include "shared/shared.metal"
|
||||
|
||||
|
||||
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
||||
namespace mittens {
|
||||
/**
|
||||
* @brief Row vector type alias.
|
||||
*
|
||||
* This template alias provides a convenient way to refer to the row vector type
|
||||
* associated with a given class or type `T`. It assumes that the class `T` has
|
||||
* a nested type named `row_vec`.
|
||||
*
|
||||
* @tparam T The class or type for which the row vector type is defined.
|
||||
*
|
||||
* Example usage:
|
||||
* @code
|
||||
* mittens::row_vec<decltype(some_tile)> row_vector;
|
||||
* @endcode
|
||||
*/
|
||||
template<typename T>
|
||||
using row_vec = typename T::row_vec;
|
||||
|
||||
/**
|
||||
* @brief Column vector type alias.
|
||||
*
|
||||
* This template alias provides a convenient way to refer to the column vector type
|
||||
* associated with a given class or type `T`. It assumes that the class `T` has
|
||||
* a nested type named `col_vec`.
|
||||
*
|
||||
* @tparam T The class or type for which the column vector type is defined.
|
||||
*
|
||||
* Example usage:
|
||||
* @code
|
||||
* mittens::col_vec<decltype(some_tile)> col_vector;
|
||||
* @endcode
|
||||
*/
|
||||
template<typename T>
|
||||
using col_vec = typename T::col_vec;
|
||||
|
||||
// register vector layouts
|
||||
using align_l = ducks::rv_layout::align;
|
||||
using ortho_l = ducks::rv_layout::ortho;
|
||||
using naive_l = ducks::rv_layout::naive;
|
||||
|
||||
// ^ this code lives here because it applies to both sv and rv types
|
||||
}
|
||||
@@ -47,7 +47,7 @@ class _Device:
|
||||
os.environ[device] = "1" # we set this in environment for spawned children
|
||||
return device
|
||||
except StopIteration as exc: raise RuntimeError("no usable devices") from exc
|
||||
Device = _Device()
|
||||
Device: _Device = _Device()
|
||||
atexit.register(lambda: [Device[dn].finalize() for dn in Device._opened_devices])
|
||||
|
||||
# **************** Profile ****************
|
||||
|
||||
Reference in New Issue
Block a user