mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
SD - Implement seed arrays for batch runs (#1690)
* SD Scripts and UI tabs that support batch_count can now take a string containing a JSON array, or a list of integers, as their seed input. * Each batch in a run will now take the seed specified at the corresponding array index if one exists. If there is no seed at that index, the seed value will be treated as -1 and a random seed will be assigned at that position. If an integer rather than a list or json array has been, everything works as before. * UI seed input controls are now Textboxes with info lines about the seed formats allowed. * UI error handling updated to be more helpful if the seed input is invalid.
This commit is contained in:
@@ -28,6 +28,7 @@ from apps.stable_diffusion.src.utils.utils import (
|
||||
fetch_and_update_base_model_id,
|
||||
get_path_to_diffusers_checkpoint,
|
||||
sanitize_seed,
|
||||
parse_seed_input,
|
||||
batch_seeds,
|
||||
get_path_stem,
|
||||
get_extended_name,
|
||||
|
||||
@@ -66,9 +66,9 @@ p.add_argument(
|
||||
|
||||
p.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
type=str,
|
||||
default=-1,
|
||||
help="The seed to use. -1 for a random one.",
|
||||
help="The seed or list of seeds to use. -1 for a random one.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
|
||||
@@ -727,7 +727,8 @@ def fetch_and_update_base_model_id(model_to_run, base_model=""):
|
||||
|
||||
# Generate and return a new seed if the provided one is not in the
|
||||
# supported range (including -1)
|
||||
def sanitize_seed(seed):
|
||||
def sanitize_seed(seed: int | str):
|
||||
seed = int(seed)
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
@@ -735,20 +736,48 @@ def sanitize_seed(seed):
|
||||
return seed
|
||||
|
||||
|
||||
# Generate a set of seeds, using as the first seed of the set,
|
||||
# optionally using it as the rng seed for subsequent seeds in the set
|
||||
def batch_seeds(seed, batch_count, repeatable=False):
|
||||
# use the passed seed as the initial seed of the batch
|
||||
seeds = [sanitize_seed(seed)]
|
||||
# take a seed expression in an input format and convert it to
|
||||
# a list of integers, where possible
|
||||
def parse_seed_input(seed_input: str | list | int):
|
||||
if isinstance(seed_input, str):
|
||||
try:
|
||||
seed_input = json.loads(seed_input)
|
||||
except (ValueError, TypeError):
|
||||
seed_input = None
|
||||
|
||||
if isinstance(seed_input, int):
|
||||
return [seed_input]
|
||||
|
||||
if isinstance(seed_input, list) and all(
|
||||
type(seed) is int for seed in seed_input
|
||||
):
|
||||
return seed_input
|
||||
|
||||
raise TypeError(
|
||||
"Seed input must be an integer or an array of integers in JSON format"
|
||||
)
|
||||
|
||||
|
||||
# Generate a set of seeds from an input expression for batch_count batches,
|
||||
# optionally using that input as the rng seed for any randomly generated seeds.
|
||||
def batch_seeds(
|
||||
seed_input: str | list | int, batch_count: int, repeatable=False
|
||||
):
|
||||
# turn the input into a list if possible
|
||||
seeds = parse_seed_input(seed_input)
|
||||
|
||||
# slice or pad the list to be of batch_count length
|
||||
seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds))
|
||||
|
||||
if repeatable:
|
||||
# use the initial seed as the rng generator seed
|
||||
# set seed for the rng based on what we have so far
|
||||
saved_random_state = random_getstate()
|
||||
seed_random(seed)
|
||||
if all(seed < 0 for seed in seeds):
|
||||
seeds[0] = sanitize_seed(seeds[0])
|
||||
seed_random(str(seeds))
|
||||
|
||||
# generate the additional seeds
|
||||
for i in range(1, batch_count):
|
||||
seeds.append(sanitize_seed(-1))
|
||||
# generate any seeds that are unspecified
|
||||
seeds = [sanitize_seed(seed) for seed in seeds]
|
||||
|
||||
if repeatable:
|
||||
# reset the rng back to normal
|
||||
|
||||
@@ -50,7 +50,7 @@ def img2img_inf(
|
||||
steps: int,
|
||||
strength: float,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -230,10 +230,12 @@ def img2img_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = []
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
extra_info = {"STRENGTH": strength}
|
||||
text_output = ""
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
@@ -617,8 +619,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
|
||||
@@ -49,7 +49,7 @@ def inpaint_inf(
|
||||
inpaint_full_res_padding: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -181,10 +181,13 @@ def inpaint_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
@@ -514,8 +517,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from apps.stable_diffusion.scripts import lora_train
|
||||
from apps.stable_diffusion.src import prompt_examples, args
|
||||
from apps.stable_diffusion.src import prompt_examples, args, utils
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -168,7 +168,9 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
value=utils.parse_seed_input(args.seed)[0],
|
||||
precision=0,
|
||||
label="Seed",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
|
||||
@@ -49,7 +49,7 @@ def outpaint_inf(
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -178,7 +178,10 @@ def outpaint_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
left = True if "left" in directions else False
|
||||
right = True if "right" in directions else False
|
||||
@@ -542,8 +545,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
|
||||
@@ -46,7 +46,7 @@ def txt2img_inf(
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -178,8 +178,11 @@ def txt2img_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
text_output = ""
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
@@ -481,8 +484,10 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
|
||||
@@ -42,7 +42,7 @@ def upscaler_inf(
|
||||
steps: int,
|
||||
noise_level: int,
|
||||
guidance_scale: float,
|
||||
seed: int,
|
||||
seed: str,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
@@ -177,8 +177,11 @@ def upscaler_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
extra_info = {"NOISE LEVEL": noise_level}
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
low_res_img = image
|
||||
@@ -534,8 +537,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Number(
|
||||
value=args.seed, precision=0, label="Seed"
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
|
||||
Reference in New Issue
Block a user