mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
ENET WORKS
This commit is contained in:
@@ -55,7 +55,7 @@ class MBConvBlock:
|
||||
x = self._bn2(x.conv2d(self._project_conv))
|
||||
if x.shape == inputs.shape:
|
||||
x = x.add(inputs)
|
||||
return swish(x)
|
||||
return x
|
||||
|
||||
class EfficientNet:
|
||||
def __init__(self):
|
||||
@@ -128,9 +128,11 @@ if __name__ == "__main__":
|
||||
else:
|
||||
url = "https://c.files.bbci.co.uk/12A9B/production/_111434467_gettyimages-1143489763.jpg"
|
||||
img = Image.open(io.BytesIO(fetch(url)))
|
||||
img = img.resize((398, 224))
|
||||
aspect_ratio = img.size[0] / img.size[1]
|
||||
img = img.resize((int(224*aspect_ratio), 224))
|
||||
img = np.array(img)
|
||||
img = img[:, 87:-87]
|
||||
chapo = (img.shape[1]-224)//2
|
||||
img = img[:, chapo:chapo+224]
|
||||
img = np.moveaxis(img, [2,0,1], [0,1,2])
|
||||
img = img.astype(np.float32).reshape(1,3,224,224)
|
||||
img /= 256
|
||||
@@ -145,8 +147,9 @@ if __name__ == "__main__":
|
||||
"""
|
||||
|
||||
# category labels
|
||||
lbls = fetch("https://gist.githubusercontent.com/aaronpolhamus/964a4411c0906315deb9f4a3723aac57/raw/aa66dd9dbf6b56649fa3fab83659b2acbf3cbfd1/map_clsloc.txt")
|
||||
lbls = dict([(int(x.split(" ")[1]), x.split(" ")[2]) for x in lbls.decode('utf-8').split("\n")])
|
||||
import ast
|
||||
lbls = fetch("https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt")
|
||||
lbls = ast.literal_eval(lbls.decode('utf-8'))
|
||||
|
||||
# run the net
|
||||
import time
|
||||
|
||||
Reference in New Issue
Block a user