diff --git a/tinygrad/codegen/rewriter.py b/tinygrad/codegen/rewriter.py index 22da6d9871..c49bcf2e4d 100644 --- a/tinygrad/codegen/rewriter.py +++ b/tinygrad/codegen/rewriter.py @@ -6,7 +6,7 @@ from tinygrad.dtype import dtypes, ImageDType, PtrDType from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple, resolve from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same -from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES +from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES from tinygrad.renderer import Renderer # ***** float4/image store handling ***** @@ -124,6 +124,7 @@ powers_of_two = {2**i:i for i in range(64)} def get_late_rewrite_patterns(ops, force_transcendental=False): pat: list[tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \ ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental] + pat.append((UPat(Ops.POW, name="p"), lambda p: xpow(*p.src))) # rewrite MOD to AND (which should always be supported, but not for generic in tests): x % (2**y) -> x & (2**y-1) if Ops.AND in ops: pat += [(UPat.var("x", dtypes.ints)%UPat.cvar("c"), lambda x,c: x & (c.arg-1) if c.arg in powers_of_two else None)] diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 753aed3787..5611cae1d1 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -254,3 +254,13 @@ def xlog2(d:UOp) -> UOp: r = d.ne(d).where(r.const_like(math.nan), r) # log2(-0.0) = -Inf. In certain devices like PTX, x == -0.0 won't be true. so making reciprocal. return d.reciprocal().ne(-math.inf).where(r, r.const_like(-math.inf)) + +def xpow(base:UOp, exponent:UOp) -> UOp: + # start with b ** e = exp2(e * log2(b)) + ret = (base < 0).where(-base, base).log2().mul(exponent).exp2() + # negative base adjustment: nan for non-integer exponent and -1 for odd exponent + adj = (base < 0).where((exponent != exponent.cast(dtypes.int32).cast(exponent.dtype)).where( + ret.const_like(math.nan), + (exponent.cast(dtypes.int32).cast(dtypes.uint32)%2).eq(1).where(ret.const_like(-1), ret.const_like(1))), ret.const_like(1)) + # fix 0 ** 0 = 1 + return (base.eq(0) & exponent.eq(0)).where(ret.const_like(1), ret * adj) diff --git a/tinygrad/gradient.py b/tinygrad/gradient.py index 91d0d3a599..54e6a3465f 100644 --- a/tinygrad/gradient.py +++ b/tinygrad/gradient.py @@ -22,6 +22,7 @@ pm_gradient = PatternMatcher([ (UPat(Ops.SQRT, name="ret"), lambda ctx, ret: (ctx / (ret*2),)), (UPat((Ops.CMPLT, Ops.CMPNE)), lambda: (None, None)), (UPat(Ops.ADD), lambda ctx: (ctx, ctx)), + (UPat(Ops.POW, name="ret"), lambda ctx, ret: (ctx*ret*ret.src[1]/ret.src[0], ctx*ret*ret.src[0].log2()*math.log(2.0))), (UPat(Ops.MAX, name="ret"), lambda ctx, ret: ((ret.src[0]>ret.src[1]).where(ctx, (ret.src[0]!=ret.src[1]).where(ctx.const_like(0), ctx * 0.5)), (ret.src[0] 0 else -math.inf + python_alu: dict[Ops, Callable] = { Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: safe_exp2, Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), - Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, + Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, Ops.POW: safe_pow, Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor, Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, Ops.MAX: max, Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else 0, diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 529554f18a..95bb5837bd 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3313,12 +3313,9 @@ class Tensor(SimpleMathTrait): base, exponent = self._broadcasted(x, reverse=reverse) # TODO: int pow if not base.is_floating_point(): raise RuntimeError("base needs to be float") - # start with b ** e = exp(e * log(b)) - ret = base.abs().log().mul(exponent).exp() - # negative base adjustment: nan for non-integer exponent and -1 for odd exponent - adj = (base < 0).detach().where((exponent != exponent.int()).detach().where(math.nan, (exponent.int()%2==1).where(-1, 1)), 1) - # fix 0 ** 0 = 1 - ret = ((base == 0) * (exponent == 0)).detach().where(1, ret * adj) + + # NOTE: pow(int, float) -> int + ret = base._apply_uop(UOp.pow, exponent) return ret.round().cast(self.dtype) if not dtypes.is_float(self.dtype) else ret def maximum(self, x:Union[Tensor, ConstType]) -> Tensor: