diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0123e1b54c..a1f8051a12 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -125,6 +125,8 @@ jobs: run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu - name: Run Pytest run: TORCH=1 python -m pytest -n=auto test/ --durations=20 + - name: Run test_ops with inputs moved to a different device + run: TORCH=1 MOVE_TENSOR=TORCH:1 python -m pytest -n=auto test/test_ops.py --durations=20 - name: Run ONNX run: TORCH=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20 diff --git a/test/test_ops.py b/test/test_ops.py index 2cbc6a3506..270cf7690a 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -19,6 +19,10 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra out = torch_fxn(*ts) torch_fp = time.monotonic() - st + # move inputs to a different device, test the device of intermediate tensors are correct + if mt:=getenv("MOVE_TENSOR", ""): + for t in tst: t.to_(mt) + st = time.monotonic() ret = tinygrad_fxn(*tst).realize() tinygrad_fp = time.monotonic() - st diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 3f0456a4f9..06a37e080a 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -264,7 +264,7 @@ class Tensor: assert replacement or num_samples == 1, "no replacement only supports num_samples = 1" weight = self.unsqueeze(0) if self.ndim == 1 else self cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1) - unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1) + unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device) indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0)) return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int) @@ -356,6 +356,8 @@ class Tensor: # turn scalar Tensors into const val for int indexing if possible indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices] + # move Tensor indices to the same device as self + indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices] # filter ellipsis and fill with slice(None) or fill rest of indices with slice(None) ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis] @@ -451,7 +453,7 @@ class Tensor: assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim" assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape" if dim < 0: dim += self.ndim - idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1) + idx = idx.to(self.device).transpose(ax1=dim, ax2=0).unsqueeze(-1) permarg = list(range(self.ndim)) permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]] return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(