universal disk cache (#2130)

* caching infra for tinygrad

* nons tr key

* fix linter

* no shelve in beam search

* beam search caching

* check tensor cores with beam too

* pretty print

* LATEBEAM in stable diffusion
This commit is contained in:
George Hotz
2023-10-22 10:56:57 -07:00
committed by GitHub
parent ace6b2a151
commit 6dc8eb5bfd
8 changed files with 125 additions and 29 deletions

View File

@@ -9,7 +9,7 @@ from collections import namedtuple
from tqdm import tqdm
from tinygrad.tensor import Tensor
from tinygrad.ops import Device
from tinygrad.helpers import dtypes, GlobalCounters, Timing
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
from extra.utils import download_file
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
@@ -636,14 +636,15 @@ if __name__ == "__main__":
if args.seed is not None: Tensor._seed = args.seed
latent = Tensor.randn(1,4,64,64)
# this is diffusion
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep))
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
del do_step
with Context(BEAM=getenv("LATEBEAM")):
# this is diffusion
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
GlobalCounters.reset()
t.set_description("%3d %3d" % (index, timestep))
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
if args.timing: Device[Device.DEFAULT].synchronize()
del do_step
# upsample latent space to image with autoencoder
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)