mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
fix GPU efficientnet example
This commit is contained in:
@@ -159,10 +159,11 @@ class EfficientNet:
|
||||
mv = eval(mk.replace(".weight", ""))
|
||||
except AttributeError:
|
||||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
vnp = v.numpy().astype(np.float32) if USE_TORCH else v
|
||||
vnp = v.numpy().astype(np.float32) if USE_TORCH else v.astype(np.float32)
|
||||
vnp = vnp if k != '_fc.weight' else vnp.T
|
||||
vnp = vnp if vnp.shape != () else np.array([vnp])
|
||||
|
||||
if mv.shape == vnp.shape or vnp.shape == ():
|
||||
mv.data[:] = vnp
|
||||
if mv.shape == vnp.shape:
|
||||
mv.assign(Tensor(vnp))
|
||||
else:
|
||||
print("MISMATCH SHAPE IN %s, %r %r" % (k, mv.shape, vnp.shape))
|
||||
|
||||
Reference in New Issue
Block a user