Apply black

This commit is contained in:
Martin Kristiansen
2023-07-27 10:54:01 -04:00
parent 2183dba5c5
commit 218b6d0546
148 changed files with 5486 additions and 6296 deletions

View File

@@ -18,7 +18,7 @@ from pytorch_lightning import seed_everything
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.invoke.devices import choose_torch_device
from ldm.invoke.devices import choose_torch_device
def chunk(it, size):
@@ -55,7 +55,7 @@ def load_img(path):
image = np.array(image).astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
return 2.*image - 1.
return 2.0 * image - 1.0
def main():
@@ -66,33 +66,24 @@ def main():
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render",
)
parser.add_argument(
"--init-img",
type=str,
nargs="?",
help="path to the input image"
)
parser.add_argument("--init-img", type=str, nargs="?", help="path to the input image")
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/img2img-samples"
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/img2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action='store_true',
action="store_true",
help="do not save indiviual samples. For speed measurements.",
)
@@ -105,12 +96,12 @@ def main():
parser.add_argument(
"--plms",
action='store_true',
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--fixed_code",
action='store_true',
action="store_true",
help="if enabled, uses the same starting code across all samples ",
)
@@ -187,11 +178,7 @@ def main():
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
)
opt = parser.parse_args()
@@ -232,18 +219,18 @@ def main():
assert os.path.isfile(opt.init_img)
init_image = load_img(opt.init_img).to(device)
init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)
init_image = repeat(init_image, "1 ... -> b ...", b=batch_size)
init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space
sampler.make_schedule(ddim_num_steps=opt.ddim_steps, ddim_eta=opt.ddim_eta, verbose=False)
assert 0. <= opt.strength <= 1., 'can only work with strength in [0.0, 1.0]'
assert 0.0 <= opt.strength <= 1.0, "can only work with strength in [0.0, 1.0]"
t_enc = int(opt.strength * opt.ddim_steps)
print(f"target t_enc is {t_enc} steps")
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
if device.type in ["mps", "cpu"]:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
@@ -259,37 +246,42 @@ def main():
c = model.get_learned_conditioning(prompts)
# encode (scaled latent)
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))
z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc] * batch_size).to(device))
# decode it
samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,)
samples = sampler.decode(
z_enc,
c,
t_enc,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
)
x_samples = model.decode_first_stage(samples)
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
if not opt.skip_save:
for x_sample in x_samples:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png"))
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
all_samples.append(x_samples)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
if __name__ == "__main__":

View File

@@ -8,25 +8,26 @@ from main import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.invoke.devices import choose_torch_device
def make_batch(image, mask, device):
image = np.array(Image.open(image).convert("RGB"))
image = image.astype(np.float32)/255.0
image = image[None].transpose(0,3,1,2)
image = image.astype(np.float32) / 255.0
image = image[None].transpose(0, 3, 1, 2)
image = torch.from_numpy(image)
mask = np.array(Image.open(mask).convert("L"))
mask = mask.astype(np.float32)/255.0
mask = mask[None,None]
mask = mask.astype(np.float32) / 255.0
mask = mask[None, None]
mask[mask < 0.5] = 0
mask[mask >= 0.5] = 1
mask = torch.from_numpy(mask)
masked_image = (1-mask)*image
masked_image = (1 - mask) * image
batch = {"image": image, "mask": mask, "masked_image": masked_image}
for k in batch:
batch[k] = batch[k].to(device=device)
batch[k] = batch[k]*2.0-1.0
batch[k] = batch[k] * 2.0 - 1.0
return batch
@@ -58,11 +59,10 @@ if __name__ == "__main__":
config = OmegaConf.load("models/ldm/inpainting_big/config.yaml")
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"],
strict=False)
model.load_state_dict(torch.load("models/ldm/inpainting_big/last.ckpt")["state_dict"], strict=False)
device = choose_torch_device()
model = model.to(device)
device = choose_torch_device()
model = model.to(device)
sampler = DDIMSampler(model)
os.makedirs(opt.outdir, exist_ok=True)
@@ -74,25 +74,19 @@ if __name__ == "__main__":
# encode masked image and concat downsampled mask
c = model.cond_stage_model.encode(batch["masked_image"])
cc = torch.nn.functional.interpolate(batch["mask"],
size=c.shape[-2:])
cc = torch.nn.functional.interpolate(batch["mask"], size=c.shape[-2:])
c = torch.cat((c, cc), dim=1)
shape = (c.shape[1]-1,)+c.shape[2:]
samples_ddim, _ = sampler.sample(S=opt.steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False)
shape = (c.shape[1] - 1,) + c.shape[2:]
samples_ddim, _ = sampler.sample(
S=opt.steps, conditioning=c, batch_size=c.shape[0], shape=shape, verbose=False
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
image = torch.clamp((batch["image"]+1.0)/2.0,
min=0.0, max=1.0)
mask = torch.clamp((batch["mask"]+1.0)/2.0,
min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim+1.0)/2.0,
min=0.0, max=1.0)
image = torch.clamp((batch["image"] + 1.0) / 2.0, min=0.0, max=1.0)
mask = torch.clamp((batch["mask"] + 1.0) / 2.0, min=0.0, max=1.0)
predicted_image = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
inpainted = (1-mask)*image+mask*predicted_image
inpainted = inpainted.cpu().numpy().transpose(0,2,3,1)[0]*255
inpainted = (1 - mask) * image + mask * predicted_image
inpainted = inpainted.cpu().numpy().transpose(0, 2, 3, 1)[0] * 255
Image.fromarray(inpainted.astype(np.uint8)).save(outpath)

View File

@@ -59,29 +59,24 @@ def load_model_from_config(config, ckpt, verbose=False):
class Searcher(object):
def __init__(self, database, retriever_version='ViT-L/14'):
def __init__(self, database, retriever_version="ViT-L/14"):
assert database in DATABASES
# self.database = self.load_database(database)
self.database_name = database
self.searcher_savedir = f'data/rdm/searchers/{self.database_name}'
self.database_path = f'data/rdm/retrieval_databases/{self.database_name}'
self.searcher_savedir = f"data/rdm/searchers/{self.database_name}"
self.database_path = f"data/rdm/retrieval_databases/{self.database_name}"
self.retriever = self.load_retriever(version=retriever_version)
self.database = {'embedding': [],
'img_id': [],
'patch_coords': []}
self.database = {"embedding": [], "img_id": [], "patch_coords": []}
self.load_database()
self.load_searcher()
def train_searcher(self, k,
metric='dot_product',
searcher_savedir=None):
print('Start training searcher')
searcher = scann.scann_ops_pybind.builder(self.database['embedding'] /
np.linalg.norm(self.database['embedding'], axis=1)[:, np.newaxis],
k, metric)
def train_searcher(self, k, metric="dot_product", searcher_savedir=None):
print("Start training searcher")
searcher = scann.scann_ops_pybind.builder(
self.database["embedding"] / np.linalg.norm(self.database["embedding"], axis=1)[:, np.newaxis], k, metric
)
self.searcher = searcher.score_brute_force().build()
print('Finish training searcher')
print("Finish training searcher")
if searcher_savedir is not None:
print(f'Save trained searcher under "{searcher_savedir}"')
@@ -91,36 +86,40 @@ class Searcher(object):
def load_single_file(self, saved_embeddings):
compressed = np.load(saved_embeddings)
self.database = {key: compressed[key] for key in compressed.files}
print('Finished loading of clip embeddings.')
print("Finished loading of clip embeddings.")
def load_multi_files(self, data_archive):
out_data = {key: [] for key in self.database}
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
out_data[key].append(d[key])
return out_data
def load_database(self):
print(f'Load saved patch embedding from "{self.database_path}"')
file_content = glob.glob(os.path.join(self.database_path, '*.npz'))
file_content = glob.glob(os.path.join(self.database_path, "*.npz"))
if len(file_content) == 1:
self.load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(self.load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
prefetched_data = parallel_data_prefetch(
self.load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
self.database = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in
self.database}
self.database = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in self.database
}
else:
raise ValueError(f'No npz-files in specified path "{self.database_path}" is this directory existing?')
print(f'Finished loading of retrieval database of length {self.database["embedding"].shape[0]}.')
def load_retriever(self, version='ViT-L/14', ):
def load_retriever(
self,
version="ViT-L/14",
):
model = FrozenClipImageEmbedder(model=version)
if torch.cuda.is_available():
model.cuda()
@@ -128,14 +127,14 @@ class Searcher(object):
return model
def load_searcher(self):
print(f'load searcher for database {self.database_name} from {self.searcher_savedir}')
print(f"load searcher for database {self.database_name} from {self.searcher_savedir}")
self.searcher = scann.scann_ops_pybind.load_searcher(self.searcher_savedir)
print('Finished loading searcher.')
print("Finished loading searcher.")
def search(self, x, k):
if self.searcher is None and self.database['embedding'].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, 'Cannot search with uninitialized searcher'
if self.searcher is None and self.database["embedding"].shape[0] < 2e4:
self.train_searcher(k) # quickly fit searcher on the fly for small databases
assert self.searcher is not None, "Cannot search with uninitialized searcher"
if isinstance(x, torch.Tensor):
x = x.detach().cpu().numpy()
if len(x.shape) == 3:
@@ -146,17 +145,19 @@ class Searcher(object):
nns, distances = self.searcher.search_batched(query_embeddings, final_num_neighbors=k)
end = time.time()
out_embeddings = self.database['embedding'][nns]
out_img_ids = self.database['img_id'][nns]
out_pc = self.database['patch_coords'][nns]
out_embeddings = self.database["embedding"][nns]
out_img_ids = self.database["img_id"][nns]
out_pc = self.database["patch_coords"][nns]
out = {'nn_embeddings': out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
'img_ids': out_img_ids,
'patch_coords': out_pc,
'queries': x,
'exec_time': end - start,
'nns': nns,
'q_embeddings': query_embeddings}
out = {
"nn_embeddings": out_embeddings / np.linalg.norm(out_embeddings, axis=-1)[..., np.newaxis],
"img_ids": out_img_ids,
"patch_coords": out_pc,
"queries": x,
"exec_time": end - start,
"nns": nns,
"q_embeddings": query_embeddings,
}
return out
@@ -173,20 +174,16 @@ if __name__ == "__main__":
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
@@ -206,7 +203,7 @@ if __name__ == "__main__":
parser.add_argument(
"--plms",
action='store_true',
action="store_true",
help="use plms sampling",
)
@@ -287,14 +284,14 @@ if __name__ == "__main__":
parser.add_argument(
"--database",
type=str,
default='artbench-surrealism',
default="artbench-surrealism",
choices=DATABASES,
help="The database used for the search, only applied when --use_neighbors=True",
)
parser.add_argument(
"--use_neighbors",
default=False,
action='store_true',
action="store_true",
help="Include neighbors in addition to text prompt for conditioning",
)
parser.add_argument(
@@ -358,41 +355,43 @@ if __name__ == "__main__":
uc = None
if searcher is not None:
nn_dict = searcher(c, opt.knn)
c = torch.cat([c, torch.from_numpy(nn_dict['nn_embeddings']).cuda()], dim=1)
c = torch.cat([c, torch.from_numpy(nn_dict["nn_embeddings"]).cuda()], dim=1)
if opt.scale != 1.0:
uc = torch.zeros_like(c)
if isinstance(prompts, tuple):
prompts = list(prompts)
shape = [16, opt.H // 16, opt.W // 16] # note: currently hardcoded for f16 model
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=c.shape[0],
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
)
x_samples_ddim = model.decode_first_stage(samples_ddim)
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png"))
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
all_samples.append(x_samples_ddim)
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
print(f"Your samples are ready and waiting for you here: \n{outpath} \nEnjoy.")

View File

@@ -25,15 +25,19 @@ from pytorch_lightning.utilities import rank_zero_info
from ldm.data.base import Txt2ImgIterableBaseDataset
from ldm.util import instantiate_from_config
def fix_func(orig):
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
def new_func(*args, **kw):
device = kw.get("device", "mps")
kw["device"]="cpu"
kw["device"] = "cpu"
return orig(*args, **kw).to(device)
return new_func
return orig
torch.rand = fix_func(torch.rand)
torch.rand_like = fix_func(torch.rand_like)
torch.randn = fix_func(torch.randn)
@@ -43,18 +47,19 @@ torch.randint_like = fix_func(torch.randint_like)
torch.bernoulli = fix_func(torch.bernoulli)
torch.multinomial = fix_func(torch.multinomial)
def load_model_from_config(config, ckpt, verbose=False):
print(f'Loading model from {ckpt}')
pl_sd = torch.load(ckpt, map_location='cpu')
sd = pl_sd['state_dict']
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt, map_location="cpu")
sd = pl_sd["state_dict"]
config.model.params.ckpt_path = ckpt
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print('missing keys:')
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print('unexpected keys:')
print("unexpected keys:")
print(u)
if torch.cuda.is_available():
@@ -66,132 +71,130 @@ def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
'-n',
'--name',
"-n",
"--name",
type=str,
const=True,
default='',
nargs='?',
help='postfix for logdir',
default="",
nargs="?",
help="postfix for logdir",
)
parser.add_argument(
'-r',
'--resume',
"-r",
"--resume",
type=str,
const=True,
default='',
nargs='?',
help='resume from logdir or checkpoint in logdir',
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
'-b',
'--base',
nargs='*',
metavar='base_config.yaml',
help='paths to base configs. Loaded from left-to-right. '
'Parameters can be overwritten or added with command-line options of the form `--key value`.',
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
'-t',
'--train',
"-t",
"--train",
type=str2bool,
const=True,
default=False,
nargs='?',
help='train',
nargs="?",
help="train",
)
parser.add_argument(
'--no-test',
"--no-test",
type=str2bool,
const=True,
default=False,
nargs='?',
help='disable test',
nargs="?",
help="disable test",
)
parser.add_argument("-p", "--project", help="name of new or path to existing project")
parser.add_argument(
'-p', '--project', help='name of new or path to existing project'
)
parser.add_argument(
'-d',
'--debug',
"-d",
"--debug",
type=str2bool,
nargs='?',
nargs="?",
const=True,
default=False,
help='enable post-mortem debugging',
help="enable post-mortem debugging",
)
parser.add_argument(
'-s',
'--seed',
"-s",
"--seed",
type=int,
default=23,
help='seed for seed_everything',
help="seed for seed_everything",
)
parser.add_argument(
'-f',
'--postfix',
"-f",
"--postfix",
type=str,
default='',
help='post-postfix for default name',
default="",
help="post-postfix for default name",
)
parser.add_argument(
'-l',
'--logdir',
"-l",
"--logdir",
type=str,
default='logs',
help='directory for logging dat shit',
default="logs",
help="directory for logging dat shit",
)
parser.add_argument(
'--scale_lr',
"--scale_lr",
type=str2bool,
nargs='?',
nargs="?",
const=True,
default=True,
help='scale base-lr by ngpu * batch_size * n_accumulate',
help="scale base-lr by ngpu * batch_size * n_accumulate",
)
parser.add_argument(
'--datadir_in_name',
"--datadir_in_name",
type=str2bool,
nargs='?',
nargs="?",
const=True,
default=True,
help='Prepend the final directory in the data_root to the output directory name',
help="Prepend the final directory in the data_root to the output directory name",
)
parser.add_argument(
'--actual_resume',
"--actual_resume",
type=str,
default='',
help='Path to model to actually resume from',
default="",
help="Path to model to actually resume from",
)
parser.add_argument(
'--data_root',
"--data_root",
type=str,
required=True,
help='Path to directory with training images',
help="Path to directory with training images",
)
parser.add_argument(
'--embedding_manager_ckpt',
"--embedding_manager_ckpt",
type=str,
default='',
help='Initialize embedding manager from a checkpoint',
default="",
help="Initialize embedding manager from a checkpoint",
)
parser.add_argument(
'--init_word',
"--init_word",
type=str,
help='Word to use as source for initial token embedding.',
help="Word to use as source for initial token embedding.",
)
return parser
@@ -226,9 +229,7 @@ def worker_init_fn(_):
if isinstance(dataset, Txt2ImgIterableBaseDataset):
split_size = dataset.num_records // worker_info.num_workers
# reset num_records to the true number to retain reliable length information
dataset.sample_ids = dataset.valid_ids[
worker_id * split_size : (worker_id + 1) * split_size
]
dataset.sample_ids = dataset.valid_ids[worker_id * split_size : (worker_id + 1) * split_size]
current_id = np.random.choice(len(np.random.get_state()[1]), 1)
return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
else:
@@ -252,25 +253,19 @@ class DataModuleFromConfig(pl.LightningDataModule):
super().__init__()
self.batch_size = batch_size
self.dataset_configs = dict()
self.num_workers = (
num_workers if num_workers is not None else batch_size * 2
)
self.num_workers = num_workers if num_workers is not None else batch_size * 2
self.use_worker_init_fn = use_worker_init_fn
if train is not None:
self.dataset_configs['train'] = train
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs['validation'] = validation
self.val_dataloader = partial(
self._val_dataloader, shuffle=shuffle_val_dataloader
)
self.dataset_configs["validation"] = validation
self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader)
if test is not None:
self.dataset_configs['test'] = test
self.test_dataloader = partial(
self._test_dataloader, shuffle=shuffle_test_loader
)
self.dataset_configs["test"] = test
self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader)
if predict is not None:
self.dataset_configs['predict'] = predict
self.dataset_configs["predict"] = predict
self.predict_dataloader = self._predict_dataloader
self.wrap = wrap
@@ -279,24 +274,19 @@ class DataModuleFromConfig(pl.LightningDataModule):
instantiate_from_config(data_cfg)
def setup(self, stage=None):
self.datasets = dict(
(k, instantiate_from_config(self.dataset_configs[k]))
for k in self.dataset_configs
)
self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs)
if self.wrap:
for k in self.datasets:
self.datasets[k] = WrappedDataset(self.datasets[k])
def _train_dataloader(self):
is_iterable_dataset = isinstance(
self.datasets['train'], Txt2ImgIterableBaseDataset
)
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['train'],
self.datasets["train"],
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False if is_iterable_dataset else True,
@@ -304,15 +294,12 @@ class DataModuleFromConfig(pl.LightningDataModule):
)
def _val_dataloader(self, shuffle=False):
if (
isinstance(self.datasets['validation'], Txt2ImgIterableBaseDataset)
or self.use_worker_init_fn
):
if isinstance(self.datasets["validation"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['validation'],
self.datasets["validation"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
@@ -320,9 +307,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
)
def _test_dataloader(self, shuffle=False):
is_iterable_dataset = isinstance(
self.datasets['train'], Txt2ImgIterableBaseDataset
)
is_iterable_dataset = isinstance(self.datasets["train"], Txt2ImgIterableBaseDataset)
if is_iterable_dataset or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
@@ -332,7 +317,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
shuffle = shuffle and (not is_iterable_dataset)
return DataLoader(
self.datasets['test'],
self.datasets["test"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
@@ -340,15 +325,12 @@ class DataModuleFromConfig(pl.LightningDataModule):
)
def _predict_dataloader(self, shuffle=False):
if (
isinstance(self.datasets['predict'], Txt2ImgIterableBaseDataset)
or self.use_worker_init_fn
):
if isinstance(self.datasets["predict"], Txt2ImgIterableBaseDataset) or self.use_worker_init_fn:
init_fn = worker_init_fn
else:
init_fn = None
return DataLoader(
self.datasets['predict'],
self.datasets["predict"],
batch_size=self.batch_size,
num_workers=self.num_workers,
worker_init_fn=init_fn,
@@ -356,9 +338,7 @@ class DataModuleFromConfig(pl.LightningDataModule):
class SetupCallback(Callback):
def __init__(
self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config
):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
super().__init__()
self.resume = resume
self.now = now
@@ -370,8 +350,8 @@ class SetupCallback(Callback):
def on_keyboard_interrupt(self, trainer, pl_module):
if trainer.global_rank == 0:
print('Summoning checkpoint.')
ckpt_path = os.path.join(self.ckptdir, 'last.ckpt')
print("Summoning checkpoint.")
ckpt_path = os.path.join(self.ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def on_pretrain_routine_start(self, trainer, pl_module):
@@ -381,36 +361,31 @@ class SetupCallback(Callback):
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
if 'callbacks' in self.lightning_config:
if (
'metrics_over_trainsteps_checkpoint'
in self.lightning_config['callbacks']
):
if "callbacks" in self.lightning_config:
if "metrics_over_trainsteps_checkpoint" in self.lightning_config["callbacks"]:
os.makedirs(
os.path.join(self.ckptdir, 'trainstep_checkpoints'),
os.path.join(self.ckptdir, "trainstep_checkpoints"),
exist_ok=True,
)
print('Project config')
print("Project config")
print(OmegaConf.to_yaml(self.config))
OmegaConf.save(
self.config,
os.path.join(self.cfgdir, '{}-project.yaml'.format(self.now)),
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
)
print('Lightning config')
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
OmegaConf.save(
OmegaConf.create({'lightning': self.lightning_config}),
os.path.join(
self.cfgdir, '{}-lightning.yaml'.format(self.now)
),
OmegaConf.create({"lightning": self.lightning_config}),
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, 'child_runs', name)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
@@ -435,10 +410,8 @@ class ImageLogger(Callback):
self.rescale = rescale
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = { }
self.log_steps = [
2**n for n in range(int(np.log2(self.batch_freq)) + 1)
]
self.logger_log_images = {}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
@@ -448,10 +421,8 @@ class ImageLogger(Callback):
self.log_first_step = log_first_step
@rank_zero_only
def log_local(
self, save_dir, split, images, global_step, current_epoch, batch_idx
):
root = os.path.join(save_dir, 'images', split)
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "images", split)
for k in images:
grid = torchvision.utils.make_grid(images[k], nrow=4)
if self.rescale:
@@ -459,22 +430,16 @@ class ImageLogger(Callback):
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = '{}_gs-{:06}_e-{:06}_b-{:06}.png'.format(
k, global_step, current_epoch, batch_idx
)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(k, global_step, current_epoch, batch_idx)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split='train'):
check_idx = (
batch_idx if self.log_on_batch_idx else pl_module.global_step
)
def log_img(self, pl_module, batch, batch_idx, split="train"):
check_idx = batch_idx if self.log_on_batch_idx else pl_module.global_step
if (
self.check_frequency(check_idx)
and hasattr( # batch_idx % self.batch_freq == 0
pl_module, 'log_images'
)
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
and callable(pl_module.log_images)
and self.max_images > 0
):
@@ -485,9 +450,7 @@ class ImageLogger(Callback):
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(
batch, split=split, **self.log_images_kwargs
)
images = pl_module.log_images(batch, split=split, **self.log_images_kwargs)
for k in images:
N = min(images[k].shape[0], self.max_images)
@@ -506,18 +469,16 @@ class ImageLogger(Callback):
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images = self.logger_log_images.get(logger, lambda *args, **kwargs: None)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, check_idx):
if (
(check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)
) and (check_idx > 0 or self.log_first_step):
if ((check_idx % self.batch_freq) == 0 or (check_idx in self.log_steps)) and (
check_idx > 0 or self.log_first_step
):
try:
self.log_steps.pop(0)
except IndexError as e:
@@ -526,23 +487,15 @@ class ImageLogger(Callback):
return True
return False
def on_train_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
if not self.disabled and (
pl_module.global_step > 0 or self.log_first_step
):
self.log_img(pl_module, batch, batch_idx, split='train')
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
if not self.disabled and (pl_module.global_step > 0 or self.log_first_step):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None
):
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=None):
if not self.disabled and pl_module.global_step > 0:
self.log_img(pl_module, batch, batch_idx, split='val')
if hasattr(pl_module, 'calibrate_grad_norm'):
if (
pl_module.calibrate_grad_norm and batch_idx % 25 == 0
) and batch_idx > 0:
self.log_img(pl_module, batch, batch_idx, split="val")
if hasattr(pl_module, "calibrate_grad_norm"):
if (pl_module.calibrate_grad_norm and batch_idx % 25 == 0) and batch_idx > 0:
self.log_gradients(trainer, pl_module, batch_idx=batch_idx)
@@ -562,19 +515,17 @@ class CUDACallback(Callback):
try:
epoch_time = trainer.training_type_plugin.reduce(epoch_time)
rank_zero_info(f'Average Epoch time: {epoch_time:.2f} seconds')
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")
if torch.cuda.is_available():
max_memory = (
torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
)
max_memory = torch.cuda.max_memory_allocated(trainer.root_gpu) / 2**20
max_memory = trainer.training_type_plugin.reduce(max_memory)
rank_zero_info(f'Average Peak memory {max_memory:.2f}MiB')
rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB")
except AttributeError:
pass
class ModeSwapCallback(Callback):
class ModeSwapCallback(Callback):
def __init__(self, swap_step=2000):
super().__init__()
self.is_frozen = False
@@ -589,7 +540,8 @@ class ModeSwapCallback(Callback):
self.is_frozen = False
trainer.optimizers = [pl_module.configure_opt_model()]
if __name__ == '__main__':
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
@@ -631,7 +583,7 @@ if __name__ == '__main__':
# params:
# key: value
now = datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
@@ -644,50 +596,47 @@ if __name__ == '__main__':
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
'-n/--name and -r/--resume cannot be specified both.'
'If you want to resume training in a new log folder, '
'use -n/--name in combination with --resume_from_checkpoint'
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError('Cannot find {}'.format(opt.resume))
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split('/')
paths = opt.resume.split("/")
# idx = len(paths)-paths[::-1].index("logs")+1
# logdir = "/".join(paths[:idx])
logdir = '/'.join(paths[:-2])
logdir = "/".join(paths[:-2])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip('/')
ckpt = os.path.join(logdir, 'checkpoints', 'last.ckpt')
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(
glob.glob(os.path.join(logdir, 'configs/*.yaml'))
)
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split('/')
_tmp = logdir.split("/")
nowname = _tmp[-1]
else:
if opt.name:
name = '_' + opt.name
name = "_" + opt.name
elif opt.base:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = '_' + cfg_name
name = "_" + cfg_name
else:
name = ''
name = ""
if opt.datadir_in_name:
now = os.path.basename(os.path.normpath(opt.data_root)) + now
nowname = now + name + opt.postfix
logdir = os.path.join(opt.logdir, nowname)
ckptdir = os.path.join(logdir, 'checkpoints')
cfgdir = os.path.join(logdir, 'configs')
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
try:
@@ -695,19 +644,19 @@ if __name__ == '__main__':
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop('lightning', OmegaConf.create())
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get('trainer', OmegaConf.create())
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to ddp
trainer_config['accelerator'] = 'auto'
trainer_config["accelerator"] = "auto"
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not 'gpus' in trainer_config:
del trainer_config['accelerator']
if not "gpus" in trainer_config:
del trainer_config["accelerator"]
cpu = True
else:
gpuinfo = trainer_config['gpus']
print(f'Running on GPUs {gpuinfo}')
gpuinfo = trainer_config["gpus"]
print(f"Running on GPUs {gpuinfo}")
cpu = False
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
@@ -715,9 +664,7 @@ if __name__ == '__main__':
# model
# config.model.params.personalization_config.params.init_word = opt.init_word
config.model.params.personalization_config.params.embedding_manager_ckpt = (
opt.embedding_manager_ckpt
)
config.model.params.personalization_config.params.embedding_manager_ckpt = opt.embedding_manager_ckpt
if opt.init_word:
config.model.params.personalization_config.params.initializer_words = [opt.init_word]
@@ -731,142 +678,128 @@ if __name__ == '__main__':
trainer_kwargs = dict()
# default logger configs
def_logger = 'csv'
def_logger_target = 'CSVLogger'
def_logger = "csv"
def_logger_target = "CSVLogger"
default_logger_cfgs = {
'wandb': {
'target': 'pytorch_lightning.loggers.WandbLogger',
'params': {
'name': nowname,
'save_dir': logdir,
'offline': opt.debug,
'id': nowname,
"wandb": {
"target": "pytorch_lightning.loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
"offline": opt.debug,
"id": nowname,
},
},
def_logger: {
'target': 'pytorch_lightning.loggers.' + def_logger_target,
'params': {
'name': def_logger,
'save_dir': logdir,
"target": "pytorch_lightning.loggers." + def_logger_target,
"params": {
"name": def_logger,
"save_dir": logdir,
},
},
}
default_logger_cfg = default_logger_cfgs[def_logger]
if 'logger' in lightning_config:
if "logger" in lightning_config:
logger_cfg = lightning_config.logger
else:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs['logger'] = instantiate_from_config(logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
'target': 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
'dirpath': ckptdir,
'filename': '{epoch:06}',
'verbose': True,
'save_last': True,
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
},
}
if hasattr(model, 'monitor'):
print(f'Monitoring {model.monitor} as checkpoint metric.')
default_modelckpt_cfg['params']['monitor'] = model.monitor
default_modelckpt_cfg['params']['save_top_k'] = 1
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["params"]["monitor"] = model.monitor
default_modelckpt_cfg["params"]["save_top_k"] = 1
if 'modelcheckpoint' in lightning_config:
if "modelcheckpoint" in lightning_config:
modelckpt_cfg = lightning_config.modelcheckpoint
else:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
print(f'Merged modelckpt-cfg: \n{modelckpt_cfg}')
if version.parse(pl.__version__) < version.parse('1.4.0'):
trainer_kwargs['checkpoint_callback'] = instantiate_from_config(
modelckpt_cfg
)
print(f"Merged modelckpt-cfg: \n{modelckpt_cfg}")
if version.parse(pl.__version__) < version.parse("1.4.0"):
trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
# add callback which sets up log directory
default_callbacks_cfg = {
'setup_callback': {
'target': 'main.SetupCallback',
'params': {
'resume': opt.resume,
'now': now,
'logdir': logdir,
'ckptdir': ckptdir,
'cfgdir': cfgdir,
'config': config,
'lightning_config': lightning_config,
"setup_callback": {
"target": "main.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
"lightning_config": lightning_config,
},
},
'image_logger': {
'target': 'main.ImageLogger',
'params': {
'batch_frequency': 750,
'max_images': 4,
'clamp': True,
"image_logger": {
"target": "main.ImageLogger",
"params": {
"batch_frequency": 750,
"max_images": 4,
"clamp": True,
},
},
'learning_rate_logger': {
'target': 'main.LearningRateMonitor',
'params': {
'logging_interval': 'step',
"learning_rate_logger": {
"target": "main.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
'cuda_callback': {'target': 'main.CUDACallback'},
"cuda_callback": {"target": "main.CUDACallback"},
}
if version.parse(pl.__version__) >= version.parse('1.4.0'):
default_callbacks_cfg.update(
{'checkpoint_callback': modelckpt_cfg}
)
if version.parse(pl.__version__) >= version.parse("1.4.0"):
default_callbacks_cfg.update({"checkpoint_callback": modelckpt_cfg})
if 'callbacks' in lightning_config:
if "callbacks" in lightning_config:
callbacks_cfg = lightning_config.callbacks
else:
callbacks_cfg = OmegaConf.create()
if 'metrics_over_trainsteps_checkpoint' in callbacks_cfg:
if "metrics_over_trainsteps_checkpoint" in callbacks_cfg:
print(
'Caution: Saving checkpoints every n train steps without deleting. This might require some free space.'
"Caution: Saving checkpoints every n train steps without deleting. This might require some free space."
)
default_metrics_over_trainsteps_ckpt_dict = {
'metrics_over_trainsteps_checkpoint': {
'target': 'pytorch_lightning.callbacks.ModelCheckpoint',
'params': {
'dirpath': os.path.join(
ckptdir, 'trainstep_checkpoints'
),
'filename': '{epoch:06}-{step:09}',
'verbose': True,
'save_top_k': -1,
'every_n_train_steps': 10000,
'save_weights_only': True,
"metrics_over_trainsteps_checkpoint": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": os.path.join(ckptdir, "trainstep_checkpoints"),
"filename": "{epoch:06}-{step:09}",
"verbose": True,
"save_top_k": -1,
"every_n_train_steps": 10000,
"save_weights_only": True,
},
}
}
default_callbacks_cfg.update(
default_metrics_over_trainsteps_ckpt_dict
)
default_callbacks_cfg.update(default_metrics_over_trainsteps_ckpt_dict)
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
if 'ignore_keys_callback' in callbacks_cfg and hasattr(
trainer_opt, 'resume_from_checkpoint'
):
callbacks_cfg.ignore_keys_callback.params[
'ckpt_path'
] = trainer_opt.resume_from_checkpoint
elif 'ignore_keys_callback' in callbacks_cfg:
del callbacks_cfg['ignore_keys_callback']
if "ignore_keys_callback" in callbacks_cfg and hasattr(trainer_opt, "resume_from_checkpoint"):
callbacks_cfg.ignore_keys_callback.params["ckpt_path"] = trainer_opt.resume_from_checkpoint
elif "ignore_keys_callback" in callbacks_cfg:
del callbacks_cfg["ignore_keys_callback"]
trainer_kwargs['callbacks'] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
trainer_kwargs['max_steps'] = trainer_opt.max_steps
trainer_kwargs["callbacks"] = [instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg]
trainer_kwargs["max_steps"] = trainer_opt.max_steps
if hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
trainer_opt.accelerator = 'mps'
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
trainer_opt.accelerator = "mps"
trainer_opt.detect_anomaly = False
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
@@ -882,11 +815,9 @@ if __name__ == '__main__':
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
print('#### Data #####')
print("#### Data #####")
for k in data.datasets:
print(
f'{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}'
)
print(f"{k}, {data.datasets[k].__class__.__name__}, {len(data.datasets[k])}")
# configure learning rate
bs, base_lr = (
@@ -894,24 +825,20 @@ if __name__ == '__main__':
config.model.base_learning_rate,
)
if not cpu:
gpus = str(lightning_config.trainer.gpus).strip(', ').split(',')
gpus = str(lightning_config.trainer.gpus).strip(", ").split(",")
ngpu = len(gpus)
else:
ngpu = 1
if 'accumulate_grad_batches' in lightning_config.trainer:
accumulate_grad_batches = (
lightning_config.trainer.accumulate_grad_batches
)
if "accumulate_grad_batches" in lightning_config.trainer:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
else:
accumulate_grad_batches = 1
print(f'accumulate_grad_batches = {accumulate_grad_batches}')
lightning_config.trainer.accumulate_grad_batches = (
accumulate_grad_batches
)
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
if opt.scale_lr:
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
'Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)'.format(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
model.learning_rate,
accumulate_grad_batches,
ngpu,
@@ -921,15 +848,15 @@ if __name__ == '__main__':
)
else:
model.learning_rate = base_lr
print('++++ NOT USING LR SCALING ++++')
print(f'Setting learning rate to {model.learning_rate:.2e}')
print("++++ NOT USING LR SCALING ++++")
print(f"Setting learning rate to {model.learning_rate:.2e}")
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print('Summoning checkpoint.')
ckpt_path = os.path.join(ckptdir, 'last.ckpt')
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
@@ -964,7 +891,7 @@ if __name__ == '__main__':
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank == 0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, 'debug_runs', name)
dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)
# if trainer.global_rank == 0:

View File

@@ -7,21 +7,30 @@ from functools import partial
import torch
def get_placeholder_loop(placeholder_string, embedder, use_bert):
new_placeholder = None
def get_placeholder_loop(placeholder_string, embedder, use_bert):
new_placeholder = None
while True:
if new_placeholder is None:
new_placeholder = input(f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: ")
new_placeholder = input(
f"Placeholder string {placeholder_string} was already used. Please enter a replacement string: "
)
else:
new_placeholder = input(f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: ")
new_placeholder = input(
f"Placeholder string '{new_placeholder}' maps to more than a single token. Please enter another string: "
)
token = get_bert_token_for_string(embedder.tknz_fn, new_placeholder) if use_bert else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
token = (
get_bert_token_for_string(embedder.tknz_fn, new_placeholder)
if use_bert
else get_clip_token_for_string(embedder.tokenizer, new_placeholder)
)
if token is not None:
return new_placeholder, token
def get_clip_token_for_string(tokenizer, string):
batch_encoding = tokenizer(
string,
@@ -30,7 +39,7 @@ def get_clip_token_for_string(tokenizer, string):
return_length=True,
return_overflowing_tokens=False,
padding="max_length",
return_tensors="pt"
return_tensors="pt",
)
tokens = batch_encoding["input_ids"]
@@ -40,6 +49,7 @@ def get_clip_token_for_string(tokenizer, string):
return None
def get_bert_token_for_string(tokenizer, string):
token = tokenizer(string)
if torch.count_nonzero(token) == 3:
@@ -49,22 +59,17 @@ def get_bert_token_for_string(tokenizer, string):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--root_dir",
type=str,
default='.',
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'."
default=".",
help="Path to the InvokeAI install directory containing 'models', 'outputs' and 'configs'.",
)
parser.add_argument(
"--manager_ckpts",
type=str,
nargs="+",
required=True,
help="Paths to a set of embedding managers to be merged."
"--manager_ckpts", type=str, nargs="+", required=True, help="Paths to a set of embedding managers to be merged."
)
parser.add_argument(
@@ -75,13 +80,14 @@ if __name__ == "__main__":
)
parser.add_argument(
"-sd", "--use_bert",
"-sd",
"--use_bert",
action="store_true",
help="Flag to denote that we are not merging stable diffusion embeddings"
help="Flag to denote that we are not merging stable diffusion embeddings",
)
args = parser.parse_args()
Globals.root=args.root_dir
Globals.root = args.root_dir
if args.use_bert:
embedder = BERTEmbedder(n_embed=1280, n_layer=32).cuda()

View File

@@ -10,12 +10,13 @@ from PIL import Image
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
rescale = lambda x: (x + 1.) / 2.
rescale = lambda x: (x + 1.0) / 2.0
def custom_to_pil(x):
x = x.detach().cpu()
x = torch.clamp(x, -1., 1.)
x = (x + 1.) / 2.
x = torch.clamp(x, -1.0, 1.0)
x = (x + 1.0) / 2.0
x = x.permute(1, 2, 0).numpy()
x = (255 * x).astype(np.uint8)
x = Image.fromarray(x)
@@ -51,49 +52,51 @@ def logs2pil(logs, keys=["sample"]):
@torch.no_grad()
def convsample(model, shape, return_intermediates=True,
verbose=True,
make_prog_row=False):
def convsample(model, shape, return_intermediates=True, verbose=True, make_prog_row=False):
if not make_prog_row:
return model.p_sample_loop(None, shape,
return_intermediates=return_intermediates, verbose=verbose)
return model.p_sample_loop(None, shape, return_intermediates=return_intermediates, verbose=verbose)
else:
return model.progressive_denoising(
None, shape, verbose=True
)
return model.progressive_denoising(None, shape, verbose=True)
@torch.no_grad()
def convsample_ddim(model, steps, shape, eta=1.0
):
def convsample_ddim(model, steps, shape, eta=1.0):
ddim = DDIMSampler(model)
bs = shape[0]
shape = shape[1:]
samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, eta=eta, verbose=False,)
samples, intermediates = ddim.sample(
steps,
batch_size=bs,
shape=shape,
eta=eta,
verbose=False,
)
return samples, intermediates
@torch.no_grad()
def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=None, eta=1.0,):
def make_convolutional_sample(
model,
batch_size,
vanilla=False,
custom_steps=None,
eta=1.0,
):
log = dict()
shape = [batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size]
shape = [
batch_size,
model.model.diffusion_model.in_channels,
model.model.diffusion_model.image_size,
model.model.diffusion_model.image_size,
]
with model.ema_scope("Plotting"):
t0 = time.time()
if vanilla:
sample, progrow = convsample(model, shape,
make_prog_row=True)
sample, progrow = convsample(model, shape, make_prog_row=True)
else:
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape,
eta=eta)
sample, intermediates = convsample_ddim(model, steps=custom_steps, shape=shape, eta=eta)
t1 = time.time()
@@ -101,32 +104,32 @@ def make_convolutional_sample(model, batch_size, vanilla=False, custom_steps=Non
log["sample"] = x_sample
log["time"] = t1 - t0
log['throughput'] = sample.shape[0] / (t1 - t0)
log["throughput"] = sample.shape[0] / (t1 - t0)
print(f'Throughput for this batch: {log["throughput"]}')
return log
def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None, n_samples=50000, nplog=None):
if vanilla:
print(f'Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.')
print(f"Using Vanilla DDPM sampling with {model.num_timesteps} sampling steps.")
else:
print(f'Using DDIM sampling with {custom_steps} sampling steps and eta={eta}')
print(f"Using DDIM sampling with {custom_steps} sampling steps and eta={eta}")
tstart = time.time()
n_saved = len(glob.glob(os.path.join(logdir,'*.png')))-1
n_saved = len(glob.glob(os.path.join(logdir, "*.png"))) - 1
# path = logdir
if model.cond_stage_model is None:
all_images = []
print(f"Running unconditional sampling for {n_samples} samples")
for _ in trange(n_samples // batch_size, desc="Sampling Batches (unconditional)"):
logs = make_convolutional_sample(model, batch_size=batch_size,
vanilla=vanilla, custom_steps=custom_steps,
eta=eta)
logs = make_convolutional_sample(
model, batch_size=batch_size, vanilla=vanilla, custom_steps=custom_steps, eta=eta
)
n_saved = save_logs(logs, logdir, n_saved=n_saved, key="sample")
all_images.extend([custom_to_np(logs["sample"])])
if n_saved >= n_samples:
print(f'Finish after generating {n_saved} samples')
print(f"Finish after generating {n_saved} samples")
break
all_img = np.concatenate(all_images, axis=0)
all_img = all_img[:n_samples]
@@ -135,7 +138,7 @@ def run(model, logdir, batch_size=50, vanilla=False, custom_steps=None, eta=None
np.savez(nppath, all_img)
else:
raise NotImplementedError('Currently only sampling for unconditional models supported.')
raise NotImplementedError("Currently only sampling for unconditional models supported.")
print(f"sampling of {n_saved} images finished in {(time.time() - tstart) / 60.:.2f} minutes.")
@@ -168,58 +171,33 @@ def get_parser():
nargs="?",
help="load from logdir or checkpoint in logdir",
)
parser.add_argument(
"-n",
"--n_samples",
type=int,
nargs="?",
help="number of samples to draw",
default=50000
)
parser.add_argument("-n", "--n_samples", type=int, nargs="?", help="number of samples to draw", default=50000)
parser.add_argument(
"-e",
"--eta",
type=float,
nargs="?",
help="eta for ddim sampling (0.0 yields deterministic sampling)",
default=1.0
default=1.0,
)
parser.add_argument(
"-v",
"--vanilla_sample",
default=False,
action='store_true',
action="store_true",
help="vanilla sampling (default option is DDIM sampling)?",
)
parser.add_argument("-l", "--logdir", type=str, nargs="?", help="extra logdir", default="none")
parser.add_argument(
"-l",
"--logdir",
type=str,
nargs="?",
help="extra logdir",
default="none"
)
parser.add_argument(
"-c",
"--custom_steps",
type=int,
nargs="?",
help="number of steps for ddim and fastdpm sampling",
default=50
)
parser.add_argument(
"--batch_size",
type=int,
nargs="?",
help="the bs",
default=10
"-c", "--custom_steps", type=int, nargs="?", help="number of steps for ddim and fastdpm sampling", default=50
)
parser.add_argument("--batch_size", type=int, nargs="?", help="the bs", default=10)
return parser
def load_model_from_config(config, sd):
model = instantiate_from_config(config)
model.load_state_dict(sd,strict=False)
model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
@@ -233,8 +211,7 @@ def load_model(config, ckpt, gpu, eval_mode):
else:
pl_sd = {"state_dict": None}
global_step = None
model = load_model_from_config(config.model,
pl_sd["state_dict"])
model = load_model_from_config(config.model, pl_sd["state_dict"])
return model, global_step
@@ -253,9 +230,9 @@ if __name__ == "__main__":
if os.path.isfile(opt.resume):
# paths = opt.resume.split("/")
try:
logdir = '/'.join(opt.resume.split('/')[:-1])
logdir = "/".join(opt.resume.split("/")[:-1])
# idx = len(paths)-paths[::-1].index("logs")+1
print(f'Logdir is {logdir}')
print(f"Logdir is {logdir}")
except ValueError:
paths = opt.resume.split("/")
idx = -2 # take a guess: path/to/logdir/checkpoints/model.ckpt
@@ -278,7 +255,8 @@ if __name__ == "__main__":
if opt.logdir != "none":
locallog = logdir.split(os.sep)[-1]
if locallog == "": locallog = logdir.split(os.sep)[-2]
if locallog == "":
locallog = logdir.split(os.sep)[-2]
print(f"Switching logdir from '{logdir}' to '{os.path.join(opt.logdir, locallog)}'")
logdir = os.path.join(opt.logdir, locallog)
@@ -301,13 +279,19 @@ if __name__ == "__main__":
sampling_file = os.path.join(logdir, "sampling_config.yaml")
sampling_conf = vars(opt)
with open(sampling_file, 'w') as f:
with open(sampling_file, "w") as f:
yaml.dump(sampling_conf, f, default_flow_style=False)
print(sampling_conf)
run(model, imglogdir, eta=opt.eta,
vanilla=opt.vanilla_sample, n_samples=opt.n_samples, custom_steps=opt.custom_steps,
batch_size=opt.batch_size, nplog=numpylogdir)
run(
model,
imglogdir,
eta=opt.eta,
vanilla=opt.vanilla_sample,
n_samples=opt.n_samples,
custom_steps=opt.custom_steps,
batch_size=opt.batch_size,
nplog=numpylogdir,
)
print("done.")

View File

@@ -13,21 +13,26 @@ def search_bruteforce(searcher):
return searcher.score_brute_force().build()
def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search):
return searcher.tree(num_leaves=num_leaves,
num_leaves_to_search=num_leaves_to_search,
training_sample_size=partioning_trainsize). \
score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
def search_partioned_ah(
searcher, dims_per_block, aiq_threshold, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
):
return (
searcher.tree(
num_leaves=num_leaves, num_leaves_to_search=num_leaves_to_search, training_sample_size=partioning_trainsize
)
.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold)
.reorder(reorder_k)
.build()
)
def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k):
return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(
reorder_k).build()
return (
searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build()
)
def load_datapool(dpath):
def load_single_file(saved_embeddings):
compressed = np.load(saved_embeddings)
database = {key: compressed[key] for key in compressed.files}
@@ -35,23 +40,26 @@ def load_datapool(dpath):
def load_multi_files(data_archive):
database = {key: [] for key in data_archive[0].files}
for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'):
for d in tqdm(data_archive, desc=f"Loading datapool from {len(data_archive)} individual files."):
for key in d.files:
database[key].append(d[key])
return database
print(f'Load saved patch embedding from "{dpath}"')
file_content = glob.glob(os.path.join(dpath, '*.npz'))
file_content = glob.glob(os.path.join(dpath, "*.npz"))
if len(file_content) == 1:
data_pool = load_single_file(file_content[0])
elif len(file_content) > 1:
data = [np.load(f) for f in file_content]
prefetched_data = parallel_data_prefetch(load_multi_files, data,
n_proc=min(len(data), cpu_count()), target_data_type='dict')
prefetched_data = parallel_data_prefetch(
load_multi_files, data, n_proc=min(len(data), cpu_count()), target_data_type="dict"
)
data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()}
data_pool = {
key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()
}
else:
raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?')
@@ -59,16 +67,17 @@ def load_datapool(dpath):
return data_pool
def train_searcher(opt,
metric='dot_product',
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,):
def train_searcher(
opt,
metric="dot_product",
partioning_trainsize=None,
reorder_k=None,
# todo tune
aiq_thld=0.2,
dims_per_block=2,
num_leaves=None,
num_leaves_to_search=None,
):
data_pool = load_datapool(opt.database)
k = opt.knn
@@ -77,71 +86,83 @@ def train_searcher(opt,
# normalize
# embeddings =
searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric)
pool_size = data_pool['embedding'].shape[0]
searcher = scann.scann_ops_pybind.builder(
data_pool["embedding"] / np.linalg.norm(data_pool["embedding"], axis=1)[:, np.newaxis], k, metric
)
pool_size = data_pool["embedding"].shape[0]
print(*(['#'] * 100))
print('Initializing scaNN searcher with the following values:')
print(f'k: {k}')
print(f'metric: {metric}')
print(f'reorder_k: {reorder_k}')
print(f'anisotropic_quantization_threshold: {aiq_thld}')
print(f'dims_per_block: {dims_per_block}')
print(*(['#'] * 100))
print('Start training searcher....')
print(f'N samples in pool is {pool_size}')
print(*(["#"] * 100))
print("Initializing scaNN searcher with the following values:")
print(f"k: {k}")
print(f"metric: {metric}")
print(f"reorder_k: {reorder_k}")
print(f"anisotropic_quantization_threshold: {aiq_thld}")
print(f"dims_per_block: {dims_per_block}")
print(*(["#"] * 100))
print("Start training searcher....")
print(f"N samples in pool is {pool_size}")
# this reflects the recommended design choices proposed at
# https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md
if pool_size < 2e4:
print('Using brute force search.')
print("Using brute force search.")
searcher = search_bruteforce(searcher)
elif 2e4 <= pool_size and pool_size < 1e5:
print('Using asymmetric hashing search and reordering.')
print("Using asymmetric hashing search and reordering.")
searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
else:
print('Using using partioning, asymmetric hashing search and reordering.')
print("Using using partioning, asymmetric hashing search and reordering.")
if not partioning_trainsize:
partioning_trainsize = data_pool['embedding'].shape[0] // 10
partioning_trainsize = data_pool["embedding"].shape[0] // 10
if not num_leaves:
num_leaves = int(np.sqrt(pool_size))
if not num_leaves_to_search:
num_leaves_to_search = max(num_leaves // 20, 1)
print('Partitioning params:')
print(f'num_leaves: {num_leaves}')
print(f'num_leaves_to_search: {num_leaves_to_search}')
print("Partitioning params:")
print(f"num_leaves: {num_leaves}")
print(f"num_leaves_to_search: {num_leaves_to_search}")
# self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k)
searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k,
partioning_trainsize, num_leaves, num_leaves_to_search)
searcher = search_partioned_ah(
searcher, dims_per_block, aiq_thld, reorder_k, partioning_trainsize, num_leaves, num_leaves_to_search
)
print('Finish training searcher')
print("Finish training searcher")
searcher_savedir = opt.target_path
os.makedirs(searcher_savedir, exist_ok=True)
searcher.serialize(searcher_savedir)
print(f'Saved trained searcher under "{searcher_savedir}"')
if __name__ == '__main__':
if __name__ == "__main__":
sys.path.append(os.getcwd())
parser = argparse.ArgumentParser()
parser.add_argument('--database',
'-d',
default='data/rdm/retrieval_databases/openimages',
type=str,
help='path to folder containing the clip feature of the database')
parser.add_argument('--target_path',
'-t',
default='data/rdm/searchers/openimages',
type=str,
help='path to the target folder where the searcher shall be stored.')
parser.add_argument('--knn',
'-k',
default=20,
type=int,
help='number of nearest neighbors, for which the searcher shall be optimized')
parser.add_argument(
"--database",
"-d",
default="data/rdm/retrieval_databases/openimages",
type=str,
help="path to folder containing the clip feature of the database",
)
parser.add_argument(
"--target_path",
"-t",
default="data/rdm/searchers/openimages",
type=str,
help="path to the target folder where the searcher shall be stored.",
)
parser.add_argument(
"--knn",
"-k",
default=20,
type=int,
help="number of nearest neighbors, for which the searcher shall be optimized",
)
opt, _ = parser.parse_known_args()
opt, _ = parser.parse_known_args()
train_searcher(opt,)
train_searcher(
opt,
)

View File

@@ -15,10 +15,11 @@ from contextlib import contextmanager, nullcontext
import k_diffusion as K
import torch.nn as nn
from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler
from ldm.invoke.devices import choose_torch_device
from ldm.invoke.devices import choose_torch_device
def chunk(it, size):
it = iter(it)
@@ -53,23 +54,19 @@ def main():
type=str,
nargs="?",
default="a painting of a virus monster playing guitar",
help="the prompt to render"
help="the prompt to render",
)
parser.add_argument(
"--outdir",
type=str,
nargs="?",
help="dir to write results to",
default="outputs/txt2img-samples"
"--outdir", type=str, nargs="?", help="dir to write results to", default="outputs/txt2img-samples"
)
parser.add_argument(
"--skip_grid",
action='store_true',
action="store_true",
help="do not save a grid, only individual samples. Helpful when evaluating lots of samples",
)
parser.add_argument(
"--skip_save",
action='store_true',
action="store_true",
help="do not save individual samples. For speed measurements.",
)
parser.add_argument(
@@ -80,22 +77,22 @@ def main():
)
parser.add_argument(
"--plms",
action='store_true',
action="store_true",
help="use plms sampling",
)
parser.add_argument(
"--klms",
action='store_true',
action="store_true",
help="use klms sampling",
)
parser.add_argument(
"--laion400m",
action='store_true',
action="store_true",
help="uses the LAION400M model",
)
parser.add_argument(
"--fixed_code",
action='store_true',
action="store_true",
help="if enabled, uses the same starting code across samples ",
)
parser.add_argument(
@@ -176,11 +173,7 @@ def main():
help="the seed (for reproducible sampling)",
)
parser.add_argument(
"--precision",
type=str,
help="evaluate at this precision",
choices=["full", "autocast"],
default="autocast"
"--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast"
)
opt = parser.parse_args()
@@ -190,17 +183,17 @@ def main():
opt.ckpt = "models/ldm/text2img-large/model.ckpt"
opt.outdir = "outputs/txt2img-samples-laion400m"
config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
seed_everything(opt.seed)
device = torch.device(choose_torch_device())
model = model.to(device)
model = model.to(device)
#for klms
# for klms
model_wrap = K.external.CompVisDenoiser(model)
class CFGDenoiser(nn.Module):
def __init__(self, model):
super().__init__()
@@ -232,10 +225,10 @@ def main():
print(f"reading prompts from {opt.from_file}")
with open(opt.from_file, "r") as f:
data = f.read().splitlines()
if (len(data) >= batch_size):
if len(data) >= batch_size:
data = list(chunk(data, batch_size))
else:
while (len(data) < batch_size):
while len(data) < batch_size:
data.append(data[-1])
data = [data]
@@ -247,14 +240,14 @@ def main():
start_code = None
if opt.fixed_code:
shape = [opt.n_samples, opt.C, opt.H // opt.f, opt.W // opt.f]
if device.type == 'mps':
start_code = torch.randn(shape, device='cpu').to(device)
if device.type == "mps":
start_code = torch.randn(shape, device="cpu").to(device)
else:
torch.randn(shape, device=device)
precision_scope = autocast if opt.precision=="autocast" else nullcontext
if device.type in ['mps', 'cpu']:
precision_scope = nullcontext # have to use f32 on mps
precision_scope = autocast if opt.precision == "autocast" else nullcontext
if device.type in ["mps", "cpu"]:
precision_scope = nullcontext # have to use f32 on mps
with torch.no_grad():
with precision_scope(device.type):
with model.ema_scope():
@@ -271,23 +264,25 @@ def main():
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
if not opt.klms:
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code)
samples_ddim, _ = sampler.sample(
S=opt.ddim_steps,
conditioning=c,
batch_size=opt.n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=opt.scale,
unconditional_conditioning=uc,
eta=opt.ddim_eta,
x_T=start_code,
)
else:
sigmas = model_wrap.get_sigmas(opt.ddim_steps)
if start_code:
x = start_code
else:
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
x = torch.randn([opt.n_samples, *shape], device=device) * sigmas[0] # for GPU draw
model_wrap_cfg = CFGDenoiser(model_wrap)
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
extra_args = {"cond": c, "uncond": uc, "cond_scale": opt.scale}
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
x_samples_ddim = model.decode_first_stage(samples_ddim)
@@ -295,9 +290,10 @@ def main():
if not opt.skip_save:
for x_sample in x_samples_ddim:
x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(sample_path, f"{base_count:05}.png"))
os.path.join(sample_path, f"{base_count:05}.png")
)
base_count += 1
if not opt.skip_grid:
@@ -306,18 +302,17 @@ def main():
if not opt.skip_grid:
# additionally, save as grid
grid = torch.stack(all_samples, 0)
grid = rearrange(grid, 'n b c h w -> (n b) c h w')
grid = rearrange(grid, "n b c h w -> (n b) c h w")
grid = make_grid(grid, nrow=n_rows)
# to image
grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f'grid-{grid_count:04}.png'))
grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy()
Image.fromarray(grid.astype(np.uint8)).save(os.path.join(outpath, f"grid-{grid_count:04}.png"))
grid_count += 1
toc = time.time()
print(f"Your samples are ready and waiting for you here: \n{outpath} \n"
f" \nEnjoy.")
print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.")
if __name__ == "__main__":