Remove accelerate library

This library is not required to use k-diffusion
Make k-diffusion wrapper closer to the other samplers
This commit is contained in:
BlueAmulet
2022-08-25 11:04:57 -06:00
parent 1c8ecacddf
commit 39b55ae016
5 changed files with 13 additions and 26 deletions

View File

@@ -2,7 +2,6 @@
import k_diffusion as K
import torch
import torch.nn as nn
import accelerate
class CFGDenoiser(nn.Module):
def __init__(self, model):
@@ -17,12 +16,11 @@ class CFGDenoiser(nn.Module):
return uncond + (cond - uncond) * cond_scale
class KSampler(object):
def __init__(self,model,schedule="lms", **kwargs):
def __init__(self, model, schedule="lms", device="cuda", **kwargs):
super().__init__()
self.model = K.external.CompVisDenoiser(model)
self.accelerator = accelerate.Accelerator()
self.device = self.accelerator.device
self.model = K.external.CompVisDenoiser(model)
self.schedule = schedule
self.device = device
def forward(self, x, sigma, uncond, cond, cond_scale):
x_in = torch.cat([x] * 2)
@@ -67,8 +65,5 @@ class KSampler(object):
x = torch.randn([batch_size, *shape], device=self.device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(self.model)
extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': unconditional_guidance_scale}
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not self.accelerator.is_main_process),
return (K.sampling.__dict__[f'sample_{self.schedule}'](model_wrap_cfg, x, sigmas, extra_args=extra_args),
None)
def gather(samples_ddim):
return self.accelerator.gather(samples_ddim)

View File

@@ -467,17 +467,17 @@ The vast majority of these arguments default to reasonable values.
elif self.sampler_name == 'ddim':
self.sampler = DDIMSampler(self.model, device=self.device)
elif self.sampler_name == 'k_dpm_2_a':
self.sampler = KSampler(self.model,'dpm_2_ancestral')
self.sampler = KSampler(self.model, 'dpm_2_ancestral', device=self.device)
elif self.sampler_name == 'k_dpm_2':
self.sampler = KSampler(self.model,'dpm_2')
self.sampler = KSampler(self.model, 'dpm_2', device=self.device)
elif self.sampler_name == 'k_euler_a':
self.sampler = KSampler(self.model,'euler_ancestral')
self.sampler = KSampler(self.model, 'euler_ancestral', device=self.device)
elif self.sampler_name == 'k_euler':
self.sampler = KSampler(self.model,'euler')
self.sampler = KSampler(self.model, 'euler', device=self.device)
elif self.sampler_name == 'k_heun':
self.sampler = KSampler(self.model,'heun')
self.sampler = KSampler(self.model, 'heun', device=self.device)
elif self.sampler_name == 'k_lms':
self.sampler = KSampler(self.model,'lms')
self.sampler = KSampler(self.model, 'lms', device=self.device)
else:
msg = f'unsupported sampler {self.sampler_name}, defaulting to plms'
self.sampler = PLMSSampler(self.model, device=self.device)