Force WebGPU backend type [pr] (#9164)

* Force webgpu backend type

* Mypy fix

* Rename to WEBGPU_BACKEND

* Add it to env_vars docs

* Remove link
This commit is contained in:
Ahmed Harmouche
2025-02-19 10:19:39 +01:00
committed by GitHub
parent 4bc708a9b0
commit 0f94b98646
3 changed files with 10 additions and 4 deletions

View File

@@ -449,7 +449,7 @@ jobs:
WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add
- name: Run selected webgpu tests
run: |
WEBGPU=1 python3 -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit \
WEBGPU=1 WEBGPU_BACKEND="WGPUBackendType_Vulkan" python3 -m pytest -n=auto test/ --ignore=test/models --ignore=test/unit \
--ignore=test/test_copy_speed.py --ignore=test/test_rearrange_einops.py \
--ignore=test/test_fuzz_shape_ops.py --ignore=test/test_linearizer_failures.py --durations=20
- name: Run process replay tests
@@ -565,7 +565,7 @@ jobs:
key: osx-webgpu
webgpu: 'true'
- name: Build WEBGPU Efficientnet
run: WEBGPU=1 python3 -m examples.compile_efficientnet
run: WEBGPU=1 WEBGPU_BACKEND="WGPUBackendType_Metal" python3 -m examples.compile_efficientnet
- name: Clean npm cache
run: npm cache clean --force
- name: Install Puppeteer

View File

@@ -48,4 +48,5 @@ PROFILE | [1] | enable profiling. This feature is supported i
VISIBLE_DEVICES | [list[int]]| restricts the NV/AMD devices that are available. The format is a comma-separated list of identifiers (indexing starts with 0).
JIT | [0-2] | 0=disabled, 1=[jit enabled](quickstart.md#jit) (default), 2=jit enabled, but graphs are disabled
VIZ | [1] | 0=disabled, 1=[viz enabled](https://github.com/tinygrad/tinygrad/tree/master/tinygrad/viz)
ALLOW_TF32 | [1] | enable TensorFloat-32 tensor cores on Ampere or newer GPUs.
ALLOW_TF32 | [1] | enable TensorFloat-32 tensor cores on Ampere or newer GPUs.
WEBGPU_BACKEND | [WGPUBackendType_Metal, ...] | Force select a backend for WebGPU (Metal, DirectX, OpenGL, Vulkan...)

View File

@@ -5,6 +5,9 @@ from tinygrad.helpers import round_up, OSX
from tinygrad.runtime.autogen import webgpu
from typing import List, Any
import ctypes
import os
backend_types = {v: k for k, v in webgpu.WGPUBackendType__enumvalues.items() }
try:
instance = webgpu.wgpuCreateInstance(webgpu.WGPUInstanceDescriptor(features = webgpu.WGPUInstanceFeatures(timedWaitAnyEnable = True)))
@@ -193,7 +196,9 @@ class WebGpuDevice(Compiled):
# Requesting an adapter
adapter_res = _run(webgpu.wgpuInstanceRequestAdapterF, webgpu.WGPURequestAdapterCallbackInfo, webgpu.WGPURequestAdapterCallback,
webgpu.WGPURequestAdapterStatus__enumvalues, 1, 2, instance,
webgpu.WGPURequestAdapterOptions(powerPreference=webgpu.WGPUPowerPreference_HighPerformance))
webgpu.WGPURequestAdapterOptions(powerPreference=webgpu.WGPUPowerPreference_HighPerformance,
backendType=backend_types.get(os.getenv("WEBGPU_BACKEND", ""), 0)))
# Get supported features
supported_features = webgpu.WGPUSupportedFeatures()