mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
update llmc export (#4584)
* update example * move train to optim * rename * b2
This commit is contained in:
@@ -6,7 +6,7 @@ Device.DEFAULT = "CLANG"
|
||||
from train_gpt2 import GPT, GPTConfig
|
||||
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GRAPH, GlobalCounters, ansilen, to_function_name
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.realize import get_linearizer, run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.ops import BufferOps, LoadOps
|
||||
|
||||
@@ -24,6 +24,7 @@ if __name__ == "__main__":
|
||||
#B, T = Variable("B", 1, 128).bind(4), 64 #Variable("T", 1, 1024).bind(64)
|
||||
B, T = 4, 64
|
||||
|
||||
Tensor.training = True
|
||||
optimizer = nn.optim.Adam(nn.state.get_parameters(model), lr=1e-4)
|
||||
warmup_count = getenv("WARMUP", 3)
|
||||
for i in range(warmup_count): # TODO: why does it take three and not two to stablize
|
||||
@@ -46,7 +47,7 @@ if __name__ == "__main__":
|
||||
ast_dedup = dedup([si.ast for si in sched if si.ast[0].op is BufferOps.STORE])
|
||||
srcs = {}
|
||||
for ast in ast_dedup:
|
||||
k = Device["CLANG"].get_linearizer(*ast)
|
||||
k = get_linearizer(Device["CLANG"].renderer, ast)
|
||||
k.linearize()
|
||||
src = Device["CLANG"].renderer.render(to_function_name(k.name), k.uops)
|
||||
srcs[ast] = (k.name, src)
|
||||
@@ -62,7 +63,7 @@ if __name__ == "__main__":
|
||||
if v.lazydata.base.buffer not in used_buffers: print(f"UNUSED: {k}")
|
||||
if v.grad is not None: grad_state_dict['grad_'+k] = v.grad
|
||||
state_dict.update(grad_state_dict)
|
||||
state_dict.update({'adam_b1': optimizer.b1, 'adam_b2': optimizer.b2, 'adam_t': optimizer.t, 'adam_lr': optimizer.lr})
|
||||
state_dict.update({'adam_b1_t': optimizer.b1_t, 'adam_b2_t': optimizer.b2_t, 'adam_lr': optimizer.lr})
|
||||
inverse_state_dict = {v:k for k,v in state_dict.items()}
|
||||
for p,m,v in zip(optimizer.params, optimizer.m, optimizer.v):
|
||||
nm = inverse_state_dict[p]
|
||||
|
||||
Reference in New Issue
Block a user