mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
add cat support
This commit is contained in:
@@ -220,6 +220,10 @@ class TestOps(unittest.TestCase):
|
||||
lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz),
|
||||
lambda x: Tensor.avg_pool2d(x, kernel_size=ksz), rtol=1e-5)
|
||||
|
||||
def test_cat(self):
|
||||
for dim in range(-1, 2):
|
||||
helper_test_op([(45,65), (45,65)], lambda x,y: torch.cat((x,y), dim), lambda x,y: x.cat(y, dim), forward_only=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
np.random.seed(1337)
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -177,6 +177,20 @@ class Tensor:
|
||||
assert s.step is None or s.step == 1
|
||||
return self.slice(arg = arg + [(0,self.shape[i]) for i in range(len(arg), len(self.shape))])
|
||||
|
||||
def cat(self, y, dim=0):
|
||||
assert len(self.shape) == len(y.shape)
|
||||
dim = (dim + len(self.shape)) if dim < 0 else dim
|
||||
s1, s2 = [], []
|
||||
for i in range(len(self.shape)):
|
||||
if i != dim:
|
||||
assert self.shape[i] == y.shape[i]
|
||||
s1.append((0, self.shape[i]))
|
||||
s2.append((0, self.shape[i]))
|
||||
else:
|
||||
s1.append((0, self.shape[i]+y.shape[i]))
|
||||
s2.append((-self.shape[i], y.shape[i]))
|
||||
return self.slice(arg=s1) + y.slice(arg=s2)
|
||||
|
||||
def pad2d(self, padding):
|
||||
return self[:, :, -padding[2]:self.shape[2]+padding[3], -padding[0]:self.shape[3]+padding[1]]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user