mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix opencl types
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import functools, platform
|
||||
import numpy as np
|
||||
import pyopencl as cl # type: ignore
|
||||
from typing import Dict, Optional, Tuple, List, ClassVar
|
||||
from typing import Dict, Optional, Tuple, List, ClassVar, Final
|
||||
from collections import defaultdict
|
||||
from tinygrad.ops import DEBUG, GlobalCounters
|
||||
from tinygrad.helpers import getenv
|
||||
@@ -17,7 +17,7 @@ class CL:
|
||||
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
|
||||
cl_ctx : ClassVar[Optional[cl.Context]] = None
|
||||
cl_queue : ClassVar[Optional[cl.CommandQueue]] = None
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
if CL.cl_queue is not None: return # already initted
|
||||
devices : List[cl.Device] = sum([x.get_devices(device_type=cl.device_type.GPU) for x in cl.get_platforms()], [])
|
||||
if len(devices) == 0: # settle for CPU
|
||||
@@ -50,7 +50,7 @@ class CLBuffer:
|
||||
def copyout(self, a:np.ndarray): CL.enqueue_copy(a, self._cl, True)
|
||||
|
||||
class CLImage:
|
||||
fmt = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
|
||||
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
|
||||
|
||||
def __init__(self, shape):
|
||||
self._cl = cl.Image(CL().cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
|
||||
@@ -64,7 +64,7 @@ class CLImage:
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
class CLProgram:
|
||||
kernel_cnt : Dict[str, int] = defaultdict(int)
|
||||
kernel_cnt : Final[Dict[str, int]] = defaultdict(int)
|
||||
def __init__(self, name:str, prg:str, options:Tuple[str, ...]=tuple(), argdtypes=None, rename=True, binary=False, op_estimate=0, mem_estimate=0):
|
||||
self.name = f"{name}{('_N'+str(CLProgram.kernel_cnt[name])) if CLProgram.kernel_cnt[name] else str()}" if rename else name
|
||||
self.prg, self.options, self.argdtypes, self.op_estimate, self.mem_estimate = prg.replace(f"{name}(", f"{self.name}(") if rename else prg, options, argdtypes, op_estimate, mem_estimate
|
||||
|
||||
Reference in New Issue
Block a user