mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add name support to fetch (#2407)
* add name support * use fetch in gpt2 * remove requests from main lib, networkx also optional * umm, keep that assert * updates to fetch * i love the walrus so much * stop bundling mnist with tinygrad * err, https * download cache names * add DOWNLOAD_CACHE_VERSION * need env. * ugh, wrong path * replace get_child
This commit is contained in:
@@ -9,9 +9,8 @@ from collections import namedtuple
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv
|
||||
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv, fetch
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
from tinygrad.jit import TinyJit
|
||||
|
||||
@@ -405,10 +404,7 @@ class CLIPTextTransformer:
|
||||
|
||||
# Clip tokenizer, taken from https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py (MIT license)
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
fn = Path(__file__).parents[1] / "weights/bpe_simple_vocab_16e6.txt.gz"
|
||||
download_file("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", fn)
|
||||
return fn
|
||||
def default_bpe(): return fetch("https://github.com/openai/CLIP/raw/main/clip/bpe_simple_vocab_16e6.txt.gz", "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
@@ -576,9 +572,6 @@ class StableDiffusion:
|
||||
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
|
||||
# cond_stage_model.transformer.text_model
|
||||
|
||||
# this is sd-v1-4.ckpt
|
||||
FILENAME = Path(__file__).parents[1] / "weights/sd-v1-4.ckpt"
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Run Stable Diffusion', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
||||
parser.add_argument('--steps', type=int, default=5, help="Number of steps in diffusion")
|
||||
@@ -595,8 +588,7 @@ if __name__ == "__main__":
|
||||
model = StableDiffusion()
|
||||
|
||||
# load in weights
|
||||
download_file('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', FILENAME)
|
||||
load_state_dict(model, torch_load(FILENAME)['state_dict'], strict=False)
|
||||
load_state_dict(model, torch_load(fetch('https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', 'sd-v1-4.ckpt'))['state_dict'], strict=False)
|
||||
|
||||
if args.fp16:
|
||||
for l in get_state_dict(model).values():
|
||||
|
||||
Reference in New Issue
Block a user