mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
get parameters
This commit is contained in:
@@ -3,10 +3,13 @@ import time
|
||||
import numpy as np
|
||||
from extra.efficientnet import EfficientNet
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.utils import get_parameters
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.default_gpu = os.getenv("GPU") is not None
|
||||
model = EfficientNet(int(os.getenv("NUM", "0")))
|
||||
parameters = get_parameters(model)
|
||||
print(len(parameters))
|
||||
|
||||
BS = 16
|
||||
img = np.zeros((BS,3,224,224), dtype=np.float32)
|
||||
@@ -26,6 +29,10 @@ if __name__ == "__main__":
|
||||
y = Tensor(y)
|
||||
loss = out.logsoftmax().mul(y).mean()
|
||||
|
||||
# zero grad
|
||||
for p in parameters:
|
||||
p.grad = None
|
||||
|
||||
st = time.time()
|
||||
loss.backward()
|
||||
et = time.time()
|
||||
|
||||
Reference in New Issue
Block a user