Fix cl import in the copy_speed test and cifar example (#2586)

* fix CL import

* update test to only run on GPU

* update hlb_cifar too
This commit is contained in:
qazal
2023-12-03 12:22:07 -05:00
committed by GitHub
parent 3226b3d96b
commit ab2d4d8d29
2 changed files with 6 additions and 5 deletions

View File

@@ -430,8 +430,8 @@ if __name__ == "__main__":
from tinygrad.runtime.ops_hip import HIP
devices = [f"hip:{i}" for i in range(HIP.device_count)]
else:
from tinygrad.runtime.ops_gpu import CL
devices = [f"gpu:{i}" for i in range(len(CL.devices))]
from tinygrad.runtime.ops_gpu import CLDevice
devices = [f"gpu:{i}" for i in range(len(CLDevice.device_ids))]
world_size = len(devices)
# ensure that the batch size is divisible by the number of devices