From 3072e098c0e482d4884bc956fd9fd1c61dcbcf3a Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sun, 5 Mar 2023 12:08:12 -0800 Subject: [PATCH] local workgroup optimizer --- tinygrad/ops.py | 12 +++++++++++- tinygrad/runtime/ops_gpu.py | 3 +++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index 7f4d85d0b3..9df5bfe193 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -2,7 +2,7 @@ from __future__ import annotations import numpy as np from enum import Enum, auto from typing import Union, Type, NamedTuple, Tuple, Any, List, ClassVar, Optional, Callable, Dict, TypeVar, Set -import functools, operator +import functools, itertools, operator, random from tinygrad.helpers import prod, DEBUG, getenv from tinygrad.shape import ShapeTracker @@ -112,8 +112,18 @@ class ASTRunner: def build(self, runtime): self.clprg = runtime(self.name, self.prg) return self + def timeit(self, bufs, local_override=None) -> float: + try: return self.clprg(self.global_size, local_override if local_override is not None else self.local_size, *bufs, wait=True) + except Exception: return float('inf') + def optimize_local_size(self, bufs) -> List[int]: + assert self.global_size is not None, "needs a global size to optimize local size" + MAX_WORKGROUP = self.clprg.max_work_group_size() if hasattr(self.clprg, 'max_work_group_size') else 1024 + local_dims = [[x for x in set([sz, 1, 2, 4, 8, 16, 32, 64, 128, 256, MAX_WORKGROUP]) if x<=sz] for sz in self.global_size] + local_sizes = [list(x) for x in itertools.product(*local_dims) if prod(x) <= MAX_WORKGROUP] * 2 # try each valid size twice + return min([(self.timeit(bufs, local_size), local_size) for local_size in random.sample(local_sizes, len(local_sizes))])[1] def lower(self, bufs) -> List[RawBuffer]: return [x.raw() for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete] def __call__(self, bufs): + if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(bufs) et = self.clprg(self.global_size, self.local_size, *bufs, wait=DEBUG>=2) if et is not None: GlobalCounters.time_sum_s += et if DEBUG >= 1: diff --git a/tinygrad/runtime/ops_gpu.py b/tinygrad/runtime/ops_gpu.py index 43f1f44d70..38a0558cd1 100644 --- a/tinygrad/runtime/ops_gpu.py +++ b/tinygrad/runtime/ops_gpu.py @@ -76,6 +76,9 @@ class CLProgram: print(binary.decode('utf-8')) if self.argdtypes is not None: self.clprg.set_scalar_arg_dtypes(self.argdtypes) + @staticmethod + def max_work_group_size(): return CL.cl_ctx.devices[0].max_work_group_size + def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]: e = self.clprg(CL.cl_queue, global_size, local_size, *[x._cl if isinstance(x, (CLBuffer, CLImage)) else x for x in bufs]) if wait: