Webgpu support (#1077)

* initial commit

* 81 passing

* 105 passing tests

* 148 passing

* CI tests

* install dep on ci

* try opencl pkgs

* try using vulkan

* down to only 6 failing

* refactor

* cleaning up

* another test skipped due to buffer limit

* linter

* segfault

* indent fix

* another segfault found

* small touchups

* Fix max and maxpool tests

* Add constant folding

* Add javascript export script

* better asserts in codegen

* manual upcasting

* reverted token type change

* skip safetensor test due to unsupported type

* FIx efficientnet and all other model tests

* Remove np copy

* fixed indent and missing import

* manually destroy the buffer

* revert back to length

* linter errors

* removed extra val

* skip broken tests

* skipping more tests

* Make the page pretty

* Save model weights as safetensor

* Fix imagenet to c test

* Fix second imagenet to c bug

* Async and paralel kernel compilation

* workgroup support

* reversed local size

* fixed non local bug

* correct local groups

* ci experiment

* removed typo

* Fix define local by using shared memory

* Refactor

* try running on mac

* match metal tests

* add more workers

* scope down tests

* trying windows runner

* fixed windows env

* see how many it can do

* merged master

* refactor

* missed refactor

* increase test suite coverage

* missing import

* whitespace in test_efficientnet.py

* getting there

* fixed reset

* fixed bufs

* switched to cstyle

* cleanup

* min/max rename

* one more linter issue

* fixed demo

* linter

* testing ci chrome

* add unsafe webgpu arg

* add build step

* remove WEBGPU from cmd line

* use module

* try forcing directx

* trying forced metal backend

* temp disable conv2d for CI

* disable conv_trasnpose2d

---------

Co-authored-by: 0x4d - Martin Loretz <20306567+martinloretzzz@users.noreply.github.com>
Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Diogo
2023-07-12 15:52:06 -04:00
committed by GitHub
parent 613bcd945d
commit a9a1df785f
20 changed files with 467 additions and 77 deletions

View File

@@ -55,7 +55,30 @@ jobs:
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' docs/quickstart.md > quickstart.py && PYTHONPATH=. python3 quickstart.py
- name: Run Pytest
run: python -m pytest -s -v -n=auto test/
testwebgpu:
name: WebGPU Tests
runs-on: macos-13
steps:
- name: Checkout Code
uses: actions/checkout@v3
- name: Set up Python 3.8
uses: actions/setup-python@v4
with:
python-version: 3.8
- name: Install Dependencies
run: pip install -e '.[testing,webgpu]' --extra-index-url https://download.pytorch.org/whl/cpu
# - name: Set Env
# run: printf "WEBGPU=1\nWGPU_BACKEND_TYPE=D3D12\n" >> $GITHUB_ENV
- name: Run Pytest
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -s -v -n=auto test/test_ops.py test/test_speed_v_torch.py test/test_nn.py test/test_jit.py test/test_randomness.py test/test_tensor.py test/test_assign.py test/test_conv.py test/test_nn.py test/test_custom_function.py test/test_conv_shapetracker.py
- name: Build WEBGPU Efficientnet
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.webgpu.compile_webgpu
# - name: Install Puppeteer
# run: npm install puppeteer
# - name: Run Efficientnet
# run: node test/test_webgpu.js
testimagenet:
name: ImageNet to C Compile Test
runs-on: ubuntu-latest

5
.gitignore vendored
View File

@@ -31,3 +31,8 @@ extra/datasets/kits/
extra/datasets/COCO/
extra/datasets/audio*
venv
examples/webgpu/net.js
examples/webgpu/net.safetensors
node_modules
package.json
package-lock.json

View File

@@ -1,15 +1,11 @@
from models.efficientnet import EfficientNet
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from extra.utils import fetch
import ast
def compile_net(run, special_names):
# functions that run the net
functions = {}
bufs = {}
bufnum = 0
statements = []
bufs_to_save = {}
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
for fxn,args in run.jit_cache:
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
cargs = []
@@ -17,26 +13,21 @@ def compile_net(run, special_names):
key = id(arg)
if key not in bufs:
if key in special_names:
bufs[key] = (special_names[key], len(arg._buf))
bufs[key] = (special_names[key], arg._memsz, key)
else:
bufs[key] = (f"buf_{bufnum}", len(arg._buf))
bufs[key] = (f"buf_{bufnum}", arg._memsz, key)
bufnum += 1
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
cargs.append(bufs[key][0])
statements.append(f"{fxn.name}({', '.join(cargs)});")
statements.append((fxn.name, cargs, fxn.global_size))
return functions, statements, bufs, bufs_to_save
if __name__ == "__main__":
model = EfficientNet(0)
model.load_from_pretrained()
from tinygrad.jit import TinyJit
def jit_model(model, the_input):
@TinyJit
def run(x): return model.forward(x).realize()
# twice to run the JIT
the_input = Tensor.randn(1,3,224,224)
the_output = run(the_input)
the_output = run(the_input)
@@ -47,7 +38,12 @@ if __name__ == "__main__":
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}
return run, special_names
if __name__ == "__main__":
model = EfficientNet(0)
model.load_from_pretrained()
run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
# c header
@@ -68,13 +64,13 @@ if __name__ == "__main__":
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
# buffers (empty + weights)
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len in bufs.values()]
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len,_key in bufs.values()]
# the functions
cprog += list(functions.values())
# the net
cprog += ["void net() {"] + statements + ["}"]
# the net
cprog += ["void net() {"] + [f"{name}({', '.join(args)});" for (name, args, _global_size) in statements] + ["}"]
cprog += ["""
int main(int argc, char* argv[]) {

View File

@@ -0,0 +1,87 @@
from os import path
from examples.compile_efficientnet import compile_net, jit_model
from models.efficientnet import EfficientNet
from tinygrad.state import get_state_dict, safe_save
from tinygrad.tensor import Tensor
if __name__ == "__main__":
model = EfficientNet(0)
model.load_from_pretrained()
run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
functions, statements, bufs, _bufs_to_save = compile_net(run, special_names)
state = get_state_dict(model)
weights = {id(x.lazydata.realized): name for name, x in state.items()}
safe_save(state, path.join(path.dirname(__file__), "net.safetensors"))
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
kernel_names = ', '.join([name for (name, _args, _global_size) in statements])
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size) in enumerate(statements) ])
bufs = '\n '.join([f"const {buf[0]} = " + (f"createEmptyBuf(device, {buf[1]});" if buf[2] not in weights else f"createWeightBuf(device, {buf[1]}, getTensorBuffer(safetensor, metadata['{weights[buf[2]]}']))") + ";" for buf in bufs.values()])
prg = f"""const getTensorMetadata = (safetensorBuffer) => {{
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
}};
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
}}
const createEmptyBuf = (device, size) => {{
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
}};
const createWeightBuf = (device, size, data) => {{
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
new Uint8Array(buf.getMappedRange()).set(data);
buf.unmap();
return buf;
}};
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
const passEncoder = commandEncoder.beginComputePass();
passEncoder.setPipeline(pipeline);
passEncoder.setBindGroup(0, bindGroup);
passEncoder.dispatchWorkgroups(...workgroup);
passEncoder.end();
}};
{kernel_code}
const setupNet = async (device, safetensor) => {{
const metadata = getTensorMetadata(safetensor);
{bufs}
const gpuWriteBuffer = device.createBuffer({{size:input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});
const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
const kernels = [{kernel_names}];
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
return async (data) => {{
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
new Float32Array(gpuWriteBuffer.getMappedRange()).set(data);
gpuWriteBuffer.unmap();
const commandEncoder = device.createCommandEncoder();
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
{kernel_calls}
commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size);
const gpuCommands = commandEncoder.finish();
device.queue.submit([gpuCommands]);
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
const resultBuffer = new Float32Array(gpuReadBuffer.size);
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
gpuReadBuffer.unmap();
return resultBuffer;
}}
}}
"""
with open(path.join(path.dirname(__file__), "net.js"), "w") as text_file:
text_file.write(prg)

119
examples/webgpu/index.html Normal file
View File

@@ -0,0 +1,119 @@
<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
#result { font-size: 48px; }
#time { font-size: 16px; color: grey; }
#mybox { padding: 20px; }
#resultbox { padding: 50px; }
.bigggg { font-size: 18px; margin-top: 10px; }
.bigg { font-size: 18px; }
#url { font-size: 18px; width: 70%; }
a { text-decoration: none; }
h1 { padding: 50px; padding-bottom: 0px; font-size: 36px; font-weight: normal; }
#imagebox { height:224px; width:224px; border: 1px dotted black; }
#video { height:0px; width:0px; border: 1px dotted black; object-fit: cover;}
canvas { display: none; }
* { text-align: center; font-family: monospace; }
</style>
<title>tinygrad has WebGPU</title>
<script src="./net.js"></script>
</head>
<body>
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
<div id="mybox">
<input type="text" id="url" placeholder="put url here" value="https://upload.wikimedia.org/wikipedia/commons/d/da/Norwegian_hen.jpg">
<input class="bigg" type="button" onclick="runNetWResource(document.getElementById('url').value)" value="Use URL">
</div>
<br/>
<img id="imagebox"></img>
<canvas id="canvas" width="200" height="200"> </canvas>
<div id="resultbox">
<div id="result">result will go here</div>
<div id="time"></div>
</div>
<script>
const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
const resultText = document.getElementById('result');
let labels, net;
const error = (err) => {
resultText.innerHTML = `Error: ${err}`;
throw new Error(err);
}
const getDevice = async () => {
if (!navigator.gpu) error("WebGPU not supported.");
const adapter = await navigator.gpu.requestAdapter();
return await adapter.requestDevice();
};
const timer = async (func, label = "") => {
document.getElementById('time').innerHTML = "";
const start = performance.now();
const out = await func();
const delta = (performance.now() - start).toFixed(1)
console.log(`${delta} ms ${label}`);
document.getElementById('time').innerHTML = `${delta} ms ${label}`;
return out;
}
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer());
const reorderChannelsAndRemoveAlpha = (data) => {
const out = [];
let i = 0;
for (let c = 0; c < 3; c++) {
for (let x = 0; x < 224 * 224; x++) {
out[i] = data[x * 4 + c];
i++;
}
}
return out;
};
const runNetWResource = async (resource) => {
resultText.innerHTML = "pending..."
if (resource == "") error("sir. please type in a URL");
const response = await fetch(resource)
if (!response.ok) error("sir. that is not a good URL. try a new one");
document.getElementById("imagebox").src = resource
const img = new Image();
img.crossOrigin = "Anonymous";
img.onload = () => {
URL.revokeObjectURL(img.src);
ctx.drawImage(img, 0, 0, 224, 224);
const data = ctx.getImageData(0, 0, 224, 224).data;
runNet(data)
};
img.src = resource;
}
const loadLet = async () => {
resultText.innerHTML = "loading..."
labels = await getLabels();
const safetensor = await getSavetensorBuffer();
const device = await getDevice();
net = await timer(() => setupNet(device, safetensor), "(compilation)");
resultText.innerHTML = "ready"
}
const runNet = async (data) => {
if (!net) error("Net not loaded yet.");
const input = reorderChannelsAndRemoveAlpha(Array.from(data).map((pix) => (pix / 255.0) * 0.45 - 0.225));
const out = await timer(() => net(new Float32Array(input)));
const arr = Array.from(new Float32Array(out));
const index = arr.indexOf(Math.max(...arr));
resultText.textContent = labels[index];
};
loadLet();
</script>
</body>
</html>

View File

@@ -25,6 +25,7 @@ setup(name='tinygrad',
'llvm': ["llvmlite"],
'cuda': ["pycuda"],
'triton': ["triton>=2.0.0.dev20221202"],
'webgpu': ["wgpu"],
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],
'linting': [
"flake8",

View File

@@ -46,6 +46,7 @@ class TestTrain(unittest.TestCase):
train_one_step(model,X,Y)
check_gc()
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "too many buffers for webgpu")
def test_vit(self):
model = ViT()
X = np.zeros((BS,3,224,224), dtype=np.float32)

View File

@@ -28,7 +28,7 @@ def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_o
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
# for LLVM, it segfaults because it can't link to the casting function
@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends")
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")
class TestHalfDtype(unittest.TestCase):
def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
@@ -53,6 +53,7 @@ class TestHalfDtype(unittest.TestCase):
def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
class TestInt8Dtype(unittest.TestCase):
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])

View File

@@ -95,7 +95,7 @@ class TestNN(unittest.TestCase):
torch_z = torch_layer(torch_x)
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
@unittest.skipIf(getenv("CI", "") != "" and WINDOWS, "runs out of memory in CI")
@unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
def test_conv_transpose2d(self):
BS, C1, H, W = 4, 16, 224, 224
C2, K, S, P = 64, 7, 2, 1

View File

@@ -197,12 +197,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x)
helper_test_op([()], lambda x: x/2, lambda x: x/2)
helper_test_op([()], lambda x: 2/x, lambda x: 2/x)
@unittest.skipIf(Device.DEFAULT == "METAL", "METAL has issues with -inf")
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf")
def test_mul_const_naninf(self):
helper_test_op([(45,65)], lambda x: x*float("inf"), lambda x: x*float("inf"))
helper_test_op([(45,65)], lambda x: x*-float("inf"), lambda x: x*-float("inf"))
helper_test_op([(45,65)], lambda x: x*float("nan"), lambda x: x*float("nan"))
@unittest.skipIf(Device.DEFAULT == "METAL", "METAL has issues with -inf")
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf")
def test_div_const_naninf(self):
helper_test_op([(45,65)], lambda x: x/float("inf"), lambda x: x/float("inf"))
helper_test_op([(45,65)], lambda x: x/-float("inf"), lambda x: x/-float("inf"))
@@ -726,7 +726,7 @@ class TestOps(unittest.TestCase):
lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(),
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
@unittest.skipIf(IMAGE>0, "no conv1d on images")
@unittest.skipIf((IMAGE>0 or (Device.DEFAULT == "WEBGPU" and getenv("CI","") != "")), "no conv1d on images")
def test_conv1d(self):
for bs in [1,8]:
for cin in [1,3]:

View File

@@ -19,7 +19,7 @@ class TestSpecific(unittest.TestCase):
w = Tensor.randn(2048, 512)
(x @ w).reshape(1, 128, 4).contiguous().realize()
@unittest.skipIf(Device.DEFAULT == "LLVM", "Broken on LLVM")
@unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU"], "Broken on LLVM and webgpu")
def test_big_vec_mul(self):
# from LLaMA
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]

View File

@@ -131,7 +131,7 @@ class TestBigSpeed(unittest.TestCase):
def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
def test_large_conv_3x3(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
@unittest.skipIf(getenv("BIG") == 1, "only big tests")
@unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests")
class TestSpeed(unittest.TestCase):
def setUp(self):
global prefix

51
test/test_webgpu.js Normal file
View File

@@ -0,0 +1,51 @@
const puppeteer = require('puppeteer');
const { spawn } = require('child_process');
const res = spawn("python", ["-m", "http.server", "8000"], { shell: true });
async function timeout(time) {
return new Promise((resolve) => setTimeout(resolve, time));
}
function cleanup(err) {
res.kill();
if(err != null) {
console.error(err);
process.exit(1);
}
}
async function waitForText(selector, text) {
let n = 0;
let ready = false;
while (n < 10) {
const res = await (await selector.getProperty("textContent")).jsonValue();
console.log(`waiting for text ${text} got ${res}`);
if(res == text) {
ready = true;
break
}
await timeout(2000);
n += 1
}
return ready;
}
puppeteer.launch({ headless: false, args: ["--enable-unsafe-webgpu"]}).then(async browser => {
const page = await browser.newPage();
page.on("console", message => console.log(`message from console ${message.text()}`))
.on("pageerror", ({ message }) => console.log(`error from page ${message}`))
const res = await page.goto("http://localhost:8000/examples/webgpu/index.html");
if(res.status() != 200) throw new Error("Failed to load page");
const textSelector = await page.waitForSelector("#result");
const buttonSelector = await page.waitForSelector("input[type=button]");
const ready = await waitForText(textSelector, "ready");
if(!ready) throw new Error("Failed to load page");
await buttonSelector.evaluate(e => e.click());
const done = await waitForText(textSelector, "hen");
if(!done) throw new Error("failed to get hen");
browser.close();
cleanup(null);
}).catch(err => {
cleanup(err);
});

View File

@@ -1,8 +1,8 @@
import pathlib
import unittest
import numpy as np
from tinygrad.tensor import Tensor
from tinygrad.state import safe_load, safe_save, torch_load, get_state_dict
from tinygrad.tensor import Tensor, Device
from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load
from tinygrad.helpers import dtypes
from tinygrad.runtime.ops_disk import RawDiskBuffer
from tinygrad.helpers import Timing
@@ -46,7 +46,7 @@ class TestRawDiskBuffer(unittest.TestCase):
tst = np.empty(test_size, np.uint8)
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
np.copyto(tst, db.toCPU())
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
class TestSafetensors(unittest.TestCase):
def test_real_safetensors(self):
import torch

View File

@@ -1,8 +1,8 @@
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
from typing import Final, Dict, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Union
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod
from tinygrad.ops import ASTRunner, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod, DType
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
from tinygrad.lazy import LazyBuffer
@@ -13,6 +13,8 @@ render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
class CStyleLanguage(NamedTuple):
size_prefix: str = "int"
generic_var_prefix: str = ""
kernel_prefix: str = ""
buffer_prefix: str = ""
buffer_suffix: str = ""
@@ -24,9 +26,20 @@ class CStyleLanguage(NamedTuple):
float4: Optional[str] = None
half_prekernel: Optional[str] = None
uses_vload: bool = False
external_local_bufs: bool = False
code_for_op: Dict = {
UnaryOps.EXP2: lambda x: f"exp2({x})",
UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})",
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
}
# returns a str expression of the casted xs with the given type
def render_cast(self, x:List[str], var_dtype) -> str:
def render_cast(self, x:List[str], var_dtype:DType) -> str:
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert self.float4 is not None, "cast is not supported on this platform"
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
@@ -50,9 +63,30 @@ class CStyleLanguage(NamedTuple):
if output_dtype.sz > 1:
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
return f"{buf_name}[{idx.render(render_cl)}]"
def render_local(self, name:str, size:int):
return self.smem_prefix + f"float {name}[{size}];"
def render_for(self, expr: str, _min:int, _max:int) -> str:
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str:
return f"({cond})?({x}):{y}"
def render_kernel(self, kernel:List[str], bufs:List[Union[LocalBuffer,LazyBuffer]], bufnames:List[str], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(x.dtype, ImageDType) for x in bufs) else ""
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
("const " if i > 0 else "")+self.buffer_prefix+x.dtype.name+"*"+self.buffer_suffix) for i,x in enumerate(bufs)
if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
prg = ''.join([f"{self.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + self.extra_args)] +
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
if self.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
return prg, global_size[::-1], local_size[::-1]
# returns a str statement that does the store
def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes._float4, "images must be float4"
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
@@ -62,18 +96,7 @@ class CStyleLanguage(NamedTuple):
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
code_for_op: Final[Dict[Op, Callable]] = {
UnaryOps.EXP2: lambda x: f"exp2({x})",
UnaryOps.LOG2: lambda x: f"log2({x})",
UnaryOps.SIN: lambda x: f"sin({x})",
UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
}
def add_gl_dimension(args, i, var, local_size, xid):
def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]):
# for M1 tensor core stuff, support > 3 dims
if i >= 2 and len(args[0]) > len(xid):
# do this on the x dim for warps
@@ -82,19 +105,14 @@ def add_gl_dimension(args, i, var, local_size, xid):
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
assert lidx.max == var.max and lidx.min == var.min
return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
local_size.append(var.max+1)
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
prekernel: Set[str] = set()
kernel = []
global_size = []
local_size = []
kernel,global_size,local_size,prekernel = [],[],[],[]
pend_close = None
bufnames = [b.name if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
depth = 0
def kk(s): kernel.append(" "*depth+s)
@@ -108,12 +126,12 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
kk("{")
else:
if args[1] == "global" and lang.gid:
kk(add_gl_dimension(args, i, var, global_size, lang.gid))
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
elif args[1] == "local" and lang.lid:
kk(add_gl_dimension(args, i, var, local_size, lang.lid))
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
else:
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
kk(lang.render_for(var.expr, var.min, var.max))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
@@ -139,10 +157,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}")
elif uop == UOps.CONST:
assert newvar is not None
kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};")
kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {lang.render_const(args, newvar.dtype)};")
elif uop == UOps.ALU:
assert newvar is not None
kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};")
kk(f"{lang.generic_var_prefix if newvar not in vin else ''}{newvar.render(newvar not in vin and lang.generic_var_prefix == '')} = {lang.code_for_op[args](*[x.render() for x in vin])};")
elif uop == UOps.LOAD:
assert newvar is not None
# valids are handled here
@@ -152,8 +170,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype)
else:
val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer))
if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}"
kk(f"{newvar.render(True)} = {val};")
if args.valid.min == 0 and args.valid.max == 1: val = lang.render_conditional(args.valid.render(render_cl), val, lang.render_const(0.0, newvar.dtype))
kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {val};")
elif uop == UOps.STORE:
assert args.valid.min == 1, "store must be valid"
# TODO: instead of dtypes.float, a base type
@@ -161,20 +179,14 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1:
kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};")
elif uop == UOps.DEFINE_LOCAL:
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
if lang.external_local_bufs:
prekernel.append(lang.render_local(args[0], args[1]))
else:
kk(lang.render_local(args[0], args[1]))
else:
raise RuntimeError(f"failed to render {uop}")
if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
prg = ''.join([f"{lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg])
return prg, global_size, local_size
return lang.render_kernel(kernel, bufs, bufnames, global_size, local_size, prekernel)
class CStyleCodegen(Linearizer):
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
@@ -202,5 +214,5 @@ class CStyleCodegen(Linearizer):
CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
global_size[::-1], local_size[::-1],
global_size, local_size,
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name)

View File

@@ -200,7 +200,7 @@ class Linearizer:
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
def global_load(self, i, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
def global_load(self, i:int, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
expanded_nodes = [expand_node(idx) for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
upcast_dim = self.get_upcast_dim(i)
@@ -217,10 +217,10 @@ class Linearizer:
localtype = dtypes._float4 if amt == 4 else dtypes._float2
if idx.render() != ((idx//amt)*amt).render():
idx, valid = self.sts[i].expr_idxs(_idx)
localtype = dtypes.float
localtype = dtypes.float32
else:
idx, valid = self.sts[i].expr_idxs(_idx)
localtype = dtypes.float
localtype = dtypes.float32
key = f"{localtype}{idx.render()}{valid.render()}"
if key not in cache:
if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)

54
tinygrad/codegen/wgsl.py Normal file
View File

@@ -0,0 +1,54 @@
from tinygrad.lazy import LazyBuffer
from tinygrad.codegen.cstyle import render_cl
from tinygrad.helpers import dtypes, DType
from tinygrad.codegen.linearizer import LocalBuffer
from tinygrad.codegen.cstyle import CStyleLanguage
from typing import List, Union
from tinygrad.runtime.lib import RawConst
from tinygrad.ops import UnaryOps, BinaryOps, FusedOps
import math
from typing import Tuple
type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"}
class WGSLLanguage(CStyleLanguage):
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
size_prefix = "let"
barrier="workgroupBarrier();"
generic_var_prefix = "var "
external_local_bufs = True
code_for_op = {
UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})",
BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPEQ: lambda x,y: f"f32({x}=={y})",
FusedOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})",
}
def render_local(self, name: str, size: int):
return f"var<workgroup> {name}: array<f32,{size}>;"
def render_const(self, x:Union[float,int], var_dtype) -> str:
if math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
def render_kernel(self, kernel:List[str], bufs:List[Union[LocalBuffer,LazyBuffer]], bufnames:List[str], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
local_size = local_size[::-1] if len(local_size) else [1]
bind_it = iter(range(len(bufs)))
prg = "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> data{i}: array<{type_map[x.dtype]}>;" for i,x in enumerate(bufs) if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)])
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn KERNEL_NAME_PLACEHOLDER(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
return prg, global_size[::-1] if len(global_size) else [1], local_size
def render_for(self, expr:str, _min:int, _max:int) -> str:
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
def render_conditional(self, cond:str, x:str, y:str) -> str:
return f"select(f32({y}), {x}, bool({cond}))"
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
if buf_dtype != var_dtype:
var_name = f"{type_map[buf_dtype]}({var_name})"
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"

View File

@@ -18,7 +18,7 @@ class TinyJit:
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> Any:
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
assert len(input_rawbuffers) != 0, "no inputs to JIT"

View File

@@ -0,0 +1,40 @@
import numpy as np
from wgpu.utils._device import get_default_device
from tinygrad.runtime.lib import RawBufferCopyIn
from tinygrad.helpers import dtypes, DType
from tinygrad.ops import Compiled
from tinygrad.codegen.cstyle import CStyleCodegen
from tinygrad.codegen.wgsl import WGSLLanguage
import wgpu
device = get_default_device()
class WebGPUProgram:
def __init__(self, name: str, prg: str): self.name,self.prg = name,device.create_shader_module(code=prg)
def __call__(self, global_size, local_size, *bufs, wait=False):
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)]
bind_group_layout = device.create_bind_group_layout(entries=binding_layouts)
pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)
compute_pipeline = device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
command_encoder = device.create_command_encoder()
compute_pass = command_encoder.begin_compute_pass()
compute_pass.set_pipeline(compute_pipeline)
compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
compute_pass.dispatch_workgroups(*global_size) # x y z
compute_pass.end()
device.queue.submit([command_encoder.finish()])
class WGSLCodegen(CStyleCodegen):
lang = WGSLLanguage()
supports_float4: bool = False
class RawWebGPUBuffer(RawBufferCopyIn):
def __init__(self, size:int, dtype:DType):
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64], f"dtype {dtype} not supported on WEBGPU"
super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC))
def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
WebGpuBuffer = Compiled(RawWebGPUBuffer, WGSLCodegen, WebGPUProgram)

View File

@@ -24,7 +24,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str):
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:1].cast(dtypes.int64).assign([len(j)])
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8))
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
for k,v in safe_load(t).items(): v.assign(tensors[k])
# state dict