WebGPU on Windows (#10890)

* WebGPU on Windows

* Fix dawn-python install

* New test

* pydeps

* Minor fix

* Only install dawn-python on windows webgpu

---------

Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
This commit is contained in:
Ahmed Harmouche
2025-07-02 17:38:45 +02:00
committed by GitHub
parent e67a6d2310
commit e992ed10dc
4 changed files with 18 additions and 8 deletions

View File

@@ -1,4 +1,4 @@
import time, math, unittest, functools, warnings
import time, math, unittest, functools, platform, warnings
import numpy as np
from typing import List, Callable
import torch
@@ -826,6 +826,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.sin(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend")
def test_cos(self):
helper_test_op([(45,65)], lambda x: x.cos())
helper_test_op([()], lambda x: x.cos())
@@ -833,6 +834,7 @@ class TestOps(unittest.TestCase):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf, 0.0]])
helper_test_op(None, lambda x: x.cos(), vals=[[1e1, 1e2, 1e3, 1e4, 1e5, 1e6, -1e1, -1e2, -1e3, -1e4, -1e5, -1e6]],
atol=3e-3, rtol=3e-3, grad_atol=3e-3, grad_rtol=3e-3)
@unittest.skipIf(Device.DEFAULT == "WEBGPU" and platform.system() == "Windows", "Not accurate enough with DirectX backend")
def test_tan(self):
# NOTE: backward has much higher diff with input close to pi/2 and -pi/2
helper_test_op([(45,65)], lambda x: x.tan(), low=-1.5, high=1.5)
@@ -1281,7 +1283,8 @@ class TestOps(unittest.TestCase):
np.arange(64,128,dtype=np.float32).reshape(8,8)])
def test_small_gemm_eye(self):
helper_test_op(None, lambda x,y: x.matmul(y), lambda x,y: x@y, vals=[np.eye(8).astype(np.float32), np.eye(8).astype(np.float32)])
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE, "not supported on these in CI/IMAGE")
@unittest.skipIf(CI and Device.DEFAULT in ["NV", "LLVM", "GPU", "CUDA"] or IMAGE
or (Device.DEFAULT == "WEBGPU" and platform.system() == "Windows"), "not supported on these in CI/IMAGE")
def test_gemm_fp16(self):
helper_test_op([(64,64), (64,64)], lambda x,y: x.half().matmul(y.half()), atol=5e-3, rtol=5e-3)
def test_gemm(self):