mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
75 lines
3.0 KiB
Plaintext
75 lines
3.0 KiB
Plaintext
#pragma once
|
|
|
|
#include "util.cuh"
|
|
#include <pybind11/pybind11.h>
|
|
|
|
namespace kittens {
|
|
namespace py {
|
|
|
|
template<typename T> struct from_object {
|
|
static T make(pybind11::object obj) {
|
|
return obj.cast<T>();
|
|
}
|
|
};
|
|
template<ducks::gl::all GL> struct from_object<GL> {
|
|
static GL make(pybind11::object obj) {
|
|
// Check if argument is a torch.Tensor
|
|
if (pybind11::hasattr(obj, "__class__") &&
|
|
obj.attr("__class__").attr("__name__").cast<std::string>() == "Tensor") {
|
|
|
|
// Check if tensor is contiguous
|
|
if (!obj.attr("is_contiguous")().cast<bool>()) {
|
|
throw std::runtime_error("Tensor must be contiguous");
|
|
}
|
|
if (obj.attr("device").attr("type").cast<std::string>() == "cpu") {
|
|
throw std::runtime_error("Tensor must be on CUDA device");
|
|
}
|
|
|
|
// Get shape, pad with 1s if needed
|
|
std::array<int, 4> shape = {1, 1, 1, 1};
|
|
auto py_shape = obj.attr("shape").cast<pybind11::tuple>();
|
|
size_t dims = py_shape.size();
|
|
if (dims > 4) {
|
|
throw std::runtime_error("Expected Tensor.ndim <= 4");
|
|
}
|
|
for (size_t i = 0; i < dims; ++i) {
|
|
shape[4 - dims + i] = pybind11::cast<int>(py_shape[i]);
|
|
}
|
|
|
|
// Get data pointer using data_ptr()
|
|
uint64_t data_ptr = obj.attr("data_ptr")().cast<uint64_t>();
|
|
|
|
// Create GL object using make_gl
|
|
return make_gl<GL>(data_ptr, shape[0], shape[1], shape[2], shape[3]);
|
|
}
|
|
throw std::runtime_error("Expected a torch.Tensor");
|
|
}
|
|
};
|
|
|
|
template<typename T> concept has_dynamic_shared_memory = requires(T t) { { t.dynamic_shared_memory() } -> std::convertible_to<int>; };
|
|
|
|
template<typename> struct trait;
|
|
template<typename MT, typename T> struct trait<MT T::*> { using member_type = MT; using type = T; };
|
|
template<typename> using object = pybind11::object;
|
|
template<auto kernel, typename TGlobal> static void bind_kernel(auto m, auto name, auto TGlobal::*... member_ptrs) {
|
|
m.def(name, [](object<decltype(member_ptrs)>... args) {
|
|
TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
|
|
if constexpr (has_dynamic_shared_memory<TGlobal>) {
|
|
int __dynamic_shared_memory__ = (int)__g__.dynamic_shared_memory();
|
|
hipFuncSetAttribute((void *) kernel, hipFuncAttributeMaxDynamicSharedMemorySize, __dynamic_shared_memory__);
|
|
kernel<<<__g__.grid(), __g__.block(), __dynamic_shared_memory__>>>(__g__);
|
|
} else {
|
|
kernel<<<__g__.grid(), __g__.block()>>>(__g__);
|
|
}
|
|
});
|
|
}
|
|
template<auto function, typename TGlobal> static void bind_function(auto m, auto name, auto TGlobal::*... member_ptrs) {
|
|
m.def(name, [](object<decltype(member_ptrs)>... args) {
|
|
TGlobal __g__ {from_object<typename trait<decltype(member_ptrs)>::member_type>::make(args)...};
|
|
function(__g__);
|
|
});
|
|
}
|
|
|
|
} // namespace py
|
|
} // namespace kittens
|