mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix caching for fetch (#13544)
This commit is contained in:
@@ -163,6 +163,14 @@ class TestFetch(unittest.TestCase):
|
||||
fetch("https://csrc.nist.gov/CSRC/media/Projects/lightweight-cryptography/documents/finalist-round/updated-submissions/sparkle.zip",
|
||||
allow_caching=False)
|
||||
|
||||
def test_fetch_half_and_full_file(self):
|
||||
x = fetch("https://csrc.nist.gov/CSRC/media/Projects/lightweight-cryptography/documents/finalist-round/updated-submissions/sparkle.zip",
|
||||
headers={"Range": "bytes=0-10"}).read_bytes()
|
||||
assert len(x) == 11, f"{len(x) != 11}"
|
||||
x = fetch("https://csrc.nist.gov/CSRC/media/Projects/lightweight-cryptography/documents/finalist-round/updated-submissions/sparkle.zip",
|
||||
headers={"Range": "bytes=0-100"}).read_bytes()
|
||||
assert len(x) == 101, f"{len(x) != 101}"
|
||||
|
||||
class TestFullyFlatten(unittest.TestCase):
|
||||
def test_fully_flatten(self):
|
||||
self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2])
|
||||
|
||||
@@ -381,10 +381,12 @@ def _ensure_downloads_dir() -> pathlib.Path:
|
||||
return pathlib.Path(cache_dir) / "downloads"
|
||||
|
||||
def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip:bool=False,
|
||||
allow_caching=not getenv("DISABLE_HTTP_CACHE"), headers={}) -> pathlib.Path:
|
||||
allow_caching=not getenv("DISABLE_HTTP_CACHE"), headers:dict[str, str]={}) -> pathlib.Path:
|
||||
if url.startswith(("/", ".")): return pathlib.Path(url)
|
||||
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
||||
else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else ""))
|
||||
else:
|
||||
hh = "_"+hashlib.md5(("\n".join(f"{k.strip()}:{v.strip()}" for k,v in sorted(headers.items()))).encode("utf-8")).hexdigest() if headers else ""
|
||||
fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + hh + (".gunzip" if gunzip else ""))
|
||||
if not fp.is_file() or not allow_caching:
|
||||
(_dir := fp.parent).mkdir(parents=True, exist_ok=True)
|
||||
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": "tinygrad 0.11.0", **headers}), timeout=10) as r:
|
||||
|
||||
Reference in New Issue
Block a user