mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
nn init matches torch (#901)
This commit is contained in:
@@ -43,14 +43,14 @@ def normal_test(func, shape=(20, 23), alpha=0.05):
|
||||
y = np.random.randn(*shape).flatten()
|
||||
return kstest(x, y) >= alpha
|
||||
|
||||
def equal_distribution(tiny_func, torch_func, numpy_func, shape=(20, 23), alpha=0.05):
|
||||
def equal_distribution(tiny_func, torch_func, numpy_func=None, shape=(20, 23), alpha=0.05):
|
||||
Tensor.manual_seed(1337)
|
||||
torch.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
x = tiny_func(*shape).cpu().numpy().flatten()
|
||||
y = numpy_func(shape).flatten()
|
||||
if numpy_func is not None: y = numpy_func(shape).flatten()
|
||||
z = torch_func(shape).numpy().flatten()
|
||||
return kstest(x, y) >= alpha and kstest(x, z) >= alpha
|
||||
return (numpy_func is None or kstest(x, y) >= alpha) and kstest(x, z) >= alpha
|
||||
|
||||
class TestRandomness(unittest.TestCase):
|
||||
def test_rand(self):
|
||||
@@ -73,13 +73,12 @@ class TestRandomness(unittest.TestCase):
|
||||
self.assertFalse(normal_test(Tensor.glorot_uniform))
|
||||
self.assertTrue(equal_distribution(Tensor.glorot_uniform, lambda x: torch.nn.init.xavier_uniform_(torch.empty(x)), lambda x: (np.random.rand(*x) * 2 - 1) * math.sqrt(6 / (x[0] + math.prod(x[1:])))))
|
||||
|
||||
def test_kaiming_uniform(self, shape=(20, 23), a=0.01):
|
||||
def test_kaiming_uniform(self):
|
||||
Tensor.manual_seed(1337)
|
||||
torch.manual_seed(1337)
|
||||
np.random.seed(1337)
|
||||
|
||||
bound = (math.sqrt(3.0) * (math.sqrt(2.0 / (1 + a ** 2)) / math.sqrt(shape[1] * np.prod(shape[2:]))))
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), lambda x: np.random.uniform(low=-bound, high=bound, size=shape)))
|
||||
for shape in [(128, 64, 3, 3), (20, 24)]:
|
||||
self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user