update llmc export (#4584)

* update example

* move train to optim

* rename

* b2
This commit is contained in:
qazal
2024-05-15 02:18:38 +08:00
committed by GitHub
parent 355e1c135c
commit 9aa5e02229

View File

@@ -6,7 +6,7 @@ Device.DEFAULT = "CLANG"
from train_gpt2 import GPT, GPTConfig
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name
from tinygrad.engine.schedule import create_schedule
from tinygrad.engine.realize import run_schedule
from tinygrad.engine.realize import get_linearizer, run_schedule
from tinygrad.engine.memory import memory_planner
from tinygrad.ops import BufferOps, LoadOps
@@ -24,6 +24,7 @@ if __name__ == "__main__":
#B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
B, T = 4, 64
Tensor.training = True
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4)
warmup_count = getenv("WARMUP", 3)
for i in range(warmup_count): # TODO: why does it take three and not two to stablize
@@ -46,7 +47,7 @@ if __name__ == "__main__":
ast_dedup = dedup([si.ast for si in sched if si.ast[0].op is BufferOps.STORE])
srcs = {}
for ast in ast_dedup:
k = Device["CLANG"].get_linearizer(*ast)
k = get_linearizer(Device["CLANG"].renderer, ast)
k.linearize()
src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
srcs[ast] = (k.name, src)
@@ -62,7 +63,7 @@ if __name__ == "__main__":
if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}")
if v.grad is not None: grad_state_dict['grad_'+k] = v.grad
state_dict.update(grad_state_dict)
state_dict.update({'adam_b1': optimizer.b1, 'adam_b2': optimizer.b2, 'adam_t': optimizer.t, 'adam_lr': optimizer.lr})
state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr})
inverse_state_dict = {v:k for k,v in state_dict.items()}
for p,m,v in zip(optimizer.params, optimizer.m, optimizer.v):
nm = inverse_state_dict[p]