track down llama bug

This commit is contained in:
George Hotz
2023-08-21 15:14:21 -07:00
parent b02f77b354
commit 4ea00bad38

View File

@@ -151,8 +151,12 @@ class Transformer:
def __call__(self, tokens:Tensor, start_pos:int):
_bsz, seqlen = tokens.shape
# get only the part we are using. TODO: removing contiguous resulted in a bug?
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous().realize()
# get only the part we are using.
# NOTE: if you remove contiguous here, it breaks because you can't put different ShapeTrackers into the compiled JIT
# NOTE: realize is not enough, since the realized buffer will have an offset that the kernel doesn't know about
# TODO: check that we didn't do this in the JIT and confirm the ShapeTrackers match the template
# TODO: support Variables in shrink
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous()
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
do_jit = getenv("JIT") and mask is None
@@ -433,6 +437,7 @@ After you are done speaking, output [EOS]. You are not Chad.
last_break = len(outputted)
for i in range(args.count):
GlobalCounters.reset()
if args.profile and i == 2: profiler.enable()
if args.timing: print("")