diff --git a/test/test_ops.py b/test/test_ops.py index ec3d4cf611..28ba1bb40e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -150,6 +150,10 @@ class TestOps(unittest.TestCase): def test_mul(self): helper_test_op([(64,64), (64,64)], lambda x,y: x*y, Tensor.mul) helper_test_op([(), ()], lambda x,y: x*y, Tensor.mul) + def test_mul_const(self): + helper_test_op([(45,65)], lambda x: x*float("inf"), lambda x: x*float("inf")) + helper_test_op([(45,65)], lambda x: x*-float("inf"), lambda x: x*-float("inf")) + helper_test_op([(45,65)], lambda x: x*float("nan"), lambda x: x*float("nan")) def test_div(self): helper_test_op([(45,65), (45,65)], lambda x,y: x/y, Tensor.div) helper_test_op([(), ()], lambda x,y: x/y, Tensor.div) @@ -159,6 +163,12 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: 1/x, lambda x: 1/x) helper_test_op([(45,65)], lambda x: x/2, lambda x: x/2) helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x) + helper_test_op([(45,65)], lambda x: x/float("inf"), lambda x: x/float("inf")) + helper_test_op([(45,65)], lambda x: x/-float("inf"), lambda x: x/-float("inf")) + helper_test_op([(45,65)], lambda x: x/float("nan"), lambda x: x/float("nan")) + helper_test_op([(45,65)], lambda x: float("inf")/x, lambda x: float("inf")/x) + helper_test_op([(45,65)], lambda x: (-float("inf"))/x, lambda x: (-float("inf"))/x) + helper_test_op([(45,65)], lambda x: float("nan")/x, lambda x: float("nan")/x) helper_test_op([()], lambda x: x/2, lambda x: x/2) helper_test_op([()], lambda x: 2/x, lambda x: 2/x) def test_pow(self): diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 73de73b159..0633986760 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -120,8 +120,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan # TODO: merge with CONST? if bufs[args.i] is not None and isinstance(bufs[args.i].realized, RawConst): assert newvar.ltype == LocalTypes.float, "const can't be float4" - # nan? inf? - val = f"{bufs[args.i].realized._buf}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "") + x = bufs[args.i].realized._buf + if math.isnan(x): val = "NAN" + elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" + else: val = f"{x}" + ("f" if not dtypes.is_int(bufs[args.i].dtype) else "") elif isinstance(bufs[args.i].dtype, ImageDType): assert newvar.ltype == LocalTypes.float4, "image must be float4" prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")