hotfix: casted nan/inf

This commit is contained in:
George Hotz
2024-10-31 19:50:17 +08:00
parent a43b7a4b7c
commit f579693ec9

View File

@@ -26,9 +26,9 @@ base_rewrite = PatternMatcher([
(UPat(UOps.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]),
(UPat(UOps.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"),
# const
(UPat(UOps.CONST, arg=math.inf), lambda ctx: ctx.infinity),
(UPat(UOps.CONST, arg=-math.inf), lambda ctx: "-"+ctx.infinity),
(UPat(UOps.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: ctx.nan if math.isnan(x.arg) else None),
(UPat(UOps.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"),
(UPat(UOps.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"),
(UPat(UOps.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None),
(UPat(UOps.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"),
(UPat(UOps.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"),
(UPat(UOps.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"),