all models match

This commit is contained in:
George Hotz
2022-09-05 12:27:54 -07:00
parent b8bd34b5d2
commit 98d6264987

View File

@@ -409,19 +409,30 @@ class CLIPAttention:
def _shape(self, tensor, seq_len: int, bsz: int):
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3)
def __call__(self, hidden_states):
def __call__(self, hidden_states, causal_attention_mask):
bsz, tgt_len, embed_dim = hidden_states.shape
query_states = self.q_proj(hidden_states) * self.scale
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
#print("ATTN", query_states.numpy())
#print(hidden_states.shape, query_states.shape, key_states.shape, value_states.shape)
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
key_states = key_states.reshape(*proj_shape)
src_len = key_states.shape[1]
value_states = value_states.reshape(*proj_shape)
#print(query_states.shape, key_states.shape)
attn_weights = query_states @ key_states.permute(0,2,1)
#print(attn_weights.shape, causal_attention_mask.shape)
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
attn_weights = attn_weights.softmax()
attn_output = attn_weights @ value_states
@@ -440,10 +451,10 @@ class CLIPEncoderLayer:
self.mlp = CLIPMLP()
self.layer_norm2 = Normalize(768, num_groups=None)
def __call__(self, hidden_states):
def __call__(self, hidden_states, causal_attention_mask):
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states)
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
hidden_states = residual + hidden_states
residual = hidden_states
@@ -457,8 +468,13 @@ class CLIPEncoder:
def __init__(self):
self.layers = [CLIPEncoderLayer() for i in range(12)]
def __call__(self, hidden_states):
return hidden_states.sequential(self.layers)
def __call__(self, hidden_states, causal_attention_mask):
for i,l in enumerate(self.layers):
#if i == 2:
# print(hidden_states.numpy())
# break
hidden_states = l(hidden_states, causal_attention_mask)
return hidden_states
class CLIPTextEmbeddings:
def __init__(self):
@@ -484,14 +500,16 @@ class CLIPTextTransformer:
def __call__(self, input_ids):
x = self.embeddings(input_ids, list(range(len(input_ids))))
x = self.encoder(x)
print(x.numpy())
causal_attention_mask = np.triu(np.ones((1,1,77,77), dtype=np.float32) * -np.inf, k=1)
x = self.encoder(x, Tensor(causal_attention_mask, device=x.device))
return self.final_layer_norm(x)
class StableDiffusion:
def __init__(self):
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
#self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
#self.first_stage_model = AutoencoderKL()
#self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
def __call__(self, x, timesteps, context):
return self.model.diffusion_model(x, timesteps, context)
@@ -537,12 +555,41 @@ if __name__ == "__main__":
assert w.shape == v.shape
w.assign(v.astype(np.float32))
# "a horse sized cat eating a bagel"
phrase = [49406, 320, 4558, 9832, 2368, 4371, 320, 28777, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
"""
# load apple latent space
nz = Tensor(np.load("datasets/stable_diffusion_apple.npy"))
# run unet (without context)
timesteps = Tensor([32])
context = Tensor.zeros(1, 77, 768)
nz = model(nz, timesteps, context)
# upsample latent space to image with autoencoder
x = model.first_stage_model.post_quant_conv(nz)
x = model.first_stage_model.decoder(x)
# make image correct size
x = x.reshape(3,512,512).permute(1,2,0)
dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8)
print(dat.shape)
# save image
from PIL import Image
im = Image.fromarray(dat)
im.save("/tmp/rendered.png")
exit(0)
"""
"""
outs = model.cond_stage_model.transformer.text_model([1,2,3])
print(outs.numpy())
print(outs.numpy().shape)
"""
"""
from ldm.modules.diffusionmodules.openaimodel import UNetModel
tmodel = UNetModel(
image_size = 32,
@@ -559,10 +606,11 @@ if __name__ == "__main__":
use_checkpoint = True,
legacy = False)
prefix = "model.diffusion_model."
"""
#from ldm.modules.encoders.modules import FrozenCLIPEmbedder
#tmodel = FrozenCLIPEmbedder()
#prefix = "cond_stage_model."
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
tmodel = FrozenCLIPEmbedder()
prefix = "cond_stage_model."
#from ldm.models.autoencoder import AutoencoderKL
#tmodel = AutoencoderKL(
@@ -590,22 +638,31 @@ if __name__ == "__main__":
sd[k[len(prefix):]] = dat[k]
print("loading", len(sd))
tmodel.load_state_dict(sd, strict=True)
tmodel = tmodel.cuda()
# load apple latent space
nz = Tensor(np.load("datasets/stable_diffusion_apple.npy"))
ret = tmodel("a horse sized cat eating a bagel")
print(ret)
re = model.cond_stage_model.transformer.text_model(phrase)
print(re.numpy())
exit(0)
# run one pass of unet
tnz = torch.Tensor(nz.numpy())
ttimesteps = torch.Tensor([0])
tcontext = torch.zeros(1, 77, 768)
timesteps = Tensor([10])
context = Tensor.uniform(1, 77, 768)
ttimesteps = torch.Tensor(timesteps.numpy())
tcontext = torch.Tensor(context.numpy())
tnz = tmodel(tnz, ttimesteps, tcontext)
timesteps = Tensor([0])
context = Tensor.zeros(1, 77, 768)
nz = model(nz, timesteps, context)
print(tnz)
print(nz.numpy())
print("match", np.mean((tnz.detach().numpy() - nz.numpy())**2))
exit(0)