mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
all models match
This commit is contained in:
@@ -409,19 +409,30 @@ class CLIPAttention:
|
||||
def _shape(self, tensor, seq_len: int, bsz: int):
|
||||
return tensor.reshape(bsz, seq_len, self.num_heads, self.head_dim).permute(0,2,1,3)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
bsz, tgt_len, embed_dim = hidden_states.shape
|
||||
|
||||
|
||||
query_states = self.q_proj(hidden_states) * self.scale
|
||||
key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
|
||||
value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
|
||||
|
||||
#print("ATTN", query_states.numpy())
|
||||
#print(hidden_states.shape, query_states.shape, key_states.shape, value_states.shape)
|
||||
|
||||
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
|
||||
query_states = self._shape(query_states, tgt_len, bsz).reshape(*proj_shape)
|
||||
key_states = key_states.reshape(*proj_shape)
|
||||
src_len = key_states.shape[1]
|
||||
value_states = value_states.reshape(*proj_shape)
|
||||
|
||||
#print(query_states.shape, key_states.shape)
|
||||
attn_weights = query_states @ key_states.permute(0,2,1)
|
||||
|
||||
#print(attn_weights.shape, causal_attention_mask.shape)
|
||||
attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask
|
||||
attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_weights = attn_weights.softmax()
|
||||
|
||||
attn_output = attn_weights @ value_states
|
||||
@@ -440,10 +451,10 @@ class CLIPEncoderLayer:
|
||||
self.mlp = CLIPMLP()
|
||||
self.layer_norm2 = Normalize(768, num_groups=None)
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, causal_attention_mask)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
@@ -457,8 +468,13 @@ class CLIPEncoder:
|
||||
def __init__(self):
|
||||
self.layers = [CLIPEncoderLayer() for i in range(12)]
|
||||
|
||||
def __call__(self, hidden_states):
|
||||
return hidden_states.sequential(self.layers)
|
||||
def __call__(self, hidden_states, causal_attention_mask):
|
||||
for i,l in enumerate(self.layers):
|
||||
#if i == 2:
|
||||
# print(hidden_states.numpy())
|
||||
# break
|
||||
hidden_states = l(hidden_states, causal_attention_mask)
|
||||
return hidden_states
|
||||
|
||||
class CLIPTextEmbeddings:
|
||||
def __init__(self):
|
||||
@@ -484,14 +500,16 @@ class CLIPTextTransformer:
|
||||
|
||||
def __call__(self, input_ids):
|
||||
x = self.embeddings(input_ids, list(range(len(input_ids))))
|
||||
x = self.encoder(x)
|
||||
print(x.numpy())
|
||||
causal_attention_mask = np.triu(np.ones((1,1,77,77), dtype=np.float32) * -np.inf, k=1)
|
||||
x = self.encoder(x, Tensor(causal_attention_mask, device=x.device))
|
||||
return self.final_layer_norm(x)
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self):
|
||||
self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
||||
#self.model = namedtuple("DiffusionModel", ["diffusion_model"])(diffusion_model = UNetModel())
|
||||
#self.first_stage_model = AutoencoderKL()
|
||||
#self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
self.cond_stage_model = namedtuple("CondStageModel", ["transformer"])(transformer = namedtuple("Transformer", ["text_model"])(text_model = CLIPTextTransformer()))
|
||||
|
||||
def __call__(self, x, timesteps, context):
|
||||
return self.model.diffusion_model(x, timesteps, context)
|
||||
@@ -537,12 +555,41 @@ if __name__ == "__main__":
|
||||
assert w.shape == v.shape
|
||||
w.assign(v.astype(np.float32))
|
||||
|
||||
# "a horse sized cat eating a bagel"
|
||||
phrase = [49406, 320, 4558, 9832, 2368, 4371, 320, 28777, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407]
|
||||
|
||||
"""
|
||||
# load apple latent space
|
||||
nz = Tensor(np.load("datasets/stable_diffusion_apple.npy"))
|
||||
|
||||
# run unet (without context)
|
||||
timesteps = Tensor([32])
|
||||
context = Tensor.zeros(1, 77, 768)
|
||||
nz = model(nz, timesteps, context)
|
||||
|
||||
# upsample latent space to image with autoencoder
|
||||
x = model.first_stage_model.post_quant_conv(nz)
|
||||
x = model.first_stage_model.decoder(x)
|
||||
|
||||
# make image correct size
|
||||
x = x.reshape(3,512,512).permute(1,2,0)
|
||||
dat = (x.detach().numpy().clip(0, 1)*255).astype(np.uint8)
|
||||
print(dat.shape)
|
||||
|
||||
# save image
|
||||
from PIL import Image
|
||||
im = Image.fromarray(dat)
|
||||
im.save("/tmp/rendered.png")
|
||||
exit(0)
|
||||
"""
|
||||
|
||||
"""
|
||||
outs = model.cond_stage_model.transformer.text_model([1,2,3])
|
||||
print(outs.numpy())
|
||||
print(outs.numpy().shape)
|
||||
"""
|
||||
|
||||
"""
|
||||
from ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
tmodel = UNetModel(
|
||||
image_size = 32,
|
||||
@@ -559,10 +606,11 @@ if __name__ == "__main__":
|
||||
use_checkpoint = True,
|
||||
legacy = False)
|
||||
prefix = "model.diffusion_model."
|
||||
"""
|
||||
|
||||
#from ldm.modules.encoders.modules import FrozenCLIPEmbedder
|
||||
#tmodel = FrozenCLIPEmbedder()
|
||||
#prefix = "cond_stage_model."
|
||||
from ldm.modules.encoders.modules import FrozenCLIPEmbedder
|
||||
tmodel = FrozenCLIPEmbedder()
|
||||
prefix = "cond_stage_model."
|
||||
|
||||
#from ldm.models.autoencoder import AutoencoderKL
|
||||
#tmodel = AutoencoderKL(
|
||||
@@ -590,22 +638,31 @@ if __name__ == "__main__":
|
||||
sd[k[len(prefix):]] = dat[k]
|
||||
print("loading", len(sd))
|
||||
tmodel.load_state_dict(sd, strict=True)
|
||||
tmodel = tmodel.cuda()
|
||||
|
||||
# load apple latent space
|
||||
nz = Tensor(np.load("datasets/stable_diffusion_apple.npy"))
|
||||
ret = tmodel("a horse sized cat eating a bagel")
|
||||
print(ret)
|
||||
|
||||
re = model.cond_stage_model.transformer.text_model(phrase)
|
||||
print(re.numpy())
|
||||
|
||||
exit(0)
|
||||
|
||||
# run one pass of unet
|
||||
tnz = torch.Tensor(nz.numpy())
|
||||
ttimesteps = torch.Tensor([0])
|
||||
tcontext = torch.zeros(1, 77, 768)
|
||||
timesteps = Tensor([10])
|
||||
context = Tensor.uniform(1, 77, 768)
|
||||
|
||||
ttimesteps = torch.Tensor(timesteps.numpy())
|
||||
tcontext = torch.Tensor(context.numpy())
|
||||
tnz = tmodel(tnz, ttimesteps, tcontext)
|
||||
timesteps = Tensor([0])
|
||||
context = Tensor.zeros(1, 77, 768)
|
||||
nz = model(nz, timesteps, context)
|
||||
|
||||
print(tnz)
|
||||
print(nz.numpy())
|
||||
|
||||
print("match", np.mean((tnz.detach().numpy() - nz.numpy())**2))
|
||||
|
||||
|
||||
exit(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user