mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
* Fix examples * Remove training in parameters * Simplify a bit * Remove extra import * Fix linter errors * factor out Device * NumPy-like semantics for Tensor.__getitem__ (#506) * Rewrote Tensor.__getitem__ to fix negative indices and add support for np.newaxis/None * Fixed pad2d * mypy doesn't know about mlops methods * normal python behavior for out-of-bounds slicing * type: ignore * inlined idxfix * added comment for __getitem__ * Better comments, better tests, and fixed bug in np.newaxis * update cpu and torch to hold buffers (#542) * update cpu and torch to hold buffers * save lines, and probably faster * Mypy fun (#541) * mypy fun * things are just faster * running fast * mypy is fast * compile.sh * no gpu hack * refactor ops_cpu and ops_torch to not subclass * make weak buffer work * tensor works * fix test failing * cpu/torch cleanups * no or operator on dict in python 3.8 * that was junk * fix warnings * comment and touchup * dyn add of math ops * refactor ops_cpu and ops_torch to not share code * nn/optim.py compiles now * Reorder imports * call mkdir only if directory doesn't exist --------- Co-authored-by: George Hotz <geohot@gmail.com> Co-authored-by: Mitchell Goff <mitchellgoffpc@gmail.com> Co-authored-by: George Hotz <72895+geohot@users.noreply.github.com>
50 lines
1.5 KiB
Python
50 lines
1.5 KiB
Python
import ast
|
|
import io
|
|
import numpy as np
|
|
from PIL import Image
|
|
from tinygrad.tensor import Tensor
|
|
from tinygrad.helpers import getenv
|
|
from models.vit import ViT
|
|
from extra.utils import fetch
|
|
"""
|
|
fn = "gs://vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz"
|
|
import tensorflow as tf
|
|
with tf.io.gfile.GFile(fn, "rb") as f:
|
|
dat = f.read()
|
|
with open("cache/"+ fn.rsplit("/", 1)[1], "wb") as g:
|
|
g.write(dat)
|
|
"""
|
|
|
|
Tensor.training = False
|
|
if getenv("LARGE", 0) == 1:
|
|
m = ViT(embed_dim=768, num_heads=12)
|
|
else:
|
|
# tiny
|
|
m = ViT(embed_dim=192, num_heads=3)
|
|
m.load_from_pretrained()
|
|
|
|
# category labels
|
|
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
|
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
|
|
|
#url = "https://upload.wikimedia.org/wikipedia/commons/4/41/Chicken.jpg"
|
|
url = "https://repository-images.githubusercontent.com/296744635/39ba6700-082d-11eb-98b8-cb29fb7369c0"
|
|
|
|
# junk
|
|
img = Image.open(io.BytesIO(fetch(url)))
|
|
aspect_ratio = img.size[0] / img.size[1]
|
|
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
|
img = np.array(img)
|
|
y0,x0=(np.asarray(img.shape)[:2]-224)//2
|
|
img = img[y0:y0+224, x0:x0+224]
|
|
img = np.moveaxis(img, [2,0,1], [0,1,2])
|
|
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
|
|
img /= 255.0
|
|
img -= 0.5
|
|
img /= 0.5
|
|
|
|
out = m.forward(Tensor(img))
|
|
outnp = out.cpu().data.ravel()
|
|
choice = outnp.argmax()
|
|
print(out.shape, choice, outnp[choice], lbls[choice])
|