From 9aa5e02229f6050cf9e71df9fe2e06beb71014ea Mon Sep 17 00:00:00 2001 From: qazal <77887910+Qazalin@users.noreply.github.com> Date: Wed, 15 May 2024 02:18:38 +0800 Subject: [PATCH] update llmc export (#4584) * update example * move train to optim * rename * b2 --- examples/llm.c/export.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/llm.c/export.py b/examples/llm.c/export.py index fa0c64a9b0..ee1e9650e1 100755 --- a/examples/llm.c/export.py +++ b/examples/llm.c/export.py @@ -6,7 +6,7 @@ Device.DEFAULT = "CLANG" from train_gpt2 import GPT, GPTConfig from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name from tinygrad.engine.schedule import create_schedule -from tinygrad.engine.realize import run_schedule +from tinygrad.engine.realize import get_linearizer, run_schedule from tinygrad.engine.memory import memory_planner from tinygrad.ops import BufferOps, LoadOps @@ -24,6 +24,7 @@ if __name__ == "__main__": #B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64) B, T = 4, 64 + Tensor.training = True optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4) warmup_count = getenv("WARMUP", 3) for i in range(warmup_count): # TODO: why does it take three and not two to stablize @@ -46,7 +47,7 @@ if __name__ == "__main__": ast_dedup = dedup([si.ast for si in sched if si.ast[0].op is BufferOps.STORE]) srcs = {} for ast in ast_dedup: - k = Device["CLANG"].get_linearizer(*ast) + k = get_linearizer(Device["CLANG"].renderer, ast) k.linearize() src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops) srcs[ast] = (k.name, src) @@ -62,7 +63,7 @@ if __name__ == "__main__": if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}") if v.grad is not None: grad_state_dict['grad_'+k] = v.grad state_dict.update(grad_state_dict) - state_dict.update({'adam_b1': optimizer.b1, 'adam_b2': optimizer.b2, 'adam_t': optimizer.t, 'adam_lr': optimizer.lr}) + state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr}) inverse_state_dict = {v:k for k,v in state_dict.items()} for p,m,v in zip(optimizer.params, optimizer.m, optimizer.v): nm = inverse_state_dict[p]