diff --git a/examples/llama.py b/examples/llama.py index 519eecd55e..eaa224c7fc 100755 --- a/examples/llama.py +++ b/examples/llama.py @@ -311,7 +311,7 @@ class LLaMa: else: weights = load(str(model_path)) if "model.embed_tokens.weight" in weights: - weights = convert_from_huggingface(weights, model, model_args["n_heads"], model_args["n_kv_heads"]) + weights = convert_from_huggingface(weights, model, model_args["n_heads"], model_args.get("n_kv_heads", model_args["n_heads"])) if quantize: weights = AbsmaxQuantizedLinear.quantize(weights)