mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
correct patch JIT llama chat (#1500)
This commit is contained in:
@@ -127,10 +127,10 @@ class TransformerBlock:
|
||||
|
||||
def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]):
|
||||
# if mask is not None, x's shape is dymanic based on user input and pre/post can't be jitted
|
||||
xq, xk, xv = self._pre(x, freqs_cis) if mask is not None else self.pre(x, freqs_cis)
|
||||
xq, xk, xv = self._pre(x, freqs_cis) if mask is None else self.pre(x, freqs_cis)
|
||||
# inner_attention can't be jitted because it's dynamic based on start_pos
|
||||
output = self.attention.inner_attention(xq, xk, xv, start_pos, mask)
|
||||
return self._post(x, output) if mask is not None else self.post(x, output)
|
||||
return self._post(x, output) if mask is None else self.post(x, output)
|
||||
|
||||
class Transformer:
|
||||
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None):
|
||||
|
||||
Reference in New Issue
Block a user