mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-05 05:44:58 -05:00
add support for Apple hardware using MPS acceleration
This commit is contained in:
@@ -18,6 +18,7 @@ from pytorch_lightning import seed_everything
|
||||
from ldm.util import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.models.diffusion.plms import PLMSSampler
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
|
||||
|
||||
def chunk(it, size):
|
||||
@@ -40,7 +41,7 @@ def load_model_from_config(config, ckpt, verbose=False):
|
||||
print("unexpected keys:")
|
||||
print(u)
|
||||
|
||||
model.cuda()
|
||||
model.to(choose_torch_device())
|
||||
model.eval()
|
||||
return model
|
||||
|
||||
@@ -199,7 +200,7 @@ def main():
|
||||
config = OmegaConf.load(f"{opt.config}")
|
||||
model = load_model_from_config(config, f"{opt.ckpt}")
|
||||
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
device = choose_torch_device()
|
||||
model = model.to(device)
|
||||
|
||||
if opt.plms:
|
||||
@@ -241,8 +242,10 @@ def main():
|
||||
print(f"target t_enc is {t_enc} steps")
|
||||
|
||||
precision_scope = autocast if opt.precision == "autocast" else nullcontext
|
||||
if device.type in ['mps', 'cpu']:
|
||||
precision_scope = nullcontext # have to use f32 on mps
|
||||
with torch.no_grad():
|
||||
with precision_scope("cuda"):
|
||||
with precision_scope(device.type):
|
||||
with model.ema_scope():
|
||||
tic = time.time()
|
||||
all_samples = list()
|
||||
|
||||
Reference in New Issue
Block a user