Files
tinygrad/extra/thunder/amd/include/pyutils/pyutils.cuh
2026-02-12 20:16:43 -08:00

75 lines
3.0 KiB
Plaintext

#pragma once
#include "util.cuh"
#include <pybind11/pybind11.h>
namespace kittens {
namespace py {
template<typename T> struct from_object {
static T make(pybind11::object obj) {
return obj.cast<T>();
}
};
template<ducks::gl::all GL> struct from_object<GL> {
static GL make(pybind11::object obj) {
// Check if argument is a torch.Tensor
if (pybind11::hasattr(obj, "__class__") &&
obj.attr("__class__").attr("__name__").cast<std::string>() == "Tensor") {
// Check if tensor is contiguous
if (!obj.attr("is_contiguous")().cast<bool>()) {
throw std::runtime_error("Tensor must be contiguous");
}
if (obj.attr("device").attr("type").cast<std::string>() == "cpu") {
throw std::runtime_error("Tensor must be on CUDA device");
}
// Get shape, pad with 1s if needed
std::array<int, 4> shape = {1, 1, 1, 1};
auto py_shape = obj.attr("shape").cast<pybind11::tuple>();
size_t dims = py_shape.size();
if (dims > 4) {
throw std::runtime_error("Expected Tensor.ndim <= 4");
}
for (size_t i = 0; i < dims; ++i) {
shape[4 - dims + i] = pybind11::cast<int>(py_shape[i]);
}
// Get data pointer using data_ptr()
uint64_t data_ptr = obj.attr("data_ptr")().cast<uint64_t>();
// Create GL object using make_gl
return make_gl<GL>(data_ptr, shape[0], shape[1], shape[2], shape[3]);
}
throw std::runtime_error("Expected a torch.Tensor");
}
};
template<typename T> concept has_dynamic_shared_memory = requires(T t) { { t.dynamic_shared_memory() } -> std::convertible_to<int>; };
template<typename> struct trait;
template<typename MT, typename T> struct trait<MT T::*> { using member_type = MT; using type = T; };
template<typename> using object = pybind11::object;
template<auto kernel, typename TGlobal> static void bind_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) {
m.def(name, [](object<decltype(member_ptrs)>... args) {
TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
if constexpr (has_dynamic_shared_memory<TGlobal>) {
int __dynamic_shared_memory__ = (int)__g__.dynamic_shared_memory();
hipFuncSetAttribute((void *) kernel, hipFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__);
kernel<<<__g__.grid(), __g__.block(), __dynamic_shared_memory__>>>(__g__);
} else {
kernel<<<__g__.grid(), __g__.block()>>>(__g__);
}
});
}
template<auto function, typename TGlobal> static void bind_function(auto m, auto name, auto TGlobal::*... member_ptrs) {
m.def(name, [](object<decltype(member_ptrs)>... args) {
TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
function(__g__);
});
}
} // namespace py
} // namespace kittens