mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
This reverts commit dc4d7f2d55.
This commit is contained in:
@@ -209,13 +209,6 @@ class TestFetch(unittest.TestCase):
|
||||
headers={"Range": "bytes=0-100"}).read_bytes()
|
||||
assert len(x) == 101, f"{len(x) != 101}"
|
||||
|
||||
def test_fetch_retries(self):
|
||||
from unittest.mock import patch
|
||||
with patch('urllib.request.urlopen', side_effect=TimeoutError()) as mock_urlopen:
|
||||
with self.assertRaises(TimeoutError):
|
||||
fetch('http://example.com/test', allow_caching=False, retries=2)
|
||||
assert mock_urlopen.call_count == 2
|
||||
|
||||
class TestFullyFlatten(unittest.TestCase):
|
||||
def test_fully_flatten(self):
|
||||
self.assertEqual(fully_flatten([[1, 3], [1, 2]]), [1, 3, 1, 2])
|
||||
|
||||
@@ -380,7 +380,7 @@ def _ensure_downloads_dir() -> pathlib.Path:
|
||||
return pathlib.Path(cache_dir) / "downloads"
|
||||
|
||||
def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip:bool=False,
|
||||
allow_caching=not getenv("DISABLE_HTTP_CACHE"), headers:dict[str, str]={}, retries:int=3) -> pathlib.Path:
|
||||
allow_caching=not getenv("DISABLE_HTTP_CACHE"), headers:dict[str, str]={}) -> pathlib.Path:
|
||||
import urllib.request
|
||||
if url.startswith(("/", ".")): return pathlib.Path(url)
|
||||
if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name)
|
||||
@@ -389,25 +389,17 @@ def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip
|
||||
fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + hh + (".gunzip" if gunzip else ""))
|
||||
if not fp.is_file() or not allow_caching:
|
||||
(_dir := fp.parent).mkdir(parents=True, exist_ok=True)
|
||||
assert retries > 0
|
||||
for retry in range(retries):
|
||||
try:
|
||||
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": "tinygrad 0.12.0", **headers}), timeout=10) as r:
|
||||
assert r.status in {200, 206}, r.status
|
||||
length = int(r.headers.get('content-length', 0)) if not gunzip else None
|
||||
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
|
||||
progress_bar:tqdm = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
||||
with tempfile.NamedTemporaryFile(dir=_dir, delete=False) as f:
|
||||
while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
|
||||
f.close()
|
||||
pathlib.Path(f.name).rename(fp)
|
||||
progress_bar.update(close=True)
|
||||
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
||||
break # success so don't retry
|
||||
except Exception as e:
|
||||
if retry+1 == retries: raise e
|
||||
if DEBUG >= 2: print(f'Request {retry+1} failed: {e}. Retrying...')
|
||||
time.sleep(0.1 * 2**retry) # exponential backoff
|
||||
with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": "tinygrad 0.12.0", **headers}), timeout=10) as r:
|
||||
assert r.status in {200, 206}, r.status
|
||||
length = int(r.headers.get('content-length', 0)) if not gunzip else None
|
||||
readfile = gzip.GzipFile(fileobj=r) if gunzip else r
|
||||
progress_bar:tqdm = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)
|
||||
with tempfile.NamedTemporaryFile(dir=_dir, delete=False) as f:
|
||||
while chunk := readfile.read(16384): progress_bar.update(f.write(chunk))
|
||||
f.close()
|
||||
pathlib.Path(f.name).rename(fp)
|
||||
progress_bar.update(close=True)
|
||||
if length and (file_size:=os.stat(fp).st_size) < length: raise RuntimeError(f"fetch size incomplete, {file_size} < {length}")
|
||||
return fp
|
||||
|
||||
# *** Exec helpers
|
||||
|
||||
Reference in New Issue
Block a user