mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
track down llama bug
This commit is contained in:
@@ -151,8 +151,12 @@ class Transformer:
|
||||
|
||||
def __call__(self, tokens:Tensor, start_pos:int):
|
||||
_bsz, seqlen = tokens.shape
|
||||
# get only the part we are using. TODO: removing contiguous resulted in a bug?
|
||||
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous().realize()
|
||||
# get only the part we are using.
|
||||
# NOTE: if you remove contiguous here, it breaks because you can't put different ShapeTrackers into the compiled JIT
|
||||
# NOTE: realize is not enough, since the realized buffer will have an offset that the kernel doesn't know about
|
||||
# TODO: check that we didn't do this in the JIT and confirm the ShapeTrackers match the template
|
||||
# TODO: support Variables in shrink
|
||||
freqs_cis = self.freqs_cis[:, start_pos:start_pos+seqlen].contiguous()
|
||||
mask = Tensor.full((1, 1, seqlen, start_pos + seqlen), float("-inf"), dtype=dtypes.float32).triu(start_pos+1).realize() if seqlen > 1 else None
|
||||
|
||||
do_jit = getenv("JIT") and mask is None
|
||||
@@ -433,6 +437,7 @@ After you are done speaking, output [EOS]. You are not Chad.
|
||||
|
||||
last_break = len(outputted)
|
||||
for i in range(args.count):
|
||||
GlobalCounters.reset()
|
||||
if args.profile and i == 2: profiler.enable()
|
||||
|
||||
if args.timing: print("")
|
||||
|
||||
Reference in New Issue
Block a user