realize sched

This commit is contained in:
George Hotz
2025-10-14 14:44:56 +08:00
parent fe683bafa6
commit 30ff87eab4

View File

@@ -19,10 +19,9 @@ class TestOuterworldTrain(unittest.TestCase):
steps = UOp.range(16, -1)
opt.zero_grad()
loss = (layer(X[steps]) - Y[steps]).square().mean().backward()
opt.schedule_step() # TODO: does this need to know anything about steps?
sched = opt.schedule_step() # TODO: does this need to know anything about steps?
# NOTE: this can't work. the inputs to layer are not the assign, need to run twice for the fixed point?
all_losses = loss.reshape(1).expand(steps).contiguous()
all_losses.realize()
all_losses = Tensor.realize(loss.reshape(1).expand(steps).contiguous(), *sched)
print(all_losses.numpy())
#@unittest.skip("TODO: understand assign")