diff --git a/test/test_rangeify.py b/test/test_rangeify.py index a7cb8a4a54..d15b993965 100644 --- a/test/test_rangeify.py +++ b/test/test_rangeify.py @@ -38,6 +38,13 @@ class TestRangeifyOpt(unittest.TestCase): Xsel, Ysel = X[sel], Y[sel] Tensor.realize(Xsel, Ysel) + def test_resnetconv(self): + conv1 = nn.Conv2d(3, 8, kernel_size=7, stride=2, bias=False, padding=3) + conv1.weight.replace(conv1.weight.empty_like()) + x = Tensor.empty(1, 3, 56, 56) + x = conv1(x).pad([1,1,1,1])+1 + x.realize() + @unittest.skipIf(RANGEIFY<1, "tests only for RANGEIFY") class TestRangeify(unittest.TestCase): def test_groupnorm(self): diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 81ef5013e7..76969ea11d 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -456,6 +456,13 @@ class Tensor(MathTrait): device = tuple(Device.canonicalize(d) for d in device) if isinstance(device, tuple) else Device.canonicalize(device) return Tensor(UOp.new_buffer(device, size, dtype), device, dtype, **kwargs).shrink(((0,prod(shape)),)).reshape(shape) + def empty_like(self, **kwargs) -> Tensor: + """ + Creates an empty tensor with the same shape as `self`. + If `dtype` is not specified, the dtype of `self` is used. + """ + return Tensor.empty(self.shape, dtype=kwargs.pop("dtype", self.dtype), device=kwargs.pop("device", self.device), **kwargs) + @staticmethod def from_blob(ptr:int, shape:tuple[int, ...], **kwargs) -> Tensor: """