mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-04-23 03:00:31 -04:00
add support for Apple hardware using MPS acceleration
This commit is contained in:
@@ -4,6 +4,7 @@ import torch
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from ldm.dream.devices import choose_torch_device
|
||||
|
||||
from ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters,
|
||||
@@ -14,17 +15,17 @@ from ldm.modules.diffusionmodules.util import (
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
def __init__(self, model, schedule='linear', device='cuda', **kwargs):
|
||||
def __init__(self, model, schedule='linear', device=None, **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
self.device = device
|
||||
self.device = device or choose_torch_device()
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device(self.device):
|
||||
attr = attr.to(torch.device(self.device))
|
||||
attr = attr.to(dtype=torch.float32, device=self.device)
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(
|
||||
|
||||
Reference in New Issue
Block a user