diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index d571b99178..003221695d 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -26,9 +26,9 @@ base_rewrite = PatternMatcher([ (UPat(UOps.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]), (UPat(UOps.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"), # const - (UPat(UOps.CONST, arg=math.inf), lambda ctx: ctx.infinity), - (UPat(UOps.CONST, arg=-math.inf), lambda ctx: "-"+ctx.infinity), - (UPat(UOps.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: ctx.nan if math.isnan(x.arg) else None), + (UPat(UOps.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"), + (UPat(UOps.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"), + (UPat(UOps.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None), (UPat(UOps.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"), (UPat(UOps.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"), (UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),