mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix gpt2 on rangeify (#12335)
This commit is contained in:
@@ -134,7 +134,8 @@ class GPT2:
|
||||
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
|
||||
for k in weights:
|
||||
if k.endswith(transposed):
|
||||
weights[k] = weights[k].T
|
||||
# TODO: it should not silently break without that .to(None)
|
||||
weights[k] = weights[k].to(None).T
|
||||
# lm head and wte are tied
|
||||
weights['lm_head.weight'] = weights['wte.weight']
|
||||
|
||||
|
||||
Reference in New Issue
Block a user