Approximations for SIN/LOG2/EXP2 passing all tests. (#5187)

* [WIP] Added an approximated implementation of Sin(FP32, FP64) passing all tests on Clang runtime

* Map nan/-inf/inf as 1.0 in order to avoid doing as_const(math.inf)

* [WIP] Added a support for LLVM IR

* cleaned up the code for the mypy and linter

* [WIP] Updated fp64 supports (bitwise shift causes the compilation error), fixed linter issue.

* [Add] added fast=true mode which disables the payne-hanek reduction which is slow

* [Fix] fails to compute elements when shape includes zero

* [WIP] Added BinaryOps.ADD/BinaryOps.OR to assembly

* [wip] update the assembly for ptx

* Enables fast=True when device is one of PTX, NV, CUDA, to avoid slow bitwise ops (as lv3 reduction is not required).

* [WIP] Added an approximation of LOG2/EXP2 (FP32, FP64)

* [Fix] Cyclic dependencies existing in xlog2

* [Fix] Cycle dependency in the graph of exp2, and log2. (passing test_symbolic_ops.py)

* [Fix] keep using higher precision for exp2, but cycle graph issue remained to be fixed...

* [Refactor] removed is_metal option. xsin does not rely on fp64 when fp32 mode.

* [WIP] fp16 xsin implementation passing all tests. (still needs to be refactored)

* [WIP] Added fp16 exp2 implementation

* [WIP] Increased the precision of Log2 from 3.5 ULP to 1.0 ULP, and added FP16 Log2 approximation.

* stashed the changes for FP16 sin

* [Fix] Patch for FP16 Sin/Exp2. (updated the dtype_via, fp32_p, and lower)

* [Refactor] migration to fastmath.py, some code simplification, renamed apis in fastmath, et al.

* [Refactor] Added the function polyN to clean-up N-terms polynomial approximation.

* [Patch] Increase fp64 precision when ldexp3k if possible, and patch for fp16 exp2

* [Patch] added bitcast_forward option

* [Patch] resolved cycle graph

* patch fix cycle graph

* set bitcast_forward=True in ilogb2k

* bitcast_forward for multi.py

* E501

* Break into multiple small PRs

* [Patch] FP16 -> FP64 upcast is not anymore required since xlog2 use quad precision polyN

* [Patch] NV still required FP64 for xlog2

* updated schedule test

* updated the count of kernels

* [Update] Removed all bitwise ops (SHL/SHR), tweaked the nan manipulation of log2, passing all tests except for AMD.

* Bitcast: make them api-compatible

* [update] force to use bitcast

* updated the count of constant folding

* [Patch] Creating a mask for exp2 using x <= Inf satisfies True as long as x is a real value

* [Update] isNaN(x) Free log2 algorithm, passing PTX tests, METAL with fastmath enabled is able to handle nan well, amd backend will not crash.

* xsin is reluctant to call payne_hanek_reduction which is slow to compile, passing stable diffusion compilation in a realistic time

* some minor simplification to payne hanek reduction

* [refactor] refactored some rebundant parts existing in payne hanek

* [refactor] more readable payne hanek impl

* [refactor] improved the code consistency of payne hanek

* [experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)

* Revert "[experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)"

This reverts commit 0eee08b87c.

* use allow_buffer_view

* lets support multilazytensor

* updated the count of kernels

* [test] added the jit tests for approx ops

* keep failed constant folding tests tested, added expectedFailure

* explict the timeout deadline when testing approx jit timeout

* [WIP] Simplified the implementation of xsin, never timeouts

* [Refactor] Improved the consistency of approx sin implementation, passing time out tests

* integrated xexp2_base into xexp2

* Set switch_over=39800.0

* delete: is_buffer_fastmath_supported

* sin: compute against abs(x)

* some cleanups

* fix typo

* removed the space between param and dtype

* allow 514 kernels on CI for sd

* [refactor] no need to upcast ad ldexp3k

* [refactor] added some comments, references to help understanding the code.

* [Fix] 1.0 ULP Sine Approximation for FP16

* [update] assume e != 0

* use pow2if instead of ldexp3k to fuse payne_hanek reduction into one

* check if approximated sin/log2/exp are fused into one

* clean up changes

* test amd exp

* some code cleanup and test sigmoid

* fix: enabled payne_hanek for fp16 to achieve higher acc

* fix: payne_hanek always accumlates the value with uint64, and fp16 sin is fused to a single kernel

* [Refactor] Rename: fastmath -> transcendental

* [Refactor] Added TRANSCENDENTAL, Moved the gate function to function.py

* updated const folding tests

* TRANSCENDENTAL as a ContextVar, removed old test of cody waite reduction, added assertions, et al.

* Add: unittest.main()

* Import TRANSCENDENTAL instead of getenv

* Refactor: Added dtype check when TRANSCENDENTAL=2, more context var

* Patch: xlog2, break expt(2, 32) x 2 -> expt(2, 16) x 4 for fp16 math

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
hikettei
2024-07-11 08:44:58 +09:00
committed by GitHub
parent 7c0a657f08
commit 320e7ed935
6 changed files with 428 additions and 8 deletions

View File

@@ -1,12 +1,25 @@
"""This is where the forwards and backwards passes live."""
import math
from typing import Tuple, Optional
from tinygrad.helpers import argsort
from tinygrad.helpers import argsort, TRANSCENDENTAL
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.symbolic import sint
from tinygrad.transcendental import xsin, xlog2, xexp2, is_dtype_transcendental_supported
transcendental_supported_devices = ["CLANG", "LLVM"]
def use_transcendental(d:LazyBuffer) -> bool:
# TRANSCENDENTAL=0 to always ignore.
# TRANSCENDENTAL=1 to run only in CLANG/LLVM (default).
# TRANSCENDENTAL=2 to always run it.
if TRANSCENDENTAL >= 2:
return is_dtype_transcendental_supported(d.dtype)
if TRANSCENDENTAL >= 1:
return (is_dtype_transcendental_supported(d.dtype) and
d.device in transcendental_supported_devices)
return False
class Contiguous(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
@@ -39,10 +52,11 @@ class Reciprocal(Function):
class Sin(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.e(UnaryOps.SIN)
return xsin(x) if use_transcendental(x) else x.e(UnaryOps.SIN)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG)).e(UnaryOps.SIN).e(BinaryOps.MUL, grad_output)
def _xsin(x): return xsin(x) if use_transcendental(x) else x.e(UnaryOps.SIN)
return _xsin(self.x.const(math.pi / 2).e(BinaryOps.ADD, self.x.e(UnaryOps.NEG))).e(BinaryOps.MUL, grad_output)
# NOTE: maximum(x, 0) behaves differently where x=0
class Relu(Function):
@@ -56,13 +70,15 @@ class Relu(Function):
class Log(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
def _xlog2(x): return xlog2(x) if use_transcendental(x) else x.e(UnaryOps.LOG2)
return _xlog2(x).e(BinaryOps.MUL, x.const(math.log(2)))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return grad_output.e(BinaryOps.MUL, self.x.e(UnaryOps.RECIP))
class Exp(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.e(BinaryOps.MUL, x.const(1/math.log(2))).e(UnaryOps.EXP2)
def _xexp2(x): return xexp2(x) if use_transcendental(x) else x.e(UnaryOps.EXP2)
self.ret = _xexp2(x.e(BinaryOps.MUL, x.const(1/math.log(2))))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return self.ret.e(BinaryOps.MUL, grad_output)
@@ -80,7 +96,8 @@ class Sqrt(Function):
# TODO: have the backend automatically find this
class Sigmoid(Function):
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.const(1).e(BinaryOps.ADD, x.e(BinaryOps.MUL, x.const(-1/math.log(2))).e(UnaryOps.EXP2)).e(UnaryOps.RECIP)
def _xexp2(x): return xexp2(x) if use_transcendental(x) else x.e(UnaryOps.EXP2)
self.ret = x.const(1).e(BinaryOps.ADD, _xexp2(x.e(BinaryOps.MUL, x.const(-1/math.log(2))))).e(UnaryOps.RECIP)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: