diff --git a/examples/llama3.py b/examples/llama3.py index 2ac5821195..1710823687 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -179,10 +179,10 @@ def build_transformer(model_path: Path, model_size="8B", quantize=None, device=N return model # default settings -TEMPERATURE = 0.85 -TOP_K = 25 -TOP_P = 0.9 -ALPHA_F = 0.1 +TEMPERATURE = 0.95 +TOP_K = 0 +TOP_P = 0.0 +ALPHA_F = 0.0 ALPHA_P = 0.0 last_seen_toks = [] diff --git a/extra/models/llama.py b/extra/models/llama.py index 88d33710f8..4204dd71da 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -109,6 +109,8 @@ def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float): # if temperature is very low just use argmax if temp < 1e-6: return logits.argmax() + logits = logits.to(Device.DEFAULT) + # alpha sampling if af or ap: if not hasattr(sample, "alpha_counter"):