plms works, bugs quashed

- The plms sampler now works with custom inpainting model
- Quashed bug that was causing generation on normal models to fail (oops!)
- Can now generate non-square images with custom inpainting model

Credits for advice and assistance during porting:

@any-winter-4079 (http://github.com/any-winter-4079)
@db3000 (Danny Beer http://github.com/db3000)
This commit is contained in:
Lincoln Stein
2022-10-25 11:42:30 -04:00
parent b101be041b
commit e33971fe2c
5 changed files with 33 additions and 28 deletions

View File

@@ -439,3 +439,24 @@ class Sampler(object):
def conditioning_key(self)->str:
return self.model.model.conditioning_key
def make_cond_in(self, uncond, cond):
'''
This handles the choice between a conditional conditioning
that is a tensor (used by cross attention) vs one that is a dict
used by 'hybrid'
'''
if isinstance(cond, dict):
assert isinstance(uncond, dict)
cond_in = dict()
for k in cond:
if isinstance(cond[k], list):
cond_in[k] = [
torch.cat([uncond[k][i], cond[k][i]])
for i in range(len(cond[k]))
]
else:
cond_in[k] = torch.cat([uncond[k], cond[k]])
else:
cond_in = torch.cat([uncond, cond])
return cond_in