From d08c76d9cb14eacdfdde18e6e20cafec1cab7290 Mon Sep 17 00:00:00 2001 From: Christopher Milan Date: Wed, 8 Apr 2026 17:07:16 -0700 Subject: [PATCH] c.Struct cleanup (#15640) --- test/null/test_autogen.py | 45 ++++++++++++++++++++++++ tinygrad/runtime/support/c.py | 64 ++++++++++++++++------------------- 2 files changed, 74 insertions(+), 35 deletions(-) diff --git a/test/null/test_autogen.py b/test/null/test_autogen.py index 55d107ce40..cf4281a96a 100644 --- a/test/null/test_autogen.py +++ b/test/null/test_autogen.py @@ -12,6 +12,51 @@ class TestC(unittest.TestCase): subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode()) return DLL("test", f.name) + def test_struct_array_init(self): + @record + class Foo: + SIZE = 12 + a: Annotated[ctypes.c_int * 3, 0] + init_records() + + f = Foo((1,2,3)) + assert f.a[0] == 1 + assert f.a[1] == 2 + assert f.a[2] == 3 + f = Foo((ctypes.c_int * 3)(1,2,3)) + assert f.a[0] == 1 + assert f.a[1] == 2 + assert f.a[2] == 3 + + def test_field_ranges(self): + @record + class Foo: + SIZE = 2 + s: Annotated[ctypes.c_int8, 0] + u: Annotated[ctypes.c_uint8, 1] + init_records() + + f = Foo() + f.s = -1 + f.u = -1 + assert f.s == -1 + assert f.u == 255 + + # this syntax is inherited from ctypes, but it seems a bit nonsensical? + def test_voidp_none(self): + @record + class Foo: + SIZE = 8 + p: Annotated[ctypes.c_void_p, 0] + init_records() + + f = Foo(None) + assert f.p is None + f.p = ctypes.c_void_p(0xDEADBEEF) + assert f.p == 0xDEADBEEF + f.p = None + assert f.p is None + def test_packed_struct(self): @record class Baz: diff --git a/tinygrad/runtime/support/c.py b/tinygrad/runtime/support/c.py index a4f092d8f8..95d18a8709 100644 --- a/tinygrad/runtime/support/c.py +++ b/tinygrad/runtime/support/c.py @@ -1,7 +1,6 @@ from __future__ import annotations import ctypes, functools, os, pathlib, re, sys, sysconfig from tinygrad.helpers import ceildiv, getenv, unwrap, DEBUG, OSX, WIN -from _ctypes import Array as _CArray, _SimpleCData, _Pointer from typing import TYPE_CHECKING, get_type_hints, get_args, get_origin, overload, Annotated, Any, Generic, Iterable, ParamSpec, TypeVar def _do_ioctl(__idir, __base, __nr, __struct, __fd, *args, __payload=None, **kwargs): @@ -34,22 +33,22 @@ if TYPE_CHECKING: from _ctypes import _CData class Array(Generic[T, U], _CData): @overload - def __getitem__(self: Array[_SimpleCData[V], Any], key: int) -> V: ... + def __getitem__(self: Array[ctypes._SimpleCData[V], Any], key: int) -> V: ... @overload def __getitem__(self: Array[T, Any], key: slice) -> list[T]: ... @overload def __getitem__(self: Array[T, Any], key: int) -> T: ... def __getitem__(self, key) -> Any: ... @overload - def __setitem__(self: Array[_SimpleCData[V], Any], key: int, val: V): ... + def __setitem__(self: Array[ctypes._SimpleCData[V], Any], key: int, val: V): ... @overload def __setitem__(self: Array[T, Any], key: int, val: T): ... @overload def __setitem__(self: Array[T, Any], key: slice, val: Iterable[T]): ... def __setitem__(self, key, val): ... - class POINTER(Generic[T], _Pointer): ... + class POINTER(Generic[T], ctypes._Pointer): ... class CFUNCTYPE(Generic[T, P], _CFunctionType): ... - class Enum(_SimpleCData): + class Enum(ctypes._SimpleCData): @classmethod def get(cls, val:int, default="unknown") -> str: ... @classmethod @@ -80,14 +79,9 @@ else: return val def pointer(obj): return ctypes.pointer(obj) -def i2b(i:int, sz:int) -> bytes: return i.to_bytes(sz, sys.byteorder) -def b2i(b:bytes) -> int: return int.from_bytes(b, sys.byteorder) -def mv(st) -> memoryview: return memoryview(st).cast('B') - class Struct(ctypes.Structure): def __init__(self, *args, **kwargs): ctypes.Structure.__init__(self) - self._objects_ = {} for f,v in [*zip((rf[0] for rf in self._real_fields_), args), *kwargs.items()]: setattr(self, f, v) def record(cls) -> type[Struct]: @@ -98,38 +92,38 @@ def record(cls) -> type[Struct]: def init_records() -> None: for cls, struct, ns in _pending_records: setattr(struct, '_real_fields_', []) - for nm, t in get_type_hints(cls, globalns=ns, include_extras=True).items(): - if t.__origin__ in (bool, bytes, str, int, float): setattr(struct, nm, Field(*(f:=t.__metadata__))) - else: setattr(struct, nm, Field(*(f:=(del_an(t.__origin__), *t.__metadata__)))) - struct._real_fields_.append((nm,) + f) # type: ignore + for i, (nm, t) in enumerate(get_type_hints(cls, globalns=ns, include_extras=True).items()): + struct._real_fields_.append((nm, *(f:=(del_an(t.__origin__), *t.__metadata__) if isinstance(t.__metadata__[0], int) else t.__metadata__))) # type: ignore + setattr(struct, nm, Field(nm, i, *f)) _pending_records.clear() -class Field(property): - def __init__(self, typ, off:int, bit_width=None, bit_off=0): - if bit_width is not None: - sl, set_mask = slice(off,off+(sz:=ceildiv(bit_width+bit_off, 8))), ~((mask:=(1 << bit_width) - 1) << bit_off) +class Field: + def __init__(self, nm, idx, typ, off, bit_width=None, bit_off=0): + self.nm, self.idx, self.typ, self.off, self.bit_width, self.bit_off = nm, idx, typ, off, bit_width, bit_off + + # lazily resolve field descriptors + def _resolve(self, cls): + if self.bit_width: # handle bitfields ourselves + sl, set_mask = slice(self.off, self.off+(sz:=ceildiv(self.bit_width+self.bit_off, 8))), ~((mask:=(1 << self.bit_width) - 1) << self.bit_off) + def b2i(obj): return int.from_bytes(memoryview(obj).cast("B")[sl], sys.byteorder) + def bset(obj, v): memoryview(obj).cast("B")[sl] = ((b2i(obj) & set_mask) | v << self.bit_off).to_bytes(sz, sys.byteorder) # FIXME: signedness - super().__init__(lambda self: (b2i(mv(self)[sl]) >> bit_off) & mask, - lambda self,v: mv(self).__setitem__(sl, i2b((b2i(mv(self)[sl]) & set_mask) | (v << bit_off), sz))) - else: - sl = slice(off, off + ctypes.sizeof(typ)) - def set_with_objs(f): - def wrapper(self, v): - if hasattr(v, '_objects') and hasattr(self, '_objects_'): self._objects_[off] = {'_self_': v, **(v._objects or {})} - mv(self).__setitem__(sl, bytes(v if isinstance(v, typ) else f(v))) - return wrapper - if issubclass(typ, _CArray): - getter = (lambda self: typ.from_buffer(mv(self)[sl]).value) if typ._type_ is ctypes.c_char else (lambda self: typ.from_buffer(mv(self)[sl])) - super().__init__(getter, set_with_objs(lambda v: typ(*v))) - else: super().__init__(lambda self: v.value if isinstance(v:=typ.from_buffer(mv(self)[sl]), _SimpleCData) else v, set_with_objs(typ)) - self.offset = off + cf = property(lambda obj: b2i(obj) >> self.bit_off & mask, bset) + # pull the CField descriptor from a dummy class, zero length arrays are so ctypes manages references to child objects for us + else: cf = type(self.nm, (ctypes.Structure,), {"_layout_": "ms", "_pack_": 1, "_fields_": [(str(i), ctypes.c_byte * 0) for i in range(self.idx)] + + [("_", ctypes.c_byte * self.off), ("v", self.typ)]}).v # type: ignore + setattr(cls, self.nm, cf) + return cf + + def __get__(self, obj, objtype=None): return self._resolve(objtype).__get__(obj, objtype) if objtype else self + def __set__(self, obj, value): self._resolve(obj.__class__).__set__(obj, value) @functools.cache def init_c_struct_t(sz:int, fields: tuple[tuple, ...]): CStruct = type("CStruct", (Struct,), {'_fields_': [('_mem_', ctypes.c_byte * sz)], '_real_fields_': []}) - for nm,ty,*args in fields: - setattr(CStruct, nm, Field(*(f:=(del_an(ty), *args)))) - CStruct._real_fields_.append((nm,) + f) # type: ignore + for i,(nm,ty,*args) in enumerate(fields): + CStruct._real_fields_.append((nm, *(f:=(del_an(ty), *args)))) # type: ignore + setattr(CStruct, nm, Field(nm, i, *f)) return CStruct def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]