fix GPU efficientnet example

This commit is contained in:
George Hotz
2021-05-26 17:29:35 -07:00
parent 1ae0e88627
commit b80cacb416
2 changed files with 5 additions and 10 deletions

View File

@@ -159,10 +159,11 @@ class EfficientNet:
mv = eval(mk.replace(".weight", ""))
except AttributeError:
mv = eval(mk.replace(".bias", "_bias"))
vnp = v.numpy().astype(np.float32) if USE_TORCH else v
vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32)
vnp = vnp if k != '_fc.weight' else vnp.T
vnp = vnp if vnp.shape != () else np.array([vnp])
if mv.shape == vnp.shape or vnp.shape == ():
mv.data[:] = vnp
if mv.shape == vnp.shape:
mv.assign(Tensor(vnp))
else:
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))