stable_diffusion: add attn and layernorm

This commit is contained in:
George Hotz
2022-09-03 11:02:27 -07:00
parent 4dadd95e3c
commit 356732515b

View File

@@ -3,21 +3,36 @@
# this is sd-v1-4.ckpt
FILENAME = "/Users/kafka/fun/mps/stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt"
import os
import numpy as np
from extra.utils import fake_torch_load_zipped, get_child
from tinygrad.nn import Conv2d
from tinygrad.tensor import Tensor
dat = fake_torch_load_zipped(open(FILENAME, "rb"), load_weights=False)
REAL = int(os.getenv("REAL", 0))
dat = fake_torch_load_zipped(open(FILENAME, "rb"), load_weights=REAL)
class Normalize:
def __init__(self, in_channels, num_groups=32):
self.weight = Tensor.uniform(in_channels)
self.bias = Tensor.uniform(in_channels)
self.num_groups = num_groups
def __call__(self, x):
# TODO: write groupnorm
return x
print("norm", x.shape)
x = x.reshape((x.shape[0], self.num_groups, x.shape[1]//self.num_groups, x.shape[2], x.shape[3]))
# subtract mean
x = x - x.mean(axis=(2,3,4), keepdim=True)
# divide stddev
eps = 1e-5
x = x.div((x*x).mean(axis=(2,3,4), keepdim=True).add(eps).sqrt())
# return to old shape
return x.reshape((x.shape[0], x.shape[1]*x.shape[2], x.shape[3], x.shape[4]))
class AttnBlock:
def __init__(self, in_channels):
@@ -27,10 +42,28 @@ class AttnBlock:
self.v = Conv2d(in_channels, in_channels, 1)
self.proj_out = Conv2d(in_channels, in_channels, 1)
# copied from AttnBlock in ldm repo
def __call__(self, x):
# TODO: write attention
print("attention:", x.shape)
return x
h_ = self.norm(x)
q,k,v = self.q(h_), self.k(h_), self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape((b,c,h*w))
q = q.permute((0,2,1)) # b,hw,c
k = k.reshape((b,c,h*w)) # b,c,hw
w_ = q @ k
w_ = w_ * (c**(-0.5))
w_ = w_.softmax()
# attend to values
v = v.reshape((b,c,h*w))
w_ = w_.permute((0,2,1))
h_ = v @ w_
h_ = h_.reshape((b,c,h,w))
return x + self.proj_out(h_)
class ResnetBlock:
def __init__(self, in_channels, out_channels=None):
@@ -162,16 +195,27 @@ for k,v in dat['state_dict'].items():
assert w.shape == v.shape
"""
IMG = "/Users/kafka/fun/mps/stable-diffusion/outputs/txt2img-samples/grid-0006.png"
from PIL import Image
img = Tensor(np.array(Image.open(IMG))).permute((2,0,1)).reshape((1,3,512,512))
print(img.shape)
x = model(img)
print(x.shape)
x = x[0]
print(x.shape)
"""
x = Tensor.uniform(1,4,64,64)
x = model.first_stage_model.decoder(x)
dat = x.numpy()
print(x.shape)
x = x.reshape((3,512,512)).permute((1,2,0))
print(x.shape)
if not REAL: exit(0)
dat = (x.numpy()*256).astype(np.uint8)
print(dat.shape)
from PIL import Image
im = Image.fromarray(dat)
im.save("/tmp/rendered.png")
# ** ldm.models.autoencoder.AutoencoderKL