#pragma once #include #include #include "kittens.cuh" #include "parallel_tensor.cuh" namespace kittens { namespace py { template concept has_min_blocks_per_sm = requires { std::integral_constant{}; }; template consteval int min_blocks_per_sm() { if constexpr(has_min_blocks_per_sm) return Config::MIN_BLOCKS_PER_SM; else return 1; } template __global__ __launch_bounds__(Config::NUM_THREADS, min_blocks_per_sm()) void global_kernel_unclustered(const __grid_constant__ Globals G) { Kernel(G); } template __global__ __launch_bounds__(Config::NUM_THREADS, min_blocks_per_sm()) __cluster_dims__(Config::CLUSTER_SIZE) void global_kernel_clustered(const __grid_constant__ Globals G) { Kernel(G); } template static inline void tensor_check(const at::Tensor &t) { TORCH_CHECK(t.is_cuda(), "Tensor must be on CUDA device") TORCH_CHECK(t.is_contiguous(), "Tensor must be contiguous") TORCH_CHECK(t.dim() <= 4, "Expected Tensor.dim() <= 4"); if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Char, "Tensor has invalid dtype (expected int8)"); } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Short, "Tensor has invalid dtype (expected int16)"); } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Int, "Tensor has invalid dtype (expected int32)"); } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Long, "Tensor has invalid dtype (expected int64)"); } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Float8_e4m3fn, "Tensor has invalid dtype (expected fp8e4m3)"); } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Float8_e5m2, "Tensor has invalid dtype (expected fp8e5m2)"); #ifdef KITTENS_BLACKWELL } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Byte, "Tensor has invalid dtype (expected fp8e8m0 represented as uint8)"); #endif } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::BFloat16, "Tensor has invalid dtype (expected bfloat16)"); } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Half, "Tensor has invalid dtype (expected float16)"); } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Float, "Tensor has invalid dtype (expected float32)"); } else if constexpr (std::is_same_v) { TORCH_CHECK(t.dtype() == at::ScalarType::Double, "Tensor has invalid dtype (expected float64)"); } else { TORCH_CHECK(false, "Unsupported dtype"); } } template static inline void parallel_tensor_check(const TKParallelTensor& t) { tensor_check(t.data_); TORCH_CHECK(t.data_.sizes().vec() == t.shape_, "Shape mismatch between TKParallelTensor and the underlying tensor"); TORCH_CHECK(t.data_.dtype() == t.dtype_, "Dtype mismatch between TKParallelTensor and the underlying tensor"); TORCH_CHECK(t.raw_ptrs_.size() == PGL::num_devices, "Number of devices mismatch between PGL and TKParallelTensor"); TORCH_CHECK(t.local_rank_ == t.data_.device().index(), "Current tensor device index mismatch within TKParallelTensor"); TORCH_CHECK(t.local_world_size_ == PGL::num_devices, "Number of devices mismatch between PGL and TKParallelTensor"); TORCH_CHECK(t.multicast_ == PGL::multicast, "Multicast mismatch between PGL and TKParallelTensor"); TORCH_CHECK(t.raw_ptrs_[t.local_rank_] == reinterpret_cast(t.data_.data_ptr()), "Current tensor data pointer not found in TKParallelTensor's raw_ptrs_"); } template static inline GL tensor_to_gl(const at::Tensor &t) { tensor_check(t); std::array shape = {1, 1, 1, 1}; for (int i = 0; i < static_cast(t.dim()); ++i) shape[4 - t.dim() + i] = static_cast(t.size(i)); uint64_t data_ptr = reinterpret_cast(t.data_ptr()); return ::kittens::make_gl(data_ptr, shape[0], shape[1], shape[2], shape[3]); } template static inline PGL parallel_tensor_to_pgl(TKParallelTensor &t) { parallel_tensor_check(t); std::array shape = {1, 1, 1, 1}; for (int i = 0; i < static_cast(t.data_.dim()); ++i) { shape[4 - t.data_.dim() + i] = static_cast(t.data_.size(i)); } if constexpr (PGL::multicast) return ::kittens::make_pgl( reinterpret_cast(t.multicast_ptr_), reinterpret_cast(t.raw_ptrs_.data()), shape[0], shape[1], shape[2], shape[3]); else return ::kittens::make_pgl( reinterpret_cast(t.raw_ptrs_.data()), shape[0], shape[1], shape[2], shape[3]); } template static inline GL make_fake_gl(const int batch, const int depth, const int rows, const int cols) { return ::kittens::make_gl(reinterpret_cast(nullptr), batch, depth, rows, cols); } static inline void _device_check(const at::Tensor& first, const at::Tensor& second) { TORCH_CHECK(first.device() == second.device(), "All tensors must be on the same device"); } template static inline void device_check(const T1& first, const Ts&... rest) { (_device_check(first, rest), ...); } static inline void _parallel_tensor_check(const TKParallelTensor& first, const TKParallelTensor& second) { TORCH_CHECK(first.local_rank_ == second.local_rank_, "All parallel tensors must have the same local_rank"); TORCH_CHECK(first.local_world_size_ == second.local_world_size_, "All parallel tensors must have the same local_world_size"); } template static inline void parallel_tensor_check(const T1& first, const Ts&... rest) { (_parallel_tensor_check(first, rest), ...); } template concept static_grid = requires { Config::NUM_BLOCKS; }; template concept static_block = requires { Config::NUM_THREADS; }; template concept static_dynamic_shared_memory = requires { Config::DYNAMIC_SHARED_MEMORY; }; template static inline void launch_kernel(const Globals &G) { dim3 grid; if constexpr (static_grid) grid = dim3{Config::NUM_BLOCKS, 1, 1}; else grid = G.grid(); dim3 block; if constexpr (static_block) block = dim3{Config::NUM_THREADS, 1, 1}; else block = G.block(); int dynamic_shared_memory; if constexpr (static_dynamic_shared_memory) dynamic_shared_memory = static_cast(Config::DYNAMIC_SHARED_MEMORY); else dynamic_shared_memory = G.dynamic_shared_memory(); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); if constexpr (Config::CLUSTER_SIZE <= 1) { CUDACHECK(cudaFuncSetAttribute(global_kernel_unclustered, cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_shared_memory)); global_kernel_unclustered<<>>(G); } else { CUDACHECK(cudaFuncSetAttribute(global_kernel_clustered, cudaFuncAttributeMaxDynamicSharedMemorySize, dynamic_shared_memory)); global_kernel_clustered<<>>(G); } } } // namespace py } // namespace kittens