mirror of
https://github.com/invoke-ai/InvokeAI.git
synced 2026-02-19 00:14:18 -05:00
Apply black
This commit is contained in:
@@ -8,25 +8,26 @@ from main import instantiate_from_config
|
||||
from ldm.models.diffusion.ddim import DDIMSampler
|
||||
from ldm.invoke.devices import choose_torch_device
|
||||
|
||||
|
||||
def make_batch(image, mask, device):
|
||||
image = np.array(Image.open(image).convert("RGB"))
|
||||
image = image.astype(np.float32)/255.0
|
||||
image = image[None].transpose(0,3,1,2)
|
||||
image = image.astype(np.float32) / 255.0
|
||||
image = image[None].transpose(0, 3, 1, 2)
|
||||
image = torch.from_numpy(image)
|
||||
|
||||
mask = np.array(Image.open(mask).convert("L"))
|
||||
mask = mask.astype(np.float32)/255.0
|
||||
mask = mask[None,None]
|
||||
mask = mask.astype(np.float32) / 255.0
|
||||
mask = mask[None, None]
|
||||
mask[mask < 0.5] = 0
|
||||
mask[mask >= 0.5] = 1
|
||||
mask = torch.from_numpy(mask)
|
||||
|
||||
masked_image = (1-mask)*image
|
||||
masked_image = (1 - mask) * image
|
||||
|
||||
batch = {"image": image, "mask": mask, "masked_image": masked_image}
|
||||
for k in batch:
|
||||
batch[k] = batch[k].to(device=device)
|
||||
batch[k] = batch[k]*2.0-1.0
|
||||
batch[k] = batch[k] * 2.0 - 1.0
|
||||
return batch
|
||||
|
||||
|
||||
@@ -58,11 +59,10 @@ if __name__ == "__main__":
|
||||
|
||||
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
|
||||
model = instantiate_from_config(config.model)
|
||||
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
|
||||
strict=False)
|
||||
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False)
|
||||
|
||||
device = choose_torch_device()
|
||||
model = model.to(device)
|
||||
device = choose_torch_device()
|
||||
model = model.to(device)
|
||||
sampler = DDIMSampler(model)
|
||||
|
||||
os.makedirs(opt.outdir, exist_ok=True)
|
||||
@@ -74,25 +74,19 @@ if __name__ == "__main__":
|
||||
|
||||
# encode masked image and concat downsampled mask
|
||||
c = model.cond_stage_model.encode(batch["masked_image"])
|
||||
cc = torch.nn.functional.interpolate(batch["mask"],
|
||||
size=c.shape[-2:])
|
||||
cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
|
||||
c = torch.cat((c, cc), dim=1)
|
||||
|
||||
shape = (c.shape[1]-1,)+c.shape[2:]
|
||||
samples_ddim, _ = sampler.sample(S=opt.steps,
|
||||
conditioning=c,
|
||||
batch_size=c.shape[0],
|
||||
shape=shape,
|
||||
verbose=False)
|
||||
shape = (c.shape[1] - 1,) + c.shape[2:]
|
||||
samples_ddim, _ = sampler.sample(
|
||||
S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False
|
||||
)
|
||||
x_samples_ddim = model.decode_first_stage(samples_ddim)
|
||||
|
||||
image = torch.clamp((batch["image"]+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
mask = torch.clamp((batch["mask"]+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
|
||||
min=0.0, max=1.0)
|
||||
image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
|
||||
inpainted = (1-mask)*image+mask*predicted_image
|
||||
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
|
||||
inpainted = (1 - mask) * image + mask * predicted_image
|
||||
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
|
||||
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)
|
||||
|
||||
Reference in New Issue
Block a user