use ctypes to truncate float64 and float32 in uops (#3986)

this fixed the softmax.argmax bug for ops_python as the float is truncated to float32
This commit is contained in:
chenyu
2024-03-28 23:56:50 -04:00
committed by GitHub
parent 101a0c683d
commit 793ab0512e
4 changed files with 10 additions and 15 deletions

View File

@@ -44,8 +44,9 @@ python_alu = {
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
TernaryOps.WHERE: lambda x,y,z: y if x else z}
truncate: Dict[DType, Callable] = {
dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
# TODO: float16 and bfloat16?
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,