Compare commits

...

22 Commits

Author SHA1 Message Date
Ean Garvey
fd3fdeef07 Update vulkan_utils.py 2024-01-09 12:39:11 -06:00
Ean Garvey
80906155f8 Revert to 20231212 2024-01-09 12:31:38 -06:00
Ean Garvey
1c094d96eb Update vulkan_utils.py 2024-01-09 12:17:37 -06:00
Ean Garvey
94805aa3c6 Update setup_venv.sh 2024-01-08 23:01:51 -06:00
Ean Garvey
cfb494edff Update setup_venv.ps1 2024-01-08 23:01:22 -06:00
Ean Garvey
10294bef33 Update setup.py 2024-01-08 23:01:02 -06:00
Ean Garvey
8051b33005 Update pyproject.toml 2024-01-08 23:00:42 -06:00
Ean Garvey
10af1f25c7 Update setup.py 2024-01-08 22:38:05 -06:00
Ean Garvey
4305153d21 Update pyproject.toml 2024-01-08 22:37:49 -06:00
Ean Garvey
ab5639ff66 Update pyproject.toml 2024-01-08 21:13:09 -06:00
Ean Garvey
4cf0383768 Update setup.py 2024-01-08 21:12:43 -06:00
Ean Garvey
01845d593d Remove linux builds from 1.0 nightly workflow
It seems that the VMs used for these workflows are no longer available. Removing linux builds since publishing .exe is sufficient for the one-shot nightly workflows we trigger for SHARK-1.0
2024-01-08 18:28:09 -06:00
Ean Garvey
5b15ceee35 Move IREE pins for linux. 2024-01-08 18:23:00 -06:00
Ean Garvey
04b21295ee Move IREE pins for windows. 2024-01-08 18:22:09 -06:00
Stefan Kapusniak
dda7e8a163 (Shark-1) UI: Fix 'keyword argument' error in txt2img (#2058)
* Removes the incorrect valid_base_models keywoard argument left in text2img_inf
2024-01-08 16:47:15 -06:00
Stefan Kapusniak
7fdd1952ae (Shark 1.0) UI/SD UX improvements for SDXL (#2057)
* SDXL Tab
  * Filter VAEs in dropdown in the same manner as models
  * Set default VAE selection to 'madebyollin/sdxl-vae-fp16-fix'
  * Set default image size to 768x768 to match current Vulkan constraints
* SharkifySDModel Base Unet Model Determination
  * Alway use the model_to_run as the base model for unet, if it is in
base_model.json, instead of potentially trying to compile for other base
models.
  * Allow SharkSDPipelines to define a 'favor_base_models' @classmethod,
answering a list of sane base model names for the pipeline. Exclude base
models not in that list from compilation attempts when trying to determine
a base unet model.
  *  Add a 'favor_base_models' method for both Normal and SDXL Txt2Img
Pipelines. Define the method as answering 'None' in the base class.
2024-01-06 11:59:18 -06:00
Ean Garvey
0a6f6fad86 (1.0) Fix non-square controlnet dimensions. (#2056)
* Fix non-square controlnet dims

* Hide batch size slider in txt2img
2024-01-04 21:33:55 -06:00
Ean Garvey
6853a33728 Pin iree versions and fix quant matmul flags. (#2055)
Restricts quantized matmul reassociation flags to cpu compiles of llama2 and pins IREE versions for shark 1.0
2024-01-04 14:22:54 -06:00
Stefan Kapusniak
3887d83f5d (Shark 1.0) UI: Upgrade to gradio 4.12.0 and fix breakage (#2051)
* (Shark 1.0) UI: Upgrade to gradio 4.12.0 and fix breakage

* Upgrade Shark 1.0 to gradio 4.12.0
* Add javascript workaround for gradio currently ignoring @media rules in custom CSS. This fixes UI not showing at full width on desktop (>1536px width).

* (Shark 1.0) UI: Re-enable gallery download buttons

* Re-enable gallery download buttons, as this is now working again in Gradio 4.12
2024-01-03 19:00:08 -06:00
Stefan Kapusniak
8d9b5b3afa SD/UI: Merge Lora Selection Boxes, Add LoRA Strength (#2052)
* Merges LoRA selection in the UI into a single selection, rather than
one for LoRAs under ./models and another for Hugging Face Id
* Add LoRA strength to UI and pipeline parameters.
* Add a `--lora_strength` command line argument.
* Bake LoRA strength into .vmfb naming when a LoRA is specified.
* Use LoRA embedded alpha values and (up tensor dimension *
LoRA strength) for final alpha when applying LoRA weights rather
than a hardcoded value of 0.75
* Adds additional cases to the LoRA weight application that are
present for weight application in the Kohya scripts.
* Include lora strength when reading and writing png metadata.
* Allow lora_strength to be set above 1.0 in the UI, so similar effects
to the prior (overdriven alpha) implementation can be obtained.
2024-01-03 18:59:47 -06:00
xzuyn
16c03e4b44 fix for TypeError: Image2ImagePipeline.generate_images() missing 1 required positional argument: 'images' (#2053) 2024-01-03 18:54:54 -06:00
Stefan Kapusniak
17dab8334d Setup a separate .shark1.venv for Shark-1.0 (#2043)
* Updates setup_venv.ps1 to create and use ./shark1.venv/
* Updates setup_venv.sh to default to creating ./shark1.venv/
2023-12-16 23:18:29 -06:00
38 changed files with 580 additions and 393 deletions

View File

@@ -74,80 +74,3 @@ jobs:
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
with:
release_id: ${{ steps.create_release.outputs.id }}
linux-build:
runs-on: a100
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
backend: [IREE, SHARK]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
with:
python-version: ${{ matrix.python-version }}
- name: Setup pip cache
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
python -m pip install --upgrade pip
python -m pip install flake8 pytest toml
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
- name: Lint with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
- name: Build and validate the IREE package
if: ${{ matrix.backend == 'IREE' }}
continue-on-error: true
run: |
cd $GITHUB_WORKSPACE
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
source iree.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt
if !(grep -Fxq " failed" pytest_results.txt)
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*
- name: Build and validate the SHARK Runtime package
if: ${{ matrix.backend == 'SHARK' }}
run: |
cd $GITHUB_WORKSPACE
./setup_venv.sh
source shark.venv/bin/activate
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
SHARK_PACKAGE_VERSION=${package_version} \
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
# Install the built wheel
pip install ./wheelhouse/nodai*
# Validate the Models
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
tail -n 1 |
tee -a pytest_results.txt

View File

@@ -2075,6 +2075,10 @@ class UnshardedVicuna(VicunaBase):
f"Compiling for device : {self.device}"
f"{'://' + str(self.device_id) if self.device_id is not None else ''}"
)
if "cpu" in self.device:
self.extra_args.extend("--iree-llvmcpu-enable-quantized-matmul-reassociation")
self.extra_args.extend("--iree-global-opt-enable-quantized-matmul-reassociation")
shark_module = SharkInference(
mlir_module=combined_module,
device=self.device,

View File

@@ -61,6 +61,7 @@ datas += [
("src/utils/resources/opt_flags.json", "resources"),
("src/utils/resources/base_model.json", "resources"),
("web/ui/css/*", "ui/css"),
("web/ui/js/*", "ui/js"),
("web/ui/logos/*", "logos"),
(
"../language_models/src/pipelines/minigpt4_utils/configs/*",

View File

@@ -159,8 +159,10 @@ class SharkifyStableDiffusionModel:
is_sdxl: bool = False,
stencils: list[str] = [],
use_lora: str = "",
lora_strength: float = 0.75,
use_quantize: str = None,
return_mlir: bool = False,
favored_base_models=None,
):
self.check_params(max_len, width, height)
self.max_len = max_len
@@ -190,6 +192,7 @@ class SharkifyStableDiffusionModel:
)
self.model_id = model_id if custom_weights == "" else custom_weights
self.favored_base_models = favored_base_models
self.custom_vae = custom_vae
self.precision = precision
self.base_vae = use_base_vae
@@ -216,8 +219,14 @@ class SharkifyStableDiffusionModel:
self.is_upscaler = is_upscaler
self.stencils = [get_stencil_model_id(x) for x in stencils]
if use_lora != "":
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
self.model_name = (
self.model_name
+ "_"
+ get_path_stem(use_lora)
+ f"@{int(lora_strength*100)}"
)
self.use_lora = use_lora
self.lora_strength = lora_strength
self.model_name = self.get_extended_name_for_all_model()
self.debug = debug
@@ -534,6 +543,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -542,7 +552,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
update_lora_weight(
self.unet, use_lora, "unet", lora_strength
)
self.in_channels = self.unet.config.in_channels
self.train(False)
@@ -818,6 +830,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.unet = UNet2DConditionModel.from_pretrained(
@@ -826,7 +839,9 @@ class SharkifyStableDiffusionModel:
low_cpu_mem_usage=low_cpu_mem_usage,
)
if use_lora != "":
update_lora_weight(self.unet, use_lora, "unet")
update_lora_weight(
self.unet, use_lora, "unet", lora_strength
)
self.in_channels = self.unet.config.in_channels
self.train(False)
if (
@@ -1058,6 +1073,7 @@ class SharkifyStableDiffusionModel:
model_id=self.model_id,
low_cpu_mem_usage=False,
use_lora=self.use_lora,
lora_strength=self.lora_strength,
):
super().__init__()
self.text_encoder = CLIPTextModel.from_pretrained(
@@ -1067,7 +1083,10 @@ class SharkifyStableDiffusionModel:
)
if use_lora != "":
update_lora_weight(
self.text_encoder, use_lora, "text_encoder"
self.text_encoder,
use_lora,
"text_encoder",
lora_strength,
)
def forward(self, input):
@@ -1271,6 +1290,10 @@ class SharkifyStableDiffusionModel:
compiled_unet = None
unet_inputs = base_models[model]
# if the model to run *is* a base model, then we should treat it as such
if self.model_to_run in unet_inputs:
self.base_model_id = self.model_to_run
if self.base_model_id != "":
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[self.base_model_id]
@@ -1279,7 +1302,16 @@ class SharkifyStableDiffusionModel:
model, use_large=use_large, base_model=self.base_model_id
)
else:
for model_id in unet_inputs:
# restrict base models to check if we were given a specific list of valid ones
allowed_base_model_ids = unet_inputs
if self.favored_base_models != None:
allowed_base_model_ids = self.favored_base_models
print(f"self.favored_base_models: {self.favored_base_models}")
print(f"allowed_base_model_ids: {allowed_base_model_ids}")
# try compiling with each base model until we find one that works (of not)
for model_id in allowed_base_model_ids:
self.base_model_id = model_id
self.inputs["unet"] = self.get_input_info_for(
unet_inputs[model_id]
@@ -1292,7 +1324,7 @@ class SharkifyStableDiffusionModel:
except Exception as e:
print(e)
print(
"Retrying with a different base model configuration"
f"Retrying with a different base model configuration, as {model_id} did not work"
)
continue

View File

@@ -56,9 +56,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):

View File

@@ -51,9 +51,12 @@ class InpaintPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):

View File

@@ -52,9 +52,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.vae_encode = None
def load_vae_encode(self):

View File

@@ -64,10 +64,13 @@ class StencilPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
controlnet_names: list[str],
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.controlnet = [None] * len(controlnet_names)
self.controlnet_512 = [None] * len(controlnet_names)
self.controlnet_id = [str] * len(controlnet_names)
@@ -263,8 +266,8 @@ class StencilPipeline(StableDiffusionPipeline):
latent_model_input_1 = latent_model_input
# Multicontrolnet
width = latent_model_input_1.shape[2]
height = latent_model_input_1.shape[3]
height = latent_model_input_1.shape[2]
width = latent_model_input_1.shape[3]
dtype = latent_model_input_1.dtype
control_acc = (
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3

View File

@@ -49,9 +49,19 @@ class Text2ImagePipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
@classmethod
def favored_base_models(cls, model_id):
return [
"stabilityai/stable-diffusion-2-1",
"CompVis/stable-diffusion-v1-4",
]
def prepare_latents(
self,

View File

@@ -51,12 +51,28 @@ class Text2ImageSDXLPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
is_fp32_vae: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.is_fp32_vae = is_fp32_vae
@classmethod
def favored_base_models(cls, model_id):
if "turbo" in model_id:
return [
"stabilityai/sdxl-turbo",
"stabilityai/stable-diffusion-xl-base-1.0",
]
else:
return [
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
def prepare_latents(
self,
batch_size,

View File

@@ -94,9 +94,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
):
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
super().__init__(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
self.low_res_scheduler = low_res_scheduler
self.status = SD_STATE_IDLE

View File

@@ -65,6 +65,7 @@ class StableDiffusionPipeline:
sd_model: SharkifyStableDiffusionModel,
import_mlir: bool,
use_lora: str,
lora_strength: float,
ondemand: bool,
is_f32_vae: bool = False,
):
@@ -81,6 +82,7 @@ class StableDiffusionPipeline:
self.scheduler = scheduler
self.import_mlir = import_mlir
self.use_lora = use_lora
self.lora_strength = lora_strength
self.ondemand = ondemand
self.is_f32_vae = is_f32_vae
# TODO: Find a better workaround for fetching base_model_id early
@@ -92,6 +94,10 @@ class StableDiffusionPipeline:
self.unload_unet()
self.tokenizer = get_tokenizer()
def favored_base_models(cls, model_id):
# all base models can be candidate base models for unet compilation
return None
def load_clip(self):
if self.text_encoder is not None:
return
@@ -647,6 +653,7 @@ class StableDiffusionPipeline:
stencils: list[str] = [],
# stencil_images: list[Image] = []
use_lora: str = "",
lora_strength: float = 0.75,
ddpm_scheduler: DDPMScheduler = None,
use_quantize=None,
):
@@ -664,6 +671,9 @@ class StableDiffusionPipeline:
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
print(f"model_id", model_id)
print(f"ckpt_loc", ckpt_loc)
print(f"favored_base_models:", cls.favored_base_models(model_id))
sd_model = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
@@ -682,7 +692,11 @@ class StableDiffusionPipeline:
is_sdxl=is_sdxl,
stencils=stencils,
use_lora=use_lora,
lora_strength=lora_strength,
use_quantize=use_quantize,
favored_base_models=cls.favored_base_models(
model_id if model_id != "" else ckpt_loc
),
)
if cls.__name__ in ["UpscalerPipeline"]:
@@ -692,12 +706,19 @@ class StableDiffusionPipeline:
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
)
if cls.__name__ == "StencilPipeline":
return cls(
scheduler, sd_model, import_mlir, use_lora, ondemand, stencils
scheduler,
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
stencils,
)
if cls.__name__ == "Text2ImageSDXLPipeline":
is_fp32_vae = True if "16" not in custom_vae else False
@@ -706,11 +727,14 @@ class StableDiffusionPipeline:
sd_model,
import_mlir,
use_lora,
lora_strength,
ondemand,
is_fp32_vae,
)
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
return cls(
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
)
# #####################################################
# Implements text embeddings with weights from prompts

View File

@@ -435,6 +435,13 @@ p.add_argument(
"file (~3 MB).",
)
p.add_argument(
"--lora_strength",
type=float,
default=1.0,
help="Strength (alpha) scaling factor to use when applying LoRA weights",
)
p.add_argument(
"--use_quantize",
type=str,

View File

@@ -6,6 +6,7 @@ from PIL import PngImagePlugin
from PIL import Image
from datetime import datetime as dt
from csv import DictWriter
from dataclasses import dataclass
from pathlib import Path
import numpy as np
from random import (
@@ -638,30 +639,51 @@ def convert_original_vae(vae_checkpoint):
return converted_vae_checkpoint
def processLoRA(model, use_lora, splitting_prefix):
@dataclass
class LoRAweight:
up: torch.tensor
down: torch.tensor
mid: torch.tensor
alpha: torch.float32 = 1.0
def processLoRA(model, use_lora, splitting_prefix, lora_strength):
state_dict = ""
if ".safetensors" in use_lora:
state_dict = load_file(use_lora)
else:
state_dict = torch.load(use_lora)
alpha = 0.75
visited = []
# directly update weight in model
process_unet = "te" not in splitting_prefix
# gather the weights from the LoRA in a more convenient form, assumes
# everything will have an up.weight. Unsure if this is a safe assumption.
weight_dict: dict[str, LoRAweight] = {}
for key in state_dict:
if ".alpha" in key or key in visited:
continue
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
stem = key.split("up.weight")[0]
weight_key = stem.removesuffix(".lora_")
weight_key = weight_key.removesuffix("_lora_")
weight_key = weight_key.removesuffix(".lora_linear_layer.")
if weight_key not in weight_dict:
weight_dict[weight_key] = LoRAweight(
state_dict[f"{stem}up.weight"],
state_dict[f"{stem}down.weight"],
state_dict.get(f"{stem}mid.weight", None),
state_dict[f"{weight_key}.alpha"]
/ state_dict[f"{stem}up.weight"].shape[1]
if f"{weight_key}.alpha" in state_dict
else 1.0,
)
# Directly update weight in model
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
# and similar code in https://github.com/huggingface/diffusers/issues/3064
# TODO: handle mid weights (how do they even work?)
for key, lora_weight in weight_dict.items():
curr_layer = model
if ("text" not in key and process_unet) or (
"text" in key and not process_unet
):
layer_infos = (
key.split(".")[0].split(splitting_prefix)[-1].split("_")
)
else:
continue
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
# find the target layer
temp_name = layer_infos.pop(0)
@@ -678,46 +700,46 @@ def processLoRA(model, use_lora, splitting_prefix):
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = (
state_dict[pair_keys[0]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
)
weight = curr_layer.weight.data
scale = lora_weight.alpha * lora_strength
if len(weight.size()) == 2:
if len(lora_weight.up.shape) == 4:
weight_up = (
lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
)
weight_down = (
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
change = (
torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
)
else:
change = torch.mm(lora_weight.up, lora_weight.down)
elif lora_weight.down.size()[2:4] == (1, 1):
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
weight_down = (
state_dict[pair_keys[1]]
.squeeze(3)
.squeeze(2)
.to(torch.float32)
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
)
curr_layer.weight.data += alpha * torch.mm(
weight_up, weight_down
).unsqueeze(2).unsqueeze(3)
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
# update visited list
for item in pair_keys:
visited.append(item)
change = torch.nn.functional.conv2d(
lora_weight.down.permute(1, 0, 2, 3),
lora_weight.up,
).permute(1, 0, 2, 3)
curr_layer.weight.data += change * scale
return model
def update_lora_weight_for_unet(unet, use_lora):
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
extensions = [".bin", ".safetensors", ".pt"]
if not any([extension in use_lora for extension in extensions]):
# We assume if it is a HF ID with standalone LoRA weights.
unet.load_attn_procs(use_lora)
print(
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
)
return unet
main_file_name = get_path_stem(use_lora)
@@ -733,16 +755,21 @@ def update_lora_weight_for_unet(unet, use_lora):
try:
dir_name = os.path.dirname(use_lora)
unet.load_attn_procs(dir_name, weight_name=main_file_name)
print(
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
)
return unet
except:
return processLoRA(unet, use_lora, "lora_unet_")
print(f"updated unet weights manually from LoRA: {use_lora}")
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
def update_lora_weight(model, use_lora, model_name):
def update_lora_weight(model, use_lora, model_name, lora_strength):
if "unet" in model_name:
return update_lora_weight_for_unet(model, use_lora)
return update_lora_weight_for_unet(model, use_lora, lora_strength)
try:
return processLoRA(model, use_lora, "lora_te_")
print(f"updating CLIP weights from LoRA: {use_lora}")
return processLoRA(model, use_lora, "lora_te_", lora_strength)
except:
return None
@@ -898,7 +925,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
img_lora = None
if args.use_lora:
img_lora = Path(os.path.basename(args.use_lora)).stem
img_lora = f"{Path(os.path.basename(args.use_lora)).stem}:{args.lora_strength}"
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")

View File

@@ -12,7 +12,6 @@ from apps.stable_diffusion.web.api.utils import (
decode_base64_to_image,
get_model_from_request,
get_scheduler_from_request,
get_lora_params,
get_device,
GenerationInputData,
GenerationResponseData,
@@ -180,7 +179,6 @@ def txt2img_api(InputData: Txt2ImgInputData):
scheduler = get_scheduler_from_request(
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
)
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
print(
f"Prompt: {InputData.prompt}, "
@@ -208,8 +206,8 @@ def txt2img_api(InputData: Txt2ImgInputData):
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
use_hiresfix=InputData.enable_hr,
@@ -270,7 +268,6 @@ def img2img_api(
fallback_model="stabilityai/stable-diffusion-2-1-base",
)
scheduler = get_scheduler_from_request(InputData, "img2img")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
mask_image = (
@@ -308,8 +305,8 @@ def img2img_api(
use_stencil=InputData.use_stencil,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
resample_type=frozen_args.resample_type,
@@ -358,7 +355,6 @@ def inpaint_api(
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "inpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.image)
mask = decode_base64_to_image(InputData.mask)
@@ -393,8 +389,8 @@ def inpaint_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
@@ -448,7 +444,6 @@ def outpaint_api(
fallback_model="stabilityai/stable-diffusion-2-inpainting",
)
scheduler = get_scheduler_from_request(InputData, "outpaint")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
@@ -484,8 +479,8 @@ def outpaint_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)
@@ -531,7 +526,6 @@ def upscaler_api(
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
)
scheduler = get_scheduler_from_request(InputData, "upscaler")
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
init_image = decode_base64_to_image(InputData.init_images[0])
@@ -563,8 +557,8 @@ def upscaler_api(
max_length=frozen_args.max_length,
save_metadata_to_json=frozen_args.save_metadata_to_json,
save_metadata_to_png=frozen_args.write_metadata_to_png,
lora_weights=lora_weights,
lora_hf_id=lora_hf_id,
lora_weights=frozen_args.use_lora,
lora_strength=frozen_args.lora_strength,
ondemand=frozen_args.ondemand,
repeatable_seeds=False,
)

View File

@@ -191,17 +191,6 @@ def get_scheduler_from_request(
)
def get_lora_params(use_lora: str):
# TODO: since the inference functions in the webui, which we are
# still calling into for the api, jam these back together again before
# handing them off to the pipeline, we should remove this nonsense
# and unify their selection in the UI and command line args proper
if use_lora in get_custom_model_files("lora"):
return (use_lora, "")
return ("None", use_lora)
def get_device(device_str: str):
# first substring match in the list available devices, with first
# device when none are matched

View File

@@ -237,9 +237,13 @@ if __name__ == "__main__":
)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
css=dark_theme,
js=gradio_workarounds,
analytics_enabled=False,
title="SHARK AI Studio",
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they

View File

@@ -55,3 +55,10 @@ def lora_changed(lora_file):
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]
def lora_strength_changed(strength):
if strength > 1.0:
return gr.Number(elem_classes="value-out-of-range")
else:
return gr.Number(elem_classes="")

View File

@@ -117,7 +117,7 @@ body {
height: 100% !important;
}
/* display in full width for desktop devices */
/* display in full width for desktop devices, but see below */
@media (min-width: 1536px)
{
.gradio-container {
@@ -125,6 +125,15 @@ body {
}
}
/* media rules in custom css are don't appear to be applied in
gradio versions > 4.7, so we have to define a class which
we will manually need add and remove using javascript.
Remove this once this fixed in gradio.
*/
.gradio-container-size-full {
max-width: var(--size-full) !important;
}
.gradio-container .contain {
padding: 0 var(--size-4) !important;
}
@@ -182,6 +191,8 @@ footer {
aspect-ratio: unset;
max-height: calc(55vh - (2 * var(--spacing-lg)));
}
/* fix width and height of gallery items when on very large desktop screens, but see below */
@media (min-width: 1921px) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
#gallery .grid-wrap, #gallery .preview{
@@ -193,6 +204,20 @@ footer {
max-height: 770px !important;
}
}
/* media rules in custom css are don't appear to be applied in
gradio versions > 4.7, so we have to define classes which
we will manually need add and remove using javascript.
Remove this once this fixed in gradio.
*/
.gallery-force-height768 .grid-wrap, .gallery-force-height768 .preview {
min-height: calc(768px + 4px + var(--size-14)) !important;
max-height: calc(768px + 4px + var(--size-14)) !important;
}
.gallery-limit-height768 .thumbnail-item.thumbnail-lg {
max-height: 770px !important;
}
/* Don't upscale when viewing in solo image mode */
#gallery .preview img {
object-fit: scale-down;
@@ -244,6 +269,11 @@ footer {
padding-right: 8px;
}
/* number input value is out of range */
.value-out-of-range input[type="number"] {
color: red !important;
}
/* reduced animation load when generating */
.generating {
animation-play-state: paused !important;
@@ -280,7 +310,7 @@ footer {
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}

View File

@@ -21,7 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -74,7 +77,7 @@ def img2img_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
resample_type: str,
@@ -141,9 +144,8 @@ def img2img_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -176,6 +178,7 @@ def img2img_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=stencils,
ondemand=ondemand,
)
@@ -228,6 +231,7 @@ def img2img_inf(
stencils=stencils,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -249,6 +253,7 @@ def img2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -592,20 +597,20 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
value="None",
choices=choices,
)
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
value=512,
step=1,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
maximum=768,
value=512,
step=1,
step=8,
visible=False,
)
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=768,
value=512,
step=8,
visible=False,
)
make_canvas = gr.Button(
@@ -705,17 +710,17 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
maximum=1024,
maximum=768,
value=512,
step=1,
step=8,
visible=False,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
maximum=1024,
maximum=768,
value=512,
step=1,
step=8,
visible=False,
)
make_canvas = gr.Button(
@@ -806,28 +811,25 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
i2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
lora_weights = gr.Dropdown(
allow_custom_value=True,
label=f"Standalone LoRA Weights",
info=i2i_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -957,8 +959,6 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
elem_id="gallery",
columns=2,
object_fit="contain",
# TODO: Re-enable download when fixed in Gradio
show_download_button=False,
)
std_output = gr.Textbox(
value=f"{i2i_model_info}\n"
@@ -1013,7 +1013,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
resample_type,
@@ -1054,3 +1054,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -21,7 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -109,7 +112,7 @@ def inpaint_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: int,
):
@@ -150,9 +153,8 @@ def inpaint_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -171,6 +173,7 @@ def inpaint_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -215,6 +218,7 @@ def inpaint_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -350,28 +354,25 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
inpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=inpaint_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -558,7 +559,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -622,3 +623,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -0,0 +1,49 @@
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
() => {
console.log(`innerWidth: ${window.innerWidth}` )
// 1536px rules
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
function handleWidth1536(event) {
// display in full width for desktop devices
document.querySelectorAll(".gradio-container")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gradio-container-size-full");
} else {
node.classList.remove("gradio-container-size-full")
}
});
}
mediaQuery1536.addEventListener("change", handleWidth1536);
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
// 1921px rules
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
function handleWidth1921(event) {
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
document.querySelectorAll("#gallery")
.forEach( (node) => {
if (event.matches) {
node.classList.add("gallery-force-height768");
node.classList.add("gallery-limit-height768");
} else {
node.classList.remove("gallery-force-height768");
node.classList.remove("gallery-limit-height768");
}
});
}
mediaQuery1921.addEventListener("change", handleWidth1921);
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
}

View File

@@ -238,9 +238,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
max_length,
training_images_dir,
output_loc,
get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
),
get_custom_vae_or_lora_weights(lora_weights, "lora"),
],
outputs=[std_output],
show_progress="minimal" if args.progress_bar else "none",

View File

@@ -4,7 +4,10 @@ import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -60,7 +63,7 @@ def outpaint_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
@@ -100,9 +103,8 @@ def outpaint_inf(
if custom_vae != "None":
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
@@ -121,6 +123,7 @@ def outpaint_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -163,6 +166,7 @@ def outpaint_inf(
args.use_base_vae,
args.use_tuned,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -296,28 +300,25 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
outpaint_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=outpaint_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -469,8 +470,6 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
elem_id="gallery",
columns=[2],
object_fit="contain",
# TODO: Re-enable download when fixed in Gradio
show_download_button=False,
)
std_output = gr.Textbox(
value=f"{outpaint_model_info}\n"
@@ -527,7 +526,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -556,3 +555,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -92,8 +92,6 @@ with gr.Blocks() as outputgallery_web:
visible=False,
show_label=True,
columns=4,
# TODO: Re-enable download when fixed in Gradio
show_download_button=False,
)
with gr.Column(scale=4):

View File

@@ -177,8 +177,6 @@ def chat(
)
_extra_args = _extra_args + [
"--iree-global-opt-enable-quantized-matmul-reassociation",
"--iree-llvmcpu-enable-quantized-matmul-reassociation",
"--iree-opt-const-eval=false",
"--iree-opt-data-tiling=false",
]

View File

@@ -15,7 +15,10 @@ from apps.stable_diffusion.web.ui.utils import (
cancel_sd,
set_model_default_configs,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -59,7 +62,7 @@ def txt2img_sdxl_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
@@ -105,9 +108,8 @@ def txt2img_sdxl_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -123,6 +125,7 @@ def txt2img_sdxl_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=None,
ondemand=ondemand,
)
@@ -150,6 +153,7 @@ def txt2img_sdxl_inf(
if args.hf_model_id
else "stabilityai/stable-diffusion-xl-base-1.0"
)
global_obj.set_schedulers(get_schedulers(model_id))
scheduler_obj = global_obj.get_scheduler(scheduler)
if global_obj.get_cfg_obj().ondemand:
@@ -171,6 +175,7 @@ def txt2img_sdxl_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
use_quantize=args.use_quantize,
ondemand=global_obj.get_cfg_obj().ondemand,
)
@@ -276,12 +281,14 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
label=f"VAE Models",
info=t2i_sdxl_vae_info,
elem_id="custom_model",
value="None",
value="madebyollin/sdxl-vae-fp16-fix",
choices=[
None,
"madebyollin/sdxl-vae-fp16-fix",
]
+ get_custom_model_files("vae"),
+ get_custom_model_files(
"vae", custom_checkpoint_type="sdxl"
),
allow_custom_value=True,
scale=4,
)
@@ -316,28 +323,25 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_sdxl_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=t2i_sdxl_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -374,7 +378,7 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
height = gr.Slider(
512,
1024,
value=1024,
value=768,
step=256,
label="Height",
visible=True,
@@ -383,7 +387,7 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
width = gr.Slider(
512,
1024,
value=1024,
value=768,
step=256,
label="Width",
visible=True,
@@ -478,8 +482,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
elem_id="gallery",
columns=[2],
object_fit="scale_down",
# TODO: Re-enable download when fixed in Gradio
show_download_button=False,
)
std_output = gr.Textbox(
value=f"{t2i_sdxl_model_info}\n"
@@ -539,7 +541,7 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -609,7 +611,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
height,
txt2img_sdxl_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
@@ -624,7 +625,6 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
height,
txt2img_sdxl_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
)
@@ -651,3 +651,11 @@ with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -18,7 +18,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -44,7 +47,7 @@ all_gradio_labels = [
"prompt",
"negative_prompt",
"lora_weights",
"lora_hf_id",
"lora_strength",
"scheduler",
"save_metadata_to_png",
"save_metadata_to_json",
@@ -91,7 +94,7 @@ def txt2img_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
use_hiresfix: bool,
@@ -138,9 +141,8 @@ def txt2img_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -156,6 +158,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -207,6 +210,7 @@ def txt2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -256,6 +260,7 @@ def txt2img_inf(
width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -288,6 +293,7 @@ def txt2img_inf(
low_cpu_mem_usage=args.low_cpu_mem_usage,
debug=args.import_debug if args.import_mlir else False,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -311,6 +317,7 @@ def txt2img_inf(
cpu_scheduling,
args.max_embeddings_multiples,
stencils=[],
images=None,
control_mode=None,
resample_type=resample_type,
)
@@ -384,7 +391,7 @@ def load_settings():
loaded_settings.get("prompt", args.prompts[0]),
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
loaded_settings.get("lora_weights", "None"),
loaded_settings.get("lora_hf_id", ""),
loaded_settings.get("lora_strength", args.lora_strength),
loaded_settings.get("scheduler", args.scheduler),
loaded_settings.get(
"save_metadata_to_png", args.write_metadata_to_png
@@ -494,28 +501,25 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
t2i_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=t2i_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value=default_settings.get("lora_weights"),
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value=default_settings.get("lora_hf_id"),
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=default_settings.get("lora_strength"),
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -619,6 +623,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
step=1,
label=default_settings.get("batch_size"),
interactive=True,
visible=False,
)
repeatable_seeds = gr.Checkbox(
default_settings.get("repeatable_seeds"),
@@ -688,8 +693,6 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
elem_id="gallery",
columns=[2],
object_fit="contain",
# TODO: Re-enable download when fixed in Gradio
show_download_button=False,
)
std_output = gr.Textbox(
value=f"{t2i_model_info}\n"
@@ -735,7 +738,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
lora_strength,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
@@ -768,7 +771,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
prompt,
negative_prompt,
lora_weights,
lora_hf_id,
lora_strength,
scheduler,
save_metadata_to_png,
save_metadata_to_json,
@@ -812,7 +815,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
use_hiresfix,
@@ -855,7 +858,6 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
height,
txt2img_custom_model,
lora_weights,
lora_hf_id,
custom_vae,
],
outputs=[
@@ -870,7 +872,7 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
height,
txt2img_custom_model,
lora_weights,
lora_hf_id,
lora_strength,
custom_vae,
],
)
@@ -901,3 +903,11 @@ with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -13,7 +13,10 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.common_ui_events import (
lora_changed,
lora_strength_changed,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
@@ -53,7 +56,7 @@ def upscaler_inf(
save_metadata_to_json: bool,
save_metadata_to_png: bool,
lora_weights: str,
lora_hf_id: str,
lora_strength: float,
ondemand: bool,
repeatable_seeds: bool,
):
@@ -100,9 +103,8 @@ def upscaler_inf(
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
args.use_lora = get_custom_vae_or_lora_weights(
lora_weights, lora_hf_id, "lora"
)
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
args.lora_strength = lora_strength
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
@@ -120,6 +122,7 @@ def upscaler_inf(
args.width,
device,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
stencils=[],
ondemand=ondemand,
)
@@ -159,6 +162,7 @@ def upscaler_inf(
args.use_tuned,
low_cpu_mem_usage=args.low_cpu_mem_usage,
use_lora=args.use_lora,
lora_strength=args.lora_strength,
ondemand=args.ondemand,
)
)
@@ -318,28 +322,25 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
)
with gr.Accordion(label="LoRA Options", open=False):
with gr.Row():
# janky fix for overflowing text
upscaler_lora_info = (
str(get_custom_model_path("lora"))
).replace("\\", "\n\\")
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
lora_weights = gr.Dropdown(
label=f"Standalone LoRA Weights",
info=upscaler_lora_info,
label=f"LoRA Weights",
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
elem_id="lora_weights",
value="None",
choices=["None"] + get_custom_model_files("lora"),
allow_custom_value=True,
scale=3,
)
lora_hf_id = gr.Textbox(
elem_id="lora_hf_id",
placeholder="Select 'None' in the Standalone LoRA "
"weights dropdown on the left if you want to use "
"a standalone HuggingFace model ID for LoRA here "
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
value="",
label="HuggingFace Model ID",
lines=3,
lora_strength = gr.Number(
label="LoRA Strength",
info="Will be baked into the .vmfb",
step=0.01,
# number is checked on change so to allow 0.n values
# we have to allow 0 or you can't type 0.n in
minimum=0.0,
maximum=2.0,
value=args.lora_strength,
scale=1,
)
with gr.Row():
lora_tags = gr.HTML(
@@ -469,8 +470,6 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
elem_id="gallery",
columns=[2],
object_fit="contain",
# TODO: Re-enable download when fixed in Gradio
show_download_button=False,
)
std_output = gr.Textbox(
value=f"{upscaler_model_info}\n"
@@ -523,7 +522,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
save_metadata_to_json,
save_metadata_to_png,
lora_weights,
lora_hf_id,
lora_strength,
ondemand,
repeatable_seeds,
],
@@ -552,3 +551,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
outputs=[lora_tags],
queue=True,
)
lora_strength.change(
fn=lora_strength_changed,
inputs=lora_strength,
outputs=lora_strength,
queue=False,
show_progress=False,
)

View File

@@ -33,6 +33,7 @@ class Config:
width: int
device: str
use_lora: str
lora_strength: float
stencils: list[str]
ondemand: str # should this be expecting a bool instead?
@@ -180,14 +181,16 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
return sorted(ckpt_files, key=str.casefold)
def get_custom_vae_or_lora_weights(weights, hf_id, model):
use_weight = ""
if weights == "None" and not hf_id:
def get_custom_vae_or_lora_weights(weights, model):
if weights == "None":
use_weight = ""
elif not hf_id:
use_weight = get_custom_model_pathfile(weights, model)
else:
use_weight = hf_id
custom_weights = get_custom_model_pathfile(str(weights), model)
if os.path.isfile(custom_weights):
use_weight = custom_weights
else:
use_weight = weights
return use_weight

View File

@@ -122,20 +122,26 @@ def find_vae_from_png_metadata(
def find_lora_from_png_metadata(
key: str, metadata: dict[str, str | int]
) -> tuple[str, str]:
lora_hf_id = ""
) -> tuple[str, float]:
lora_custom = ""
lora_strength = 1.0
if key in metadata:
lora_file = metadata[key]
split_metadata = metadata[key].split(":")
lora_file = split_metadata[0]
if len(split_metadata) == 2:
try:
lora_strength = float(split_metadata[1])
except ValueError:
pass
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
# If nothing had matched, check vendor/hf_model_id
if not lora_custom and lora_file.count("/"):
lora_hf_id = lora_file
lora_custom = lora_file
# LoRA input is optional, should not print or throw an error if missing
return lora_custom, lora_hf_id
return lora_custom, lora_strength
def import_png_metadata(
@@ -150,7 +156,6 @@ def import_png_metadata(
height,
custom_model,
custom_lora,
hf_lora_id,
custom_vae,
):
try:
@@ -160,9 +165,10 @@ def import_png_metadata(
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
"Model", metadata
)
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
"LoRA", metadata
)
(
custom_lora,
custom_lora_strength,
) = find_lora_from_png_metadata("LoRA", metadata)
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
negative_prompt = metadata["Negative prompt"]
@@ -177,12 +183,8 @@ def import_png_metadata(
elif "Model" in metadata and png_hf_model_id:
custom_model = png_hf_model_id
if "LoRA" in metadata and lora_custom_model:
custom_lora = lora_custom_model
hf_lora_id = ""
if "LoRA" in metadata and lora_hf_model_id:
if "LoRA" in metadata and not custom_lora:
custom_lora = "None"
hf_lora_id = lora_hf_model_id
if "VAE" in metadata and vae_custom_model:
custom_vae = vae_custom_model
@@ -215,6 +217,6 @@ def import_png_metadata(
height,
custom_model,
custom_lora,
hf_lora_id,
custom_lora_strength,
custom_vae,
)

View File

@@ -6,8 +6,8 @@ requires = [
"numpy>=1.22.4",
"torch-mlir>=20230620.875",
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
"iree-compiler==20231212.*",
"iree-runtime==20231212.*",
]
build-backend = "setuptools.build_meta"

View File

@@ -26,7 +26,7 @@ diffusers
accelerate
scipy
ftfy
gradio==4.8.0
gradio==4.12.0
altair
omegaconf
# 0.3.2 doesn't have binaries for arm64

View File

@@ -11,8 +11,8 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
backend_deps = []
if "NO_BACKEND" in os.environ.keys():
backend_deps = [
"iree-compiler>=20221022.190",
"iree-runtime>=20221022.190",
"iree-compiler==20231212.*",
"iree-runtime==20231212.*",
]
setup(

View File

@@ -7,13 +7,13 @@
It checks the Python version installed and installs any required build
dependencies into a Python virtual environment.
If that environment does not exist, it creates it.
.PARAMETER update-src
git pulls latest version
.PARAMETER force
removes and recreates venv to force update of all dependencies
.EXAMPLE
.\setup_venv.ps1 --force
@@ -39,11 +39,11 @@ if ($arguments -eq "--force"){
Write-Host "deactivating..."
Deactivate
}
if (Test-Path .\shark.venv\) {
if (Test-Path .\shark1.venv\) {
Write-Host "removing and recreating venv..."
Remove-Item .\shark.venv -Force -Recurse
if (Test-Path .\shark.venv\) {
Remove-Item .\shark1.venv -Force -Recurse
if (Test-Path .\shark1.venv\) {
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
exit 1
}
@@ -83,15 +83,15 @@ if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any
Write-Host "Installing Build Dependencies"
# make sure we really use 3.11 from list, even if it's not the default.
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
else {python -m venv .\shark.venv\}
.\shark.venv\Scripts\activate
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark1.venv\}
else {python -m venv .\shark1.venv\}
.\shark1.venv\Scripts\activate
python -m pip install --upgrade pip
pip install wheel
pip install -r requirements.txt
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler==20231212.* iree-runtime==20231212.*
Write-Host "Building SHARK..."
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
Write-Host "Build and installation completed successfully"
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
Write-Host "Source your venv with ./shark1.venv/Scripts/activate"

View File

@@ -5,7 +5,7 @@
# Environment variables used by the script.
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
# VENV_DIR=myshark.venv #create a venv called myshark.venv
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
# IMPORTER=1 #Install importer deps
# BENCHMARK=1 #Install benchmark deps
@@ -35,7 +35,7 @@ fi
if [[ "$SKIP_VENV" != "1" ]]; then
if [[ -z "${CONDA_PREFIX}" ]]; then
# Not a conda env. So create a new VENV dir
VENV_DIR=${VENV_DIR:-shark.venv}
VENV_DIR=${VENV_DIR:-shark1.venv}
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
@@ -111,7 +111,7 @@ else
fi
if [[ -z "${NO_BACKEND}" ]]; then
echo "Installing ${RUNTIME}..."
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler iree-runtime
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler==20231212.* iree-runtime==20231212.*
else
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
fi

View File

@@ -43,7 +43,6 @@ def get_iree_device_args(device, extra_args=[]):
get_iree_cpu_args()
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args

View File

@@ -139,8 +139,16 @@ def get_vulkan_target_triple(device_name):
triple = f"rdna3-780m-{system_os}"
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
triple = f"rdna3-w7900-{system_os}"
elif any(x in device_name for x in ("AMD", "Radeon")):
elif "7600" in device_name:
triple = f"rdna3-7600-{system_os}"
elif "7700" in device_name:
triple = f"rdna3-7700-{system_os}"
elif any(x in device_name for x in "AMD", "Radeon") and any(x in device_name for x in "6700", "6750"):
triple = f"rdna2-unknown-{system_os}"
elif any(x in device_name for x in "AMD", "Radeon") and "7" in device_name:
triple = f"rdna3-unknown-{system_os}"
elif any(x in device_name for x in "AMD", "Radeon"):
triple = f"rdna3-unknown-{system_os}"
# Intel Targets
elif any(x in device_name for x in ("A770", "A750")):
triple = f"arc-770-{system_os}"