add support for Apple hardware using MPS acceleration

This commit is contained in:
Lincoln Stein
2022-08-31 00:33:23 -04:00
parent 1714816fe2
commit bdb0651eb2
16 changed files with 361 additions and 52 deletions

11
ldm/dream/devices.py Normal file
View File

@@ -0,0 +1,11 @@
import torch
def choose_torch_device() -> str:
'''Convenience routine for guessing which GPU device to run model on'''
if torch.cuda.is_available():
return 'cuda'
if torch.backends.mps.is_available():
return 'mps'
return 'cpu'