add support for Apple hardware using MPS acceleration

This commit is contained in:
Lincoln Stein
2022-08-31 00:33:23 -04:00
parent 1714816fe2
commit bdb0651eb2
16 changed files with 361 additions and 52 deletions

View File

@@ -4,6 +4,7 @@ import torch
import numpy as np
from tqdm import tqdm
from functools import partial
from ldm.dream.devices import choose_torch_device
from ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters,
@@ -14,17 +15,17 @@ from ldm.modules.diffusionmodules.util import (
class DDIMSampler(object):
def __init__(self, model, schedule='linear', device='cuda', **kwargs):
def __init__(self, model, schedule='linear', device=None, **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
self.device = device
self.device = device or choose_torch_device()
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device(self.device):
attr = attr.to(torch.device(self.device))
attr = attr.to(dtype=torch.float32, device=self.device)
setattr(self, name, attr)
def make_schedule(