Approximations for SIN/LOG2/EXP2 passing all tests. (#5187)

* [WIP] Added an approximated implementation of Sin(FP32, FP64) passing all tests on Clang runtime

* Map nan/-inf/inf as 1.0 in order to avoid doing as_const(math.inf)

* [WIP] Added a support for LLVM IR

* cleaned up the code for the mypy and linter

* [WIP] Updated fp64 supports (bitwise shift causes the compilation error), fixed linter issue.

* [Add] added fast=true mode which disables the payne-hanek reduction which is slow

* [Fix] fails to compute elements when shape includes zero

* [WIP] Added BinaryOps.ADD/BinaryOps.OR to assembly

* [wip] update the assembly for ptx

* Enables fast=True when device is one of PTX, NV, CUDA, to avoid slow bitwise ops (as lv3 reduction is not required).

* [WIP] Added an approximation of LOG2/EXP2 (FP32, FP64)

* [Fix] Cyclic dependencies existing in xlog2

* [Fix] Cycle dependency in the graph of exp2, and log2. (passing test_symbolic_ops.py)

* [Fix] keep using higher precision for exp2, but cycle graph issue remained to be fixed...

* [Refactor] removed is_metal option. xsin does not rely on fp64 when fp32 mode.

* [WIP] fp16 xsin implementation passing all tests. (still needs to be refactored)

* [WIP] Added fp16 exp2 implementation

* [WIP] Increased the precision of Log2 from 3.5 ULP to 1.0 ULP, and added FP16 Log2 approximation.

* stashed the changes for FP16 sin

* [Fix] Patch for FP16 Sin/Exp2. (updated the dtype_via, fp32_p, and lower)

* [Refactor] migration to fastmath.py, some code simplification, renamed apis in fastmath, et al.

* [Refactor] Added the function polyN to clean-up N-terms polynomial approximation.

* [Patch] Increase fp64 precision when ldexp3k if possible, and patch for fp16 exp2

* [Patch] added bitcast_forward option

* [Patch] resolved cycle graph

* patch fix cycle graph

* set bitcast_forward=True in ilogb2k

* bitcast_forward for multi.py

* E501

* Break into multiple small PRs

* [Patch] FP16 -> FP64 upcast is not anymore required since xlog2 use quad precision polyN

* [Patch] NV still required FP64 for xlog2

* updated schedule test

* updated the count of kernels

* [Update] Removed all bitwise ops (SHL/SHR), tweaked the nan manipulation of log2, passing all tests except for AMD.

* Bitcast: make them api-compatible

* [update] force to use bitcast

* updated the count of constant folding

* [Patch] Creating a mask for exp2 using x <= Inf satisfies True as long as x is a real value

* [Update] isNaN(x) Free log2 algorithm, passing PTX tests, METAL with fastmath enabled is able to handle nan well, amd backend will not crash.

* xsin is reluctant to call payne_hanek_reduction which is slow to compile, passing stable diffusion compilation in a realistic time

* some minor simplification to payne hanek reduction

* [refactor] refactored some rebundant parts existing in payne hanek

* [refactor] more readable payne hanek impl

* [refactor] improved the code consistency of payne hanek

* [experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)

* Revert "[experiment] topological sort when doing _recursive_group (i dunno if this is good but at least it works.)"

This reverts commit 0eee08b87c.

* use allow_buffer_view

* lets support multilazytensor

* updated the count of kernels

* [test] added the jit tests for approx ops

* keep failed constant folding tests tested, added expectedFailure

* explict the timeout deadline when testing approx jit timeout

* [WIP] Simplified the implementation of xsin, never timeouts

* [Refactor] Improved the consistency of approx sin implementation, passing time out tests

* integrated xexp2_base into xexp2

* Set switch_over=39800.0

* delete: is_buffer_fastmath_supported

* sin: compute against abs(x)

* some cleanups

* fix typo

* removed the space between param and dtype

* allow 514 kernels on CI for sd

* [refactor] no need to upcast ad ldexp3k

* [refactor] added some comments, references to help understanding the code.

* [Fix] 1.0 ULP Sine Approximation for FP16

* [update] assume e != 0

* use pow2if instead of ldexp3k to fuse payne_hanek reduction into one

* check if approximated sin/log2/exp are fused into one

* clean up changes

* test amd exp

* some code cleanup and test sigmoid

* fix: enabled payne_hanek for fp16 to achieve higher acc

* fix: payne_hanek always accumlates the value with uint64, and fp16 sin is fused to a single kernel

* [Refactor] Rename: fastmath -> transcendental

* [Refactor] Added TRANSCENDENTAL, Moved the gate function to function.py

* updated const folding tests

* TRANSCENDENTAL as a ContextVar, removed old test of cody waite reduction, added assertions, et al.

* Add: unittest.main()

* Import TRANSCENDENTAL instead of getenv

* Refactor: Added dtype check when TRANSCENDENTAL=2, more context var

* Patch: xlog2, break expt(2, 32) x 2 -> expt(2, 16) x 4 for fp16 math

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
Co-authored-by: chenyu <chenyu@fastmail.com>
This commit is contained in:
hikettei
2024-07-11 08:44:58 +09:00
committed by GitHub
parent 7c0a657f08
commit 320e7ed935
6 changed files with 428 additions and 8 deletions

View File

@@ -13,8 +13,12 @@ def _check_ast_count(desired_count:int, t:Tensor):
assert len(asts) == desired_count
class TestUnaryOpsConstFolding(unittest.TestCase):
def test_all_consts_ops(self):
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "CLANG"], f"no support on {Device.DEFAULT}")
@unittest.expectedFailure
def test_all_const_ops_todo(self):
_check_ast_count(0, Tensor.ones(4).exp())
def test_all_consts_ops(self):
_check_ast_count(0, Tensor.ones(4).sqrt())
_check_ast_count(0, Tensor.ones(4) + Tensor.ones(4))
_check_ast_count(0, Tensor.ones(4) / Tensor.ones(4))
@@ -87,8 +91,12 @@ class TestBinaryOpsConstFolding(unittest.TestCase):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** 1)
def test_pow_tensor_one(self):
_check_ast_count(0, Tensor([1.0, 2, 3, 4]) ** Tensor.ones(4))
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "CLANG"], f"no support on {Device.DEFAULT}")
@unittest.expectedFailure
def test_literal_one_pow(self):
_check_ast_count(0, 1 ** Tensor([1.0, 2, 3, 4]))
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "CLANG"], f"no support on {Device.DEFAULT}")
@unittest.expectedFailure
def test_tensor_one_pow(self):
_check_ast_count(0, Tensor.ones(4) ** Tensor([1.0, 2, 3, 4]))

View File

@@ -2,12 +2,14 @@
import unittest, functools
import numpy as np
from hypothesis import given, settings, strategies as strat
from test.helpers import assert_jit_cache_len
from tinygrad.tensor import Tensor
from tinygrad.engine.jit import TinyJit
from tinygrad.device import Device
from tinygrad.helpers import CI, Context
from tinygrad.dtype import dtypes
from extra.models.unet import ResBlock
def _simple_test(add, extract=lambda x: x, N=10):
for _ in range(5):
@@ -18,6 +20,19 @@ def _simple_test(add, extract=lambda x: x, N=10):
assert_jit_cache_len(add, 1)
class TestJit(unittest.TestCase):
@settings(deadline=2e4)
@unittest.skipUnless(Device.DEFAULT in ["LLVM", "CLANG"], f"no support on {Device.DEFAULT}")
@given(strat.sampled_from([Tensor.exp2, Tensor.log2, Tensor.sin]))
def test_approx_jit_timeout(self, op):
with Context(TRANSCENDENTAL=2):
model = [ResBlock(16, 24, 16) for _ in range(4)]
@TinyJit
def fw_approx(t, t2):
for l in model: t = l(t, t2)
return op(t).realize()
fw_approx(Tensor.empty(4, 16, 8, 8), Tensor.empty(1, 24))
def test_simple_jit(self):
@TinyJit
def add(a, b): return (a+b).realize()

View File

@@ -0,0 +1,33 @@
import unittest
from tinygrad.tensor import Tensor
from tinygrad.helpers import Context
from test.test_schedule import check_schedule
class TestTranscendentalSchedule(unittest.TestCase):
# w/ payne_hanek_reduction (fp32)
def test_transcendental_sin_fusion(self):
with Context(TRANSCENDENTAL=2):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a.sin() + b.sin()
c = c.sin()
check_schedule(c, 1)
def test_transcendental_log2_fusion(self):
with Context(TRANSCENDENTAL=2):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a.log2() + b.log2()
c = c.log2()
check_schedule(c, 1)
def test_transcendental_exp2_fusion(self):
with Context(TRANSCENDENTAL=2):
a = Tensor.empty(10)
b = Tensor.empty(10)
c = a.exp2() + b.exp2()
c = c.exp2()
check_schedule(c, 1)
if __name__ == '__main__':
unittest.main()