mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
with Tensor.train() (#1935)
* add with.train * remove the rest TODOs * fix pyflake * fix pyflake error * fix mypy
This commit is contained in:
@@ -5,15 +5,15 @@ from tinygrad.nn import Conv2d, BatchNorm2d
|
||||
from tinygrad.nn.state import get_parameters
|
||||
|
||||
if __name__ == "__main__":
|
||||
Tensor.training = True
|
||||
with Tensor.train():
|
||||
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
BS, C1, H, W = 4, 16, 224, 224
|
||||
C2, K, S, P = 64, 7, 2, 1
|
||||
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||
for t in get_parameters([x, conv, bn]): t.realize()
|
||||
x = Tensor.uniform(BS, C1, H, W)
|
||||
conv = Conv2d(C1, C2, kernel_size=K, stride=S, padding=P)
|
||||
bn = BatchNorm2d(C2, track_running_stats=False)
|
||||
for t in get_parameters([x, conv, bn]): t.realize()
|
||||
|
||||
print("running network")
|
||||
x.sequential([conv, bn]).numpy()
|
||||
print("running network")
|
||||
x.sequential([conv, bn]).numpy()
|
||||
|
||||
Reference in New Issue
Block a user