diff --git a/test/test_dtype.py b/test/test_dtype.py index 34f7befb51..cb26729e1c 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -118,6 +118,13 @@ class TestDType(unittest.TestCase): assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}" np.testing.assert_allclose(tin, tor, atol=1e-6, rtol=1e-3) + def test_finfo(self): + if self.DTYPE not in [dtypes.float16, dtypes.bfloat16, dtypes.float32, dtypes.float64]: return + info = np.finfo(_to_np_dtype(self.DTYPE)) + assert info.bits == self.DTYPE.itemsize*8 + assert info.nexp == dtypes.finfo(self.DTYPE)[0] + assert info.nmant == dtypes.finfo(self.DTYPE)[1] + def _test_ops(a_dtype:DType, b_dtype:DType, target_dtype=None): target_dtype = target_dtype or least_upper_dtype(a_dtype, b_dtype) if not is_dtype_supported(a_dtype) or not is_dtype_supported(b_dtype) or not is_dtype_supported(target_dtype): return @@ -490,7 +497,7 @@ class TestTypeSpec(unittest.TestCase): @given(strat.sampled_from(core_dtypes), strat.sampled_from([operator.gt, operator.ge, operator.le, operator.lt, operator.eq, operator.ne])) def test_bool_ops(self, dtype, op): - assert op(Tensor.rand(4, 4, dtype=dtype), Tensor.rand(4, 4, dtype=dtype)).dtype == dtypes.bool + assert op(Tensor.ones(4, 4, dtype=dtype), Tensor.ones(4, 4, dtype=dtype)).dtype == dtypes.bool @given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints), strat.sampled_from(dtype_floats)) def test_functions_return_index(self, dtype, default_int, default_float): @@ -501,7 +508,7 @@ class TestTypeSpec(unittest.TestCase): @given(strat.sampled_from(core_dtypes), strat.sampled_from(dtype_ints)) def test_tensor_indexing_returns_same_dtype(self, data_dtype, indices_dtype): - X_data = Tensor.rand(60000, 1, 28, 28, dtype=data_dtype) + X_data = Tensor.ones(60000, 1, 28, 28, dtype=data_dtype) indices = Tensor.randint(512, high=X_data.shape[0]).cast(indices_dtype) assert X_data[indices].dtype == X_data.dtype @@ -584,10 +591,10 @@ class TestAutoCastType(unittest.TestCase): @given(strat.sampled_from(core_dtypes)) def test_broadcast_scalar(self, dt): - assert (Tensor.rand(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float) - assert (Tensor.rand(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int) + assert (Tensor.ones(4, 4, dtype=dt) + 2.3).dtype == (dt if dtypes.is_float(dt) else dtypes.default_float) + assert (Tensor.ones(4, 4, dtype=dt) + 2).dtype == (dt if dtypes.is_float(dt) or dtypes.is_int(dt) else dtypes.default_int) if Device.DEFAULT != "WEBGPU" and dt != dtypes.bool: - assert (Tensor.rand(4, 4, dtype=dt) + True).dtype == dt + assert (Tensor.ones(4, 4, dtype=dt) + True).dtype == dt def test_sum(self): assert (Tensor([0, 1], dtype=dtypes.bool)).sum().dtype == dtypes.int32 diff --git a/test/test_hcq.py b/test/test_hcq.py index 4501b489a9..4b60b0fb4d 100644 --- a/test/test_hcq.py +++ b/test/test_hcq.py @@ -136,7 +136,7 @@ class TestHCQ(unittest.TestCase): assert (val:=TestHCQ.b.lazydata.buffer.as_buffer().cast("f")[1]) == 0.0, f"got val {val}, should not be updated" def test_exec_update_fuzz(self): - a = Tensor.rand((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize() + a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize() b = a + 1 si = create_schedule([b.lazydata])[-1] k = Kernel(si.ast, opts=TestHCQ.d0.renderer) diff --git a/test/test_randomness.py b/test/test_randomness.py index 1b7b4c07e2..f6dc61a986 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -64,16 +64,16 @@ class TestRandomness(unittest.TestCase): self.assertFalse(normal_test(Tensor.rand)) self.assertTrue(equal_distribution(Tensor.rand, torch.rand, lambda x: np.random.rand(*x))) - @unittest.skipIf(THREEFRY.value, "broken with threefry") + @unittest.skipUnless(is_dtype_supported(dtypes.float16), "need bfloat16 support") def test_rand_half(self): N = 128 x = Tensor.rand((2, N, N), dtype=dtypes.half) assert x.dtype == dtypes.half x = x.numpy() - ones = np.take(x, np.where(x == 1)) - zeros = np.take(x, np.where(x == 0)) - self.assertTrue(ones.size == 0) - self.assertTrue(zeros.size > 0) + ones = x[x == 1] + zeros = x[x == 0] + assert ones.size == 0 + assert zeros.size > 0 equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) @unittest.skipIf(not THREEFRY.value, "not using threefry") @@ -149,11 +149,11 @@ class TestRandomness(unittest.TestCase): lambda x: np.random.uniform(-1, 1, size=x) * math.sqrt(6 / (x[0] + math.prod(x[1:]))))) def test_kaiming_uniform(self): - for shape in [(128, 64, 3, 3), (20, 24)]: + for shape in [(128, 64, 3, 3), (20, 24), (3, 55, 5)]: self.assertTrue(equal_distribution(Tensor.kaiming_uniform, lambda x: torch.nn.init.kaiming_uniform_(torch.empty(x)), shape=shape)) def test_kaiming_normal(self): - for shape in [(128, 64, 3, 3), (20, 24)]: + for shape in [(128, 64, 3, 3), (20, 24), (3, 55, 5)]: self.assertTrue(equal_distribution(Tensor.kaiming_normal, lambda x: torch.nn.init.kaiming_normal_(torch.empty(x)), shape=shape)) def test_multinomial(self): diff --git a/tinygrad/dtype.py b/tinygrad/dtype.py index 64931618fc..0c9166854d 100644 --- a/tinygrad/dtype.py +++ b/tinygrad/dtype.py @@ -64,6 +64,10 @@ class dtypes: if dtypes.is_int(dtype): return (2**(dtype.itemsize*8-(0 if dtypes.is_unsigned(dtype) else 1)))-1 return float("inf") if dtypes.is_float(dtype) else True @staticmethod + def finfo(dtype:DType) -> Tuple[int, int]: + if not dtypes.is_float(dtype): raise ValueError(f"{dtype} is not a floating point type") + return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype] + @staticmethod def fields() -> Dict[str, DType]: return DTYPES_DICT # TODO: priority should be higher than bool pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index f2b3c29b8d..51ab18897d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -417,6 +417,7 @@ class Tensor: print(t.numpy()) ``` """ + if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}") if (had_counter := Tensor._rng_counter is None): Tensor._rng_counter = Tensor([0], dtype=dtypes.uint32, requires_grad=False) if not all(s >= 0 for s in argfix(*shape)): raise ValueError(f"cannot create tensor with negative dimension in {shape=}") if not THREEFRY.value: @@ -426,17 +427,28 @@ class Tensor: return Tensor._metaop(MetaOps.CUSTOM, argfix(*shape), arg=custom_random, device=device, dtype=dtype, **kwargs) # threefry - if (num := prod((shape:=argfix(*shape)))) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs) + assert all_int(shape:=argfix(*shape)), f"symbolic shape not supported, {shape=}" + if (num := math.ceil(((num_ := prod(shape)) * dtype.itemsize) / 4)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs) if not had_counter: Tensor._rng_counter.assign(Tensor._rng_counter + num) counts1 = (Tensor.arange(math.ceil(num / 2), device=device, dtype=dtypes.uint32, requires_grad=False)+Tensor._rng_counter.to(device)) counts2 = counts1 + math.ceil(num / 2) + # threefry random bits x = counts2.cast(dtypes.uint64) << 32 | counts1.cast(dtypes.uint64) x = F.Threefry.apply(*x._broadcasted(Tensor._seed)) counts1, counts2 = (x & 0xffffffff).cast(dtypes.uint32), ((x >> 32) & 0xffffffff).cast(dtypes.uint32) + bits = counts1.cat(counts2)[:num] - out = counts1.cat(counts2).rshift(8).cast(dtypes.float32).div(2 ** 24)[:num] - out = out.reshape(shape).cast(dtypes.default_float if dtype is None else dtype) + # bitcast to uint with same number of bits + _, nmant = dtypes.finfo(dtype) + uint_dtype = {1: dtypes.uint8, 2: dtypes.uint16, 4: dtypes.uint32, 8: dtypes.uint64}[dtype.itemsize] + bits = bits.bitcast(uint_dtype) + # only randomize the mantissa bits and set the exponent to 1 + one = Tensor.ones_like(bits, device=bits.device, dtype=dtype).bitcast(uint_dtype) + bits = bits.rshift((dtype.itemsize * 8) - nmant).bitwise_or(one) + + # bitcast back to the original dtype + out = bits.bitcast(dtype)[:num_].sub(1).reshape(shape) out.requires_grad = kwargs.get("requires_grad") return out.contiguous()