diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 757c2d066d..b34ac132ee 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -1,5 +1,5 @@ from __future__ import annotations -import functools, math, operator, itertools +import functools, math, operator, itertools, ctypes from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast from collections import defaultdict from tinygrad.helpers import DEBUG, flatten, prod @@ -45,11 +45,12 @@ python_alu = { TernaryOps.WHERE: lambda x,y,z: y if x else z} truncate: Dict[DType, Callable] = { - dtypes.bool: lambda x: bool(x), - **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)}, - **{dt:functools.partial(lambda vv,x: x&vv, (1 << (dt.itemsize*8))-1) for dt in dtypes.fields().values() if dtypes.is_unsigned(dt)}, - **{dt:functools.partial(lambda vv,aa,x: ((x+aa)&vv)-aa, (1 << (dt.itemsize*8))-1, 1 << (dt.itemsize*8-1)) \ - for dt in dtypes.fields().values() if dtypes.is_int(dt) and not dtypes.is_unsigned(dt)}} + dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)}, + dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value, + dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value, + dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value, + dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,} + def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p)) def uop_alu_resolve(u:UOp) -> sint: