fix GPU efficientnet example

This commit is contained in:
George Hotz
2021-05-26 17:29:35 -07:00
parent 1ae0e88627
commit b80cacb416
2 changed files with 5 additions and 10 deletions

View File

@@ -3,7 +3,6 @@
# a rough copy of # a rough copy of
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
import os import os
GPU = os.getenv("GPU", None) is not None
import sys import sys
import io import io
import time import time
@@ -37,10 +36,7 @@ def infer(model, img):
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1)) img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
# run the net # run the net
if GPU: out = model.forward(Tensor(img)).cpu()
out = model.forward(Tensor(img).gpu()).cpu()
else:
out = model.forward(Tensor(img))
# if you want to look at the outputs # if you want to look at the outputs
""" """
@@ -54,8 +50,6 @@ if __name__ == "__main__":
# instantiate my net # instantiate my net
model = EfficientNet(int(os.getenv("NUM", "0"))) model = EfficientNet(int(os.getenv("NUM", "0")))
model.load_weights_from_torch() model.load_weights_from_torch()
if GPU:
[x.gpu_() for x in get_parameters(model)]
# category labels # category labels
import ast import ast

View File

@@ -159,10 +159,11 @@ class EfficientNet:
mv = eval(mk.replace(".weight", "")) mv = eval(mk.replace(".weight", ""))
except AttributeError: except AttributeError:
mv = eval(mk.replace(".bias", "_bias")) mv = eval(mk.replace(".bias", "_bias"))
vnp = v.numpy().astype(np.float32) if USE_TORCH else v vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32)
vnp = vnp if k != '_fc.weight' else vnp.T vnp = vnp if k != '_fc.weight' else vnp.T
vnp = vnp if vnp.shape != () else np.array([vnp])
if mv.shape == vnp.shape or vnp.shape == (): if mv.shape == vnp.shape:
mv.data[:] = vnp mv.assign(Tensor(vnp))
else: else:
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape)) print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))