mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 06:58:11 -05:00
feat: example and extra tweaks (#6310)
This commit is contained in:
@@ -33,7 +33,7 @@ if __name__ == "__main__":
|
||||
def train_step() -> Tensor:
|
||||
with Tensor.train():
|
||||
opt.zero_grad()
|
||||
samples = Tensor.randint(512, high=X_train.shape[0])
|
||||
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
|
||||
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0
|
||||
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
|
||||
loss = model(Xt).sparse_categorical_crossentropy(Yt).backward()
|
||||
@@ -44,7 +44,7 @@ if __name__ == "__main__":
|
||||
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
|
||||
|
||||
test_acc = float('nan')
|
||||
for i in (t:=trange(70)):
|
||||
for i in (t:=trange(getenv("STEPS", 70))):
|
||||
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
|
||||
loss = train_step()
|
||||
if i%10 == 9: test_acc = get_test_acc().item()
|
||||
@@ -53,4 +53,4 @@ if __name__ == "__main__":
|
||||
# verify eval acc
|
||||
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
|
||||
if test_acc >= target: print(colored(f"{test_acc=} >= {target}", "green"))
|
||||
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))
|
||||
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))
|
||||
|
||||
@@ -2,9 +2,14 @@ import os
|
||||
if "DEBUG" not in os.environ: os.environ["DEBUG"] = "2"
|
||||
if "THREEFRY" not in os.environ: os.environ["THREEFRY"] = "1"
|
||||
|
||||
from tinygrad import Tensor, GlobalCounters
|
||||
from tinygrad import Tensor, GlobalCounters, Device
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
GPUS = getenv("SHARD", 1)
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(GPUS))
|
||||
|
||||
for N in [10_000_000, 100_000_000, 1_000_000_000]:
|
||||
GlobalCounters.reset()
|
||||
Tensor.rand(N).realize()
|
||||
print(f"N {N:>20_}, global_ops {GlobalCounters.global_ops:>20_}, global_mem {GlobalCounters.global_mem:>20_}")
|
||||
t = Tensor.rand(N) if GPUS <= 1 else Tensor.rand(N, device=devices)
|
||||
t.realize()
|
||||
print(f"N {N:>20_}, global_ops {GlobalCounters.global_ops:>20_}, global_mem {GlobalCounters.global_mem:>20_}")
|
||||
|
||||
Reference in New Issue
Block a user