ENET WORKS

This commit is contained in:
George Hotz
2020-10-31 10:42:58 -07:00
parent 68cba88e8f
commit 06928cf3cc

View File

@@ -55,7 +55,7 @@ class MBConvBlock:
x = self._bn2(x.conv2d(self._project_conv))
if x.shape == inputs.shape:
x = x.add(inputs)
return swish(x)
return x
class EfficientNet:
def __init__(self):
@@ -128,9 +128,11 @@ if __name__ == "__main__":
else:
url = "https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg"
img = Image.open(io.BytesIO(fetch(url)))
img = img.resize((398, 224))
aspect_ratio = img.size[0] / img.size[1]
img = img.resize((int(224*aspect_ratio), 224))
img = np.array(img)
img = img[:, 87:-87]
chapo = (img.shape[1]-224)//2
img = img[:, chapo:chapo+224]
img = np.moveaxis(img, [2,0,1], [0,1,2])
img = img.astype(np.float32).reshape(1,3,224,224)
img /= 256
@@ -145,8 +147,9 @@ if __name__ == "__main__":
"""
# category labels
lbls = fetch("https://gist.githubusercontent.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57/raw/aa66dd9dbf6b56649fa3fab83659b2acbf3cbfd1/map_clsloc.txt")
lbls = dict([(int(x.split(" ")[1]), x.split(" ")[2]) for x in lbls.decode('utf-8').split("\n")])
import ast
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
lbls = ast.literal_eval(lbls.decode('utf-8'))
# run the net
import time