#include #include #include #include // register guard namespace at { namespace detail { C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); } } struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { // NOTE: no idea what this is bool hasPrimaryContext(c10::DeviceIndex device_index) const override { return true; } }; int register_hook() { at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); return 0; } int temp_register_hook = register_hook(); // code from chatgpt struct GILSafeDeleter { void operator()(PyObject* ptr) const { if (ptr) { py::gil_scoped_acquire gil; Py_DECREF(ptr); } } }; class TinyTensor { private: // We wrap the PyObject* inside a shared_ptr so the GILSafeDeleter runs on destruction. std::shared_ptr obj_; public: TinyTensor() : obj_(nullptr, GILSafeDeleter()) {} // From a py::object TinyTensor(const py::object& o) : obj_(o.inc_ref().ptr(), GILSafeDeleter()) { // o.inc_ref() bumps the PyObject reference count; we store the pointer in shared_ptr } // Optional move or copy ctors if needed: TinyTensor(const TinyTensor &other) = default; TinyTensor(TinyTensor &&other) = default; TinyTensor& operator=(const TinyTensor &other) = default; TinyTensor& operator=(TinyTensor &&other) = default; py::object get_py_obj() const { if (!obj_) { return py::none(); } // Safely borrow as a py::object (we must hold the GIL). py::gil_scoped_acquire gil; return py::reinterpret_borrow(obj_.get()); } }; static caffe2::TypeMeta dtypeFromName(const std::string &dtype_name) { if (dtype_name == "float") { return caffe2::TypeMeta::Make(); } else if (dtype_name == "double") { return caffe2::TypeMeta::Make(); } else if (dtype_name == "int") { return caffe2::TypeMeta::Make(); } else if (dtype_name == "long") { return caffe2::TypeMeta::Make(); } else if (dtype_name == "bool") { return caffe2::TypeMeta::Make(); } else if (dtype_name == "char") { return caffe2::TypeMeta::Make(); } else if (dtype_name == "unsigned char") { return caffe2::TypeMeta::Make(); } throw std::runtime_error("Unsupported dtype: " + dtype_name); } at::Tensor wrap_tensor(py::object &py_obj) { // TODO: we have to get the dtype and the shape from the tinygrad Tensor std::vector sizes = py_obj.attr("shape").cast>(); std::string dtype_name = py_obj.attr("dtype").attr("name").cast(); return at::detail::make_tensor>( at::DispatchKeySet(at::DispatchKey::PrivateUse1), dtypeFromName(dtype_name), at::Device(at::kPrivateUse1), TinyTensor(py_obj), sizes); } py::object unwrap_tensor(const at::Tensor &tensor) { auto* impl = tensor.unsafeGetTensorImpl(); auto* opaque_impl = static_cast*>(impl); const TinyTensor &tiny = opaque_impl->opaque_handle(); return tiny.get_py_obj(); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("wrap", &wrap_tensor); m.def("unwrap", &unwrap_tensor); }