This commit is contained in:
Damian at mba
2022-10-21 15:07:11 +02:00
parent e574a1574f
commit 64051d081c
6 changed files with 22 additions and 201 deletions

View File

@@ -51,9 +51,8 @@ class Img2Img(Generator):
img_callback = step_callback,
unconditional_guidance_scale=cfg_scale,
unconditional_conditioning=uc,
init_latent = self.init_latent,
init_latent = self.init_latent, # changes how noising is performed in ksampler
extra_conditioning_info = extra_conditioning_info
# changes how noising is performed in ksampler
)
return self.sample_to_image(samples)

View File

@@ -29,9 +29,9 @@ work fine.
import torch
import numpy as np
from models.clipseg import CLIPDensePredT
from clipseg_models.clipseg import CLIPDensePredT
from einops import rearrange, repeat
from PIL import Image
from PIL import Image, ImageOps
from torchvision import transforms
CLIP_VERSION = 'ViT-B/16'
@@ -50,9 +50,14 @@ class SegmentedGrayscale(object):
discrete_heatmap = self.heatmap.lt(threshold).int()
return self._rescale(Image.fromarray(np.uint8(discrete_heatmap*255),mode='L'))
def to_transparent(self)->Image:
def to_transparent(self,invert:bool=False)->Image:
transparent_image = self.image.copy()
transparent_image.putalpha(self.to_grayscale())
gs = self.to_grayscale()
# The following line looks like a bug, but isn't.
# For img2img, we want the selected regions to be transparent,
# but to_grayscale() returns the opposite.
gs = ImageOps.invert(gs) if not invert else gs
transparent_image.putalpha(gs)
return transparent_image
# unscales and uncrops the 352x352 heatmap so that it matches the image again
@@ -79,7 +84,7 @@ class Txt2Mask(object):
self.model.load_state_dict(torch.load(CLIPSEG_WEIGHTS, map_location=torch.device('cpu')), strict=False)
@torch.no_grad()
def segment(self, image:Image, prompt:str) -> SegmentedGrayscale:
def segment(self, image, prompt:str) -> SegmentedGrayscale:
'''
Given a prompt string such as "a bagel", tries to identify the object in the
provided image and returns a SegmentedGrayscale object in which the brighter
@@ -94,6 +99,10 @@ class Txt2Mask(object):
transforms.Resize((CLIPSEG_SIZE, CLIPSEG_SIZE)), # must be multiple of 64...
])
if type(image) is str:
image = Image.open(image).convert('RGB')
image = ImageOps.exif_transpose(image)
img = self._scale_and_crop(image)
img = transform(img).unsqueeze(0)

View File

@@ -1,5 +1,4 @@
"""SAMPLING ONLY."""
from typing import Union
import torch
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
@@ -29,7 +28,7 @@ class DDIMSampler(Sampler):
def p_sample(
self,
x,
c: Union[torch.Tensor, list],
c,
t,
index,
repeat_noise=False,

View File

@@ -8,7 +8,7 @@ import numpy as np
from einops import rearrange
from ldm.util import instantiate_from_config
#from ldm.modules.attention import LinearAttention
from ldm.modules.attention import LinearAttention
import psutil
@@ -151,10 +151,10 @@ class ResnetBlock(nn.Module):
return x + h
#class LinAttnBlock(LinearAttention):
# """to match AttnBlock usage"""
# def __init__(self, in_channels):
# super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):