From 8ea53951c1bb241326cc286d8faff230687e8dcd Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 15 Mar 2024 15:05:13 -0400 Subject: [PATCH] bfloat16 Tensor.rand (#3764) * Tensor.rand for bfloat16 for numpy based random, generate one for float then cast for bfloat16. close #3653 * remove realize --- .github/workflows/test.yml | 2 +- test/test_randomness.py | 21 +++++++++++++++++++-- tinygrad/tensor.py | 7 ++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 8a13dc7e2f..a432379b0e 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -468,7 +468,7 @@ jobs: run: python -m pytest -n=auto test/ -k 'not (half or test_efficientnet_safetensors)' --ignore=test/external --ignore=test/models --durations=20 - name: Run pytest (hip) if: matrix.backend=='hip' - run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/imported/test_indexing.py test/external/external_test_hip_compile.py --durations=20 + run: python -m pytest -n=auto test/test_ops.py test/test_dtype.py test/test_dtype_alu.py test/test_linearizer.py test/test_randomness.py test/imported/test_indexing.py test/external/external_test_hip_compile.py --durations=20 #testunicorn: # name: ARM64 unicorn Test diff --git a/test/test_randomness.py b/test/test_randomness.py index ce79e85f45..2d04b44606 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -1,9 +1,11 @@ import math import unittest +from functools import partial + import numpy as np import torch from tinygrad import nn, dtypes, Tensor -from functools import partial +from test.helpers import is_dtype_supported # https://gist.github.com/devries/11405101 def ksprob(a): @@ -60,13 +62,28 @@ class TestRandomness(unittest.TestCase): def test_rand_half(self): N = 128 - x = Tensor.rand((2, N, N), dtype=dtypes.half).realize().numpy() + x = Tensor.rand((2, N, N), dtype=dtypes.half) + assert x.dtype == dtypes.half + x = x.numpy() ones = np.take(x, np.where(x == 1)) zeros = np.take(x, np.where(x == 0)) self.assertTrue(ones.size == 0) self.assertTrue(zeros.size > 0) equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) + @unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support") + def test_rand_bfloat16(self): + N = 128 + x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16) + assert x.dtype == dtypes.bfloat16 + # TODO: fix this property for bfloat16 random + # x = x.numpy() + # ones = np.take(x, np.where(x == 1)) + # zeros = np.take(x, np.where(x == 0)) + # self.assertTrue(ones.size == 0) + # self.assertTrue(zeros.size > 0) + equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) + def test_randn(self): self.assertTrue(normal_test(Tensor.randn)) self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x))) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 162903b667..034368d884 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -226,7 +226,12 @@ class Tensor: def manual_seed(seed=0): Tensor._seed = seed @staticmethod - def rand(*shape, **kwargs): return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, **kwargs) + def rand(*shape, **kwargs): + if kwargs.get("dtype") == dtypes.bfloat16: + # TODO: remove this once we use threefry for rand. + kwargs.pop("dtype") + return Tensor.rand(*shape, **kwargs, dtype=dtypes.float).cast(dtypes.bfloat16) + return Tensor._loadop(LoadOps.CUSTOM, argfix(*shape), arg=custom_random, **kwargs) # ***** creation helper functions *****