diff --git a/examples/efficientnet.py b/examples/efficientnet.py index 7e8afbb5b3..dd1d015ab4 100644 --- a/examples/efficientnet.py +++ b/examples/efficientnet.py @@ -3,7 +3,6 @@ # a rough copy of # https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py import os -GPU = os.getenv("GPU", None) is not None import sys import io import time @@ -37,10 +36,7 @@ def infer(model, img): img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1)) # run the net - if GPU: - out = model.forward(Tensor(img).gpu()).cpu() - else: - out = model.forward(Tensor(img)) + out = model.forward(Tensor(img)).cpu() # if you want to look at the outputs """ @@ -54,8 +50,6 @@ if __name__ == "__main__": # instantiate my net model = EfficientNet(int(os.getenv("NUM", "0"))) model.load_weights_from_torch() - if GPU: - [x.gpu_() for x in get_parameters(model)] # category labels import ast diff --git a/extra/efficientnet.py b/extra/efficientnet.py index 7af1243f09..9e60ca1944 100644 --- a/extra/efficientnet.py +++ b/extra/efficientnet.py @@ -159,10 +159,11 @@ class EfficientNet: mv = eval(mk.replace(".weight", "")) except AttributeError: mv = eval(mk.replace(".bias", "_bias")) - vnp = v.numpy().astype(np.float32) if USE_TORCH else v + vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32) vnp = vnp if k != '_fc.weight' else vnp.T + vnp = vnp if vnp.shape != () else np.array([vnp]) - if mv.shape == vnp.shape or vnp.shape == (): - mv.data[:] = vnp + if mv.shape == vnp.shape: + mv.assign(Tensor(vnp)) else: print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))