Ops.POW and transcendental (#8911)

This commit is contained in:
chenyu
2025-02-05 15:15:59 -05:00
committed by GitHub
parent bff7c70eef
commit 9307572fe3
5 changed files with 26 additions and 11 deletions

View File

@@ -6,7 +6,7 @@ from tinygrad.dtype import dtypes, ImageDType, PtrDType
from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve
from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
from tinygrad.renderer import Renderer
# ***** float4/image store handling *****
@@ -124,6 +124,7 @@ powers_of_two = {2**i:i for i in range(64)}
def get_late_rewrite_patterns(ops, force_transcendental=False):
pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \
((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental]
pat.append((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src)))
# rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1)
if Ops.AND in ops:
pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)]

View File

@@ -254,3 +254,13 @@ def xlog2(d:UOp) -> UOp:
r = d.ne(d).where(r.const_like(math.nan), r)
# log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal.
return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf))
def xpow(base:UOp, exponent:UOp) -> UOp:
# start with b ** e = exp2(e * log2(b))
ret = (base < 0).where(-base, base).log2().mul(exponent).exp2()
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
adj = (base < 0).where((exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)).where(
ret.const_like(math.nan),
(exponent.cast(dtypes.int32).cast(dtypes.uint32)%2).eq(1).where(ret.const_like(-1), ret.const_like(1))), ret.const_like(1))
# fix 0 ** 0 = 1
return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * adj)

View File

@@ -22,6 +22,7 @@ pm_gradient = PatternMatcher([
(UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)),
(UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)),
(UPat(Ops.ADD), lambda ctx: (ctx, ctx)),
(UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))),
(UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)),
(ret.src[0]<ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)))),
(UPat(Ops.MUL, name="ret"), lambda ctx, ret: (ret.src[1]*ctx, ret.src[0]*ctx)),

View File

@@ -89,6 +89,7 @@ class MathTrait(SimpleMathTrait):
def sin(self): return self.alu(Ops.SIN)
def log2(self): return self.alu(Ops.LOG2)
def exp2(self): return self.alu(Ops.EXP2)
def pow(self, x): return self.alu(Ops.POW, x)
# the order of these Ops controls the order of the toposort
class Ops(FastEnum):
@@ -133,7 +134,7 @@ class Ops(FastEnum):
# BinaryOps
ADD = auto(); MUL = auto(); IDIV = auto(); MAX = auto(); MOD = auto(); CMPLT = auto(); CMPNE = auto(); XOR = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto() # noqa: E702
SHL = auto(); SHR = auto(); OR = auto(); AND = auto(); THREEFRY = auto(); SUB = auto(); FDIV = auto(); POW = auto() # noqa: E702
# TernaryOps
WHERE = auto(); MULACC = auto() # noqa: E702
@@ -155,7 +156,7 @@ class Ops(FastEnum):
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
Binary = {Ops.ADD, Ops.MUL, Ops.IDIV, Ops.MAX, Ops.MOD, Ops.CMPLT, Ops.CMPNE, Ops.XOR, Ops.SHL, Ops.SHR, Ops.OR, Ops.AND, Ops.THREEFRY,
Ops.SUB, Ops.FDIV}
Ops.SUB, Ops.FDIV, Ops.POW}
Ternary = {Ops.WHERE, Ops.MULACC}
ALU = set.union(Unary, Binary, Ternary)
@@ -175,7 +176,7 @@ class GroupOp:
Idempotent = {Ops.OR, Ops.AND, Ops.MAX}
# do not preserve f(0) = 0
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV}
UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV, Ops.POW}
All = set(Ops)
@@ -675,10 +676,15 @@ def safe_exp2(x):
try: return 2 ** x
except OverflowError: return math.inf
def safe_pow(x, y):
try: return math.nan if isinstance(p:=pow(x, y), complex) else p
except ZeroDivisionError: return math.inf
except ValueError: return math.inf if x > 0 else -math.inf
python_alu: dict[Ops, Callable] = {
Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2,
Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow,
Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt,
Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max,
Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0,

View File

@@ -3313,12 +3313,9 @@ class Tensor(SimpleMathTrait):
base, exponent = self._broadcasted(x, reverse=reverse)
# TODO: int pow
if not base.is_floating_point(): raise RuntimeError("base needs to be float")
# start with b ** e = exp(e * log(b))
ret = base.abs().log().mul(exponent).exp()
# negative base adjustment: nan for non-integer exponent and -1 for odd exponent
adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1)
# fix 0 ** 0 = 1
ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj)
# NOTE: pow(int, float) -> int
ret = base._apply_uop(UOp.pow, exponent)
return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret
def maximum(self, x:Union[Tensor, ConstType]) -> Tensor: