use ctyles for uops truncate (#3985)

This commit is contained in:
chenyu
2024-03-28 23:31:34 -04:00
committed by GitHub
parent 1bf0a7a2d1
commit 101a0c683d

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import functools, math, operator, itertools
import functools, math, operator, itertools, ctypes
from typing import List, Set, Optional, Tuple, Any, Dict, DefaultDict, Callable, cast
from collections import defaultdict
from tinygrad.helpers import DEBUG, flatten, prod
@@ -45,11 +45,12 @@ python_alu = {
TernaryOps.WHERE: lambda x,y,z: y if x else z}
truncate: Dict[DType, Callable] = {
dtypes.bool: lambda x: bool(x),
**{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
**{dt:functools.partial(lambda vv,x: x&vv, (1 << (dt.itemsize*8))-1) for dt in dtypes.fields().values() if dtypes.is_unsigned(dt)},
**{dt:functools.partial(lambda vv,aa,x: ((x+aa)&vv)-aa, (1 << (dt.itemsize*8))-1, 1 << (dt.itemsize*8-1)) \
for dt in dtypes.fields().values() if dtypes.is_int(dt) and not dtypes.is_unsigned(dt)}}
dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
dtypes.int32: lambda x: ctypes.c_int32(x).value, dtypes.int64: lambda x: ctypes.c_int64(x).value,}
def exec_alu(arg, dtype, p): return truncate[dtype](python_alu[arg](*p))
def uop_alu_resolve(u:UOp) -> sint: