Tensor.pad_to and Tensor.shrink_to (#12210)

most of the time i want this instead of spelling out the args

also add more input validation to shrink
This commit is contained in:
chenyu
2025-09-16 12:24:55 -04:00
committed by GitHub
parent 122a50fe8c
commit 84d2d047ea
3 changed files with 19 additions and 4 deletions

View File

@@ -109,7 +109,7 @@ class TextDecoder:
def forward(self, x:Tensor, pos:Union[Variable, Literal[0]], encoded_audio:Tensor):
seqlen = x.shape[-1]
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos+seqlen), None, None))
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos+seqlen), None))
for block in self.blocks: x = block(x, xa=encoded_audio, mask=self.mask, len=pos)
return self.output_tok(x)