mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix GPU efficientnet example
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
# a rough copy of
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
import os
|
||||
GPU = os.getenv("GPU", None) is not None
|
||||
import sys
|
||||
import io
|
||||
import time
|
||||
@@ -37,10 +36,7 @@ def infer(model, img):
|
||||
img /= np.array([0.229, 0.224, 0.225]).reshape((1,-1,1,1))
|
||||
|
||||
# run the net
|
||||
if GPU:
|
||||
out = model.forward(Tensor(img).gpu()).cpu()
|
||||
else:
|
||||
out = model.forward(Tensor(img))
|
||||
out = model.forward(Tensor(img)).cpu()
|
||||
|
||||
# if you want to look at the outputs
|
||||
"""
|
||||
@@ -54,8 +50,6 @@ if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
model = EfficientNet(int(os.getenv("NUM", "0")))
|
||||
model.load_weights_from_torch()
|
||||
if GPU:
|
||||
[x.gpu_() for x in get_parameters(model)]
|
||||
|
||||
# category labels
|
||||
import ast
|
||||
|
||||
Reference in New Issue
Block a user