c.Struct cleanup (#15640)

This commit is contained in:
Christopher Milan
2026-04-08 17:07:16 -07:00
committed by GitHub
parent 742b3894d7
commit d08c76d9cb
2 changed files with 74 additions and 35 deletions

View File

@@ -12,6 +12,51 @@ class TestC(unittest.TestCase):
subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode())
return DLL("test", f.name)
def test_struct_array_init(self):
@record
class Foo:
SIZE = 12
a: Annotated[ctypes.c_int * 3, 0]
init_records()
f = Foo((1,2,3))
assert f.a[0] == 1
assert f.a[1] == 2
assert f.a[2] == 3
f = Foo((ctypes.c_int * 3)(1,2,3))
assert f.a[0] == 1
assert f.a[1] == 2
assert f.a[2] == 3
def test_field_ranges(self):
@record
class Foo:
SIZE = 2
s: Annotated[ctypes.c_int8, 0]
u: Annotated[ctypes.c_uint8, 1]
init_records()
f = Foo()
f.s = -1
f.u = -1
assert f.s == -1
assert f.u == 255
# this syntax is inherited from ctypes, but it seems a bit nonsensical?
def test_voidp_none(self):
@record
class Foo:
SIZE = 8
p: Annotated[ctypes.c_void_p, 0]
init_records()
f = Foo(None)
assert f.p is None
f.p = ctypes.c_void_p(0xDEADBEEF)
assert f.p == 0xDEADBEEF
f.p = None
assert f.p is None
def test_packed_struct(self):
@record
class Baz:

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
import ctypes, functools, os, pathlib, re, sys, sysconfig
from tinygrad.helpers import ceildiv, getenv, unwrap, DEBUG, OSX, WIN
from _ctypes import Array as _CArray, _SimpleCData, _Pointer
from typing import TYPE_CHECKING, get_type_hints, get_args, get_origin, overload, Annotated, Any, Generic, Iterable, ParamSpec, TypeVar
def _do_ioctl(__idir, __base, __nr, __struct, __fd, *args, __payload=None, **kwargs):
@@ -34,22 +33,22 @@ if TYPE_CHECKING:
from _ctypes import _CData
class Array(Generic[T, U], _CData):
@overload
def __getitem__(self: Array[_SimpleCData[V], Any], key: int) -> V: ...
def __getitem__(self: Array[ctypes._SimpleCData[V], Any], key: int) -> V: ...
@overload
def __getitem__(self: Array[T, Any], key: slice) -> list[T]: ...
@overload
def __getitem__(self: Array[T, Any], key: int) -> T: ...
def __getitem__(self, key) -> Any: ...
@overload
def __setitem__(self: Array[_SimpleCData[V], Any], key: int, val: V): ...
def __setitem__(self: Array[ctypes._SimpleCData[V], Any], key: int, val: V): ...
@overload
def __setitem__(self: Array[T, Any], key: int, val: T): ...
@overload
def __setitem__(self: Array[T, Any], key: slice, val: Iterable[T]): ...
def __setitem__(self, key, val): ...
class POINTER(Generic[T], _Pointer): ...
class POINTER(Generic[T], ctypes._Pointer): ...
class CFUNCTYPE(Generic[T, P], _CFunctionType): ...
class Enum(_SimpleCData):
class Enum(ctypes._SimpleCData):
@classmethod
def get(cls, val:int, default="unknown") -> str: ...
@classmethod
@@ -80,14 +79,9 @@ else:
return val
def pointer(obj): return ctypes.pointer(obj)
def i2b(i:int, sz:int) -> bytes: return i.to_bytes(sz, sys.byteorder)
def b2i(b:bytes) -> int: return int.from_bytes(b, sys.byteorder)
def mv(st) -> memoryview: return memoryview(st).cast('B')
class Struct(ctypes.Structure):
def __init__(self, *args, **kwargs):
ctypes.Structure.__init__(self)
self._objects_ = {}
for f,v in [*zip((rf[0] for rf in self._real_fields_), args), *kwargs.items()]: setattr(self, f, v)
def record(cls) -> type[Struct]:
@@ -98,38 +92,38 @@ def record(cls) -> type[Struct]:
def init_records() -> None:
for cls, struct, ns in _pending_records:
setattr(struct, '_real_fields_', [])
for nm, t in get_type_hints(cls, globalns=ns, include_extras=True).items():
if t.__origin__ in (bool, bytes, str, int, float): setattr(struct, nm, Field(*(f:=t.__metadata__)))
else: setattr(struct, nm, Field(*(f:=(del_an(t.__origin__), *t.__metadata__))))
struct._real_fields_.append((nm,) + f) # type: ignore
for i, (nm, t) in enumerate(get_type_hints(cls, globalns=ns, include_extras=True).items()):
struct._real_fields_.append((nm, *(f:=(del_an(t.__origin__), *t.__metadata__) if isinstance(t.__metadata__[0], int) else t.__metadata__))) # type: ignore
setattr(struct, nm, Field(nm, i, *f))
_pending_records.clear()
class Field(property):
def __init__(self, typ, off:int, bit_width=None, bit_off=0):
if bit_width is not None:
sl, set_mask = slice(off,off+(sz:=ceildiv(bit_width+bit_off, 8))), ~((mask:=(1 << bit_width) - 1) << bit_off)
class Field:
def __init__(self, nm, idx, typ, off, bit_width=None, bit_off=0):
self.nm, self.idx, self.typ, self.off, self.bit_width, self.bit_off = nm, idx, typ, off, bit_width, bit_off
# lazily resolve field descriptors
def _resolve(self, cls):
if self.bit_width: # handle bitfields ourselves
sl, set_mask = slice(self.off, self.off+(sz:=ceildiv(self.bit_width+self.bit_off, 8))), ~((mask:=(1 << self.bit_width) - 1) << self.bit_off)
def b2i(obj): return int.from_bytes(memoryview(obj).cast("B")[sl], sys.byteorder)
def bset(obj, v): memoryview(obj).cast("B")[sl] = ((b2i(obj) & set_mask) | v << self.bit_off).to_bytes(sz, sys.byteorder)
# FIXME: signedness
super().__init__(lambda self: (b2i(mv(self)[sl]) >> bit_off) & mask,
lambda self,v: mv(self).__setitem__(sl, i2b((b2i(mv(self)[sl]) & set_mask) | (v << bit_off), sz)))
else:
sl = slice(off, off + ctypes.sizeof(typ))
def set_with_objs(f):
def wrapper(self, v):
if hasattr(v, '_objects') and hasattr(self, '_objects_'): self._objects_[off] = {'_self_': v, **(v._objects or {})}
mv(self).__setitem__(sl, bytes(v if isinstance(v, typ) else f(v)))
return wrapper
if issubclass(typ, _CArray):
getter = (lambda self: typ.from_buffer(mv(self)[sl]).value) if typ._type_ is ctypes.c_char else (lambda self: typ.from_buffer(mv(self)[sl]))
super().__init__(getter, set_with_objs(lambda v: typ(*v)))
else: super().__init__(lambda self: v.value if isinstance(v:=typ.from_buffer(mv(self)[sl]), _SimpleCData) else v, set_with_objs(typ))
self.offset = off
cf = property(lambda obj: b2i(obj) >> self.bit_off & mask, bset)
# pull the CField descriptor from a dummy class, zero length arrays are so ctypes manages references to child objects for us
else: cf = type(self.nm, (ctypes.Structure,), {"_layout_": "ms", "_pack_": 1, "_fields_": [(str(i), ctypes.c_byte * 0) for i in range(self.idx)] +
[("_", ctypes.c_byte * self.off), ("v", self.typ)]}).v # type: ignore
setattr(cls, self.nm, cf)
return cf
def __get__(self, obj, objtype=None): return self._resolve(objtype).__get__(obj, objtype) if objtype else self
def __set__(self, obj, value): self._resolve(obj.__class__).__set__(obj, value)
@functools.cache
def init_c_struct_t(sz:int, fields: tuple[tuple, ...]):
CStruct = type("CStruct", (Struct,), {'_fields_': [('_mem_', ctypes.c_byte * sz)], '_real_fields_': []})
for nm,ty,*args in fields:
setattr(CStruct, nm, Field(*(f:=(del_an(ty), *args))))
CStruct._real_fields_.append((nm,) + f) # type: ignore
for i,(nm,ty,*args) in enumerate(fields):
CStruct._real_fields_.append((nm, *(f:=(del_an(ty), *args)))) # type: ignore
setattr(CStruct, nm, Field(nm, i, *f))
return CStruct
def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]