Files
tinygrad/test/test_arange.py
Ahmed Harmouche 10618aba98 Bring back WebGPU (#7063)
* Start from andredaprato:webgpu-clean

* Fix infs

* inf wgsl function is not needed

* Emulated ulong for threefry, more tests passing

* Randomness tests passing

* Update model export to support new changes in webgpu, efficientnet export works again

* Simplify shift emulation in wgsl

* Delete test file

* Fix bigger than u32 u32 literal

* Why was skip copies added here?

* Python3.12 for webgpu tests

* Fix model export syntax error

* Get test ops passing with some skips

* Fix lint

* Much simpler shift

* Run more tests

* Timestamp queries are not supported in CI, so skip search tests

* All fancy indexing passing

* r is ctx

* Run more dtype tests by using is_dtype_supported

* Cleanup ulong shift rendering

* UPat -> Pat, UOps -> Ops

* Pat -> UPat

* Refactor render_ushift if-else

* Pattern to avoid ulong mul

* Remove vals_dtype

* is_nan trick + rewrite, test_isnan passing

* Rewrite a * select(1, nan, gate) -> select(a, nan, gate)

* No arg, just op

* Support char, uchar, short, ushort

* Run test_index_mnis now that we have uint8

* Fix pyling

* Save 3 lines by using base Compiler

* No more long emulation

* Remove fixup_binops

* No more external_local_bufx wgsl specific cstyle modif, use base extra_pm

* Simpler, faster copyin/out

* Skip some new tests that use long

* Fix typo

* copyout touchup

* Save lines by using render_cast

* WebGL is not supported in core, delete it from is_dtype_supported

* More narrow test skips for some unary tests

* TernaryOps, UnaryOps -> Ops

* TinyGrad supports WebGPU

* StableDiffusion demo: f16tof32 gpu is a lib, update UI

* Packed load/store, no more scale_size, no core tinygrad changes

* Rename copyin, copyout

* Device -> dev

* Fix lint

* Pattern matcher rule for packed load/store

* Refactor

* Shorter packed load/store

* this should fix lint

* Fix mypy

* SD compile script working

* New SD webgpu UI

* New default prompt

* New SD weights

* Fix title when webgpu not available

* Run symbolic tests, simplify is_nan, use round_up

* Show step time on UI

* Bump minimum wgpu version to v0.19

* Fix latent

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
2024-11-26 12:26:40 +08:00

181 lines
8.2 KiB
Python

import unittest, contextlib
import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device
from tinygrad.helpers import CI, Context, getenv
from tinygrad.engine.realize import run_schedule
from tinygrad.codegen.kernel import Opt, OptOps, Kernel, KernelOptError
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.engine.search import get_kernel_actions
from tinygrad.ops import Ops
class TestArange(unittest.TestCase):
def _get_flops(self, N, opts=None):
GlobalCounters.reset()
tt = Tensor.arange(N)
sched = tt.schedule()
self.assertEqual(len(sched), 1)
k = Kernel(sched[-1].ast)
if opts is not None:
for o in opts: k.apply_opt(o)
p = k.to_program()
print(p.name)
#print(p.src)
ExecItem(CompiledRunner(p), [tt.lazydata.buffer]).run()
np.testing.assert_equal(tt.numpy(), np.arange(N))
return p.op_estimate
def test_complexity(self, opts=None, limit=None):
# add 1 to avoid divide by 0. arange is 0 flops now!
f1 = self._get_flops(256, opts) + 1
f2 = self._get_flops(2560, opts) + 1
print(f"{f1=}, {f2=}")
assert (f1 < 5000 and f2 < 5000) or (f2 / f1 < 15), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
if limit is not None and not getenv("PTX"):
# PTX counts index ALU in flops
assert f1 <= limit, f"{f1=}, {limit=}"
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=1)
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=1)
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=1)
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=1)
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1)
@unittest.skip("doesn't work yet")
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, amt=32)])
def test_all_opts(self, opts=None, exclude=None):
k = Kernel(Tensor.arange(256).schedule()[-1].ast)
if opts is not None:
for o in opts: k.apply_opt(o)
all_opts_256 = [kk.applied_opts for kk in get_kernel_actions(k, include_0=False).values()]
k = Kernel(Tensor.arange(2560).schedule()[-1].ast)
if opts is not None:
for o in opts: k.apply_opt(o)
all_opts_2560 = [kk.applied_opts for kk in get_kernel_actions(k, include_0=False).values()]
all_opts = [x for x in all_opts_256 if x in all_opts_2560]
for opts in all_opts:
if exclude is not None and opts[-1] in exclude: continue
print(opts)
self.test_complexity(opts)
def test_all_opts_w_local(self):
with contextlib.suppress(KernelOptError):
return self.test_all_opts([Opt(OptOps.LOCAL, 0, 16)], [Opt(op=OptOps.PADTO, axis=1, amt=32)])
def test_all_opts_w_upcast(self): return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4)])
def test_all_opts_w_unroll(self): return self.test_all_opts([Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
def test_all_opts_w_upcast_and_unroll(self):
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
class TestIndexing(unittest.TestCase):
def test_arange_2_reduce(self):
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
needle[1337] = 1
needle.realize()
with Context(NOOPT=1, FUSE_ARANGE=1):
GlobalCounters.reset()
# TODO: it should work without these reshapes
out = ((Tensor.arange(1,16385).reshape(16384,1)-1)*needle.reshape(16384,1)).sum()
sched = out.schedule()
assert len(sched) == 1
run_schedule(sched)
assert out.item() == 1337, f"expected 1337, got {out.item()}"
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_manual_index(self):
dataset = Tensor.rand(16384, 256).realize()
idxs = Tensor([0,3,5,6]).realize()
real_index = dataset.numpy()[idxs.numpy()]
print("*** indexing ***")
with Context(NOOPT=1, FUSE_ARANGE=1):
GlobalCounters.reset()
rng = Tensor.ones(4, 256, 16384, dtype=dtypes.int)._cumalu(axis=-1, op=Ops.ADD, _include_initial=True).reshape(4, 256, 16384, 1)
idxs = idxs.reshape(4,1,1,1).expand(4, 256, 16384, 1)
reshape_dataset = dataset.T.reshape(1, 256, 16384, 1).expand(4, 256, 16384, 1)
full = (rng==idxs).where(reshape_dataset, Tensor.zeros(4, 256, 16384, 1))
X = full.sum(axis=(2,3))
sched = X.schedule()
assert len(sched) == 1
run_schedule(sched)
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(real_index, X.numpy())
def test_index(self):
dataset = Tensor.rand(16384, 256).realize()
idxs = Tensor([0,3,5,6]).realize()
real_index = dataset.numpy()[idxs.numpy()]
print("*** indexing ***")
with Context(NOOPT=1):
GlobalCounters.reset()
X = dataset[idxs]
assert X.shape == (4,256)
sched = X.schedule()
# TODO: enable these asserts when the scheduler can handle this
#assert len(sched) == 1, f"{len(sched)} != 1"
run_schedule(sched)
#assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops}"
np.testing.assert_allclose(real_index, X.numpy())
def test_index_fused(self, noopt=1):
dataset = Tensor.rand(16384, 256).realize()
idxs = Tensor([0,3,5,6]).realize()
real_index = dataset.numpy()[idxs.numpy()]
print("*** indexing ***")
with Context(NOOPT=noopt, FUSE_ARANGE=1):
GlobalCounters.reset()
X = dataset[idxs]
assert X.shape == (4,256)
sched = X.schedule()
assert len(sched) == 2
run_schedule(sched)
assert GlobalCounters.global_ops < 4*16384, f"too many ops {GlobalCounters.global_ops} != {4*16384}"
np.testing.assert_allclose(real_index, X.numpy())
@unittest.skip("not ready")
def test_index_fused_opt(self): self.test_index_fused(0)
def test_index_fused_out_of_bounds(self):
dataset = Tensor.rand(256, 256).realize()
idxs = Tensor([-19238, -257, 256, 495, 10982377]).realize()
with Context(NOOPT=1, FUSE_ARANGE=1):
X = dataset[idxs]
np.testing.assert_equal(X.numpy(), 0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_index_mnist(self, noopt=1, op_limit=512*784*10):
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0]).realize()
GlobalCounters.reset()
x = X_train[samples].numpy()
y = Y_train[samples].numpy()
assert GlobalCounters.global_ops < op_limit, f"too many ops {GlobalCounters.global_ops} != {op_limit}"
np.testing.assert_allclose(X_train.numpy()[samples.numpy()], x)
np.testing.assert_allclose(Y_train.numpy()[samples.numpy()], y)
@unittest.skip("not ready")
def test_index_mnist_opt(self): self.test_index_mnist(0)
@unittest.skipIf(getenv("PTX") or Device.DEFAULT == "WEBGPU", "broken on ptx and WebGPU for some reason")
def test_llama_embedding(self, noopt=1, op_limit=65536):
# llama3 is 128256
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)
emb = nn.Embedding(vocab_size, embed_size)
emb_w = emb.weight.numpy()
x = Tensor([1,2,3,4])
with Context(NOOPT=noopt, FUSE_ARANGE=1):
GlobalCounters.reset()
z = emb(x).realize()
self.assertLessEqual(GlobalCounters.global_ops, op_limit)
self.assertEqual(GlobalCounters.kernel_count, 2)
if getenv("CHECK", 1):
import torch
with torch.no_grad():
torch_emb = torch.nn.Embedding(vocab_size, embed_size).eval()
torch_emb.weight[:] = torch.tensor(emb_w, dtype=torch.float32)
torch_z = torch_emb(torch.tensor(x.numpy()))
# TODO: reshape to match torch, should we do this in nn?
np.testing.assert_allclose(z.numpy().reshape(4, embed_size), torch_z.detach().numpy(), atol=1e-8, rtol=1e-8)
# at least the arange is being fused
def test_llama_embedding_opt(self): self.test_llama_embedding(0, 1_736_704_000 if CI else 5_898_240_000)
if __name__ == "__main__":
unittest.main()