diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b2677cface..06a29c9f19 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -238,9 +238,6 @@ jobs: - if: ${{ matrix.task == 'onnx' }} name: Test MLPerf datasets run: GPU=1 python -m pytest -n=auto test/external/external_test_datasets.py --durations=20 - - if: ${{ matrix.task == 'onnx' }} - name: Test THREEFRY - run: PYTHONPATH=. THREEFRY=1 GPU=1 python3 -m pytest test/test_randomness.py test/test_jit.py --durations=20 - if: ${{ matrix.task == 'onnx' }} name: Run handcode_opt run: PYTHONPATH=. MODEL=resnet GPU=1 DEBUG=1 BS=4 HALF=0 python3 examples/handcode_opt.py diff --git a/examples/sdxl_seed0.png b/examples/sdxl_seed0.png index b80cc03e22..26569f6243 100644 Binary files a/examples/sdxl_seed0.png and b/examples/sdxl_seed0.png differ diff --git a/examples/stable_diffusion.py b/examples/stable_diffusion.py index 7c12727d61..51b8748056 100644 --- a/examples/stable_diffusion.py +++ b/examples/stable_diffusion.py @@ -218,7 +218,7 @@ class StableDiffusion: if __name__ == "__main__": default_prompt = "a horse sized cat eating a bagel" parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter) - parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion") + parser.add_argument('--steps', type=int, default=6, help="Number of steps in diffusion") parser.add_argument('--prompt', type=str, default=default_prompt, help="Phrase to render") parser.add_argument('--out', type=str, default=Path(tempfile.gettempdir()) / "rendered.png", help="Output filename") parser.add_argument('--noshow', action='store_true', help="Don't show the image") @@ -287,8 +287,8 @@ if __name__ == "__main__": if not args.noshow: im.show() # validation! - if args.prompt == default_prompt and args.steps == 5 and args.seed == 0 and args.guidance == 7.5: + if args.prompt == default_prompt and args.steps == 6 and args.seed == 0 and args.guidance == 7.5: ref_image = Tensor(np.array(Image.open(Path(__file__).parent / "stable_diffusion_seed0.png"))) distance = (((x.cast(dtypes.float) - ref_image.cast(dtypes.float)) / ref_image.max())**2).mean().item() - assert distance < 3e-4, colored(f"validation failed with {distance=}", "red") + assert distance < 45e-5, colored(f"validation failed with {distance=}", "red") print(colored(f"output validated with {distance=}", "green")) diff --git a/examples/stable_diffusion_seed0.png b/examples/stable_diffusion_seed0.png index 31af0c2dea..17855b1d8f 100644 Binary files a/examples/stable_diffusion_seed0.png and b/examples/stable_diffusion_seed0.png differ diff --git a/test/external/openpilot/b1ab7897cbfa35981e1636fe551e4ce5.npy b/test/external/openpilot/b1ab7897cbfa35981e1636fe551e4ce5.npy index 9cb7382c80..a10a135705 100644 Binary files a/test/external/openpilot/b1ab7897cbfa35981e1636fe551e4ce5.npy and b/test/external/openpilot/b1ab7897cbfa35981e1636fe551e4ce5.npy differ diff --git a/test/test_arange.py b/test/test_arange.py index 09597cecc7..822000661d 100644 --- a/test/test_arange.py +++ b/test/test_arange.py @@ -142,8 +142,8 @@ class TestIndexing(unittest.TestCase): from tinygrad.nn.datasets import mnist X_train, Y_train, _, _ = mnist() with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0): + samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize() GlobalCounters.reset() - samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]) x = X_train[samples].numpy() y = Y_train[samples].numpy() assert GlobalCounters.global_ops < op_limit, f"too many ops {GlobalCounters.global_ops} != {op_limit}" diff --git a/test/test_gc.py b/test/test_gc.py index cbb9a98aa5..3732802af7 100644 --- a/test/test_gc.py +++ b/test/test_gc.py @@ -22,7 +22,7 @@ class TestGC(unittest.TestCase): Tensor.manual_seed(0) a = Tensor(np.zeros((4, 4), dtype=np.float32), requires_grad=True) b = Tensor.rand(4, 4, requires_grad=True) - assert (tensors_allocated() == 3) + assert (tensors_allocated() == 4) (a*b).mean().backward() assert (tensors_allocated() == 5) del b diff --git a/test/test_multitensor.py b/test/test_multitensor.py index b0e889809b..98b6db14ad 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -555,10 +555,15 @@ class TestMultiTensor(unittest.TestCase): # don't allow assigns that change axes t_none.assign(t_zero) - def test_rand_on_multiple_devices(self): + def test_rand_with_multiple_devices(self): with self.assertRaises(ValueError): Tensor.rand(256, device=devices_2) + def test_rand_on_multiple_devices(self): + d0_rand = Tensor.rand(256, device=d0).realize() + d1_rand = Tensor.rand(256, device=d1).realize() + assert not np.allclose(d0_rand.numpy(), d1_rand.numpy()) + def test_rand_like_on_shard(self): t = Tensor.empty((16, 16)).shard(devices_2) t2 = Tensor.rand_like(t) @@ -591,11 +596,11 @@ class TestMultiTensor(unittest.TestCase): def test_dropout_on_shard_axis(self): with Tensor.train(): - X = Tensor.ones(256).shard(devices_2, axis=0) + X = Tensor.ones(512).shard(devices_2, axis=0) output = X.dropout(0.5).numpy() unique, counts = np.unique(output, return_counts=True) assert set(unique) == {0, 2}, unique - assert 100 < counts[0] < 156, counts[0] + assert 228 < counts[0] < 284, counts[0] def test_dropout_on_uneven_shard_axis(self): with Tensor.train(): diff --git a/test/test_nn.py b/test/test_nn.py index 40c09d7d24..d598e26c07 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -452,6 +452,7 @@ class TestNN(unittest.TestCase): def test_embedding_one_kernel(self): layer = Embedding(20, 30) + layer.weight = Tensor.zeros_like(layer.weight).contiguous() a = Tensor([[1, 5, 9, 11], [12, 19, 8, 1]]) result = layer(a) diff --git a/test/test_randomness.py b/test/test_randomness.py index 14f6f1ece6..43c6ea59ee 100644 --- a/test/test_randomness.py +++ b/test/test_randomness.py @@ -4,7 +4,7 @@ from functools import partial import numpy as np import torch from tinygrad import nn, dtypes, Tensor, Device, TinyJit -from tinygrad.helpers import THREEFRY, getenv, CI +from tinygrad.helpers import getenv, CI from test.helpers import is_dtype_supported from hypothesis import given, settings, strategies as strat @@ -75,7 +75,7 @@ class TestRandomness(unittest.TestCase): assert nx[nx == 0].size > 0 equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.float16), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) - @unittest.skipIf(not THREEFRY.value, "not using threefry") + @unittest.skipIf(CI and Device.DEFAULT == "NV", "gpuocelot doesn't support certain ops needed for threefry") def test_threefly_against_reference(self): Tensor.manual_seed(1337) @@ -96,28 +96,28 @@ class TestRandomness(unittest.TestCase): np.testing.assert_allclose(jr, r) - @unittest.skipUnless(Device.DEFAULT == "GPU", "reference is on GPU") - @unittest.skipIf(not THREEFRY.value, "not using threefry") def test_threefly_against_reference_full(self): Tensor.manual_seed(1337) # reference generated using """ key0 = 1337 - key1 = 0 + key1 = int.from_bytes(hashlib.sha256(int(0).to_bytes(4)).digest(), "big") & 0xffffffff values = jax.extend.random.threefry_2x32((np.uint32(key1), np.uint32(key0)), np.arange(20, dtype=np.uint32)) + values = (values >> (32 - 23)) | np.array(1, dtype=np.float32).view(np.uint32) + values = values.view(np.float32) - 1 print(f"[{', '.join(f'{v}' for v in values)}]") """ - jr = np.array([0.7882130146026611, 0.0680311918258667, 0.6758031845092773, 0.2525523900985718, 0.5712389945983887, - 0.8758237361907959, 0.13559412956237793, 0.9069793224334717, 0.8781528472900391, 0.7737162113189697, - 0.050452232360839844, 0.1645597219467163, 0.06776463985443115, 0.09560704231262207, 0.2754603624343872, - 0.10108339786529541, 0.3488548994064331, 0.7904064655303955, 0.2519160509109497, 0.7925788164138794], dtype=np.float32) + jr = np.array([0.9073467254638672, 0.8235964775085449, 0.6872662305831909, 0.9920015335083008, 0.4941047430038452, + 0.3108327388763428, 0.09639489650726318, 0.004686474800109863, 0.8435229063034058, 0.824237585067749, + 0.5873836278915405, 0.4232727289199829, 0.2530076503753662, 0.40300023555755615, 0.03966474533081055, + 0.27904558181762695, 0.9150195121765137, 0.48057758808135986, 0.23821306228637695, 0.7676635980606079], dtype=np.float32) r = Tensor.rand(20).numpy() np.testing.assert_allclose(jr, r, atol=1e-5, rtol=1e-5) - @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL"), "no GPU CI") + @unittest.skipIf(CI and Device.DEFAULT in ("GPU", "CUDA", "METAL", "NV"), "no GPU CI") def test_threefly_tensors_cnt(self): Tensor.manual_seed(1337) @@ -141,10 +141,9 @@ class TestRandomness(unittest.TestCase): N = 128 x = Tensor.rand((2, N, N), dtype=dtypes.bfloat16) assert x.dtype == dtypes.bfloat16 - if THREEFRY.value: - nx = x.numpy() - assert nx[nx == 1].size == 0 - assert nx[nx == 0].size > 0 + nx = x.numpy() + assert nx[nx == 1].size == 0 + assert nx[nx == 0].size > 0 equal_distribution(lambda *x: Tensor.rand(*x, dtype=dtypes.bfloat16).float(), torch.rand, lambda x: np.random.rand(*x), shape=(2, N, N)) def test_rand_like(self): diff --git a/test/test_schedule.py b/test/test_schedule.py index e0ad407f4f..86f295acbb 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -47,13 +47,16 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr l.linearize() return sched +def _realize_weights(m): + for p in nn.state.get_parameters(m): p.realize() + def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs): old_default_float, dtypes.default_float = dtypes.default_float, dtype dtypes.default_float = dtype Tensor.manual_seed(0) BS, CIN = 2, 3 - img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True) - w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True) + img = Tensor.randn(BS, CIN, 64, 64, requires_grad=True).realize() + w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize() ret = Tensor.conv2d(img, w).relu().mean().backward() dtypes.default_float = old_default_float with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata]) @@ -256,12 +259,13 @@ class TestSchedule(unittest.TestCase): def test_fold_conv_batchnorm_optim(self): # this is too high - for optim, cnt in [(nn.optim.Adam, 19), (nn.optim.SGD, 17)]: + for optim, cnt in [(nn.optim.Adam, 17), (nn.optim.SGD, 15)]: with self.subTest(optim=optim.__name__): with Tensor.train(): img = Tensor.ones(1,3,4,4) c1 = nn.Conv2d(3,32,3) bn = nn.BatchNorm2d(32, track_running_stats=False) + _realize_weights([c1, bn]) opt = optim(nn.state.get_parameters([c1, bn])) img_bn = bn(c1(img)).elu().sum() opt.zero_grad() @@ -919,57 +923,63 @@ class TestSchedule(unittest.TestCase): with Tensor.train(): x = Tensor.empty(4, 64, 768) layer = nn.Linear(768, 768*4) + _realize_weights(layer) opt = nn.optim.Adam(nn.state.get_parameters(layer), lr=1e-4) layer(x).relu().sum().backward() - check_schedule(opt.schedule_step(), 11) + check_schedule(opt.schedule_step(), 9) def test_adam_conv_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,32,3) + _realize_weights(c1) opt = nn.optim.Adam(nn.state.get_parameters(c1), lr=1e-4) opt.zero_grad() c1(img).relu().sum().backward() - check_schedule(opt.schedule_step(), 11) + check_schedule(opt.schedule_step(), 9) def test_adam_2convs_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,16,3,bias=False) c2 = nn.Conv2d(16,32,3,bias=False) + _realize_weights([c1, c2]) opt = nn.optim.Adam(nn.state.get_parameters([c1, c2]), lr=1e-4) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 13) + check_schedule(opt.schedule_step(), 12) def test_sgd_conv_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,32,3) + _realize_weights(c1) opt = nn.optim.SGD(nn.state.get_parameters(c1)) opt.zero_grad() c1(img).relu().sum().backward() - check_schedule(opt.schedule_step(), 7) + check_schedule(opt.schedule_step(), 5) def test_sgd_2convs_fuse(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,16,3,bias=False) c2 = nn.Conv2d(16,32,3,bias=False) + _realize_weights([c1, c2]) opt = nn.optim.SGD(nn.state.get_parameters([c1, c2])) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 7) + check_schedule(opt.schedule_step(), 6) def test_fold_2convs_sgd_nesterov_momentum_wd(self): with Tensor.train(): img = Tensor.empty(2,3,4,4) c1 = nn.Conv2d(3,16,3,bias=False) c2 = nn.Conv2d(16,32,3,bias=False) + _realize_weights([c1, c2]) opt = nn.optim.SGD(nn.state.get_parameters([c1, c2]), nesterov=True, momentum=0.9, weight_decay=0.1) opt.zero_grad() c2(c1(img).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 9) + check_schedule(opt.schedule_step(), 8) def test_sgd_4convs_fuse(self): with Tensor.train(): @@ -978,10 +988,11 @@ class TestSchedule(unittest.TestCase): c2 = nn.Conv2d(4,8,3,bias=False) c3 = nn.Conv2d(8,16,3,bias=False) c4 = nn.Conv2d(16,32,3,bias=False) + _realize_weights([c1, c2, c3, c4]) opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - check_schedule(opt.schedule_step(), 22) + check_schedule(opt.schedule_step(), 18) def test_sgd_4convs_fuse_conv_bw(self): with Tensor.train(): @@ -990,10 +1001,11 @@ class TestSchedule(unittest.TestCase): c2 = nn.Conv2d(4,8,3,bias=False) c3 = nn.Conv2d(8,16,3,bias=False) c4 = nn.Conv2d(16,32,3,bias=False) + _realize_weights([c1, c2, c3, c4]) opt = nn.optim.SGD(nn.state.get_parameters([c1, c2, c3, c4])) opt.zero_grad() c4(c3(c2(c1(img).relu()).relu()).relu()).relu().sum().backward() - with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 19) + with Context(FUSE_CONV_BW=1): check_schedule(opt.schedule_step(), 15) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_prefer_half_buffer(self): @@ -1184,7 +1196,7 @@ class TestSchedule(unittest.TestCase): b = Tensor.rand(3, 4, 5).realize() out = (a + b).pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous() run_schedule(check_schedule(out, 1)) - np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum()) + np.testing.assert_allclose(out.numpy(), np.pad(a.numpy()+b.numpy(), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6) # multireduce spec def test_multireduce_pad_reduce_safe(self): @@ -1202,7 +1214,7 @@ class TestSchedule(unittest.TestCase): a = Tensor.rand(3, 4, 5).realize() out = a.log2().pad(((0, 1), (0, 1), (0, 1)), 1.0).sum().contiguous() run_schedule(check_schedule(out, 2)) - np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), rtol=1e-6) + np.testing.assert_allclose(out.numpy(), np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-5, rtol=1e-6) # multireduce spec def test_multireduce_pad_reduce_unsafe(self): @@ -1213,7 +1225,7 @@ class TestSchedule(unittest.TestCase): # run_schedule(check_schedule(out, 1)) run_schedule(check_schedule(out, 4)) np.testing.assert_allclose(out.numpy(), np.pad(np.log2(np.abs(np.pad(np.log2(a.numpy()), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum() + \ - b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=1e-4, rtol=1e-6) + b.numpy())), ((0, 1), (0, 1), (0, 1)), constant_values=1.0).sum(), atol=3e-4, rtol=1e-6) def test_shrink_pad_safe(self): a = Tensor.ones((3, )).contiguous().realize() @@ -1301,18 +1313,18 @@ class TestSchedule(unittest.TestCase): out = x.argmax(1) run_schedule(check_schedule(out, 3)) # TODO: push a reduceop through a reshape - def test_conv2d(self): _test_conv2d(8) - def test_conv2d_fused(self): _test_conv2d(7, FUSE_CONV_BW=1) - def test_conv2d_fused_ast_rewrite(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1) + def test_conv2d(self): _test_conv2d(7) + def test_conv2d_fused(self): _test_conv2d(6, FUSE_CONV_BW=1) + def test_conv2d_fused_ast_rewrite(self): _test_conv2d(6, FUSE_CONV_BW=1, AST_REWRITE=1) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") - def test_conv2d_half(self): _test_conv2d(8, dtype=dtypes.half) + def test_conv2d_half(self): _test_conv2d(7, dtype=dtypes.half) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @unittest.expectedFailure - def test_conv2d_fused_half(self): _test_conv2d(7, dtype=dtypes.half) + def test_conv2d_fused_half(self): _test_conv2d(5, dtype=dtypes.half) @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") @unittest.expectedFailure - def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(7, FUSE_CONV_BW=1, AST_REWRITE=1, dtype=dtypes.half) + def test_conv2d_fused_ast_rewrite_half(self): _test_conv2d(6, FUSE_CONV_BW=1, AST_REWRITE=1, dtype=dtypes.half) def test_buf_cnt_at_limit(self): _test_buf_cnt(5, buf_max=5, allowed=1) @unittest.expectedFailure @@ -1395,18 +1407,18 @@ class TestIndexing(unittest.TestCase): def test_arange_transposed(self): Tensor.manual_seed(0) - x = Tensor.randint(4, 1) + x = Tensor.randint(4, 1).realize() a = (Tensor.arange(4,)*x).T - self.check_schedule(a, 2) + self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), (np.arange(4)*x.numpy()).T) def test_arange_transposed_descendants(self): Tensor.manual_seed(0) - x = Tensor.randint(4, 1) + x = Tensor.randint(4, 1).realize() a = (Tensor.arange(4,)*x).T b = Tensor.randint(4, 4).realize() out = a+b - self.check_schedule(out, 2) + self.check_schedule(out, 1) np.testing.assert_equal(out.numpy(), (np.arange(4)*x.numpy()).T+b.numpy()) def test_arange_index(self): @@ -1415,7 +1427,7 @@ class TestIndexing(unittest.TestCase): a = Tensor.arange(10) out = (x + a[2]).sum() self.check_schedule(out, 1) - np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum()) + np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6) def test_arange_index_contiguous(self): Tensor.manual_seed(0) @@ -1423,7 +1435,7 @@ class TestIndexing(unittest.TestCase): a = Tensor.arange(10).contiguous() out = (x + a[2]).sum() self.check_schedule(out, 2) - np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum()) + np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6) def test_arange_index_child(self): Tensor.manual_seed(0) @@ -1431,7 +1443,7 @@ class TestIndexing(unittest.TestCase): a = Tensor.arange(10)+1 out = (x + a[2]).sum() self.check_schedule(out, 1) - np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum()) + np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6) def test_arange_index_contiguous_child(self): Tensor.manual_seed(0) @@ -1439,7 +1451,7 @@ class TestIndexing(unittest.TestCase): a = (Tensor.arange(10)+1).contiguous() out = (x + a[2]).sum() self.check_schedule(out, 2) - np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum()) + np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6) def test_arange_childless_base(self): a = Tensor.arange(4) @@ -1453,7 +1465,7 @@ class TestIndexing(unittest.TestCase): def test_arange_group_childless_base(self): Tensor.manual_seed(0) - x = Tensor.randint(4) + x = Tensor.randint(4).realize() a = Tensor.arange(4)+x self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), np.arange(4)+x.numpy()) @@ -1527,8 +1539,8 @@ class TestIndexing(unittest.TestCase): from tinygrad.nn.datasets import mnist import torch _, Y_train, _, _ = mnist() - samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])) - yt = Tensor.randn(BS, 10) + samples = Tensor.randint(BS:=getenv("BS", 512), high=cast(int,Y_train.shape[-1])).realize() + yt = Tensor.randn(BS, 10).realize() with Context(SPLIT_REDUCEOP=0): loss = yt.sparse_categorical_crossentropy(Y_train[samples]) self.check_schedule(loss, 6) diff --git a/test/test_search.py b/test/test_search.py index 0c554ca53a..2a2ca4255b 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -32,7 +32,7 @@ class TestTimeLinearizer(unittest.TestCase): assert all(r.size > 0 for r in rawbufs) def test_bufs_from_lin_alt(self): - a = Tensor.randn(4, 4) + a = Tensor.randn(4, 4).realize() b = a+a[0] si = [si for si in b.schedule() if si.ast.op is UOps.SINK][0] rawbufs = bufs_from_lin(k:=Kernel(si.ast)) diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 0ea8f55c38..5a4d3d540b 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -105,7 +105,7 @@ class ContextVar: def __lt__(self, x): return self.value < x DEBUG, IMAGE, BEAM, NOOPT, JIT = ContextVar("DEBUG", 0), ContextVar("IMAGE", 0), ContextVar("BEAM", 0), ContextVar("NOOPT", 0), ContextVar("JIT", 1) -WINO, THREEFRY, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("THREEFRY", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) +WINO, CAPTURING, TRACEMETA = ContextVar("WINO", 0), ContextVar("CAPTURING", 1), ContextVar("TRACEMETA", 1) GRAPH, GRAPHPATH, SAVE_SCHEDULE, RING = ContextVar("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net"), ContextVar("SAVE_SCHEDULE", 0), ContextVar("RING", 1) MULTIOUTPUT, PROFILE, PROFILEPATH = ContextVar("MULTIOUTPUT", 1), ContextVar("PROFILE", 0), ContextVar("PROFILEPATH", temp("tinygrad_profile.json")) USE_TC, TC_OPT, AMX, TRANSCENDENTAL = ContextVar("TC", 1), ContextVar("TC_OPT", 0), ContextVar("AMX", 0), ContextVar("TRANSCENDENTAL", 1) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index afa60614ce..31dfb6e987 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -8,7 +8,7 @@ import numpy as np from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, get_shape, fully_flatten, dedup -from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY, _METADATA, Metadata, TRACEMETA +from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA from tinygrad.lazy import LazyBuffer from tinygrad.multi import MultiLazyBuffer from tinygrad.ops import MetaOps, truncate @@ -438,26 +438,20 @@ class Tensor: if not dtypes.is_float(dtype := to_dtype(dtype or dtypes.default_float)): raise ValueError(f"rand only supports float dtypes, got {dtype}") if not all_int(shape:=argfix(*shape)) or not all(s >= 0 for s in shape): raise ValueError(f"invalid input {shape=}") if device is not None and not isinstance(device, str): raise ValueError(f"rand only supports single device, got {device=}") - device, had_counter = Device.canonicalize(device), False + _device = device = Device.canonicalize(device) # when using MOCKGPU and NV generate rand on CLANG - if THREEFRY and getenv("MOCKGPU") and device.startswith("NV"): _device, device = device, "CLANG" - else: _device = None + if getenv("MOCKGPU") and device.startswith("NV"): device = "CLANG" # generate per device seeds and rng counter if we haven't seen this device yet if device not in Tensor._device_seeds: - Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(device.encode()).digest(), "big") & 0xffffffff + Tensor._device_seeds[device] = int.from_bytes(hashlib.sha256(len(Tensor._device_seeds).to_bytes(4, "big")).digest(), "big") & 0xffffffff Tensor._device_rng_counters[device] = Tensor([0], device=device, dtype=dtypes.uint32, requires_grad=False) + had_counter = False else: had_counter = True - if not THREEFRY: - # for bfloat16, numpy rand passes buffer in float - if to_dtype(dtype or dtypes.default_float) == dtypes.bfloat16: - return Tensor.rand(*shape, **kwargs, device=device, dtype=dtypes.float).cast(dtypes.bfloat16) - return Tensor._metaop(MetaOps.CUSTOM, shape, arg=custom_random, device=device, dtype=dtype, **kwargs) - # if shape has 0, return zero tensor - if (num := math.ceil(((num_ := prod(shape)) * dtype.itemsize) / 4)) == 0: return Tensor.zeros(shape, device=device, dtype=dtype, **kwargs) + if (num := math.ceil(((num_ := prod(shape)) * dtype.itemsize) / 4)) == 0: return Tensor.zeros(shape, device=_device, dtype=dtype, **kwargs) # increment rng counter for devices if had_counter: Tensor._device_rng_counters[device].assign(Tensor._device_rng_counters[device] + num) @@ -3462,14 +3456,6 @@ if IMAGE: setattr(Tensor, "conv2d", Tensor.image_conv2d) setattr(Tensor, "dot", Tensor.image_dot) -# TODO: eventually remove this -def custom_random(out:Buffer): - Tensor._seed += 1 - rng = np.random.default_rng(Tensor._seed) - if out.dtype == dtypes.half: rng_np_buffer = (rng.integers(low=0, high=2047, size=out.size) / 2048).astype(np.half, copy=False) - else: rng_np_buffer = rng.random(size=out.size, dtype=np.float32).astype(dtype=_to_np_dtype(out.dtype), copy=False) - out.copyin(rng_np_buffer.data) - def _metadata_wrapper(fn): def _wrapper(*args, **kwargs): if _METADATA.get() is not None: return fn(*args, **kwargs) diff --git a/viz/test_viz.py b/viz/test_viz.py index e14cd68699..0e5bcdfc14 100644 --- a/viz/test_viz.py +++ b/viz/test_viz.py @@ -42,8 +42,8 @@ class TestViz(unittest.TestCase): def test_ctx_groups(self): contexts.clear() - schedule1 = Tensor.randn(4, 1).contiguous().schedule() - schedule2 = Tensor.randn(4, 4).contiguous().schedule() + schedule1 = Tensor.zeros(4, 1).contiguous().exp().schedule() + schedule2 = Tensor.zeros(4, 1).contiguous().exp().schedule() list(lower_schedule(schedule1)) list(lower_schedule(schedule2)) ret = load_kernels(contexts) @@ -118,8 +118,8 @@ class TestViz(unittest.TestCase): def test_dedup_ast(self): contexts.clear() - a = Tensor.randn(4, 4)+2 - b = Tensor.randn(4, 4)+2 + a = Tensor.empty(4, 4).contiguous().realize()+2 + b = Tensor.empty(4, 4).contiguous().realize()+2 Tensor.schedule(a, b) kernels = load_kernels(contexts) self.assertEqual(len(kernels), 1)