teeny changes (#1647)

* teeny changes

* teeny updates
This commit is contained in:
George Hotz
2023-08-23 09:53:39 -07:00
committed by GitHub
parent a6d842af7a
commit a89363574d
3 changed files with 20 additions and 15 deletions

View File

@@ -329,7 +329,6 @@ def _realize_contiguous(buffer: LazyBuffer) -> None:
# no need to run an AST, this is already contiguous
buffer.realized = realized
else:
# TODO: remove UnaryOps.NOOP, replace with LoadOps.CONTIGUOUS. confusing with Compiled though
buffer.op = LazyOp(UnaryOps.NOOP, buffer.op.src)
def _realize_custom(buffer: LazyBuffer) -> None:

View File

@@ -3,8 +3,6 @@ import time, importlib, inspect, functools, pathlib
from enum import Enum, auto
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored, dedup, merge_dicts
from tinygrad.shape.symbolic import Variable, sym_infer
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
if TYPE_CHECKING:
from tinygrad.lazy import LazyBuffer
@@ -80,15 +78,16 @@ class LazyOp:
# **************** Device ****************
class _Device:
def __init__(self) -> None:
self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
self.DEFAULT: str = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) or self._default_device()
def __init__(self) -> None: self._buffers: List[str] = [x.stem[len("ops_"):].upper() for x in (pathlib.Path(__file__).parent/"runtime").iterdir() if x.stem.startswith("ops_")]
def canonicalize(self, device:Optional[str]) -> str: return (device.split(":", 1)[0].upper() + ((":"+device.split(":", 1)[1]) if ':' in device else '')).replace(":0", "") if device is not None else self.DEFAULT
@functools.lru_cache(maxsize=None) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __getitem__(self, x:str) -> Union[Interpreted, Compiled]:
x = x.split(":")[0].upper()
return [cls for cname, cls in inspect.getmembers(importlib.import_module(f'tinygrad.runtime.ops_{x.lower()}')) if (cname.lower() == x.lower() + "buffer") and x in self._buffers][0]
def _default_device(self) -> str:
@functools.cached_property
def DEFAULT(self) -> str:
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None)
if device_from_env: return device_from_env
for device in ["METAL", "CUDA", "GPU"]:
try:
if self[device]: return device
@@ -123,6 +122,8 @@ class Interpreted:
return output.output_buffer
return ret
# --teenygrad--
class FlopCounter:
def __init__(self, tup:Tuple[Tuple[int, ...], DType, int]): self.shape, self.dtype, self.flops, self._buf = *tup, self
def consume_flops(self):
@@ -139,6 +140,9 @@ def get_lazyop_info(ast:LazyOp) -> FlopCounter: return InterpretedFlopCounter.ex
# **************** for Compiled Buffers ****************
from tinygrad.runtime.lib import RawBuffer, RawConst, buf_is_kernel_arg
from tinygrad.shape.symbolic import Variable, sym_infer
class ASTRunner:
def __init__(self, name, prg, global_size:Optional[List[int]]=None, local_size:Optional[List[int]]=None, op_estimate=0, mem_estimate=0, display_name:Optional[str]=None, runtime_args:Optional[dict]=None):
if DEBUG >= 4 and (runtime_args is None or 'binary' not in runtime_args or not runtime_args['binary']): print(prg)

View File

@@ -25,6 +25,16 @@ class RawBuffer: # pylint: disable=abstract-method
def fromCPU(cls:Type[_T], x:np.ndarray) -> _T: raise NotImplementedError("must be implemented")
def toCPU(self) -> np.ndarray: raise NotImplementedError("must be implemented")
class RawConst(RawBuffer): # pylint: disable=abstract-method
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
@property
def key(self): return (str(self._buf), self.dtype)
def buf_is_kernel_arg(x) -> bool:
return x.realized is not None and x.realized.__class__ is not RawConst
# --teenygrad--
class RawBufferCopyIn(RawBuffer):
def _copyin(self, x:np.ndarray) -> None: raise NotImplementedError("must be implemented")
@@ -62,14 +72,6 @@ class RawBufferTransfer(RawBuffer):
ret._transfer(x)
return ret
class RawConst(RawBuffer): # pylint: disable=abstract-method
def __repr__(self): return f"const<{self._buf}, {self.dtype}>"
@property
def key(self): return (str(self._buf), self.dtype)
def buf_is_kernel_arg(x) -> bool:
return x.realized is not None and x.realized.__class__ is not RawConst
class LRUAllocator:
def __init__(self, dev_memsz=(4<<30)):
self.epoch = 0