mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-11 07:05:04 -05:00
use ctypes to truncate float64 and float32 in uops (#3986)
this fixed the softmax.argmax bug for ops_python as the float is truncated to float32
This commit is contained in:
@@ -44,8 +44,9 @@ python_alu = {
|
||||
BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf),
|
||||
TernaryOps.WHERE: lambda x,y,z: y if x else z}
|
||||
|
||||
truncate: Dict[DType, Callable] = {
|
||||
dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
|
||||
truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)},
|
||||
# TODO: float16 and bfloat16?
|
||||
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
|
||||
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
|
||||
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
|
||||
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
|
||||
|
||||
Reference in New Issue
Block a user