mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
476 lines
23 KiB
Metal
476 lines
23 KiB
Metal
/**
|
|
* @file
|
|
* @brief Group maps on shared tiles.
|
|
*/
|
|
|
|
/**
|
|
* @brief Performs a uniform unary operation on a tile.
|
|
*
|
|
* This function applies a given unary operation to each element of the source tile and stores the result in the destination tile.
|
|
* The operation is applied independently to each element, without considering its position or the values of neighboring elements.
|
|
*
|
|
* @tparam op The unary operation to be applied. Must be specialized to support operation on the data type of T.
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the unary operation is applied.
|
|
*/
|
|
template<typename op, typename ST> // T2, w, h can be inferred from dst as long as op is specialized
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
unary_map(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
|
#pragma clang loop unroll(full)
|
|
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
|
dst.data[i] = op::template op<typename ST::dtype>(src.data[i]);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Performs a uniform binary operation on a tile with a scalar parameter.
|
|
*
|
|
* This function applies a given binary operation to each element of the source tile and a scalar parameter, then stores the result in the destination tile.
|
|
* The operation is applied independently to each element, treating the scalar parameter as the second operand for each operation.
|
|
*
|
|
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the scalar parameter.
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the binary operation is applied.
|
|
* @param[in] param The scalar parameter to be used as the second operand in the binary operation.
|
|
*/
|
|
template<typename op, typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
bin_map(threadgroup ST &dst, threadgroup const ST &src, thread const typename ST::dtype ¶m, const int threadIdx) {
|
|
#pragma clang loop unroll(full)
|
|
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
|
dst.data[i] = op::template op<typename ST::dtype>(src.data[i], param);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Performs a uniform binary operation on two tiles.
|
|
*
|
|
* This function applies a given binary operation to corresponding elements of two source tiles and stores the result in the destination tile.
|
|
* The operation is applied independently to each pair of elements, without considering their positions or the values of neighboring elements.
|
|
*
|
|
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T.
|
|
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] lhs The first source tile to which the binary operation is applied.
|
|
* @param[in] rhs The second source tile to which the binary operation is applied.
|
|
*/
|
|
template<typename op, typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
bin_map(threadgroup ST &dst, threadgroup const ST &lhs, threadgroup const ST &rhs, const int threadIdx) {
|
|
#pragma clang loop unroll(full)
|
|
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
|
dst.data[i] = op::template op<typename ST::dtype>(lhs.data[i], rhs.data[i]);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Performs a row-wise binary operation on a tile with a vector.
|
|
*
|
|
* This function applies a given binary operation to each row of the source tile and the corresponding element of the source vector,
|
|
* then stores the result in the destination tile. The operation is applied independently to each row, using the vector element as
|
|
* the second operand for each element in the row.
|
|
*
|
|
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements.
|
|
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam V The type of the vector. Must have the same data type as T.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the binary operation is applied.
|
|
* @param[in] vec The source vector containing the second operand for each row operation.
|
|
*/
|
|
template<typename op, typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
row_map(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &vec, const int threadIdx) {
|
|
static_assert(metal::is_same<typename ST::dtype, typename SV::dtype>::value, "Tile and vector must have the same data type");
|
|
static_assert(SV::length == ST::rows, "Vector length must match the number of rows in the tile");
|
|
#pragma clang loop unroll(full)
|
|
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
|
int row = i/dst.cols, col = i%dst.cols;
|
|
dst[{row, col}] = op::template op<typename ST::dtype>(src[{row, col}], vec[row]);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* @brief Performs a column-wise binary operation on a tile with a vector.
|
|
*
|
|
* This function applies a given binary operation to each column of the source tile and the corresponding element of the source vector,
|
|
* then stores the result in the destination tile. The operation is applied independently to each column, using the vector element as
|
|
* the second operand for each element in the column.
|
|
*
|
|
* @tparam op The binary operation to be applied. Must be specialized to support operation on the data type of T and the vector elements.
|
|
* @tparam T The type of the tiles. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam V The type of the vector. Must have the same data type as T.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the binary operation is applied.
|
|
* @param[in] vec The source vector containing the second operand for each column operation.
|
|
*/
|
|
template<typename op, typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
col_map(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &vec, const int threadIdx) {
|
|
static_assert(metal::is_same<typename ST::dtype, typename SV::dtype>::value, "Tile and vector must have the same data type");
|
|
static_assert(SV::length == ST::cols, "Vector length must match the number of columns in the tile");
|
|
#pragma clang loop unroll(full)
|
|
for(int i = laneid(threadIdx); i < dst.num_elements; i += GROUP_THREADS) {
|
|
int row = i/dst.cols, col = i%dst.cols;
|
|
dst[{row, col}] = op::template op<typename ST::dtype>(src[{row, col}], vec[col]);
|
|
}
|
|
}
|
|
|
|
|
|
/* ---------- WRAPPERS FOR PRETTINESS ---------- */
|
|
|
|
// All of the annoying qualifiers *should* be automatically inferred during compile-time.
|
|
// So, syntax should just be mittens::add_row(tile, colvec);
|
|
|
|
// const maps
|
|
/**
|
|
* @brief Sets all elements of the destination tile to zero.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
zero(threadgroup ST &dst, const int threadIdx) {
|
|
unary_map<base_ops::zero, ST>(dst, dst, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Sets all elements of the destination tile to one.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
one(threadgroup ST &dst, const int threadIdx) {
|
|
unary_map<base_ops::one, ST>(dst, dst, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Sets all elements of the destination tile to positive infinity.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
pos_infty(threadgroup ST &dst, const int threadIdx) {
|
|
unary_map<base_ops::pos_infty, ST>(dst, dst, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Sets all elements of the destination tile to negative infinity.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
neg_infty(threadgroup ST &dst, const int threadIdx) {
|
|
unary_map<base_ops::neg_infty, ST>(dst, dst, threadIdx);
|
|
}
|
|
|
|
// unary maps
|
|
/**
|
|
* @brief Applies the exponential function to each element of the source tile and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the exponential function is applied.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
exp(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
|
unary_map<base_ops::exp, ST>(dst, src, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Applies the exponential function to each element of the source tile and stores the result in the destination tile, in base 2.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the exponential function is applied.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
exp2(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
|
unary_map<base_ops::exp2, ST>(dst, src, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Applies the natural logarithm function to each element of the source tile and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the natural logarithm function is applied.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
log(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
|
unary_map<base_ops::log, ST>(dst, src, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Applies the absolute function to each element of the source tile and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the absolute function is applied.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
abs(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
|
unary_map<base_ops::abs, ST>(dst, src, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Applies the rectified linear unit function to each element of the source tile and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source tile to which the rectified linear unit function is applied.
|
|
*/
|
|
template<typename ST>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
relu(threadgroup ST &dst, threadgroup const ST &src, const int threadIdx) {
|
|
unary_map<base_ops::relu, ST>(dst, src, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Copies the elements of the source tile to the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam U The type of the source data. Must be convertible to the data type of the destination tile.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] src The source data to be copied.
|
|
*/
|
|
template<typename ST, typename U>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
copy(threadgroup ST &dst, thread const U &src, const int threadIdx) {
|
|
bin_map<base_ops::copy, ST>(dst, src, threadIdx);
|
|
}
|
|
|
|
// uniform binary maps
|
|
/**
|
|
* @brief Finds the maximum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] lhs The first source tile.
|
|
* @param[in] rhs The second source data.
|
|
*/
|
|
template<typename ST, typename U>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
max(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
|
bin_map<base_ops::max, ST>(dst, lhs, rhs, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Finds the minimum of each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] lhs The first source tile.
|
|
* @param[in] rhs The second source data.
|
|
*/
|
|
template<typename ST, typename U>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
min(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
|
bin_map<base_ops::min, ST>(dst, lhs, rhs, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Adds each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] lhs The first source tile.
|
|
* @param[in] rhs The second source data.
|
|
*/
|
|
template<typename ST, typename U>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
add(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
|
bin_map<base_ops::sum, ST>(dst, lhs, rhs, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Subtracts each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] lhs The first source tile.
|
|
* @param[in] rhs The second source data.
|
|
*/
|
|
template<typename ST, typename U>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
sub(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
|
bin_map<base_ops::sub, ST>(dst, lhs, rhs, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Multiplies each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] lhs The first source tile.
|
|
* @param[in] rhs The second source data.
|
|
*/
|
|
template<typename ST, typename U>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
mul(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
|
bin_map<base_ops::mul, ST>(dst, lhs, rhs, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Divides each pair of corresponding elements in the two source tiles and stores the result in the destination tile.
|
|
*
|
|
* @tparam T The type of the tile. Must satisfy the `ducks::st::all` concept.
|
|
* @tparam U The type of the second source data. Must be convertible to the data type of the destination tile.
|
|
* @param[out] dst The destination tile where the results are stored.
|
|
* @param[in] lhs The first source tile.
|
|
* @param[in] rhs The second source data.
|
|
*/
|
|
template<typename ST, typename U>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>(), void>::type
|
|
div(threadgroup ST &dst, threadgroup const ST &lhs, thread const U &rhs, const int threadIdx) {
|
|
bin_map<base_ops::div, ST>(dst, lhs, rhs, threadIdx);
|
|
}
|
|
|
|
// Row and col maps
|
|
|
|
/**
|
|
* @brief Adds row values to each row of a tile.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Column vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param src[in] Source tile to apply the addition on.
|
|
* @param row_values[in] Column vector containing values to add to each row.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
add_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const int threadIdx) {
|
|
row_map<base_ops::sum, ST, SV>(dst, src, row_values, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Subtracts row values from each row of a tile.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Column vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param src[in] Source tile to apply the subtraction on.
|
|
* @param row_values[in] Column vector containing values to subtract from each row.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
sub_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const int threadIdx) {
|
|
row_map<base_ops::sub, ST, SV>(dst, src, row_values, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Multiplies each row of a tile by row values.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Column vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param src[in] Source tile to apply the multiplication on.
|
|
* @param row_values[in] Column vector containing values to multiply each row by.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
mul_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const int threadIdx) {
|
|
row_map<base_ops::mul, ST, SV>(dst, src, row_values, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Divides each row of a tile by row values.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Column vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param src[in] Source tile to apply the division on.
|
|
* @param row_values[in] Column vector containing values to divide each row by.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
div_row(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &row_values, const int threadIdx) {
|
|
row_map<base_ops::div, ST, SV>(dst, src, row_values, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Broadcast a vector into into a tile's rows.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Column vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param row_values[in] Column vector containing values to broadcast into rows.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
broadcast_row(threadgroup ST &dst, threadgroup const SV &row_values, const int threadIdx) {
|
|
row_map<base_ops::copy2, ST, SV>(dst, dst, row_values, threadIdx);
|
|
}
|
|
|
|
|
|
// col maps
|
|
/**
|
|
* @brief Adds column values to each column of a tile.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Row vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param src[in] Source tile to apply the addition on.
|
|
* @param col_values[in] Row vector containing values to add to each column.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
add_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const int threadIdx) {
|
|
col_map<base_ops::sum, ST, SV>(dst, src, col_values, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Subtracts column values from each column of a tile.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Row vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param src[in] Source tile to apply the subtraction on.
|
|
* @param col_values[in] Row vector containing values to subtract from each column.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
sub_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const int threadIdx) {
|
|
col_map<base_ops::sub, ST, SV>(dst, src, col_values, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Multiplies each column of a tile by column values.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Row vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param src[in] Source tile to apply the multiplication on.
|
|
* @param col_values[in] Row vector containing values to multiply each column by.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
mul_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const int threadIdx) {
|
|
col_map<base_ops::mul, ST, SV>(dst, src, col_values, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Divides each column of a tile by column values.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Row vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param src[in] Source tile to apply the division on.
|
|
* @param col_values[in] Row vector containing values to divide each column by.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
div_col(threadgroup ST &dst, threadgroup const ST &src, threadgroup const SV &col_values, const int threadIdx) {
|
|
col_map<base_ops::div, ST, SV>(dst, src, col_values, threadIdx);
|
|
}
|
|
/**
|
|
* @brief Broadcast a vector into into a tile's columns.
|
|
*
|
|
* @tparam T Tile type.
|
|
* @tparam V Row vector type.
|
|
* @param dst[out] Destination tile where the result is stored.
|
|
* @param row_values[in] Row vector containing values to broadcast into cols.
|
|
*/
|
|
template<typename ST, typename SV>
|
|
static METAL_FUNC typename metal::enable_if<ducks::is_shared_tile<ST>() && ducks::is_shared_vector<SV>(), void>::type
|
|
broadcast_col(threadgroup ST &dst, threadgroup const SV &col_values, const int threadIdx) {
|
|
col_map<base_ops::copy2, ST, SV>(dst, dst, col_values, threadIdx);
|
|
}
|