mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
@@ -170,5 +170,9 @@ class TestRand(unittest.TestCase):
|
||||
Tensor.rand(2**17, 2**17).schedule()
|
||||
Tensor.rand(2**17, 2**17).schedule()
|
||||
|
||||
class TestTensorDevice(unittest.TestCase):
|
||||
def test_create_from_single_device_tuple(self):
|
||||
(Tensor([1.0], device=(Device.DEFAULT,)) + Tensor([2.0])).realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -20,7 +20,8 @@ from tinygrad.engine.allocations import transform_to_call
|
||||
|
||||
# TODO: this should be the only usage of Device
|
||||
def canonicalize_device(device:str|tuple|list|None) -> str|tuple[str, ...]:
|
||||
return tuple(Device.canonicalize(d) for d in device) if isinstance(device, (tuple, list)) else Device.canonicalize(device)
|
||||
if not isinstance(device, (tuple, list)): return Device.canonicalize(device)
|
||||
return canonical[0] if len(canonical:=tuple(Device.canonicalize(d) for d in device)) == 1 else canonical
|
||||
|
||||
# *** all in scope Tensors are here. this gets relevant UOps ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user