clean up some functions in helpers [pr] (#13942)

This commit is contained in:
chenyu
2025-12-31 18:29:16 -05:00
committed by GitHub
parent e2987001ee
commit 0ed58c1fcd

View File

@@ -38,15 +38,11 @@ def ansilen(s:str): return len(ansistrip(s))
def make_tuple(x:int|Sequence[int], cnt:int) -> tuple[int, ...]: return (x,)*cnt if isinstance(x, int) else tuple(x)
def flatten(l:Iterable[Iterable[T]]): return [item for sublist in l for item in sublist]
def fully_flatten(l):
if hasattr(l, "__len__") and hasattr(l, "__getitem__") and not isinstance(l, str):
if hasattr(l, "shape") and l.shape == (): return [l[()]]
flattened = []
for li in l: flattened.extend(fully_flatten(li))
return flattened
return [l]
if not (hasattr(l, "__len__") and hasattr(l, "__getitem__")) or isinstance(l, str): return [l]
return [l[()]] if hasattr(l, "shape") and l.shape == () else [x for li in l for x in fully_flatten(li)]
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
def _is_balanced(s:str) -> bool: return (d := 0, all((d := d + (c == '(') - (c == ')')) >= 0 for c in s))[1] and d == 0
def strip_parens(fst:str) -> str: return fst[1:-1] if fst and fst[0]=='(' and fst[-1] == ')' and _is_balanced(fst[1:-1]) else fst
def strip_parens(fst:str) -> str: return fst[1:-1] if fst[:1]=='(' and fst[-1:]==')' and _is_balanced(fst[1:-1]) else fst
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
def round_down(num:int, amt:int) -> int: return -round_up(-num, amt)
@@ -88,9 +84,7 @@ def word_wrap(x, wrap=80):
while len(ansistrip(x[:i])) < wrap and i < len(x): i += 1
return x[:i] + "\n" + word_wrap(x[i:], wrap)
def pad_bytes(b:bytes, align:int) -> bytes: return b + b'\x00' * ((align - (len(b) % align)) % align)
def panic(e:Exception|None=None):
if e is None: raise RuntimeError("PANIC!")
raise e
def panic(e:Exception|None=None): raise e if e is not None else RuntimeError("PANIC!")
@functools.cache
def canonicalize_strides(shape:tuple[T, ...], strides:tuple[T, ...]) -> tuple[T, ...]:
@@ -150,9 +144,7 @@ def getenv(key:str, default:Any=0): return type(default)(os.getenv(key, default)
def temp(x:str, append_user:bool=False) -> str:
return (pathlib.Path(tempfile.gettempdir()) / (f"{x}.{getpass.getuser()}" if append_user else x)).as_posix()
def stderr_log(msg):
sys.stderr.write(msg)
sys.stderr.flush()
def stderr_log(msg:str): print(msg, end='', file=sys.stderr, flush=True)
class Context(contextlib.ContextDecorator):
def __init__(self, **kwargs): self.kwargs = kwargs