mostly ported to new manager API; needs testing

This commit is contained in:
Lincoln Stein
2023-05-06 00:44:12 -04:00
parent af8c7c7d29
commit e0214a32bc
12 changed files with 353 additions and 332 deletions

View File

@@ -123,51 +123,51 @@ class InvokeAIGenerator(metaclass=ABCMeta):
generator_args.update(keyword_args)
model_info = self.model_info
model_name = model_info['model_name']
model:StableDiffusionGeneratorPipeline = model_info['model']
model_hash = model_info['hash']
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
model_name = model_info.name
model_hash = model_info.hash
with model_info.context as model:
scheduler: Scheduler = self.get_scheduler(
model=model,
scheduler_name=generator_args.get('scheduler')
)
uc, c, extra_conditioning_info = get_uc_and_c_and_ec(prompt,model=model)
gen_class = self._generator_class()
generator = gen_class(model, self.params.precision)
if self.params.variation_amount > 0:
generator.set_variation(generator_args.get('seed'),
generator_args.get('variation_amount'),
generator_args.get('with_variations')
)
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
if isinstance(model, DiffusionPipeline):
for component in [model.unet, model.vae]:
configure_model_padding(component,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
else:
configure_model_padding(model,
generator_args.get('seamless',False),
generator_args.get('seamless_axes')
)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)
iteration_count = range(iterations) if iterations else itertools.count(start=0, step=1)
for i in iteration_count:
results = generator.generate(prompt,
conditioning=(uc, c, extra_conditioning_info),
step_callback=step_callback,
sampler=scheduler,
**generator_args,
)
output = InvokeAIGeneratorOutput(
image=results[0][0],
seed=results[0][1],
attention_maps_images=results[0][2],
model_hash = model_hash,
params=Namespace(model_name=model_name,**generator_args),
)
if callback:
callback(output)
yield output
@classmethod
@@ -275,7 +275,6 @@ class Embiggen(Txt2Img):
from .embiggen import Embiggen
return Embiggen
class Generator:
downsampling_factor: int
latent_channels: int