simple runtime args (#2211)

* simple runtime args

* fix some tests

* fix abstractions and triton

* fix search
This commit is contained in:
George Hotz
2023-11-03 12:31:29 -07:00
committed by GitHub
parent 9ea0448103
commit f17bc16f46
16 changed files with 37 additions and 24 deletions

View File

@@ -4,7 +4,7 @@ os.environ['PYOPENCL_NO_CACHE'] = '1'
import pathlib
import numpy as np
import pyopencl as cl # type: ignore
from typing import Optional, List
from typing import Optional, List, Tuple
from tinygrad.helpers import DEBUG, getenv, prod, ImageDType, OSX, fromimport, diskcache
from tinygrad.ops import Compiled
from tinygrad.renderer.opencl import OpenCLRenderer
@@ -90,7 +90,7 @@ class CLProgram:
@staticmethod
def max_work_group_size(): return CL.cl_ctxs[0].devices[0].max_work_group_size
def __call__(self, global_size, local_size, *bufs, wait=False) -> Optional[float]:
def __call__(self, *bufs, global_size:Tuple[int,int,int], local_size:Optional[Tuple[int,int,int]]=None, wait=False) -> Optional[float]:
if not hasattr(self, 'argdtypes'): self.set_argdtypes(tuple(None if x.__class__ is CLBuffer else np.int32 for x in bufs))
cl_bufs, wait_for = [], []
for x in bufs: