diff --git a/extra/models/llama.py b/extra/models/llama.py index 6a62280e01..3a089facbc 100644 --- a/extra/models/llama.py +++ b/extra/models/llama.py @@ -107,7 +107,7 @@ class TransformerBlock: def __call__(self, x:Tensor, start_pos:Union[Variable,int], freqs_cis:Tensor, mask:Optional[Tensor]): h = x + self.attention(self.attention_norm(x), start_pos, freqs_cis, mask) - return h + self.feed_forward(self.ffn_norm(h)) + return (h + self.feed_forward(self.ffn_norm(h))).contiguous() # standard openai sampling def sample(logits: Tensor, temp: float, k: int, p: float, af: float, ap: float):