mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
stable_diffusion: add attn and layernorm
This commit is contained in:
@@ -3,21 +3,36 @@
|
||||
# this is sd-v1-4.ckpt
|
||||
FILENAME = "/Users/kafka/fun/mps/stable-diffusion/models/ldm/stable-diffusion-v1/model.ckpt"
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from extra.utils import fake_torch_load_zipped, get_child
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
dat = fake_torch_load_zipped(open(FILENAME, "rb"), load_weights=False)
|
||||
REAL = int(os.getenv("REAL", 0))
|
||||
|
||||
dat = fake_torch_load_zipped(open(FILENAME, "rb"), load_weights=REAL)
|
||||
|
||||
class Normalize:
|
||||
def __init__(self, in_channels, num_groups=32):
|
||||
self.weight = Tensor.uniform(in_channels)
|
||||
self.bias = Tensor.uniform(in_channels)
|
||||
self.num_groups = num_groups
|
||||
|
||||
def __call__(self, x):
|
||||
# TODO: write groupnorm
|
||||
return x
|
||||
print("norm", x.shape)
|
||||
x = x.reshape((x.shape[0], self.num_groups, x.shape[1]//self.num_groups, x.shape[2], x.shape[3]))
|
||||
|
||||
# subtract mean
|
||||
x = x - x.mean(axis=(2,3,4), keepdim=True)
|
||||
|
||||
# divide stddev
|
||||
eps = 1e-5
|
||||
x = x.div((x*x).mean(axis=(2,3,4), keepdim=True).add(eps).sqrt())
|
||||
|
||||
# return to old shape
|
||||
return x.reshape((x.shape[0], x.shape[1]*x.shape[2], x.shape[3], x.shape[4]))
|
||||
|
||||
|
||||
class AttnBlock:
|
||||
def __init__(self, in_channels):
|
||||
@@ -27,10 +42,28 @@ class AttnBlock:
|
||||
self.v = Conv2d(in_channels, in_channels, 1)
|
||||
self.proj_out = Conv2d(in_channels, in_channels, 1)
|
||||
|
||||
# copied from AttnBlock in ldm repo
|
||||
def __call__(self, x):
|
||||
# TODO: write attention
|
||||
print("attention:", x.shape)
|
||||
return x
|
||||
h_ = self.norm(x)
|
||||
q,k,v = self.q(h_), self.k(h_), self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape((b,c,h*w))
|
||||
q = q.permute((0,2,1)) # b,hw,c
|
||||
k = k.reshape((b,c,h*w)) # b,c,hw
|
||||
w_ = q @ k
|
||||
w_ = w_ * (c**(-0.5))
|
||||
w_ = w_.softmax()
|
||||
|
||||
# attend to values
|
||||
v = v.reshape((b,c,h*w))
|
||||
w_ = w_.permute((0,2,1))
|
||||
h_ = v @ w_
|
||||
h_ = h_.reshape((b,c,h,w))
|
||||
|
||||
return x + self.proj_out(h_)
|
||||
|
||||
class ResnetBlock:
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
@@ -162,16 +195,27 @@ for k,v in dat['state_dict'].items():
|
||||
assert w.shape == v.shape
|
||||
|
||||
|
||||
"""
|
||||
IMG = "/Users/kafka/fun/mps/stable-diffusion/outputs/txt2img-samples/grid-0006.png"
|
||||
from PIL import Image
|
||||
img = Tensor(np.array(Image.open(IMG))).permute((2,0,1)).reshape((1,3,512,512))
|
||||
print(img.shape)
|
||||
x = model(img)
|
||||
print(x.shape)
|
||||
x = x[0]
|
||||
print(x.shape)
|
||||
"""
|
||||
x = Tensor.uniform(1,4,64,64)
|
||||
x = model.first_stage_model.decoder(x)
|
||||
|
||||
dat = x.numpy()
|
||||
print(x.shape)
|
||||
x = x.reshape((3,512,512)).permute((1,2,0))
|
||||
print(x.shape)
|
||||
if not REAL: exit(0)
|
||||
|
||||
dat = (x.numpy()*256).astype(np.uint8)
|
||||
print(dat.shape)
|
||||
|
||||
from PIL import Image
|
||||
im = Image.fromarray(dat)
|
||||
im.save("/tmp/rendered.png")
|
||||
|
||||
|
||||
# ** ldm.models.autoencoder.AutoencoderKL
|
||||
|
||||
Reference in New Issue
Block a user