mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
test intermediate tensors created by function have same device as input (#3338)
run on TORCH since it's the fastest one on CI. caught a bug in multinomial, and update the behavior of fancy index and gather to move the indices Tensor to same device as self.
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -125,6 +125,8 @@ jobs:
|
||||
run: pip install -e '.[testing]' --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
- name: Run Pytest
|
||||
run: TORCH=1 python -m pytest -n=auto test/ --durations=20
|
||||
- name: Run test_ops with inputs moved to a different device
|
||||
run: TORCH=1 MOVE_TENSOR=TORCH:1 python -m pytest -n=auto test/test_ops.py --durations=20
|
||||
- name: Run ONNX
|
||||
run: TORCH=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
|
||||
|
||||
@@ -19,6 +19,10 @@ def helper_test_op(shps, torch_fxn, tinygrad_fxn=None, atol=1e-6, rtol=1e-3, gra
|
||||
out = torch_fxn(*ts)
|
||||
torch_fp = time.monotonic() - st
|
||||
|
||||
# move inputs to a different device, test the device of intermediate tensors are correct
|
||||
if mt:=getenv("MOVE_TENSOR", ""):
|
||||
for t in tst: t.to_(mt)
|
||||
|
||||
st = time.monotonic()
|
||||
ret = tinygrad_fxn(*tst).realize()
|
||||
tinygrad_fp = time.monotonic() - st
|
||||
|
||||
@@ -264,7 +264,7 @@ class Tensor:
|
||||
assert replacement or num_samples == 1, "no replacement only supports num_samples = 1"
|
||||
weight = self.unsqueeze(0) if self.ndim == 1 else self
|
||||
cdf = (cw := weight.cumsum(1).float()) / cw[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1)
|
||||
unif_samples = Tensor.rand(num_samples, cdf.shape[0], 1, device=self.device)
|
||||
indices = (unif_samples.expand((-1, -1, cdf.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
return (indices.squeeze(0) if self.ndim == 1 else indices).cast(dtypes.default_int)
|
||||
|
||||
@@ -356,6 +356,8 @@ class Tensor:
|
||||
|
||||
# turn scalar Tensors into const val for int indexing if possible
|
||||
indices = [self._to_const_val(i) if isinstance(i, Tensor) else i for i in indices]
|
||||
# move Tensor indices to the same device as self
|
||||
indices = [i.to(self.device) if isinstance(i, Tensor) else i for i in indices]
|
||||
|
||||
# filter ellipsis and fill with slice(None) or fill rest of indices with slice(None)
|
||||
ellipsis_idx = [dim for dim, i in enumerate(indices) if i is Ellipsis]
|
||||
@@ -451,7 +453,7 @@ class Tensor:
|
||||
assert idx.ndim == self.ndim, "self.ndim must equal idx.ndim"
|
||||
assert all(s >= i for s,i in zip(self.shape, idx.shape)), "all dim of idx.shape must be smaller than self.shape"
|
||||
if dim < 0: dim += self.ndim
|
||||
idx = idx.transpose(ax1=dim, ax2=0).unsqueeze(-1)
|
||||
idx = idx.to(self.device).transpose(ax1=dim, ax2=0).unsqueeze(-1)
|
||||
permarg = list(range(self.ndim))
|
||||
permarg = permarg[1:dim] + [permarg[0]] + permarg[dim+1:] + [permarg[dim]] if dim != 0 else permarg[1:] + [permarg[0]]
|
||||
return ((idx == Tensor.arange(self.shape[dim], requires_grad=False, device=self.device)) * self.permute(*permarg).shrink(
|
||||
|
||||
Reference in New Issue
Block a user