#1843 - Add Export Default settings button (#2016)

* #1843 - Add Export Default settings button

* #1843 reformating units test

---------

Co-authored-by: Richard Pastirčák <richard.pastircak@student.tuke.sk>
This commit is contained in:
Richard Pastirčák
2023-12-06 21:58:17 +01:00
committed by GitHub
parent 3322b7264f
commit 3af0c6c658
2 changed files with 270 additions and 28 deletions

View File

@@ -1,4 +1,6 @@
import json
import os
import warnings
import torch
import time
import sys
@@ -35,6 +37,34 @@ from apps.stable_diffusion.src.utils import (
resampler_list,
)
# Names of all interactive fields that can be edited by user
all_gradio_labels = [
"txt2img_custom_model",
"custom_vae",
"prompt",
"negative_prompt",
"lora_weights",
"lora_hf_id",
"scheduler",
"save_metadata_to_png",
"save_metadata_to_json",
"height",
"width",
"steps",
"guidance_scale",
"Low VRAM",
"use_hiresfix",
"resample_type",
"hiresfix_height",
"hiresfix_width",
"hiresfix_strength",
"batch_count",
"batch_size",
"repeatable_seeds",
"seed",
"device",
]
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
init_iree_metal_target_platform = args.iree_metal_target_platform
@@ -313,6 +343,81 @@ def resource_path(relative_path):
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# This function export values for all fields that can be edited by user to the settings.json file in ui folder
def export_settings(*values):
settings_list = list(zip(all_gradio_labels, values))
settings = {}
for label, value in settings_list:
settings[label] = value
settings = {"txt2img": settings}
with open("./ui/settings.json", "w") as json_file:
json.dump(settings, json_file, indent=4)
# This function loads all values for all fields that can be edited by user from the settings.json file in ui folder
def load_settings():
try:
with open("./ui/settings.json", "r") as json_file:
loaded_settings = json.load(json_file)["txt2img"]
except (FileNotFoundError, KeyError):
warnings.warn(
"Settings.json file not found or 'txt2img' key is missing. Using default values for fields."
)
loaded_settings = (
{}
) # json file not existing or the data wasn't saved yet
return [
loaded_settings.get(
"txt2img_custom_model",
os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
),
loaded_settings.get(
"custom_vae",
os.path.basename(args.custom_vae) if args.custom_vae else "None",
),
loaded_settings.get("prompt", args.prompts[0]),
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
loaded_settings.get("lora_weights", "None"),
loaded_settings.get("lora_hf_id", ""),
loaded_settings.get("scheduler", args.scheduler),
loaded_settings.get(
"save_metadata_to_png", args.write_metadata_to_png
),
loaded_settings.get(
"save_metadata_to_json", args.save_metadata_to_json
),
loaded_settings.get("height", args.height),
loaded_settings.get("width", args.width),
loaded_settings.get("steps", args.steps),
loaded_settings.get("guidance_scale", args.guidance_scale),
loaded_settings.get("Low VRAM", args.ondemand),
loaded_settings.get("use_hiresfix", args.use_hiresfix),
loaded_settings.get("resample_type", args.resample_type),
loaded_settings.get("hiresfix_height", args.hiresfix_height),
loaded_settings.get("hiresfix_width", args.hiresfix_width),
loaded_settings.get("hiresfix_strength", args.hiresfix_strength),
loaded_settings.get("batch_count", args.batch_count),
loaded_settings.get("batch_size", args.batch_size),
loaded_settings.get("repeatable_seeds", args.repeatable_seeds),
loaded_settings.get("seed", args.seed),
loaded_settings.get("device", available_devices[0]),
]
# This function loads the user's exported default settings on the start of program
def onload_load_settings():
loaded_data = load_settings()
structured_data = settings_list = list(zip(all_gradio_labels, loaded_data))
return dict(structured_data)
default_settings = onload_load_settings()
with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
with gr.Row(elem_id="ui_title"):
nod_logo = Image.open(nodlogo_loc)
@@ -338,9 +443,9 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
label=f"Models",
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
elem_id="custom_model",
value=os.path.basename(args.ckpt_loc)
if args.ckpt_loc
else "stabilityai/stable-diffusion-2-1-base",
value=default_settings.get(
"txt2img_custom_model"
),
choices=get_custom_model_files()
+ predefined_models,
allow_custom_value=True,
@@ -355,9 +460,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
label=f"VAE Models",
info=t2i_vae_info,
elem_id="custom_model",
value=os.path.basename(args.custom_vae)
if args.custom_vae
else "None",
value=default_settings.get("custom_vae"),
choices=["None"]
+ get_custom_model_files("vae"),
allow_custom_value=True,
@@ -374,7 +477,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
with gr.Group(elem_id="prompt_box_outer"):
prompt = gr.Textbox(
label="Prompt",
value=args.prompts[0],
value=default_settings.get("prompt"),
lines=2,
elem_id="prompt_box",
)
@@ -385,7 +488,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=args.negative_prompts[0],
value=default_settings.get("negative_prompt"),
lines=2,
elem_id="negative_prompt_box",
)
@@ -400,7 +503,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
elem_id="lora_weights",
value="None",
value=default_settings.get("lora_weights"),
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
)
@@ -410,7 +513,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
value=default_settings.get("lora_hf_id"),
label="HuggingFace Model ID",
lines=3,
)
@@ -424,33 +527,37 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
scheduler = gr.Dropdown(
elem_id="scheduler",
label="Scheduler",
value=args.scheduler,
value=default_settings.get("scheduler"),
choices=scheduler_list,
allow_custom_value=True,
)
with gr.Column():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=args.write_metadata_to_png,
value=default_settings.get(
"save_metadata_to_png"
),
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=args.save_metadata_to_json,
value=default_settings.get(
"save_metadata_to_json"
),
interactive=True,
)
with gr.Row():
height = gr.Slider(
384,
768,
value=args.height,
value=default_settings.get("height"),
step=8,
label="Height",
)
width = gr.Slider(
384,
768,
value=args.width,
value=default_settings.get("width"),
step=8,
label="Width",
)
@@ -475,18 +582,22 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
with gr.Row():
with gr.Column(scale=3):
steps = gr.Slider(
1, 100, value=args.steps, step=1, label="Steps"
1,
100,
value=default_settings.get("steps"),
step=1,
label="Steps",
)
with gr.Column(scale=3):
guidance_scale = gr.Slider(
0,
50,
value=args.guidance_scale,
value=default_settings.get("guidance_scale"),
step=0.1,
label="CFG Scale",
)
ondemand = gr.Checkbox(
value=args.ondemand,
value=default_settings.get("Low VRAM"),
label="Low VRAM",
interactive=True,
)
@@ -495,7 +606,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
batch_count = gr.Slider(
1,
100,
value=args.batch_count,
value=default_settings.get("batch_count"),
step=1,
label="Batch Count",
interactive=True,
@@ -506,23 +617,23 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
4,
value=args.batch_size,
step=1,
label="Batch Size",
label=default_settings.get("batch_size"),
interactive=True,
)
repeatable_seeds = gr.Checkbox(
args.repeatable_seeds,
default_settings.get("repeatable_seeds"),
label="Repeatable Seeds",
)
with gr.Accordion(label="Hires Fix Options", open=False):
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
value=default_settings.get("use_hiresfix"),
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
value=default_settings.get("resample_type"),
choices=resampler_list,
label="Resample Type",
allow_custom_value=False,
@@ -530,34 +641,34 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
value=default_settings.get("hiresfix_height"),
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
value=default_settings.get("hiresfix_width"),
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
value=default_settings.get("hiresfix_strength"),
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
value=default_settings.get("seed"),
label="Seed",
info="An integer or a JSON list of integers, -1 for random",
)
device = gr.Dropdown(
elem_id="device",
label="Device",
value=available_devices[0],
value=default_settings.get("device"),
choices=available_devices,
allow_custom_value=True,
)
@@ -608,6 +719,75 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
txt2img_sendto_upscaler = gr.Button(
value="SendTo Upscaler"
)
with gr.Row():
with gr.Column(scale=2):
export_defaults = gr.Button(
value="Load Default Settings"
)
export_defaults.click(
fn=load_settings,
inputs=[],
outputs=[
txt2img_custom_model,
custom_vae,
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
height,
width,
steps,
guidance_scale,
ondemand,
use_hiresfix,
resample_type,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
batch_count,
batch_size,
repeatable_seeds,
seed,
device,
],
)
with gr.Column(scale=2):
export_defaults = gr.Button(
value="Export Default Settings"
)
export_defaults.click(
fn=export_settings,
inputs=[
txt2img_custom_model,
custom_vae,
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
height,
width,
steps,
guidance_scale,
ondemand,
use_hiresfix,
resample_type,
hiresfix_height,
hiresfix_width,
hiresfix_strength,
batch_count,
batch_size,
repeatable_seeds,
seed,
device,
],
outputs=[],
)
kwargs = dict(
fn=txt2img_inf,

View File

@@ -0,0 +1,62 @@
import unittest
from unittest.mock import mock_open, patch
from apps.stable_diffusion.web.ui.txt2img_ui import (
export_settings,
load_settings,
all_gradio_labels,
)
class TestExportSettings(unittest.TestCase):
@patch("builtins.open", new_callable=mock_open)
@patch("json.dump")
def test_export_settings(self, mock_json_dump, mock_file):
test_values = ["value1", "value2", "value3"]
expected_output = {
"txt2img": {
label: value
for label, value in zip(all_gradio_labels, test_values)
}
}
export_settings(*test_values)
mock_file.assert_called_once_with("./ui/settings.json", "w")
mock_json_dump.assert_called_once_with(
expected_output, mock_file(), indent=4
)
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch(
"builtins.open",
new_callable=mock_open,
read_data='{"txt2img": {"some_setting": "some_value"}}',
)
def test_load_settings_file_exists(self, mock_file, mock_json_load):
mock_json_load.return_value = {
"txt2img": {
"txt2img_custom_model": "custom_model_value",
"custom_vae": "custom_vae_value",
}
}
settings = load_settings()
self.assertEqual(settings[0], "custom_model_value")
self.assertEqual(settings[1], "custom_vae_value")
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch("builtins.open", side_effect=FileNotFoundError)
def test_load_settings_file_not_found(self, mock_file, mock_json_load):
settings = load_settings()
default_lora_weights = "None"
self.assertEqual(settings[4], default_lora_weights)
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
@patch("builtins.open", new_callable=mock_open, read_data="{}")
def test_load_settings_key_error(self, mock_file, mock_json_load):
mock_json_load.return_value = {}
settings = load_settings()
default_lora_weights = "None"
self.assertEqual(settings[4], default_lora_weights)