dtypes nice and clean (#673)

* add dtype class

* dtypes

* buffers are lazy

* dtype is tracked by lazybuffer and GenericShape

* fix types in llvm

* llvm store

* dtype tests

* fix tests maybe

* fix flop counter

* fix CI

* CI fix and check format

* fix dtype and dtype check

* fix custom test

* fix test graph
This commit is contained in:
George Hotz
2023-03-10 16:56:07 -08:00
committed by GitHub
parent d26345595d
commit 1826ff6b89
20 changed files with 215 additions and 141 deletions

View File

@@ -2,15 +2,13 @@ from __future__ import annotations
import platform, functools
import numpy as np
import pyopencl as cl # type: ignore
from typing import Dict, Optional, List, ClassVar, Final
from collections import defaultdict
from tinygrad.helpers import IMAGE, DEBUG, getenv
from typing import Optional, List, Final
from tinygrad.helpers import IMAGE, DEBUG, getenv, dtypes
from tinygrad.ops import CompiledBuffer, GlobalCounters, RawBufferCopyInOut, RawBuffer
from tinygrad.codegen.gpu import GPUCodegen, GPULanguage
OSX = platform.system() == "Darwin"
OSX_TIMING_RATIO = (125/3) if OSX else 1.0 # see test/external_osx_profiling.py to determine this ratio. it's in like GPU clocks or something
CLCACHE = getenv("CLCACHE", 1)
FLOAT16 = getenv("FLOAT16", 0)
class _CL:
@@ -27,33 +25,18 @@ class _CL:
CL = _CL()
class CLBuffer(RawBufferCopyInOut):
# TODO: this can be in RawBuffer generically
BUFFER_CACHE : ClassVar[Dict[int, List[cl.Buffer]]] = defaultdict(list)
def __init__(self, size): # pylint: disable=super-init-not-called
self.size = size
if len(CLBuffer.BUFFER_CACHE[size]) > 0:
self._cl = CLBuffer.BUFFER_CACHE[size].pop()
else:
# TODO: on GPU OOM, clear the cache
self._cl = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, size)
GlobalCounters.mem_used += self._cl.size
def __del__(self):
if CLCACHE: CLBuffer.BUFFER_CACHE[self._cl.size].append(self._cl)
else: GlobalCounters.mem_used -= self._cl.size
def __init__(self, size, dtype):
super().__init__(size, dtype)
self._cl = cl.Buffer(CL.cl_ctx, cl.mem_flags.READ_WRITE, self._memsz)
def copyin(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, self._cl, x, is_blocking=False)
def copyout(self, x:np.ndarray): cl.enqueue_copy(CL.cl_queue, x, self._cl, is_blocking=True)
class CLImage(RawBuffer): # pylint: disable=abstract-method
fmt : Final = cl.ImageFormat(cl.channel_order.RGBA, cl.channel_type.HALF_FLOAT if FLOAT16 else cl.channel_type.FLOAT)
IMAGE : Final = True
def __init__(self, shape): # pylint: disable=super-init-not-called
self.size, self._cl = shape, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, CLImage.fmt, shape=(shape[1], shape[0]))
def __init__(self, shape, dtype=dtypes.float16 if getenv("FLOAT16") else dtypes.float32): # pylint: disable=super-init-not-called
fmt = cl.ImageFormat(cl.channel_order.RGBA, {dtypes.float16: cl.channel_type.HALF_FLOAT, dtypes.float32: cl.channel_type.FLOAT}[dtype])
self.size, self.dtype, self._cl = shape, dtype, cl.Image(CL.cl_ctx, cl.mem_flags.READ_WRITE, fmt, shape=(shape[1], shape[0]))
GlobalCounters.mem_used += self._cl.row_pitch * self._cl.height
def __del__(self): GlobalCounters.mem_used -= self._cl.row_pitch * self._cl.height
@functools.lru_cache(maxsize=None)
@@ -89,6 +72,7 @@ class CLProgram:
class CLCodegen(GPUCodegen):
lang = GPULanguage(
kernel_prefix = "__kernel", buffer_prefix = "__global ", smem_prefix = "__local ",
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable",
barrier = "barrier(CLK_LOCAL_MEM_FENCE);", float4 = "(float4)",
gid = [f'get_global_id({i})' for i in range(3)], lid = [f'get_local_id({i})' for i in range(3)])
@@ -96,8 +80,8 @@ class GPUBuffer(CompiledBuffer):
raw_buffer_type = CLBuffer
# override this method for image
@classmethod
def create_raw_buffer(cls, shape, backing) -> RawBuffer:
if len(shape) == 3 and shape[2] == 4 and IMAGE >= 2 and backing is None: return CLImage(shape)
else: return super().create_raw_buffer(shape, backing)
def create_raw_buffer(cls, shape, backing, dtype) -> RawBuffer:
if len(shape) == 3 and shape[2] == 4 and IMAGE >= 2 and backing is None: return CLImage(shape) # NOTE: this is a hack. we don't pass in the dtype here, it's controlled by the FLOAT16 env var
else: return super().create_raw_buffer(shape, backing, dtype)
codegen_type = CLCodegen
runtime_type = CLProgram