Files
tinygrad/examples/vit.py
2021-11-29 18:54:57 -05:00

189 lines
7.5 KiB
Python

import numpy as np
"""
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
import tensorflow as tf
with tf.io.gfile.GFile(fn, "rb") as f:
dat = f.read()
with open("cache/"+ fn.rsplit("/", 1)[1], "wb") as g:
g.write(dat)
"""
import io
from extra.utils import fetch
from tinygrad.tensor import Tensor
class ViTBlock:
def __init__(self, embed_dim, num_heads, ff_dim):
# Multi-Head Attention
self.num_heads = num_heads
self.head_size = embed_dim // num_heads
assert self.head_size * self.num_heads == embed_dim
# added bias
self.query_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.key_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.value_dense = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.final = (Tensor.uniform(embed_dim, embed_dim), Tensor.zeros(embed_dim))
self.ff1 = (Tensor.uniform(embed_dim, ff_dim), Tensor.zeros(ff_dim))
self.ff2 = (Tensor.uniform(ff_dim, embed_dim), Tensor.zeros(embed_dim))
self.ln1 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
self.ln2 = (Tensor.ones(embed_dim), Tensor.zeros(embed_dim))
def attn(self, x):
embed_dim = self.num_heads * self.head_size
query, key, value = [x.linear(y) \
.reshape(shape=(x.shape[0], -1, self.num_heads, self.head_size)) \
for y in [self.query_dense, self.key_dense, self.value_dense]]
query = query.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
key = key.transpose(order=(0,2,3,1)) # (bs, num_heads, head_size, T)
value = value.transpose(order=(0,2,1,3)) # (bs, num_heads, T, head_size)
score = query.dot(key) * (1 / np.sqrt(self.head_size))
weights = score.softmax() # (bs, num_heads, T, T)
attention = weights.dot(value).transpose(order=(0,2,1,3)) # (bs, T, num_heads, head_size)
return attention.reshape(shape=(x.shape[0], -1, embed_dim)).linear(self.final)
def __call__(self, x):
x = x + self.attn(x.layernorm().linear(self.ln1)).dropout(0.1)
x = x + x.layernorm().linear(self.ln2).linear(self.ff1).gelu().linear(self.ff2).dropout(0.1)
return x
class ViT:
def __init__(self, embed_dim=192):
self.conv_weight = Tensor.uniform(embed_dim, 3, 16, 16)
self.conv_bias = Tensor.zeros(embed_dim)
self.cls_token = Tensor.ones(1, 1, embed_dim)
self.tbs = [ViTBlock(embed_dim=embed_dim, num_heads=3, ff_dim=768) for i in range(12)]
self.pos_embed = Tensor.ones(1, 197, embed_dim)
self.head = (Tensor.uniform(embed_dim, 1000), Tensor.zeros(1000))
self.norm = (Tensor.uniform(embed_dim), Tensor.zeros(embed_dim))
def patch_embed(self, x):
x = x.conv2d(self.conv_weight, stride=16)
x = x.add(self.conv_bias.reshape(shape=(1,-1,1,1)))
x = x.reshape(shape=(x.shape[0], x.shape[1], -1)).transpose(order=(0,2,1))
return x
def forward(self, x):
pe = self.patch_embed(x)
# TODO: expand cls_token for batch
x = self.cls_token.cat(pe, dim=1) + self.pos_embed
for l in self.tbs:
x = l(x)
x = x.layernorm().linear(self.norm)
return x[:, 0].linear(self.head)
Tensor.training = False
m = ViT()
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
dat = np.load(io.BytesIO(fetch("https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz")))
#for x in dat.keys():
# print(x, dat[x].shape, dat[x].dtype)
m.conv_weight.assign(np.transpose(dat['embedding/kernel'], (3,2,0,1)))
m.conv_bias.assign(dat['embedding/bias'])
m.norm[0].assign(dat['Transformer/encoder_norm/scale'])
m.norm[1].assign(dat['Transformer/encoder_norm/bias'])
m.head[0].assign(dat['head/kernel'])
m.head[1].assign(dat['head/bias'])
m.cls_token.assign(dat['cls'])
m.pos_embed.assign(dat['Transformer/posembed_input/pos_embedding'])
for i in range(12):
m.tbs[i].query_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/kernel'].reshape(192, 192))
m.tbs[i].query_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/query/bias'].reshape(192))
m.tbs[i].key_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/kernel'].reshape(192, 192))
m.tbs[i].key_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/key/bias'].reshape(192))
m.tbs[i].value_dense[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/kernel'].reshape(192, 192))
m.tbs[i].value_dense[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/value/bias'].reshape(192))
m.tbs[i].final[0].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/kernel'].reshape(192, 192))
m.tbs[i].final[1].assign(dat[f'Transformer/encoderblock_{i}/MultiHeadDotProductAttention_1/out/bias'].reshape(192))
m.tbs[i].ff1[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/kernel'])
m.tbs[i].ff1[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_0/bias'])
m.tbs[i].ff2[0].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/kernel'])
m.tbs[i].ff2[1].assign(dat[f'Transformer/encoderblock_{i}/MlpBlock_3/Dense_1/bias'])
m.tbs[i].ln1[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/scale'])
m.tbs[i].ln1[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_0/bias'])
m.tbs[i].ln2[0].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/scale'])
m.tbs[i].ln2[1].assign(dat[f'Transformer/encoderblock_{i}/LayerNorm_2/bias'])
# category labels
import ast
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
lbls = ast.literal_eval(lbls.decode('utf-8'))
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
# junk
from PIL import Image
img = Image.open(io.BytesIO(fetch(url)))
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
img = np.array(img)
y0,x0=(np.asarray(img.shape)[:2]-224)//2
img = img[y0:y0+224, x0:x0+224]
img = np.moveaxis(img, [2,0,1], [0,1,2])
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
img /= 255.0
img -= 0.5
img /= 0.5
#img[:] = 0
"""
import torch
from timm.models.vision_transformer import vit_tiny_patch16_224
mdl = vit_tiny_patch16_224(pretrained=True)
#out = mdl(torch.Tensor(img))
#choice = out.argmax(axis=1).item()
#print(out[0, choice], lbls[choice])
pe = m.patch_embed(Tensor(img))
x = m.cls_token.cat(pe, dim=1) + m.pos_embed
x = m.tbs[0](x)
#x = layernorm(x, 192).linear(m.tbs[0].ln1)
xp = mdl.patch_embed(torch.Tensor(img))
xp = torch.cat((mdl.cls_token, xp), dim=1) + mdl.pos_embed
xp = mdl.blocks[0](xp)
#xp = mdl.blocks[0].norm1(xp)
print(x.shape, xp.shape)
print(np.max(x.data), np.max(xp.detach().numpy()))
print(np.max(np.abs(x.data - xp.detach().numpy())))
exit(0)
"""
#import matplotlib.pyplot as plt
#plt.imshow(np.transpose(img[0], (1,2,0)))
#plt.show()
out = m.forward(Tensor(img))
outnp = out.cpu().data.ravel()
choice = outnp.argmax()
print(out.shape, choice, outnp[choice], lbls[choice])
#lookup = dict([x.split(" ") for x in open("cache/classids.txt").read().strip().split("\n")])
#cls = open("cache/imagenet21k_wordnet_ids.txt").read().strip().split("\n")
#print(cls[choice], lookup[cls[choice]])