mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Approximations for SIN/LOG2/EXP2 passing all tests. (#5187)
* [WIP] Added an approximated implementation of Sin(FP32, FP64) passing all tests on Clang runtime
* Map nan/-inf/inf as 1.0 in order to avoid doing as_const(math.inf)
* [WIP] Added a support for LLVM IR
* cleaned up the code for the mypy and linter
* [WIP] Updated fp64 supports (bitwise shift causes the compilation error), fixed linter issue.
* [Add] added fast=true mode which disables the payne-hanek reduction which is slow
* [Fix] fails to compute elements when shape includes zero
* [WIP] Added BinaryOps.ADD/BinaryOps.OR to assembly
* [wip] update the assembly for ptx
* Enables fast=True when device is one of PTX, NV, CUDA, to avoid slow bitwise ops (as lv3 reduction is not required).
* [WIP] Added an approximation of LOG2/EXP2 (FP32, FP64)
* [Fix] Cyclic dependencies existing in xlog2
* [Fix] Cycle dependency in the graph of exp2, and log2. (passing test_symbolic_ops.py)
* [Fix] keep using higher precision for exp2, but cycle graph issue remained to be fixed...
* [Refactor] removed is_metal option. xsin does not rely on fp64 when fp32 mode.
* [WIP] fp16 xsin implementation passing all tests. (still needs to be refactored)
* [WIP] Added fp16 exp2 implementation
* [WIP] Increased the precision of Log2 from 3.5 ULP to 1.0 ULP, and added FP16 Log2 approximation.
* stashed the changes for FP16 sin
* [Fix] Patch for FP16 Sin/Exp2. (updated the dtype_via, fp32_p, and lower)
* [Refactor] migration to fastmath.py, some code simplification, renamed apis in fastmath, et al.
* [Refactor] Added the function polyN to clean-up N-terms polynomial approximation.
* [Patch] Increase fp64 precision when ldexp3k if possible, and patch for fp16 exp2
* [Patch] added bitcast_forward option
* [Patch] resolved cycle graph
* patch fix cycle graph
* set bitcast_forward=True in ilogb2k
* bitcast_forward for multi.py
* E501
* Break into multiple small PRs
* [Patch] FP16 -> FP64 upcast is not anymore required since xlog2 use quad precision polyN
* [Patch] NV still required FP64 for xlog2
* updated schedule test
* updated the count of kernels
* [Update] Removed all bitwise ops (SHL/SHR), tweaked the nan manipulation of log2, passing all tests except for AMD.
* Bitcast: make them api-compatible
* [update] force to use bitcast
* updated the count of constant folding
* [Patch] Creating a mask for exp2 using x <= Inf satisfies True as long as x is a real value
* [Update] isNaN(x) Free log2 algorithm, passing PTX tests, METAL with fastmath enabled is able to handle nan well, amd backend will not crash.
* xsin is reluctant to call payne_hanek_reduction which is slow to compile, passing stable diffusion compilation in a realistic time
* some minor simplification to payne hanek reduction
* [refactor] refactored some rebundant parts existing in payne hanek
* [refactor] more readable payne hanek impl
* [refactor] improved the code consistency of payne hanek
* [experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)
* Revert "[experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)"
This reverts commit 0eee08b87c.
* use allow_buffer_view
* lets support multilazytensor
* updated the count of kernels
* [test] added the jit tests for approx ops
* keep failed constant folding tests tested, added expectedFailure
* explict the timeout deadline when testing approx jit timeout
* [WIP] Simplified the implementation of xsin, never timeouts
* [Refactor] Improved the consistency of approx sin implementation, passing time out tests
* integrated xexp2_base into xexp2
* Set switch_over=39800.0
* delete: is_buffer_fastmath_supported
* sin: compute against abs(x)
* some cleanups
* fix typo
* removed the space between param and dtype
* allow 514 kernels on CI for sd
* [refactor] no need to upcast ad ldexp3k
* [refactor] added some comments, references to help understanding the code.
* [Fix] 1.0 ULP Sine Approximation for FP16
* [update] assume e != 0
* use pow2if instead of ldexp3k to fuse payne_hanek reduction into one
* check if approximated sin/log2/exp are fused into one
* clean up changes
* test amd exp
* some code cleanup and test sigmoid
* fix: enabled payne_hanek for fp16 to achieve higher acc
* fix: payne_hanek always accumlates the value with uint64, and fp16 sin is fused to a single kernel
* [Refactor] Rename: fastmath -> transcendental
* [Refactor] Added TRANSCENDENTAL, Moved the gate function to function.py
* updated const folding tests
* TRANSCENDENTAL as a ContextVar, removed old test of cody waite reduction, added assertions, et al.
* Add: unittest.main()
* Import TRANSCENDENTAL instead of getenv
* Refactor: Added dtype check when TRANSCENDENTAL=2, more context var
* Patch: xlog2, break expt(2, 32) x 2 -> expt(2, 16) x 4 for fp16 math
---------
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
@@ -1,12 +1,25 @@
|
||||
"""This is where the forwards and backwards passes live."""
|
||||
import math
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.helpers import argsort, TRANSCENDENTAL
|
||||
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.transcendental import xsin, xlog2, xexp2, is_dtype_transcendental_supported
|
||||
|
||||
transcendental_supported_devices = ["CLANG", "LLVM"]
|
||||
def use_transcendental(d:LazyBuffer) -> bool:
|
||||
# TRANSCENDENTAL=0 to always ignore.
|
||||
# TRANSCENDENTAL=1 to run only in CLANG/LLVM (default).
|
||||
# TRANSCENDENTAL=2 to always run it.
|
||||
if TRANSCENDENTAL >= 2:
|
||||
return is_dtype_transcendental_supported(d.dtype)
|
||||
if TRANSCENDENTAL >= 1:
|
||||
return (is_dtype_transcendental_supported(d.dtype) and
|
||||
d.device in transcendental_supported_devices)
|
||||
return False
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
||||
@@ -39,10 +52,11 @@ class Reciprocal(Function):
|
||||
class Sin(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.e(UnaryOps.SIN)
|
||||
return xsin(x) if use_transcendental(x) else x.e(UnaryOps.SIN)
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
|
||||
def _xsin(x): return xsin(x) if use_transcendental(x) else x.e(UnaryOps.SIN)
|
||||
return _xsin(self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
|
||||
|
||||
# NOTE: maximum(x, 0) behaves differently where x=0
|
||||
class Relu(Function):
|
||||
@@ -56,13 +70,15 @@ class Relu(Function):
|
||||
class Log(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.x = x
|
||||
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
def _xlog2(x): return xlog2(x) if use_transcendental(x) else x.e(UnaryOps.LOG2)
|
||||
return _xlog2(x).e(BinaryOps.MUL, x.const(math.log(2)))
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
|
||||
|
||||
class Exp(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
|
||||
def _xexp2(x): return xexp2(x) if use_transcendental(x) else x.e(UnaryOps.EXP2)
|
||||
self.ret = _xexp2(x.e(BinaryOps.MUL, x.const(1/math.log(2))))
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
|
||||
@@ -80,7 +96,8 @@ class Sqrt(Function):
|
||||
# TODO: have the backend automatically find this
|
||||
class Sigmoid(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
|
||||
def _xexp2(x): return xexp2(x) if use_transcendental(x) else x.e(UnaryOps.EXP2)
|
||||
self.ret = x.const(1).e(BinaryOps.ADD, _xexp2(x.e(BinaryOps.MUL, x.const(-1/math.log(2))))).e(UnaryOps.RECIP)
|
||||
return self.ret
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
|
||||
|
||||
Reference in New Issue
Block a user