force_transcendental on sqrt (#14368)

This commit is contained in:
George Hotz
2026-01-27 20:24:41 +08:00
committed by GitHub
parent 14574c68fa
commit 131ae604de

View File

@@ -325,12 +325,12 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental):
pat += [(UPat(op, dtype=TRANSCENDENTAL_DTYPES, src=(UPat.var("d"),)), f),
(UPat(op, dtype=tuple(dt for dt in dtypes.floats if dt not in TRANSCENDENTAL_DTYPES), src=(UPat.var("d"),), name="x"),
lambda x,d: d.cast(dtypes.float32).alu(x.op).cast(x.dtype))]
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# no real hardware supports THREEFRY, but NullRenderer does
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
# rewrite SQRT to xpow 0.5
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),