mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
Tensor.pad_to and Tensor.shrink_to (#12210)
most of the time i want this instead of spelling out the args also add more input validation to shrink
This commit is contained in:
@@ -109,7 +109,7 @@ class TextDecoder:
|
||||
|
||||
def forward(self, x:Tensor, pos:Union[Variable, Literal[0]], encoded_audio:Tensor):
|
||||
seqlen = x.shape[-1]
|
||||
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos+seqlen), None, None))
|
||||
x = self.token_embedding(x) + self.positional_embedding.shrink(((pos, pos+seqlen), None))
|
||||
for block in self.blocks: x = block(x, xa=encoded_audio, mask=self.mask, len=pos)
|
||||
return self.output_tok(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user