mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 07:18:27 -05:00
Compare commits
10 Commits
diffusers-
...
20240109.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e80bc9f857 | ||
|
|
773e6ebebf | ||
|
|
dda7e8a163 | ||
|
|
7fdd1952ae | ||
|
|
0a6f6fad86 | ||
|
|
6853a33728 | ||
|
|
3887d83f5d | ||
|
|
8d9b5b3afa | ||
|
|
16c03e4b44 | ||
|
|
17dab8334d |
77
.github/workflows/nightly.yml
vendored
77
.github/workflows/nightly.yml
vendored
@@ -74,80 +74,3 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
linux-build:
|
||||
|
||||
runs-on: a100
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
backend: [IREE, SHARK]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
- name: Build and validate the SHARK Runtime package
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
|
||||
@@ -2075,6 +2075,10 @@ class UnshardedVicuna(VicunaBase):
|
||||
f"Compiling for device : {self.device}"
|
||||
f"{'://' + str(self.device_id) if self.device_id is not None else ''}"
|
||||
)
|
||||
if "cpu" in self.device:
|
||||
self.extra_args.extend("--iree-llvmcpu-enable-quantized-matmul-reassociation")
|
||||
self.extra_args.extend("--iree-global-opt-enable-quantized-matmul-reassociation")
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=combined_module,
|
||||
device=self.device,
|
||||
|
||||
@@ -61,6 +61,7 @@ datas += [
|
||||
("src/utils/resources/opt_flags.json", "resources"),
|
||||
("src/utils/resources/base_model.json", "resources"),
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/js/*", "ui/js"),
|
||||
("web/ui/logos/*", "logos"),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/configs/*",
|
||||
|
||||
@@ -159,8 +159,10 @@ class SharkifyStableDiffusionModel:
|
||||
is_sdxl: bool = False,
|
||||
stencils: list[str] = [],
|
||||
use_lora: str = "",
|
||||
lora_strength: float = 0.75,
|
||||
use_quantize: str = None,
|
||||
return_mlir: bool = False,
|
||||
favored_base_models=None,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
@@ -190,6 +192,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
self.favored_base_models = favored_base_models
|
||||
self.custom_vae = custom_vae
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
@@ -216,8 +219,14 @@ class SharkifyStableDiffusionModel:
|
||||
self.is_upscaler = is_upscaler
|
||||
self.stencils = [get_stencil_model_id(x) for x in stencils]
|
||||
if use_lora != "":
|
||||
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
|
||||
self.model_name = (
|
||||
self.model_name
|
||||
+ "_"
|
||||
+ get_path_stem(use_lora)
|
||||
+ f"@{int(lora_strength*100)}"
|
||||
)
|
||||
self.use_lora = use_lora
|
||||
self.lora_strength = lora_strength
|
||||
|
||||
self.model_name = self.get_extended_name_for_all_model()
|
||||
self.debug = debug
|
||||
@@ -534,6 +543,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -542,7 +552,9 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
update_lora_weight(
|
||||
self.unet, use_lora, "unet", lora_strength
|
||||
)
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
|
||||
@@ -818,6 +830,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -826,7 +839,9 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
update_lora_weight(
|
||||
self.unet, use_lora, "unet", lora_strength
|
||||
)
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
if (
|
||||
@@ -1058,6 +1073,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
@@ -1067,7 +1083,10 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(
|
||||
self.text_encoder, use_lora, "text_encoder"
|
||||
self.text_encoder,
|
||||
use_lora,
|
||||
"text_encoder",
|
||||
lora_strength,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -1271,6 +1290,10 @@ class SharkifyStableDiffusionModel:
|
||||
compiled_unet = None
|
||||
unet_inputs = base_models[model]
|
||||
|
||||
# if the model to run *is* a base model, then we should treat it as such
|
||||
if self.model_to_run in unet_inputs:
|
||||
self.base_model_id = self.model_to_run
|
||||
|
||||
if self.base_model_id != "":
|
||||
self.inputs["unet"] = self.get_input_info_for(
|
||||
unet_inputs[self.base_model_id]
|
||||
@@ -1279,7 +1302,16 @@ class SharkifyStableDiffusionModel:
|
||||
model, use_large=use_large, base_model=self.base_model_id
|
||||
)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
# restrict base models to check if we were given a specific list of valid ones
|
||||
allowed_base_model_ids = unet_inputs
|
||||
if self.favored_base_models != None:
|
||||
allowed_base_model_ids = self.favored_base_models
|
||||
|
||||
print(f"self.favored_base_models: {self.favored_base_models}")
|
||||
print(f"allowed_base_model_ids: {allowed_base_model_ids}")
|
||||
|
||||
# try compiling with each base model until we find one that works (of not)
|
||||
for model_id in allowed_base_model_ids:
|
||||
self.base_model_id = model_id
|
||||
self.inputs["unet"] = self.get_input_info_for(
|
||||
unet_inputs[model_id]
|
||||
@@ -1292,7 +1324,7 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(
|
||||
"Retrying with a different base model configuration"
|
||||
f"Retrying with a different base model configuration, as {model_id} did not work"
|
||||
)
|
||||
continue
|
||||
|
||||
|
||||
@@ -56,9 +56,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
|
||||
@@ -51,9 +51,12 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
|
||||
@@ -52,9 +52,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
|
||||
@@ -64,10 +64,13 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
controlnet_names: list[str],
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.controlnet = [None] * len(controlnet_names)
|
||||
self.controlnet_512 = [None] * len(controlnet_names)
|
||||
self.controlnet_id = [str] * len(controlnet_names)
|
||||
@@ -263,8 +266,8 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latent_model_input_1 = latent_model_input
|
||||
|
||||
# Multicontrolnet
|
||||
width = latent_model_input_1.shape[2]
|
||||
height = latent_model_input_1.shape[3]
|
||||
height = latent_model_input_1.shape[2]
|
||||
width = latent_model_input_1.shape[3]
|
||||
dtype = latent_model_input_1.dtype
|
||||
control_acc = (
|
||||
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
|
||||
|
||||
@@ -49,9 +49,19 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def favored_base_models(cls, model_id):
|
||||
return [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
|
||||
@@ -51,12 +51,28 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
is_fp32_vae: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.is_fp32_vae = is_fp32_vae
|
||||
|
||||
@classmethod
|
||||
def favored_base_models(cls, model_id):
|
||||
if "turbo" in model_id:
|
||||
return [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
|
||||
@@ -94,9 +94,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
|
||||
@@ -65,6 +65,7 @@ class StableDiffusionPipeline:
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
is_f32_vae: bool = False,
|
||||
):
|
||||
@@ -81,6 +82,7 @@ class StableDiffusionPipeline:
|
||||
self.scheduler = scheduler
|
||||
self.import_mlir = import_mlir
|
||||
self.use_lora = use_lora
|
||||
self.lora_strength = lora_strength
|
||||
self.ondemand = ondemand
|
||||
self.is_f32_vae = is_f32_vae
|
||||
# TODO: Find a better workaround for fetching base_model_id early
|
||||
@@ -92,6 +94,10 @@ class StableDiffusionPipeline:
|
||||
self.unload_unet()
|
||||
self.tokenizer = get_tokenizer()
|
||||
|
||||
def favored_base_models(cls, model_id):
|
||||
# all base models can be candidate base models for unet compilation
|
||||
return None
|
||||
|
||||
def load_clip(self):
|
||||
if self.text_encoder is not None:
|
||||
return
|
||||
@@ -647,6 +653,7 @@ class StableDiffusionPipeline:
|
||||
stencils: list[str] = [],
|
||||
# stencil_images: list[Image] = []
|
||||
use_lora: str = "",
|
||||
lora_strength: float = 0.75,
|
||||
ddpm_scheduler: DDPMScheduler = None,
|
||||
use_quantize=None,
|
||||
):
|
||||
@@ -664,6 +671,9 @@ class StableDiffusionPipeline:
|
||||
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
|
||||
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
|
||||
|
||||
print(f"model_id", model_id)
|
||||
print(f"ckpt_loc", ckpt_loc)
|
||||
print(f"favored_base_models:", cls.favored_base_models(model_id))
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
@@ -682,7 +692,11 @@ class StableDiffusionPipeline:
|
||||
is_sdxl=is_sdxl,
|
||||
stencils=stencils,
|
||||
use_lora=use_lora,
|
||||
lora_strength=lora_strength,
|
||||
use_quantize=use_quantize,
|
||||
favored_base_models=cls.favored_base_models(
|
||||
model_id if model_id != "" else ckpt_loc
|
||||
),
|
||||
)
|
||||
|
||||
if cls.__name__ in ["UpscalerPipeline"]:
|
||||
@@ -692,12 +706,19 @@ class StableDiffusionPipeline:
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
)
|
||||
|
||||
if cls.__name__ == "StencilPipeline":
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
|
||||
scheduler,
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
stencils,
|
||||
)
|
||||
if cls.__name__ == "Text2ImageSDXLPipeline":
|
||||
is_fp32_vae = True if "16" not in custom_vae else False
|
||||
@@ -706,11 +727,14 @@ class StableDiffusionPipeline:
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
is_fp32_vae,
|
||||
)
|
||||
|
||||
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
|
||||
# #####################################################
|
||||
# Implements text embeddings with weights from prompts
|
||||
|
||||
@@ -435,6 +435,13 @@ p.add_argument(
|
||||
"file (~3 MB).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--lora_strength",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Strength (alpha) scaling factor to use when applying LoRA weights",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
|
||||
@@ -6,6 +6,7 @@ from PIL import PngImagePlugin
|
||||
from PIL import Image
|
||||
from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import (
|
||||
@@ -638,30 +639,51 @@ def convert_original_vae(vae_checkpoint):
|
||||
return converted_vae_checkpoint
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
@dataclass
|
||||
class LoRAweight:
|
||||
up: torch.tensor
|
||||
down: torch.tensor
|
||||
mid: torch.tensor
|
||||
alpha: torch.float32 = 1.0
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix, lora_strength):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
alpha = 0.75
|
||||
visited = []
|
||||
|
||||
# directly update weight in model
|
||||
process_unet = "te" not in splitting_prefix
|
||||
# gather the weights from the LoRA in a more convenient form, assumes
|
||||
# everything will have an up.weight. Unsure if this is a safe assumption.
|
||||
weight_dict: dict[str, LoRAweight] = {}
|
||||
for key in state_dict:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
|
||||
stem = key.split("up.weight")[0]
|
||||
weight_key = stem.removesuffix(".lora_")
|
||||
weight_key = weight_key.removesuffix("_lora_")
|
||||
weight_key = weight_key.removesuffix(".lora_linear_layer.")
|
||||
|
||||
if weight_key not in weight_dict:
|
||||
weight_dict[weight_key] = LoRAweight(
|
||||
state_dict[f"{stem}up.weight"],
|
||||
state_dict[f"{stem}down.weight"],
|
||||
state_dict.get(f"{stem}mid.weight", None),
|
||||
state_dict[f"{weight_key}.alpha"]
|
||||
/ state_dict[f"{stem}up.weight"].shape[1]
|
||||
if f"{weight_key}.alpha" in state_dict
|
||||
else 1.0,
|
||||
)
|
||||
|
||||
# Directly update weight in model
|
||||
|
||||
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
|
||||
# and similar code in https://github.com/huggingface/diffusers/issues/3064
|
||||
|
||||
# TODO: handle mid weights (how do they even work?)
|
||||
for key, lora_weight in weight_dict.items():
|
||||
curr_layer = model
|
||||
if ("text" not in key and process_unet) or (
|
||||
"text" in key and not process_unet
|
||||
):
|
||||
layer_infos = (
|
||||
key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
)
|
||||
else:
|
||||
continue
|
||||
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
@@ -678,46 +700,46 @@ def processLoRA(model, use_lora, splitting_prefix):
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = (
|
||||
state_dict[pair_keys[0]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
weight = curr_layer.weight.data
|
||||
scale = lora_weight.alpha * lora_strength
|
||||
if len(weight.size()) == 2:
|
||||
if len(lora_weight.up.shape) == 4:
|
||||
weight_up = (
|
||||
lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
weight_down = (
|
||||
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
change = (
|
||||
torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
)
|
||||
else:
|
||||
change = torch.mm(lora_weight.up, lora_weight.down)
|
||||
elif lora_weight.down.size()[2:4] == (1, 1):
|
||||
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = (
|
||||
state_dict[pair_keys[1]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
curr_layer.weight.data += alpha * torch.mm(
|
||||
weight_up, weight_down
|
||||
).unsqueeze(2).unsqueeze(3)
|
||||
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
change = torch.nn.functional.conv2d(
|
||||
lora_weight.down.permute(1, 0, 2, 3),
|
||||
lora_weight.up,
|
||||
).permute(1, 0, 2, 3)
|
||||
|
||||
curr_layer.weight.data += change * scale
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora):
|
||||
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
print(
|
||||
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
|
||||
)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
@@ -733,16 +755,21 @@ def update_lora_weight_for_unet(unet, use_lora):
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
print(
|
||||
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
|
||||
)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_")
|
||||
print(f"updated unet weights manually from LoRA: {use_lora}")
|
||||
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name):
|
||||
def update_lora_weight(model, use_lora, model_name, lora_strength):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora)
|
||||
return update_lora_weight_for_unet(model, use_lora, lora_strength)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
print(f"updating CLIP weights from LoRA: {use_lora}")
|
||||
return processLoRA(model, use_lora, "lora_te_", lora_strength)
|
||||
except:
|
||||
return None
|
||||
|
||||
@@ -898,7 +925,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
img_lora = None
|
||||
if args.use_lora:
|
||||
img_lora = Path(os.path.basename(args.use_lora)).stem
|
||||
img_lora = f"{Path(os.path.basename(args.use_lora)).stem}:{args.lora_strength}"
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
|
||||
@@ -12,7 +12,6 @@ from apps.stable_diffusion.web.api.utils import (
|
||||
decode_base64_to_image,
|
||||
get_model_from_request,
|
||||
get_scheduler_from_request,
|
||||
get_lora_params,
|
||||
get_device,
|
||||
GenerationInputData,
|
||||
GenerationResponseData,
|
||||
@@ -180,7 +179,6 @@ def txt2img_api(InputData: Txt2ImgInputData):
|
||||
scheduler = get_scheduler_from_request(
|
||||
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
|
||||
)
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
@@ -208,8 +206,8 @@ def txt2img_api(InputData: Txt2ImgInputData):
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=InputData.enable_hr,
|
||||
@@ -270,7 +268,6 @@ def img2img_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "img2img")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
mask_image = (
|
||||
@@ -308,8 +305,8 @@ def img2img_api(
|
||||
use_stencil=InputData.use_stencil,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
resample_type=frozen_args.resample_type,
|
||||
@@ -358,7 +355,6 @@ def inpaint_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "inpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.image)
|
||||
mask = decode_base64_to_image(InputData.mask)
|
||||
@@ -393,8 +389,8 @@ def inpaint_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
@@ -448,7 +444,6 @@ def outpaint_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "outpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
@@ -484,8 +479,8 @@ def outpaint_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
@@ -531,7 +526,6 @@ def upscaler_api(
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "upscaler")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
@@ -563,8 +557,8 @@ def upscaler_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
@@ -191,17 +191,6 @@ def get_scheduler_from_request(
|
||||
)
|
||||
|
||||
|
||||
def get_lora_params(use_lora: str):
|
||||
# TODO: since the inference functions in the webui, which we are
|
||||
# still calling into for the api, jam these back together again before
|
||||
# handing them off to the pipeline, we should remove this nonsense
|
||||
# and unify their selection in the UI and command line args proper
|
||||
if use_lora in get_custom_model_files("lora"):
|
||||
return (use_lora, "")
|
||||
|
||||
return ("None", use_lora)
|
||||
|
||||
|
||||
def get_device(device_str: str):
|
||||
# first substring match in the list available devices, with first
|
||||
# device when none are matched
|
||||
|
||||
@@ -237,9 +237,13 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
|
||||
css=dark_theme,
|
||||
js=gradio_workarounds,
|
||||
analytics_enabled=False,
|
||||
title="SHARK AI Studio",
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
|
||||
@@ -55,3 +55,10 @@ def lora_changed(lora_file):
|
||||
return [
|
||||
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
|
||||
]
|
||||
|
||||
|
||||
def lora_strength_changed(strength):
|
||||
if strength > 1.0:
|
||||
return gr.Number(elem_classes="value-out-of-range")
|
||||
else:
|
||||
return gr.Number(elem_classes="")
|
||||
|
||||
@@ -117,7 +117,7 @@ body {
|
||||
height: 100% !important;
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
/* display in full width for desktop devices, but see below */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
.gradio-container {
|
||||
@@ -125,6 +125,15 @@ body {
|
||||
}
|
||||
}
|
||||
|
||||
/* media rules in custom css are don't appear to be applied in
|
||||
gradio versions > 4.7, so we have to define a class which
|
||||
we will manually need add and remove using javascript.
|
||||
Remove this once this fixed in gradio.
|
||||
*/
|
||||
.gradio-container-size-full {
|
||||
max-width: var(--size-full) !important;
|
||||
}
|
||||
|
||||
.gradio-container .contain {
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
@@ -182,6 +191,8 @@ footer {
|
||||
aspect-ratio: unset;
|
||||
max-height: calc(55vh - (2 * var(--spacing-lg)));
|
||||
}
|
||||
|
||||
/* fix width and height of gallery items when on very large desktop screens, but see below */
|
||||
@media (min-width: 1921px) {
|
||||
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
|
||||
#gallery .grid-wrap, #gallery .preview{
|
||||
@@ -193,6 +204,20 @@ footer {
|
||||
max-height: 770px !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* media rules in custom css are don't appear to be applied in
|
||||
gradio versions > 4.7, so we have to define classes which
|
||||
we will manually need add and remove using javascript.
|
||||
Remove this once this fixed in gradio.
|
||||
*/
|
||||
.gallery-force-height768 .grid-wrap, .gallery-force-height768 .preview {
|
||||
min-height: calc(768px + 4px + var(--size-14)) !important;
|
||||
max-height: calc(768px + 4px + var(--size-14)) !important;
|
||||
}
|
||||
.gallery-limit-height768 .thumbnail-item.thumbnail-lg {
|
||||
max-height: 770px !important;
|
||||
}
|
||||
|
||||
/* Don't upscale when viewing in solo image mode */
|
||||
#gallery .preview img {
|
||||
object-fit: scale-down;
|
||||
@@ -244,6 +269,11 @@ footer {
|
||||
padding-right: 8px;
|
||||
}
|
||||
|
||||
/* number input value is out of range */
|
||||
.value-out-of-range input[type="number"] {
|
||||
color: red !important;
|
||||
}
|
||||
|
||||
/* reduced animation load when generating */
|
||||
.generating {
|
||||
animation-play-state: paused !important;
|
||||
@@ -280,7 +310,7 @@ footer {
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe table.table {
|
||||
/* works around a gradio bug that always shows scrollbars */
|
||||
/* works around a gradio bug that always shows scrollbars */
|
||||
overflow: clip auto;
|
||||
}
|
||||
|
||||
|
||||
@@ -21,7 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
@@ -74,7 +77,7 @@ def img2img_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
@@ -141,9 +144,8 @@ def img2img_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -176,6 +178,7 @@ def img2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=stencils,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -228,6 +231,7 @@ def img2img_inf(
|
||||
stencils=stencils,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -249,6 +253,7 @@ def img2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -592,20 +597,20 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
maximum=768,
|
||||
value=512,
|
||||
step=1,
|
||||
step=8,
|
||||
visible=False,
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=768,
|
||||
value=512,
|
||||
step=8,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
@@ -705,17 +710,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
maximum=768,
|
||||
value=512,
|
||||
step=1,
|
||||
step=8,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
maximum=768,
|
||||
value=512,
|
||||
step=1,
|
||||
step=8,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
@@ -806,28 +811,25 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=i2i_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -957,8 +959,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
elem_id="gallery",
|
||||
columns=2,
|
||||
object_fit="contain",
|
||||
# TODO: Re-enable download when fixed in Gradio
|
||||
show_download_button=False,
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{i2i_model_info}\n"
|
||||
@@ -1013,7 +1013,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
@@ -1054,3 +1054,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -21,7 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
@@ -109,7 +112,7 @@ def inpaint_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: int,
|
||||
):
|
||||
@@ -150,9 +153,8 @@ def inpaint_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -171,6 +173,7 @@ def inpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -215,6 +218,7 @@ def inpaint_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -350,28 +354,25 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=inpaint_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -558,7 +559,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -622,3 +623,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
49
apps/stable_diffusion/web/ui/js/sd_gradio_workarounds.js
Normal file
49
apps/stable_diffusion/web/ui/js/sd_gradio_workarounds.js
Normal file
@@ -0,0 +1,49 @@
|
||||
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
|
||||
|
||||
() => {
|
||||
console.log(`innerWidth: ${window.innerWidth}` )
|
||||
|
||||
// 1536px rules
|
||||
|
||||
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
|
||||
|
||||
function handleWidth1536(event) {
|
||||
|
||||
// display in full width for desktop devices
|
||||
document.querySelectorAll(".gradio-container")
|
||||
.forEach( (node) => {
|
||||
if (event.matches) {
|
||||
node.classList.add("gradio-container-size-full");
|
||||
} else {
|
||||
node.classList.remove("gradio-container-size-full")
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
mediaQuery1536.addEventListener("change", handleWidth1536);
|
||||
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
|
||||
|
||||
// 1921px rules
|
||||
|
||||
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
|
||||
|
||||
function handleWidth1921(event) {
|
||||
|
||||
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
|
||||
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
|
||||
document.querySelectorAll("#gallery")
|
||||
.forEach( (node) => {
|
||||
if (event.matches) {
|
||||
node.classList.add("gallery-force-height768");
|
||||
node.classList.add("gallery-limit-height768");
|
||||
} else {
|
||||
node.classList.remove("gallery-force-height768");
|
||||
node.classList.remove("gallery-limit-height768");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
mediaQuery1921.addEventListener("change", handleWidth1921);
|
||||
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
|
||||
|
||||
}
|
||||
@@ -238,9 +238,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
max_length,
|
||||
training_images_dir,
|
||||
output_loc,
|
||||
get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
),
|
||||
get_custom_vae_or_lora_weights(lora_weights, "lora"),
|
||||
],
|
||||
outputs=[std_output],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -4,7 +4,10 @@ import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -60,7 +63,7 @@ def outpaint_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
@@ -100,9 +103,8 @@ def outpaint_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -121,6 +123,7 @@ def outpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -163,6 +166,7 @@ def outpaint_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -296,28 +300,25 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
outpaint_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=outpaint_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -469,8 +470,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
# TODO: Re-enable download when fixed in Gradio
|
||||
show_download_button=False,
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{outpaint_model_info}\n"
|
||||
@@ -527,7 +526,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -556,3 +555,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -92,8 +92,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
visible=False,
|
||||
show_label=True,
|
||||
columns=4,
|
||||
# TODO: Re-enable download when fixed in Gradio
|
||||
show_download_button=False,
|
||||
)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
|
||||
@@ -177,8 +177,6 @@ def chat(
|
||||
)
|
||||
|
||||
_extra_args = _extra_args + [
|
||||
"--iree-global-opt-enable-quantized-matmul-reassociation",
|
||||
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
|
||||
"--iree-opt-const-eval=false",
|
||||
"--iree-opt-data-tiling=false",
|
||||
]
|
||||
|
||||
@@ -15,7 +15,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
cancel_sd,
|
||||
set_model_default_configs,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
@@ -59,7 +62,7 @@ def txt2img_sdxl_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
@@ -105,9 +108,8 @@ def txt2img_sdxl_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -123,6 +125,7 @@ def txt2img_sdxl_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -150,6 +153,7 @@ def txt2img_sdxl_inf(
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
if global_obj.get_cfg_obj().ondemand:
|
||||
@@ -171,6 +175,7 @@ def txt2img_sdxl_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
)
|
||||
@@ -276,12 +281,14 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
label=f"VAE Models",
|
||||
info=t2i_sdxl_vae_info,
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
value="madebyollin/sdxl-vae-fp16-fix",
|
||||
choices=[
|
||||
None,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
]
|
||||
+ get_custom_model_files("vae"),
|
||||
+ get_custom_model_files(
|
||||
"vae", custom_checkpoint_type="sdxl"
|
||||
),
|
||||
allow_custom_value=True,
|
||||
scale=4,
|
||||
)
|
||||
@@ -316,28 +323,25 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_sdxl_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_sdxl_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -374,7 +378,7 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
height = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
value=768,
|
||||
step=256,
|
||||
label="Height",
|
||||
visible=True,
|
||||
@@ -383,7 +387,7 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
width = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=1024,
|
||||
value=768,
|
||||
step=256,
|
||||
label="Width",
|
||||
visible=True,
|
||||
@@ -478,8 +482,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="scale_down",
|
||||
# TODO: Re-enable download when fixed in Gradio
|
||||
show_download_button=False,
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{t2i_sdxl_model_info}\n"
|
||||
@@ -539,7 +541,7 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -609,7 +611,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
@@ -624,7 +625,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
@@ -651,3 +651,11 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
@@ -44,7 +47,7 @@ all_gradio_labels = [
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"lora_weights",
|
||||
"lora_hf_id",
|
||||
"lora_strength",
|
||||
"scheduler",
|
||||
"save_metadata_to_png",
|
||||
"save_metadata_to_json",
|
||||
@@ -91,7 +94,7 @@ def txt2img_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
use_hiresfix: bool,
|
||||
@@ -138,9 +141,8 @@ def txt2img_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -156,6 +158,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -207,6 +210,7 @@ def txt2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -256,6 +260,7 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -288,6 +293,7 @@ def txt2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -311,6 +317,7 @@ def txt2img_inf(
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
stencils=[],
|
||||
images=None,
|
||||
control_mode=None,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
@@ -384,7 +391,7 @@ def load_settings():
|
||||
loaded_settings.get("prompt", args.prompts[0]),
|
||||
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
|
||||
loaded_settings.get("lora_weights", "None"),
|
||||
loaded_settings.get("lora_hf_id", ""),
|
||||
loaded_settings.get("lora_strength", args.lora_strength),
|
||||
loaded_settings.get("scheduler", args.scheduler),
|
||||
loaded_settings.get(
|
||||
"save_metadata_to_png", args.write_metadata_to_png
|
||||
@@ -494,28 +501,25 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value=default_settings.get("lora_weights"),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value=default_settings.get("lora_hf_id"),
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=default_settings.get("lora_strength"),
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -619,6 +623,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
step=1,
|
||||
label=default_settings.get("batch_size"),
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
default_settings.get("repeatable_seeds"),
|
||||
@@ -688,8 +693,6 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
# TODO: Re-enable download when fixed in Gradio
|
||||
show_download_button=False,
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{t2i_model_info}\n"
|
||||
@@ -735,7 +738,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
prompt,
|
||||
negative_prompt,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
scheduler,
|
||||
save_metadata_to_png,
|
||||
save_metadata_to_json,
|
||||
@@ -768,7 +771,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
prompt,
|
||||
negative_prompt,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
scheduler,
|
||||
save_metadata_to_png,
|
||||
save_metadata_to_json,
|
||||
@@ -812,7 +815,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_hiresfix,
|
||||
@@ -855,7 +858,6 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
@@ -870,7 +872,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
@@ -901,3 +903,11 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
@@ -53,7 +56,7 @@ def upscaler_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
@@ -100,9 +103,8 @@ def upscaler_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -120,6 +122,7 @@ def upscaler_inf(
|
||||
args.width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
@@ -159,6 +162,7 @@ def upscaler_inf(
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -318,28 +322,25 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
upscaler_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=upscaler_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
@@ -469,8 +470,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
# TODO: Re-enable download when fixed in Gradio
|
||||
show_download_button=False,
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{upscaler_model_info}\n"
|
||||
@@ -523,7 +522,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -552,3 +551,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -33,6 +33,7 @@ class Config:
|
||||
width: int
|
||||
device: str
|
||||
use_lora: str
|
||||
lora_strength: float
|
||||
stencils: list[str]
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
@@ -180,14 +181,16 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
use_weight = ""
|
||||
if weights == "None" and not hf_id:
|
||||
def get_custom_vae_or_lora_weights(weights, model):
|
||||
if weights == "None":
|
||||
use_weight = ""
|
||||
elif not hf_id:
|
||||
use_weight = get_custom_model_pathfile(weights, model)
|
||||
else:
|
||||
use_weight = hf_id
|
||||
custom_weights = get_custom_model_pathfile(str(weights), model)
|
||||
if os.path.isfile(custom_weights):
|
||||
use_weight = custom_weights
|
||||
else:
|
||||
use_weight = weights
|
||||
|
||||
return use_weight
|
||||
|
||||
|
||||
|
||||
@@ -122,20 +122,26 @@ def find_vae_from_png_metadata(
|
||||
|
||||
def find_lora_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
lora_hf_id = ""
|
||||
) -> tuple[str, float]:
|
||||
lora_custom = ""
|
||||
lora_strength = 1.0
|
||||
|
||||
if key in metadata:
|
||||
lora_file = metadata[key]
|
||||
split_metadata = metadata[key].split(":")
|
||||
lora_file = split_metadata[0]
|
||||
if len(split_metadata) == 2:
|
||||
try:
|
||||
lora_strength = float(split_metadata[1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not lora_custom and lora_file.count("/"):
|
||||
lora_hf_id = lora_file
|
||||
lora_custom = lora_file
|
||||
|
||||
# LoRA input is optional, should not print or throw an error if missing
|
||||
|
||||
return lora_custom, lora_hf_id
|
||||
return lora_custom, lora_strength
|
||||
|
||||
|
||||
def import_png_metadata(
|
||||
@@ -150,7 +156,6 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
):
|
||||
try:
|
||||
@@ -160,9 +165,10 @@ def import_png_metadata(
|
||||
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
|
||||
"Model", metadata
|
||||
)
|
||||
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
|
||||
"LoRA", metadata
|
||||
)
|
||||
(
|
||||
custom_lora,
|
||||
custom_lora_strength,
|
||||
) = find_lora_from_png_metadata("LoRA", metadata)
|
||||
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
|
||||
|
||||
negative_prompt = metadata["Negative prompt"]
|
||||
@@ -177,12 +183,8 @@ def import_png_metadata(
|
||||
elif "Model" in metadata and png_hf_model_id:
|
||||
custom_model = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
hf_lora_id = ""
|
||||
if "LoRA" in metadata and lora_hf_model_id:
|
||||
if "LoRA" in metadata and not custom_lora:
|
||||
custom_lora = "None"
|
||||
hf_lora_id = lora_hf_model_id
|
||||
|
||||
if "VAE" in metadata and vae_custom_model:
|
||||
custom_vae = vae_custom_model
|
||||
@@ -215,6 +217,6 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_lora_strength,
|
||||
custom_vae,
|
||||
)
|
||||
|
||||
@@ -6,8 +6,8 @@ requires = [
|
||||
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20230620.875",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
"iree-compiler==20231212.*",
|
||||
"iree-runtime==20231212.*",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ diffusers
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio==4.8.0
|
||||
gradio==4.12.0
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
|
||||
4
setup.py
4
setup.py
@@ -11,8 +11,8 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
|
||||
backend_deps = []
|
||||
if "NO_BACKEND" in os.environ.keys():
|
||||
backend_deps = [
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
"iree-compiler==20231212.*",
|
||||
"iree-runtime==20231212.*",
|
||||
]
|
||||
|
||||
setup(
|
||||
|
||||
@@ -7,13 +7,13 @@
|
||||
It checks the Python version installed and installs any required build
|
||||
dependencies into a Python virtual environment.
|
||||
If that environment does not exist, it creates it.
|
||||
|
||||
|
||||
.PARAMETER update-src
|
||||
git pulls latest version
|
||||
|
||||
.PARAMETER force
|
||||
removes and recreates venv to force update of all dependencies
|
||||
|
||||
|
||||
.EXAMPLE
|
||||
.\setup_venv.ps1 --force
|
||||
|
||||
@@ -39,11 +39,11 @@ if ($arguments -eq "--force"){
|
||||
Write-Host "deactivating..."
|
||||
Deactivate
|
||||
}
|
||||
|
||||
if (Test-Path .\shark.venv\) {
|
||||
|
||||
if (Test-Path .\shark1.venv\) {
|
||||
Write-Host "removing and recreating venv..."
|
||||
Remove-Item .\shark.venv -Force -Recurse
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Remove-Item .\shark1.venv -Force -Recurse
|
||||
if (Test-Path .\shark1.venv\) {
|
||||
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
|
||||
exit 1
|
||||
}
|
||||
@@ -83,15 +83,15 @@ if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
# make sure we really use 3.11 from list, even if it's not the default.
|
||||
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
|
||||
else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark1.venv\}
|
||||
else {python -m venv .\shark1.venv\}
|
||||
.\shark1.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler==20231212.* iree-runtime==20231212.*
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
Write-Host "Source your venv with ./shark1.venv/Scripts/activate"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Environment variables used by the script.
|
||||
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
|
||||
# VENV_DIR=myshark.venv #create a venv called myshark.venv
|
||||
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
|
||||
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
|
||||
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
|
||||
# IMPORTER=1 #Install importer deps
|
||||
# BENCHMARK=1 #Install benchmark deps
|
||||
@@ -35,7 +35,7 @@ fi
|
||||
if [[ "$SKIP_VENV" != "1" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
VENV_DIR=${VENV_DIR:-shark1.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
@@ -111,7 +111,7 @@ else
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler==20231212.* iree-runtime==20231212.*
|
||||
else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
fi
|
||||
|
||||
@@ -39,12 +39,7 @@ def get_iree_device_args(device, extra_args=[]):
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
|
||||
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
|
||||
|
||||
return (
|
||||
get_iree_cpu_args()
|
||||
+ u_kernel_flag
|
||||
+ stack_size_flag
|
||||
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
|
||||
)
|
||||
return get_iree_cpu_args() + u_kernel_flag + stack_size_flag
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
|
||||
@@ -139,8 +139,15 @@ def get_vulkan_target_triple(device_name):
|
||||
triple = f"rdna3-780m-{system_os}"
|
||||
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
|
||||
triple = f"rdna3-w7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
elif "7600" in device_name:
|
||||
triple = f"rdna3-7600-{system_os}"
|
||||
elif "7700" in device_name:
|
||||
triple = f"rdna3-7700-{system_os}"
|
||||
elif any(x in device_name for x in ("Radeon 6", "RX 6", "PRO W6")):
|
||||
triple = f"rdna2-unknown-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
triple = f"rdna3-unknown-{system_os}"
|
||||
|
||||
# Intel Targets
|
||||
elif any(x in device_name for x in ("A770", "A750")):
|
||||
triple = f"arc-770-{system_os}"
|
||||
|
||||
Reference in New Issue
Block a user