fix gpt2 on rangeify (#12335)

This commit is contained in:
George Hotz
2025-09-29 21:16:44 +10:00
committed by GitHub
parent 9513f025c5
commit baf3b60cfb

View File

@@ -134,7 +134,8 @@ class GPT2:
transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight')
for k in weights:
if k.endswith(transposed):
weights[k] = weights[k].T
# TODO: it should not silently break without that .to(None)
weights[k] = weights[k].to(None).T
# lm head and wte are tied
weights['lm_head.weight'] = weights['wte.weight']