mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
c.Struct cleanup (#15640)
This commit is contained in:
committed by
GitHub
parent
742b3894d7
commit
d08c76d9cb
@@ -12,6 +12,51 @@ class TestC(unittest.TestCase):
|
||||
subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode())
|
||||
return DLL("test", f.name)
|
||||
|
||||
def test_struct_array_init(self):
|
||||
@record
|
||||
class Foo:
|
||||
SIZE = 12
|
||||
a: Annotated[ctypes.c_int * 3, 0]
|
||||
init_records()
|
||||
|
||||
f = Foo((1,2,3))
|
||||
assert f.a[0] == 1
|
||||
assert f.a[1] == 2
|
||||
assert f.a[2] == 3
|
||||
f = Foo((ctypes.c_int * 3)(1,2,3))
|
||||
assert f.a[0] == 1
|
||||
assert f.a[1] == 2
|
||||
assert f.a[2] == 3
|
||||
|
||||
def test_field_ranges(self):
|
||||
@record
|
||||
class Foo:
|
||||
SIZE = 2
|
||||
s: Annotated[ctypes.c_int8, 0]
|
||||
u: Annotated[ctypes.c_uint8, 1]
|
||||
init_records()
|
||||
|
||||
f = Foo()
|
||||
f.s = -1
|
||||
f.u = -1
|
||||
assert f.s == -1
|
||||
assert f.u == 255
|
||||
|
||||
# this syntax is inherited from ctypes, but it seems a bit nonsensical?
|
||||
def test_voidp_none(self):
|
||||
@record
|
||||
class Foo:
|
||||
SIZE = 8
|
||||
p: Annotated[ctypes.c_void_p, 0]
|
||||
init_records()
|
||||
|
||||
f = Foo(None)
|
||||
assert f.p is None
|
||||
f.p = ctypes.c_void_p(0xDEADBEEF)
|
||||
assert f.p == 0xDEADBEEF
|
||||
f.p = None
|
||||
assert f.p is None
|
||||
|
||||
def test_packed_struct(self):
|
||||
@record
|
||||
class Baz:
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, functools, os, pathlib, re, sys, sysconfig
|
||||
from tinygrad.helpers import ceildiv, getenv, unwrap, DEBUG, OSX, WIN
|
||||
from _ctypes import Array as _CArray, _SimpleCData, _Pointer
|
||||
from typing import TYPE_CHECKING, get_type_hints, get_args, get_origin, overload, Annotated, Any, Generic, Iterable, ParamSpec, TypeVar
|
||||
|
||||
def _do_ioctl(__idir, __base, __nr, __struct, __fd, *args, __payload=None, **kwargs):
|
||||
@@ -34,22 +33,22 @@ if TYPE_CHECKING:
|
||||
from _ctypes import _CData
|
||||
class Array(Generic[T, U], _CData):
|
||||
@overload
|
||||
def __getitem__(self: Array[_SimpleCData[V], Any], key: int) -> V: ...
|
||||
def __getitem__(self: Array[ctypes._SimpleCData[V], Any], key: int) -> V: ...
|
||||
@overload
|
||||
def __getitem__(self: Array[T, Any], key: slice) -> list[T]: ...
|
||||
@overload
|
||||
def __getitem__(self: Array[T, Any], key: int) -> T: ...
|
||||
def __getitem__(self, key) -> Any: ...
|
||||
@overload
|
||||
def __setitem__(self: Array[_SimpleCData[V], Any], key: int, val: V): ...
|
||||
def __setitem__(self: Array[ctypes._SimpleCData[V], Any], key: int, val: V): ...
|
||||
@overload
|
||||
def __setitem__(self: Array[T, Any], key: int, val: T): ...
|
||||
@overload
|
||||
def __setitem__(self: Array[T, Any], key: slice, val: Iterable[T]): ...
|
||||
def __setitem__(self, key, val): ...
|
||||
class POINTER(Generic[T], _Pointer): ...
|
||||
class POINTER(Generic[T], ctypes._Pointer): ...
|
||||
class CFUNCTYPE(Generic[T, P], _CFunctionType): ...
|
||||
class Enum(_SimpleCData):
|
||||
class Enum(ctypes._SimpleCData):
|
||||
@classmethod
|
||||
def get(cls, val:int, default="unknown") -> str: ...
|
||||
@classmethod
|
||||
@@ -80,14 +79,9 @@ else:
|
||||
return val
|
||||
def pointer(obj): return ctypes.pointer(obj)
|
||||
|
||||
def i2b(i:int, sz:int) -> bytes: return i.to_bytes(sz, sys.byteorder)
|
||||
def b2i(b:bytes) -> int: return int.from_bytes(b, sys.byteorder)
|
||||
def mv(st) -> memoryview: return memoryview(st).cast('B')
|
||||
|
||||
class Struct(ctypes.Structure):
|
||||
def __init__(self, *args, **kwargs):
|
||||
ctypes.Structure.__init__(self)
|
||||
self._objects_ = {}
|
||||
for f,v in [*zip((rf[0] for rf in self._real_fields_), args), *kwargs.items()]: setattr(self, f, v)
|
||||
|
||||
def record(cls) -> type[Struct]:
|
||||
@@ -98,38 +92,38 @@ def record(cls) -> type[Struct]:
|
||||
def init_records() -> None:
|
||||
for cls, struct, ns in _pending_records:
|
||||
setattr(struct, '_real_fields_', [])
|
||||
for nm, t in get_type_hints(cls, globalns=ns, include_extras=True).items():
|
||||
if t.__origin__ in (bool, bytes, str, int, float): setattr(struct, nm, Field(*(f:=t.__metadata__)))
|
||||
else: setattr(struct, nm, Field(*(f:=(del_an(t.__origin__), *t.__metadata__))))
|
||||
struct._real_fields_.append((nm,) + f) # type: ignore
|
||||
for i, (nm, t) in enumerate(get_type_hints(cls, globalns=ns, include_extras=True).items()):
|
||||
struct._real_fields_.append((nm, *(f:=(del_an(t.__origin__), *t.__metadata__) if isinstance(t.__metadata__[0], int) else t.__metadata__))) # type: ignore
|
||||
setattr(struct, nm, Field(nm, i, *f))
|
||||
_pending_records.clear()
|
||||
|
||||
class Field(property):
|
||||
def __init__(self, typ, off:int, bit_width=None, bit_off=0):
|
||||
if bit_width is not None:
|
||||
sl, set_mask = slice(off,off+(sz:=ceildiv(bit_width+bit_off, 8))), ~((mask:=(1 << bit_width) - 1) << bit_off)
|
||||
class Field:
|
||||
def __init__(self, nm, idx, typ, off, bit_width=None, bit_off=0):
|
||||
self.nm, self.idx, self.typ, self.off, self.bit_width, self.bit_off = nm, idx, typ, off, bit_width, bit_off
|
||||
|
||||
# lazily resolve field descriptors
|
||||
def _resolve(self, cls):
|
||||
if self.bit_width: # handle bitfields ourselves
|
||||
sl, set_mask = slice(self.off, self.off+(sz:=ceildiv(self.bit_width+self.bit_off, 8))), ~((mask:=(1 << self.bit_width) - 1) << self.bit_off)
|
||||
def b2i(obj): return int.from_bytes(memoryview(obj).cast("B")[sl], sys.byteorder)
|
||||
def bset(obj, v): memoryview(obj).cast("B")[sl] = ((b2i(obj) & set_mask) | v << self.bit_off).to_bytes(sz, sys.byteorder)
|
||||
# FIXME: signedness
|
||||
super().__init__(lambda self: (b2i(mv(self)[sl]) >> bit_off) & mask,
|
||||
lambda self,v: mv(self).__setitem__(sl, i2b((b2i(mv(self)[sl]) & set_mask) | (v << bit_off), sz)))
|
||||
else:
|
||||
sl = slice(off, off + ctypes.sizeof(typ))
|
||||
def set_with_objs(f):
|
||||
def wrapper(self, v):
|
||||
if hasattr(v, '_objects') and hasattr(self, '_objects_'): self._objects_[off] = {'_self_': v, **(v._objects or {})}
|
||||
mv(self).__setitem__(sl, bytes(v if isinstance(v, typ) else f(v)))
|
||||
return wrapper
|
||||
if issubclass(typ, _CArray):
|
||||
getter = (lambda self: typ.from_buffer(mv(self)[sl]).value) if typ._type_ is ctypes.c_char else (lambda self: typ.from_buffer(mv(self)[sl]))
|
||||
super().__init__(getter, set_with_objs(lambda v: typ(*v)))
|
||||
else: super().__init__(lambda self: v.value if isinstance(v:=typ.from_buffer(mv(self)[sl]), _SimpleCData) else v, set_with_objs(typ))
|
||||
self.offset = off
|
||||
cf = property(lambda obj: b2i(obj) >> self.bit_off & mask, bset)
|
||||
# pull the CField descriptor from a dummy class, zero length arrays are so ctypes manages references to child objects for us
|
||||
else: cf = type(self.nm, (ctypes.Structure,), {"_layout_": "ms", "_pack_": 1, "_fields_": [(str(i), ctypes.c_byte * 0) for i in range(self.idx)] +
|
||||
[("_", ctypes.c_byte * self.off), ("v", self.typ)]}).v # type: ignore
|
||||
setattr(cls, self.nm, cf)
|
||||
return cf
|
||||
|
||||
def __get__(self, obj, objtype=None): return self._resolve(objtype).__get__(obj, objtype) if objtype else self
|
||||
def __set__(self, obj, value): self._resolve(obj.__class__).__set__(obj, value)
|
||||
|
||||
@functools.cache
|
||||
def init_c_struct_t(sz:int, fields: tuple[tuple, ...]):
|
||||
CStruct = type("CStruct", (Struct,), {'_fields_': [('_mem_', ctypes.c_byte * sz)], '_real_fields_': []})
|
||||
for nm,ty,*args in fields:
|
||||
setattr(CStruct, nm, Field(*(f:=(del_an(ty), *args))))
|
||||
CStruct._real_fields_.append((nm,) + f) # type: ignore
|
||||
for i,(nm,ty,*args) in enumerate(fields):
|
||||
CStruct._real_fields_.append((nm, *(f:=(del_an(ty), *args)))) # type: ignore
|
||||
setattr(CStruct, nm, Field(nm, i, *f))
|
||||
return CStruct
|
||||
def init_c_var(ty, creat_cb): return (creat_cb(v:=del_an(ty)()), v)[1]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user