Files
tinygrad/test/test_randomness.py
wozeparrot a0ab755317 threefry again (#3785)
* feat: initial xor

* feat: initial threefly

* feat: remove custom random

* fix: really need to install precommit

* feat: lmao forgot that this is rotate not a shift

* clean: put that there

* feat: numpy xor

* feat: quick test for xor

* feat: llvm xor

* feat: slightly working xor in torch

* feat: rand works in jit

* clean: save a line

* feat: match jax

* feat: maybe test against jax

* feat: requires_grad

* fix: fix test_symbolic_ops

* feat: lower alpha

* feat: just pad

* fix: maybe fix training tests?

* fix: fix some llvm stuff

* feat: cursed realize on the way out

* feat: testing jax

* fix: why is the jax install process not simple

* fix: maybe passing test

* fix: symbolic workarounds

* clean: still need that precommit

* fix: aaaa

* fix: more test fixes

* fix: quick fix for wgsl

* feat: need to set requires_grad on the final tensor

* feat: one more tensor

* feat: don't take forever

* feat: seeing y ci is brok

* feat: can't allocate 64GiB lmao

* fix: fix this

* feat: hope this doesn't break smth before i go to bed

* feat: don't destroy ram

* feat: int

* feat: remove jax

* feat: properish workaround?

* feat: skip slow webgpu tests

* feat: no longer fails

* feat: use dtypes

* feat: real number

* fix: torch

* fix: don't test against reference for torch

* feat: to device

* feat: fix advanced indexing

* feat: correct casting

* feat: even rng_counter

* feat: match master

* feat: this was actually bad

* fix: maybe?

* feat: store

* feat: remove realizes

* feat: somehow this is important

* feat: somehow this is also important

* feat: save a line

* fix: don't need that anymore

* feat: restore this

* fix: linter

* feat: remove realizes

* fix: realized is in base now

* fix: add back cast

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: bump deadline

* fix: :(

* fix: :(

* fix: not being dumb

* feat: try changing less tests

* feat: shouldn't have to change that

* feat: contiguous bumps it by one

* fix: hmm

* fix: numpy memory moment

* fix: cl_khr_fp16

* fix: torch has different tensor count

* fix: missing contiguous

* hmm: hmm

* fix: some fixes

* fix: typing

* feat: dont do that

* feat: typing fixes

* feat: why is this realize required?

* feat: ngl kinda odd typing

* feat: oh

* feat: remove realizes

* feat: why is this realize required?

* fix: hacky patch for cudacpu

* fix: without this realize pytest crashes?????

* fix: shorter line

* fix: cudacpu fixes

* fix: cudacpu fixes

* feat: real buffer

* feat: don't search when searching lmao

* fix: can't use contiguous things

* fix: no more 100GB arrays

* fix: revert

* fix: skip 7 and 10

* feat: working ish beam

* feat: minimize changes

* feat: seed 0 stable diffusion example changed

* fix: different on ci

* fix: no beam

* feat: make threefry optional

* fix: check value

* fix: unused import

* feat: threefry default

* fix: 5d

* feat: allow non upcast div

* fix: 5d better

* fix: 5d better

* fix: save all dtype

* feat: proper error

* feat: lazyop key

* fix: check float

* feat: try removing this realize now

* feat: disable threefry for uops hip tensor cores

* feat: don't need that

* feat: only check upcast

* fix: disable threefry for some metal tests

* feat: disable for metal tensor uops as well

* feat: disable for most uops

* fix: disable threefry for new uops tests

* feat: multitensor

* fix: typing

* feat: threefry default off

* feat: skip threefry half rand

* feat: restore old

* fix: bad git

* clean: ruff

* feat: bfloat16 fix

* fix: :|

* feat: restore old

---------

Co-authored-by: chenyu <chenyu@fastmail.com>
2024-03-18 16:47:07 -04:00

183 lines
8.9 KiB
Python

import math
import unittest
from functools import partial
import numpy as np
import torch
from tinygrad import nn, dtypes, Tensor
from tinygrad.helpers import THREEFRY
from test.helpers import is_dtype_supported
# https://gist.github.com/devries/11405101
def ksprob(a):
fac, total, termbf = 2.0, 0.0, 0.0
a2 = -2.0 * a * a
for j in range(1, 101):
term = fac * math.exp(a2 * j * j)
total += term
if math.fabs(term) <= 0.001 * termbf or math.fabs(term) <= 1e-8 * total:
return total
fac = -fac
termbf = math.fabs(term)
return 1.0
def kstest(l1, l2):
n1, n2 = len(l1), len(l2)
l1.sort()
l2.sort()
j1, j2, d, fn1, fn2 = 0, 0, 0.0, 0.0, 0.0
while j1 < n1 and j2 < n2:
d1, d2 = l1[j1], l2[j2]
if d1 <= d2:
fn1 = (float(j1) + 1.0) / float(n1)
j1 += 1
if d2 <= d1:
fn2 = (float(j2) + 1.0) / float(n2)
j2 += 1
dtemp = math.fabs(fn2 - fn1)
if dtemp > d:
d = dtemp
ne = float(n1 * n2) / float(n1 + n2)
nesq = math.sqrt(ne)
prob = ksprob((nesq + 0.12 + 0.11 / nesq) * d)
return prob
def equal_distribution(tiny_func, torch_func=None, numpy_func=None, shape=(20, 23), alpha=0.04):
Tensor.manual_seed(1337)
torch.manual_seed(1337)
np.random.seed(1337)
assert not (torch_func is None and numpy_func is None), "no function to compare with"
x1 = tiny_func(*shape).numpy().flatten()
x2 = tiny_func(shape).numpy().flatten()
if numpy_func is not None: y = numpy_func(shape).flatten()
if torch_func is not None: z = torch_func(shape).numpy().flatten()
return (numpy_func is None or (kstest(x1, y) >= alpha and kstest(x2, y) >= alpha)) and \
(torch_func is None or (kstest(x1, z) >= alpha and kstest(x2, z) >= alpha))
def normal_test(func, shape=(20, 23), alpha=0.05): return equal_distribution(func, numpy_func=lambda x: np.random.randn(*x), shape=shape, alpha=alpha)
class TestRandomness(unittest.TestCase):
def test_rand(self):
self.assertFalse(normal_test(Tensor.rand))
self.assertTrue(equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x)))
@unittest.skipIf(THREEFRY.value, "broken with threefry")
def test_rand_half(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.half)
assert x.dtype == dtypes.half
x = x.numpy()
ones = np.take(x, np.where(x == 1))
zeros = np.take(x, np.where(x == 0))
self.assertTrue(ones.size == 0)
self.assertTrue(zeros.size > 0)
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
@unittest.skipIf(not THREEFRY.value, "not using threefry")
def test_threefly_against_reference(self):
Tensor.manual_seed(1337)
# generated using
# (jax.extend.random.threefry_2x32((np.uint32(1337), np.uint32(0x0)), np.arange(20, dtype=np.uint32)) >> 8).astype(float) / np.float32(2**24)
jr = np.array([0.30984968, 0.42723763, 0.92448753, 0.27268296, 0.48820806, 0.29587173, 0.3213513, 0.05805135, 0.4954177, 0.23303074,
0.62478125, 0.51861334, 0.24712527, 0.12718695, 0.5236074, 0.50704265, 0.9166272, 0.6918763, 0.6530086, 0.34640658])
r = Tensor.rand(20).numpy()
np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5)
@unittest.skipUnless(is_dtype_supported(dtypes.bfloat16), "need bfloat16 support")
def test_rand_bfloat16(self):
N = 128
x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16)
assert x.dtype == dtypes.bfloat16
# TODO: fix this property for bfloat16 random
# x = x.numpy()
# ones = np.take(x, np.where(x == 1))
# zeros = np.take(x, np.where(x == 0))
# self.assertTrue(ones.size == 0)
# self.assertTrue(zeros.size > 0)
equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N))
def test_randn(self):
self.assertTrue(normal_test(Tensor.randn))
self.assertTrue(equal_distribution(Tensor.randn, torch.randn, lambda x: np.random.randn(*x)))
def test_randint(self):
self.assertFalse(normal_test(Tensor.randint))
self.assertTrue(equal_distribution(partial(Tensor.randint, low=-2, high=5), numpy_func=lambda x: np.random.randint(low=-2, high=5, size=x)))
self.assertTrue(Tensor.randint(1,device="CLANG").device=="CLANG")
def test_normal(self):
self.assertTrue(normal_test(Tensor.normal))
self.assertTrue(equal_distribution(Tensor.normal, lambda x: torch.nn.init.normal_(torch.empty(x), mean=0, std=1),
lambda x: np.random.normal(loc=0, scale=1, size=x)))
def test_uniform(self):
self.assertFalse(normal_test(Tensor.uniform))
self.assertTrue(equal_distribution(Tensor.uniform, lambda x: torch.nn.init.uniform_(torch.empty(x)), lambda x: np.random.uniform(size=x)))
self.assertTrue(equal_distribution(partial(Tensor.uniform, low=-100, high=100, dtype=dtypes.int32),
numpy_func=lambda x: np.random.randint(low=-100, high=100, size=x)))
def test_scaled_uniform(self):
self.assertFalse(normal_test(Tensor.scaled_uniform))
self.assertTrue(equal_distribution(Tensor.scaled_uniform, lambda x: torch.nn.init.uniform_(torch.empty(x), a=-1, b=1) / math.sqrt(math.prod(x)),
lambda x: np.random.uniform(-1, 1, size=x) / math.sqrt(math.prod(x))))
def test_glorot_uniform(self):
self.assertFalse(normal_test(Tensor.glorot_uniform))
self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)),
lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
def test_kaiming_uniform(self):
for shape in [(128, 64, 3, 3), (20, 24)]:
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
def test_kaiming_normal(self):
for shape in [(128, 64, 3, 3), (20, 24)]:
self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape))
def test_multinomial(self):
self.assertRaises(AssertionError, lambda: Tensor(2).multinomial(1, replacement=False))
self.assertRaises(AssertionError, lambda: Tensor([1, 9]).multinomial(0, replacement=False))
def _check_with_torch(w, num_samples, replacement):
tiny_res = Tensor(w).multinomial(num_samples, replacement=replacement)
torch_res = torch.tensor(w).multinomial(num_samples, replacement=replacement)
self.assertEqual(tiny_res.shape, torch_res.shape)
if torch_res.ndim == 1:
tiny_res = tiny_res.unsqueeze(0)
torch_res = torch_res.unsqueeze(0)
for i in range(torch_res.shape[0]):
self.assertTrue(equal_distribution(lambda *_: tiny_res[i], lambda _: torch_res[i]))
_check_with_torch(w=[0.231, 0., 1., 0.5], num_samples=2000, replacement=True)
_check_with_torch(w=[[0.2, 0.8]], num_samples=2000, replacement=True) # 2D but only 1 row
_check_with_torch(w=[[0.453, 0., 1., 0.81], [0.1, 0.8, 0., 0.1]], num_samples=2000, replacement=True)
# no-replacement isn't supported, unless taking only one sample
w = [0.1, 0.9]
self.assertRaises(AssertionError, lambda: Tensor(w).multinomial(100, replacement=False))
tiny_samples = [Tensor(w).multinomial(1, replacement=False).numpy().item() for _ in range(1000)]
torch_samples = [torch.tensor(w).multinomial(1, replacement=False).item() for _ in range(1000)]
self.assertTrue(equal_distribution(lambda *_: Tensor(tiny_samples), lambda _: torch.tensor(torch_samples)))
def test_multinomial_counterexample(self):
tiny_res = Tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)
torch_res = torch.tensor([0.3, 0.6, 0.1]).multinomial(2000, replacement=True)
self.assertTrue(equal_distribution(lambda *_: tiny_res, lambda _: torch_res))
torch_res = torch.tensor([0.2, 0.7, 0.1]).multinomial(2000, replacement=True)
self.assertFalse(equal_distribution(lambda *_: tiny_res, lambda _: torch_res))
def test_conv2d_init(self):
params = (128, 256, (3,3))
assert equal_distribution(lambda *_: nn.Conv2d(*params).weight, lambda _: torch.nn.Conv2d(*params).weight.detach())
assert equal_distribution(lambda *_: nn.Conv2d(*params).bias, lambda _: torch.nn.Conv2d(*params).bias.detach())
def test_linear_init(self):
params = (64, 64)
assert equal_distribution(lambda *_: nn.Linear(*params).weight, lambda _: torch.nn.Linear(*params).weight.detach())
assert equal_distribution(lambda *_: nn.Linear(*params).bias, lambda _: torch.nn.Linear(*params).bias.detach())
def test_bn_init(self):
params = (64,)
assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).weight, lambda _: torch.nn.BatchNorm2d(*params).weight.detach())
assert equal_distribution(lambda *_: nn.BatchNorm2d(*params).bias, lambda _: torch.nn.BatchNorm2d(*params).bias.detach())
if __name__ == "__main__":
unittest.main()