From 793ab0512eff3d3ec458e17497f6018f8eff99d4 Mon Sep 17 00:00:00 2001 From: chenyu Date: Thu, 28 Mar 2024 23:56:50 -0400 Subject: [PATCH] use ctypes to truncate float64 and float32 in uops (#3986) this fixed the softmax.argmax bug for ops_python as the float is truncated to float32 --- test/test_dtype.py | 7 +------ test/test_ops.py | 1 - test/test_uops.py | 12 ++++++------ tinygrad/codegen/uops.py | 5 +++-- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/test/test_dtype.py b/test/test_dtype.py index 67dd9983a9..9254b4ac07 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -620,12 +620,7 @@ class TestImplicitFunctionTypeChange(unittest.TestCase): t = func(Tensor([4.0, 3.0])).max() == func(Tensor([4.0, 3.0])) result.append(t.numpy().sum()) - if Device.DEFAULT not in ["PYTHON"]: - assert all(result) - else: - # PYTHON function default returns in double, and comparison to float can fail - # TODO: fix this - assert not all(result) + assert all(result) if __name__ == '__main__': unittest.main() diff --git a/test/test_ops.py b/test/test_ops.py index ee5be0ecbd..cecf244cc5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -793,7 +793,6 @@ class TestOps(unittest.TestCase): helper_test_op([(10,10,10)], lambda x: x.softmax(0), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.softmax(1), atol=1e-7, grad_atol=1e-7) helper_test_op([(10,10,10)], lambda x: x.softmax(2), atol=1e-7, grad_atol=1e-7) - @unittest.skipIf(Device.DEFAULT in ["PYTHON"], "Broken ISSUE #3552") def test_softmax_argmax(self): helper_test_op([(45,65)], lambda x: x.softmax(0).argmax().type(torch.int32), lambda x: x.softmax(0).argmax(), forward_only=True, atol=1e-7, grad_atol=1e-7) diff --git a/test/test_uops.py b/test/test_uops.py index a71f2c4d69..0584be645a 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -165,9 +165,9 @@ class TestExecALU(TestUOps): self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (7, -3)), -2) self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.int8, (-50, 6)), -8) - self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (8.0, 2.0)), 4.0) - self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, 3.0)), 2+(1.0/3.0)) - self.assertEqual(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, -3.0)), -2-(1.0/3.0)) + np.testing.assert_allclose(exec_alu(BinaryOps.DIV, dtypes.float32, (8.0, 2.0)), 4.0) + np.testing.assert_allclose(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, 3.0)), 2+(1.0/3.0)) + np.testing.assert_allclose(exec_alu(BinaryOps.DIV, dtypes.float32, (7.0, -3.0)), -2-(1.0/3.0)) def test_bool_neg(self): self.assertEqual(exec_alu(UnaryOps.NEG, dtypes.bool, (False,)), True) @@ -180,9 +180,9 @@ class TestExecALU(TestUOps): self.assertEqual(exec_alu(BinaryOps.CMPLT, dtypes.bool, (True, True)), False) def test_bool_where(self): - self.assertIs(exec_alu(TernaryOps.WHERE, dtypes.bool, (False, False, False)), False) - self.assertIs(exec_alu(TernaryOps.WHERE, dtypes.int, (False, 2, 4)), 4) - self.assertIs(exec_alu(TernaryOps.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5) + self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.bool, (False, False, False)), False) + self.assertEqual(exec_alu(TernaryOps.WHERE, dtypes.int, (False, 2, 4)), 4) + np.testing.assert_allclose(exec_alu(TernaryOps.WHERE, dtypes.float, (False, 2.2, 4.5)), 4.5) def test_overflow(self): self.assertEqual(exec_alu(BinaryOps.ADD, dtypes.uint8, (250, 250)), 244) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index b34ac132ee..f36ed79d76 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -44,8 +44,9 @@ python_alu = { BinaryOps.DIV: lambda x,y: int(x/y) if isinstance(x, int) else (x/y if y != 0 else x*math.inf), TernaryOps.WHERE: lambda x,y,z: y if x else z} -truncate: Dict[DType, Callable] = { - dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)}, +truncate: Dict[DType, Callable] = {dtypes.bool: bool, **{dt:lambda x: x for dt in dtypes.fields().values() if dtypes.is_float(dt)}, + # TODO: float16 and bfloat16? + dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value, dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value, dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value, dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,