mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
clean up some functions in helpers [pr] (#13942)
This commit is contained in:
@@ -38,15 +38,11 @@ def ansilen(s:str): return len(ansistrip(s))
|
||||
def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
|
||||
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
|
||||
def fully_flatten(l):
|
||||
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
|
||||
if hasattr(l, "shape") and l.shape == (): return [l[()]]
|
||||
flattened = []
|
||||
for li in l: flattened.extend(fully_flatten(li))
|
||||
return flattened
|
||||
return [l]
|
||||
if not (hasattr(l, "__len__") and hasattr(l, "__getitem__")) or isinstance(l, str): return [l]
|
||||
return [l[()]] if hasattr(l, "shape") and l.shape == () else [x for li in l for x in fully_flatten(li)]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0
|
||||
def strip_parens(fst:str) -> str: return fst[1:-1] if fst and fst[0]=='(' and fst[-1] == ')' and _is_balanced(fst[1:-1]) else fst
|
||||
def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst
|
||||
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
||||
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
||||
def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
|
||||
@@ -88,9 +84,7 @@ def word_wrap(x, wrap=80):
|
||||
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
|
||||
return x[:i] + "\n" + word_wrap(x[i:], wrap)
|
||||
def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align)
|
||||
def panic(e:Exception|None=None):
|
||||
if e is None: raise RuntimeError("PANIC!")
|
||||
raise e
|
||||
def panic(e:Exception|None=None): raise e if e is not None else RuntimeError("PANIC!")
|
||||
|
||||
@functools.cache
|
||||
def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]:
|
||||
@@ -150,9 +144,7 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default)
|
||||
def temp(x:str, append_user:bool=False) -> str:
|
||||
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
|
||||
|
||||
def stderr_log(msg):
|
||||
sys.stderr.write(msg)
|
||||
sys.stderr.flush()
|
||||
def stderr_log(msg:str): print(msg, end='', file=sys.stderr, flush=True)
|
||||
|
||||
class Context(contextlib.ContextDecorator):
|
||||
def __init__(self, **kwargs): self.kwargs = kwargs
|
||||
|
||||
Reference in New Issue
Block a user