mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
WebGPU on Windows (#10890)
* WebGPU on Windows * Fix dawn-python install * New test * pydeps * Minor fix * Only install dawn-python on windows webgpu --------- Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
7
.github/workflows/test.yml
vendored
7
.github/workflows/test.yml
vendored
@@ -938,7 +938,7 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
backend: [llvm, cpu]
|
||||
backend: [llvm, cpu, webgpu]
|
||||
|
||||
name: Windows (${{ matrix.backend }})
|
||||
runs-on: windows-latest
|
||||
@@ -951,11 +951,12 @@ jobs:
|
||||
- name: Setup Environment
|
||||
uses: ./.github/actions/setup-tinygrad
|
||||
with:
|
||||
key: windows-minimal
|
||||
key: windows-${{ matrix.backend }}-minimal
|
||||
deps: testing_unit
|
||||
pydeps: ${{ matrix.backend == 'webgpu' && 'dawn-python' || '' }}
|
||||
- name: Set env
|
||||
shell: bash
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1'}}" >> $GITHUB_ENV
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'cpu' && 'CPU=1' || matrix.backend == 'webgpu' && 'WEBGPU=1'}}" >> $GITHUB_ENV
|
||||
- name: Run unit tests
|
||||
if: matrix.backend=='llvm'
|
||||
run: python -m pytest -n=auto test/unit/ --ignore=test/unit/test_disk_tensor.py --ignore=test/unit/test_elf.py --ignore=test/unit/test_tar.py
|
||||
|
||||
@@ -12,7 +12,7 @@ tinygrad supports various runtimes, enabling your code to scale across a wide ra
|
||||
| [GPU (OpenCL)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_gpu.py) | Accelerates computations using OpenCL on GPUs | OpenCL 2.0 compatible device |
|
||||
| [CPU (C Code)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_cpu.py) | Runs on CPU using the clang compiler | `clang` compiler in system `PATH` |
|
||||
| [LLVM (LLVM IR)](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_llvm.py) | Runs on CPU using the LLVM compiler infrastructure | llvm libraries installed and findable |
|
||||
| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | Dawn library installed and findable. Download binaries [here](https://github.com/wpmed92/pydawn/releases/tag/v0.1.6). |
|
||||
| [WEBGPU](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/runtime/ops_webgpu.py) | Runs on GPU using the Dawn WebGPU engine (used in Google Chrome) | Dawn library installed and findable. Download binaries [here](https://github.com/wpmed92/pydawn/releases/tag/v0.3.0). |
|
||||
|
||||
## Interoperability
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import time, math, unittest, functools, warnings
|
||||
import time, math, unittest, functools, platform, warnings
|
||||
import numpy as np
|
||||
from typing import List, Callable
|
||||
import torch
|
||||
@@ -826,6 +826,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
|
||||
helper_test_op(None, lambda x: x.sin(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
|
||||
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend")
|
||||
def test_cos(self):
|
||||
helper_test_op([(45,65)], lambda x: x.cos())
|
||||
helper_test_op([()], lambda x: x.cos())
|
||||
@@ -833,6 +834,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
|
||||
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
|
||||
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend")
|
||||
def test_tan(self):
|
||||
# NOTE: backward has much higher diff with input close to pi/2 and -pi/2
|
||||
helper_test_op([(45,65)], lambda x: x.tan(), low=-1.5, high=1.5)
|
||||
@@ -1281,7 +1283,8 @@ class TestOps(unittest.TestCase):
|
||||
np.arange(64,128,dtype=np.float32).reshape(8,8)])
|
||||
def test_small_gemm_eye(self):
|
||||
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE, "not supported on these in CI/IMAGE")
|
||||
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE
|
||||
or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE")
|
||||
def test_gemm_fp16(self):
|
||||
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3)
|
||||
def test_gemm(self):
|
||||
|
||||
@@ -1,10 +1,16 @@
|
||||
import ctypes, ctypes.util, os, subprocess, platform
|
||||
import ctypes, ctypes.util, os, subprocess, platform, sysconfig
|
||||
from tinygrad.helpers import OSX
|
||||
|
||||
WEBGPU_PATH: str | None
|
||||
|
||||
if OSX:
|
||||
if not os.path.exists(brew_prefix:=subprocess.check_output(['brew', '--prefix', 'dawn']).decode().strip()):
|
||||
raise FileNotFoundError('dawn library not found. Install it with `brew tap wpmed92/dawn && brew install dawn`')
|
||||
WEBGPU_PATH: str|None = os.path.join(brew_prefix, 'lib', 'libwebgpu_dawn.dylib')
|
||||
WEBGPU_PATH = os.path.join(brew_prefix, 'lib', 'libwebgpu_dawn.dylib')
|
||||
elif platform.system() == "Windows":
|
||||
if not os.path.exists(pydawn_path:=os.path.join(sysconfig.get_paths()["purelib"], "pydawn")):
|
||||
raise FileNotFoundError("dawn library not found. Install it with `pip install dawn-python`")
|
||||
WEBGPU_PATH = os.path.join(pydawn_path, "lib", "libwebgpu_dawn.dll")
|
||||
else:
|
||||
if (WEBGPU_PATH:=ctypes.util.find_library('webgpu_dawn')) is None:
|
||||
raise FileNotFoundError("dawn library not found. " +
|
||||
|
||||
Reference in New Issue
Block a user