mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
Transpose on GPU (#221)
* 2serious * load/save * fixing GPU * added DEBUG * needs BatchNorm or doesn't learn anything * old file not needed * added conv biases * added extra/training.py and checkpoint * assert in test only * save * padding * num_classes * checkpoint * checkpoints for padding * training was broken * merge * rotation augmentation * more aug * needs testing * streamline augment, augment is fast thus bicubic * tidying up * transformer eval * axis=-1 * transpose * test for permutation using torch.movedims * another test * line
This commit is contained in:
@@ -127,9 +127,10 @@ class TestOps(unittest.TestCase):
|
||||
def test_pad2d(self):
|
||||
helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4)), device=self.device)
|
||||
|
||||
@cpu_only
|
||||
def test_transpose(self):
|
||||
helper_test_op([(3,3,3)], lambda x: x.transpose(1,2), lambda x: x.transpose(order=(0,2,1)), device=self.device)
|
||||
helper_test_op([(21,22,23,24)], lambda x: x.movedim((3,0,2,1),(0,1,2,3)), lambda x: x.transpose(order=(3,0,2,1)), device=self.device)
|
||||
helper_test_op([(3,4,5,6)], lambda x: x.movedim((3,2,1,0),(0,1,2,3)), lambda x: x.transpose(order=(3,2,1,0)), device=self.device)
|
||||
|
||||
def test_reshape(self):
|
||||
helper_test_op([(4,3,6,6)], lambda x: torch.reshape(x, (-1,3,6,6)), lambda x: x.reshape(shape=(-1,3,6,6)), device=self.device)
|
||||
|
||||
Reference in New Issue
Block a user