wip bringing cross-attention to PLMS and DDIM

This commit is contained in:
Damian at mba
2022-10-18 22:09:06 +02:00
parent 09f62032ec
commit 54e6a68acb
6 changed files with 112 additions and 63 deletions

View File

@@ -192,6 +192,7 @@ class Sampler(object):
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
steps=S,
**kwargs
)
return samples, intermediates
@@ -216,6 +217,7 @@ class Sampler(object):
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
steps=None,
**kwargs
):
b = shape[0]
time_range = (
@@ -233,7 +235,7 @@ class Sampler(object):
dynamic_ncols=True,
)
old_eps = []
self.prepare_to_sample(t_enc=total_steps)
self.prepare_to_sample(t_enc=total_steps,**kwargs)
img = self.get_initial_image(x_T,shape,total_steps)
# probably don't need this at all
@@ -323,7 +325,7 @@ class Sampler(object):
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
x0 = init_latent
self.prepare_to_sample(t_enc=total_steps)
self.prepare_to_sample(t_enc=total_steps,**kwargs)
for i, step in enumerate(iterator):
index = total_steps - i - 1
@@ -414,5 +416,3 @@ class Sampler(object):
'''
return self.model.q_sample(x0,ts)