Files
tinygrad/extra/thunder/include/common/base_ops.metal
George Hotz b4509fba31 thundermittens (#12471)
* thundermittens

* give device a type
2025-10-07 11:47:39 +08:00

393 lines
20 KiB
Metal

/**
* @file
* @brief Basic operations on generic types.
*/
#pragma once
#include "base_types.metal"
#include <metal_math>
namespace mittens {
/**
* @namespace base_ops
*
* @brief A namespace for operations on basic data types.
*/
namespace base_ops {
#define TEMPLATE_OPS_SINGLE(func_contents) \
template<typename T> static METAL_FUNC T op(device const T &x) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &x) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &x) { func_contents }
#define TEMPLATE_OPS_OVERRIDE_SINGLE(T, op_name, func_contents) \
template<> METAL_FUNC T op_name::op<T>(device const T &x) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &x) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &x) { func_contents }
#define TEMPLATE_OPS_DOUBLE(func_contents) \
template<typename T> static METAL_FUNC T op(device const T &a, device const T &b) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, threadgroup const T &b) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, thread const T &b) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, device const T &b) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, thread const T &b) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, device const T &b) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, threadgroup const T &b) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, thread const T &b) { func_contents }
#define TEMPLATE_OPS_OVERRIDE_DOUBLE(T, op_name, func_contents) \
template<> METAL_FUNC T op_name::op<T>(device const T &a, device const T &b) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, threadgroup const T &b) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, thread const T &b) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, device const T &b) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, threadgroup const T &b) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, thread const T &b) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, device const T &b) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, threadgroup const T &b) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, thread const T &b) { func_contents }
#define TEMPLATE_OPS_TRIPLE(func_contents) \
template<typename T> static METAL_FUNC T op(device const T &a, device const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, device const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, device const T &b, thread const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, threadgroup const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, thread const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(device const T &a, thread const T &b, thread const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, device const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, device const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, device const T &b, thread const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, thread const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(threadgroup const T &a, thread const T &b, thread const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, device const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, device const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, device const T &b, thread const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, threadgroup const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, thread const T &b, device const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
template<typename T> static METAL_FUNC T op(thread const T &a, thread const T &b, thread const T &c) { func_contents }
#define TEMPLATE_OPS_OVERRIDE_TRIPLE(T, op_name, func_contents) \
template<> METAL_FUNC T op_name::op<T>(device const T &a, device const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, device const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, device const T &b, thread const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, threadgroup const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, thread const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(device const T &a, thread const T &b, thread const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, device const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, device const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, device const T &b, thread const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, threadgroup const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, thread const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(threadgroup const T &a, thread const T &b, thread const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, device const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, device const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, device const T &b, thread const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, threadgroup const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, threadgroup const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, threadgroup const T &b, thread const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, thread const T &b, device const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, thread const T &b, threadgroup const T &c) { func_contents } \
template<> METAL_FUNC T op_name::op<T>(thread const T &a, thread const T &b, thread const T &c) { func_contents }
/* ---------- CONST OPS ---------- */
/**
* @brief Represents the zero constant operation.
*
* This operation returns the zero value of the specified type.
*
* @tparam T The data type for which to return the zero value.
* @return The zero value of type T.
*/
struct zero {
template<typename T, typename... args> static METAL_FUNC constexpr T op(args... _) { return base_types::constants<T>::zero(); }
};
/**
* @brief Represents the one constant operation.
*
* This operation returns the one value of the specified type.
*
* @tparam T The data type for which to return the one value.
* @return The one value of type T.
*/
struct one {
template<typename T, typename... args> static METAL_FUNC constexpr T op(args... _) { return base_types::constants<T>::one(); }
};
/**
* @brief Represents the positive infinity constant operation.
*
* This operation returns the positive infinity value of the specified type.
*
* @tparam T The data type for which to return the positive infinity value.
* @return The positive infinity value of type T.
*/
struct pos_infty {
template<typename T, typename... args> static METAL_FUNC constexpr T op(args... _) { return base_types::constants<T>::pos_infty(); }
};
/**
* @brief Represents the negative infinity constant operation.
*
* This operation returns the negative infinity value of the specified type.
*
* @tparam T The data type for which to return the negative infinity value.
* @return The negative infinity value of type T.
*/
struct neg_infty {
template<typename T, typename... args> static METAL_FUNC constexpr T op(args... _) { return base_types::constants<T>::neg_infty(); }
};
/* ---------- UNARY OPS ---------- */
/**
* @brief Exponential function operation.
*
* This operation calculates the exponential of the input value.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The exponential of the input value.
*/
struct exp {
TEMPLATE_OPS_SINGLE(return metal::exp(x);)
};
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, exp, return bf16(metal::exp((float)x));)
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, exp, return bf16_2(metal::exp(float2(x)));)
/**
* @brief Exponential function operation, in base 2
*
* This operation calculates the exponential of the input value, in base 2.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The exponential of the input value.
*/
struct exp2 {
template<typename T> static METAL_FUNC T op(device const T &x) { return metal::exp2(x); } \
template<typename T> static METAL_FUNC T op(threadgroup const T &x) { return metal::exp2(x); } \
template<typename T> static METAL_FUNC T op(thread const T &x) { return metal::exp2(x); }
};
//template<> METAL_FUNC bf16 exp2::op<bf16>(device const bf16 &x) { return bf16(metal::exp2(x)); } \
//template<> METAL_FUNC bf16 exp2::op<bf16>(threadgroup const bf16 &x) { return bf16(metal::exp2(x)); } \
//template<> METAL_FUNC bf16 exp2::op<bf16>(thread const bf16 &x) { return bf16(metal::exp2(x)); }
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, exp2, return bf16(metal::exp2(x));)
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, exp2, return bf16_2(metal::exp2((float2)x));)
/**
* @brief Natural log function operation.
*
* This operation calculates the natural logarithm of the input value.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The natural logarithm of the input value.
*/
struct log {
TEMPLATE_OPS_SINGLE(return metal::log(x);)
};
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16, log, return bf16(metal::log(x));)
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, log, return bf16_2(metal::log((float2)x));)
/**
* @brief Absolute value operation.
*
* This operation calculates the absolute value of the input.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The absolute value of the input.
*/
struct abs {
TEMPLATE_OPS_SINGLE(return metal::abs(x);)
};
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16 , abs, return bf16(metal::abs((float)x));)
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, abs, return bf16_2(metal::abs((float2)x));)
/**
* @brief Rectified Linear Unit (ReLU) operation.
*
* This operation applies the ReLU function to the input, which is the
* maximum of zero and the input value.
*
* @tparam T The data type of the input and output values.
* @param x[in] The input value.
* @return The result of ReLU function applied to the input.
*/
struct relu {
TEMPLATE_OPS_SINGLE(return max(x, base_types::constants<T>::zero());)
};
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16 , relu, return bf16(metal::max((float)x, base_types::constants<float>::zero()));)
TEMPLATE_OPS_OVERRIDE_SINGLE(bf16_2, relu, return bf16_2(metal::max((float2)x, base_types::constants<float2>::zero()));)
/**
* @brief Copy operation.
*
* This operation returns the input value unchanged.
*
* @tparam T The data type of the input and output values.
* @param a[in] The input value.
* @return The same value as the input.
*/
struct copy { // for non-compile-time setters.
TEMPLATE_OPS_SINGLE(return x;)
};
/* ---------- BINARY OPS ---------- */
/**
* @brief Copy2 operation.
*
* This operation returns the second input value unchanged.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value (ignored).
* @param b[in] The second input value.
* @return The same value as the second input.
*/
struct copy2 { // this turns out to be a slightly hacky op that makes some code cleaner :/
TEMPLATE_OPS_DOUBLE(return b;)
};
/**
* @brief Sum operation.
*
* This operation calculates the sum of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The sum of the input values.
*/
struct sum {
TEMPLATE_OPS_DOUBLE(return a+b;)
};
/**
* @brief Subtraction operation.
*
* This operation calculates the difference between two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The difference between the input values.
*/
struct sub {
TEMPLATE_OPS_DOUBLE(return a-b;)
};
/**
* @brief Multiplication operation.
*
* This operation calculates the product of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The product of the input values.
*/
struct mul {
TEMPLATE_OPS_DOUBLE(return a*b;)
};
/**
* @brief Division operation.
*
* This operation calculates the quotient of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The quotient of the input values.
*/
struct div {
TEMPLATE_OPS_DOUBLE(return a/b;)
};
/**
* @brief Maximum operation.
*
* This operation calculates the maximum of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The maximum of the input values.
*/
struct max {
TEMPLATE_OPS_DOUBLE(return metal::max(a,b);)
};
TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16 , max, return (bf16)metal::max((float)a, (float)b);)
TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16_2, max, return (bf16_2)metal::max((float2)a, (float2)b);)
/**
* @brief Minimum operation.
*
* This operation calculates the minimum of two input values.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @return The minimum of the input values.
*/
struct min {
TEMPLATE_OPS_DOUBLE(return metal::min(a,b);)
};
TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16 , min, return (bf16)metal::min((float)a, (float)b);)
TEMPLATE_OPS_OVERRIDE_DOUBLE(bf16_2, min, return (bf16_2)metal::min((float2)a, (float2)b);)
/* ---------- TERNARY OPS ---------- */
/**
* @brief Fused multiply-add operation A * B + C.
*
* This operation performs a fused multiply-add, computing (A * B) + C with only one rounding.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The second input value.
* @param c[in] The third input value to be added.
* @return The result of the fused multiply-add operation.
*/
struct fma_AxBtC {
TEMPLATE_OPS_TRIPLE(return sum::op<T>(mul::op<T>(a, b), c);)
};
/**
* @brief Fused multiply-add operation A * C + B.
*
* This operation performs a fused multiply-add, computing (A * C) + B with only one rounding.
* This is particularly useful for attention mechanisms in neural networks.
*
* @tparam T The data type of the input and output values.
* @param a[in] The first input value.
* @param b[in] The third input value to be added.
* @param c[in] The second input value.
* @return The result of the fused multiply-add operation.
*/
struct fma_AxCtB { // this is the one needed for attention
TEMPLATE_OPS_TRIPLE(return sum::op<T>(mul::op<T>(a, c), b);)
};
#undef TEMPLATE_OPS_SINGLE
#undef TEMPLATE_OPS_OVERRIDE_SINGLE
#undef TEMPLATE_OPS_DOUBLE
#undef TEMPLATE_OPS_OVERRIDE_DOUBLE
#undef TEMPLATE_OPS_TRIPLE
#undef TEMPLATE_OPS_OVERRIDE_TRIPLE
} // base_ops
} // mittens