diff --git a/extra/utils.py b/extra/utils.py index b4a121aef3..758da66d60 100644 --- a/extra/utils.py +++ b/extra/utils.py @@ -1,7 +1,7 @@ import pickle import numpy as np from tqdm import tqdm -from tinygrad.tensor import Tensor +import tempfile from tinygrad.helpers import prod, getenv def fetch(url): @@ -21,10 +21,10 @@ def download_file(url, fp, skip_if_exists=False): r = requests.get(url, stream=True) assert r.status_code == 200 progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url) - with open(fp+".tmp", "wb") as f: + with tempfile.NamedTemporaryFile(delete=False) as f: for chunk in r.iter_content(chunk_size=16384): progress_bar.update(f.write(chunk)) - os.rename(fp+".tmp", fp) + os.rename(f.name, fp) def my_unpickle(fb0): key_prelookup = {} diff --git a/test/test_yolo.py b/test/external_test_yolo.py similarity index 100% rename from test/test_yolo.py rename to test/external_test_yolo.py diff --git a/test/test_utils.py b/test/test_utils.py index 787d7627d5..4464421611 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,6 +1,8 @@ #!/usr/bin/env python +import io import unittest from extra.utils import fetch +from PIL import Image class TestUtils(unittest.TestCase): def test_fetch_bad_http(self): @@ -11,5 +13,10 @@ class TestUtils(unittest.TestCase): def test_fetch_small(self): assert(len(fetch('https://google.com'))>0) + def test_fetch_img(self): + img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190") + pimg = Image.open(io.BytesIO(img)) + assert pimg.size == (705, 1024) + if __name__ == '__main__': unittest.main() \ No newline at end of file