feat: example and extra tweaks (#6310)

This commit is contained in:
wozeparrot
2024-08-28 19:26:11 -07:00
committed by GitHub
parent ea5b7910b7
commit cb61cfce24
2 changed files with 11 additions and 6 deletions

View File

@@ -33,7 +33,7 @@ if __name__ == "__main__":
def train_step() -> Tensor:
with Tensor.train():
opt.zero_grad()
samples = Tensor.randint(512, high=X_train.shape[0])
samples = Tensor.randint(getenv("BS", 512), high=X_train.shape[0])
Xt, Yt = X_train[samples].shard_(GPUS, axis=0), Y_train[samples].shard_(GPUS, axis=0) # we shard the data on axis 0
# TODO: this "gather" of samples is very slow. will be under 5s when this is fixed
loss = model(Xt).sparse_categorical_crossentropy(Yt).backward()
@@ -44,7 +44,7 @@ if __name__ == "__main__":
def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100
test_acc = float('nan')
for i in (t:=trange(70)):
for i in (t:=trange(getenv("STEPS", 70))):
GlobalCounters.reset() # NOTE: this makes it nice for DEBUG=2 timing
loss = train_step()
if i%10 == 9: test_acc = get_test_acc().item()
@@ -53,4 +53,4 @@ if __name__ == "__main__":
# verify eval acc
if target := getenv("TARGET_EVAL_ACC_PCT", 0.0):
if test_acc >= target: print(colored(f"{test_acc=} >= {target}", "green"))
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))
else: raise ValueError(colored(f"{test_acc=} < {target}", "red"))

View File

@@ -2,9 +2,14 @@ import os
if "DEBUG" not in os.environ: os.environ["DEBUG"] = "2"
if "THREEFRY" not in os.environ: os.environ["THREEFRY"] = "1"
from tinygrad import Tensor, GlobalCounters
from tinygrad import Tensor, GlobalCounters, Device
from tinygrad.helpers import getenv
GPUS = getenv("SHARD", 1)
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(GPUS))
for N in [10_000_000, 100_000_000, 1_000_000_000]:
GlobalCounters.reset()
Tensor.rand(N).realize()
print(f"N {N:>20_}, global_ops {GlobalCounters.global_ops:>20_}, global_mem {GlobalCounters.global_mem:>20_}")
t = Tensor.rand(N) if GPUS <= 1 else Tensor.rand(N, device=devices)
t.realize()
print(f"N {N:>20_}, global_ops {GlobalCounters.global_ops:>20_}, global_mem {GlobalCounters.global_mem:>20_}")