add name support to fetch (#2407)

* add name support

* use fetch in gpt2

* remove requests from main lib, networkx also optional

* umm, keep that assert

* updates to fetch

* i love the walrus so much

* stop bundling mnist with tinygrad

* err, https

* download cache names

* add DOWNLOAD_CACHE_VERSION

* need env.

* ugh, wrong path

* replace get_child
This commit is contained in:
George Hotz
2023-11-23 14:16:17 -08:00
committed by GitHub
parent 397c093656
commit 095e2ced61
16 changed files with 73 additions and 79 deletions

View File

@@ -9,9 +9,8 @@ from collections import namedtuple
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv, fetch
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
from tinygrad.jit import TinyJit
@@ -405,10 +404,7 @@ class CLIPTextTransformer:
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
@lru_cache()
def default_bpe():
fn = Path(__file__).parents[1] / "weights/bpe_simple_vocab_16e6.txt.gz"
download_file("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", fn)
return fn
def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
def get_pairs(word):
"""Return set of symbol pairs in a word.
@@ -576,9 +572,6 @@ class StableDiffusion:
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
# cond_stage_model.transformer.text_model
# this is sd-v1-4.ckpt
FILENAME = Path(__file__).parents[1] / "weights/sd-v1-4.ckpt"
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
@@ -595,8 +588,7 @@ if __name__ == "__main__":
model = StableDiffusion()
# load in weights
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
if args.fp16:
for l in get_state_dict(model).values():