it thinks it's washer. it's cat. bad net. you do bad.

This commit is contained in:
George Hotz
2020-10-30 08:28:05 -07:00
parent c14473f87d
commit 71aedc2309

View File

@@ -116,18 +116,24 @@ if __name__ == "__main__":
model = EfficientNet()
model.load_weights_from_torch()
# load cat image
# load cat image and preprocess
from PIL import Image
img = Image.open(io.BytesIO(fetch("https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg")))
img = img.resize((224, 224))
img = np.moveaxis(np.array(img), [2,0,1], [0,1,2])
img = img.astype(np.float32).reshape(1,3,224,224)
print(img.shape, img.dtype)
img /= 256
img -= np.array([0.485, 0.456, 0.406]).reshape((1,-1,1,1))
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
# category labels
lbls = fetch("https://gist.githubusercontent.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57/raw/aa66dd9dbf6b56649fa3fab83659b2acbf3cbfd1/map_clsloc.txt")
lbls = dict([(int(x.split(" ")[1]), x.split(" ")[2]) for x in lbls.decode('utf-8').split("\n")])
# run the net
import time
st = time.time()
out = model.forward(Tensor(img))
print("did inference in %.2f s" % (time.time()-st))
print(np.argmax(out.data), np.max(out.data))
print(np.argmax(out.data), np.max(out.data), lbls[np.argmax(out.data)])