/** * @file * @brief Group maps on shared vectors. */ /** * @brief Applies a unary operation to each element of a shared memory vector. * * @tparam op Unary operation type. * @tparam T Shared memory vector type. * @param dst[out] Destination vector in which to store the result. * @param src[in] Source vector to apply the unary operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type unary_op(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) { #pragma clang loop unroll(full) for(auto cur = laneid(threadIdx); cur < SV::length; cur+=GROUP_THREADS) { dst[cur] = op::template op(src[cur]); } } /** * @brief Perform a binary operation on two shared vectors. * * @tparam op The binary operation to perform. * @tparam T The type of the vectors. * @param dst[out] The destination vector where the result is stored. * @param lhs[in] The left-hand side vector for the operation. * @param rhs[in] The right-hand side vector for the operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type bin_op(threadgroup SV &dst, threadgroup const SV &lhs, threadgroup const SV &rhs, const int threadIdx) { #pragma clang loop unroll(full) for(auto cur = laneid(threadIdx); cur < SV::length; cur+=GROUP_THREADS) { dst[cur] = op::template op(lhs[cur], rhs[cur]); } } /** * @brief Perform a binary operation on a shared vector and a scalar. * * @tparam op The binary operation to perform. * @tparam T The type of the vector. * @param dst[out] The destination vector where the result is stored. * @param src[in] The source vector for the operation. * @param param[in] The scalar parameter for the operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type bin_op(threadgroup SV &dst, threadgroup const SV &src, thread const typename SV::dtype ¶m, const int threadIdx) { #pragma clang loop unroll(full) for(auto cur = laneid(threadIdx); cur < SV::length; cur+=GROUP_THREADS) { dst[cur] = op::template op(src[cur], param); } } /* ---------- WRAPPERS FOR PRETTINESS ---------- */ // ---- const ops ---- /** * @brief Sets all elements of a shared memory vector to zero. * * @tparam T Shared memory vector type. * @param dst[out] Destination vector to be set to zero. */ template static METAL_FUNC typename metal::enable_if(), void>::type zero(threadgroup SV &dst, const int threadIdx) { unary_op(dst, dst, threadIdx); } /** * @brief Sets all elements of a shared memory vector to one. * * @tparam T Shared memory vector type. * @param dst[out] Destination vector to be set to one. */ template static METAL_FUNC typename metal::enable_if(), void>::type one(threadgroup SV &dst, const int threadIdx) { unary_op(dst, dst, threadIdx); } /** * @brief Sets all elements of a shared memory vector to positive infinity. * * @tparam T Shared memory vector type. * @param dst[out] Destination vector to be set to positive infinity. */ template static METAL_FUNC typename metal::enable_if(), void>::type pos_infty(threadgroup SV &dst, const int threadIdx) { unary_op(dst, dst, threadIdx); } /** * @brief Sets all elements of a shared memory vector to negative infinity. * * @tparam T Shared memory vector type. * @param dst[out] Destination vector to be set to negative infinity. */ template static METAL_FUNC typename metal::enable_if(), void>::type neg_infty(threadgroup SV &dst, const int threadIdx) { unary_op(dst, dst, threadIdx); } // ---- unary ops ---- /** * @brief Copies the elements from one shared vector to another. * * @tparam T Shared vector type. * @tparam U Type of the source vector. * @param dst[out] Destination vector where the elements will be copied to. * @param src[in] Source vector to copy the elements from. */ template static METAL_FUNC typename metal::enable_if(), void>::type copy(threadgroup SV &dst, thread const U &src, const int threadIdx) { bin_op(dst, dst, src, threadIdx); // the second arg is ignored here. } /** * @brief Applies the exponential function element-wise to a shared vector. * * @tparam T Shared vector type. * @param dst[out] Destination vector where the exponential values will be stored. * @param src[in] Source vector to apply the exponential function to. */ template static METAL_FUNC typename metal::enable_if(), void>::type exp(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) { unary_op(dst, src, threadIdx); } /** * @brief Applies the exponential function element-wise to a shared vector, in base 2. * * @tparam T Shared vector type. * @param dst[out] Destination vector where the exponential values will be stored. * @param src[in] Source vector to apply the exponential function to. */ template static METAL_FUNC typename metal::enable_if(), void>::type exp2(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) { unary_op(dst, src, threadIdx); } /** * @brief Applies the natural logarithm function element-wise to a shared vector. * * @tparam T Shared vector type. * @param dst[out] Destination vector where the exponential values will be stored. * @param src[in] Source vector to apply the logarithm function to. */ template static METAL_FUNC typename metal::enable_if(), void>::type log(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) { unary_op(dst, src, threadIdx); } /** * @brief Applies the absolute value function element-wise to a shared vector. * * @tparam T Shared vector type. * @param dst[out] Destination vector where the absolute values will be stored. * @param src[in] Source vector to apply the absolute value function to. */ template static METAL_FUNC typename metal::enable_if(), void>::type abs(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) { unary_op(dst, src, threadIdx); } /** * @brief Applies the rectified linear unit (ReLU) function element-wise to a shared vector. * * @tparam T Shared vector type. * @param dst[out] Destination vector where the ReLU values will be stored. * @param src[in] Source vector to apply the ReLU function to. */ template static METAL_FUNC typename metal::enable_if(), void>::type relu(threadgroup SV &dst, threadgroup const SV &src, const int threadIdx) { unary_op(dst, src, threadIdx); } // ---- binary ops ---- /** * @brief Computes the element-wise maximum of two shared vectors. * * @tparam T Shared vector type. * @tparam U Type of the second vector. * @param dst[out] Destination vector where the maximum values will be stored. * @param lhs[in] First vector for the maximum operation. * @param rhs[in] Second vector for the maximum operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type max(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) { bin_op(dst, lhs, rhs, threadIdx); } /** * @brief Computes the element-wise minimum of two shared vectors. * * @tparam T Shared vector type. * @tparam U Type of the second vector. * @param dst[out] Destination vector where the minimum values will be stored. * @param lhs[in] First vector for the minimum operation. * @param rhs[in] Second vector for the minimum operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type min(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) { bin_op(dst, lhs, rhs, threadIdx); } /** * @brief Computes the element-wise sum of two shared vectors. * * @tparam T Shared vector type. * @tparam U Type of the second vector. * @param dst[out] Destination vector where the sum values will be stored. * @param lhs[in] First vector for the sum operation. * @param rhs[in] Second vector for the sum operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type add(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) { bin_op(dst, lhs, rhs, threadIdx); } /** * @brief Computes the element-wise difference of two shared vectors. * * @tparam T Shared vector type. * @tparam U Type of the second vector. * @param dst[out] Destination vector where the difference values will be stored. * @param lhs[in] First vector for the difference operation. * @param rhs[in] Second vector for the difference operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type sub(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) { bin_op(dst, lhs, rhs, threadIdx); } /** * @brief Computes the element-wise product of two shared vectors. * * @tparam T Shared vector type. * @tparam U Type of the second vector. * @param dst[out] Destination vector where the product values will be stored. * @param lhs[in] First vector for the product operation. * @param rhs[in] Second vector for the product operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type mul(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) { bin_op(dst, lhs, rhs, threadIdx); } /** * @brief Computes the element-wise division of two shared vectors. * * @tparam T Shared vector type. * @tparam U Type of the second vector. * @param dst[out] Destination vector where the division values will be stored. * @param lhs[in] First vector for the division operation. * @param rhs[in] Second vector for the division operation. */ template static METAL_FUNC typename metal::enable_if(), void>::type div(threadgroup SV &dst, threadgroup const SV &lhs, thread const U &rhs, const int threadIdx) { bin_op(dst, lhs, rhs, threadIdx); }