mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 23:48:01 -05:00
189 lines
7.5 KiB
Python
189 lines
7.5 KiB
Python
|
|
import numpy as np
|
|
"""
|
|
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
|
|
import tensorflow as tf
|
|
with tf.io.gfile.GFile(fn, "rb") as f:
|
|
dat = f.read()
|
|
with open("cache/"+ fn.rsplit("/", 1)[1], "wb") as g:
|
|
g.write(dat)
|
|
"""
|
|
|
|
import io
|
|
from extra.utils import fetch
|
|
|
|
from tinygrad.tensor import Tensor
|
|
|
|
class ViTBlock:
|
|
def __init__(self, embed_dim, num_heads, ff_dim):
|
|
# Multi-Head Attention
|
|
self.num_heads = num_heads
|
|
self.head_size = embed_dim // num_heads
|
|
assert self.head_size * self.num_heads == embed_dim
|
|
|
|
# added bias
|
|
self.query_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
|
self.key_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
|
self.value_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
|
|
|
self.final = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
|
|
|
|
self.ff1 = (Tensor.uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
|
|
self.ff2 = (Tensor.uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
|
|
|
|
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
|
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
|
|
|
|
def attn(self, x):
|
|
embed_dim = self.num_heads * self.head_size
|
|
|
|
query, key, value = [x.linear(y) \
|
|
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
|
|
for y in [self.query_dense, self.key_dense, self.value_dense]]
|
|
|
|
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
|
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
|
|
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
|
|
|
|
score = query.dot(key) * (1 / np.sqrt(self.head_size))
|
|
weights = score.softmax() # (bs, num_heads, T, T)
|
|
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
|
|
|
|
return attention.reshape(shape=(x.shape[0], -1, embed_dim)).linear(self.final)
|
|
|
|
def __call__(self, x):
|
|
x = x + self.attn(x.layernorm().linear(self.ln1)).dropout(0.1)
|
|
x = x + x.layernorm().linear(self.ln2).linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
|
|
return x
|
|
|
|
class ViT:
|
|
def __init__(self, embed_dim=192):
|
|
self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16)
|
|
self.conv_bias = Tensor.zeros(embed_dim)
|
|
self.cls_token = Tensor.ones(1, 1, embed_dim)
|
|
self.tbs = [ViTBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768) for i in range(12)]
|
|
self.pos_embed = Tensor.ones(1, 197, embed_dim)
|
|
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
|
|
self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim))
|
|
|
|
def patch_embed(self, x):
|
|
x = x.conv2d(self.conv_weight, stride=16)
|
|
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
|
|
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
|
|
return x
|
|
|
|
def forward(self, x):
|
|
pe = self.patch_embed(x)
|
|
# TODO: expand cls_token for batch
|
|
x = self.cls_token.cat(pe, dim=1) + self.pos_embed
|
|
for l in self.tbs:
|
|
x = l(x)
|
|
x = x.layernorm().linear(self.norm)
|
|
return x[:, 0].linear(self.head)
|
|
|
|
Tensor.training = False
|
|
m = ViT()
|
|
|
|
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
|
dat = np.load(io.BytesIO(fetch("https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz")))
|
|
#for x in dat.keys():
|
|
# print(x, dat[x].shape, dat[x].dtype)
|
|
|
|
m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
|
|
m.conv_bias.assign(dat['embedding/bias'])
|
|
|
|
m.norm[0].assign(dat['Transformer/encoder_norm/scale'])
|
|
m.norm[1].assign(dat['Transformer/encoder_norm/bias'])
|
|
|
|
m.head[0].assign(dat['head/kernel'])
|
|
m.head[1].assign(dat['head/bias'])
|
|
|
|
m.cls_token.assign(dat['cls'])
|
|
m.pos_embed.assign(dat['Transformer/posembed_input/pos_embedding'])
|
|
|
|
for i in range(12):
|
|
m.tbs[i].query_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel'].reshape(192, 192))
|
|
m.tbs[i].query_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/bias'].reshape(192))
|
|
m.tbs[i].key_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/kernel'].reshape(192, 192))
|
|
m.tbs[i].key_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/bias'].reshape(192))
|
|
m.tbs[i].value_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/kernel'].reshape(192, 192))
|
|
m.tbs[i].value_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/bias'].reshape(192))
|
|
m.tbs[i].final[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/kernel'].reshape(192, 192))
|
|
m.tbs[i].final[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/bias'].reshape(192))
|
|
m.tbs[i].ff1[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/kernel'])
|
|
m.tbs[i].ff1[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/bias'])
|
|
m.tbs[i].ff2[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/kernel'])
|
|
m.tbs[i].ff2[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/bias'])
|
|
m.tbs[i].ln1[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/scale'])
|
|
m.tbs[i].ln1[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/bias'])
|
|
m.tbs[i].ln2[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/scale'])
|
|
m.tbs[i].ln2[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/bias'])
|
|
|
|
# category labels
|
|
import ast
|
|
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
|
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
|
|
|
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
|
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
|
|
|
# junk
|
|
from PIL import Image
|
|
img = Image.open(io.BytesIO(fetch(url)))
|
|
aspect_ratio = img.size[0] / img.size[1]
|
|
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
|
img = np.array(img)
|
|
y0,x0=(np.asarray(img.shape)[:2]-224)//2
|
|
img = img[y0:y0+224, x0:x0+224]
|
|
img = np.moveaxis(img, [2,0,1], [0,1,2])
|
|
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
|
|
img /= 255.0
|
|
img -= 0.5
|
|
img /= 0.5
|
|
#img[:] = 0
|
|
|
|
"""
|
|
import torch
|
|
from timm.models.vision_transformer import vit_tiny_patch16_224
|
|
mdl = vit_tiny_patch16_224(pretrained=True)
|
|
#out = mdl(torch.Tensor(img))
|
|
#choice = out.argmax(axis=1).item()
|
|
#print(out[0, choice], lbls[choice])
|
|
|
|
pe = m.patch_embed(Tensor(img))
|
|
x = m.cls_token.cat(pe, dim=1) + m.pos_embed
|
|
x = m.tbs[0](x)
|
|
#x = layernorm(x, 192).linear(m.tbs[0].ln1)
|
|
|
|
xp = mdl.patch_embed(torch.Tensor(img))
|
|
xp = torch.cat((mdl.cls_token, xp), dim=1) + mdl.pos_embed
|
|
xp = mdl.blocks[0](xp)
|
|
#xp = mdl.blocks[0].norm1(xp)
|
|
|
|
print(x.shape, xp.shape)
|
|
print(np.max(x.data), np.max(xp.detach().numpy()))
|
|
print(np.max(np.abs(x.data - xp.detach().numpy())))
|
|
|
|
exit(0)
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
#import matplotlib.pyplot as plt
|
|
#plt.imshow(np.transpose(img[0], (1,2,0)))
|
|
#plt.show()
|
|
|
|
out = m.forward(Tensor(img))
|
|
outnp = out.cpu().data.ravel()
|
|
choice = outnp.argmax()
|
|
print(out.shape, choice, outnp[choice], lbls[choice])
|
|
|
|
#lookup = dict([x.split(" ") for x in open("cache/classids.txt").read().strip().split("\n")])
|
|
#cls = open("cache/imagenet21k_wordnet_ids.txt").read().strip().split("\n")
|
|
#print(cls[choice], lookup[cls[choice]])
|
|
|
|
|
|
|