mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
fix device of Tensor.arange inside Tensor.one_hot (#3199)
it should have the same device as self
This commit is contained in:
@@ -437,5 +437,12 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
np.testing.assert_equal(Tensor([]).sum().numpy(), 0)
|
||||
np.testing.assert_equal(Tensor([]).mean().numpy(), 0)
|
||||
|
||||
class TestTensorCreationDevice(unittest.TestCase):
|
||||
# test auxiliary tensors are created on the same device
|
||||
def test_one_hot(self):
|
||||
y = Tensor([1, 2, 3]).to("CPU")
|
||||
x = y.one_hot(10)
|
||||
x.realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -901,7 +901,8 @@ class Tensor:
|
||||
if not Tensor.training or p == 0: return self
|
||||
return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p))
|
||||
|
||||
def one_hot(self, num_classes:int, **kwargs) -> Tensor: return Tensor.where(self[..., None] == Tensor.arange(num_classes), 1, 0, **kwargs)
|
||||
def one_hot(self, num_classes:int) -> Tensor:
|
||||
return (self[..., None] == Tensor.arange(num_classes, requires_grad=False, device=self.device)).where(1, 0)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501
|
||||
# NOTE: it works if key, value have symbolic shape
|
||||
|
||||
Reference in New Issue
Block a user