From 8cbcd1b342e9d65eff682e9384d71af54f79b097 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Mon, 8 Jan 2024 09:10:07 -0800 Subject: [PATCH] Remove webgpu, back to 5k lines (#3040) * remove webgpu * max 5000 lines --- .github/workflows/test.yml | 84 +++++++++---------- README.md | 2 - .../runtime => extra/backends}/ops_webgpu.py | 0 extra/{triton => backends}/triton.py | 0 setup.py | 1 - tinygrad/lazy.py | 5 +- tinygrad/renderer/cstyle.py | 35 -------- tinygrad/tensor.py | 1 - 8 files changed, 44 insertions(+), 84 deletions(-) rename {tinygrad/runtime => extra/backends}/ops_webgpu.py (100%) rename extra/{triton => backends}/triton.py (100%) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f4ac6b0b39..0deee41e1a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -63,8 +63,8 @@ jobs: source venv/bin/activate pip install $GITHUB_WORKSPACE python -c "from tinygrad.tensor import Tensor; print(Tensor([1,2,3,4,5]))" - - name: Repo line count <6000 lines - run: MAX_LINE_COUNT=6000 python sz.py + - name: Repo line count <5000 lines + run: MAX_LINE_COUNT=5000 python sz.py testcpuimagenet: name: CPU and ImageNet to C Tests @@ -214,48 +214,48 @@ jobs: name: Test Beam Search run: PYTHONPATH="." GPU=1 IGNORE_BEAM_CACHE=1 python3 -m pytest extra/optimization/test_beam_search.py - testwebgpu: - name: WebGPU Tests - runs-on: macos-13 - timeout-minutes: 20 - steps: - - name: Checkout Code - uses: actions/checkout@v3 - - name: Set up Python 3.11 - uses: actions/setup-python@v4 - with: - python-version: 3.11 - - name: Cache python packages - uses: actions/cache@v3 - with: - path: /Users/runner/Library/Python/3.11/lib/python/site-packages - key: webgpu-testing-user3-packages-${{ hashFiles('**/setup.py') }} - - name: Install Dependencies - run: pip install --user -e '.[webgpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu - - name: Cache downloads - uses: actions/cache@v3 - with: - path: ~/Library/Caches/tinygrad/downloads/ - key: downloads-cache-webgpu-${{ env.DOWNLOAD_CACHE_VERSION }} - - name: Check Device.DEFAULT (WEBGPU) and print some source - run: | - WEBGPU=1 python -c "from tinygrad import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT" - WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add + #testwebgpu: + # name: WebGPU Tests + # runs-on: macos-13 + # timeout-minutes: 20 + # steps: + # - name: Checkout Code + # uses: actions/checkout@v3 + # - name: Set up Python 3.11 + # uses: actions/setup-python@v4 + # with: + # python-version: 3.11 + # - name: Cache python packages + # uses: actions/cache@v3 + # with: + # path: /Users/runner/Library/Python/3.11/lib/python/site-packages + # key: webgpu-testing-user3-packages-${{ hashFiles('**/setup.py') }} + # - name: Install Dependencies + # run: pip install --user -e '.[webgpu,testing]' --extra-index-url https://download.pytorch.org/whl/cpu + # - name: Cache downloads + # uses: actions/cache@v3 + # with: + # path: ~/Library/Caches/tinygrad/downloads/ + # key: downloads-cache-webgpu-${{ env.DOWNLOAD_CACHE_VERSION }} + # - name: Check Device.DEFAULT (WEBGPU) and print some source + # run: | + # WEBGPU=1 python -c "from tinygrad import Device; assert Device.DEFAULT == 'WEBGPU', Device.DEFAULT" + # WEBGPU=1 DEBUG=4 FORWARD_ONLY=1 python3 test/test_ops.py TestOps.test_add #- name: Run webgpu pytest # run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto - - name: Run selected webgpu tests - run: | - WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_ops.py test/test_dtype.py \ - test/test_jit.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_linearizer.py \ - test/test_linearizer_failures.py test/test_nn.py - - name: Build WEBGPU Efficientnet - run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet - - name: Install Puppeteer - run: npm install puppeteer - - name: Run WEBGPU Efficientnet - run: node test/web/test_webgpu.js - - name: Test LLaMA compile speed - run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py + # - name: Run selected webgpu tests + # run: | + # WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m pytest -n=auto test/test_ops.py test/test_dtype.py \ + # test/test_jit.py test/test_symbolic_ops.py test/test_symbolic_jit.py test/test_linearizer.py \ + # test/test_linearizer_failures.py test/test_nn.py + # - name: Build WEBGPU Efficientnet + # run: WEBGPU=1 WGPU_BACKEND_TYPE=Metal python -m examples.compile_efficientnet + # - name: Install Puppeteer + # run: npm install puppeteer + # - name: Run WEBGPU Efficientnet + # run: node test/web/test_webgpu.js + # - name: Test LLaMA compile speed + # run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py testmetal: name: Metal Tests diff --git a/README.md b/README.md index 332991c74d..3f6320fbea 100644 --- a/README.md +++ b/README.md @@ -82,10 +82,8 @@ tinygrad already supports numerous accelerators, including: - [x] [LLVM](tinygrad/runtime/ops_llvm.py) - [x] [METAL](tinygrad/runtime/ops_metal.py) - [x] [CUDA](tinygrad/runtime/ops_cuda.py) -- [x] [Triton](extra/triton/triton.py) - [x] [PyTorch](tinygrad/runtime/ops_torch.py) - [x] [HIP](tinygrad/runtime/ops_hip.py) -- [x] [WebGPU](tinygrad/runtime/ops_webgpu.py) And it is easy to add more! Your accelerator of choice only needs to support a total of ~25 low level ops. More information can be found in the [documentation for adding new accelerators](/docs/adding_new_accelerators.md). diff --git a/tinygrad/runtime/ops_webgpu.py b/extra/backends/ops_webgpu.py similarity index 100% rename from tinygrad/runtime/ops_webgpu.py rename to extra/backends/ops_webgpu.py diff --git a/extra/triton/triton.py b/extra/backends/triton.py similarity index 100% rename from extra/triton/triton.py rename to extra/backends/triton.py diff --git a/setup.py b/setup.py index 9e69516d13..3ac0bd12aa 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,6 @@ setup(name='tinygrad', 'llvm': ["llvmlite"], 'arm': ["unicorn"], 'triton': ["triton-nightly>=2.1.0.dev20231014192330"], - 'webgpu': ["wgpu>=v0.12.0"], 'linting': [ "pylint", "mypy", diff --git a/tinygrad/lazy.py b/tinygrad/lazy.py index 38688ef0b9..a302f48a89 100644 --- a/tinygrad/lazy.py +++ b/tinygrad/lazy.py @@ -99,8 +99,7 @@ class LazyBuffer: return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st) # if it's a shrink, do the shrink before the copy with CONTIGUOUS - # TODO: why is this required on WEBGPU? - if prod(self.st.shape) < prod(self.base.st.shape) or device == "WEBGPU": + if prod(self.st.shape) < prod(self.base.st.shape): return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, LoadOps.COPY, srcs=(self.contiguous(),)) # copy the base and apply the shapetracker on the new device @@ -118,7 +117,7 @@ class LazyBuffer: if op == TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool" out_dtype = srcs[-1].dtype if op not in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else dtypes.bool ret = create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs)) - return ret.cast(dtypes.float32) if (out_dtype == dtypes.bool and self.device == "WEBGPU") else ret + return ret # *** reduce ops *** diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index ed8d58c68d..55ae4bff02 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -289,38 +289,3 @@ __device__ half16 make_half16(half x, half y, half z, half w, half a, half b, ha """ type_map = {dtypes.bfloat16: "hip_bfloat16"} HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) - -# TODO: how much of this can be merged with above? -class WGSLLanguage(CStyleLanguage): - code_for_workitem = {"g": lambda x: f"i32(gindex.{'xyz'[x]})", "l": lambda x: f"i32(lindex.{'xyz'[x]})"} - size_prefix = "let" - barrier="workgroupBarrier();" - generic_var_prefix = "var " - external_local_bufs = True - code_for_op = { **CStyleLanguage().code_for_op, - BinaryOps.CMPLT: lambda x,y,dtype: f"f32({x}<{y})", BinaryOps.CMPEQ: lambda x,y,dtype: f"f32({x}=={y})", - TernaryOps.MULACC: lambda x,y,z,dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c,dtype: f"select({c},{b},bool({a}))" } - # HACK: write bool as f32 - type_map = {dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "f32"} - - def render_local(self, name: str, dtype:DType, size: int): return f"var {name}: array<{self.type_map[dtype]},{size}>;" - - def render_const(self, x:Union[float,int], var_dtype) -> str: - if math.isnan(x): return "nan()" - elif math.isinf(x): return ("-" if x < 0 else "") + "inf(1.0)" - return f"({super().render_const(x, var_dtype)})" - - def render_if(self, cond: str): return f"if (bool({cond})) {{" - - def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,DType]], local_size:List[int], prekernel:List[str]) -> str: - local_size = local_size[::-1] if local_size else [1] - bind_it = iter(range(len(bufs))) - prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast(bits); }\nfn inf(a: f32) -> f32 { return a/0.0; }\n" - prg += "\n".join(prekernel+[f"@group(0) @binding({next(bind_it)}) {'var' if isinstance(dtype, PtrDType) else 'var'} {name}: {f'array<{self.type_map[dtype]}>' if isinstance(dtype, PtrDType) else 'i32'};" for name,dtype in bufs]) # noqa: E501 - prg += f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3, @builtin(local_invocation_id) lindex: vec3) {{\n" + "\n".join(kernel) + "\n}" # noqa: E501 - return prg - - def render_cast(self, x:List[str], var_dtype:DType, bitcast=False) -> str: - if self.type_map[var_dtype]: return f"bitcast<{self.type_map[var_dtype]}>({x[0]})" if bitcast else f"{self.type_map[var_dtype]}({x[0]})" - raise NotImplementedError(f"no cast for {var_dtype}") -WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage()) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index c05614cf2b..95a1a61bc7 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -863,7 +863,6 @@ class Tensor: def __imatmul__(self, x) -> Tensor: return self.assign(self.matmul(x)) def __ixor__(self, x) -> Tensor: return self.assign(self.xor(x)) - # in webgpu bool cannot be used as a storage buffer type def __lt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, False)) def __gt__(self, x) -> Tensor: return mlops.Less.apply(*self._broadcasted(x, True)) def __ge__(self, x) -> Tensor: return (self