mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Webgpu support (#1077)
* initial commit * 81 passing * 105 passing tests * 148 passing * CI tests * install dep on ci * try opencl pkgs * try using vulkan * down to only 6 failing * refactor * cleaning up * another test skipped due to buffer limit * linter * segfault * indent fix * another segfault found * small touchups * Fix max and maxpool tests * Add constant folding * Add javascript export script * better asserts in codegen * manual upcasting * reverted token type change * skip safetensor test due to unsupported type * FIx efficientnet and all other model tests * Remove np copy * fixed indent and missing import * manually destroy the buffer * revert back to length * linter errors * removed extra val * skip broken tests * skipping more tests * Make the page pretty * Save model weights as safetensor * Fix imagenet to c test * Fix second imagenet to c bug * Async and paralel kernel compilation * workgroup support * reversed local size * fixed non local bug * correct local groups * ci experiment * removed typo * Fix define local by using shared memory * Refactor * try running on mac * match metal tests * add more workers * scope down tests * trying windows runner * fixed windows env * see how many it can do * merged master * refactor * missed refactor * increase test suite coverage * missing import * whitespace in test_efficientnet.py * getting there * fixed reset * fixed bufs * switched to cstyle * cleanup * min/max rename * one more linter issue * fixed demo * linter * testing ci chrome * add unsafe webgpu arg * add build step * remove WEBGPU from cmd line * use module * try forcing directx * trying forced metal backend * temp disable conv2d for CI * disable conv_trasnpose2d --------- Co-authored-by: 0x4d - Martin Loretz <20306567+martinloretzzz@users.noreply.github.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
23
.github/workflows/test.yml
vendored
23
.github/workflows/test.yml
vendored
@@ -55,7 +55,30 @@ jobs:
|
||||
run: awk '/```python/{flag=1;next}/```/{flag=0}flag' docs/quickstart.md > quickstart.py && PYTHONPATH=. python3 quickstart.py
|
||||
- name: Run Pytest
|
||||
run: python -m pytest -s -v -n=auto test/
|
||||
|
||||
testwebgpu:
|
||||
name: WebGPU Tests
|
||||
runs-on: macos-13
|
||||
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Python 3.8
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: 3.8
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing,webgpu]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
# - name: Set Env
|
||||
# run: printf "WEBGPU=1\nWGPU_BACKEND_TYPE=D3D12\n" >> $GITHUB_ENV
|
||||
- name: Run Pytest
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -s -v -n=auto test/test_ops.py test/test_speed_v_torch.py test/test_nn.py test/test_jit.py test/test_randomness.py test/test_tensor.py test/test_assign.py test/test_conv.py test/test_nn.py test/test_custom_function.py test/test_conv_shapetracker.py
|
||||
- name: Build WEBGPU Efficientnet
|
||||
run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.webgpu.compile_webgpu
|
||||
# - name: Install Puppeteer
|
||||
# run: npm install puppeteer
|
||||
# - name: Run Efficientnet
|
||||
# run: node test/test_webgpu.js
|
||||
testimagenet:
|
||||
name: ImageNet to C Compile Test
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -31,3 +31,8 @@ extra/datasets/kits/
|
||||
extra/datasets/COCO/
|
||||
extra/datasets/audio*
|
||||
venv
|
||||
examples/webgpu/net.js
|
||||
examples/webgpu/net.safetensors
|
||||
node_modules
|
||||
package.json
|
||||
package-lock.json
|
||||
@@ -1,15 +1,11 @@
|
||||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from extra.utils import fetch
|
||||
import ast
|
||||
|
||||
def compile_net(run, special_names):
|
||||
# functions that run the net
|
||||
functions = {}
|
||||
bufs = {}
|
||||
bufnum = 0
|
||||
statements = []
|
||||
bufs_to_save = {}
|
||||
functions, bufs, bufs_to_save, statements, bufnum = {}, {}, {}, [], 0
|
||||
for fxn,args in run.jit_cache:
|
||||
functions[fxn.name] = fxn.prg # NOTE: this assumes all with the same name are the same
|
||||
cargs = []
|
||||
@@ -17,26 +13,21 @@ def compile_net(run, special_names):
|
||||
key = id(arg)
|
||||
if key not in bufs:
|
||||
if key in special_names:
|
||||
bufs[key] = (special_names[key], len(arg._buf))
|
||||
bufs[key] = (special_names[key], arg._memsz, key)
|
||||
else:
|
||||
bufs[key] = (f"buf_{bufnum}", len(arg._buf))
|
||||
bufs[key] = (f"buf_{bufnum}", arg._memsz, key)
|
||||
bufnum += 1
|
||||
if i > 0: bufs_to_save[bufs[key][0]] = arg # if first usage of a buffer is not an output, and it's not a special name
|
||||
cargs.append(bufs[key][0])
|
||||
statements.append(f"{fxn.name}({', '.join(cargs)});")
|
||||
statements.append((fxn.name, cargs, fxn.global_size))
|
||||
|
||||
return functions, statements, bufs, bufs_to_save
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
|
||||
from tinygrad.jit import TinyJit
|
||||
def jit_model(model, the_input):
|
||||
@TinyJit
|
||||
def run(x): return model.forward(x).realize()
|
||||
|
||||
# twice to run the JIT
|
||||
the_input = Tensor.randn(1,3,224,224)
|
||||
the_output = run(the_input)
|
||||
the_output = run(the_input)
|
||||
|
||||
@@ -47,7 +38,12 @@ if __name__ == "__main__":
|
||||
|
||||
# TODO: fetch this from the jit in self.input_replace and self.ret (hint: use get_parameters on self.ret)
|
||||
special_names = {id(the_input.lazydata.realized): "input", id(the_output.lazydata.realized): "outputs"}
|
||||
return run, special_names
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
|
||||
functions, statements, bufs, bufs_to_save = compile_net(run, special_names)
|
||||
|
||||
# c header
|
||||
@@ -68,13 +64,13 @@ if __name__ == "__main__":
|
||||
cprog.append(f"char *lbls[] = {{{','.join(lbls)}}};")
|
||||
|
||||
# buffers (empty + weights)
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len in bufs.values()]
|
||||
cprog += [f"float {name}[{len}];" if name not in bufs_to_save else f"float *{name} = (float *){name}_data;" for name,len,_key in bufs.values()]
|
||||
|
||||
# the functions
|
||||
cprog += list(functions.values())
|
||||
|
||||
# the net
|
||||
cprog += ["void net() {"] + statements + ["}"]
|
||||
# the net
|
||||
cprog += ["void net() {"] + [f"{name}({', '.join(args)});" for (name, args, _global_size) in statements] + ["}"]
|
||||
|
||||
cprog += ["""
|
||||
int main(int argc, char* argv[]) {
|
||||
|
||||
87
examples/webgpu/compile_webgpu.py
Normal file
87
examples/webgpu/compile_webgpu.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from os import path
|
||||
from examples.compile_efficientnet import compile_net, jit_model
|
||||
from models.efficientnet import EfficientNet
|
||||
from tinygrad.state import get_state_dict, safe_save
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = EfficientNet(0)
|
||||
model.load_from_pretrained()
|
||||
run, special_names = jit_model(model, Tensor.randn(1,3,224,224))
|
||||
functions, statements, bufs, _bufs_to_save = compile_net(run, special_names)
|
||||
|
||||
state = get_state_dict(model)
|
||||
weights = {id(x.lazydata.realized): name for name, x in state.items()}
|
||||
safe_save(state, path.join(path.dirname(__file__), "net.safetensors"))
|
||||
|
||||
kernel_code = '\n\n'.join([f"const {key} = `{code.replace(key, 'main')}`;" for key, code in functions.items()])
|
||||
kernel_names = ', '.join([name for (name, _args, _global_size) in statements])
|
||||
kernel_calls = '\n '.join([f"addComputePass(device, commandEncoder, piplines[{i}], [{', '.join(args)}], {global_size});" for i, (_name, args, global_size) in enumerate(statements) ])
|
||||
bufs = '\n '.join([f"const {buf[0]} = " + (f"createEmptyBuf(device, {buf[1]});" if buf[2] not in weights else f"createWeightBuf(device, {buf[1]}, getTensorBuffer(safetensor, metadata['{weights[buf[2]]}']))") + ";" for buf in bufs.values()])
|
||||
|
||||
prg = f"""const getTensorMetadata = (safetensorBuffer) => {{
|
||||
const metadataLength = Number(new DataView(safetensorBuffer.buffer).getBigUint64(0, true));
|
||||
const metadata = JSON.parse(new TextDecoder("utf8").decode(safetensorBuffer.subarray(8, 8 + metadataLength)));
|
||||
return Object.fromEntries(Object.entries(metadata).filter(([k, v]) => k !== "__metadata__").map(([k, v]) => [k, {{...v, data_offsets: v.data_offsets.map(x => 8 + metadataLength + x)}}]));
|
||||
}};
|
||||
|
||||
const getTensorBuffer = (safetensorBuffer, tensorMetadata) => {{
|
||||
return safetensorBuffer.subarray(...tensorMetadata.data_offsets);
|
||||
}}
|
||||
|
||||
const createEmptyBuf = (device, size) => {{
|
||||
return device.createBuffer({{size, usage: GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST }});
|
||||
}};
|
||||
|
||||
const createWeightBuf = (device, size, data) => {{
|
||||
const buf = device.createBuffer({{ mappedAtCreation: true, size, usage: GPUBufferUsage.STORAGE }});
|
||||
new Uint8Array(buf.getMappedRange()).set(data);
|
||||
buf.unmap();
|
||||
return buf;
|
||||
}};
|
||||
|
||||
const addComputePass = (device, commandEncoder, pipeline, bufs, workgroup) => {{
|
||||
const bindGroup = device.createBindGroup({{layout: pipeline.getBindGroupLayout(0), entries: bufs.map((buffer, index) => ({{ binding: index, resource: {{ buffer }} }}))}});
|
||||
const passEncoder = commandEncoder.beginComputePass();
|
||||
passEncoder.setPipeline(pipeline);
|
||||
passEncoder.setBindGroup(0, bindGroup);
|
||||
passEncoder.dispatchWorkgroups(...workgroup);
|
||||
passEncoder.end();
|
||||
}};
|
||||
|
||||
{kernel_code}
|
||||
|
||||
const setupNet = async (device, safetensor) => {{
|
||||
const metadata = getTensorMetadata(safetensor);
|
||||
|
||||
{bufs}
|
||||
|
||||
const gpuWriteBuffer = device.createBuffer({{size:input.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});
|
||||
const gpuReadBuffer = device.createBuffer({{ size: outputs.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
const kernels = [{kernel_names}];
|
||||
const piplines = await Promise.all(kernels.map(name => device.createComputePipelineAsync({{layout: "auto", compute: {{ module: device.createShaderModule({{ code: name }}), entryPoint: "main" }}}})));
|
||||
|
||||
return async (data) => {{
|
||||
await gpuWriteBuffer.mapAsync(GPUMapMode.WRITE);
|
||||
new Float32Array(gpuWriteBuffer.getMappedRange()).set(data);
|
||||
gpuWriteBuffer.unmap();
|
||||
|
||||
const commandEncoder = device.createCommandEncoder();
|
||||
commandEncoder.copyBufferToBuffer(gpuWriteBuffer, 0, input, 0, gpuWriteBuffer.size);
|
||||
{kernel_calls}
|
||||
commandEncoder.copyBufferToBuffer(outputs, 0, gpuReadBuffer, 0, outputs.size);
|
||||
const gpuCommands = commandEncoder.finish();
|
||||
device.queue.submit([gpuCommands]);
|
||||
|
||||
await gpuReadBuffer.mapAsync(GPUMapMode.READ);
|
||||
const resultBuffer = new Float32Array(gpuReadBuffer.size);
|
||||
resultBuffer.set(new Float32Array(gpuReadBuffer.getMappedRange()));
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
with open(path.join(path.dirname(__file__), "net.js"), "w") as text_file:
|
||||
text_file.write(prg)
|
||||
119
examples/webgpu/index.html
Normal file
119
examples/webgpu/index.html
Normal file
@@ -0,0 +1,119 @@
|
||||
<html>
|
||||
<head>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
#result { font-size: 48px; }
|
||||
#time { font-size: 16px; color: grey; }
|
||||
#mybox { padding: 20px; }
|
||||
#resultbox { padding: 50px; }
|
||||
.bigggg { font-size: 18px; margin-top: 10px; }
|
||||
.bigg { font-size: 18px; }
|
||||
#url { font-size: 18px; width: 70%; }
|
||||
a { text-decoration: none; }
|
||||
h1 { padding: 50px; padding-bottom: 0px; font-size: 36px; font-weight: normal; }
|
||||
#imagebox { height:224px; width:224px; border: 1px dotted black; }
|
||||
#video { height:0px; width:0px; border: 1px dotted black; object-fit: cover;}
|
||||
canvas { display: none; }
|
||||
* { text-align: center; font-family: monospace; }
|
||||
</style>
|
||||
<title>tinygrad has WebGPU</title>
|
||||
<script src="./net.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<h1>WebGPU <a href="https://github.com/geohot/tinygrad">tinygrad</a> EfficientNet!</h1>
|
||||
<div id="mybox">
|
||||
<input type="text" id="url" placeholder="put url here" value="https://upload.wikimedia.org/wikipedia/commons/d/da/Norwegian_hen.jpg">
|
||||
<input class="bigg" type="button" onclick="runNetWResource(document.getElementById('url').value)" value="Use URL">
|
||||
</div>
|
||||
<br/>
|
||||
<img id="imagebox"></img>
|
||||
<canvas id="canvas" width="200" height="200"> </canvas>
|
||||
<div id="resultbox">
|
||||
<div id="result">result will go here</div>
|
||||
<div id="time"></div>
|
||||
</div>
|
||||
<script>
|
||||
const ctx = document.getElementById("canvas").getContext("2d", { willReadFrequently: true });
|
||||
const resultText = document.getElementById('result');
|
||||
let labels, net;
|
||||
|
||||
const error = (err) => {
|
||||
resultText.innerHTML = `Error: ${err}`;
|
||||
throw new Error(err);
|
||||
}
|
||||
|
||||
const getDevice = async () => {
|
||||
if (!navigator.gpu) error("WebGPU not supported.");
|
||||
const adapter = await navigator.gpu.requestAdapter();
|
||||
return await adapter.requestDevice();
|
||||
};
|
||||
|
||||
const timer = async (func, label = "") => {
|
||||
document.getElementById('time').innerHTML = "";
|
||||
const start = performance.now();
|
||||
const out = await func();
|
||||
const delta = (performance.now() - start).toFixed(1)
|
||||
console.log(`${delta} ms ${label}`);
|
||||
document.getElementById('time').innerHTML = `${delta} ms ${label}`;
|
||||
return out;
|
||||
}
|
||||
|
||||
const getLabels = async () => (await fetch("https://raw.githubusercontent.com/anishathalye/imagenet-simple-labels/master/imagenet-simple-labels.json")).json();
|
||||
|
||||
const getSavetensorBuffer = async () => new Uint8Array(await (await fetch("./net.safetensors")).arrayBuffer());
|
||||
|
||||
const reorderChannelsAndRemoveAlpha = (data) => {
|
||||
const out = [];
|
||||
let i = 0;
|
||||
for (let c = 0; c < 3; c++) {
|
||||
for (let x = 0; x < 224 * 224; x++) {
|
||||
out[i] = data[x * 4 + c];
|
||||
i++;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
const runNetWResource = async (resource) => {
|
||||
resultText.innerHTML = "pending..."
|
||||
if (resource == "") error("sir. please type in a URL");
|
||||
const response = await fetch(resource)
|
||||
if (!response.ok) error("sir. that is not a good URL. try a new one");
|
||||
document.getElementById("imagebox").src = resource
|
||||
|
||||
const img = new Image();
|
||||
img.crossOrigin = "Anonymous";
|
||||
img.onload = () => {
|
||||
URL.revokeObjectURL(img.src);
|
||||
ctx.drawImage(img, 0, 0, 224, 224);
|
||||
const data = ctx.getImageData(0, 0, 224, 224).data;
|
||||
runNet(data)
|
||||
};
|
||||
img.src = resource;
|
||||
}
|
||||
|
||||
const loadLet = async () => {
|
||||
resultText.innerHTML = "loading..."
|
||||
labels = await getLabels();
|
||||
const safetensor = await getSavetensorBuffer();
|
||||
const device = await getDevice();
|
||||
net = await timer(() => setupNet(device, safetensor), "(compilation)");
|
||||
resultText.innerHTML = "ready"
|
||||
}
|
||||
|
||||
const runNet = async (data) => {
|
||||
if (!net) error("Net not loaded yet.");
|
||||
|
||||
const input = reorderChannelsAndRemoveAlpha(Array.from(data).map((pix) => (pix / 255.0) * 0.45 - 0.225));
|
||||
const out = await timer(() => net(new Float32Array(input)));
|
||||
|
||||
const arr = Array.from(new Float32Array(out));
|
||||
const index = arr.indexOf(Math.max(...arr));
|
||||
|
||||
resultText.textContent = labels[index];
|
||||
};
|
||||
|
||||
loadLet();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
1
setup.py
1
setup.py
@@ -25,6 +25,7 @@ setup(name='tinygrad',
|
||||
'llvm': ["llvmlite"],
|
||||
'cuda': ["pycuda"],
|
||||
'triton': ["triton>=2.0.0.dev20221202"],
|
||||
'webgpu': ["wgpu"],
|
||||
'metal': ["pyobjc-framework-Metal", "pyobjc-framework-Cocoa", "pyobjc-framework-libdispatch"],
|
||||
'linting': [
|
||||
"flake8",
|
||||
|
||||
@@ -46,6 +46,7 @@ class TestTrain(unittest.TestCase):
|
||||
train_one_step(model,X,Y)
|
||||
check_gc()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "too many buffers for webgpu")
|
||||
def test_vit(self):
|
||||
model = ViT()
|
||||
X = np.zeros((BS,3,224,224), dtype=np.float32)
|
||||
|
||||
@@ -28,7 +28,7 @@ def _test_matmul_upcast(a:Tensor, b:Tensor, target_dtype:DType, target): _test_o
|
||||
|
||||
# for GPU, cl_khr_fp16 isn't supported (except now we don't need it!)
|
||||
# for LLVM, it segfaults because it can't link to the casting function
|
||||
@unittest.skipIf(getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"], "float16 broken in some CI backends")
|
||||
@unittest.skipIf((getenv("CI", "") != "" and Device.DEFAULT in ["LLVM"]) or Device.DEFAULT == "WEBGPU", "float16 broken in some CI backends")
|
||||
class TestHalfDtype(unittest.TestCase):
|
||||
def test_half_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.float16), np.float16, [1,2,3,4])
|
||||
|
||||
@@ -53,6 +53,7 @@ class TestHalfDtype(unittest.TestCase):
|
||||
def test_half_matmul_upcast_float(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.float16), Tensor.eye(2, dtype=dtypes.float32), dtypes.float32, [[1,2],[3,4]])
|
||||
def test_int8_matmul_upcast_half(self): _test_matmul_upcast(Tensor([[1,2],[3,4]], dtype=dtypes.int8), Tensor.eye(2, dtype=dtypes.float16), dtypes.float16, [[1,2],[3,4]])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu does not support int8")
|
||||
class TestInt8Dtype(unittest.TestCase):
|
||||
def test_int8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.int8), np.int8, [1,2,3,4])
|
||||
def test_uint8_to_np(self): _test_to_np(Tensor([1,2,3,4], dtype=dtypes.uint8), np.uint8, [1,2,3,4])
|
||||
|
||||
@@ -95,7 +95,7 @@ class TestNN(unittest.TestCase):
|
||||
torch_z = torch_layer(torch_x)
|
||||
np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(getenv("CI", "") != "" and WINDOWS, "runs out of memory in CI")
|
||||
@unittest.skipIf(getenv("CI", "") != "" and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI")
|
||||
def test_conv_transpose2d(self):
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
@@ -197,12 +197,12 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65)], lambda x: 2/x, lambda x: 2/x)
|
||||
helper_test_op([()], lambda x: x/2, lambda x: x/2)
|
||||
helper_test_op([()], lambda x: 2/x, lambda x: 2/x)
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL", "METAL has issues with -inf")
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf")
|
||||
def test_mul_const_naninf(self):
|
||||
helper_test_op([(45,65)], lambda x: x*float("inf"), lambda x: x*float("inf"))
|
||||
helper_test_op([(45,65)], lambda x: x*-float("inf"), lambda x: x*-float("inf"))
|
||||
helper_test_op([(45,65)], lambda x: x*float("nan"), lambda x: x*float("nan"))
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL", "METAL has issues with -inf")
|
||||
@unittest.skipIf(Device.DEFAULT in ["METAL", "WEBGPU"], "WEBGPU does not have support for inf/nan, METAL has issues with -inf")
|
||||
def test_div_const_naninf(self):
|
||||
helper_test_op([(45,65)], lambda x: x/float("inf"), lambda x: x/float("inf"))
|
||||
helper_test_op([(45,65)], lambda x: x/-float("inf"), lambda x: x/-float("inf"))
|
||||
@@ -726,7 +726,7 @@ class TestOps(unittest.TestCase):
|
||||
lambda x,w: torch.nn.functional.conv_transpose3d(x,w).relu(),
|
||||
lambda x,w: Tensor.conv_transpose2d(x,w).relu(), atol=1e-4, grad_rtol=1e-5)
|
||||
|
||||
@unittest.skipIf(IMAGE>0, "no conv1d on images")
|
||||
@unittest.skipIf((IMAGE>0 or (Device.DEFAULT == "WEBGPU" and getenv("CI","") != "")), "no conv1d on images")
|
||||
def test_conv1d(self):
|
||||
for bs in [1,8]:
|
||||
for cin in [1,3]:
|
||||
|
||||
@@ -19,7 +19,7 @@ class TestSpecific(unittest.TestCase):
|
||||
w = Tensor.randn(2048, 512)
|
||||
(x @ w).reshape(1, 128, 4).contiguous().realize()
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "LLVM", "Broken on LLVM")
|
||||
@unittest.skipIf(Device.DEFAULT in ["LLVM", "WEBGPU"], "Broken on LLVM and webgpu")
|
||||
def test_big_vec_mul(self):
|
||||
# from LLaMA
|
||||
# 0 buffer<4096, dtypes.float> [View((1024, 1, 1, 4), (4, 0, 0, 1), 0, None)]
|
||||
|
||||
@@ -131,7 +131,7 @@ class TestBigSpeed(unittest.TestCase):
|
||||
def test_large_conv_1x1(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=1, img_size_y=128, img_size_x=128)
|
||||
def test_large_conv_3x3(self): helper_test_conv(bs=32, in_chans=128, out_chans=128, kernel_size=3, img_size_y=130, img_size_x=130)
|
||||
|
||||
@unittest.skipIf(getenv("BIG") == 1, "only big tests")
|
||||
@unittest.skipIf((getenv("BIG") == 1 or Device.DEFAULT == "WEBGPU"), "only big tests")
|
||||
class TestSpeed(unittest.TestCase):
|
||||
def setUp(self):
|
||||
global prefix
|
||||
|
||||
51
test/test_webgpu.js
Normal file
51
test/test_webgpu.js
Normal file
@@ -0,0 +1,51 @@
|
||||
const puppeteer = require('puppeteer');
|
||||
const { spawn } = require('child_process');
|
||||
const res = spawn("python", ["-m", "http.server", "8000"], { shell: true });
|
||||
|
||||
async function timeout(time) {
|
||||
return new Promise((resolve) => setTimeout(resolve, time));
|
||||
}
|
||||
|
||||
function cleanup(err) {
|
||||
res.kill();
|
||||
if(err != null) {
|
||||
console.error(err);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
async function waitForText(selector, text) {
|
||||
let n = 0;
|
||||
let ready = false;
|
||||
while (n < 10) {
|
||||
const res = await (await selector.getProperty("textContent")).jsonValue();
|
||||
console.log(`waiting for text ${text} got ${res}`);
|
||||
if(res == text) {
|
||||
ready = true;
|
||||
break
|
||||
}
|
||||
await timeout(2000);
|
||||
n += 1
|
||||
}
|
||||
return ready;
|
||||
}
|
||||
|
||||
puppeteer.launch({ headless: false, args: ["--enable-unsafe-webgpu"]}).then(async browser => {
|
||||
const page = await browser.newPage();
|
||||
page.on("console", message => console.log(`message from console ${message.text()}`))
|
||||
.on("pageerror", ({ message }) => console.log(`error from page ${message}`))
|
||||
|
||||
const res = await page.goto("http://localhost:8000/examples/webgpu/index.html");
|
||||
if(res.status() != 200) throw new Error("Failed to load page");
|
||||
const textSelector = await page.waitForSelector("#result");
|
||||
const buttonSelector = await page.waitForSelector("input[type=button]");
|
||||
const ready = await waitForText(textSelector, "ready");
|
||||
if(!ready) throw new Error("Failed to load page");
|
||||
await buttonSelector.evaluate(e => e.click());
|
||||
const done = await waitForText(textSelector, "hen");
|
||||
if(!done) throw new Error("failed to get hen");
|
||||
browser.close();
|
||||
cleanup(null);
|
||||
}).catch(err => {
|
||||
cleanup(err);
|
||||
});
|
||||
@@ -1,8 +1,8 @@
|
||||
import pathlib
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.state import safe_load, safe_save, torch_load, get_state_dict
|
||||
from tinygrad.tensor import Tensor, Device
|
||||
from tinygrad.state import safe_load, safe_save, get_state_dict, torch_load
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.runtime.ops_disk import RawDiskBuffer
|
||||
from tinygrad.helpers import Timing
|
||||
@@ -46,7 +46,7 @@ class TestRawDiskBuffer(unittest.TestCase):
|
||||
tst = np.empty(test_size, np.uint8)
|
||||
with Timing("copy in ", lambda et_ns: f" {test_size/et_ns:.2f} GB/s"):
|
||||
np.copyto(tst, db.toCPU())
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "webgpu doesn't support uint8 datatype")
|
||||
class TestSafetensors(unittest.TestCase):
|
||||
def test_real_safetensors(self):
|
||||
import torch
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Set, Union
|
||||
from typing import Final, Dict, ClassVar, List, Optional, NamedTuple, DefaultDict, Tuple, Union
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod
|
||||
from tinygrad.ops import ASTRunner, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import ImageDType, dtypes, colored, getenv, prod, DType
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -13,6 +13,8 @@ render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops, ctx)}/{self.b})
|
||||
render_cl[AndNode] = lambda self,ops,ctx: f"({'&&'.join(sorted([x.render(ops,ctx) for x in self.nodes]))})"
|
||||
|
||||
class CStyleLanguage(NamedTuple):
|
||||
size_prefix: str = "int"
|
||||
generic_var_prefix: str = ""
|
||||
kernel_prefix: str = ""
|
||||
buffer_prefix: str = ""
|
||||
buffer_suffix: str = ""
|
||||
@@ -24,9 +26,20 @@ class CStyleLanguage(NamedTuple):
|
||||
float4: Optional[str] = None
|
||||
half_prekernel: Optional[str] = None
|
||||
uses_vload: bool = False
|
||||
external_local_bufs: bool = False
|
||||
code_for_op: Dict = {
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x: f"log2({x})",
|
||||
UnaryOps.SIN: lambda x: f"sin({x})",
|
||||
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
||||
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
|
||||
}
|
||||
|
||||
# returns a str expression of the casted xs with the given type
|
||||
def render_cast(self, x:List[str], var_dtype) -> str:
|
||||
def render_cast(self, x:List[str], var_dtype:DType) -> str:
|
||||
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
|
||||
assert self.float4 is not None, "cast is not supported on this platform"
|
||||
if var_dtype == dtypes._float4: return f"{self.float4}({','.join(x)})"
|
||||
@@ -50,9 +63,30 @@ class CStyleLanguage(NamedTuple):
|
||||
if output_dtype.sz > 1:
|
||||
return f"({output_dtype.name})(*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})))"
|
||||
return f"{buf_name}[{idx.render(render_cl)}]"
|
||||
|
||||
def render_local(self, name:str, size:int):
|
||||
return self.smem_prefix + f"float {name}[{size}];"
|
||||
|
||||
def render_for(self, expr: str, _min:int, _max:int) -> str:
|
||||
return f"for (int {expr} = {_min}; {expr} <= {_max}; ++{expr}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
return f"({cond})?({x}):{y}"
|
||||
|
||||
def render_kernel(self, kernel:List[str], bufs:List[Union[LocalBuffer,LazyBuffer]], bufnames:List[str], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str,List[int],List[int]]:
|
||||
tmp = "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(x.dtype, ImageDType) for x in bufs) else ""
|
||||
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
|
||||
("const " if i > 0 else "")+self.buffer_prefix+x.dtype.name+"*"+self.buffer_suffix) for i,x in enumerate(bufs)
|
||||
if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
|
||||
prg = ''.join([f"{self.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
|
||||
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + self.extra_args)] +
|
||||
[") {\n" + tmp] + ['\n'.join(kernel), "\n}"])
|
||||
if self.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{self.half_prekernel}", "\n", prg])
|
||||
|
||||
return prg, global_size[::-1], local_size[::-1]
|
||||
|
||||
# returns a str statement that does the store
|
||||
def render_store(self, buf_name, buf_dtype, var_name, var_dtype, idx, local=False) -> str:
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
||||
if isinstance(buf_dtype, ImageDType):
|
||||
assert var_dtype == dtypes._float4, "images must be float4"
|
||||
return f"write_imagef({buf_name}, (int2)({idx[0].render(render_cl)}, {idx[1].render(render_cl)}), {var_name});"
|
||||
@@ -62,18 +96,7 @@ class CStyleLanguage(NamedTuple):
|
||||
return f"*(({self.smem_prefix if local else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx.render(render_cl, strip_parens=True)})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
|
||||
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
||||
|
||||
code_for_op: Final[Dict[Op, Callable]] = {
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})",
|
||||
UnaryOps.LOG2: lambda x: f"log2({x})",
|
||||
UnaryOps.SIN: lambda x: f"sin({x})",
|
||||
UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda a,b: f"({a}+{b})", BinaryOps.SUB: lambda a,b: f"({a}-{b})",
|
||||
BinaryOps.MUL: lambda a,b: f"({a}*{b})", BinaryOps.DIV: lambda a,b: f"({a}/{b})",
|
||||
BinaryOps.MAX: lambda a,b: f"max({a},{b})",
|
||||
BinaryOps.CMPEQ: lambda a,b: f"({a}=={b})", FusedOps.MULACC: lambda a,b,c: f"(({a}*{b})+{c})"
|
||||
}
|
||||
|
||||
def add_gl_dimension(args, i, var, local_size, xid):
|
||||
def add_gl_dimension(prefix: str, args, i:int, var, local_size:List[int], xid:List[str]):
|
||||
# for M1 tensor core stuff, support > 3 dims
|
||||
if i >= 2 and len(args[0]) > len(xid):
|
||||
# do this on the x dim for warps
|
||||
@@ -82,19 +105,14 @@ def add_gl_dimension(args, i, var, local_size, xid):
|
||||
lidx = Variable(xid[0], 0, prod(x.max+1 for x in args[0][2:])-1)
|
||||
lidx = (lidx//((lidx.max+1)//local_size[-1]))%(var.max+1)
|
||||
assert lidx.max == var.max and lidx.min == var.min
|
||||
return f"{{ int {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
||||
return f"{{ {prefix} {var.expr} = {lidx.render(render_cl)}; /* {var.max+1} */"
|
||||
local_size.append(var.max+1)
|
||||
return f"{{ int {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||
return f"{{ {prefix} {var.expr} = {xid[min(len(xid), len(args[0]))-1-i]}; /* {var.max+1} */"
|
||||
|
||||
def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lang:CStyleLanguage) -> Tuple[str, List[int], List[int]]:
|
||||
prekernel: Set[str] = set()
|
||||
kernel = []
|
||||
global_size = []
|
||||
local_size = []
|
||||
kernel,global_size,local_size,prekernel = [],[],[],[]
|
||||
pend_close = None
|
||||
|
||||
bufnames = [b.name if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
|
||||
|
||||
depth = 0
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
|
||||
@@ -108,12 +126,12 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
kk("{")
|
||||
else:
|
||||
if args[1] == "global" and lang.gid:
|
||||
kk(add_gl_dimension(args, i, var, global_size, lang.gid))
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, global_size, lang.gid))
|
||||
elif args[1] == "local" and lang.lid:
|
||||
kk(add_gl_dimension(args, i, var, local_size, lang.lid))
|
||||
kk(add_gl_dimension(lang.size_prefix, args, i, var, local_size, lang.lid))
|
||||
else:
|
||||
if getenv("NOUNROLL"): kk("#pragma unroll(1)") # prevent loop unrolling
|
||||
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
|
||||
kk(lang.render_for(var.expr, var.min, var.max))
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
@@ -139,10 +157,10 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
kk(f"{vin[4].render()} = c.thread_elements()[0]; {vin[5].render()} = c.thread_elements()[1]; }}")
|
||||
elif uop == UOps.CONST:
|
||||
assert newvar is not None
|
||||
kk(f"{newvar.render(True)} = {lang.render_const(args, newvar.dtype)};")
|
||||
kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {lang.render_const(args, newvar.dtype)};")
|
||||
elif uop == UOps.ALU:
|
||||
assert newvar is not None
|
||||
kk(f"{newvar.render(newvar not in vin)} = {code_for_op[args](*[x.render() for x in vin])};")
|
||||
kk(f"{lang.generic_var_prefix if newvar not in vin else ''}{newvar.render(newvar not in vin and lang.generic_var_prefix == '')} = {lang.code_for_op[args](*[x.render() for x in vin])};")
|
||||
elif uop == UOps.LOAD:
|
||||
assert newvar is not None
|
||||
# valids are handled here
|
||||
@@ -152,8 +170,8 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
val = lang.render_const(bufs[args.i].realized._buf, newvar.dtype)
|
||||
else:
|
||||
val = lang.render_load(newvar.dtype, bufnames[args.i], bufs[args.i].dtype, args.idx, isinstance(bufs[args.i], LocalBuffer))
|
||||
if args.valid.min == 0 and args.valid.max == 1: val = f"({args.valid.render(render_cl)}) ? ({val}) : {lang.render_const(0.0, newvar.dtype)}"
|
||||
kk(f"{newvar.render(True)} = {val};")
|
||||
if args.valid.min == 0 and args.valid.max == 1: val = lang.render_conditional(args.valid.render(render_cl), val, lang.render_const(0.0, newvar.dtype))
|
||||
kk(f"{lang.generic_var_prefix}{newvar.render(lang.generic_var_prefix == '')} = {val};")
|
||||
elif uop == UOps.STORE:
|
||||
assert args.valid.min == 1, "store must be valid"
|
||||
# TODO: instead of dtypes.float, a base type
|
||||
@@ -161,20 +179,14 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
elif uop == UOps.CAST and newvar is not None and newvar.dtype.sz > 1:
|
||||
kk(f"{newvar.render(True)} = {lang.render_cast([x.render() for x in vin], newvar.dtype)};")
|
||||
elif uop == UOps.DEFINE_LOCAL:
|
||||
kk(lang.smem_prefix + f"float {args[0]}[{args[1]}];")
|
||||
if lang.external_local_bufs:
|
||||
prekernel.append(lang.render_local(args[0], args[1]))
|
||||
else:
|
||||
kk(lang.render_local(args[0], args[1]))
|
||||
else:
|
||||
raise RuntimeError(f"failed to render {uop}")
|
||||
|
||||
if any(isinstance(x.dtype, ImageDType) for x in bufs): prekernel.add("const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n")
|
||||
buftypes = [(i,f"{'read_only' if i > 0 else 'write_only'} image2d_t" if x.dtype.name.startswith('image') else
|
||||
("const " if i > 0 else "")+lang.buffer_prefix+x.dtype.name+"*"+lang.buffer_suffix) for i,x in enumerate(bufs)
|
||||
if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)]
|
||||
prg = ''.join([f"{lang.kernel_prefix} void KERNEL_NAME_PLACEHOLDER(",] +
|
||||
[', '.join([f'{t} {bufnames[i]}' for i,t in buftypes] + lang.extra_args)] +
|
||||
[") {\n"] + list(prekernel) + ['\n'.join(kernel), "\n}"])
|
||||
|
||||
if lang.half_prekernel and any(x.dtype == dtypes.float16 for x in bufs): prg = ''.join([f"{lang.half_prekernel}", "\n", prg])
|
||||
return prg, global_size, local_size
|
||||
return lang.render_kernel(kernel, bufs, bufnames, global_size, local_size, prekernel)
|
||||
|
||||
class CStyleCodegen(Linearizer):
|
||||
lang: ClassVar[CStyleLanguage] = CStyleLanguage()
|
||||
@@ -202,5 +214,5 @@ class CStyleCodegen(Linearizer):
|
||||
CStyleCodegen.kernel_name_cache[prg] = function_name, display_name = self.function_name+suffix, self.display_name+colored(suffix, 'BLACK')
|
||||
|
||||
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name),
|
||||
global_size[::-1], local_size[::-1],
|
||||
global_size, local_size,
|
||||
op_estimate=self.info.flops, mem_estimate=self.mem_estimate, display_name=display_name)
|
||||
|
||||
@@ -200,7 +200,7 @@ class Linearizer:
|
||||
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType))
|
||||
return [x for x in self.sts[i].unit_stride_axes() if should_upcast and x >= self.shape_len-self.upcasted and self.sts[i].shape[x] > 1]
|
||||
|
||||
def global_load(self, i, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
|
||||
def global_load(self, i:int, idxs:Sequence[VariableOrNum], const=None) -> List[Token]:
|
||||
expanded_nodes = [expand_node(idx) for idx in idxs]
|
||||
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
|
||||
upcast_dim = self.get_upcast_dim(i)
|
||||
@@ -217,10 +217,10 @@ class Linearizer:
|
||||
localtype = dtypes._float4 if amt == 4 else dtypes._float2
|
||||
if idx.render() != ((idx//amt)*amt).render():
|
||||
idx, valid = self.sts[i].expr_idxs(_idx)
|
||||
localtype = dtypes.float
|
||||
localtype = dtypes.float32
|
||||
else:
|
||||
idx, valid = self.sts[i].expr_idxs(_idx)
|
||||
localtype = dtypes.float
|
||||
localtype = dtypes.float32
|
||||
key = f"{localtype}{idx.render()}{valid.render()}"
|
||||
if key not in cache:
|
||||
if isinstance(self.bufs[i].dtype, ImageDType): idx = to_image_idx(self.bufs[i].dtype.shape, idx, valid)
|
||||
|
||||
54
tinygrad/codegen/wgsl.py
Normal file
54
tinygrad/codegen/wgsl.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.codegen.cstyle import render_cl
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.codegen.linearizer import LocalBuffer
|
||||
from tinygrad.codegen.cstyle import CStyleLanguage
|
||||
from typing import List, Union
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, FusedOps
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool"}
|
||||
class WGSLLanguage(CStyleLanguage):
|
||||
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
|
||||
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
|
||||
size_prefix = "let"
|
||||
barrier="workgroupBarrier();"
|
||||
generic_var_prefix = "var "
|
||||
external_local_bufs = True
|
||||
code_for_op = {
|
||||
UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})",
|
||||
BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})",
|
||||
BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPEQ: lambda x,y: f"f32({x}=={y})",
|
||||
FusedOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})",
|
||||
}
|
||||
|
||||
def render_local(self, name: str, size: int):
|
||||
return f"var<workgroup> {name}: array<f32,{size}>;"
|
||||
|
||||
def render_const(self, x:Union[float,int], var_dtype) -> str:
|
||||
if math.isinf(x): val = ("-" if x < 0 else "") + "0x1.fffffep+127f"
|
||||
else: val = f"{x}" + ("" if dtypes.is_int(var_dtype) else "f")
|
||||
return self.render_cast([val]*var_dtype.sz, var_dtype) if var_dtype.sz > 1 else val
|
||||
|
||||
def render_kernel(self, kernel:List[str], bufs:List[Union[LocalBuffer,LazyBuffer]], bufnames:List[str], global_size:List[int], local_size:List[int], prekernel:List[str]) -> Tuple[str, List[int], List[int]]:
|
||||
local_size = local_size[::-1] if len(local_size) else [1]
|
||||
bind_it = iter(range(len(bufs)))
|
||||
prg = "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) var<storage,read_write> data{i}: array<{type_map[x.dtype]}>;" for i,x in enumerate(bufs) if not isinstance(x, LocalBuffer) and not isinstance(x.realized, RawConst)])
|
||||
prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn KERNEL_NAME_PLACEHOLDER(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}"
|
||||
return prg, global_size[::-1] if len(global_size) else [1], local_size
|
||||
|
||||
def render_for(self, expr:str, _min:int, _max:int) -> str:
|
||||
return f"for(var {expr} = {_min}; {expr} <= {_max}; {expr}++) {{"
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
return f"select(f32({y}), {x}, bool({cond}))"
|
||||
|
||||
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
|
||||
return f"f32({super().render_load(output_dtype, buf_name, buf_dtype, idx, local)})"
|
||||
|
||||
def render_store(self, buf_name:str, buf_dtype:DType, var_name:str, var_dtype:DType, idx, local=False) -> str:
|
||||
if buf_dtype != var_dtype:
|
||||
var_name = f"{type_map[buf_dtype]}({var_name})"
|
||||
return f"{buf_name}[{idx.render(render_cl)}] = {var_name};"
|
||||
@@ -18,7 +18,7 @@ class TinyJit:
|
||||
def __get__(self, obj, objtype): return functools.partial(self.__call__, obj)
|
||||
|
||||
def __call__(self, *args, **kwargs) -> Any:
|
||||
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
|
||||
if Device.DEFAULT not in ["GPU", "CLANG", "METAL", "CUDA", "HIP", "WEBGPU"]: return self.fxn(*args, **kwargs) # only jit on the GPU codegen
|
||||
# NOTE: this cast is needed since although we know realize will create a ".realized" DeviceBuffer, the type checker doesn't
|
||||
input_rawbuffers: Dict[Union[int, str], RawBuffer] = {cast(Union[int, str], k):cast(RawBuffer, v.realize().lazydata.realized) for k,v in itertools.chain(enumerate(args), kwargs.items()) if isinstance(v, Tensor)}
|
||||
assert len(input_rawbuffers) != 0, "no inputs to JIT"
|
||||
|
||||
40
tinygrad/runtime/ops_webgpu.py
Normal file
40
tinygrad/runtime/ops_webgpu.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
from wgpu.utils._device import get_default_device
|
||||
from tinygrad.runtime.lib import RawBufferCopyIn
|
||||
from tinygrad.helpers import dtypes, DType
|
||||
from tinygrad.ops import Compiled
|
||||
from tinygrad.codegen.cstyle import CStyleCodegen
|
||||
from tinygrad.codegen.wgsl import WGSLLanguage
|
||||
import wgpu
|
||||
|
||||
device = get_default_device()
|
||||
|
||||
class WebGPUProgram:
|
||||
def __init__(self, name: str, prg: str): self.name,self.prg = name,device.create_shader_module(code=prg)
|
||||
def __call__(self, global_size, local_size, *bufs, wait=False):
|
||||
binding_layouts = [{"binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}} for i in range(len(bufs))]
|
||||
bindings = [{"binding": i, "resource": {"buffer": x._buf, "offset": 0, "size": x._buf.size}} for i, x in enumerate(bufs)]
|
||||
bind_group_layout = device.create_bind_group_layout(entries=binding_layouts)
|
||||
pipeline_layout = device.create_pipeline_layout(bind_group_layouts=[bind_group_layout])
|
||||
bind_group = device.create_bind_group(layout=bind_group_layout, entries=bindings)
|
||||
compute_pipeline = device.create_compute_pipeline(layout=pipeline_layout,compute={"module": self.prg, "entry_point": self.name},)
|
||||
command_encoder = device.create_command_encoder()
|
||||
compute_pass = command_encoder.begin_compute_pass()
|
||||
compute_pass.set_pipeline(compute_pipeline)
|
||||
compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used
|
||||
compute_pass.dispatch_workgroups(*global_size) # x y z
|
||||
compute_pass.end()
|
||||
device.queue.submit([command_encoder.finish()])
|
||||
|
||||
class WGSLCodegen(CStyleCodegen):
|
||||
lang = WGSLLanguage()
|
||||
supports_float4: bool = False
|
||||
|
||||
class RawWebGPUBuffer(RawBufferCopyIn):
|
||||
def __init__(self, size:int, dtype:DType):
|
||||
assert dtype not in [dtypes.int8,dtypes.uint8,dtypes.int64,dtypes.uint64], f"dtype {dtype} not supported on WEBGPU"
|
||||
super().__init__(size, dtype, device.create_buffer(size=size*dtype.itemsize, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC))
|
||||
def _copyin(self, x:np.ndarray): device.queue.write_buffer(self._buf, 0, np.ascontiguousarray(x))
|
||||
def toCPU(self) -> np.ndarray: return np.frombuffer(device.queue.read_buffer(self._buf, 0), dtype=np.dtype(self.dtype.np, metadata={"backing": self})) # type: ignore
|
||||
|
||||
WebGpuBuffer = Compiled(RawWebGPUBuffer, WGSLCodegen, WebGPUProgram)
|
||||
@@ -24,7 +24,7 @@ def safe_save(tensors:Dict[str, Tensor], fn:str):
|
||||
pathlib.Path(fn).unlink(missing_ok=True)
|
||||
t = Tensor.empty(8+len(j)+offset, dtype=dtypes.uint8, device=f"disk:{fn}")
|
||||
t[0:1].cast(dtypes.int64).assign([len(j)])
|
||||
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8))
|
||||
t[8:8+len(j)].assign(Tensor(list(j.encode('utf-8')), dtype=dtypes.uint8, device="cpu"))
|
||||
for k,v in safe_load(t).items(): v.assign(tensors[k])
|
||||
|
||||
# state dict
|
||||
|
||||
Reference in New Issue
Block a user