mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
force_transcendental on sqrt (#14368)
This commit is contained in:
@@ -325,12 +325,12 @@ def get_late_rewrite_patterns(ops:tuple[Ops, ...], force_transcendental):
|
||||
pat += [(UPat(op, dtype=TRANSCENDENTAL_DTYPES, src=(UPat.var("d"),)), f),
|
||||
(UPat(op, dtype=tuple(dt for dt in dtypes.floats if dt not in TRANSCENDENTAL_DTYPES), src=(UPat.var("d"),), name="x"),
|
||||
lambda x,d: d.cast(dtypes.float32).alu(x.op).cast(x.dtype))]
|
||||
# rewrite SQRT to xpow 0.5
|
||||
if Ops.SQRT not in ops or force_transcendental: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
||||
# no real hardware supports THREEFRY, but NullRenderer does
|
||||
if Ops.THREEFRY not in ops: pat.append((UPat(Ops.THREEFRY, dtype=dtypes.uint64, src=(UPat.var("x"), UPat.var("key"))), threefry2x32))
|
||||
# MAX can be rewritten as CMPLT + WHERE (max function is annoying on many cstyle backends)
|
||||
if Ops.MAX not in ops and Ops.CMPLT in ops: pat.append((UPat(Ops.MAX, name="m"), lambda m: (m.src[0] < m.src[1]).where(m.src[1], m.src[0])))
|
||||
# rewrite SQRT to xpow 0.5
|
||||
if Ops.SQRT not in ops: pat.append((UPat(Ops.SQRT, src=UPat.var("d")), lambda d: xpow(d, d.const_like(0.5))))
|
||||
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
|
||||
if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]
|
||||
if Ops.OR in ops: pat += [(UPat.var("x", dtypes.bool).logical_not()&UPat.var("y", dtypes.bool).logical_not(),
|
||||
|
||||
Reference in New Issue
Block a user