mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
fix dll.bind caching (#14168)
This commit is contained in:
committed by
GitHub
parent
f9ca072b61
commit
5abc262e22
@@ -165,14 +165,14 @@ class DLL(ctypes.CDLL):
|
||||
if DEBUG >= 3: print(f"loading {nm} failed: {e}")
|
||||
elif DEBUG >= 3: print(f"loading {nm} failed: not found on system")
|
||||
|
||||
@functools.cache
|
||||
def _get_func(self, name:str, args:tuple, res):
|
||||
(fn:=getattr(self, name)).argtypes, fn.restype = args, res
|
||||
return fn
|
||||
|
||||
def bind(self, fn):
|
||||
restype, argtypes = del_an((hints:=get_type_hints(fn, include_extras=True)).pop('return', None)), tuple(del_an(h) for h in hints.values())
|
||||
return lambda *args: self._get_func(fn.__name__, argtypes, restype)(*args)
|
||||
cfunc = None
|
||||
def wrapper(*args):
|
||||
nonlocal cfunc
|
||||
if cfunc is None: (cfunc:=getattr(self, fn.__name__)).argtypes, cfunc.restype = argtypes, restype
|
||||
return cfunc(*args)
|
||||
return wrapper
|
||||
|
||||
def __getattr__(self, nm):
|
||||
if not self.loaded: raise AttributeError(f"failed to load library {self.nm}: " + (self.emsg or f"try setting {self.nm.upper()+'_PATH'}?"))
|
||||
|
||||
Reference in New Issue
Block a user