mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
fix strided convs, GPU env var for enet
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
|
||||
# a rough copy of
|
||||
# https://github.com/lukemelas/EfficientNet-PyTorch/blob/master/efficientnet_pytorch/model.py
|
||||
import os
|
||||
GPU = os.getenv("GPU", None) is not None
|
||||
import sys
|
||||
import io
|
||||
import numpy as np
|
||||
@@ -114,6 +116,8 @@ class EfficientNet:
|
||||
mv = eval(mk.replace(".bias", "_bias"))
|
||||
vnp = v.numpy().astype(np.float32)
|
||||
mv.data[:] = vnp if k != '_fc.weight' else vnp.T
|
||||
if GPU:
|
||||
mv.cuda_()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# instantiate my net
|
||||
@@ -154,7 +158,10 @@ if __name__ == "__main__":
|
||||
# run the net
|
||||
import time
|
||||
st = time.time()
|
||||
out = model.forward(Tensor(img))
|
||||
if GPU:
|
||||
out = model.forward(Tensor(img).cuda())
|
||||
else:
|
||||
out = model.forward(Tensor(img))
|
||||
|
||||
# if you want to look at the outputs
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user