test intermediate tensors created by function have same device as input (#3338)

run on TORCH since it's the fastest one on CI.
caught a bug in multinomial, and update the behavior of fancy index and gather to move the indices Tensor to same device as self.
This commit is contained in:
chenyu
2024-02-07 09:24:36 -05:00
committed by GitHub
parent 1732f1ba83
commit 0d2dacb549
3 changed files with 10 additions and 2 deletions

View File

@@ -125,6 +125,8 @@ jobs:
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
- name: Run Pytest
run: TORCH=1 python -m pytest -n=auto test/ --durations=20
- name: Run test_ops with inputs moved to a different device
run: TORCH=1 MOVE_TENSOR=TORCH:1 python -m pytest -n=auto test/test_ops.py --durations=20
- name: Run ONNX
run: TORCH=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20

View File

@@ -19,6 +19,10 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
out = torch_fxn(*ts)
torch_fp = time.monotonic() - st
# move inputs to a different device, test the device of intermediate tensors are correct
if mt:=getenv("MOVE_TENSOR", ""):
for t in tst: t.to_(mt)
st = time.monotonic()
ret = tinygrad_fxn(*tst).realize()
tinygrad_fp = time.monotonic() - st

View File

@@ -264,7 +264,7 @@ class Tensor:
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
weight = self.unsqueeze(0) if self.ndim == 1 else self
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int)
@@ -356,6 +356,8 @@ class Tensor:
# turn scalar Tensors into const val for int indexing if possible
indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices]
# move Tensor indices to the same device as self
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
@@ -451,7 +453,7 @@ class Tensor:
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
if dim < 0: dim += self.ndim
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
idx = idx.to(self.device).transpose(ax1=dim, ax2=0).unsqueeze(-1)
permarg = list(range(self.ndim))
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(