From e992ed10dcd5a02b40fb9bee60a1eccd67f5cf35 Mon Sep 17 00:00:00 2001 From: Ahmed Harmouche Date: Wed, 2 Jul 2025 17:38:45 +0200 Subject: [PATCH] WebGPU on Windows (#10890) * WebGPU on Windows * Fix dawn-python install * New test * pydeps * Minor fix * Only install dawn-python on windows webgpu --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com> --- .github/workflows/test.yml | 7 ++++--- docs/runtime.md | 2 +- test/test_ops.py | 7 +++++-- tinygrad/runtime/support/webgpu.py | 10 ++++++++-- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e98c5aad91..9795d6fe43 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -938,7 +938,7 @@ jobs: strategy: fail-fast: false matrix: - backend: [llvm, cpu] + backend: [llvm, cpu, webgpu] name: Windows (${{ matrix.backend }}) runs-on: windows-latest @@ -951,11 +951,12 @@ jobs: - name: Setup Environment uses: ./.github/actions/setup-tinygrad with: - key: windows-minimal + key: windows-${{ matrix.backend }}-minimal deps: testing_unit + pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }} - name: Set env shell: bash - run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1'}}" >> $GITHUB_ENV + run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV - name: Run unit tests if: matrix.backend=='llvm' run: python -m pytest -n=auto test/unit/ --ignore=test/unit/test_disk_tensor.py --ignore=test/unit/test_elf.py --ignore=test/unit/test_tar.py diff --git a/docs/runtime.md b/docs/runtime.md index 045ca91ce6..bc85d9bedf 100644 --- a/docs/runtime.md +++ b/docs/runtime.md @@ -12,7 +12,7 @@ tinygrad supports various runtimes, enabling your code to scale across a wide ra | [GPU (OpenCL)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_gpu.py) | Accelerates computations using OpenCL on GPUs | OpenCL 2.0 compatible device | | [CPU (C Code)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang compiler | `clang` compiler in system `PATH` | | [LLVM (LLVM IR)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_llvm.py) | Runs on CPU using the LLVM compiler infrastructure | llvm libraries installed and findable | -| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | Dawn library installed and findable. Download binaries [here](https://github.com/wpmed92/pydawn/releases/tag/v0.1.6). | +| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | Dawn library installed and findable. Download binaries [here](https://github.com/wpmed92/pydawn/releases/tag/v0.3.0). | ## Interoperability diff --git a/test/test_ops.py b/test/test_ops.py index c39b9c9504..f6e308c912 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1,4 +1,4 @@ -import time, math, unittest, functools, warnings +import time, math, unittest, functools, platform, warnings import numpy as np from typing import List, Callable import torch @@ -826,6 +826,7 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]]) helper_test_op(None, lambda x: x.sin(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]], atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3) + @unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend") def test_cos(self): helper_test_op([(45,65)], lambda x: x.cos()) helper_test_op([()], lambda x: x.cos()) @@ -833,6 +834,7 @@ class TestOps(unittest.TestCase): helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]]) helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]], atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3) + @unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend") def test_tan(self): # NOTE: backward has much higher diff with input close to pi/2 and -pi/2 helper_test_op([(45,65)], lambda x: x.tan(), low=-1.5, high=1.5) @@ -1281,7 +1283,8 @@ class TestOps(unittest.TestCase): np.arange(64,128,dtype=np.float32).reshape(8,8)]) def test_small_gemm_eye(self): helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)]) - @unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE, "not supported on these in CI/IMAGE") + @unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE + or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE") def test_gemm_fp16(self): helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3) def test_gemm(self): diff --git a/tinygrad/runtime/support/webgpu.py b/tinygrad/runtime/support/webgpu.py index 9b5362ee8b..11c6e10386 100644 --- a/tinygrad/runtime/support/webgpu.py +++ b/tinygrad/runtime/support/webgpu.py @@ -1,10 +1,16 @@ -import ctypes, ctypes.util, os, subprocess, platform +import ctypes, ctypes.util, os, subprocess, platform, sysconfig from tinygrad.helpers import OSX +WEBGPU_PATH: str | None + if OSX: if not os.path.exists(brew_prefix:=subprocess.check_output(['brew', '--prefix', 'dawn']).decode().strip()): raise FileNotFoundError('dawn library not found. Install it with `brew tap wpmed92/dawn && brew install dawn`') - WEBGPU_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libwebgpu_dawn.dylib') + WEBGPU_PATH = os.path.join(brew_prefix, 'lib', 'libwebgpu_dawn.dylib') +elif platform.system() == "Windows": + if not os.path.exists(pydawn_path:=os.path.join(sysconfig.get_paths()["purelib"], "pydawn")): + raise FileNotFoundError("dawn library not found. Install it with `pip install dawn-python`") + WEBGPU_PATH = os.path.join(pydawn_path, "lib", "libwebgpu_dawn.dll") else: if (WEBGPU_PATH:=ctypes.util.find_library('webgpu_dawn')) is None: raise FileNotFoundError("dawn library not found. " +