mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
More fixes for demo.
This commit is contained in:
@@ -101,7 +101,7 @@ class StableDiffusion:
|
||||
external_weights: str = "safetensors",
|
||||
progress=gr.Progress(),
|
||||
):
|
||||
progress(0, desc="Initializing pipeline...")
|
||||
progress(None, desc="Initializing pipeline...")
|
||||
self.ui_device = device
|
||||
self.precision = precision
|
||||
self.compiled_pipeline = False
|
||||
@@ -181,7 +181,7 @@ class StableDiffusion:
|
||||
external_weights=external_weights,
|
||||
custom_vae=custom_vae,
|
||||
)
|
||||
progress(1, desc="Pipeline initialized!...")
|
||||
progress(None, desc="Pipeline initialized!...")
|
||||
gc.collect()
|
||||
|
||||
def prepare_pipe(
|
||||
@@ -245,18 +245,18 @@ class StableDiffusion:
|
||||
"diffusion_pytorch_model.safetensors",
|
||||
)
|
||||
weights[key] = save_irpa(vae_weights_path, "vae.")
|
||||
progress(0, desc=f"Preparing pipeline for {self.ui_device}...")
|
||||
progress(None, desc=f"Preparing pipeline for {self.ui_device}...")
|
||||
|
||||
vmfbs, weights = self.sd_pipe.check_prepared(
|
||||
mlirs, vmfbs, weights, interactive=False
|
||||
)
|
||||
progress(1, desc=f"Artifacts ready!")
|
||||
progress(0, desc=f"Loading pipeline on device {self.ui_device}...")
|
||||
progress(None, desc=f"Artifacts ready!")
|
||||
progress(None, desc=f"Loading pipeline on device {self.ui_device}...")
|
||||
|
||||
self.sd_pipe.load_pipeline(
|
||||
vmfbs, weights, self.rt_device, self.compiled_pipeline
|
||||
)
|
||||
progress(1, desc="Pipeline loaded!")
|
||||
progress(None, desc="Pipeline loaded! Generating images...")
|
||||
return
|
||||
|
||||
def generate_images(
|
||||
@@ -271,9 +271,9 @@ class StableDiffusion:
|
||||
resample_type,
|
||||
control_mode,
|
||||
hints,
|
||||
progress=gr.Progress(track_tqdm=True),
|
||||
progress=gr.Progress()
|
||||
):
|
||||
progress(0, desc="Generating images...")
|
||||
|
||||
img = self.sd_pipe.generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
@@ -282,7 +282,6 @@ class StableDiffusion:
|
||||
seed,
|
||||
return_imgs=True,
|
||||
)
|
||||
progress(1, desc="Image generation complete!")
|
||||
return img
|
||||
|
||||
|
||||
@@ -453,6 +452,8 @@ def shark_sd_fn(
|
||||
generated_imgs = []
|
||||
if seed == -1:
|
||||
seed = randint(0, sys.maxsize)
|
||||
progress(None, desc=f"Generating...")
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
start_time = time.time()
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
|
||||
|
||||
@@ -35,10 +35,7 @@ p.add_argument(
|
||||
"--prompt",
|
||||
nargs="+",
|
||||
default=[
|
||||
"a photo taken of the front of a super-car drifting on a road near "
|
||||
"mountains at high speeds with smoke coming off the tires, front "
|
||||
"angle, front point of view, trees in the mountains of the "
|
||||
"background, ((sharp focus))"
|
||||
"A hi-res photo of a red street racer drifting around a curve on a mountain, high altitude, at night, tokyo in the background, 8k"
|
||||
],
|
||||
help="Text of which images to be generated.",
|
||||
)
|
||||
@@ -62,7 +59,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--steps",
|
||||
type=int,
|
||||
default=50,
|
||||
default=2,
|
||||
help="The number of steps to do the sampling.",
|
||||
)
|
||||
|
||||
@@ -100,7 +97,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--guidance_scale",
|
||||
type=float,
|
||||
default=7.5,
|
||||
default=0,
|
||||
help="The value to be used for guidance scaling.",
|
||||
)
|
||||
|
||||
@@ -346,7 +343,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--batch_count",
|
||||
type=int,
|
||||
default=1,
|
||||
default=4,
|
||||
help="Number of batches to be generated with random seeds in " "single execution.",
|
||||
)
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ sd_default_models = [
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
sd_default_models.extend(get_checkpoints(model_type="scripts"))
|
||||
|
||||
|
||||
def view_json_file(file_path):
|
||||
@@ -200,7 +201,7 @@ def save_sd_cfg(config: dict, save_name: str):
|
||||
filepath += ".json"
|
||||
with open(filepath, mode="w") as f:
|
||||
f.write(json.dumps(config))
|
||||
return "..."
|
||||
return save_name
|
||||
|
||||
|
||||
def create_canvas(width, height):
|
||||
@@ -284,7 +285,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
label="\U000026F0\U0000FE0F Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
value="stabilityai/sdxl-turbo",
|
||||
choices=sd_default_models,
|
||||
allow_custom_value=True,
|
||||
) # base_model_id
|
||||
@@ -410,21 +411,21 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
seed = gr.Textbox(
|
||||
value=cmd_opts.seed,
|
||||
label="\U0001F331\U0000FE0F Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
info="An integer, -1 for random",
|
||||
show_copy_button=True,
|
||||
)
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="\U0001F4C5\U0000FE0F Scheduler",
|
||||
info="\U000E0020", # forces same height as seed
|
||||
value="EulerDiscrete",
|
||||
value="EulerAncestralDiscrete",
|
||||
choices=scheduler_model_map.keys(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
50,
|
||||
value=cmd_opts.steps,
|
||||
step=1,
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
@@ -485,17 +486,17 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
minimum=512,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=8,
|
||||
step=512,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
minimum=512,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=8,
|
||||
step=512,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
@@ -616,7 +617,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
visible=False, # DEMO
|
||||
)
|
||||
compiled_pipeline = gr.Checkbox(
|
||||
False,
|
||||
True,
|
||||
label="Faster txt2img (SDXL only)",
|
||||
)
|
||||
with gr.Row():
|
||||
@@ -627,7 +628,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop")
|
||||
stop_batch = gr.Button("Stop", visible=False)
|
||||
with gr.Tab(label="Config", id=102) as sd_tab_config:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
@@ -653,7 +654,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
if cmd_opts.configs_path
|
||||
else get_configs_path()
|
||||
),
|
||||
height=75,
|
||||
height=200,
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_sd_config = gr.Button(
|
||||
@@ -664,13 +665,13 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
size="sm",
|
||||
components=sd_json,
|
||||
)
|
||||
with gr.Row():
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
)
|
||||
#with gr.Row():
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
)
|
||||
load_sd_config.change(
|
||||
fn=load_sd_cfg,
|
||||
inputs=[sd_json, load_sd_config],
|
||||
@@ -758,6 +759,7 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
outputs=[
|
||||
sd_json,
|
||||
],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
|
||||
@@ -88,6 +88,8 @@ def get_checkpoints_path(model_type=""):
|
||||
def get_checkpoints(model_type="checkpoints"):
|
||||
ckpt_files = []
|
||||
file_types = checkpoints_filetypes
|
||||
if model_type == "scripts":
|
||||
file_types = ["shark_*.py"]
|
||||
if model_type == "lora":
|
||||
file_types = file_types + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
|
||||
@@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install --pre -r requirements.txt
|
||||
pip install --force-reinstall https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_compiler-20240528.279-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_runtime-20240528.279-cp311-cp311-win_amd64.whl
|
||||
pip install --force-reinstall https://github.com/nod-ai/SRT/releases/download/candidate-20240601.282/iree_compiler-20240601.282-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240601.282/iree_runtime-20240601.282-cp311-cp311-win_amd64.whl
|
||||
pip install -e .
|
||||
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
Reference in New Issue
Block a user