fix opencl types

This commit is contained in:
George Hotz
2023-02-10 23:18:39 -06:00
parent fed95119dc
commit 6f9b103878

View File

@@ -1,7 +1,7 @@
import functools, platform
import numpy as np
import pyopencl as cl # type: ignore
from typing import Dict, Optional, Tuple, List, ClassVar
from typing import Dict, Optional, Tuple, List, ClassVar, Final
from collections import defaultdict
from tinygrad.ops import DEBUG, GlobalCounters
from tinygrad.helpers import getenv
@@ -17,7 +17,7 @@ class CL:
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
cl_ctx : ClassVar[Optional[cl.Context]] = None
cl_queue : ClassVar[Optional[cl.CommandQueue]] = None
def __init__(self):
def __init__(self) -> None:
if CL.cl_queue is not None: return # already initted
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
if len(devices) == 0: # settle for CPU
@@ -50,7 +50,7 @@ class CLBuffer:
def copyout(self, a:np.ndarray): CL.enqueue_copy(a, self._cl, True)
class CLImage:
fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
def __init__(self, shape):
self._cl = cl.Image(CL().cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
@@ -64,7 +64,7 @@ class CLImage:
@functools.lru_cache(maxsize=None)
class CLProgram:
kernel_cnt : Dict[str, int] = defaultdict(int)
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0):
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate