mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
add SGD to beautiful_mnist (#13571)
This commit is contained in:
@@ -21,7 +21,7 @@ if __name__ == "__main__":
|
|||||||
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
|
X_train, Y_train, X_test, Y_test = mnist(fashion=getenv("FASHION"))
|
||||||
|
|
||||||
model = Model()
|
model = Model()
|
||||||
opt = (nn.optim.Adam if not getenv("MUON") else nn.optim.Muon)(nn.state.get_parameters(model))
|
opt = (nn.optim.Muon if getenv("MUON") else nn.optim.SGD if getenv("SGD") else nn.optim.Adam)(nn.state.get_parameters(model))
|
||||||
|
|
||||||
@TinyJit
|
@TinyJit
|
||||||
@Tensor.train()
|
@Tensor.train()
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
|
|||||||
# test lowering all the ScheduleItems to ExecItems
|
# test lowering all the ScheduleItems to ExecItems
|
||||||
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
|
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
|
||||||
if kernel_cnt != allowed:
|
if kernel_cnt != allowed:
|
||||||
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
print(f"SCHEDULE ISSUE, expecting {allowed} got {kernel_cnt}")
|
||||||
if DEBUG >= 3:
|
if DEBUG >= 3:
|
||||||
for i,s in enumerate(sched):
|
for i,s in enumerate(sched):
|
||||||
print("kernel", i+1)
|
print("kernel", i+1)
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ class Optimizer:
|
|||||||
Base class for all optimizers.
|
Base class for all optimizers.
|
||||||
"""
|
"""
|
||||||
def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
|
def __init__(self, params: list[Tensor], lr: float, fused=FUSE_OPTIM):
|
||||||
# if it's None, but being put into an optimizer, set it to True
|
# if requires_grad is None, but being put into an optimizer, set it to True
|
||||||
for x in params:
|
for x in params:
|
||||||
if x.requires_grad is None: x.requires_grad = True
|
if x.requires_grad is None: x.requires_grad = True
|
||||||
|
|
||||||
@@ -51,6 +51,7 @@ class Optimizer:
|
|||||||
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
- help: Consider setting Tensor.training=True before calling Optimizer.step().""")
|
||||||
if self.fused:
|
if self.fused:
|
||||||
# optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up
|
# optimizer fusion just concatenates all the buffers, runs the _step, then splits them back up
|
||||||
|
# NOTE: contiguous is for speed
|
||||||
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
|
out, extra = self._step([Tensor.cat(*[t.flatten() for t in self.params], dim=0)],
|
||||||
[Tensor.cat(*[unwrap(t.grad).contiguous().flatten() for t in self.params], dim=0)])
|
[Tensor.cat(*[unwrap(t.grad).contiguous().flatten() for t in self.params], dim=0)])
|
||||||
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
updated_params = [out[0][self.pos_params[i]:self.pos_params[i+1]].reshape(tt.shape) for i, tt in enumerate(self.params)]
|
||||||
|
|||||||
Reference in New Issue
Block a user