mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
it thinks it's washer. it's cat. bad net. you do bad.
This commit is contained in:
@@ -116,18 +116,24 @@ if __name__ == "__main__":
|
||||
model = EfficientNet()
|
||||
model.load_weights_from_torch()
|
||||
|
||||
# load cat image
|
||||
# load cat image and preprocess
|
||||
from PIL import Image
|
||||
img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg")))
|
||||
img = img.resize((224, 224))
|
||||
img = np.moveaxis(np.array(img), [2,0,1], [0,1,2])
|
||||
img = img.astype(np.float32).reshape(1,3,224,224)
|
||||
print(img.shape, img.dtype)
|
||||
img /= 256
|
||||
img -= np.array([0.485, 0.456, 0.406]).reshape((1,-1,1,1))
|
||||
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
|
||||
|
||||
# category labels
|
||||
lbls = fetch("https://gist.githubusercontent.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57/raw/aa66dd9dbf6b56649fa3fab83659b2acbf3cbfd1/map_clsloc.txt")
|
||||
lbls = dict([(int(x.split(" ")[1]), x.split(" ")[2]) for x in lbls.decode('utf-8').split("\n")])
|
||||
|
||||
# run the net
|
||||
import time
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(img))
|
||||
print("did inference in %.2f s" % (time.time()-st))
|
||||
print(np.argmax(out.data), np.max(out.data))
|
||||
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])
|
||||
|
||||
|
||||
Reference in New Issue
Block a user