diff --git a/test/testextra/test_hevc.py b/test/testextra/test_hevc.py index 331abd5f9e..289d907466 100644 --- a/test/testextra/test_hevc.py +++ b/test/testextra/test_hevc.py @@ -1,14 +1,14 @@ import unittest -from tinygrad import Tensor, Device +from tinygrad import Device +from tinygrad.helpers import fetch from extra.hevc.hevc import parse_hevc_file_headers, nv_gpu class TestHevc(unittest.TestCase): def test_hevc_parser(self): url = "https://github.com/haraschax/filedump/raw/09a497959f7fa6fd8dba501a25f2cdb3a41ecb12/comma_video.hevc" - hevc_tensor = Tensor.from_url(url, device="CPU") + dat = fetch(url, headers={"Range": f"bytes=0-{512<<10}"}).read_bytes() - dat = bytes(hevc_tensor.data()) opaque, frame_info, w, h, luma_w, luma_h, chroma_off = parse_hevc_file_headers(dat, device=Device.DEFAULT) def _test_common(frame, bts): diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index 8fbf11d074..07ea8cf914 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -381,14 +381,14 @@ def _ensure_downloads_dir() -> pathlib.Path: return pathlib.Path(cache_dir) / "downloads" def fetch(url:str, name:pathlib.Path|str|None=None, subdir:str|None=None, gunzip:bool=False, - allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path: + allow_caching=not getenv("DISABLE_HTTP_CACHE"), headers={}) -> pathlib.Path: if url.startswith(("/", ".")): return pathlib.Path(url) if name is not None and (isinstance(name, pathlib.Path) or '/' in name): fp = pathlib.Path(name) else: fp = _ensure_downloads_dir() / (subdir or "") / ((name or hashlib.md5(url.encode('utf-8')).hexdigest()) + (".gunzip" if gunzip else "")) if not fp.is_file() or not allow_caching: (_dir := fp.parent).mkdir(parents=True, exist_ok=True) - with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": "tinygrad 0.11.0"}), timeout=10) as r: - assert r.status == 200, r.status + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": "tinygrad 0.11.0", **headers}), timeout=10) as r: + assert r.status in {200, 206}, r.status length = int(r.headers.get('content-length', 0)) if not gunzip else None readfile = gzip.GzipFile(fileobj=r) if gunzip else r progress_bar:tqdm = tqdm(total=length, unit='B', unit_scale=True, desc=f"{url}", disable=CI)