reduce collapse generic (#10045)

* reduce collapse generic

* new arange folder

* new range folding

* correct with sym

* all tests pass

* indexing ops passes

* failing tests

* fix tests, remove unused

* revert that

* torch indexing is fast

* skip on webgpu

* touchups

* comments
This commit is contained in:
George Hotz
2025-04-26 09:13:24 -04:00
committed by GitHub
parent 5cdc96409e
commit ea5dddc537
7 changed files with 105 additions and 58 deletions

View File

@@ -54,7 +54,7 @@ if __name__ == "__main__":
return loss
test_acc = float('nan')
for i in (t:=trange(70)):
for i in (t:=trange(getenv("STEPS", 70))):
samples = torch.randint(0, X_train.shape[0], (512,)) # putting this in JIT didn't work well
loss = step(samples)
if i%10 == 9: test_acc = ((model(X_test).argmax(axis=-1) == Y_test).sum() * 100 / X_test.shape[0]).item()