Files
tinygrad/test/null/test_autogen.py
2026-04-13 16:54:39 -04:00

536 lines
16 KiB
Python

import ctypes, struct, subprocess, tempfile, unittest
from tinygrad.helpers import OSX, WIN
from tinygrad.runtime.support.c import DLL, record, Field
from tinygrad.runtime.support import c
from tinygrad.runtime.support.autogen import gen
@unittest.skipIf(WIN, "doesn't compile on windows")
class TestC(unittest.TestCase):
def compile(self, src):
with tempfile.NamedTemporaryFile(suffix=".so") as f:
subprocess.check_output(('clang', '-x', 'c', '-fPIC', '-shared', '-', '-o', f.name), input=src.encode())
return DLL("test", f.name)
def test_struct_array_init(self):
@record
class Foo(c.Struct):
SIZE = 12
a = Field(ctypes.c_int * 3, 0)
f = Foo((1,2,3))
assert f.a[0] == 1
assert f.a[1] == 2
assert f.a[2] == 3
f = Foo((ctypes.c_int * 3)(1,2,3))
assert f.a[0] == 1
assert f.a[1] == 2
assert f.a[2] == 3
def test_field_ranges(self):
@record
class Foo(c.Struct):
SIZE = 2
s = Field(ctypes.c_int8, 0)
u = Field(ctypes.c_uint8, 1)
f = Foo()
f.s = -1
f.u = -1
assert f.s == -1
assert f.u == 255
# this syntax is inherited from ctypes, but it seems a bit nonsensical?
def test_voidp_none(self):
@record
class Foo(c.Struct):
SIZE = 8
p = Field(ctypes.c_void_p, 0)
f = Foo(None)
assert f.p is None
f.p = ctypes.c_void_p(0xDEADBEEF)
assert f.p == 0xDEADBEEF
f.p = None
assert f.p is None
def test_packed_struct(self):
@record
class Baz(c.Struct):
SIZE = 8
a = Field(ctypes.c_uint, 0, 30)
b = Field(ctypes.c_uint, 3, 30, 6)
c = Field(ctypes.c_uint, 7, 2, 4)
d = Field(ctypes.c_uint, 7, 2, 6)
b = Baz(0x3AAADEAD, 0xBEEF, 1, 0)
assert b.a == 0x3AAADEAD
assert b.b == 0xBEEF
assert b.c == 1
assert b.d == 0
b.a = 0xCAFE
assert b.a == 0xCAFE
assert b.b == 0xBEEF
assert b.c == 1
assert b.d == 0
def test_packed_struct_interop(self):
@record
class Baz(c.Struct):
SIZE = 8
a = Field(ctypes.c_int, 0, 30)
b = Field(ctypes.c_int, 3, 30, 6)
c = Field(ctypes.c_int, 7, 2, 4)
d = Field(ctypes.c_int, 7, 2, 6)
src = '''
struct __attribute__((packed)) baz {
int a:30;
int b:30;
int c:2;
int d:2;
};
int test(struct baz x) {
return x.a + x.b + x.c + x.d;
}
'''
dll = self.compile(src)
b = Baz(0xAA000, 0x00BB0, 0, 1)
@dll.bind(ctypes.c_int, Baz)
def test(x:Baz) -> ctypes.c_int: ...
self.assertEqual(test(b), b.a + b.b + b.c + b.d)
# https://github.com/python/cpython/issues/90914
def test_bitfield_interop(self):
@record
class Baz(c.Struct):
SIZE = 1
a = Field(ctypes.c_bool, 0, 1, 0)
b = Field(ctypes.c_bool, 0, 1, 1)
c = Field(ctypes.c_bool, 0, 1, 2)
d = Field(ctypes.c_bool, 0, 1, 3)
e = Field(ctypes.c_bool, 0, 1, 4)
f = Field(ctypes.c_bool, 0, 1, 5)
g = Field(ctypes.c_bool, 0, 1, 6)
h = Field(ctypes.c_bool, 0, 1, 7)
src = '''#include <stdbool.h>
struct baz {
bool a:1, b:1, c:1, d:1, e:1, f:1, g:1, h:1;
};
int test(struct baz x) {
return x.c;
}
'''
dll = self.compile(src)
@dll.bind(ctypes.c_int, Baz)
def test(x:Baz) -> ctypes.c_int: ...
for i in range(8): self.assertEqual(test(Baz(*(j==i for j in range(8)))), i==2)
def test_struct_interop(self):
@record
class Baz(c.Struct):
SIZE = 32
a = Field(ctypes.c_int, 0)
b = Field(ctypes.c_int, 4)
c = Field(ctypes.c_int, 8)
d = Field(ctypes.c_int, 12)
e = Field(ctypes.c_int, 16)
f = Field(ctypes.c_int, 20)
g = Field(ctypes.c_int, 24)
h = Field(ctypes.c_int, 28)
src = '''#include <stdio.h>
struct baz {
int a, b, c, d, e, f, g, h;
};
struct baz test(struct baz x) {
return (struct baz){x.h, x.g, x.f, x.e, x.d, x.c, x.b, x.a};
}
'''
dll = self.compile(src)
@dll.bind(Baz, Baz)
def test(x:Baz) -> Baz: ...
self.assertEqual(bytes(test(Baz(*range(8)))), struct.pack("8i", *range(7, -1, -1)))
def test_aos_interop(self):
@record
class Item(c.Struct):
SIZE = 4
val = Field(ctypes.c_int, 0)
src = """
struct item { int val; };
int test(struct item arr[3]) {
int ret = 0;
for (int i = 0; i < 3; i++) ret += arr[i].val;
return ret;
}
"""
dll = self.compile(src)
@dll.bind(ctypes.c_int, Item * 3)
def test(arr:(Item * 3)) -> ctypes.c_int: ...
self.assertEqual(test((Item * 3)(Item(10), Item(20), Item(30))), 60)
def test_soa_interop(self):
@record
class Row(c.Struct):
SIZE = 16
data = Field(ctypes.c_int * 3, 0)
src = """
struct row { int data[3]; };
struct row test(struct row x) {
return (struct row){{ x.data[2], x.data[1], x.data[0] }};
}
"""
dll = self.compile(src)
@dll.bind(Row, Row)
def test(x:Row) -> Row: ...
r = test(Row((ctypes.c_int * 3)(10, 20, 30)))
self.assertIsInstance(r, Row)
self.assertEqual(r.data[0], 30)
self.assertEqual(r.data[1], 20)
self.assertEqual(r.data[2], 10)
def test_soa_ptr_interop(self):
@record
class Row(c.Struct):
SIZE = 8
data = Field(c.POINTER[ctypes.c_int], 0)
src = """
struct row { int *data; };
int test(struct row x) {
return x.data[2] + x.data[1] + x.data[0];
}
"""
dll = self.compile(src)
@dll.bind(ctypes.c_int, Row)
def test(x:Row) -> ctypes.c_int: ...
assert test(Row((ctypes.c_int * 3)(10, 20, 30))) == 60
def test_nested_struct_interop(self):
@record
class Inner(c.Struct):
SIZE = 4
a = Field(ctypes.c_int, 0)
@record
class Outer(c.Struct):
SIZE = 8
inner = Field(Inner, 0)
b = Field(ctypes.c_int, 4)
src = """
struct i { int a; };
struct o { struct i i; int b; };
struct o test(struct o x) {
return (struct o){(struct i){ x.b }, x.i.a };
}
"""
dll = self.compile(src)
@dll.bind(Outer, Outer)
def test(x:Outer) -> Outer: ...
o = test(Outer(Inner(10), 20))
self.assertEqual(o.inner.a, 20)
self.assertEqual(o.b, 10)
def test_struct_pointer_interop(self):
@record
class Foo(c.Struct):
SIZE = 8
a = Field(ctypes.c_int, 0)
b = Field(ctypes.c_int, 4)
src = """
struct foo { int a, b; };
struct foo *test(struct foo *f) {
int x = f->a;
f->a = f->b;
f->b = x;
return f;
}
"""
dll = self.compile(src)
@dll.bind(ctypes.POINTER(Foo), ctypes.POINTER(Foo))
def test(f:ctypes.POINTER(Foo)) -> ctypes.POINTER(Foo): ...
inp = ctypes.pointer(Foo(10, 20))
out = test(inp)
self.assertEqual(out.contents.a, 20)
self.assertEqual(out.contents.b, 10)
def test_pointer_field_roundtrip(self):
# This tests storing a pointer in a record struct field and passing it to C
# Mimics how mesa.struct_lp_build_tgsi_params.mask is used
from tinygrad.runtime.support.c import POINTER
@record
class Inner(c.Struct):
SIZE = 8
value = Field(ctypes.c_int, 0)
flag = Field(ctypes.c_int, 4)
@record
class Outer(c.Struct):
SIZE = 16
x = Field(ctypes.c_int, 0)
inner_ptr = Field(POINTER[Inner], 8)
src = """
struct inner { int value; int flag; };
struct outer { int x; struct inner *inner_ptr; };
int test(struct inner *p) {
return p->value + p->flag;
}
"""
dll = self.compile(src)
@dll.bind(ctypes.c_int, ctypes.POINTER(Inner))
def test(p:POINTER[Inner]) -> ctypes.c_int: ...
inner = Inner(value=42, flag=10)
outer = Outer(x=1, inner_ptr=ctypes.pointer(inner))
# Retrieve pointer from struct field and pass to C
self.assertEqual(test(outer.inner_ptr), 52)
def test_pointer_field_loses_reference(self):
# BUG: When a pointer is stored in a record struct field, only the address bytes are saved.
# The pointer's _objects dict (which prevents GC of the pointed-to object) is lost.
# This causes the pointed-to object to be garbage collected, leading to use-after-free.
from tinygrad.runtime.support.c import POINTER
@record
class MaskContext(c.Struct):
SIZE = 16
value = Field(ctypes.c_int, 0)
initialized = Field(ctypes.c_int, 4)
ptr = Field(ctypes.c_void_p, 8)
@record
class Params(c.Struct):
SIZE = 16
x = Field(ctypes.c_int, 0)
mask = Field(POINTER[MaskContext], 8)
src = """
struct mask_ctx { int value; int initialized; void *ptr; };
void mask_begin(struct mask_ctx *m, int val) { m->value = val; m->initialized = 1; }
int mask_end(struct mask_ctx *m) { return m->value + m->initialized; }
"""
dll = self.compile(src)
@dll.bind(None, ctypes.POINTER(MaskContext), ctypes.c_int)
def mask_begin(m:POINTER[MaskContext], val:ctypes.c_int) -> None: ...
@dll.bind(ctypes.c_int, ctypes.POINTER(MaskContext))
def mask_end(m:POINTER[MaskContext]) -> ctypes.c_int: ...
# When MaskContext() is created inline, it gets garbage collected after the pointer
# is stored because only the address bytes are saved, not the _objects reference.
params = Params(x=1, mask=ctypes.pointer(MaskContext()))
mask_begin(params.mask, 42)
result = mask_end(params.mask)
self.assertEqual(result, 43) # 42 + 1
@unittest.skipIf(OSX and ('MTLCompiler' in DLL._loaded_ or 'llvm' in DLL._loaded_), "libclang can't be loaded after MTLCompiler or llvm on OSX")
@unittest.skipIf(WIN, "doesn't compile on windows")
class TestAutogen(unittest.TestCase):
def run_gen(self, contents):
with tempfile.NamedTemporaryFile(mode='w', suffix='.h') as f:
f.write(contents)
f.flush()
generated_code = gen(name="test_header", dll=None, files=[f.name])
namespace = {}
exec(generated_code, namespace)
return namespace
def test_packed_structs(self):
ns = self.run_gen("""
typedef unsigned NvU32;
typedef unsigned long NvU64;
typedef struct
{
NvU32 version;
NvU32 size;
NvU64 gfwImageOffset;
NvU32 gfwImageSize;
NvU32 flags;
} __attribute__((packed)) FWSECLIC_READ_VBIOS_DESC;
#define FWSECLIC_READ_VBIOS_STRUCT_FLAGS (2)
typedef struct
{
NvU32 version;
NvU32 size;
NvU32 frtsRegionOffset4K;
NvU32 frtsRegionSize;
NvU32 frtsRegionMediaType;
} __attribute__((packed)) FWSECLIC_FRTS_REGION_DESC;
#define FWSECLIC_FRTS_REGION_MEDIA_FB (2)
#define FWSECLIC_FRTS_REGION_SIZE_1MB_IN_4K (0x100)
typedef struct
{
FWSECLIC_READ_VBIOS_DESC readVbiosDesc;
FWSECLIC_FRTS_REGION_DESC frtsRegionDesc;
} __attribute__((packed)) FWSECLIC_FRTS_CMD;
""")
FWSECLIC_READ_VBIOS_DESC = ns['FWSECLIC_READ_VBIOS_DESC']
FWSECLIC_FRTS_REGION_DESC = ns['FWSECLIC_FRTS_REGION_DESC']
FWSECLIC_FRTS_CMD = ns['FWSECLIC_FRTS_CMD']
read_vbios_desc = FWSECLIC_READ_VBIOS_DESC(version=0x1, size=ctypes.sizeof(FWSECLIC_READ_VBIOS_DESC), flags=2)
frst_reg_desc = FWSECLIC_FRTS_REGION_DESC(version=0x1, size=ctypes.sizeof(FWSECLIC_FRTS_REGION_DESC),
frtsRegionOffset4K=0xdead, frtsRegionSize=0x100, frtsRegionMediaType=2)
frts_cmd = FWSECLIC_FRTS_CMD(readVbiosDesc=read_vbios_desc, frtsRegionDesc=frst_reg_desc)
assert int.from_bytes(frts_cmd, 'little') == 0x2000001000000dead0000001400000001000000020000000000000000000000000000001800000001
assert int.from_bytes(frts_cmd.readVbiosDesc, 'little') == int.from_bytes(read_vbios_desc, 'little')
assert int.from_bytes(frts_cmd.frtsRegionDesc, 'little') == int.from_bytes(frst_reg_desc, 'little')
assert frts_cmd.readVbiosDesc.__class__ is FWSECLIC_READ_VBIOS_DESC
assert frts_cmd.frtsRegionDesc.__class__ is FWSECLIC_FRTS_REGION_DESC
def test_gen_from_header(self):
namespace = self.run_gen("""
typedef struct {
int x;
int y;
} Point;
typedef enum {
RED = 0,
GREEN = 1,
BLUE = 2
} Color;
typedef struct {
Point origin;
int width;
int height;
Color color;
} Rectangle;
int add_points(Point a, Point b);""")
self.assertIn('Point', namespace)
self.assertIn('Color', namespace)
self.assertIn('Rectangle', namespace)
self.assertIn('RED', namespace)
self.assertIn('GREEN', namespace)
self.assertIn('BLUE', namespace)
self.assertEqual(namespace['RED'], 0)
self.assertEqual(namespace['GREEN'], 1)
self.assertEqual(namespace['BLUE'], 2)
Point = namespace['Point']
p = Point()
self.assertTrue(hasattr(p, 'x'))
self.assertTrue(hasattr(p, 'y'))
Rectangle = namespace['Rectangle']
rect = Rectangle()
self.assertTrue(hasattr(rect, 'origin'))
self.assertTrue(hasattr(rect, 'width'))
self.assertTrue(hasattr(rect, 'height'))
self.assertTrue(hasattr(rect, 'color'))
p2 = Point(10, 20)
self.assertEqual(p2.x, 10)
self.assertEqual(p2.y, 20)
def test_struct_ordering(self):
namespace = self.run_gen("""
struct A;
struct C;
typedef struct A A;
struct B {
struct C *c_ptr;
};
struct C {
struct A *a_ptr;
};
struct A {
int x;
struct B *b_ptr;
};""")
self.assertIn('struct_A', namespace)
self.assertIn('struct_B', namespace)
self.assertIn('struct_C', namespace)
A, B, C = namespace['A'], namespace['struct_B'], namespace['struct_C']
a, b, c = A(), B(), C()
self.assertTrue(hasattr(a, 'x'))
self.assertTrue(hasattr(a, 'b_ptr'))
self.assertTrue(hasattr(b, 'c_ptr'))
self.assertTrue(hasattr(c, 'a_ptr'))
def test_anonymous_children(self):
namespace = self.run_gen("""
struct foo {
struct {
int a,b;
} bar;
};
""")
self.assertIn('struct_foo', namespace)
self.assertIn('struct_foo_bar', namespace)
def test_enums(self):
namespace = self.run_gen("""
enum Foo { A, B, C };
enum Bar { X, Y, Z };
""")
assert namespace["A"] == 0
assert namespace["B"] == 1
assert namespace["C"] == 2
assert namespace["X"] == 0
assert namespace["Y"] == 1
assert namespace["Z"] == 2
assert namespace["enum_Foo"].get(0) == "A"
assert namespace["enum_Foo"].get(1) == "B"
assert namespace["enum_Foo"].get(2) == "C"
assert namespace["enum_Bar"].get(0) == "X"
assert namespace["enum_Bar"].get(1) == "Y"
assert namespace["enum_Bar"].get(2) == "Z"
@unittest.skipIf(OSX, "can't find stdint?")
def test_packed_fields(self):
ns = self.run_gen("""#include <stdint.h>
typedef struct die_info
{
uint16_t die_id;
uint16_t die_offset; /* Points to the corresponding die_header structure */
} die_info;
typedef struct ip_discovery_header
{
uint32_t signature; /* Table Signature */
uint16_t version; /* Table Version */
uint16_t size; /* Table Size */
uint32_t id; /* Table ID */
uint16_t num_dies; /* Number of Dies */
die_info die_info[16]; /* list die information for up to 16 dies */
union {
uint16_t padding[1]; /* version <= 3 */
struct { /* version == 4 */
uint8_t base_addr_64_bit : 1; /* ip structures are using 64 bit base address */
uint8_t reserved : 7;
uint8_t reserved2;
};
};
} ip_discovery_header;
""")
ip_discovery_header = ns['ip_discovery_header']
hdr = b'IPDS\x04\x00|\x1d\x80\x1a\xffd\x01\x00\x00\x00\x8c\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x00' # noqa: E501
ihdr = ip_discovery_header.from_buffer_copy(hdr)
assert ctypes.sizeof(ihdr) == 80
assert ihdr.signature == 0x53445049
assert ihdr.version == 0x0004
assert ihdr.num_dies == 1
assert ihdr.base_addr_64_bit == 1
if __name__ == "__main__": unittest.main()