Files
tinygrad/examples/stable_diffusion.py
2022-09-05 15:52:14 -07:00

901 lines
28 KiB
Python

# https://arxiv.org/pdf/2112.10752.pdf
# https://github.com/ekagra-ranjan/huggingface-blog/blob/main/stable_diffusion.md
import os
import math
import numpy as np
import traceback
from tqdm import tqdm
from collections import namedtuple
from extra.utils import fake_torch_load_zipped, get_child
from tinygrad.nn import Conv2d
from tinygrad.tensor import Tensor
# TODO: rename to GroupNorm and put in nn.py
class Normalize:
def __init__(self, in_channels, num_groups=32):
self.weight = Tensor.empty(in_channels)
self.bias = Tensor.empty(in_channels)
self.num_groups = num_groups
def __call__(self, x):
# reshape for layernorm to work as group norm
# subtract mean and divide stddev
if self.num_groups == None: # just layernorm
x = x.layernorm()
else:
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm().reshape(x.shape)
#print(x.shape)
# elementwise_affine on channels
if len(x.shape) == 4:
# HACK for channels in conv
return (x * self.weight.reshape(1, -1, 1, 1)) + self.bias.reshape(1, -1, 1, 1)
else:
return x.linear(self.weight, self.bias)
class AttnBlock:
def __init__(self, in_channels):
self.norm = Normalize(in_channels)
self.q = Conv2d(in_channels, in_channels, 1)
self.k = Conv2d(in_channels, in_channels, 1)
self.v = Conv2d(in_channels, in_channels, 1)
self.proj_out = Conv2d(in_channels, in_channels, 1)
# copied from AttnBlock in ldm repo
def __call__(self, x):
h_ = self.norm(x)
q,k,v = self.q(h_), self.k(h_), self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = q @ k
w_ = w_ * (c**(-0.5))
w_ = w_.softmax()
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1)
h_ = v @ w_
h_ = h_.reshape(b,c,h,w)
return x + self.proj_out(h_)
class ResnetBlock:
def __init__(self, in_channels, out_channels=None):
self.norm1 = Normalize(in_channels)
self.conv1 = Conv2d(in_channels, out_channels, 3, padding=1)
self.norm2 = Normalize(out_channels)
self.conv2 = Conv2d(out_channels, out_channels, 3, padding=1)
self.nin_shortcut = Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else lambda x: x
def __call__(self, x):
h = self.conv1(self.norm1(x).swish())
h = self.conv2(self.norm2(h).swish())
return self.nin_shortcut(x) + h
class Mid:
def __init__(self, block_in):
self.block_1 = ResnetBlock(block_in, block_in)
self.attn_1 = AttnBlock(block_in)
self.block_2 = ResnetBlock(block_in, block_in)
def __call__(self, x):
return x.sequential([self.block_1, self.attn_1, self.block_2])
class Decoder:
def __init__(self):
sz = [(128, 256), (256, 512), (512, 512), (512, 512)]
self.conv_in = Conv2d(4,512,3, padding=1)
self.mid = Mid(512)
arr = []
for i,s in enumerate(sz):
arr.append({"block":
[ResnetBlock(s[1], s[0]),
ResnetBlock(s[0], s[0]),
ResnetBlock(s[0], s[0])]})
if i != 0: arr[-1]['upsample'] = {"conv": Conv2d(s[0], s[0], 3, padding=1)}
self.up = arr
self.norm_out = Normalize(128)
self.conv_out = Conv2d(128, 3, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
x = self.mid(x)
for l in self.up[::-1]:
print("decode", x.shape)
for b in l['block']: x = b(x)
if 'upsample' in l:
# https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html ?
bs,c,py,px = x.shape
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
x = l['upsample']['conv'](x)
return self.conv_out(self.norm_out(x).swish())
class Encoder:
def __init__(self):
sz = [(128, 128), (128, 256), (256, 512), (512, 512)]
self.conv_in = Conv2d(3,128,3, padding=1)
arr = []
for i,s in enumerate(sz):
arr.append({"block":
[ResnetBlock(s[0], s[1]),
ResnetBlock(s[1], s[1])]})
if i != 3: arr[-1]['downsample'] = {"conv": Conv2d(s[1], s[1], 3, stride=2, padding=(0,1,0,1))}
self.down = arr
self.mid = Mid(512)
self.norm_out = Normalize(512)
self.conv_out = Conv2d(512, 8, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
for l in self.down:
print("encode", x.shape)
for b in l['block']: x = b(x)
if 'downsample' in l: x = l['downsample']['conv'](x)
x = self.mid(x)
return self.conv_out(self.norm_out(x).swish())
class AutoencoderKL:
def __init__(self):
self.encoder = Encoder()
self.decoder = Decoder()
self.quant_conv = Conv2d(8, 8, 1)
self.post_quant_conv = Conv2d(4, 4, 1)
def __call__(self, x):
latent = self.encoder(x)
latent = self.quant_conv(latent)
latent = latent[:, 0:4] # only the means
print("latent", latent.shape)
latent = self.post_quant_conv(latent)
return self.decoder(latent)
class Linear:
def __init__(self, in_features, out_features, bias=True):
self.weight = Tensor.empty(out_features, in_features)
self.bias = Tensor.empty(out_features) if bias else None
def __call__(self, x):
#print(x.shape, self.weight.shape, self.bias.shape)
return x.linear(self.weight.transpose(), self.bias)
# not to be confused with ResnetBlock
class ResBlock:
def __init__(self, channels, emb_channels, out_channels):
self.in_layers = [
Normalize(channels),
Tensor.silu,
Conv2d(channels, out_channels, 3, padding=1)
]
self.emb_layers = [
Tensor.silu,
Linear(emb_channels, out_channels)
]
self.out_layers = [
Normalize(out_channels),
Tensor.silu,
lambda x: x,
Conv2d(out_channels, out_channels, 3, padding=1)
]
self.skip_connection = Conv2d(channels, out_channels, 1) if channels != out_channels else lambda x: x
def __call__(self, x, emb):
h = x.sequential(self.in_layers)
emb_out = emb.sequential(self.emb_layers)
h = h + emb_out
h = h.sequential(self.out_layers)
ret = self.skip_connection(x) + h
#print(ret.numpy())
return ret
class CrossAttention:
def __init__(self, query_dim, context_dim, n_heads, d_head):
self.to_q = Linear(query_dim, n_heads*d_head, bias=False)
self.to_k = Linear(context_dim, n_heads*d_head, bias=False)
self.to_v = Linear(context_dim, n_heads*d_head, bias=False)
self.scale = d_head ** -0.5
self.num_heads = n_heads
self.head_size = d_head
self.to_out = [Linear(n_heads*d_head, query_dim)]
# TODO: this is probably very wrong
def __call__(self, x, context=None):
context = x if context is None else context
#print(x.numpy())
q,k,v = self.to_q(x), self.to_k(context), self.to_v(context)
q = q.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
k = k.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,3,1) # (bs, num_heads, head_size, time)
v = v.reshape(x.shape[0], -1, self.num_heads, self.head_size).permute(0,2,1,3) # (bs, num_heads, time, head_size)
score = q.dot(k) * self.scale
#print("score", score.shape, score.numpy())
#exit(0)
weights = score.softmax() # (bs, num_heads, time, time)
attention = weights.dot(v).permute(0,2,1,3) # (bs, time, num_heads, head_size)
h_ = attention.reshape(shape=(x.shape[0], -1, self.num_heads * self.head_size))
return h_.sequential(self.to_out)
class GEGLU:
def __init__(self, dim_in, dim_out):
self.proj = Linear(dim_in, dim_out * 2)
self.dim_out = dim_out
def __call__(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * gate.gelu()
class FeedForward:
def __init__(self, dim, mult=4):
self.net = [
GEGLU(dim, dim*mult),
lambda x: x,
Linear(dim*mult, dim)
]
def __call__(self, x):
return x.sequential(self.net)
class BasicTransformerBlock:
def __init__(self, dim, context_dim, n_heads, d_head):
self.attn1 = CrossAttention(dim, dim, n_heads, d_head)
self.ff = FeedForward(dim)
self.attn2 = CrossAttention(dim, context_dim, n_heads, d_head)
self.norm1 = Normalize(dim, num_groups=None)
self.norm2 = Normalize(dim, num_groups=None)
self.norm3 = Normalize(dim, num_groups=None)
def __call__(self, x, context=None):
#print(self.norm1(x).numpy())
x = self.attn1(self.norm1(x)) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer:
def __init__(self, channels, context_dim, n_heads, d_head):
self.norm = Normalize(channels)
assert channels == n_heads * d_head
self.proj_in = Conv2d(channels, n_heads * d_head, 1)
self.transformer_blocks = [BasicTransformerBlock(channels, context_dim, n_heads, d_head)]
self.proj_out = Conv2d(n_heads * d_head, channels, 1)
def __call__(self, x, context=None):
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
# correct to here
x = x.reshape(b, c, h*w).permute(0,2,1)
for block in self.transformer_blocks:
x = block(x, context=context)
#print(x.numpy())
#print(x.shape, x.numpy())
x = x.permute(0,2,1).reshape(b, c, h, w)
ret = self.proj_out(x) + x_in
return ret
class Downsample:
def __init__(self, channels):
self.op = Conv2d(channels, channels, 3, stride=2, padding=1)
def __call__(self, x):
return self.op(x)
class Upsample:
def __init__(self, channels):
self.conv = Conv2d(channels, channels, 3, padding=1)
def __call__(self, x):
bs,c,py,px = x.shape
x = x.reshape(bs, c, py, 1, px, 1).expand(bs, c, py, 2, px, 2).reshape(bs, c, py*2, px*2)
return self.conv(x)
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = np.exp(-math.log(max_period) * np.arange(0, half, dtype=np.float32) / half)
args = timesteps.numpy() * freqs
embedding = np.concatenate([np.cos(args), np.sin(args)])
return Tensor(embedding).reshape(1, -1)
class UNetModel:
def __init__(self):
self.time_embed = [
Linear(320, 1280),
Tensor.silu,
Linear(1280, 1280),
]
self.input_blocks = [
[Conv2d(4, 320, kernel_size=3, padding=1)],
# TODO: my head sizes and counts are a guess
[ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[ResBlock(320, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[Downsample(320)],
[ResBlock(320, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
[ResBlock(640, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
[Downsample(640)],
[ResBlock(640, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[ResBlock(1280, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[Downsample(1280)],
[ResBlock(1280, 1280, 1280)],
[ResBlock(1280, 1280, 1280)]
]
self.middle_block = [
ResBlock(1280, 1280, 1280),
SpatialTransformer(1280, 768, 8, 160),
ResBlock(1280, 1280, 1280)
]
self.output_blocks = [
[ResBlock(2560, 1280, 1280)],
[ResBlock(2560, 1280, 1280)],
[ResBlock(2560, 1280, 1280), Upsample(1280)],
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[ResBlock(2560, 1280, 1280), SpatialTransformer(1280, 768, 8, 160)],
[ResBlock(1920, 1280, 1280), SpatialTransformer(1280, 768, 8, 160), Upsample(1280)],
[ResBlock(1920, 1280, 640), SpatialTransformer(640, 768, 8, 80)], # 6
[ResBlock(1280, 1280, 640), SpatialTransformer(640, 768, 8, 80)],
[ResBlock(960, 1280, 640), SpatialTransformer(640, 768, 8, 80), Upsample(640)],
[ResBlock(960, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
[ResBlock(640, 1280, 320), SpatialTransformer(320, 768, 8, 40)],
]
self.out = [
Normalize(320),
Tensor.silu,
Conv2d(320, 4, kernel_size=3, padding=1)
]
def __call__(self, x, timesteps=None, context=None):
# TODO: real time embedding
t_emb = timestep_embedding(timesteps, 320)
emb = t_emb.sequential(self.time_embed)
def run(x, bb):
if isinstance(bb, ResBlock): x = bb(x, emb)
elif isinstance(bb, SpatialTransformer): x = bb(x, context)
else: x = bb(x)
return x
#print(emb.numpy())
saved_inputs = []
for i,b in enumerate(self.input_blocks):
print("input block", i)
for bb in b:
x = run(x, bb)
#if i == 1:
#print(x.numpy())
#exit(0)
saved_inputs.append(x)
#print(x.numpy())
for bb in self.middle_block:
x = run(x, bb)
for i,b in enumerate(self.output_blocks):
print("output block", i)
x = x.cat(saved_inputs.pop(), dim=1)
for bb in b:
x = run(x, bb)
return x.sequential(self.out)
class CLIPMLP:
def __init__(self):
self.fc1 = Linear(768, 3072)
self.fc2 = Linear(3072, 768)
def __call__(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = hidden_states.quick_gelu()
hidden_states = self.fc2(hidden_states)
return hidden_states
class CLIPAttention:
def __init__(self):
self.embed_dim = 768
self.num_heads = 12
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.k_proj = Linear(self.embed_dim, self.embed_dim)
self.v_proj = Linear(self.embed_dim, self.embed_dim)
self.q_proj = Linear(self.embed_dim, self.embed_dim)
self.out_proj = Linear(self.embed_dim, self.embed_dim)
def _shape(self, tensor, seq_len: int, bsz: int):
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3)
def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
#print("ATTN", query_states.numpy())
#print(hidden_states.shape, query_states.shape, key_states.shape, value_states.shape)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
key_states = key_states.reshape(*proj_shape)
src_len = key_states.shape[1]
value_states = value_states.reshape(*proj_shape)
#print(query_states.shape, key_states.shape)
attn_weights = query_states @ key_states.permute(0,2,1)
#print(attn_weights.shape, causal_attention_mask.shape)
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.softmax()
attn_output = attn_weights @ value_states
attn_output = attn_output.reshape(bsz, self.num_heads, tgt_len, self.head_dim)
attn_output = attn_output.permute(0,2,1,3)
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim)
attn_output = self.out_proj(attn_output)
return attn_output
class CLIPEncoderLayer:
def __init__(self):
self.self_attn = CLIPAttention()
self.layer_norm1 = Normalize(768, num_groups=None)
self.mlp = CLIPMLP()
self.layer_norm2 = Normalize(768, num_groups=None)
def __call__(self, hidden_states, causal_attention_mask):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class CLIPEncoder:
def __init__(self):
self.layers = [CLIPEncoderLayer() for i in range(12)]
def __call__(self, hidden_states, causal_attention_mask):
for i,l in enumerate(self.layers):
#if i == 2:
# print(hidden_states.numpy())
# break
hidden_states = l(hidden_states, causal_attention_mask)
return hidden_states
class CLIPTextEmbeddings:
def __init__(self):
self.position_ids = Tensor.empty(1, 77) # what is this?
self.token_embedding = {"weight": Tensor.empty(49408, 768)}
self.position_embedding = {"weight": Tensor.empty(77, 768)}
def __call__(self, input_ids, position_ids):
# TODO: actually support batches
inputs = np.zeros((1, len(input_ids), 49408))
positions = np.zeros((1, len(position_ids), 77))
for i,x in enumerate(input_ids): inputs[0][i][x] = 1
for i,x in enumerate(position_ids): positions[0][i][x] = 1
inputs_embeds = Tensor(inputs, device=self.token_embedding['weight'].device) @ self.token_embedding['weight']
position_embeddings = Tensor(positions, device=self.position_embedding['weight'].device) @ self.position_embedding['weight']
return inputs_embeds + position_embeddings
class CLIPTextTransformer:
def __init__(self):
self.embeddings = CLIPTextEmbeddings()
self.encoder = CLIPEncoder()
self.final_layer_norm = Normalize(768, num_groups=None)
def __call__(self, input_ids):
x = self.embeddings(input_ids, list(range(len(input_ids))))
#print(x.numpy())
causal_attention_mask = np.triu(np.ones((1,1,77,77), dtype=np.float32) * -np.inf, k=1)
x = self.encoder(x, Tensor(causal_attention_mask, device=x.device))
return self.final_layer_norm(x)
class StableDiffusion:
def __init__(self):
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
self.first_stage_model = AutoencoderKL()
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
#def __call__(self, x, timesteps, context):
#return self.model.diffusion_model(x, timesteps, context)
#return self.first_stage_model(x)
# ** ldm.models.autoencoder.AutoencoderKL (done!)
# 3x512x512 <--> 4x64x64 (16384)
# decode torch.Size([1, 4, 64, 64]) torch.Size([1, 3, 512, 512])
# section 4.3 of paper
# first_stage_model.encoder, first_stage_model.decoder
# ** ldm.modules.diffusionmodules.openaimodel.UNetModel
# this is what runs each time to sample. is this the LDM?
# input: 4x64x64
# output: 4x64x64
# model.diffusion_model
# it has attention?
# ** ldm.modules.encoders.modules.FrozenCLIPEmbedder
# cond_stage_model.transformer.text_model
# this is sd-v1-4.ckpt
#FILENAME = "/Users/kafka/fun/mps/stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt"
FILENAME = "/home/kafka/model.ckpt"
REAL = int(os.getenv("REAL", 0))
if __name__ == "__main__":
Tensor.no_init = True
# WTF!! no_grad breaks it
Tensor.no_grad = True
model = StableDiffusion()
# load in weights
dat = fake_torch_load_zipped(open(FILENAME, "rb"), load_weights=REAL)
for k,v in dat['state_dict'].items():
try:
w = get_child(model, k)
except (AttributeError, KeyError, IndexError):
#traceback.print_exc()
w = None
print(f"{str(v.shape):30s}", w, k)
if w is not None:
assert w.shape == v.shape
w.assign(v.astype(np.float32))
# "a horse sized cat eating a bagel"
# run through CLIP to get context
phrase = [49406, 320, 4558, 9832, 2368, 4371, 320, 28777, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
context = model.cond_stage_model.transformer.text_model(phrase)
print("got CLIP context", context.shape)
phrase = [49406, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
unconditional_context = model.cond_stage_model.transformer.text_model(phrase)
print("got unconditional CLIP context", unconditional_context.shape)
def get_model_output(latent, t):
# put into diffuser
timesteps = Tensor([t])
unconditional_latent = model.model.diffusion_model(latent, timesteps, unconditional_context)
latent = model.model.diffusion_model(latent, timesteps, context)
unconditional_guidance_scale = 7.5
e_t = unconditional_latent + unconditional_guidance_scale * (latent - unconditional_latent)
return e_t
alphas = [0.9983, 0.6722, 0.2750, 0.0557]
alphas_prev = [0.9991499781608582, 0.9982960224151611, 0.6721514463424683, 0.27499905228614807]
sigmas = [0,0,0,0]
sqrt_one_minus_alphas = [0.0413, 0.5726, 0.8515, 0.9717]
def get_x_prev_and_pred_x0(x, e_t, index):
temperature = 1
a_t, a_prev, sigma_t, sqrt_one_minus_at = alphas[index], alphas_prev[index], sigmas[index], sqrt_one_minus_alphas[index]
pred_x0 = (x - sqrt_one_minus_at * e_t) / math.sqrt(a_t)
# direction pointing to x_t
dir_xt = math.sqrt(1. - a_prev - sigma_t**2) * e_t
noise = sigma_t * Tensor.randn(*x.shape) * temperature
x_prev = math.sqrt(a_prev) * pred_x0 + dir_xt #+ noise
return x_prev, pred_x0
# start with random noise
latent = Tensor.randn(1,4,64,64)
# is this the diffusion?
for index, timestep in tqdm(list(enumerate([1, 251, 501, 751]))[::-1]):
print(index, timestep)
e_t = get_model_output(latent, timestep)
#print(e_t.numpy())
x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t, index)
#e_t_next = get_model_output(x_prev)
#e_t_prime = (e_t + e_t_next) / 2
#x_prev, pred_x0 = get_x_prev_and_pred_x0(latent, e_t_prime, index)
latent = x_prev
latent.realize()
print(latent.numpy())
#exit(0)
# sanity check
#latent = Tensor(np.load("datasets/stable_diffusion_apple.npy"))
# upsample latent space to image with autoencoder
x = model.first_stage_model.post_quant_conv(1/0.18215 * latent)
x = model.first_stage_model.decoder(x)
# make image correct size and scale
x = (x + 1.0) / 2.0
x = x.reshape(3,512,512).permute(1,2,0)
dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8)
print(dat.shape)
# save image
from PIL import Image
im = Image.fromarray(dat)
im.save("/tmp/rendered.png")
exit(0)
"""
# load apple latent space
nz = Tensor(np.load("datasets/stable_diffusion_apple.npy"))
# run unet (without context)
timesteps = Tensor([32])
context = Tensor.zeros(1, 77, 768)
nz = model(nz, timesteps, context)
# upsample latent space to image with autoencoder
x = model.first_stage_model.post_quant_conv(nz)
x = model.first_stage_model.decoder(x)
# make image correct size
x = x.reshape(3,512,512).permute(1,2,0)
dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8)
print(dat.shape)
# save image
from PIL import Image
im = Image.fromarray(dat)
im.save("/tmp/rendered.png")
exit(0)
"""
"""
outs = model.cond_stage_model.transformer.text_model([1,2,3])
print(outs.numpy())
print(outs.numpy().shape)
"""
"""
from ldm.modules.diffusionmodules.openaimodel import UNetModel
tmodel = UNetModel(
image_size = 32,
in_channels = 4,
out_channels = 4,
model_channels = 320,
attention_resolutions = [4, 2, 1],
num_res_blocks = 2,
channel_mult = [ 1, 2, 4, 4 ],
num_heads = 8,
use_spatial_transformer = True,
transformer_depth = 1,
context_dim = 768,
use_checkpoint = True,
legacy = False)
prefix = "model.diffusion_model."
"""
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
tmodel = FrozenCLIPEmbedder()
prefix = "cond_stage_model."
#from ldm.models.autoencoder import AutoencoderKL
#tmodel = AutoencoderKL(
# ddconfig = {
# "double_z": True,
# "z_channels": 4,
# "resolution": 256,
# "in_channels": 3,
# "out_ch": 3,
# "ch": 128,
# "ch_mult": [1,2,4,4],
# "num_res_blocks": 2,
# "attn_resolutions": []
# },
# lossconfig={"target": "torch.nn.Identity"},
# embed_dim=4)
#prefix = "first_stage_model."
import torch
ckpt = torch.load(FILENAME)
dat = ckpt['state_dict']
sd = {}
for k in dat:
if k.startswith(prefix):
sd[k[len(prefix):]] = dat[k]
print("loading", len(sd))
tmodel.load_state_dict(sd, strict=True)
tmodel = tmodel.cuda()
ret = tmodel("a horse sized cat eating a bagel")
print(ret)
re = model.cond_stage_model.transformer.text_model(phrase)
print(re.numpy())
exit(0)
# run one pass of unet
tnz = torch.Tensor(nz.numpy())
timesteps = Tensor([10])
context = Tensor.uniform(1, 77, 768)
ttimesteps = torch.Tensor(timesteps.numpy())
tcontext = torch.Tensor(context.numpy())
tnz = tmodel(tnz, ttimesteps, tcontext)
nz = model(nz, timesteps, context)
print(tnz)
print(nz.numpy())
print("match", np.mean((tnz.detach().numpy() - nz.numpy())**2))
exit(0)
del model.model
# clear unet
nz = nz.detach()
import gc
gc.collect()
import torch
torch.cuda.empty_cache()
"""
print(out)
print(out.numpy())
exit(0)
if not REAL: exit(0)
"""
# load image
#IMG = "/tmp/apple.png"
#from PIL import Image
#realimg = Tensor(np.array(Image.open(IMG))).permute((2,0,1)).reshape((1,3,512,512))*(1/255)
#print(realimg.shape)
#x = model(realimg)
# load latent space
x = model.first_stage_model.post_quant_conv(nz)
x = model.first_stage_model.decoder(x)
x = x.reshape(3,512,512).permute(1,2,0)
dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8)
print(dat.shape)
from PIL import Image
im = Image.fromarray(dat)
im.save("/tmp/rendered.png")
# torch junk
#IMG = "/Users/kafka/fun/mps/stable-diffusion/outputs/txt2img-samples/grid-0006.png"
#from PIL import Image
#realimg = Tensor(np.array(Image.open(IMG))).permute((2,0,1)).reshape((1,3,512,512))*(1/255)
#print(img.shape)
#x = model(img)
#nz = np.random.randn(*nz.shape) * 100
# PYTHONPATH="$PWD:/Users/kafka/fun/mps/stable-diffusion"
"""
from ldm.models.autoencoder import AutoencoderKL
import torch
ckpt = torch.load(FILENAME)
dat = ckpt['state_dict']
sd = {}
for k in dat:
if k.startswith("first_stage_model."):
sd[k[len("first_stage_model."):]] = dat[k]
print("loading", len(sd))
tmodel = AutoencoderKL(
ddconfig = {
"double_z": True,
"z_channels": 4,
"resolution": 256,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": [1,2,4,4],
"num_res_blocks": 2,
"attn_resolutions": []
},
lossconfig={"target": "torch.nn.Identity"},
embed_dim=4)
tmodel.load_state_dict(sd, strict=True)
nz = np.load("datasets/stable_diffusion_apple.npy")
zmodel = model.first_stage_model
x_torch = torch.tensor(nz)
x_tiny = Tensor(nz)
x_torch = tmodel.post_quant_conv(x_torch)
x_tiny = zmodel.post_quant_conv(x_tiny)
x_torch = tmodel.decoder.conv_in(x_torch)
x_tiny = zmodel.decoder.conv_in(x_tiny)
x_torch = tmodel.decoder.mid.block_1(x_torch, None)
x_tiny = zmodel.decoder.mid['block_1'](x_tiny)
"""
"""
x_torch = tmodel.decoder.mid.block_1.norm1(x_torch)
x_tiny = zmodel.decoder.mid['block_1'].norm1(x_tiny)
x_torch = x_torch * torch.sigmoid(x_torch)
x_tiny = x_tiny.swish()
print(zmodel.decoder.mid['block_1'].conv1.weight.shape)
print(x_tiny.shape)
x_torch = tmodel.decoder.mid.block_1.conv1(x_torch)
x_tiny = zmodel.decoder.mid['block_1'].conv1(x_tiny)
"""
#print(tmodel.decoder.mid.block_1.conv1.weight)
#print(zmodel.decoder.mid['block_1'].conv1.weight.numpy())
#print(abs(x_torch.detach().numpy() - x_tiny.numpy()).mean())
#print(x_torch.shape, x_tiny.shape)
#exit(0)
#exit(0)
"""
posterior = tmodel.encode(torch.tensor(realimg.numpy()))
z = posterior.mode()
print(z.shape)
#exit(0)
nz = z.detach().numpy()
np.save("/tmp/apple.npy", nz)
exit(0)
"""
#x, latent = tmodel(torch.tensor(realimg.numpy()))
#x = tmodel.decode(torch.tensor(nz))
#x = x.reshape(3,512,512).permute(1,2,0)
"""
x = Tensor.randn(1,4,64,64)
x = model.first_stage_model.post_quant_conv(x)
x = model.first_stage_model.decoder(x)
print(x.shape)
x = x.reshape((3,512,512)).permute((1,2,0))
print(x.shape)
if not REAL: exit(0)
"""
"""
#dat = (x.detach().numpy()*256).astype(np.uint8)
dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8)
print(dat.shape)
from PIL import Image
im = Image.fromarray(dat)
im.save("/tmp/rendered.png")
"""