mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
* fixed pylint, formatted python files iwth cblack on localhost
* Revert "fixed pylint, formatted python files iwth cblack on localhost"
This reverts commit 07e2b88466.
* dedented 4-spaces added linter
Co-authored-by: Iain Wong <iainwong@outlook.com>
96 lines
2.9 KiB
Python
96 lines
2.9 KiB
Python
# load weights from
|
|
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
|
# a rough copy of
|
|
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
|
import os
|
|
GPU = os.getenv("GPU", None) is not None
|
|
import sys
|
|
import io
|
|
import time
|
|
import numpy as np
|
|
np.set_printoptions(suppress=True)
|
|
from tinygrad.tensor import Tensor
|
|
from extra.utils import fetch, get_parameters
|
|
from extra.efficientnet import EfficientNet
|
|
|
|
def infer(model, img):
|
|
# preprocess image
|
|
aspect_ratio = img.size[0] / img.size[1]
|
|
img = img.resize((int(224*max(aspect_ratio,1.0)), int(224*max(1.0/aspect_ratio,1.0))))
|
|
|
|
img = np.array(img)
|
|
y0,x0=(np.asarray(img.shape)[:2]-224)//2
|
|
retimg = img = img[y0:y0+224, x0:x0+224]
|
|
|
|
# if you want to look at the image
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
plt.imshow(img)
|
|
plt.show()
|
|
"""
|
|
|
|
# low level preprocess
|
|
img = np.moveaxis(img, [2,0,1], [0,1,2])
|
|
img = img.astype(np.float32)[:3].reshape(1,3,224,224)
|
|
img /= 255.0
|
|
img -= np.array([0.485, 0.456, 0.406]).reshape((1,-1,1,1))
|
|
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
|
|
|
|
# run the net
|
|
if GPU:
|
|
out = model.forward(Tensor(img).cuda()).cpu()
|
|
else:
|
|
out = model.forward(Tensor(img))
|
|
|
|
# if you want to look at the outputs
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
plt.plot(out.data[0])
|
|
plt.show()
|
|
"""
|
|
return out, retimg
|
|
|
|
if __name__ == "__main__":
|
|
# instantiate my net
|
|
model = EfficientNet(int(os.getenv("NUM", "0")))
|
|
model.load_weights_from_torch()
|
|
if GPU:
|
|
[x.cuda_() for x in get_parameters(model)]
|
|
|
|
# category labels
|
|
import ast
|
|
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
|
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
|
|
|
# load image and preprocess
|
|
from PIL import Image
|
|
url = sys.argv[1]
|
|
if url == 'webcam':
|
|
import cv2
|
|
cap = cv2.VideoCapture(0)
|
|
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)
|
|
while 1:
|
|
_ = cap.grab() # discard one frame to circumvent capture buffering
|
|
ret, frame = cap.read()
|
|
img = Image.fromarray(frame[:, :, [2,1,0]])
|
|
out, retimg = infer(model, img)
|
|
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
|
|
SCALE = 3
|
|
simg = cv2.resize(retimg, (224*SCALE, 224*SCALE))
|
|
retimg = cv2.cvtColor(simg, cv2.COLOR_RGB2BGR)
|
|
cv2.imshow('capture', retimg)
|
|
if cv2.waitKey(1) & 0xFF == ord('q'):
|
|
break
|
|
cap.release()
|
|
cv2.destroyAllWindows()
|
|
else:
|
|
if url.startswith('http'):
|
|
img = Image.open(io.BytesIO(fetch(url)))
|
|
else:
|
|
img = Image.open(url)
|
|
st = time.time()
|
|
out, _ = infer(model, img)
|
|
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
|
|
print("did inference in %.2f s" % (time.time()-st))
|
|
#print("NOT", np.argmin(out.data), np.min(out.data), lbls[np.argmin(out.data)])
|