diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 97440fa738..240c0233a7 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -387,7 +387,7 @@ jobs: WEBGPU=1 WGPU_BACKEND_TYPE=Vulkan python3 -m pytest -n=auto test/test_assign.py test/test_arange.py test/test_const_folding.py test/test_dtype.py \ test/test_dtype_alu.py test/test_conv.py test/test_conv_shapetracker.py test/test_nn.py test/test_ops.py test/test_optim.py \ test/test_jit.py test/test_randomness.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_uops_stats.py test/test_uops.py \ - test/testextra/test_export_model.py --durations=20 + test/testextra/test_export_model.py test/testextra/test_f16_decompress.py --durations=20 - name: Run process replay tests run: | export PR_TITLE=$(jq -r .pull_request.title "$GITHUB_EVENT_PATH") diff --git a/extra/f16_decompress.py b/extra/f16_decompress.py new file mode 100644 index 0000000000..26bad9b844 --- /dev/null +++ b/extra/f16_decompress.py @@ -0,0 +1,16 @@ +from tinygrad import Tensor + +def bit_extract(x: Tensor, e: int, s: int) -> Tensor: + mask = (1 << (e - s + 1)) - 1 + return (x >> s) & mask + +def u16_to_f16(x: Tensor) -> Tensor: + sign = bit_extract(x, 15, 15).float() + exponent = bit_extract(x, 14, 10).float() + fraction = bit_extract(x, 9, 0).float() + return sign.where(-1, 1) * exponent.where((exponent - 15.0).exp2() * (1 + fraction / 1024.0), 6.103515625e-5 * (fraction / 1024.0)) + +def u32_to_f16(oo: Tensor) -> Tensor: + f1 = u16_to_f16(oo>>16) + f2 = u16_to_f16(oo&0xFFFF) + return Tensor.cat(f2.reshape(-1, 1), f1.reshape(-1, 1), dim=1).flatten() diff --git a/extra/f16_w_uint32.py b/extra/f16_w_uint32.py deleted file mode 100644 index 82105e9b77..0000000000 --- a/extra/f16_w_uint32.py +++ /dev/null @@ -1,40 +0,0 @@ -import numpy as np -from tinygrad import Device, dtypes, Tensor - -# TODO: will be better when tinygrad does math in the target dtype, can remove the floor and use a mul -def bit_extract(x, s, e) -> Tensor: - # extract the top bits we don't want - top_bits = (x / (1<<(s+1))).floor() * (1<<(s+1)) - x = (x - top_bits) / (1<