mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
universal disk cache (#2130)
* caching infra for tinygrad * nons tr key * fix linter * no shelve in beam search * beam search caching * check tensor cores with beam too * pretty print * LATEBEAM in stable diffusion
This commit is contained in:
@@ -9,7 +9,7 @@ from collections import namedtuple
|
||||
from tqdm import tqdm
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.helpers import dtypes, GlobalCounters, Timing
|
||||
from tinygrad.helpers import dtypes, GlobalCounters, Timing, Context, getenv
|
||||
from tinygrad.nn import Conv2d, Linear, GroupNorm, LayerNorm, Embedding
|
||||
from extra.utils import download_file
|
||||
from tinygrad.nn.state import torch_load, load_state_dict, get_state_dict
|
||||
@@ -636,14 +636,15 @@ if __name__ == "__main__":
|
||||
if args.seed is not None: Tensor._seed = args.seed
|
||||
latent = Tensor.randn(1,4,64,64)
|
||||
|
||||
# this is diffusion
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
del do_step
|
||||
with Context(BEAM=getenv("LATEBEAM")):
|
||||
# this is diffusion
|
||||
for index, timestep in (t:=tqdm(list(enumerate(timesteps))[::-1])):
|
||||
GlobalCounters.reset()
|
||||
t.set_description("%3d %3d" % (index, timestep))
|
||||
with Timing("step in ", enabled=args.timing, on_exit=lambda _: f", using {GlobalCounters.mem_used/1e9:.2f} GB"):
|
||||
latent = do_step(latent, Tensor([timestep]), Tensor([index]), Tensor([args.guidance]))
|
||||
if args.timing: Device[Device.DEFAULT].synchronize()
|
||||
del do_step
|
||||
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)
|
||||
|
||||
Reference in New Issue
Block a user