mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
add name support to fetch (#2407)
* add name support * use fetch in gpt2 * remove requests from main lib, networkx also optional * umm, keep that assert * updates to fetch * i love the walrus so much * stop bundling mnist with tinygrad * err, https * download cache names * add DOWNLOAD_CACHE_VERSION * need env. * ugh, wrong path * replace get_child
This commit is contained in:
8
.github/workflows/benchmark.yml
vendored
8
.github/workflows/benchmark.yml
vendored
@@ -36,8 +36,8 @@ jobs:
|
||||
shell: bash
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
JIT=0 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
JIT=1 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
shell: bash
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
@@ -121,8 +121,8 @@ jobs:
|
||||
shell: bash
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
JIT=0 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
JIT=1 python3 examples/llama.py --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
|
||||
shell: bash
|
||||
- name: Run GPT2
|
||||
run: |
|
||||
|
||||
18
.github/workflows/test.yml
vendored
18
.github/workflows/test.yml
vendored
@@ -1,4 +1,7 @@
|
||||
name: Unit Tests
|
||||
env:
|
||||
# increment this when downloads substantially change to avoid the internet
|
||||
DOWNLOAD_CACHE_VERSION: '1'
|
||||
|
||||
on:
|
||||
push:
|
||||
@@ -85,7 +88,7 @@ jobs:
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/tinygrad/downloads/
|
||||
key: downloads-cache
|
||||
key: downloads-cache-cpu-${{ env.DOWNLOAD_CACHE_VERSION }}
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Pytest
|
||||
@@ -118,7 +121,7 @@ jobs:
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/tinygrad/downloads/
|
||||
key: downloads-cache
|
||||
key: downloads-cache-torch-${{ env.DOWNLOAD_CACHE_VERSION }}
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Pytest
|
||||
@@ -155,6 +158,11 @@ jobs:
|
||||
with:
|
||||
path: ${{ env.Python3_ROOT_DIR }}/lib/python3.11/site-packages
|
||||
key: testing-packages-${{ hashFiles('**/setup.py') }}
|
||||
- name: Cache downloads
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/tinygrad/downloads/
|
||||
key: downloads-cache-${{ matrix.task }}-${{ env.DOWNLOAD_CACHE_VERSION }}
|
||||
- name: Install Dependencies
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
@@ -229,7 +237,7 @@ jobs:
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/Library/Caches/tinygrad/downloads/
|
||||
key: downloads-cache
|
||||
key: downloads-cache-metal-${{ env.DOWNLOAD_CACHE_VERSION }}
|
||||
- name: Test LLaMA compile speed
|
||||
run: PYTHONPATH="." METAL=1 python test/external/external_test_speed_llama.py
|
||||
#- name: Run dtype test
|
||||
@@ -293,8 +301,8 @@ jobs:
|
||||
- name: Cache downloads
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/Library/Caches/tinygrad/downloads/
|
||||
key: downloads-cache
|
||||
path: ~/.cache/tinygrad/downloads/
|
||||
key: downloads-cache-${{ matrix.backend }}-${{ env.DOWNLOAD_CACHE_VERSION }}
|
||||
- name: Set env
|
||||
run: printf "${{ matrix.backend == 'llvm' && 'LLVM=1' || matrix.backend == 'clang' && 'CLANG=1' || matrix.backend == 'gpu' && 'GPU=1' || matrix.backend == 'cuda' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\n' || matrix.backend == 'PTX' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nPTX=1' || matrix.backend == 'triton' && 'FORWARD_ONLY=1\nJIT=1\nOPT=2\nCUDA=1\nCUDACPU=1\nTRITON=1\nTRITON_PTXAS_PATH=/usr/bin/ptxas'}}" >> $GITHUB_ENV
|
||||
- name: Install OpenCL
|
||||
|
||||
@@ -2,9 +2,8 @@ from pathlib import Path
|
||||
from extra.models.efficientnet import EfficientNet
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn.state import safe_save
|
||||
from extra.utils import fetch
|
||||
from extra.export_model import export_model
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.helpers import getenv, fetch
|
||||
import ast
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -21,11 +20,10 @@ if __name__ == "__main__":
|
||||
else:
|
||||
cprog = [prg]
|
||||
# image library!
|
||||
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").decode('utf-8').replace("half", "_half")]
|
||||
cprog += ["#define STB_IMAGE_IMPLEMENTATION", fetch("https://raw.githubusercontent.com/nothings/stb/master/stb_image.h").read_text().replace("half", "_half")]
|
||||
|
||||
# imagenet labels, move to datasets?
|
||||
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
||||
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
||||
lbls = ast.literal_eval(fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt").read_text())
|
||||
lbls = ['"'+lbls[i]+'"' for i in range(1000)]
|
||||
inputs = "\n".join([f"float {inp}[{inp_size}];" for inp,inp_size in inp_sizes.items()])
|
||||
outputs = "\n".join([f"float {out}[{out_size}];" for out,out_size in out_sizes.items()])
|
||||
|
||||
@@ -10,8 +10,7 @@ from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.jit import TinyJit
|
||||
import tiktoken
|
||||
from tinygrad.nn.state import torch_load, load_state_dict
|
||||
from extra.utils import fetch_as_file
|
||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv
|
||||
from tinygrad.helpers import GlobalCounters, Timing, DEBUG, getenv, fetch
|
||||
|
||||
MAX_CONTEXT = 128
|
||||
|
||||
@@ -106,7 +105,7 @@ class GPT2:
|
||||
tokenizer = tiktoken.get_encoding("gpt2")
|
||||
|
||||
model = Transformer(**MODEL_PARAMS[model_size])
|
||||
weights = torch_load(fetch_as_file(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
|
||||
weights = torch_load(fetch(f'https://huggingface.co/{model_size}/resolve/main/pytorch_model.bin'))
|
||||
# special treatment for the Conv1D weights we need to transpose
|
||||
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']
|
||||
for k in weights.keys():
|
||||
|
||||
@@ -9,9 +9,8 @@ from collections import namedtuple
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv
|
||||
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv, fetch
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
@@ -405,10 +404,7 @@ class CLIPTextTransformer:
|
||||
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
fn = Path(__file__).parents[1] / "weights/bpe_simple_vocab_16e6.txt.gz"
|
||||
download_file("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", fn)
|
||||
return fn
|
||||
def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
@@ -576,9 +572,6 @@ class StableDiffusion:
|
||||
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
# cond_stage_model.transformer.text_model
|
||||
|
||||
# this is sd-v1-4.ckpt
|
||||
FILENAME = Path(__file__).parents[1] / "weights/sd-v1-4.ckpt"
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
|
||||
@@ -595,8 +588,7 @@ if __name__ == "__main__":
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
|
||||
if args.fp16:
|
||||
for l in get_state_dict(model).values():
|
||||
|
||||
@@ -4,7 +4,7 @@ from examples.stable_diffusion import StableDiffusion
|
||||
from tinygrad.nn.state import get_state_dict, safe_save, safe_load_metadata, torch_load, load_state_dict
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from extra.utils import download_file
|
||||
from tinygrad.helpers import fetch
|
||||
from typing import NamedTuple, Any, List
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
@@ -28,8 +28,6 @@ def convert_f32_to_f16(input_file, output_file):
|
||||
front_float16_values.tofile(f)
|
||||
rest_float32_values.tofile(f)
|
||||
|
||||
FILENAME = Path(__file__).parent.parent.parent.parent / "weights/sd-v1-4.ckpt"
|
||||
|
||||
def split_safetensor(fn):
|
||||
_, json_len, metadata = safe_load_metadata(fn)
|
||||
text_model_offset = 3772703308
|
||||
@@ -40,7 +38,7 @@ def split_safetensor(fn):
|
||||
if (metadata[k]["data_offsets"][0] < text_model_offset):
|
||||
metadata[k]["data_offsets"][0] = int(metadata[k]["data_offsets"][0]/2)
|
||||
metadata[k]["data_offsets"][1] = int(metadata[k]["data_offsets"][1]/2)
|
||||
|
||||
|
||||
last_offset = 0
|
||||
part_end_offsets = []
|
||||
|
||||
@@ -51,7 +49,7 @@ def split_safetensor(fn):
|
||||
break
|
||||
|
||||
part_offset = offset - last_offset
|
||||
|
||||
|
||||
if (part_offset >= chunk_size):
|
||||
part_end_offsets.append(8+json_len+offset)
|
||||
last_offset = offset
|
||||
@@ -60,7 +58,7 @@ def split_safetensor(fn):
|
||||
net_bytes = bytes(open(fn, 'rb').read())
|
||||
part_end_offsets.append(text_model_start+8+json_len)
|
||||
cur_pos = 0
|
||||
|
||||
|
||||
for i, end_pos in enumerate(part_end_offsets):
|
||||
with open(f'./net_part{i}.safetensors', "wb+") as f:
|
||||
f.write(net_bytes[cur_pos:end_pos])
|
||||
@@ -68,7 +66,7 @@ def split_safetensor(fn):
|
||||
|
||||
with open(f'./net_textmodel.safetensors', "wb+") as f:
|
||||
f.write(net_bytes[text_model_start+8+json_len:])
|
||||
|
||||
|
||||
return part_end_offsets
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -81,8 +79,7 @@ if __name__ == "__main__":
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
|
||||
class Step(NamedTuple):
|
||||
name: str = ""
|
||||
@@ -90,11 +87,11 @@ if __name__ == "__main__":
|
||||
forward: Any = None
|
||||
|
||||
sub_steps = [
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "textModel", input = [Tensor.randn(1, 77)], forward = model.cond_stage_model.transformer.text_model),
|
||||
Step(name = "diffusor", input = [Tensor.randn(1, 77, 768), Tensor.randn(1, 77, 768), Tensor.randn(1,4,64,64), Tensor.rand(1), Tensor.randn(1), Tensor.randn(1), Tensor.randn(1)], forward = model),
|
||||
Step(name = "decoder", input = [Tensor.randn(1,4,64,64)], forward = model.decode)
|
||||
]
|
||||
|
||||
|
||||
prg = ""
|
||||
|
||||
def compile_step(model, step: Step):
|
||||
@@ -109,7 +106,7 @@ if __name__ == "__main__":
|
||||
gpu_write_bufs = '\n '.join([f"const gpuWriteBuffer{i} = device.createBuffer({{size:input{i}.size, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.MAP_WRITE }});" for i,(_,value) in enumerate(special_names.items()) if "output" not in value])
|
||||
input_writer = '\n '.join([f"await gpuWriteBuffer{i}.mapAsync(GPUMapMode.WRITE);\n new Float32Array(gpuWriteBuffer{i}.getMappedRange()).set(" + f'data{i});' + f"\n gpuWriteBuffer{i}.unmap();\ncommandEncoder.copyBufferToBuffer(gpuWriteBuffer{i}, 0, input{i}, 0, gpuWriteBuffer{i}.size);" for i,(_,value) in enumerate(special_names.items()) if value != "output0"])
|
||||
return f"""\n var {step.name} = function() {{
|
||||
|
||||
|
||||
{kernel_code}
|
||||
|
||||
return {{
|
||||
@@ -117,7 +114,7 @@ if __name__ == "__main__":
|
||||
const metadata = getTensorMetadata(safetensor[0]);
|
||||
|
||||
{bufs}
|
||||
|
||||
|
||||
{gpu_write_bufs}
|
||||
const gpuReadBuffer = device.createBuffer({{ size: output0.size, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ }});
|
||||
|
||||
@@ -140,7 +137,7 @@ if __name__ == "__main__":
|
||||
gpuReadBuffer.unmap();
|
||||
return resultBuffer;
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import os, gzip, tarfile, pickle
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import dtypes, fetch
|
||||
|
||||
def fetch_mnist(tensors=False):
|
||||
parse = lambda file: np.frombuffer(gzip.open(file).read(), dtype=np.uint8).copy()
|
||||
dirname = Path(__file__).parent.resolve()
|
||||
X_train = parse(dirname / "mnist/train-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_train = parse(dirname / "mnist/train-labels-idx1-ubyte.gz")[8:]
|
||||
X_test = parse(dirname / "mnist/t10k-images-idx3-ubyte.gz")[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_test = parse(dirname / "mnist/t10k-labels-idx1-ubyte.gz")[8:]
|
||||
BASE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/" # http://yann.lecun.com/exdb/mnist/ lacks https
|
||||
X_train = parse(fetch(f"{BASE_URL}train-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_train = parse(fetch(f"{BASE_URL}train-labels-idx1-ubyte.gz"))[8:]
|
||||
X_test = parse(fetch(f"{BASE_URL}t10k-images-idx3-ubyte.gz"))[0x10:].reshape((-1, 28*28)).astype(np.float32)
|
||||
Y_test = parse(fetch(f"{BASE_URL}t10k-labels-idx1-ubyte.gz"))[8:]
|
||||
if tensors: return Tensor(X_train).reshape(-1, 1, 28, 28), Tensor(Y_train), Tensor(X_test).reshape(-1, 1, 28, 28), Tensor(Y_test)
|
||||
else: return X_train, Y_train, X_test, Y_test
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,7 +1,8 @@
|
||||
import math
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import BatchNorm2d
|
||||
from extra.utils import get_child
|
||||
from tinygrad.helpers import get_child, fetch
|
||||
from tinygrad.nn.state import torch_load
|
||||
|
||||
class MBConvBlock:
|
||||
def __init__(self, kernel_size, strides, expand_ratio, input_filters, output_filters, se_ratio, has_se, track_running_stats=True):
|
||||
@@ -142,9 +143,7 @@ class EfficientNet:
|
||||
7: "https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth"
|
||||
}
|
||||
|
||||
from extra.utils import fetch_as_file
|
||||
from tinygrad.nn.state import torch_load
|
||||
b0 = torch_load(fetch_as_file(model_urls[self.number]))
|
||||
b0 = torch_load(fetch(model_urls[self.number]))
|
||||
for k,v in b0.items():
|
||||
if k.endswith("num_batches_tracked"): continue
|
||||
for cat in ['_conv_head', '_conv_stem', '_depthwise_conv', '_expand_conv', '_fc', '_project_conv', '_se_reduce', '_se_expand']:
|
||||
|
||||
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from collections import defaultdict
|
||||
from typing import Union
|
||||
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, dtypes
|
||||
from tinygrad.helpers import prod, getenv, DEBUG, dtypes, get_child
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -47,13 +47,4 @@ def download_file(url, fp, skip_if_exists=True):
|
||||
f.close()
|
||||
Path(f.name).rename(fp)
|
||||
|
||||
def get_child(parent, key):
|
||||
obj = parent
|
||||
for k in key.split('.'):
|
||||
if k.isnumeric():
|
||||
obj = obj[int(k)]
|
||||
elif isinstance(obj, dict):
|
||||
obj = obj[k]
|
||||
else:
|
||||
obj = getattr(obj, k)
|
||||
return obj
|
||||
|
||||
|
||||
4
setup.py
4
setup.py
@@ -19,7 +19,7 @@ setup(name='tinygrad',
|
||||
"Programming Language :: Python :: 3",
|
||||
"License :: OSI Approved :: MIT License"
|
||||
],
|
||||
install_requires=["numpy", "requests", "tqdm", "networkx", "pyopencl",
|
||||
install_requires=["numpy", "tqdm", "pyopencl",
|
||||
"pyobjc-framework-Metal; platform_system=='Darwin'",
|
||||
"pyobjc-framework-Cocoa; platform_system=='Darwin'",
|
||||
"pyobjc-framework-libdispatch; platform_system=='Darwin'"],
|
||||
@@ -55,6 +55,8 @@ setup(name='tinygrad',
|
||||
"sentencepiece",
|
||||
"tiktoken",
|
||||
"librosa",
|
||||
"requests",
|
||||
"networkx",
|
||||
]
|
||||
},
|
||||
include_package_data=True)
|
||||
|
||||
@@ -150,13 +150,13 @@ class TestRoundUp(unittest.TestCase):
|
||||
|
||||
class TestFetch(unittest.TestCase):
|
||||
def test_fetch_bad_http(self):
|
||||
self.assertRaises(AssertionError, fetch, 'http://www.google.com/404')
|
||||
self.assertRaises(Exception, fetch, 'http://www.google.com/404')
|
||||
|
||||
def test_fetch_small(self):
|
||||
assert(len(fetch('https://google.com').read_bytes())>0)
|
||||
assert(len(fetch('https://google.com', allow_caching=False).read_bytes())>0)
|
||||
|
||||
def test_fetch_img(self):
|
||||
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190")
|
||||
img = fetch("https://media.istockphoto.com/photos/hen-picture-id831791190", allow_caching=False)
|
||||
with Image.open(img) as pimg:
|
||||
assert pimg.size == (705, 1024)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, requests, tempfile, pathlib
|
||||
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib
|
||||
import numpy as np
|
||||
from urllib import request
|
||||
from tqdm import tqdm
|
||||
from typing import Dict, Tuple, Union, List, NamedTuple, Final, ClassVar, Optional, Iterable, Any, TypeVar, TYPE_CHECKING, Callable
|
||||
if TYPE_CHECKING: # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
|
||||
@@ -39,6 +40,12 @@ def partition(lst:List[T], fxn:Callable[[T],bool]):
|
||||
def unwrap(x:Optional[T]) -> T:
|
||||
assert x is not None
|
||||
return x
|
||||
def get_child(obj, key):
|
||||
for k in key.split('.'):
|
||||
if k.isnumeric(): obj = obj[int(k)]
|
||||
elif isinstance(obj, dict): obj = obj[k]
|
||||
else: obj = getattr(obj, k)
|
||||
return obj
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def getenv(key:str, default=0): return type(default)(os.getenv(key, default))
|
||||
@@ -225,15 +232,17 @@ def diskcache(func):
|
||||
|
||||
# *** http support ***
|
||||
|
||||
def fetch(url:str) -> pathlib.Path:
|
||||
fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / hashlib.md5(url.encode('utf-8')).hexdigest()
|
||||
if not fp.is_file():
|
||||
r = requests.get(url, stream=True, timeout=10)
|
||||
assert r.status_code == 200
|
||||
progress_bar = tqdm(total=int(r.headers.get('content-length', 0)), unit='B', unit_scale=True, desc=url)
|
||||
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
||||
for chunk in r.iter_content(chunk_size=16384): progress_bar.update(f.write(chunk))
|
||||
f.close()
|
||||
pathlib.Path(f.name).rename(fp)
|
||||
def fetch(url:str, name:Optional[str]=None, allow_caching=not getenv("DISABLE_HTTP_CACHE")) -> pathlib.Path:
|
||||
fp = pathlib.Path(_cache_dir) / "tinygrad" / "downloads" / (name if name else hashlib.md5(url.encode('utf-8')).hexdigest())
|
||||
if not fp.is_file() or not allow_caching:
|
||||
with request.urlopen(url, timeout=10) as r:
|
||||
assert r.status == 200
|
||||
total_length = int(r.headers.get('content-length', 0))
|
||||
progress_bar = tqdm(total=total_length, unit='B', unit_scale=True, desc=url)
|
||||
(path := fp.parent).mkdir(parents=True, exist_ok=True)
|
||||
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
|
||||
while chunk := r.read(16384): progress_bar.update(f.write(chunk))
|
||||
f.close()
|
||||
if (file_size:=os.stat(f.name).st_size) < total_length: raise RuntimeError(f"fetch size incomplete, {file_size} < {total_length}")
|
||||
pathlib.Path(f.name).rename(fp)
|
||||
return fp
|
||||
|
||||
Reference in New Issue
Block a user