mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
small helpers cleanups (#7977)
less lines for ceildiv and partition, and removed one # noqa: E501
This commit is contained in:
@@ -41,9 +41,7 @@ def fully_flatten(l):
|
||||
return [l]
|
||||
def fromimport(mod, frm): return getattr(__import__(mod, fromlist=[frm]), frm)
|
||||
def strip_parens(fst:str): return fst[1:-1] if fst[0] == '(' and fst[-1] == ')' and fst[1:-1].find('(') <= fst[1:-1].find(')') else fst
|
||||
def ceildiv(num, amt):
|
||||
ret = -(num//-amt)
|
||||
return ret if not isinstance(ret, float) else int(ret)
|
||||
def ceildiv(num, amt): return int(ret) if isinstance((ret:=-(num//-amt)), float) else ret
|
||||
def round_up(num:int, amt:int) -> int: return (num+amt-1)//amt * amt
|
||||
def data64(data:Any) -> Tuple[Any, Any]: return (data >> 32, data & 0xFFFFFFFF) # Any is sint
|
||||
def data64_le(data:Any) -> Tuple[Any, Any]: return (data & 0xFFFFFFFF, data >> 32) # Any is sint
|
||||
@@ -52,10 +50,9 @@ def merge_dicts(ds:Iterable[Dict[T,U]]) -> Dict[T,U]:
|
||||
assert len(kvs) == len(set(kv[0] for kv in kvs)), f"cannot merge, {kvs} contains different values for the same key"
|
||||
return {k:v for d in ds for k,v in d.items()}
|
||||
def partition(itr:Iterable[T], fxn:Callable[[T],bool]) -> Tuple[List[T], List[T]]:
|
||||
a:List[T] = []
|
||||
b:List[T] = []
|
||||
for s in itr: (a if fxn(s) else b).append(s)
|
||||
return a,b
|
||||
ret:Tuple[List[T], List[T]] = ([], [])
|
||||
for s in itr: (ret[0] if fxn(s) else ret[1]).append(s)
|
||||
return ret
|
||||
def unwrap(x:Optional[T]) -> T:
|
||||
assert x is not None
|
||||
return x
|
||||
@@ -268,7 +265,8 @@ def from_mv(mv:memoryview, to_type=ctypes.c_char):
|
||||
return ctypes.cast(ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type * len(mv))).contents
|
||||
def to_mv(ptr:int, sz:int) -> memoryview: return memoryview(ctypes.cast(ptr, ctypes.POINTER(ctypes.c_uint8 * sz)).contents).cast("B")
|
||||
def mv_address(mv:memoryview): return ctypes.addressof(ctypes.c_char.from_buffer(mv))
|
||||
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char): return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options]) # noqa: E501
|
||||
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char):
|
||||
return (ctypes.POINTER(to_type) * len(options))(*[ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type)) for o in options])
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
|
||||
class CStruct(ctypes.Structure):
|
||||
|
||||
Reference in New Issue
Block a user