correct patch JIT llama chat (#1500)

This commit is contained in:
chenyu
2023-08-08 16:52:09 -07:00
committed by GitHub
parent 7c2ea85bb0
commit 827d13e64e

View File

@@ -127,10 +127,10 @@ class TransformerBlock:
def __call__(self, x:Tensor, start_pos:int, freqs_cis:Tensor, mask:Optional[Tensor]):
# if mask is not None, x's shape is dymanic based on user input and pre/post can't be jitted
xq, xk, xv = self._pre(x, freqs_cis) if mask is not None else self.pre(x, freqs_cis)
xq, xk, xv = self._pre(x, freqs_cis) if mask is None else self.pre(x, freqs_cis)
# inner_attention can't be jitted because it's dynamic based on start_pos
output = self.attention.inner_attention(xq, xk, xv, start_pos, mask)
return self._post(x, output) if mask is not None else self.post(x, output)
return self._post(x, output) if mask is None else self.post(x, output)
class Transformer:
def __init__(self, dim, multiple_of, n_heads, n_layers, norm_eps, vocab_size, linear=Linear, max_batch_size=32, max_seq_len=1024, ffn_dim_multiplier=None):