lines from helpers

This commit is contained in:
George Hotz
2023-03-03 23:07:46 -08:00
parent 81cda2b672
commit 893f136fe0
3 changed files with 12 additions and 11 deletions

5
extra/helpers.py Normal file
View File

@@ -0,0 +1,5 @@
import time
class Timing(object):
def __enter__(self): self.st = time.monotonic_ns()
def __exit__(self, exc_type, exc_val, exc_tb): print(f"{(time.monotonic_ns()-self.st)*1e-6:.2f} ms")

View File

@@ -1,4 +1,4 @@
import os, math, functools, time
import os, math, functools
from typing import Tuple, Union, List
def dedup(x): return list(dict.fromkeys(x)) # retains list order
@@ -11,15 +11,8 @@ def partition(lst, fxn): return [x for x in lst if fxn(x)], [x for x in lst if n
def make_pair(x:Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else x
def flatten(l): return [item for sublist in l for item in sublist]
class Timing(object):
def __enter__(self): self.st = time.monotonic_ns()
def __exit__(self, exc_type, exc_val, exc_tb): print(f"{(time.monotonic_ns()-self.st)*1e-6:.2f} ms")
@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
DEBUG = getenv("DEBUG", 0)
IMAGE = getenv("IMAGE", 0)
def shape_to_axis(old_shape, new_shape):
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
return tuple([i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b])

View File

@@ -1,8 +1,11 @@
import numpy as np
import operator
from typing import ClassVar, Callable, Dict
from typing import ClassVar, Callable, Dict, Tuple
from tinygrad.ops import UnaryOps, BinaryOps, MovementOps, ReduceOps, FusedOps, InterpretedBuffer, Op
from tinygrad.helpers import shape_to_axis
def shape_to_axis(old_shape:Tuple[int, ...], new_shape:Tuple[int, ...]) -> Tuple[int, ...]:
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
return tuple(i for i,(a,b) in enumerate(zip(old_shape, new_shape)) if a != b)
base_fxn_for_op : Dict[Op, Callable] = {
UnaryOps.NEG: lambda x: -x, UnaryOps.NOT: lambda x: (1.0 - x),