[Bounty] Vectorize Transcendental (#9058)

* init

* cast everythig right

* more casting

* install pillow in test

* quick tests

* simplify

* quick tests

* delete test

* tests

* fix import error

* add vec to ldexp3k

* vec for bitcast

* some helper tests

* high level tests

* clean tests

* change tolerance so cuda passes

* ruff passes

* remove tests for transcendental helpers

* ruff passes

* make exponent in power vectorized

* fix pow test

* add newline

* add vec dtype to ilogb2k

* comment + clean up

* ruff

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Eitan Turok
2025-02-28 09:47:25 +02:00
committed by GitHub
parent 8ae215dd3d
commit d657d5f754
4 changed files with 54 additions and 24 deletions

View File

@@ -447,7 +447,7 @@ jobs:
with:
key: dsp-minimal
deps: testing_minimal
pydeps: "onnx==1.16.0 onnxruntime"
pydeps: "onnx==1.16.0 onnxruntime pillow"
llvm: "true"
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
@@ -466,6 +466,8 @@ jobs:
run: PYTHONPATH="." DEBUG=2 DSP=1 python3 test/test_quantize_onnx.py
- name: Test LLVM=1 DEVECTORIZE=0
run: LLVM=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"
- name: Test LLVM=1 DEVECTORIZE=0 for model
run: PYTHONPATH="." LLVM=1 DEVECTORIZE=0 python3 test/models/test_efficientnet.py
- name: Test CPU=1 DEVECTORIZE=0
run: CPU=1 DEVECTORIZE=0 python3 -m pytest -n auto test/test_tiny.py test/test_ops.py -k "not test_avg_pool3d_failure"

View File

@@ -128,5 +128,33 @@ class TestTranscendentalSchedule(unittest.TestCase):
c = c.exp2()
check_schedule(c, 1)
class TestTranscendentalVectorized(unittest.TestCase):
def _vectorized_data(self, low, high, vec_size):
np_data = np.linspace(low, high, num=(128 // vec_size) * vec_size, dtype=np.float32).reshape(-1, vec_size)
data = Tensor(np_data, dtype=dtypes.float32.vec(vec_size))
return data, np_data
def _test_vectorized_op(self, fxn, np_fxn, data_range, vec_size, param_range=None):
data, np_data = self._vectorized_data(data_range[0], data_range[1], vec_size)
if param_range:
param, np_param = self._vectorized_data(param_range[0], param_range[1], vec_size)
out, np_out = fxn(data, param), np_fxn(np_data, np_param)
else:
out, np_out = fxn(data), np_fxn(np_data)
np.testing.assert_allclose(out.numpy(), np_out, rtol=1e-4)
def test_exp2_vectorized(self):
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.exp2, np.exp2, (-100, 100), vec_size)
def test_log2_vectorized(self):
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.log2, np.log2, (0.001, 200), vec_size)
def test_sin_vectorized(self):
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.sin, np.sin, (-100, 100), vec_size)
def test_pow_vectorized(self):
# np.pow returns nan for negative values raised to a non-integral power
for vec_size in [1,2,3,4,5,127,128]: self._test_vectorized_op(Tensor.pow, np.pow, (0.001, 200), vec_size, param_range=(-10, 10))
if __name__ == '__main__':
unittest.main()

View File

@@ -183,7 +183,7 @@ devectorize = PatternMatcher([
devectorize_load_store = PatternMatcher([
# TODO: add vectorized support to transcendental
(UPat((Ops.INDEX, Ops.EXP2, Ops.LOG2, Ops.SIN), name="alu"), no_vectorized_alu),
(UPat((Ops.INDEX), name="alu"), no_vectorized_alu),
(UPat((Ops.LOAD, Ops.STORE), name="ls"), no_vectorized_load_store),
])

View File

@@ -10,9 +10,9 @@ def _lazy_map_numbers(x:UOp, inf:UOp, _inf:UOp, nan:UOp, ratio:UOp):
return x.ne(math.inf).where(x.ne(x).where(nan, x.ne(-math.inf).where(ratio, _inf)), inf)
# *** helper functions for bit manipulation ***
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d)[1]
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d]
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d]
def mantissa_bits(d:DType) -> int: return dtypes.finfo(d.scalar())[1]
def exponent_bias(d:DType) -> int: return {dtypes.float64: 1023, dtypes.float32: 127, dtypes.float16: 15}[d.scalar()]
def exponent_mask(d:DType) -> int: return {dtypes.float64: 2047, dtypes.float32: 255, dtypes.float16: 31}[d.scalar()]
# **** utils ****
def shr(x:UOp, y:int) -> UOp: return x // (2**y)
@@ -20,41 +20,41 @@ def shl(x:UOp, y:int) -> UOp: return x * (2**y)
def rintk(d:UOp) -> UOp:
"""round d:float to int away from 0"""
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype]
out_dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount)
return (d + (d<0.0).where(d.const_like(-0.5), d.const_like(0.5))).cast(out_dtype)
def pow2if(q:UOp, float_dtype:DType):
"""cast(2^q, float_dtype) where q is any integer in the range of [-126, 127]"""
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype]
out_dtype = {dtypes.int64: dtypes.float64, dtypes.int32: dtypes.float32, dtypes.int16: float_dtype}[q.dtype.scalar()].vec(q.dtype.vcount)
return shl(q + exponent_bias(out_dtype), mantissa_bits(out_dtype)).bitcast(out_dtype)
def ilogb2k(d:UOp) -> UOp:
"""calculate the integer part of log2(d), where d is normalized fp value in the range of [0, +inf)."""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype])
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
dint = d.bitcast({dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.vcount))
# -1 <= ilog2bk(d) <= 128
return (shr(dint, mantissa_bits(d.dtype)) & exponent_mask(d.dtype)) - exponent_bias(d.dtype)
def ldexp3k(d:UOp, e:UOp) -> UOp:
"""d*2^e. e is a number obtained by casting an integer in the range [-127, 127] to a float. d is any float number."""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
cast_map = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}
m1 = d.bitcast(cast_map[d.dtype])
m2 = shl(e.cast(cast_map[d.dtype]), mantissa_bits(d.dtype))
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
dtype = {dtypes.float64: dtypes.int64, dtypes.float32: dtypes.int32, dtypes.float16: dtypes.int16}[d.dtype.scalar()].vec(d.dtype.count)
m1 = d.bitcast(dtype)
m2 = shl(e.cast(dtype), mantissa_bits(d.dtype))
return (m1 + m2).bitcast(d.dtype).cast(d.dtype)
def ldexp2k(d:UOp, e:UOp) -> UOp:
"""d*2^e. much faster than ldexp3k but risky. d > 0 and d is not denormal."""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype in (dtypes.int16, dtypes.int32, dtypes.int64)
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES and e.dtype.scalar() in (dtypes.int16, dtypes.int32, dtypes.int64)
return (d * pow2if(shr(e, 1), d.dtype)) * pow2if(e - shr(e, 1), d.dtype)
def frexp(v:UOp) -> tuple[UOp, UOp]:
"""frexp(v) -> (mantissa, exponent) assuming v != 0"""
assert v.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
assert v.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
# m1 = masks for mantissa, m2 = masks to normalize the mantissa.
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype]
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype]
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype])
m1 = {dtypes.float64: 0x000FFFFFFFFFFFFF, dtypes.float32: 0x807FFFFF, dtypes.float16: 0x83FF}[v.dtype.scalar()]
m2 = {dtypes.float64: 0x3FE0000000000000, dtypes.float32: 0x3F000000, dtypes.float16: 0x3800}[v.dtype.scalar()]
bits = v.bitcast({dtypes.float64: dtypes.uint64, dtypes.float32: dtypes.uint32, dtypes.float16: dtypes.uint16}[v.dtype.scalar()].vec(v.dtype.count))
exponent = shr(bits, mantissa_bits(v.dtype)) & exponent_mask(v.dtype)
# Set the exponent bits appropriately to normalize the mantissa into the range of [0.5, 1.0).
mantissa = ((bits & m1) | m2).bitcast(v.dtype)
@@ -70,7 +70,7 @@ def payne_hanek_reduction(d:UOp) -> tuple[UOp, UOp]:
- `r`[d.dtype] is the reminder value corresponding to `round_to_nearest(x % pi/2)`.
- `q`[int32] is an integer, and q % 4 is corresponding to the quadrant of the original angle `d`.
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
# https://stackoverflow.com/questions/30463616/payne-hanek-algorithm-implementation-in-c/30465751#30465751
# 190 bits of 2/pi for Payne-Hanek style argument reduction
two_over_pi_f = [0x00000000, 0x28be60db, 0x9391054a, 0x7f09d5f4, 0x7d4d3770, 0x36d8a566, 0x4f10e410]
@@ -172,7 +172,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp:
- fast=True assumes x <= switch_over.
- switch_over is the threshold for switching to payne_hanek_reduction.
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
# mask +-inf/nan as zero
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
# x_sign = sign(x)
@@ -194,7 +194,7 @@ def xexp2(d:UOp) -> UOp:
Implements a 1.0 ULP approximation for Ops.EXP2
- Paper: https://arxiv.org/pdf/2001.09258
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
# mask +=inf/nan as zero.
x = _lazy_map_numbers(d, d.const_like(0.0), d.const_like(0.0), d.const_like(0.0), d)
q = rintk(x)
@@ -207,7 +207,7 @@ def xexp2(d:UOp) -> UOp:
0.6931471805599452862e+0, 0.1000000000000000000e+1])
else: u = polyN(s, [0.1535920892e-3, 0.1339262701e-2, 0.9618384764e-2, 0.5550347269e-1, 0.2402264476e+0, 0.6931471825e+0, 1.0])
u = ldexp2k(u, q) # u*2^q
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype]
upper, lower = {dtypes.float64: (1024, -2000), dtypes.float32: (128, -150), dtypes.float16: (23, -22)}[d.dtype.scalar()]
# Replace x >= upper with +inf
u = (d >= upper).where(d.const_like(math.inf), u)
# Replace x < lower with zero.
@@ -220,7 +220,7 @@ def xlog2(d:UOp) -> UOp:
Implements a 1.0 ULP approximation for Ops.LOG2
Paper: https://arxiv.org/pdf/2001.09258 5.5
"""
assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES
assert d.dtype.scalar() in TRANSCENDENTAL_SUPPORTED_DTYPES
# TODO: float16 denormal need float32 to achieve precision
if d.dtype == dtypes.float16: return xlog2(d.cast(dtypes.float32)).cast(dtypes.float16)
FLT_MIN = d.const_like(1e-6 if d.dtype == dtypes.float16 else 1e-4)
@@ -248,7 +248,7 @@ def xlog2(d:UOp) -> UOp:
r = (d<-0.0).where(r.const_like(math.nan), r)
# log2(0) = -Inf, but we will compare using the value of y because 1e-200==0 is true.
# log2_zero = the value of unmasked xlog2(0.0).
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype]
log2_zero = {dtypes.float64: -1087, dtypes.float32: -191, dtypes.float16: -79}[d.dtype.scalar()]
r = r.ne(log2_zero).where(r, r.const_like(-math.inf))
# log2(NaN) = NaN
r = d.ne(d).where(r.const_like(math.nan), r)