diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 586d8ed798..96698ec669 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,7 +1,7 @@ name: Unit Tests env: # increment this when downloads substantially change to avoid the internet - DOWNLOAD_CACHE_VERSION: '1' + DOWNLOAD_CACHE_VERSION: '3' on: push: diff --git a/.gitignore b/.gitignore index eca5fc91ad..c336736fd3 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ extra/datasets/open-images-v6-mlperf extra/datasets/kits/ extra/datasets/COCO/ extra/datasets/audio* +extra/weights venv examples/**/net.*[js,json] examples/**/*.safetensors diff --git a/examples/vgg7_helpers/waifu2x.py b/examples/vgg7_helpers/waifu2x.py index 67c75d875d..91f3a2273d 100644 --- a/examples/vgg7_helpers/waifu2x.py +++ b/examples/vgg7_helpers/waifu2x.py @@ -4,8 +4,7 @@ import numpy from tinygrad.tensor import Tensor from PIL import Image -from pathlib import Path -from extra.utils import download_file +from tinygrad.helpers import fetch # File Formats @@ -122,13 +121,8 @@ class Vgg7: """ Downloads a nagadomi/waifu2x JSON weight file and loads it. """ - fn = Path(__file__).parents[2] / ("weights/vgg_7_" + intent + "_" + subtype + "_model.json") - download_file("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json", fn) - import json - with open(fn, "rb") as f: - data = json.load(f) - + data = json.loads(fetch("https://github.com/nagadomi/waifu2x/raw/master/models/vgg_7/" + intent + "/" + subtype + "_model.json").read_bytes()) self.load_waifu2x_json(data) def load_waifu2x_json(self, data: list): diff --git a/examples/vit.py b/examples/vit.py index cd86e0fa96..bf9a8f5d31 100644 --- a/examples/vit.py +++ b/examples/vit.py @@ -1,11 +1,9 @@ import ast -import io import numpy as np from PIL import Image from tinygrad.tensor import Tensor -from tinygrad.helpers import getenv +from tinygrad.helpers import getenv, fetch from extra.models.vit import ViT -from extra.utils import fetch """ fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz" import tensorflow as tf @@ -24,14 +22,13 @@ else: m.load_from_pretrained() # category labels -lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt") -lbls = ast.literal_eval(lbls.decode('utf-8')) +lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text()) #url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg" url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0" # junk -img = Image.open(io.BytesIO(fetch(url))) +img = Image.open(fetch(url)) aspect_ratio = img.size[0] / img.size[1] img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0)))) img = np.array(img) diff --git a/examples/yolov3.py b/examples/yolov3.py index 9a984e60f2..ce137e6c80 100755 --- a/examples/yolov3.py +++ b/examples/yolov3.py @@ -8,10 +8,10 @@ import numpy as np from PIL import Image from tinygrad.tensor import Tensor from tinygrad.nn import BatchNorm2d, Conv2d -from extra.utils import fetch +from tinygrad.helpers import fetch def show_labels(prediction, confidence=0.5, num_classes=80): - coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names') + coco_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_bytes() coco_labels = coco_labels.decode('utf-8').split('\n') prediction = prediction.detach().numpy() conf_mask = (prediction[:,:,4] > confidence) diff --git a/examples/yolov8.py b/examples/yolov8.py index 27bd1f681a..c01dc5c134 100644 --- a/examples/yolov8.py +++ b/examples/yolov8.py @@ -2,14 +2,13 @@ from tinygrad.nn import Conv2d, BatchNorm2d from tinygrad.tensor import Tensor import numpy as np from itertools import chain -from extra.utils import get_child, fetch, download_file from pathlib import Path import cv2 from collections import defaultdict -import time, io, sys +import time, sys +from tinygrad.helpers import fetch from tinygrad.nn.state import safe_load, load_state_dict - #Model architecture from https://github.com/ultralytics/ultralytics/issues/189 #The upsampling class has been taken from this pull request https://github.com/tinygrad/tinygrad/pull/784 by dc-dc-dc. Now 2(?) models use upsampling. (retinet and this) @@ -400,7 +399,7 @@ if __name__ == '__main__': output_folder_path = Path('./outputs_yolov8') output_folder_path.mkdir(parents=True, exist_ok=True) #absolute image path or URL - image_location = [np.frombuffer(io.BytesIO(fetch(img_path)).read(), np.uint8)] + image_location = [np.frombuffer(fetch(img_path).read_bytes(), np.uint8)] image = [cv2.imdecode(image_location[0], 1)] out_paths = [(output_folder_path / f"{Path(img_path).stem}_output{Path(img_path).suffix}").as_posix()] if not isinstance(image[0], np.ndarray): @@ -412,10 +411,7 @@ if __name__ == '__main__': depth, width, ratio = get_variant_multiples(yolo_variant) yolo_infer = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) - weights_location = Path(__file__).parents[1] / "weights" / f'yolov8{yolo_variant}.safetensors' - download_file(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors', weights_location) - - state_dict = safe_load(weights_location) + state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{yolo_variant}.safetensors')) load_state_dict(yolo_infer, state_dict) st = time.time() @@ -425,8 +421,7 @@ if __name__ == '__main__': post_predictions = postprocess(preds=predictions, img=pre_processed_image, orig_imgs=image) #v8 and v3 have same 80 class names for Object Detection - class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names') - class_labels = class_labels.decode('utf-8').split('\n') + class_labels = fetch('https://raw.githubusercontent.com/pjreddie/darknet/master/data/coco.names').read_text().split("\n") draw_bounding_boxes_and_save(orig_img_paths=image_location, output_img_paths=out_paths, all_predictions=post_predictions, class_labels=class_labels) diff --git a/extra/datasets/coco.py b/extra/datasets/coco.py index 5b35974587..0952e37701 100644 --- a/extra/datasets/coco.py +++ b/extra/datasets/coco.py @@ -2,7 +2,7 @@ import json import pathlib import zipfile import numpy as np -from extra.utils import download_file +from tinygrad.helpers import fetch import pycocotools._mask as _mask from examples.mask_rcnn import Masker from pycocotools.coco import COCO @@ -19,16 +19,14 @@ def create_dict(key_row, val_row, rows): return {row[key_row]:row[val_row] for r if not pathlib.Path(BASEDIR/'val2017').is_dir(): - fn = BASEDIR/'val2017.zip' - download_file('http://images.cocodataset.org/zips/val2017.zip',fn) + fn = fetch('http://images.cocodataset.org/zips/val2017.zip') with zipfile.ZipFile(fn, 'r') as zip_ref: zip_ref.extractall(BASEDIR) fn.unlink() if not pathlib.Path(BASEDIR/'annotations').is_dir(): - fn = BASEDIR/'annotations_trainval2017.zip' - download_file('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',fn) + fn = fetch('http://images.cocodataset.org/annotations/annotations_trainval2017.zip') with zipfile.ZipFile(fn, 'r') as zip_ref: zip_ref.extractall(BASEDIR) fn.unlink() diff --git a/extra/datasets/imagenet_download.py b/extra/datasets/imagenet_download.py index 7eca9dd261..5c01c72f52 100644 --- a/extra/datasets/imagenet_download.py +++ b/extra/datasets/imagenet_download.py @@ -1,5 +1,5 @@ # Python version of https://gist.github.com/antoinebrl/7d00d5cb6c95ef194c737392ef7e476a -from extra.utils import download_file +from tinygrad.helpers import fetch from pathlib import Path from tqdm import tqdm import tarfile, os @@ -40,12 +40,12 @@ if __name__ == "__main__": os.makedirs(Path(__file__).parent / "imagenet", exist_ok=True) os.makedirs(Path(__file__).parent / "imagenet" / "val", exist_ok=True) os.makedirs(Path(__file__).parent / "imagenet" / "train", exist_ok=True) - download_file("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent / "imagenet" / "imagenet_class_index.json") - download_file("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent / "imagenet"/ "imagenet_2012_validation_synset_labels.txt") - download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar") # 7GB + fetch("https://raw.githubusercontent.com/raghakot/keras-vis/master/resources/imagenet_class_index.json", Path(__file__).parent / "imagenet" / "imagenet_class_index.json") + fetch("https://raw.githubusercontent.com/tensorflow/models/master/research/slim/datasets/imagenet_2012_validation_synset_labels.txt", Path(__file__).parent / "imagenet"/ "imagenet_2012_validation_synset_labels.txt") + fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar") # 7GB imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_val.tar", Path(__file__).parent / "imagenet" / "val") imagenet_prepare_val() if os.getenv('IMGNET_TRAIN', None) is not None: - download_file("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar") #138GB! + fetch("https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar") #138GB! imagenet_extract(Path(__file__).parent / "imagenet" / "ILSVRC2012_img_train.tar", Path(__file__).parent / "imagenet" / "train") imagenet_prepare_train() diff --git a/extra/datasets/openimages.py b/extra/datasets/openimages.py index 8e411e067d..97bd2e846f 100644 --- a/extra/datasets/openimages.py +++ b/extra/datasets/openimages.py @@ -1,12 +1,11 @@ import os import math import json -from extra.utils import OSX import numpy as np from PIL import Image import pathlib import boto3, botocore -from extra.utils import download_file +from tinygrad.helpers import fetch from tqdm import tqdm import pandas as pd import concurrent.futures @@ -114,11 +113,11 @@ def fetch_openimages(output_fn): data_dir.mkdir(parents=True, exist_ok=True) annotations_fn = annotations_dir / BBOX_ANNOTATIONS_URL.split('/')[-1] - download_file(BBOX_ANNOTATIONS_URL, annotations_fn) + fetch(BBOX_ANNOTATIONS_URL, annotations_fn) annotations = pd.read_csv(annotations_fn) classmap_fn = annotations_dir / MAP_CLASSES_URL.split('/')[-1] - download_file(MAP_CLASSES_URL, classmap_fn) + fetch(MAP_CLASSES_URL, classmap_fn) class_map = pd.read_csv(classmap_fn, names=["LabelName", "DisplayName"]) image_list = get_image_list(class_map, annotations) diff --git a/extra/datasets/squad.py b/extra/datasets/squad.py index 1b95385ddf..7be38bff41 100644 --- a/extra/datasets/squad.py +++ b/extra/datasets/squad.py @@ -3,12 +3,12 @@ import os from pathlib import Path from transformers import BertTokenizer import numpy as np -from extra.utils import download_file +from tinygrad.helpers import fetch BASEDIR = Path(__file__).parent / "squad" def init_dataset(): os.makedirs(BASEDIR, exist_ok=True) - download_file("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json") + fetch("https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json", BASEDIR / "dev-v1.1.json") with open(BASEDIR / "dev-v1.1.json") as f: data = json.load(f)["data"] diff --git a/extra/models/bert.py b/extra/models/bert.py index e0d32dc167..dd155126cc 100644 --- a/extra/models/bert.py +++ b/extra/models/bert.py @@ -1,6 +1,6 @@ from tinygrad.tensor import Tensor from tinygrad.nn import Linear, LayerNorm, Embedding -from extra.utils import download_file, get_child +from tinygrad.helpers import fetch, get_child from pathlib import Path @@ -11,9 +11,9 @@ class BertForQuestionAnswering: def load_from_pretrained(self): fn = Path(__file__).parents[1] / "weights/bert_for_qa.pt" - download_file("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn) + fetch("https://zenodo.org/record/3733896/files/model.pytorch?download=1", fn) fn_vocab = Path(__file__).parents[1] / "weights/bert_vocab.txt" - download_file("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab) + fetch("https://zenodo.org/record/3733896/files/vocab.txt?download=1", fn_vocab) import torch with open(fn, "rb") as f: diff --git a/extra/models/convnext.py b/extra/models/convnext.py index 3aa08ed90f..591112ad11 100644 --- a/extra/models/convnext.py +++ b/extra/models/convnext.py @@ -1,6 +1,6 @@ from tinygrad.tensor import Tensor from tinygrad.nn import Conv2d, LayerNorm, LayerNorm2d, Linear -from tinygrad.helpers import fetch +from tinygrad.helpers import fetch, get_child class Block: def __init__(self, dim): @@ -44,7 +44,6 @@ versions = { def get_model(version, load_weights=False): model = ConvNeXt(**versions[version]) if load_weights: - from extra.utils import get_child from tinygrad.nn.state import torch_load weights = torch_load(fetch(f'https://dl.fbaipublicfiles.com/convnext/convnext_{version}_1k_224_ema.pth'))['model'] for k,v in weights.items(): diff --git a/extra/models/mask_rcnn.py b/extra/models/mask_rcnn.py index ce35e9e4df..e352ad758c 100644 --- a/extra/models/mask_rcnn.py +++ b/extra/models/mask_rcnn.py @@ -5,13 +5,11 @@ import numpy as np from pathlib import Path from tinygrad import nn from tinygrad.tensor import Tensor -from tinygrad.helpers import dtypes -from extra.utils import get_child, download_file +from tinygrad.helpers import dtypes, get_child, fetch from tinygrad.nn.state import torch_load from extra.models.resnet import ResNet from extra.models.retinanet import nms as _box_nms - USE_NP_GATHER = os.getenv('FULL_TINYGRAD', '0') == '0' def rint(tensor): @@ -1240,7 +1238,7 @@ class MaskRCNN: def load_from_pretrained(self): fn = Path('./') / "weights/maskrcnn.pt" - download_file("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn) + fetch("https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_R_50_FPN_1x.pth", fn) state_dict = torch_load(fn)['model'] loaded_keys = [] diff --git a/extra/models/resnet.py b/extra/models/resnet.py index 8f14cbba74..517f1ec9e1 100644 --- a/extra/models/resnet.py +++ b/extra/models/resnet.py @@ -1,8 +1,7 @@ import tinygrad.nn as nn from tinygrad.tensor import Tensor from tinygrad.nn.state import torch_load -from tinygrad.helpers import fetch -from extra.utils import get_child +from tinygrad.helpers import fetch, get_child class BasicBlock: expansion = 1 diff --git a/extra/models/retinanet.py b/extra/models/retinanet.py index 363e01c86e..91d758ff5f 100644 --- a/extra/models/retinanet.py +++ b/extra/models/retinanet.py @@ -1,8 +1,7 @@ import math -from tinygrad.helpers import flatten +from tinygrad.helpers import flatten, get_child import tinygrad.nn as nn from extra.models.resnet import ResNet -from extra.utils import get_child import numpy as np def nms(boxes, scores, thresh=0.5): diff --git a/extra/models/rnnt.py b/extra/models/rnnt.py index 88523dcce9..589ac75c04 100644 --- a/extra/models/rnnt.py +++ b/extra/models/rnnt.py @@ -1,8 +1,8 @@ from tinygrad.tensor import Tensor from tinygrad.jit import TinyJit from tinygrad.nn import Linear, Embedding +from tinygrad.helpers import fetch import numpy as np -from extra.utils import download_file from pathlib import Path @@ -61,7 +61,7 @@ class RNNT: def load_from_pretrained(self): fn = Path(__file__).parents[1] / "weights/rnnt.pt" - download_file("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn) + fetch("https://zenodo.org/record/3662521/files/DistributedDataParallel_1576581068.9962234-epoch-100.pt?download=1", fn) import torch with open(fn, "rb") as f: diff --git a/extra/models/unet3d.py b/extra/models/unet3d.py index b67225f9fe..1b2558c871 100644 --- a/extra/models/unet3d.py +++ b/extra/models/unet3d.py @@ -2,7 +2,7 @@ from pathlib import Path import torch from tinygrad import nn from tinygrad.tensor import Tensor -from extra.utils import download_file, get_child +from tinygrad.helpers import fetch, get_child class DownsampleBlock: def __init__(self, c0, c1, stride=2): @@ -47,7 +47,7 @@ class UNet3D: def load_from_pretrained(self): fn = Path(__file__).parents[1] / "weights" / "unet-3d.ckpt" - download_file("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn) + fetch("https://zenodo.org/record/5597155/files/3dunet_kits19_pytorch.ptc?download=1", fn) state_dict = torch.jit.load(fn, map_location=torch.device("cpu")).state_dict() for k, v in state_dict.items(): obj = get_child(self, k) diff --git a/extra/models/vit.py b/extra/models/vit.py index 6c7882d219..a465708474 100644 --- a/extra/models/vit.py +++ b/extra/models/vit.py @@ -1,5 +1,6 @@ import numpy as np from tinygrad.tensor import Tensor +from tinygrad.helpers import fetch from extra.models.transformer import TransformerBlock class ViT: @@ -29,9 +30,6 @@ class ViT: return x[:, 0].linear(*self.head) def load_from_pretrained(m): - import io - from extra.utils import fetch - # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py if m.embed_dim == 192: url = "https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz" @@ -39,7 +37,7 @@ class ViT: url = "https://storage.googleapis.com/vit_models/augreg/B_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_224.npz" else: raise Exception("no pretrained weights for configuration") - dat = np.load(io.BytesIO(fetch(url))) + dat = np.load(fetch(url)) #for x in dat.keys(): # print(x, dat[x].shape, dat[x].dtype) diff --git a/extra/utils.py b/extra/utils.py deleted file mode 100644 index d7c39ff255..0000000000 --- a/extra/utils.py +++ /dev/null @@ -1,50 +0,0 @@ -# type: ignore -import pickle, hashlib, zipfile, io, requests, struct, tempfile, platform, concurrent.futures -import numpy as np -from tqdm import tqdm -from pathlib import Path -from collections import defaultdict -from typing import Union - -from tinygrad.helpers import prod, getenv, DEBUG, dtypes, get_child -from tinygrad.helpers import GlobalCounters -from tinygrad.tensor import Tensor -from tinygrad.lazy import LazyBuffer -from tinygrad import Device -from tinygrad.shape.view import strides_for_shape -OSX = platform.system() == "Darwin" -WINDOWS = platform.system() == "Windows" - -def temp(x:str) -> str: return (Path(tempfile.gettempdir()) / x).as_posix() - -def fetch(url): - if url.startswith("/") or url.startswith("."): - with open(url, "rb") as f: - return f.read() - fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest()) - download_file(url, fp, skip_if_exists=not getenv("NOCACHE")) - with open(fp, "rb") as f: - return f.read() - -def fetch_as_file(url): - if url.startswith("/") or url.startswith("."): - with open(url, "rb") as f: - return f.read() - fp = temp(hashlib.md5(url.encode('utf-8')).hexdigest()) - download_file(url, fp, skip_if_exists=not getenv("NOCACHE")) - return fp - -def download_file(url, fp, skip_if_exists=True): - if skip_if_exists and Path(fp).is_file() and Path(fp).stat().st_size > 0: - return - r = requests.get(url, stream=True, timeout=10) - assert r.status_code == 200 - progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url) - (path := Path(fp).parent).mkdir(parents=True, exist_ok=True) - with tempfile.NamedTemporaryFile(dir=path, delete=False) as f: - for chunk in r.iter_content(chunk_size=16384): - progress_bar.update(f.write(chunk)) - f.close() - Path(f.name).rename(fp) - - diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 441cb46711..f12325cd1e 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -11,11 +11,10 @@ OPENPILOT_MODEL = "https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/mod import onnx from typing import Tuple, List -from extra.utils import fetch from extra.onnx import get_run_onnx from tinygrad.graph import print_tree, log_schedule_item from tinygrad import Tensor, Device -from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, DEBUG, getenv, ImageDType, GRAPH +from tinygrad.helpers import dtypes, partition, GlobalCounters, Context, fetch, getenv, ImageDType, GRAPH from tinygrad.realize import run_schedule from tinygrad.ops import LoadOps, ScheduleItem from tinygrad.features.image import fix_schedule_for_images @@ -135,7 +134,7 @@ def thneed_test_onnx(onnx_data, output_fn): print("thneed self-test passed!") if __name__ == "__main__": - onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL) + onnx_data = fetch(sys.argv[1] if len(sys.argv) > 1 else OPENPILOT_MODEL).read_bytes() # quick test for ONNX issues #thneed_test_onnx(onnx_data, None) diff --git a/test/external/external_model_benchmark.py b/test/external/external_model_benchmark.py index 51d2d1682b..4c95113897 100644 --- a/test/external/external_model_benchmark.py +++ b/test/external/external_model_benchmark.py @@ -6,9 +6,8 @@ import onnx from onnx.helper import tensor_dtype_to_np_dtype import onnxruntime as ort from onnx2torch import convert -from extra.utils import download_file from extra.onnx import get_run_onnx -from tinygrad.helpers import OSX, DEBUG +from tinygrad.helpers import OSX, DEBUG, fetch from tinygrad.tensor import Tensor from tinygrad import Device @@ -49,8 +48,7 @@ def benchmark_model(m, validate_outs=False): global open_csv, CSV CSV = {"model": m} - fn = BASE / MODELS[m].split("/")[-1] - download_file(MODELS[m], fn) + fn = fetch(MODELS[m]) onnx_model = onnx.load(fn) output_names = [out.name for out in onnx_model.graph.output] excluded = {inp.name for inp in onnx_model.graph.initializer} diff --git a/test/external/external_test_yolo.py b/test/external/external_test_yolo.py index e4370fe616..f28f23aa5f 100644 --- a/test/external/external_test_yolo.py +++ b/test/external/external_test_yolo.py @@ -4,15 +4,15 @@ from pathlib import Path import cv2 from examples.yolov3 import Darknet, infer, show_labels -from extra.utils import fetch +from tinygrad.helpers import fetch -chicken_img = cv2.imread(str(Path(__file__).parent / 'efficientnet/Chicken.jpg')) -car_img = cv2.imread(str(Path(__file__).parent / 'efficientnet/car.jpg')) +chicken_img = cv2.imread(str(Path(__file__).parent.parent / 'models/efficientnet/Chicken.jpg')) +car_img = cv2.imread(str(Path(__file__).parent.parent / 'models/efficientnet/car.jpg')) class TestYOLO(unittest.TestCase): @classmethod def setUpClass(cls): - cls.model = Darknet(fetch("https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg")) + cls.model = Darknet(fetch("https://raw.githubusercontent.com/pjreddie/darknet/master/cfg/yolov3.cfg").read_bytes()) print("Loading weights file (237MB). This might take a while…") cls.model.load_weights("https://pjreddie.com/media/files/yolov3.weights") diff --git a/test/external/external_test_yolov8.py b/test/external/external_test_yolov8.py index 22e0e6e8ec..c98a4266e0 100644 --- a/test/external/external_test_yolov8.py +++ b/test/external/external_test_yolov8.py @@ -1,37 +1,31 @@ import numpy as np -from extra.utils import fetch, download_file from examples.yolov8 import YOLOv8, get_variant_multiples, preprocess, postprocess, label_predictions -from pathlib import Path import unittest import io, cv2 import onnxruntime as ort import ultralytics from tinygrad.nn.state import safe_load, load_state_dict +from tinygrad.helpers import fetch class TestYOLOv8(unittest.TestCase): def test_all_load_weights(self): for variant in ['n', 's', 'm', 'l', 'x']: - weights_location = Path(__file__).parents[2] / "weights" / f'yolov8{variant}.safetensors' - download_file(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors', weights_location) - depth, width, ratio = get_variant_multiples(variant) TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) - state_dict = safe_load(weights_location) + state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors')) load_state_dict(TinyYolov8, state_dict) print(f'successfully loaded weights for yolov{variant}') def test_predictions(self): test_image_urls = ['https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg', 'https://www.aljazeera.com/wp-content/uploads/2022/10/2022-04-28T192650Z_1186456067_UP1EI4S1I0P14_RTRMADP_3_SOCCER-ENGLAND-MUN-CHE-REPORT.jpg'] variant = 'n' - weights_location = Path(__file__).parents[2] / "weights" / f'yolov8{variant}.safetensors' depth, width, ratio = get_variant_multiples(variant) TinyYolov8 = YOLOv8(w=width, r=ratio, d=depth, num_classes=80) - state_dict = safe_load(weights_location) + state_dict = safe_load(fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors')) load_state_dict(TinyYolov8, state_dict) for i in range(len(test_image_urls)): - img_stream = io.BytesIO(fetch(test_image_urls[i])) - img = cv2.imdecode(np.frombuffer(img_stream.read(), np.uint8), 1) + img = cv2.imdecode(np.frombuffer(fetch(test_image_urls[i]).read_bytes(), np.uint8), 1) test_image = preprocess([img]) predictions = TinyYolov8(test_image) post_predictions = postprocess(preds=predictions, img=test_image, orig_imgs=[img]) @@ -40,11 +34,10 @@ class TestYOLOv8(unittest.TestCase): def test_forward_pass_torch_onnx(self): variant = 'n' - weights_location_onnx = Path(__file__).parents[2] / "weights" / f'yolov8{variant}.onnx' - weights_location_pt = Path(__file__).parents[2] / "weights" / f'yolov8{variant}.pt' - weights_location = Path(__file__).parents[2] / "weights" / f'yolov8{variant}.safetensors' + weights_location = fetch(f'https://gitlab.com/r3sist/yolov8_weights/-/raw/master/yolov8{variant}.safetensors') + weights_location_pt = fetch(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', name=f"yolov8{variant}.pt") # it needs the pt extension + weights_location_onnx = weights_location_pt.parent / f"yolov8{variant}.onnx" - download_file(f'https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8{variant}.pt', weights_location_pt) # the ultralytics export prints a lot of unneccesary things if not weights_location_onnx.is_file(): model = ultralytics.YOLO(model=weights_location_pt, task='Detect') @@ -55,7 +48,7 @@ class TestYOLOv8(unittest.TestCase): state_dict = safe_load(weights_location) load_state_dict(TinyYolov8, state_dict) - image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg')).read(), np.uint8)] + image_location = [np.frombuffer(io.BytesIO(fetch('https://raw.githubusercontent.com/ultralytics/yolov5/master/data/images/bus.jpg').read_bytes()).read(), np.uint8)] orig_image = [cv2.imdecode(image_location[0], 1)] input_image = preprocess(orig_image) diff --git a/test/extra/test_utils.py b/test/extra/test_utils.py deleted file mode 100644 index 4b47c01269..0000000000 --- a/test/extra/test_utils.py +++ /dev/null @@ -1,104 +0,0 @@ -#!/usr/bin/env python -import io, unittest -import os -import tempfile -from unittest.mock import patch, MagicMock - -import torch -import numpy as np -from tinygrad.helpers import CI -from extra.utils import fetch, temp, download_file -from tinygrad.nn.state import torch_load -from PIL import Image - -@unittest.skipIf(CI, "no internet tests in CI") -class TestFetch(unittest.TestCase): - def test_fetch_bad_http(self): - self.assertRaises(AssertionError, fetch, 'http://www.google.com/404') - - def test_fetch_small(self): - assert(len(fetch('https://google.com'))>0) - - def test_fetch_img(self): - img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190") - pimg = Image.open(io.BytesIO(img)) - assert pimg.size == (705, 1024) - -class TestFetchRelative(unittest.TestCase): - def setUp(self): - self.working_dir = os.getcwd() - self.tempdir = tempfile.TemporaryDirectory() - os.chdir(self.tempdir.name) - with open('test_file.txt', 'x') as f: - f.write("12345") - - def tearDown(self): - os.chdir(self.working_dir) - self.tempdir.cleanup() - - #test ./ - def test_fetch_relative_dotslash(self): - self.assertEqual(b'12345', fetch("./test_file.txt")) - - #test ../ - def test_fetch_relative_dotdotslash(self): - os.mkdir('test_file_path') - os.chdir('test_file_path') - self.assertEqual(b'12345', fetch("../test_file.txt")) - -class TestDownloadFile(unittest.TestCase): - def setUp(self): - from pathlib import Path - self.test_file = Path(temp("test_download_file/test_file.txt")) - - def tearDown(self): - os.remove(self.test_file) - os.removedirs(self.test_file.parent) - - @patch('requests.get') - def test_download_file_with_mkdir(self, mock_requests): - mock_response = MagicMock() - mock_response.iter_content.return_value = [b'1234', b'5678'] - mock_response.status_code = 200 - mock_response.headers = {'content-length': '8'} - mock_requests.return_value = mock_response - self.assertFalse(self.test_file.parent.exists()) - download_file("https://www.mock.com/fake.txt", self.test_file, skip_if_exists=False) - self.assertTrue(self.test_file.parent.exists()) - self.assertTrue(self.test_file.is_file()) - self.assertEqual('12345678', self.test_file.read_text()) - -class TestUtils(unittest.TestCase): - def test_fake_torch_load_zipped(self): self._test_fake_torch_load_zipped() - def test_fake_torch_load_zipped_float16(self): self._test_fake_torch_load_zipped(isfloat16=True) - def _test_fake_torch_load_zipped(self, isfloat16=False): - class LayerWithOffset(torch.nn.Module): - def __init__(self): - super(LayerWithOffset, self).__init__() - d = torch.randn(16) - self.param1 = torch.nn.Parameter( - d.as_strided([2, 2], [1, 2], storage_offset=5) - ) - self.param2 = torch.nn.Parameter( - d.as_strided([2, 2], [1, 2], storage_offset=4) - ) - - model = torch.nn.Sequential( - torch.nn.Linear(4, 8), - torch.nn.Linear(8, 3), - LayerWithOffset() - ) - if isfloat16: model = model.half() - - path = temp(f"test_load_{isfloat16}.pt") - torch.save(model.state_dict(), path) - model2 = torch_load(path) - - for name, a in model.state_dict().items(): - b = model2[name] - a, b = a.numpy(), b.numpy() - assert a.shape == b.shape - assert a.dtype == b.dtype - assert np.array_equal(a, b) -if __name__ == '__main__': - unittest.main() diff --git a/test/models/test_onnx.py b/test/models/test_onnx.py index feffdabc3f..845dc36f29 100644 --- a/test/models/test_onnx.py +++ b/test/models/test_onnx.py @@ -1,14 +1,12 @@ #!/usr/bin/env python import os import time -import io import unittest import numpy as np import onnx -from extra.utils import fetch, temp from extra.onnx import get_run_onnx from tinygrad.tensor import Tensor -from tinygrad.helpers import CI +from tinygrad.helpers import CI, fetch, temp import pytest pytestmark = [pytest.mark.exclude_gpu, pytest.mark.exclude_clang] @@ -27,8 +25,7 @@ np.random.seed(1337) class TestOnnxModel(unittest.TestCase): def test_benchmark_openpilot_model(self): - dat = fetch(OPENPILOT_MODEL) - onnx_model = onnx.load(io.BytesIO(dat)) + onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) run_onnx = get_run_onnx(onnx_model) def get_inputs(): np_inputs = { @@ -71,8 +68,7 @@ class TestOnnxModel(unittest.TestCase): ps.print_stats(30) def test_openpilot_model(self): - dat = fetch(OPENPILOT_MODEL) - onnx_model = onnx.load(io.BytesIO(dat)) + onnx_model = onnx.load(fetch(OPENPILOT_MODEL)) run_onnx = get_run_onnx(onnx_model) print("got run_onnx") inputs = { @@ -103,26 +99,21 @@ class TestOnnxModel(unittest.TestCase): np.testing.assert_allclose(torch_out, tinygrad_out, atol=1e-4, rtol=1e-2) def test_efficientnet(self): - dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx") input_name, input_new = "images:0", True - self._test_model(dat, input_name, input_new) + self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/efficientnet-lite4/model/efficientnet-lite4-11.onnx"), input_name, input_new) def test_shufflenet(self): - dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx") - print(f"shufflenet downloaded : {len(dat)/1e6:.2f} MB") input_name, input_new = "gpu_0/data_0", False - self._test_model(dat, input_name, input_new) + self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/shufflenet/model/shufflenet-9.onnx"), input_name, input_new) @unittest.skip("test is very slow") def test_resnet(self): # NOTE: many onnx models can't be run right now due to max pool with strides != kernel_size - dat = fetch("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx") - print(f"resnet downloaded : {len(dat)/1e6:.2f} MB") input_name, input_new = "data", False - self._test_model(dat, input_name, input_new) + self._test_model(fetch("https://github.com/onnx/models/raw/main/vision/classification/resnet/model/resnet18-v2-7.onnx"), input_name, input_new) - def _test_model(self, dat, input_name, input_new, debug=False): - onnx_model = onnx.load(io.BytesIO(dat)) + def _test_model(self, fn, input_name, input_new, debug=False): + onnx_model = onnx.load(fn) print("onnx loaded") from test.models.test_efficientnet import chicken_img, car_img, preprocess, _LABELS run_onnx = get_run_onnx(onnx_model) diff --git a/test/test_dtype.py b/test/test_dtype.py index 64668eb514..9b233a6dce 100644 --- a/test/test_dtype.py +++ b/test/test_dtype.py @@ -1,6 +1,6 @@ import unittest import numpy as np -from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX +from tinygrad.helpers import CI, DTYPES_DICT, getenv, DType, DEBUG, ImageDType, PtrDType, OSX, temp from tinygrad import Device from tinygrad.tensor import Tensor, dtypes from typing import Any, List @@ -112,7 +112,6 @@ class TestBFloat16DType(unittest.TestCase): assert tuple(back.numpy().tolist()) == (9984., -1, -1000, -9984, 20) def test_bf16_disk_write_read(self): - from extra.utils import temp t = Tensor([10000, -1, -1000, -10000, 20]).cast(dtypes.float32) t.to(f"disk:{temp('f32')}").realize() diff --git a/test/test_nn.py b/test/test_nn.py index 8839810a91..5804f84c77 100755 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -1,7 +1,6 @@ #!/usr/bin/env python import unittest import numpy as np -from extra.utils import WINDOWS from tinygrad.helpers import CI from tinygrad.jit import TinyJit from tinygrad.tensor import Tensor, Device @@ -167,7 +166,7 @@ class TestNN(unittest.TestCase): Tensor.wino = old_wino - @unittest.skipIf(CI and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI") + @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI") def test_conv_transpose1d(self): BS, C1, W = 4, 16, 224//4 C2, K, S, P = 64, 7, 2, 1 @@ -188,7 +187,7 @@ class TestNN(unittest.TestCase): torch_z = torch_layer(torch_x) np.testing.assert_allclose(z.numpy(), torch_z.detach().numpy(), atol=5e-4, rtol=1e-5) - @unittest.skipIf(CI and (WINDOWS or Device.DEFAULT == "WEBGPU"), "runs out of memory in CI") + @unittest.skipIf(CI and Device.DEFAULT == "WEBGPU", "runs out of memory in CI") def test_conv_transpose2d(self): BS, C1, H, W = 4, 16, 224//4, 224//4 C2, K, S, P = 64, 7, 2, 1 diff --git a/test/test_tensor.py b/test/test_tensor.py index 9f110b6c95..6e8b4da611 100644 --- a/test/test_tensor.py +++ b/test/test_tensor.py @@ -3,9 +3,8 @@ import torch import unittest, copy import mmap from tinygrad.tensor import Tensor, Device -from tinygrad.helpers import dtypes +from tinygrad.helpers import dtypes, temp from extra.gradcheck import numerical_jacobian, jacobian, gradcheck -from extra.utils import temp x_init = np.random.randn(1,3).astype(np.float32) U_init = np.random.randn(3,3).astype(np.float32) diff --git a/test/unit/test_disk_tensor.py b/test/unit/test_disk_tensor.py index ae9801b798..bd410b1d93 100644 --- a/test/unit/test_disk_tensor.py +++ b/test/unit/test_disk_tensor.py @@ -3,14 +3,13 @@ import unittest import numpy as np from tinygrad.tensor import Tensor, Device from tinygrad.nn.state import safe_load, safe_save, get_state_dict, torch_load -from tinygrad.helpers import dtypes +from tinygrad.helpers import dtypes, fetch, temp from tinygrad.runtime.ops_disk import RawDiskBuffer from tinygrad.helpers import Timing -from extra.utils import fetch_as_file, temp def compare_weights_both(url): import torch - fn = fetch_as_file(url) + fn = fetch(url) tg_weights = get_state_dict(torch_load(fn)) torch_weights = get_state_dict(torch.load(fn), tensor_type=torch.Tensor) assert list(tg_weights.keys()) == list(torch_weights.keys()) @@ -88,7 +87,7 @@ class TestSafetensors(unittest.TestCase): def test_huggingface_enet_safetensors(self): # test a real file - fn = fetch_as_file("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors") + fn = fetch("https://huggingface.co/timm/mobilenetv3_small_075.lamb_in1k/resolve/main/model.safetensors") state_dict = safe_load(fn) assert len(state_dict.keys()) == 244 assert 'blocks.2.2.se.conv_reduce.weight' in state_dict diff --git a/tinygrad/helpers.py b/tinygrad/helpers.py index a524d7192f..248141feca 100644 --- a/tinygrad/helpers.py +++ b/tinygrad/helpers.py @@ -51,6 +51,7 @@ def get_child(obj, key): def to_function_name(s:str): return ''.join([c if c in (string.ascii_letters+string.digits+'_') else f'{ord(c):02X}' for c in ansistrip(s)]) @functools.lru_cache(maxsize=None) def getenv(key:str, default=0): return type(default)(os.getenv(key, default)) +def temp(x:str) -> str: return (pathlib.Path(tempfile.gettempdir()) / x).as_posix() class Context(contextlib.ContextDecorator): stack: ClassVar[List[dict[str, int]]] = [{}] @@ -235,10 +236,11 @@ def diskcache(func): # *** http support *** -def fetch(url:str, name:Optional[str]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path: - fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) +def fetch(url:str, name:Optional[Union[pathlib.Path, str]]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path: + if url.startswith("/") or url.startswith("."): return pathlib.Path(url) + fp = pathlib.Path(name) if name is not None and (isinstance(name, pathlib.Path) or '/' in name) else pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest()) if not fp.is_file() or not allow_caching: - with request.urlopen(url, timeout=10) as r: + with request.urlopen(url, timeout=15) as r: assert r.status == 200 total_length = int(r.headers.get('content-length', 0)) progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)