diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index edd3efac58..229ed788d4 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -55,7 +55,30 @@ jobs:
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' docs/quickstart.md > quickstart.py && PYTHONPATH=. python3 quickstart.py
- name: Run Pytest
run: python -m pytest -s -v -n=auto test/
+
+ testwebgpu:
+ name: WebGPU Tests
+ runs-on: macos-13
+ steps:
+ - name: Checkout Code
+ uses: actions/checkout@v3
+ - name: Set up Python 3.8
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.8
+ - name: Install Dependencies
+ run: pip install -e '.[testing,webgpu]' --extra-index-url https://download.pytorch.org/whl/cpu
+ # - name: Set Env
+ # run: printf "WEBGPU=1\nWGPU_BACKEND_TYPE=D3D12\n" >> $GITHUB_ENV
+ - name: Run Pytest
+ run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -s -v -n=auto test/test_ops.py test/test_speed_v_torch.py test/test_nn.py test/test_jit.py test/test_randomness.py test/test_tensor.py test/test_assign.py test/test_conv.py test/test_nn.py test/test_custom_function.py test/test_conv_shapetracker.py
+ - name: Build WEBGPU Efficientnet
+ run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.webgpu.compile_webgpu
+ # - name: Install Puppeteer
+ # run: npm install puppeteer
+ # - name: Run Efficientnet
+ # run: node test/test_webgpu.js
testimagenet:
name: ImageNet to C Compile Test
runs-on: ubuntu-latest
diff --git a/.gitignore b/.gitignore
index 6bb74b33ce..2d03c2f41f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -31,3 +31,8 @@ extra/datasets/kits/
extra/datasets/COCO/
extra/datasets/audio*
venv
+examples/webgpu/net.js
+examples/webgpu/net.safetensors
+node_modules
+package.json
+package-lock.json
\ No newline at end of file
diff --git a/examples/compile_efficientnet.py b/examples/compile_efficientnet.py
index 5d8023fec8..4d7c0a7b84 100644
--- a/examples/compile_efficientnet.py
+++ b/examples/compile_efficientnet.py
@@ -1,15 +1,11 @@
from models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
+from tinygrad.jit import TinyJit
from extra.utils import fetch
import ast
def compile_net(run, special_names):
- # functions that run the net
- functions = {}
- bufs = {}
- bufnum = 0
- statements = []
- bufs_to_save = {}
+ functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for fxn,args in run.jit_cache:
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
cargs = []
@@ -17,26 +13,21 @@ def compile_net(run, special_names):
key = id(arg)
if key not in bufs:
if key in special_names:
- bufs[key] = (special_names[key], len(arg._buf))
+ bufs[key] = (special_names[key], arg._memsz, key)
else:
- bufs[key] = (f"buf_{bufnum}", len(arg._buf))
+ bufs[key] = (f"buf_{bufnum}", arg._memsz, key)
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])
- statements.append(f"{fxn.name}({', '.join(cargs)});")
+ statements.append((fxn.name, cargs, fxn.global_size))
return functions, statements, bufs, bufs_to_save
-if __name__ == "__main__":
- model = EfficientNet(0)
- model.load_from_pretrained()
-
- from tinygrad.jit import TinyJit
+def jit_model(model, the_input):
@TinyJit
def run(x): return model.forward(x).realize()
# twice to run the JIT
- the_input = Tensor.randn(1,3,224,224)
the_output = run(the_input)
the_output = run(the_input)
@@ -47,7 +38,12 @@ if __name__ == "__main__":
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}
+ return run, special_names
+if __name__ == "__main__":
+ model = EfficientNet(0)
+ model.load_from_pretrained()
+ run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
# c header
@@ -68,13 +64,13 @@ if __name__ == "__main__":
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
# buffers (empty + weights)
- cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len in bufs.values()]
+ cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len,_key in bufs.values()]
# the functions
cprog += list(functions.values())
- # the net
- cprog += ["void net() {"] + statements + ["}"]
+ # the net
+ cprog += ["void net() {"] + [f"{name}({', '.join(args)});" for (name, args, _global_size) in statements] + ["}"]
cprog += ["""
int main(int argc, char* argv[]) {
diff --git a/examples/webgpu/compile_webgpu.py b/examples/webgpu/compile_webgpu.py
new file mode 100644
index 0000000000..0fd00a5aa0
--- /dev/null
+++ b/examples/webgpu/compile_webgpu.py
@@ -0,0 +1,87 @@
+from os import path
+from examples.compile_efficientnet import compile_net, jit_model
+from models.efficientnet import EfficientNet
+from tinygrad.state import get_state_dict, safe_save
+from tinygrad.tensor import Tensor
+
+if __name__ == "__main__":
+ model = EfficientNet(0)
+ model.load_from_pretrained()
+ run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
+ functions, statements, bufs, _bufs_to_save = compile_net(run, special_names)
+
+ state = get_state_dict(model)
+ weights = {id(x.lazydata.realized): name for name, x in state.items()}
+ safe_save(state, path.join(path.dirname(__file__), "net.safetensors"))
+
+ kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
+ kernel_names = ', '.join([name for (name, _args, _global_size) in statements])
+ kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size) in enumerate(statements) ])
+ bufs = '\n '.join([f"const {buf[0]} = " + (f"createEmptyBuf(device, {buf[1]});" if buf[2] not in weights else f"createWeightBuf(device, {buf[1]}, getTensorBuffer(safetensor, metadata['{weights[buf[2]]}']))") + ";" for buf in bufs.values()])
+
+ prg = f"""const getTensorMetadata = (safetensorBuffer) => {{
+ const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
+ const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
+ return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
+}};
+
+const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
+ return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
+}}
+
+const createEmptyBuf = (device, size) => {{
+ return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
+}};
+
+const createWeightBuf = (device, size, data) => {{
+ const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
+ new Uint8Array(buf.getMappedRange()).set(data);
+ buf.unmap();
+ return buf;
+}};
+
+const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
+ const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
+ const passEncoder = commandEncoder.beginComputePass();
+ passEncoder.setPipeline(pipeline);
+ passEncoder.setBindGroup(0, bindGroup);
+ passEncoder.dispatchWorkgroups(...workgroup);
+ passEncoder.end();
+}};
+
+{kernel_code}
+
+const setupNet = async (device, safetensor) => {{
+ const metadata = getTensorMetadata(safetensor);
+
+ {bufs}
+
+ const gpuWriteBuffer = device.createBuffer({{size:input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});
+ const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
+
+ const kernels = [{kernel_names}];
+ const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
+
+ return async (data) => {{
+ await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
+ new Float32Array(gpuWriteBuffer.getMappedRange()).set(data);
+ gpuWriteBuffer.unmap();
+
+ const commandEncoder = device.createCommandEncoder();
+ commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
+ {kernel_calls}
+ commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size);
+ const gpuCommands = commandEncoder.finish();
+ device.queue.submit([gpuCommands]);
+
+ await gpuReadBuffer.mapAsync(GPUMapMode.READ);
+ const resultBuffer = new Float32Array(gpuReadBuffer.size);
+ resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
+ gpuReadBuffer.unmap();
+ return resultBuffer;
+ }}
+}}
+"""
+
+ with open(path.join(path.dirname(__file__), "net.js"), "w") as text_file:
+ text_file.write(prg)
diff --git a/examples/webgpu/index.html b/examples/webgpu/index.html
new file mode 100644
index 0000000000..f878db0825
--- /dev/null
+++ b/examples/webgpu/index.html
@@ -0,0 +1,119 @@
+
+
+
+
+tinygrad has WebGPU
+
+
+
+WebGPU tinygrad EfficientNet!
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/setup.py b/setup.py
index 14d3112eb2..c90ae5903b 100644
--- a/setup.py
+++ b/setup.py
@@ -25,6 +25,7 @@ setup(name='tinygrad',
'llvm': ["llvmlite"],
'cuda': ["pycuda"],
'triton': ["triton>=2.0.0.dev20221202"],
+ 'webgpu': ["wgpu"],
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],
'linting': [
"flake8",
diff --git a/test/models/test_train.py b/test/models/test_train.py
index c3280611be..6ea30ab0c8 100644
--- a/test/models/test_train.py
+++ b/test/models/test_train.py
@@ -46,6 +46,7 @@ class TestTrain(unittest.TestCase):
train_one_step(model,X,Y)
check_gc()
+ @unittest.skipIf(Device.DEFAULT == "WEBGPU", "too many buffers for webgpu")
def test_vit(self):
model = ViT()
X = np.zeros((BS,3,224,224), dtype=np.float32)
diff --git a/test/test_dtype.py b/test/test_dtype.py
index 61fb6aa79d..e651523e9b 100644
--- a/test/test_dtype.py
+++ b/test/test_dtype.py
@@ -28,7 +28,7 @@ def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_o
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
-@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends")
+@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")
class TestHalfDtype(unittest.TestCase):
def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
@@ -53,6 +53,7 @@ class TestHalfDtype(unittest.TestCase):
def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
+@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
class TestInt8Dtype(unittest.TestCase):
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
diff --git a/test/test_nn.py b/test/test_nn.py
index ca7b5dda56..0c4f1bfc15 100755
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -95,7 +95,7 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
- @unittest.skipIf(getenv("CI", "") != "" and WINDOWS, "runs out of memory in CI")
+ @unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
def test_conv_transpose2d(self):
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1
diff --git a/test/test_ops.py b/test/test_ops.py
index 562ca3cdf7..250f4a6911 100644
--- a/test/test_ops.py
+++ b/test/test_ops.py
@@ -197,12 +197,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x)
helper_test_op([()], lambda x: x/2, lambda x: x/2)
helper_test_op([()], lambda x: 2/x, lambda x: 2/x)
- @unittest.skipIf(Device.DEFAULT == "METAL", "METAL has issues with -inf")
+ @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf")
def test_mul_const_naninf(self):
helper_test_op([(45,65)], lambda x: x*float("inf"), lambda x: x*float("inf"))
helper_test_op([(45,65)], lambda x: x*-float("inf"), lambda x: x*-float("inf"))
helper_test_op([(45,65)], lambda x: x*float("nan"), lambda x: x*float("nan"))
- @unittest.skipIf(Device.DEFAULT == "METAL", "METAL has issues with -inf")
+ @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf")
def test_div_const_naninf(self):
helper_test_op([(45,65)], lambda x: x/float("inf"), lambda x: x/float("inf"))
helper_test_op([(45,65)], lambda x: x/-float("inf"), lambda x: x/-float("inf"))
@@ -726,7 +726,7 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
- @unittest.skipIf(IMAGE>0, "no conv1d on images")
+ @unittest.skipIf((IMAGE>0 or (Device.DEFAULT == "WEBGPU" and getenv("CI","") != "")), "no conv1d on images")
def test_conv1d(self):
for bs in [1,8]:
for cin in [1,3]:
diff --git a/test/test_specific_conv.py b/test/test_specific_conv.py
index 8e79095da2..1737a78d45 100644
--- a/test/test_specific_conv.py
+++ b/test/test_specific_conv.py
@@ -19,7 +19,7 @@ class TestSpecific(unittest.TestCase):
w = Tensor.randn(2048, 512)
(x @ w).reshape(1, 128, 4).contiguous().realize()
- @unittest.skipIf(Device.DEFAULT == "LLVM", "Broken on LLVM")
+ @unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU"], "Broken on LLVM and webgpu")
def test_big_vec_mul(self):
# from LLaMA
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]
diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py
index 3af3a3c66a..5d6e3c2ec0 100644
--- a/test/test_speed_v_torch.py
+++ b/test/test_speed_v_torch.py
@@ -131,7 +131,7 @@ class TestBigSpeed(unittest.TestCase):
def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
def test_large_conv_3x3(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
-@unittest.skipIf(getenv("BIG") == 1, "only big tests")
+@unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests")
class TestSpeed(unittest.TestCase):
def setUp(self):
global prefix
diff --git a/test/test_webgpu.js b/test/test_webgpu.js
new file mode 100644
index 0000000000..ce12330f66
--- /dev/null
+++ b/test/test_webgpu.js
@@ -0,0 +1,51 @@
+const puppeteer = require('puppeteer');
+const { spawn } = require('child_process');
+const res = spawn("python", ["-m", "http.server", "8000"], { shell: true });
+
+async function timeout(time) {
+ return new Promise((resolve) => setTimeout(resolve, time));
+}
+
+function cleanup(err) {
+ res.kill();
+ if(err != null) {
+ console.error(err);
+ process.exit(1);
+ }
+}
+
+async function waitForText(selector, text) {
+ let n = 0;
+ let ready = false;
+ while (n < 10) {
+ const res = await (await selector.getProperty("textContent")).jsonValue();
+ console.log(`waiting for text ${text} got ${res}`);
+ if(res == text) {
+ ready = true;
+ break
+ }
+ await timeout(2000);
+ n += 1
+ }
+ return ready;
+}
+
+puppeteer.launch({ headless: false, args: ["--enable-unsafe-webgpu"]}).then(async browser => {
+ const page = await browser.newPage();
+ page.on("console", message => console.log(`message from console ${message.text()}`))
+ .on("pageerror", ({ message }) => console.log(`error from page ${message}`))
+
+ const res = await page.goto("http://localhost:8000/examples/webgpu/index.html");
+ if(res.status() != 200) throw new Error("Failed to load page");
+ const textSelector = await page.waitForSelector("#result");
+ const buttonSelector = await page.waitForSelector("input[type=button]");
+ const ready = await waitForText(textSelector, "ready");
+ if(!ready) throw new Error("Failed to load page");
+ await buttonSelector.evaluate(e => e.click());
+ const done = await waitForText(textSelector, "hen");
+ if(!done) throw new Error("failed to get hen");
+ browser.close();
+ cleanup(null);
+}).catch(err => {
+ cleanup(err);
+});
\ No newline at end of file
diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py
index d5e7c74ad0..b142b7e0a5 100644
--- a/test/unit/test_disk_tensor.py
+++ b/test/unit/test_disk_tensor.py
@@ -1,8 +1,8 @@
import pathlib
import unittest
import numpy as np
-from tinygrad.tensor import Tensor
-from tinygrad.state import safe_load, safe_save, torch_load, get_state_dict
+from tinygrad.tensor import Tensor, Device
+from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import dtypes
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.helpers import Timing
@@ -46,7 +46,7 @@ class TestRawDiskBuffer(unittest.TestCase):
tst = np.empty(test_size, np.uint8)
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
np.copyto(tst, db.toCPU())
-
+@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
class TestSafetensors(unittest.TestCase):
def test_real_safetensors(self):
import torch
diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py
index 0e5399c95c..309edbcfcd 100644
--- a/tinygrad/codegen/cstyle.py
+++ b/tinygrad/codegen/cstyle.py
@@ -1,8 +1,8 @@
-from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
+from typing import Final, Dict, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Union
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
-from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
-from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod
+from tinygrad.ops import ASTRunner, UnaryOps, BinaryOps, FusedOps
+from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod, DType
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
from tinygrad.lazy import LazyBuffer
@@ -13,6 +13,8 @@ render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
class CStyleLanguage(NamedTuple):
+ size_prefix: str = "int"
+ generic_var_prefix: str = ""
kernel_prefix: str = ""
buffer_prefix: str = ""
buffer_suffix: str = ""
@@ -24,9 +26,20 @@ class CStyleLanguage(NamedTuple):
float4: Optional[str] = None
half_prekernel: Optional[str] = None
uses_vload: bool = False
+ external_local_bufs: bool = False
+ code_for_op: Dict = {
+ UnaryOps.EXP2: lambda x: f"exp2({x})",
+ UnaryOps.LOG2: lambda x: f"log2({x})",
+ UnaryOps.SIN: lambda x: f"sin({x})",
+ UnaryOps.SQRT: lambda x: f"sqrt({x})",
+ BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
+ BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
+ BinaryOps.MAX: lambda a,b: f"max({a},{b})",
+ BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
+ }
# returns a str expression of the casted xs with the given type
- def render_cast(self, x:List[str], var_dtype) -> str:
+ def render_cast(self, x:List[str], var_dtype:DType) -> str:
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
@@ -50,9 +63,30 @@ class CStyleLanguage(NamedTuple):
if output_dtype.sz > 1:
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
return f"{buf_name}[{idx.render(render_cl)}]"
+
+ def render_local(self, name:str, size:int):
+ return self.smem_prefix + f"float {name}[{size}];"
+
+ def render_for(self, expr: str, _min:int, _max:int) -> str:
+ return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
+
+ def render_conditional(self, cond: str, x:str, y:str) -> str:
+ return f"({cond})?({x}):{y}"
+
+ def render_kernel(self, kernel:List[str], bufs:List[Union[LocalBuffer,LazyBuffer]], bufnames:List[str], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
+ tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(x.dtype, ImageDType) for x in bufs) else ""
+ buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
+ ("const " if i > 0 else "")+self.buffer_prefix+x.dtype.name+"*"+self.buffer_suffix) for i,x in enumerate(bufs)
+ if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
+ prg = ''.join([f"{self.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
+ [', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + self.extra_args)] +
+ [") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
+ if self.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
+
+ return prg, global_size[::-1], local_size[::-1]
# returns a str statement that does the store
- def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
+ def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
@@ -62,18 +96,7 @@ class CStyleLanguage(NamedTuple):
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
-code_for_op: Final[Dict[Op, Callable]] = {
- UnaryOps.EXP2: lambda x: f"exp2({x})",
- UnaryOps.LOG2: lambda x: f"log2({x})",
- UnaryOps.SIN: lambda x: f"sin({x})",
- UnaryOps.SQRT: lambda x: f"sqrt({x})",
- BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
- BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
- BinaryOps.MAX: lambda a,b: f"max({a},{b})",
- BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
-}
-
-def add_gl_dimension(args, i, var, local_size, xid):
+def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]):
# for M1 tensor core stuff, support > 3 dims
if i >= 2 and len(args[0]) > len(xid):
# do this on the x dim for warps
@@ -82,19 +105,14 @@ def add_gl_dimension(args, i, var, local_size, xid):
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
assert lidx.max == var.max and lidx.min == var.min
- return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
+ return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
local_size.append(var.max+1)
- return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
+ return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
- prekernel: Set[str] = set()
- kernel = []
- global_size = []
- local_size = []
+ kernel,global_size,local_size,prekernel = [],[],[],[]
pend_close = None
-
bufnames = [b.name if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
-
depth = 0
def kk(s): kernel.append(" "*depth+s)
@@ -108,12 +126,12 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
kk("{")
else:
if args[1] == "global" and lang.gid:
- kk(add_gl_dimension(args, i, var, global_size, lang.gid))
+ kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
elif args[1] == "local" and lang.lid:
- kk(add_gl_dimension(args, i, var, local_size, lang.lid))
+ kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
else:
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
- kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
+ kk(lang.render_for(var.expr, var.min, var.max))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
@@ -139,10 +157,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}")
elif uop == UOps.CONST:
assert newvar is not None
- kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};")
+ kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {lang.render_const(args, newvar.dtype)};")
elif uop == UOps.ALU:
assert newvar is not None
- kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};")
+ kk(f"{lang.generic_var_prefix if newvar not in vin else ''}{newvar.render(newvar not in vin and lang.generic_var_prefix == '')} = {lang.code_for_op[args](*[x.render() for x in vin])};")
elif uop == UOps.LOAD:
assert newvar is not None
# valids are handled here
@@ -152,8 +170,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype)
else:
val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer))
- if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}"
- kk(f"{newvar.render(True)} = {val};")
+ if args.valid.min == 0 and args.valid.max == 1: val = lang.render_conditional(args.valid.render(render_cl), val, lang.render_const(0.0, newvar.dtype))
+ kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {val};")
elif uop == UOps.STORE:
assert args.valid.min == 1, "store must be valid"
# TODO: instead of dtypes.float, a base type
@@ -161,20 +179,14 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1:
kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};")
elif uop == UOps.DEFINE_LOCAL:
- kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
+ if lang.external_local_bufs:
+ prekernel.append(lang.render_local(args[0], args[1]))
+ else:
+ kk(lang.render_local(args[0], args[1]))
else:
raise RuntimeError(f"failed to render {uop}")
- if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
- buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
- ("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
- if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
- prg = ''.join([f"{lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
- [', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
- [") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
-
- if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg])
- return prg, global_size, local_size
+ return lang.render_kernel(kernel, bufs, bufnames, global_size, local_size, prekernel)
class CStyleCodegen(Linearizer):
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
@@ -202,5 +214,5 @@ class CStyleCodegen(Linearizer):
CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
- global_size[::-1], local_size[::-1],
+ global_size, local_size,
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name)
diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py
index f3bad6287b..c8398deb87 100644
--- a/tinygrad/codegen/linearizer.py
+++ b/tinygrad/codegen/linearizer.py
@@ -200,7 +200,7 @@ class Linearizer:
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
- def global_load(self, i, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
+ def global_load(self, i:int, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
expanded_nodes = [expand_node(idx) for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
upcast_dim = self.get_upcast_dim(i)
@@ -217,10 +217,10 @@ class Linearizer:
localtype = dtypes._float4 if amt == 4 else dtypes._float2
if idx.render() != ((idx//amt)*amt).render():
idx, valid = self.sts[i].expr_idxs(_idx)
- localtype = dtypes.float
+ localtype = dtypes.float32
else:
idx, valid = self.sts[i].expr_idxs(_idx)
- localtype = dtypes.float
+ localtype = dtypes.float32
key = f"{localtype}{idx.render()}{valid.render()}"
if key not in cache:
if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
diff --git a/tinygrad/codegen/wgsl.py b/tinygrad/codegen/wgsl.py
new file mode 100644
index 0000000000..55fe468832
--- /dev/null
+++ b/tinygrad/codegen/wgsl.py
@@ -0,0 +1,54 @@
+from tinygrad.lazy import LazyBuffer
+from tinygrad.codegen.cstyle import render_cl
+from tinygrad.helpers import dtypes, DType
+from tinygrad.codegen.linearizer import LocalBuffer
+from tinygrad.codegen.cstyle import CStyleLanguage
+from typing import List, Union
+from tinygrad.runtime.lib import RawConst
+from tinygrad.ops import UnaryOps, BinaryOps, FusedOps
+import math
+from typing import Tuple
+
+type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"}
+class WGSLLanguage(CStyleLanguage):
+ gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
+ lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
+ size_prefix = "let"
+ barrier="workgroupBarrier();"
+ generic_var_prefix = "var "
+ external_local_bufs = True
+ code_for_op = {
+ UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
+ BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})",
+ BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPEQ: lambda x,y: f"f32({x}=={y})",
+ FusedOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})",
+ }
+
+ def render_local(self, name: str, size: int):
+ return f"var {name}: array;"
+
+ def render_const(self, x:Union[float,int], var_dtype) -> str:
+ if math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
+ else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
+ return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
+
+ def render_kernel(self, kernel:List[str], bufs:List[Union[LocalBuffer,LazyBuffer]], bufnames:List[str], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
+ local_size = local_size[::-1] if len(local_size) else [1]
+ bind_it = iter(range(len(bufs)))
+ prg = "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var data{i}: array<{type_map[x.dtype]}>;" for i,x in enumerate(bufs) if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)])
+ prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn KERNEL_NAME_PLACEHOLDER(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}"
+ return prg, global_size[::-1] if len(global_size) else [1], local_size
+
+ def render_for(self, expr:str, _min:int, _max:int) -> str:
+ return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
+
+ def render_conditional(self, cond:str, x:str, y:str) -> str:
+ return f"select(f32({y}), {x}, bool({cond}))"
+
+ def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
+ return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
+
+ def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
+ if buf_dtype != var_dtype:
+ var_name = f"{type_map[buf_dtype]}({var_name})"
+ return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
\ No newline at end of file
diff --git a/tinygrad/jit.py b/tinygrad/jit.py
index 2b022fa01b..cc139efde5 100644
--- a/tinygrad/jit.py
+++ b/tinygrad/jit.py
@@ -18,7 +18,7 @@ class TinyJit:
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
- if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
+ if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_rawbuffers) != 0, "no inputs to JIT"
diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py
new file mode 100644
index 0000000000..e226031d63
--- /dev/null
+++ b/tinygrad/runtime/ops_webgpu.py
@@ -0,0 +1,40 @@
+import numpy as np
+from wgpu.utils._device import get_default_device
+from tinygrad.runtime.lib import RawBufferCopyIn
+from tinygrad.helpers import dtypes, DType
+from tinygrad.ops import Compiled
+from tinygrad.codegen.cstyle import CStyleCodegen
+from tinygrad.codegen.wgsl import WGSLLanguage
+import wgpu
+
+device = get_default_device()
+
+class WebGPUProgram:
+ def __init__(self, name: str, prg: str): self.name,self.prg = name,device.create_shader_module(code=prg)
+ def __call__(self, global_size, local_size, *bufs, wait=False):
+ binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
+ bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)]
+ bind_group_layout = device.create_bind_group_layout(entries=binding_layouts)
+ pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
+ bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)
+ compute_pipeline = device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
+ command_encoder = device.create_command_encoder()
+ compute_pass = command_encoder.begin_compute_pass()
+ compute_pass.set_pipeline(compute_pipeline)
+ compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
+ compute_pass.dispatch_workgroups(*global_size) # x y z
+ compute_pass.end()
+ device.queue.submit([command_encoder.finish()])
+
+class WGSLCodegen(CStyleCodegen):
+ lang = WGSLLanguage()
+ supports_float4: bool = False
+
+class RawWebGPUBuffer(RawBufferCopyIn):
+ def __init__(self, size:int, dtype:DType):
+ assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64], f"dtype {dtype} not supported on WEBGPU"
+ super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC))
+ def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
+ def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
+
+WebGpuBuffer = Compiled(RawWebGPUBuffer, WGSLCodegen, WebGPUProgram)
diff --git a/tinygrad/state.py b/tinygrad/state.py
index 13d7d35d71..dfa27543dd 100644
--- a/tinygrad/state.py
+++ b/tinygrad/state.py
@@ -24,7 +24,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str):
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:1].cast(dtypes.int64).assign([len(j)])
- t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8))
+ t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# state dict