add configs for training unconditional/class-conditional ldms

This commit is contained in:
ablattmann
2021-12-22 15:57:23 +01:00
parent f8b4a07105
commit 171cf29fb5
13 changed files with 562 additions and 53 deletions

View File

@@ -259,3 +259,9 @@ class HybridConditioner(nn.Module):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()