From a9a1df785fe86037f95495d1c4d99e839d2d5d63 Mon Sep 17 00:00:00 2001 From: Diogo Date: Wed, 12 Jul 2023 15:52:06 -0400 Subject: [PATCH] Webgpu support (#1077) * initial commit * 81 passing * 105 passing tests * 148 passing * CI tests * install dep on ci * try opencl pkgs * try using vulkan * down to only 6 failing * refactor * cleaning up * another test skipped due to buffer limit * linter * segfault * indent fix * another segfault found * small touchups * Fix max and maxpool tests * Add constant folding * Add javascript export script * better asserts in codegen * manual upcasting * reverted token type change * skip safetensor test due to unsupported type * FIx efficientnet and all other model tests * Remove np copy * fixed indent and missing import * manually destroy the buffer * revert back to length * linter errors * removed extra val * skip broken tests * skipping more tests * Make the page pretty * Save model weights as safetensor * Fix imagenet to c test * Fix second imagenet to c bug * Async and paralel kernel compilation * workgroup support * reversed local size * fixed non local bug * correct local groups * ci experiment * removed typo * Fix define local by using shared memory * Refactor * try running on mac * match metal tests * add more workers * scope down tests * trying windows runner * fixed windows env * see how many it can do * merged master * refactor * missed refactor * increase test suite coverage * missing import * whitespace in test_efficientnet.py * getting there * fixed reset * fixed bufs * switched to cstyle * cleanup * min/max rename * one more linter issue * fixed demo * linter * testing ci chrome * add unsafe webgpu arg * add build step * remove WEBGPU from cmd line * use module * try forcing directx * trying forced metal backend * temp disable conv2d for CI * disable conv_trasnpose2d --------- Co-authored-by: 0x4d - Martin Loretz <20306567+martinloretzzz@users.noreply.github.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- .github/workflows/test.yml | 23 ++++++ .gitignore | 5 ++ examples/compile_efficientnet.py | 32 ++++---- examples/webgpu/compile_webgpu.py | 87 ++++++++++++++++++++++ examples/webgpu/index.html | 119 ++++++++++++++++++++++++++++++ setup.py | 1 + test/models/test_train.py | 1 + test/test_dtype.py | 3 +- test/test_nn.py | 2 +- test/test_ops.py | 6 +- test/test_specific_conv.py | 2 +- test/test_speed_v_torch.py | 2 +- test/test_webgpu.js | 51 +++++++++++++ test/unit/test_disk_tensor.py | 6 +- tinygrad/codegen/cstyle.py | 100 ++++++++++++++----------- tinygrad/codegen/linearizer.py | 6 +- tinygrad/codegen/wgsl.py | 54 ++++++++++++++ tinygrad/jit.py | 2 +- tinygrad/runtime/ops_webgpu.py | 40 ++++++++++ tinygrad/state.py | 2 +- 20 files changed, 467 insertions(+), 77 deletions(-) create mode 100644 examples/webgpu/compile_webgpu.py create mode 100644 examples/webgpu/index.html create mode 100644 test/test_webgpu.js create mode 100644 tinygrad/codegen/wgsl.py create mode 100644 tinygrad/runtime/ops_webgpu.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index edd3efac58..229ed788d4 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -55,7 +55,30 @@ jobs: run: awk '/```python/{flag=1;next}/```/{flag=0}flag' docs/quickstart.md > quickstart.py && PYTHONPATH=. python3 quickstart.py - name: Run Pytest run: python -m pytest -s -v -n=auto test/ + + testwebgpu: + name: WebGPU Tests + runs-on: macos-13 + steps: + - name: Checkout Code + uses: actions/checkout@v3 + - name: Set up Python 3.8 + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Install Dependencies + run: pip install -e '.[testing,webgpu]' --extra-index-url https://download.pytorch.org/whl/cpu + # - name: Set Env + # run: printf "WEBGPU=1\nWGPU_BACKEND_TYPE=D3D12\n" >> $GITHUB_ENV + - name: Run Pytest + run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -s -v -n=auto test/test_ops.py test/test_speed_v_torch.py test/test_nn.py test/test_jit.py test/test_randomness.py test/test_tensor.py test/test_assign.py test/test_conv.py test/test_nn.py test/test_custom_function.py test/test_conv_shapetracker.py + - name: Build WEBGPU Efficientnet + run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.webgpu.compile_webgpu + # - name: Install Puppeteer + # run: npm install puppeteer + # - name: Run Efficientnet + # run: node test/test_webgpu.js testimagenet: name: ImageNet to C Compile Test runs-on: ubuntu-latest diff --git a/.gitignore b/.gitignore index 6bb74b33ce..2d03c2f41f 100644 --- a/.gitignore +++ b/.gitignore @@ -31,3 +31,8 @@ extra/datasets/kits/ extra/datasets/COCO/ extra/datasets/audio* venv +examples/webgpu/net.js +examples/webgpu/net.safetensors +node_modules +package.json +package-lock.json \ No newline at end of file diff --git a/examples/compile_efficientnet.py b/examples/compile_efficientnet.py index 5d8023fec8..4d7c0a7b84 100644 --- a/examples/compile_efficientnet.py +++ b/examples/compile_efficientnet.py @@ -1,15 +1,11 @@ from models.efficientnet import EfficientNet from tinygrad.tensor import Tensor +from tinygrad.jit import TinyJit from extra.utils import fetch import ast def compile_net(run, special_names): - # functions that run the net - functions = {} - bufs = {} - bufnum = 0 - statements = [] - bufs_to_save = {} + functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0 for fxn,args in run.jit_cache: functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same cargs = [] @@ -17,26 +13,21 @@ def compile_net(run, special_names): key = id(arg) if key not in bufs: if key in special_names: - bufs[key] = (special_names[key], len(arg._buf)) + bufs[key] = (special_names[key], arg._memsz, key) else: - bufs[key] = (f"buf_{bufnum}", len(arg._buf)) + bufs[key] = (f"buf_{bufnum}", arg._memsz, key) bufnum += 1 if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name cargs.append(bufs[key][0]) - statements.append(f"{fxn.name}({', '.join(cargs)});") + statements.append((fxn.name, cargs, fxn.global_size)) return functions, statements, bufs, bufs_to_save -if __name__ == "__main__": - model = EfficientNet(0) - model.load_from_pretrained() - - from tinygrad.jit import TinyJit +def jit_model(model, the_input): @TinyJit def run(x): return model.forward(x).realize() # twice to run the JIT - the_input = Tensor.randn(1,3,224,224) the_output = run(the_input) the_output = run(the_input) @@ -47,7 +38,12 @@ if __name__ == "__main__": # TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret) special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"} + return run, special_names +if __name__ == "__main__": + model = EfficientNet(0) + model.load_from_pretrained() + run, special_names = jit_model(model, Tensor.randn(1,3,224,224)) functions, statements, bufs, bufs_to_save = compile_net(run, special_names) # c header @@ -68,13 +64,13 @@ if __name__ == "__main__": cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};") # buffers (empty + weights) - cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len in bufs.values()] + cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len,_key in bufs.values()] # the functions cprog += list(functions.values()) - # the net - cprog += ["void net() {"] + statements + ["}"] + # the net + cprog += ["void net() {"] + [f"{name}({', '.join(args)});" for (name, args, _global_size) in statements] + ["}"] cprog += [""" int main(int argc, char* argv[]) { diff --git a/examples/webgpu/compile_webgpu.py b/examples/webgpu/compile_webgpu.py new file mode 100644 index 0000000000..0fd00a5aa0 --- /dev/null +++ b/examples/webgpu/compile_webgpu.py @@ -0,0 +1,87 @@ +from os import path +from examples.compile_efficientnet import compile_net, jit_model +from models.efficientnet import EfficientNet +from tinygrad.state import get_state_dict, safe_save +from tinygrad.tensor import Tensor + +if __name__ == "__main__": + model = EfficientNet(0) + model.load_from_pretrained() + run, special_names = jit_model(model, Tensor.randn(1,3,224,224)) + functions, statements, bufs, _bufs_to_save = compile_net(run, special_names) + + state = get_state_dict(model) + weights = {id(x.lazydata.realized): name for name, x in state.items()} + safe_save(state, path.join(path.dirname(__file__), "net.safetensors")) + + kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()]) + kernel_names = ', '.join([name for (name, _args, _global_size) in statements]) + kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size) in enumerate(statements) ]) + bufs = '\n '.join([f"const {buf[0]} = " + (f"createEmptyBuf(device, {buf[1]});" if buf[2] not in weights else f"createWeightBuf(device, {buf[1]}, getTensorBuffer(safetensor, metadata['{weights[buf[2]]}']))") + ";" for buf in bufs.values()]) + + prg = f"""const getTensorMetadata = (safetensorBuffer) => {{ + const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true)); + const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength))); + return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}])); +}}; + +const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{ + return safetensorBuffer.subarray(...tensorMetadata.data_offsets); +}} + +const createEmptyBuf = (device, size) => {{ + return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }}); +}}; + +const createWeightBuf = (device, size, data) => {{ + const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }}); + new Uint8Array(buf.getMappedRange()).set(data); + buf.unmap(); + return buf; +}}; + +const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{ + const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}}); + const passEncoder = commandEncoder.beginComputePass(); + passEncoder.setPipeline(pipeline); + passEncoder.setBindGroup(0, bindGroup); + passEncoder.dispatchWorkgroups(...workgroup); + passEncoder.end(); +}}; + +{kernel_code} + +const setupNet = async (device, safetensor) => {{ + const metadata = getTensorMetadata(safetensor); + + {bufs} + + const gpuWriteBuffer = device.createBuffer({{size:input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }}); + const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }}); + + const kernels = [{kernel_names}]; + const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}}))); + + return async (data) => {{ + await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE); + new Float32Array(gpuWriteBuffer.getMappedRange()).set(data); + gpuWriteBuffer.unmap(); + + const commandEncoder = device.createCommandEncoder(); + commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size); + {kernel_calls} + commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size); + const gpuCommands = commandEncoder.finish(); + device.queue.submit([gpuCommands]); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + const resultBuffer = new Float32Array(gpuReadBuffer.size); + resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange())); + gpuReadBuffer.unmap(); + return resultBuffer; + }} +}} +""" + + with open(path.join(path.dirname(__file__), "net.js"), "w") as text_file: + text_file.write(prg) diff --git a/examples/webgpu/index.html b/examples/webgpu/index.html new file mode 100644 index 0000000000..f878db0825 --- /dev/null +++ b/examples/webgpu/index.html @@ -0,0 +1,119 @@ + + + + +tinygrad has WebGPU + + + +

WebGPU tinygrad EfficientNet!

+
+ + +
+
+ + +
+
result will go here
+
+
+ + + \ No newline at end of file diff --git a/setup.py b/setup.py index 14d3112eb2..c90ae5903b 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ setup(name='tinygrad', 'llvm': ["llvmlite"], 'cuda': ["pycuda"], 'triton': ["triton>=2.0.0.dev20221202"], + 'webgpu': ["wgpu"], 'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"], 'linting': [ "flake8", diff --git a/test/models/test_train.py b/test/models/test_train.py index c3280611be..6ea30ab0c8 100644 --- a/test/models/test_train.py +++ b/test/models/test_train.py @@ -46,6 +46,7 @@ class TestTrain(unittest.TestCase): train_one_step(model,X,Y) check_gc() + @unittest.skipIf(Device.DEFAULT == "WEBGPU", "too many buffers for webgpu") def test_vit(self): model = ViT() X = np.zeros((BS,3,224,224), dtype=np.float32) diff --git a/test/test_dtype.py b/test/test_dtype.py index 61fb6aa79d..e651523e9b 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -28,7 +28,7 @@ def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_o # for GPU, cl_khr_fp16 isn't supported (except now we don't need it!) # for LLVM, it segfaults because it can't link to the casting function -@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends") +@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends") class TestHalfDtype(unittest.TestCase): def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4]) @@ -53,6 +53,7 @@ class TestHalfDtype(unittest.TestCase): def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]]) def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]]) +@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8") class TestInt8Dtype(unittest.TestCase): def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4]) def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4]) diff --git a/test/test_nn.py b/test/test_nn.py index ca7b5dda56..0c4f1bfc15 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -95,7 +95,7 @@ class TestNN(unittest.TestCase): torch_z = torch_layer(torch_x) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) - @unittest.skipIf(getenv("CI", "") != "" and WINDOWS, "runs out of memory in CI") + @unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI") def test_conv_transpose2d(self): BS, C1, H, W = 4, 16, 224, 224 C2, K, S, P = 64, 7, 2, 1 diff --git a/test/test_ops.py b/test/test_ops.py index 562ca3cdf7..250f4a6911 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -197,12 +197,12 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x) helper_test_op([()], lambda x: x/2, lambda x: x/2) helper_test_op([()], lambda x: 2/x, lambda x: 2/x) - @unittest.skipIf(Device.DEFAULT == "METAL", "METAL has issues with -inf") + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf") def test_mul_const_naninf(self): helper_test_op([(45,65)], lambda x: x*float("inf"), lambda x: x*float("inf")) helper_test_op([(45,65)], lambda x: x*-float("inf"), lambda x: x*-float("inf")) helper_test_op([(45,65)], lambda x: x*float("nan"), lambda x: x*float("nan")) - @unittest.skipIf(Device.DEFAULT == "METAL", "METAL has issues with -inf") + @unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf") def test_div_const_naninf(self): helper_test_op([(45,65)], lambda x: x/float("inf"), lambda x: x/float("inf")) helper_test_op([(45,65)], lambda x: x/-float("inf"), lambda x: x/-float("inf")) @@ -726,7 +726,7 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(), lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5) - @unittest.skipIf(IMAGE>0, "no conv1d on images") + @unittest.skipIf((IMAGE>0 or (Device.DEFAULT == "WEBGPU" and getenv("CI","") != "")), "no conv1d on images") def test_conv1d(self): for bs in [1,8]: for cin in [1,3]: diff --git a/test/test_specific_conv.py b/test/test_specific_conv.py index 8e79095da2..1737a78d45 100644 --- a/test/test_specific_conv.py +++ b/test/test_specific_conv.py @@ -19,7 +19,7 @@ class TestSpecific(unittest.TestCase): w = Tensor.randn(2048, 512) (x @ w).reshape(1, 128, 4).contiguous().realize() - @unittest.skipIf(Device.DEFAULT == "LLVM", "Broken on LLVM") + @unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU"], "Broken on LLVM and webgpu") def test_big_vec_mul(self): # from LLaMA # 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)] diff --git a/test/test_speed_v_torch.py b/test/test_speed_v_torch.py index 3af3a3c66a..5d6e3c2ec0 100644 --- a/test/test_speed_v_torch.py +++ b/test/test_speed_v_torch.py @@ -131,7 +131,7 @@ class TestBigSpeed(unittest.TestCase): def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128) def test_large_conv_3x3(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130) -@unittest.skipIf(getenv("BIG") == 1, "only big tests") +@unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests") class TestSpeed(unittest.TestCase): def setUp(self): global prefix diff --git a/test/test_webgpu.js b/test/test_webgpu.js new file mode 100644 index 0000000000..ce12330f66 --- /dev/null +++ b/test/test_webgpu.js @@ -0,0 +1,51 @@ +const puppeteer = require('puppeteer'); +const { spawn } = require('child_process'); +const res = spawn("python", ["-m", "http.server", "8000"], { shell: true }); + +async function timeout(time) { + return new Promise((resolve) => setTimeout(resolve, time)); +} + +function cleanup(err) { + res.kill(); + if(err != null) { + console.error(err); + process.exit(1); + } +} + +async function waitForText(selector, text) { + let n = 0; + let ready = false; + while (n < 10) { + const res = await (await selector.getProperty("textContent")).jsonValue(); + console.log(`waiting for text ${text} got ${res}`); + if(res == text) { + ready = true; + break + } + await timeout(2000); + n += 1 + } + return ready; +} + +puppeteer.launch({ headless: false, args: ["--enable-unsafe-webgpu"]}).then(async browser => { + const page = await browser.newPage(); + page.on("console", message => console.log(`message from console ${message.text()}`)) + .on("pageerror", ({ message }) => console.log(`error from page ${message}`)) + + const res = await page.goto("http://localhost:8000/examples/webgpu/index.html"); + if(res.status() != 200) throw new Error("Failed to load page"); + const textSelector = await page.waitForSelector("#result"); + const buttonSelector = await page.waitForSelector("input[type=button]"); + const ready = await waitForText(textSelector, "ready"); + if(!ready) throw new Error("Failed to load page"); + await buttonSelector.evaluate(e => e.click()); + const done = await waitForText(textSelector, "hen"); + if(!done) throw new Error("failed to get hen"); + browser.close(); + cleanup(null); +}).catch(err => { + cleanup(err); +}); \ No newline at end of file diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index d5e7c74ad0..b142b7e0a5 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -1,8 +1,8 @@ import pathlib import unittest import numpy as np -from tinygrad.tensor import Tensor -from tinygrad.state import safe_load, safe_save, torch_load, get_state_dict +from tinygrad.tensor import Tensor, Device +from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load from tinygrad.helpers import dtypes from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.helpers import Timing @@ -46,7 +46,7 @@ class TestRawDiskBuffer(unittest.TestCase): tst = np.empty(test_size, np.uint8) with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"): np.copyto(tst, db.toCPU()) - +@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype") class TestSafetensors(unittest.TestCase): def test_real_safetensors(self): import torch diff --git a/tinygrad/codegen/cstyle.py b/tinygrad/codegen/cstyle.py index 0e5399c95c..309edbcfcd 100644 --- a/tinygrad/codegen/cstyle.py +++ b/tinygrad/codegen/cstyle.py @@ -1,8 +1,8 @@ -from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union +from typing import Final, Dict, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Union import math, collections from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer -from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps -from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod +from tinygrad.ops import ASTRunner, UnaryOps, BinaryOps, FusedOps +from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod, DType from tinygrad.runtime.lib import RawConst from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable from tinygrad.lazy import LazyBuffer @@ -13,6 +13,8 @@ render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b}) render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})" class CStyleLanguage(NamedTuple): + size_prefix: str = "int" + generic_var_prefix: str = "" kernel_prefix: str = "" buffer_prefix: str = "" buffer_suffix: str = "" @@ -24,9 +26,20 @@ class CStyleLanguage(NamedTuple): float4: Optional[str] = None half_prekernel: Optional[str] = None uses_vload: bool = False + external_local_bufs: bool = False + code_for_op: Dict = { + UnaryOps.EXP2: lambda x: f"exp2({x})", + UnaryOps.LOG2: lambda x: f"log2({x})", + UnaryOps.SIN: lambda x: f"sin({x})", + UnaryOps.SQRT: lambda x: f"sqrt({x})", + BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", + BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", + BinaryOps.MAX: lambda a,b: f"max({a},{b})", + BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})" + } # returns a str expression of the casted xs with the given type - def render_cast(self, x:List[str], var_dtype) -> str: + def render_cast(self, x:List[str], var_dtype:DType) -> str: assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" assert self.float4 is not None, "cast is not supported on this platform" if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})" @@ -50,9 +63,30 @@ class CStyleLanguage(NamedTuple): if output_dtype.sz > 1: return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))" return f"{buf_name}[{idx.render(render_cl)}]" + + def render_local(self, name:str, size:int): + return self.smem_prefix + f"float {name}[{size}];" + + def render_for(self, expr: str, _min:int, _max:int) -> str: + return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{" + + def render_conditional(self, cond: str, x:str, y:str) -> str: + return f"({cond})?({x}):{y}" + + def render_kernel(self, kernel:List[str], bufs:List[Union[LocalBuffer,LazyBuffer]], bufnames:List[str], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]: + tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(x.dtype, ImageDType) for x in bufs) else "" + buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else + ("const " if i > 0 else "")+self.buffer_prefix+x.dtype.name+"*"+self.buffer_suffix) for i,x in enumerate(bufs) + if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)] + prg = ''.join([f"{self.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] + + [', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + self.extra_args)] + + [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) + if self.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg]) + + return prg, global_size[::-1], local_size[::-1] # returns a str statement that does the store - def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str: + def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str: if isinstance(buf_dtype, ImageDType): assert var_dtype == dtypes._float4, "images must be float4" return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});" @@ -62,18 +96,7 @@ class CStyleLanguage(NamedTuple): return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" return f"{buf_name}[{idx.render(render_cl)}] = {var_name};" -code_for_op: Final[Dict[Op, Callable]] = { - UnaryOps.EXP2: lambda x: f"exp2({x})", - UnaryOps.LOG2: lambda x: f"log2({x})", - UnaryOps.SIN: lambda x: f"sin({x})", - UnaryOps.SQRT: lambda x: f"sqrt({x})", - BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})", - BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})", - BinaryOps.MAX: lambda a,b: f"max({a},{b})", - BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})" -} - -def add_gl_dimension(args, i, var, local_size, xid): +def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]): # for M1 tensor core stuff, support > 3 dims if i >= 2 and len(args[0]) > len(xid): # do this on the x dim for warps @@ -82,19 +105,14 @@ def add_gl_dimension(args, i, var, local_size, xid): lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1) lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1) assert lidx.max == var.max and lidx.min == var.min - return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */" + return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */" local_size.append(var.max+1) - return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */" + return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */" def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]: - prekernel: Set[str] = set() - kernel = [] - global_size = [] - local_size = [] + kernel,global_size,local_size,prekernel = [],[],[],[] pend_close = None - bufnames = [b.name if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)] - depth = 0 def kk(s): kernel.append(" "*depth+s) @@ -108,12 +126,12 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan kk("{") else: if args[1] == "global" and lang.gid: - kk(add_gl_dimension(args, i, var, global_size, lang.gid)) + kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid)) elif args[1] == "local" and lang.lid: - kk(add_gl_dimension(args, i, var, local_size, lang.lid)) + kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid)) else: if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling - kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{") + kk(lang.render_for(var.expr, var.min, var.max)) depth += 1 elif uop == UOps.BARRIER: kk(lang.barrier) @@ -139,10 +157,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}") elif uop == UOps.CONST: assert newvar is not None - kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};") + kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {lang.render_const(args, newvar.dtype)};") elif uop == UOps.ALU: assert newvar is not None - kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};") + kk(f"{lang.generic_var_prefix if newvar not in vin else ''}{newvar.render(newvar not in vin and lang.generic_var_prefix == '')} = {lang.code_for_op[args](*[x.render() for x in vin])};") elif uop == UOps.LOAD: assert newvar is not None # valids are handled here @@ -152,8 +170,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype) else: val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer)) - if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}" - kk(f"{newvar.render(True)} = {val};") + if args.valid.min == 0 and args.valid.max == 1: val = lang.render_conditional(args.valid.render(render_cl), val, lang.render_const(0.0, newvar.dtype)) + kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {val};") elif uop == UOps.STORE: assert args.valid.min == 1, "store must be valid" # TODO: instead of dtypes.float, a base type @@ -161,20 +179,14 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1: kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};") elif uop == UOps.DEFINE_LOCAL: - kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];") + if lang.external_local_bufs: + prekernel.append(lang.render_local(args[0], args[1])) + else: + kk(lang.render_local(args[0], args[1])) else: raise RuntimeError(f"failed to render {uop}") - if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n") - buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else - ("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs) - if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)] - prg = ''.join([f"{lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] + - [', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] + - [") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"]) - - if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg]) - return prg, global_size, local_size + return lang.render_kernel(kernel, bufs, bufnames, global_size, local_size, prekernel) class CStyleCodegen(Linearizer): lang: ClassVar[CStyleLanguage] = CStyleLanguage() @@ -202,5 +214,5 @@ class CStyleCodegen(Linearizer): CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK') return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), - global_size[::-1], local_size[::-1], + global_size, local_size, op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index f3bad6287b..c8398deb87 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -200,7 +200,7 @@ class Linearizer: should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1] - def global_load(self, i, idxs:Sequence[VariableOrNum], const=None) -> List[Token]: + def global_load(self, i:int, idxs:Sequence[VariableOrNum], const=None) -> List[Token]: expanded_nodes = [expand_node(idx) for idx in idxs] _idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])] upcast_dim = self.get_upcast_dim(i) @@ -217,10 +217,10 @@ class Linearizer: localtype = dtypes._float4 if amt == 4 else dtypes._float2 if idx.render() != ((idx//amt)*amt).render(): idx, valid = self.sts[i].expr_idxs(_idx) - localtype = dtypes.float + localtype = dtypes.float32 else: idx, valid = self.sts[i].expr_idxs(_idx) - localtype = dtypes.float + localtype = dtypes.float32 key = f"{localtype}{idx.render()}{valid.render()}" if key not in cache: if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid) diff --git a/tinygrad/codegen/wgsl.py b/tinygrad/codegen/wgsl.py new file mode 100644 index 0000000000..55fe468832 --- /dev/null +++ b/tinygrad/codegen/wgsl.py @@ -0,0 +1,54 @@ +from tinygrad.lazy import LazyBuffer +from tinygrad.codegen.cstyle import render_cl +from tinygrad.helpers import dtypes, DType +from tinygrad.codegen.linearizer import LocalBuffer +from tinygrad.codegen.cstyle import CStyleLanguage +from typing import List, Union +from tinygrad.runtime.lib import RawConst +from tinygrad.ops import UnaryOps, BinaryOps, FusedOps +import math +from typing import Tuple + +type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"} +class WGSLLanguage(CStyleLanguage): + gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)] + lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)] + size_prefix = "let" + barrier="workgroupBarrier();" + generic_var_prefix = "var " + external_local_bufs = True + code_for_op = { + UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})", + BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})", + BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPEQ: lambda x,y: f"f32({x}=={y})", + FusedOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", + } + + def render_local(self, name: str, size: int): + return f"var {name}: array;" + + def render_const(self, x:Union[float,int], var_dtype) -> str: + if math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f" + else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f") + return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val + + def render_kernel(self, kernel:List[str], bufs:List[Union[LocalBuffer,LazyBuffer]], bufnames:List[str], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]: + local_size = local_size[::-1] if len(local_size) else [1] + bind_it = iter(range(len(bufs))) + prg = "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var data{i}: array<{type_map[x.dtype]}>;" for i,x in enumerate(bufs) if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]) + prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn KERNEL_NAME_PLACEHOLDER(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" + return prg, global_size[::-1] if len(global_size) else [1], local_size + + def render_for(self, expr:str, _min:int, _max:int) -> str: + return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{" + + def render_conditional(self, cond:str, x:str, y:str) -> str: + return f"select(f32({y}), {x}, bool({cond}))" + + def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: + return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})" + + def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str: + if buf_dtype != var_dtype: + var_name = f"{type_map[buf_dtype]}({var_name})" + return f"{buf_name}[{idx.render(render_cl)}] = {var_name};" \ No newline at end of file diff --git a/tinygrad/jit.py b/tinygrad/jit.py index 2b022fa01b..cc139efde5 100644 --- a/tinygrad/jit.py +++ b/tinygrad/jit.py @@ -18,7 +18,7 @@ class TinyJit: def __get__(self, obj, objtype): return functools.partial(self.__call__, obj) def __call__(self, *args, **kwargs) -> Any: - if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen + if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen # NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)} assert len(input_rawbuffers) != 0, "no inputs to JIT" diff --git a/tinygrad/runtime/ops_webgpu.py b/tinygrad/runtime/ops_webgpu.py new file mode 100644 index 0000000000..e226031d63 --- /dev/null +++ b/tinygrad/runtime/ops_webgpu.py @@ -0,0 +1,40 @@ +import numpy as np +from wgpu.utils._device import get_default_device +from tinygrad.runtime.lib import RawBufferCopyIn +from tinygrad.helpers import dtypes, DType +from tinygrad.ops import Compiled +from tinygrad.codegen.cstyle import CStyleCodegen +from tinygrad.codegen.wgsl import WGSLLanguage +import wgpu + +device = get_default_device() + +class WebGPUProgram: + def __init__(self, name: str, prg: str): self.name,self.prg = name,device.create_shader_module(code=prg) + def __call__(self, global_size, local_size, *bufs, wait=False): + binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))] + bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)] + bind_group_layout = device.create_bind_group_layout(entries=binding_layouts) + pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout]) + bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings) + compute_pipeline = device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},) + command_encoder = device.create_command_encoder() + compute_pass = command_encoder.begin_compute_pass() + compute_pass.set_pipeline(compute_pipeline) + compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used + compute_pass.dispatch_workgroups(*global_size) # x y z + compute_pass.end() + device.queue.submit([command_encoder.finish()]) + +class WGSLCodegen(CStyleCodegen): + lang = WGSLLanguage() + supports_float4: bool = False + +class RawWebGPUBuffer(RawBufferCopyIn): + def __init__(self, size:int, dtype:DType): + assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64], f"dtype {dtype} not supported on WEBGPU" + super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC)) + def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x)) + def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore + +WebGpuBuffer = Compiled(RawWebGPUBuffer, WGSLCodegen, WebGPUProgram) diff --git a/tinygrad/state.py b/tinygrad/state.py index 13d7d35d71..dfa27543dd 100644 --- a/tinygrad/state.py +++ b/tinygrad/state.py @@ -24,7 +24,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str): pathlib.Path(fn).unlink(missing_ok=True) t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}") t[0:1].cast(dtypes.int64).assign([len(j)]) - t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8)) + t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu")) for k,v in safe_load(t).items(): v.assign(tensors[k]) # state dict