mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
[Bounty] Vectorize Transcendental (#9058)
* init * cast everythig right * more casting * install pillow in test * quick tests * simplify * quick tests * delete test * tests * fix import error * add vec to ldexp3k * vec for bitcast * some helper tests * high level tests * clean tests * change tolerance so cuda passes * ruff passes * remove tests for transcendental helpers * ruff passes * make exponent in power vectorized * fix pow test * add newline * add vec dtype to ilogb2k * comment + clean up * ruff --------- Co-authored-by: chenyu <chenyu@fastmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -447,7 +447,7 @@ jobs:
|
||||
with:
|
||||
key: dsp-minimal
|
||||
deps: testing_minimal
|
||||
pydeps: "onnx==1.16.0 onnxruntime"
|
||||
pydeps: "onnx==1.16.0 onnxruntime pillow"
|
||||
llvm: "true"
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
@@ -466,6 +466,8 @@ jobs:
|
||||
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py
|
||||
- name: Test LLVM=1 DEVECTORIZE=0
|
||||
run: LLVM=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
- name: Test LLVM=1 DEVECTORIZE=0 for model
|
||||
run: PYTHONPATH="." LLVM=1 DEVECTORIZE=0 python3 test/models/test_efficientnet.py
|
||||
- name: Test CPU=1 DEVECTORIZE=0
|
||||
run: CPU=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
|
||||
|
||||
|
||||
@@ -128,5 +128,33 @@ class TestTranscendentalSchedule(unittest.TestCase):
|
||||
c = c.exp2()
|
||||
check_schedule(c, 1)
|
||||
|
||||
class TestTranscendentalVectorized(unittest.TestCase):
|
||||
def _vectorized_data(self, low, high, vec_size):
|
||||
np_data = np.linspace(low, high, num=(128 // vec_size) * vec_size, dtype=np.float32).reshape(-1, vec_size)
|
||||
data = Tensor(np_data, dtype=dtypes.float32.vec(vec_size))
|
||||
return data, np_data
|
||||
|
||||
def _test_vectorized_op(self, fxn, np_fxn, data_range, vec_size, param_range=None):
|
||||
data, np_data = self._vectorized_data(data_range[0], data_range[1], vec_size)
|
||||
if param_range:
|
||||
param, np_param = self._vectorized_data(param_range[0], param_range[1], vec_size)
|
||||
out, np_out = fxn(data, param), np_fxn(np_data, np_param)
|
||||
else:
|
||||
out, np_out = fxn(data), np_fxn(np_data)
|
||||
np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-4)
|
||||
|
||||
def test_exp2_vectorized(self):
|
||||
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.exp2, np.exp2, (-100, 100), vec_size)
|
||||
|
||||
def test_log2_vectorized(self):
|
||||
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.log2, np.log2, (0.001, 200), vec_size)
|
||||
|
||||
def test_sin_vectorized(self):
|
||||
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.sin, np.sin, (-100, 100), vec_size)
|
||||
|
||||
def test_pow_vectorized(self):
|
||||
# np.pow returns nan for negative values raised to a non-integral power
|
||||
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.pow, np.pow, (0.001, 200), vec_size, param_range=(-10, 10))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -183,7 +183,7 @@ devectorize = PatternMatcher([
|
||||
|
||||
devectorize_load_store = PatternMatcher([
|
||||
# TODO: add vectorized support to transcendental
|
||||
(UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
|
||||
(UPat((Ops.INDEX), name="alu"), no_vectorized_alu),
|
||||
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
|
||||
])
|
||||
|
||||
|
||||
@@ -10,9 +10,9 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
|
||||
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
|
||||
|
||||
# *** helper functions for bit manipulation ***
|
||||
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
|
||||
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
|
||||
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d]
|
||||
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
|
||||
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
|
||||
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
|
||||
|
||||
# **** utils ****
|
||||
def shr(x:UOp, y:int) -> UOp: return x // (2**y)
|
||||
@@ -20,41 +20,41 @@ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
|
||||
|
||||
def rintk(d:UOp) -> UOp:
|
||||
"""round d:float to int away from 0"""
|
||||
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
|
||||
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount)
|
||||
return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
|
||||
|
||||
def pow2if(q:UOp, float_dtype:DType):
|
||||
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
|
||||
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
|
||||
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype.scalar()].vec(q.dtype.vcount)
|
||||
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
|
||||
|
||||
def ilogb2k(d:UOp) -> UOp:
|
||||
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount))
|
||||
# -1 <= ilog2bk(d) <= 128
|
||||
return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
|
||||
|
||||
def ldexp3k(d:UOp, e:UOp) -> UOp:
|
||||
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
|
||||
m1 = d.bitcast(cast_map[d.dtype])
|
||||
m2 = shl(e.cast(cast_map[d.dtype]), mantissa_bits(d.dtype))
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.count)
|
||||
m1 = d.bitcast(dtype)
|
||||
m2 = shl(e.cast(dtype), mantissa_bits(d.dtype))
|
||||
return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
|
||||
|
||||
def ldexp2k(d:UOp, e:UOp) -> UOp:
|
||||
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in (dtypes.int16, dtypes.int32, dtypes.int64)
|
||||
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
|
||||
|
||||
def frexp(v:UOp) -> tuple[UOp, UOp]:
|
||||
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
|
||||
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
assert v.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
|
||||
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
|
||||
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype]
|
||||
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype])
|
||||
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype.scalar()]
|
||||
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype.scalar()]
|
||||
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype.scalar()].vec(v.dtype.count))
|
||||
exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
|
||||
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
|
||||
mantissa = ((bits & m1) | m2).bitcast(v.dtype)
|
||||
@@ -70,7 +70,7 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
|
||||
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
|
||||
- `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
|
||||
"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
# https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
|
||||
# 190 bits of 2/pi for Payne-Hanek style argument reduction
|
||||
two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
|
||||
@@ -172,7 +172,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
|
||||
- fast=True assumes x <= switch_over.
|
||||
- switch_over is the threshold for switching to payne_hanek_reduction.
|
||||
"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
# mask +-inf/nan as zero
|
||||
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
||||
# x_sign = sign(x)
|
||||
@@ -194,7 +194,7 @@ def xexp2(d:UOp) -> UOp:
|
||||
Implements a 1.0 ULP approximation for Ops.EXP2
|
||||
- Paper: https://arxiv.org/pdf/2001.09258
|
||||
"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
# mask +=inf/nan as zero.
|
||||
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
|
||||
q = rintk(x)
|
||||
@@ -207,7 +207,7 @@ def xexp2(d:UOp) -> UOp:
|
||||
0.6931471805599452862e+0, 0.1000000000000000000e+1])
|
||||
else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
|
||||
u = ldexp2k(u, q) # u*2^q
|
||||
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
|
||||
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype.scalar()]
|
||||
# Replace x >= upper with +inf
|
||||
u = (d >= upper).where(d.const_like(math.inf), u)
|
||||
# Replace x < lower with zero.
|
||||
@@ -220,7 +220,7 @@ def xlog2(d:UOp) -> UOp:
|
||||
Implements a 1.0 ULP approximation for Ops.LOG2
|
||||
Paper: https://arxiv.org/pdf/2001.09258 5.5
|
||||
"""
|
||||
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
# TODO: float16 denormal need float32 to achieve precision
|
||||
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
|
||||
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
|
||||
@@ -248,7 +248,7 @@ def xlog2(d:UOp) -> UOp:
|
||||
r = (d<-0.0).where(r.const_like(math.nan), r)
|
||||
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
|
||||
# log2_zero = the value of unmasked xlog2(0.0).
|
||||
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
|
||||
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
|
||||
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
|
||||
# log2(NaN) = NaN
|
||||
r = d.ne(d).where(r.const_like(math.nan), r)
|
||||
|
||||
Reference in New Issue
Block a user