#include #include #include #include #include // register guard namespace at { namespace detail { C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); } } struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { // NOTE: no idea what this is bool hasPrimaryContext(c10::DeviceIndex device_index) const override { return true; } }; int register_hook() { at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); return 0; } int temp_register_hook = register_hook(); at::Tensor wrap_tensor(py::object &py_obj, c10::ScalarType dtype) { // TODO: we have to get the dtype and the shape from the tinygrad Tensor std::vector sizes = py_obj.attr("shape").cast>(); return at::detail::make_tensor>>( at::DispatchKeySet(at::DispatchKey::PrivateUse1), c10::scalarTypeToTypeMeta(dtype), at::Device(at::kPrivateUse1), std::make_shared(py_obj.release().ptr(), getPyInterpreter()), sizes); } py::object unwrap_tensor(const at::Tensor &tensor) { auto* impl = tensor.unsafeGetTensorImpl(); auto* opaque_impl = static_cast>*>(impl); std::shared_ptr tiny = opaque_impl->opaque_handle(); return py::reinterpret_borrow(tiny->ptr(getPyInterpreter())); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("wrap", &wrap_tensor); m.def("unwrap", &unwrap_tensor); }