Bring back WebGPU (#7063)

* Start from andredaprato:webgpu-clean

* Fix infs

* inf wgsl function is not needed

* Emulated ulong for threefry, more tests passing

* Randomness tests passing

* Update model export to support new changes in webgpu, efficientnet export works again

* Simplify shift emulation in wgsl

* Delete test file

* Fix bigger than u32 u32 literal

* Why was skip copies added here?

* Python3.12 for webgpu tests

* Fix model export syntax error

* Get test ops passing with some skips

* Fix lint

* Much simpler shift

* Run more tests

* Timestamp queries are not supported in CI, so skip search tests

* All fancy indexing passing

* r is ctx

* Run more dtype tests by using is_dtype_supported

* Cleanup ulong shift rendering

* UPat -> Pat, UOps -> Ops

* Pat -> UPat

* Refactor render_ushift if-else

* Pattern to avoid ulong mul

* Remove vals_dtype

* is_nan trick + rewrite, test_isnan passing

* Rewrite a * select(1, nan, gate) -> select(a, nan, gate)

* No arg, just op

* Support char, uchar, short, ushort

* Run test_index_mnis now that we have uint8

* Fix pyling

* Save 3 lines by using base Compiler

* No more long emulation

* Remove fixup_binops

* No more external_local_bufx wgsl specific cstyle modif, use base extra_pm

* Simpler, faster copyin/out

* Skip some new tests that use long

* Fix typo

* copyout touchup

* Save lines by using render_cast

* WebGL is not supported in core, delete it from is_dtype_supported

* More narrow test skips for some unary tests

* TernaryOps, UnaryOps -> Ops

* TinyGrad supports WebGPU

* StableDiffusion demo: f16tof32 gpu is a lib, update UI

* Packed load/store, no more scale_size, no core tinygrad changes

* Rename copyin, copyout

* Device -> dev

* Fix lint

* Pattern matcher rule for packed load/store

* Refactor

* Shorter packed load/store

* this should fix lint

* Fix mypy

* SD compile script working

* New SD webgpu UI

* New default prompt

* New SD weights

* Fix title when webgpu not available

* Run symbolic tests, simplify is_nan, use round_up

* Show step time on UI

* Bump minimum wgpu version to v0.19

* Fix latent

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ahmed Harmouche
2024-11-26 05:26:40 +01:00
committed by GitHub
parent ff3f2a9c1a
commit 10618aba98
18 changed files with 659 additions and 402 deletions

View File

@@ -1,6 +1,6 @@
import unittest, contextlib
import numpy as np
from tinygrad import Tensor, GlobalCounters, dtypes, nn
from tinygrad import Tensor, GlobalCounters, dtypes, nn, Device
from tinygrad.helpers import CI, Context, getenv
from tinygrad.engine.realize import run_schedule
from tinygrad.codegen.kernel import Opt, OptOps, Kernel, KernelOptError
@@ -139,7 +139,7 @@ class TestIndexing(unittest.TestCase):
np.testing.assert_equal(X.numpy(), 0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
def test_index_mnist(self, noopt=1, op_limit=512*784*5):
def test_index_mnist(self, noopt=1, op_limit=512*784*10):
from tinygrad.nn.datasets import mnist
X_train, Y_train, _, _ = mnist()
with Context(NOOPT=noopt, FUSE_ARANGE=1, SPLIT_REDUCEOP=0):
@@ -153,7 +153,7 @@ class TestIndexing(unittest.TestCase):
@unittest.skip("not ready")
def test_index_mnist_opt(self): self.test_index_mnist(0)
@unittest.skipIf(getenv("PTX"), "broken on ptx for some reason")
@unittest.skipIf(getenv("PTX") or Device.DEFAULT == "WEBGPU", "broken on ptx and WebGPU for some reason")
def test_llama_embedding(self, noopt=1, op_limit=65536):
# llama3 is 128256
vocab_size, embed_size = (10, 3) if CI else (32000, 4096)

View File

@@ -35,11 +35,11 @@ def _test_to_np(a:Tensor, np_dtype, target):
except AssertionError as e:
raise AssertionError(f"\ntensor {a.numpy()} does not match target {target} with np_dtype {np_dtype}") from e
def _assert_eq(tensor:Tensor, target_dtype:DType, target):
def _assert_eq(tensor:Tensor, target_dtype:DType, target, tol_target_dtype:float=1e-7):
if DEBUG >= 2: print(tensor.numpy())
try:
assert tensor.dtype == target_dtype
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, 1e-7))
np.testing.assert_allclose(tensor.numpy(), target, rtol={dtypes.float16:1e-3, dtypes.bfloat16:1e-2}.get(target_dtype, tol_target_dtype))
except AssertionError as e:
raise AssertionError(f"\ntensor {tensor.numpy()} dtype {tensor.dtype} does not match target {target} with dtype {target_dtype}") from e
@@ -541,7 +541,7 @@ class TestTypeSpec(unittest.TestCase):
_assert_eq(Tensor.arange(5, dtype=dtypes.int64), dtypes.int64, np.arange(5))
if is_dtype_supported(dtypes.float16):
_assert_eq(Tensor.arange(5, dtype=dtypes.float16), dtypes.float16, np.arange(5))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7))
_assert_eq(Tensor.arange(3, 9, 0.7), dtypes.default_float, np.arange(3, 9, 0.7), 1e-6 if Device.DEFAULT == "WEBGPU" else 1e-7)
_assert_eq(Tensor.arange(3, 8.5, 3), dtypes.default_float, np.arange(3, 8.5, 3))
# stop-start and step have different signs
_assert_eq(Tensor.arange(3, 5, -2), dtypes.default_int, np.arange(3, 5, -2))

View File

@@ -11,7 +11,7 @@ from tinygrad.engine.realize import run_schedule
from tinygrad.ops import GroupOp
from tinygrad.tensor import _to_np_dtype
from tinygrad.device import is_dtype_supported
import pytest
import pytest, math
pytestmark = pytest.mark.filterwarnings("ignore")
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
@@ -41,8 +41,8 @@ unary_operations = [(Tensor.exp, np.exp), (Tensor.log, np.log), (Tensor.sin, np.
# TODO: (a+b)/2 in tensor.py's maximum can overflow. This requires a new implementation of maximum that can be backpropagated
#binary_operations += [(Tensor.maximum, np.maximum)]
# TODO: CI CUDA segfaults on sin
if getenv("MOCKGPU") and Device.DEFAULT == "NV": unary_operations.remove((Tensor.sin, np.sin))
# TODO: CI CUDA segfaults on sin, WEBGPU sin is not precise enough for large numbers
if (getenv("MOCKGPU") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU": unary_operations.remove((Tensor.sin, np.sin))
class ht:
float64 = strat.floats(width=64, allow_subnormal=False)
@@ -88,6 +88,8 @@ def universal_test_cast(a, in_dtype, dtype):
np.testing.assert_equal(tensor_value.numpy(), numpy_value)
def universal_test_midcast(a, b, c, op1, op2, d1:DType, d2:DType):
# the 'inf' and 'nan' cases are wrong on WEBGPU
if (c in [math.inf, -math.inf] or math.isnan(c)) and Device.DEFAULT == "WEBGPU": return
if not isinstance(op1, tuple): op1 = (op1, op1)
if not isinstance(op2, tuple): op2 = (op2, op2)
at, bt, ct = Tensor([a], dtype=d1), Tensor([b], dtype=d1), Tensor([c], dtype=d2)
@@ -148,6 +150,7 @@ class TestDTypeALU(unittest.TestCase):
@given(ht.int32, ht.int32, strat.sampled_from(integer_binary_operations))
def test_int32(self, a, b, op): universal_test(a, b, dtypes.int32, op)
@unittest.skipUnless(is_dtype_supported(dtypes.int64, Device.DEFAULT), f"no int64 on {Device.DEFAULT}")
@given(ht.int64, ht.int64, strat.sampled_from(integer_binary_operations))
def test_int64(self, a, b, op): universal_test(a, b, dtypes.int64, op)

View File

@@ -312,7 +312,8 @@ class TestOps(unittest.TestCase):
def _test_cmp(self, fxn, reverse=True):
# test different dtypes
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0.,1,2], [2.,1,0]])
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
if is_dtype_supported(dtypes.long):
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[0,1,2], [2,1,0]])
helper_test_op(None, fxn, fxn, forward_only=True, vals=[[True, True, False], [False,True,False]])
# test broadcasting
for shps in [[(3, 4, 5), (3, 4, 5)], [(3, 4, 5), (5,)], [(5,), (3, 4, 5)]]:
@@ -382,6 +383,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, torch.isinf, Tensor.isinf, vals=[val], forward_only=True)
np.testing.assert_equal(Tensor(val).isinf(detect_positive=True, detect_negative=False).numpy(), [False, False, True, False, False])
np.testing.assert_equal(Tensor(val).isinf(detect_positive=False, detect_negative=True).numpy(), [True, False, False, False, False])
def test_isnan(self):
helper_test_op(None, torch.isnan, Tensor.isnan, vals=[[float('-inf'), 0., float('inf'), float('nan'), 1.1]], forward_only=True)
@@ -499,8 +501,9 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x//2, forward_only=True, vals=np.array([[3, 4, 5]], dtype=np.int32))
torch_idiv, tiny_idiv = functools.partial(torch.div, rounding_mode="trunc"), Tensor.idiv
helper_test_op(None, torch_idiv, tiny_idiv, forward_only=True, vals=np.array([[5, -6, 7],[1, 2, 3]], dtype=np.int32))
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
np.testing.assert_equal(x.numpy(), 2**64 - 1)
if is_dtype_supported(dtypes.uint64):
x = Tensor(2**64 - 1, dtype=dtypes.uint64).idiv(1)
np.testing.assert_equal(x.numpy(), 2**64 - 1)
def test_scalar_div(self):
helper_test_op([(45,65)], lambda x: x/255)
helper_test_op([(45,65)], lambda x: x/1)
@@ -525,6 +528,7 @@ class TestOps(unittest.TestCase):
def test_pow_full(self):
helper_test_op([(45,65), (45,65)], lambda x,y: x**y)
helper_test_op([(45,65), (45,65)], lambda x,y: x.pow(y))
def test_pow(self):
helper_test_op([(45,65)], lambda x: x**0)
helper_test_op([(45,65)], lambda x: x**1)
@@ -644,14 +648,14 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.sin())
helper_test_op([()], lambda x: x.sin())
# works on real CUDA but not CI
if not (getenv("MOCKGPU") and Device.DEFAULT == "NV"):
if not ((getenv("MOCKGPU") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU"):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.sin(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
def test_cos(self):
helper_test_op([(45,65)], lambda x: x.cos())
helper_test_op([()], lambda x: x.cos())
if not (getenv("MOCKGPU") and Device.DEFAULT == "NV"):
if not ((getenv("MOCKGPU") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU"):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
@@ -660,7 +664,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: x.tan(), low=-1.5, high=1.5)
helper_test_op([(45,65)], lambda x: x.tan(), low=-5, high=5, forward_only=True)
helper_test_op([()], lambda x: x.tan())
if not (getenv("MOCKGPU") and Device.DEFAULT == "NV"):
if not ((getenv("MOCKGPU") and Device.DEFAULT == "NV") or Device.DEFAULT == "WEBGPU"):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
@@ -994,7 +998,8 @@ class TestOps(unittest.TestCase):
np.arange(64,128,dtype=np.float32).reshape(8,8)])
def test_small_gemm_eye(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE, "not supported on these in CI/IMAGE")
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE \
or Device.DEFAULT == "WEBGPU", "not supported on these in CI/IMAGE")
def test_gemm_fp16(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3)
def test_gemm(self):
@@ -1076,8 +1081,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,3)], lambda x: x.min().mul(0.5))
helper_test_op([()], lambda x: x.min())
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).min(), lambda x: x.cast(dtypes.int32).min(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).min(), lambda x: x.cast(dtypes.bool).min(), forward_only=True, vals=[[True, False]])
@@ -1088,8 +1094,9 @@ class TestOps(unittest.TestCase):
helper_test_op([(3,4,5,6)], lambda x: x.max(axis=1)[0], lambda x: x.max(axis=1))
helper_test_op([()], lambda x: x.max())
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[0, -2**31]])
helper_test_op(None, lambda x: x.type(torch.int32).max(), lambda x: x.cast(dtypes.int32).max(), forward_only=True, vals=[[-2**31, 0]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[False, True]])
helper_test_op(None, lambda x: x.type(torch.bool).max(), lambda x: x.cast(dtypes.bool).max(), forward_only=True, vals=[[True, False]])
@@ -2343,16 +2350,19 @@ class TestOps(unittest.TestCase):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.cross_entropy(x, y, label_smoothing=ls),
lambda x,y: x.cross_entropy(y, label_smoothing=ls))
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss(self):
helper_test_op([(32,10), (32)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_3d(self):
helper_test_op([(32,10,3,3,3), (32,3,3,3)],
lambda x,y: torch.nn.functional.nll_loss(torch.nn.functional.log_softmax(x, dim=1), torch.clip(y,0).type(torch.long)),
lambda x,y: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long)), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_reductions(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32)],
@@ -2362,6 +2372,7 @@ class TestOps(unittest.TestCase):
lambda x,y: torch.nn.functional.nll_loss(x, torch.clip(y,0).type(torch.long), reduction="typo"),
lambda x,y: x.nll_loss(y.clip(0).cast(dtypes.long), reduction="typo"), expected=ValueError)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10), (32), (10)],
@@ -2369,6 +2380,7 @@ class TestOps(unittest.TestCase):
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_3d_weight(self):
for r in ("mean", "sum", "none"):
helper_test_op([(32,10,3,3,3), (32,3,3,3), (10)],
@@ -2376,6 +2388,7 @@ class TestOps(unittest.TestCase):
weight=z, reduction=r),
lambda x,y,z: x.log_softmax(axis=1).nll_loss(y.clip(0).cast(dtypes.long), weight=z, reduction=r), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_nll_loss_ignore_index(self):
logits = [[2.0, 0.5, -1.0],
[1.5, 2.5, -0.5],
@@ -2405,7 +2418,8 @@ class TestOps(unittest.TestCase):
@unittest.skipIf(Device.DEFAULT == "QCOM", "OpenCL fails to compile this (both on GPU(qcom)/QCOM backends)")
def test_cast(self):
helper_test_op([(3, 3)], lambda x: x.float())
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
if is_dtype_supported(dtypes.long):
helper_test_op(None, lambda x: x.float(), vals=[[0, 1, 2, 3]], forward_only=True)
helper_test_op(None, lambda x: x.float(), vals=[[True, False]], forward_only=True)
helper_test_op([(3, 3)], lambda x: x.int(), forward_only=True)
helper_test_op([(3, 3)], lambda x: x.bool(), forward_only=True)
@@ -2440,6 +2454,7 @@ class TestOpsUint8(unittest.TestCase):
lambda x: torch.nn.functional.interpolate((10*x).relu().type(torch.uint8), size=out_sz, mode="nearest-exact"),
lambda x: Tensor.interpolate((10*x).relu().cast('uint8'), size=out_sz, mode="nearest-exact"), forward_only=True)
@unittest.skipUnless(is_dtype_supported(dtypes.long), f"no long on {Device.DEFAULT}")
def test_min(self):
helper_test_op(None,
lambda x: x.type(torch.uint8).min(),

65
test/web/test_webgpu.js Normal file
View File

@@ -0,0 +1,65 @@
const puppeteer = require("puppeteer");
const { spawn } = require("child_process");
const res = spawn("python", ["-m", "http.server", "8000"], { shell: true });
async function timeout(time) {
return new Promise((resolve) => setTimeout(resolve, time));
}
function cleanup(err) {
console.log("cleaning up");
res.kill();
if (err != null) {
console.error(err);
process.exit(1);
}
process.exit(0);
}
async function waitForText(selector, text) {
let n = 0;
let ready = false;
while (n < 30) {
const res = await (await selector.getProperty("textContent")).jsonValue();
console.log(`waiting for text ${text} got ${res}`);
if (res == text) {
ready = true;
break;
}
await timeout(1000);
n += 1;
}
return ready;
}
async function runTest() {
const browser = await puppeteer.launch({
headless: false,
args: ["--enable-unsafe-webgpu"],
});
const page = await browser.newPage();
page
.on("console", (message) =>
console.log(`message from console ${message.text()}`),
)
.on("pageerror", ({ message }) =>
console.log(`error from page ${message}`),
);
const res = await page.goto("http://localhost:8000/examples/index.html");
if (res.status() !== 200) throw new Error("Failed to load page");
const textSelector = await page.waitForSelector("#result");
const buttonSelector = await page.waitForSelector("input[type=button]");
const ready = await waitForText(textSelector, "ready");
if (!ready) throw new Error("Failed to load page");
await buttonSelector.evaluate((e) => e.click());
const done = await waitForText(textSelector, "hen");
if (!done) throw new Error("failed to get hen");
cleanup(null);
}
runTest().catch((err) => cleanup(err));