diff --git a/examples/gpt2.py b/examples/gpt2.py index ae51ab2332..2ec1810e02 100644 --- a/examples/gpt2.py +++ b/examples/gpt2.py @@ -134,7 +134,8 @@ class GPT2: transposed = ('attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight') for k in weights: if k.endswith(transposed): - weights[k] = weights[k].T + # TODO: it should not silently break without that .to(None) + weights[k] = weights[k].to(None).T # lm head and wte are tied weights['lm_head.weight'] = weights['wte.weight']