mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix GPU efficientnet example
This commit is contained in:
@@ -3,7 +3,6 @@
|
|||||||
# a rough copy of
|
# a rough copy of
|
||||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||||
import os
|
import os
|
||||||
GPU = os.getenv("GPU", None) is not None
|
|
||||||
import sys
|
import sys
|
||||||
import io
|
import io
|
||||||
import time
|
import time
|
||||||
@@ -37,10 +36,7 @@ def infer(model, img):
|
|||||||
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
|
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
|
||||||
|
|
||||||
# run the net
|
# run the net
|
||||||
if GPU:
|
out = model.forward(Tensor(img)).cpu()
|
||||||
out = model.forward(Tensor(img).gpu()).cpu()
|
|
||||||
else:
|
|
||||||
out = model.forward(Tensor(img))
|
|
||||||
|
|
||||||
# if you want to look at the outputs
|
# if you want to look at the outputs
|
||||||
"""
|
"""
|
||||||
@@ -54,8 +50,6 @@ if __name__ == "__main__":
|
|||||||
# instantiate my net
|
# instantiate my net
|
||||||
model = EfficientNet(int(os.getenv("NUM", "0")))
|
model = EfficientNet(int(os.getenv("NUM", "0")))
|
||||||
model.load_weights_from_torch()
|
model.load_weights_from_torch()
|
||||||
if GPU:
|
|
||||||
[x.gpu_() for x in get_parameters(model)]
|
|
||||||
|
|
||||||
# category labels
|
# category labels
|
||||||
import ast
|
import ast
|
||||||
|
|||||||
@@ -159,10 +159,11 @@ class EfficientNet:
|
|||||||
mv = eval(mk.replace(".weight", ""))
|
mv = eval(mk.replace(".weight", ""))
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
mv = eval(mk.replace(".bias", "_bias"))
|
mv = eval(mk.replace(".bias", "_bias"))
|
||||||
vnp = v.numpy().astype(np.float32) if USE_TORCH else v
|
vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32)
|
||||||
vnp = vnp if k != '_fc.weight' else vnp.T
|
vnp = vnp if k != '_fc.weight' else vnp.T
|
||||||
|
vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||||
|
|
||||||
if mv.shape == vnp.shape or vnp.shape == ():
|
if mv.shape == vnp.shape:
|
||||||
mv.data[:] = vnp
|
mv.assign(Tensor(vnp))
|
||||||
else:
|
else:
|
||||||
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
||||||
|
|||||||
Reference in New Issue
Block a user