mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-12 23:38:12 -05:00
Compare commits
69 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd3fdeef07 | ||
|
|
80906155f8 | ||
|
|
1c094d96eb | ||
|
|
94805aa3c6 | ||
|
|
cfb494edff | ||
|
|
10294bef33 | ||
|
|
8051b33005 | ||
|
|
10af1f25c7 | ||
|
|
4305153d21 | ||
|
|
ab5639ff66 | ||
|
|
4cf0383768 | ||
|
|
01845d593d | ||
|
|
5b15ceee35 | ||
|
|
04b21295ee | ||
|
|
dda7e8a163 | ||
|
|
7fdd1952ae | ||
|
|
0a6f6fad86 | ||
|
|
6853a33728 | ||
|
|
3887d83f5d | ||
|
|
8d9b5b3afa | ||
|
|
16c03e4b44 | ||
|
|
17dab8334d | ||
|
|
f692a012e1 | ||
|
|
3cc643b2de | ||
|
|
bf70e80d20 | ||
|
|
7159698496 | ||
|
|
7e12d1782a | ||
|
|
bb5f133e1c | ||
|
|
3af0c6c658 | ||
|
|
3322b7264f | ||
|
|
eeb7bdd143 | ||
|
|
2d6f48821d | ||
|
|
c74b55f24e | ||
|
|
1a723645fb | ||
|
|
dfdd3b1f78 | ||
|
|
6384780d16 | ||
|
|
db0c53ae59 | ||
|
|
ce9ce3a7c8 | ||
|
|
d72da3801f | ||
|
|
9c50edc664 | ||
|
|
a1b7110550 | ||
|
|
ff15fd74f6 | ||
|
|
552b2c3ee3 | ||
|
|
795fc33001 | ||
|
|
2910841fe6 | ||
|
|
396a054856 | ||
|
|
5c66948d4f | ||
|
|
ed3dda94c0 | ||
|
|
d31d28b082 | ||
|
|
78c607e1d3 | ||
|
|
666e601dd9 | ||
|
|
ca58908e5b | ||
|
|
1f5b39f56e | ||
|
|
2da31c4109 | ||
|
|
da50a16242 | ||
|
|
ce38d49f05 | ||
|
|
2f780f0d38 | ||
|
|
d051c3a4a7 | ||
|
|
1b11c82c9d | ||
|
|
80a33d427f | ||
|
|
4125a26294 | ||
|
|
905d0103ff | ||
|
|
192b3b2c61 | ||
|
|
8f9adc4a2a | ||
|
|
70817bb50a | ||
|
|
dd37c26d36 | ||
|
|
a708879c6c | ||
|
|
bb1b49eb6f | ||
|
|
f6d41affd9 |
77
.github/workflows/nightly.yml
vendored
77
.github/workflows/nightly.yml
vendored
@@ -74,80 +74,3 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
linux-build:
|
||||
|
||||
runs-on: a100
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
backend: [IREE, SHARK]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
- name: Build and validate the SHARK Runtime package
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
|
||||
12
.github/workflows/test-models.yml
vendored
12
.github/workflows/test-models.yml
vendored
@@ -112,7 +112,7 @@ jobs:
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
|
||||
pytest --benchmark=native --update_tank -k cpu
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
|
||||
python build_tools/vicuna_testing.py
|
||||
@@ -121,9 +121,9 @@ jobs:
|
||||
if: matrix.suite == 'cuda'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
|
||||
pytest --benchmark=native --update_tank -k cuda
|
||||
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
|
||||
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
|
||||
# Disabled due to black image bug
|
||||
@@ -144,10 +144,10 @@ jobs:
|
||||
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
|
||||
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan
|
||||
pytest --update_tank -k vulkan
|
||||
python build_tools/stable_diffusion_testing.py --device=vulkan --no-exit_on_fail
|
||||
|
||||
- name: Validate Vulkan Models (Windows)
|
||||
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
|
||||
|
||||
@@ -25,7 +25,7 @@ from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
from typing import List, Tuple
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ class H2OGPTModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=128,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -69,91 +69,7 @@ class CompiledLMHeadEmbeddingLayer(torch.nn.Module):
|
||||
return torch.tensor(new_hidden_states)
|
||||
|
||||
|
||||
class DecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
output = self.model.forward(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
return (output[0], output[1][0], output[1][1])
|
||||
|
||||
|
||||
class CompiledDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
new_hidden_states, pkv1, pkv2 = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
return tuple(
|
||||
[
|
||||
torch.tensor(new_hidden_states),
|
||||
tuple(
|
||||
[
|
||||
torch.tensor(pkv1),
|
||||
torch.tensor(pkv2),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class EightDecoderLayer(torch.nn.Module):
|
||||
class FourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
@@ -175,163 +91,78 @@ class EightDecoderLayer(torch.nn.Module):
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
if self.falcon_variant == "7b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
)
|
||||
elif self.falcon_variant == "40b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
class CompiledFourWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
@@ -339,12 +170,14 @@ class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
alibi: torch.Tensor = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
@@ -354,24 +187,12 @@ class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import get_vmfb_from_path
|
||||
|
||||
self.falcon_vmfb_path = Path(
|
||||
f"falcon_{self.falcon_variant}_layer_{self.layer_id}_{self.precision}_{self.device}.vmfb"
|
||||
)
|
||||
print("vmfb path for layer: ", self.falcon_vmfb_path)
|
||||
self.model = get_vmfb_from_path(
|
||||
self.falcon_vmfb_path,
|
||||
self.device,
|
||||
"linalg",
|
||||
device_id=self.device_index,
|
||||
)
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
@@ -383,196 +204,452 @@ class CompiledEightDecoderLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
del self.model
|
||||
|
||||
if self.falcon_variant == "7b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class TwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(self, decoder_layer_model, falcon_variant):
|
||||
super().__init__()
|
||||
self.model = decoder_layer_model
|
||||
self.falcon_variant = falcon_variant
|
||||
|
||||
def forward(self, hidden_states, attention_mask):
|
||||
new_pkvs = []
|
||||
for layer in self.model:
|
||||
outputs = layer(
|
||||
hidden_states=hidden_states,
|
||||
alibi=None,
|
||||
attention_mask=attention_mask,
|
||||
use_cache=True,
|
||||
)
|
||||
elif self.falcon_variant == "40b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
hidden_states = outputs[0]
|
||||
new_pkvs.append(
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
)
|
||||
elif self.falcon_variant == "180b":
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
outputs[-1][0],
|
||||
outputs[-1][1],
|
||||
)
|
||||
)
|
||||
|
||||
(
|
||||
(new_pkv00, new_pkv01),
|
||||
(new_pkv10, new_pkv11),
|
||||
(new_pkv20, new_pkv21),
|
||||
(new_pkv30, new_pkv31),
|
||||
(new_pkv40, new_pkv41),
|
||||
(new_pkv50, new_pkv51),
|
||||
(new_pkv60, new_pkv61),
|
||||
(new_pkv70, new_pkv71),
|
||||
(new_pkv80, new_pkv81),
|
||||
(new_pkv90, new_pkv91),
|
||||
(new_pkv100, new_pkv101),
|
||||
(new_pkv110, new_pkv111),
|
||||
(new_pkv120, new_pkv121),
|
||||
(new_pkv130, new_pkv131),
|
||||
(new_pkv140, new_pkv141),
|
||||
(new_pkv150, new_pkv151),
|
||||
(new_pkv160, new_pkv161),
|
||||
(new_pkv170, new_pkv171),
|
||||
(new_pkv180, new_pkv181),
|
||||
(new_pkv190, new_pkv191),
|
||||
(new_pkv200, new_pkv201),
|
||||
(new_pkv210, new_pkv211),
|
||||
(new_pkv220, new_pkv221),
|
||||
(new_pkv230, new_pkv231),
|
||||
(new_pkv240, new_pkv241),
|
||||
(new_pkv250, new_pkv251),
|
||||
(new_pkv260, new_pkv261),
|
||||
(new_pkv270, new_pkv271),
|
||||
(new_pkv280, new_pkv281),
|
||||
(new_pkv290, new_pkv291),
|
||||
(new_pkv300, new_pkv301),
|
||||
(new_pkv310, new_pkv311),
|
||||
(new_pkv320, new_pkv321),
|
||||
(new_pkv330, new_pkv331),
|
||||
(new_pkv340, new_pkv341),
|
||||
(new_pkv350, new_pkv351),
|
||||
(new_pkv360, new_pkv361),
|
||||
(new_pkv370, new_pkv371),
|
||||
(new_pkv380, new_pkv381),
|
||||
(new_pkv390, new_pkv391),
|
||||
) = new_pkvs
|
||||
result = (
|
||||
hidden_states,
|
||||
new_pkv00,
|
||||
new_pkv01,
|
||||
new_pkv10,
|
||||
new_pkv11,
|
||||
new_pkv20,
|
||||
new_pkv21,
|
||||
new_pkv30,
|
||||
new_pkv31,
|
||||
new_pkv40,
|
||||
new_pkv41,
|
||||
new_pkv50,
|
||||
new_pkv51,
|
||||
new_pkv60,
|
||||
new_pkv61,
|
||||
new_pkv70,
|
||||
new_pkv71,
|
||||
new_pkv80,
|
||||
new_pkv81,
|
||||
new_pkv90,
|
||||
new_pkv91,
|
||||
new_pkv100,
|
||||
new_pkv101,
|
||||
new_pkv110,
|
||||
new_pkv111,
|
||||
new_pkv120,
|
||||
new_pkv121,
|
||||
new_pkv130,
|
||||
new_pkv131,
|
||||
new_pkv140,
|
||||
new_pkv141,
|
||||
new_pkv150,
|
||||
new_pkv151,
|
||||
new_pkv160,
|
||||
new_pkv161,
|
||||
new_pkv170,
|
||||
new_pkv171,
|
||||
new_pkv180,
|
||||
new_pkv181,
|
||||
new_pkv190,
|
||||
new_pkv191,
|
||||
new_pkv200,
|
||||
new_pkv201,
|
||||
new_pkv210,
|
||||
new_pkv211,
|
||||
new_pkv220,
|
||||
new_pkv221,
|
||||
new_pkv230,
|
||||
new_pkv231,
|
||||
new_pkv240,
|
||||
new_pkv241,
|
||||
new_pkv250,
|
||||
new_pkv251,
|
||||
new_pkv260,
|
||||
new_pkv261,
|
||||
new_pkv270,
|
||||
new_pkv271,
|
||||
new_pkv280,
|
||||
new_pkv281,
|
||||
new_pkv290,
|
||||
new_pkv291,
|
||||
new_pkv300,
|
||||
new_pkv301,
|
||||
new_pkv310,
|
||||
new_pkv311,
|
||||
new_pkv320,
|
||||
new_pkv321,
|
||||
new_pkv330,
|
||||
new_pkv331,
|
||||
new_pkv340,
|
||||
new_pkv341,
|
||||
new_pkv350,
|
||||
new_pkv351,
|
||||
new_pkv360,
|
||||
new_pkv361,
|
||||
new_pkv370,
|
||||
new_pkv371,
|
||||
new_pkv380,
|
||||
new_pkv381,
|
||||
new_pkv390,
|
||||
new_pkv391,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class CompiledTwoWayShardingDecoderLayer(torch.nn.Module):
|
||||
def __init__(
|
||||
self, layer_id, device_idx, falcon_variant, device, precision, model
|
||||
):
|
||||
super().__init__()
|
||||
self.layer_id = layer_id
|
||||
self.device_index = device_idx
|
||||
self.falcon_variant = falcon_variant
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False,
|
||||
):
|
||||
import gc
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
if self.model is None:
|
||||
raise ValueError("Layer vmfb not found")
|
||||
|
||||
hidden_states = hidden_states.to(torch.float32).detach().numpy()
|
||||
attention_mask = attention_mask.to(torch.float32).detach().numpy()
|
||||
|
||||
if alibi is not None or layer_past is not None:
|
||||
raise ValueError("Past Key Values and alibi should be None")
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported Falcon variant: ", self.falcon_variant
|
||||
output = self.model(
|
||||
"forward",
|
||||
(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
),
|
||||
)
|
||||
|
||||
result = (
|
||||
torch.tensor(output[0]),
|
||||
(
|
||||
torch.tensor(output[1]),
|
||||
torch.tensor(output[2]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[3]),
|
||||
torch.tensor(output[4]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[5]),
|
||||
torch.tensor(output[6]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[7]),
|
||||
torch.tensor(output[8]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[9]),
|
||||
torch.tensor(output[10]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[11]),
|
||||
torch.tensor(output[12]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[13]),
|
||||
torch.tensor(output[14]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[15]),
|
||||
torch.tensor(output[16]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[17]),
|
||||
torch.tensor(output[18]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[19]),
|
||||
torch.tensor(output[20]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[21]),
|
||||
torch.tensor(output[22]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[23]),
|
||||
torch.tensor(output[24]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[25]),
|
||||
torch.tensor(output[26]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[27]),
|
||||
torch.tensor(output[28]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[29]),
|
||||
torch.tensor(output[30]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[31]),
|
||||
torch.tensor(output[32]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[33]),
|
||||
torch.tensor(output[34]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[35]),
|
||||
torch.tensor(output[36]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[37]),
|
||||
torch.tensor(output[38]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[39]),
|
||||
torch.tensor(output[40]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[41]),
|
||||
torch.tensor(output[42]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[43]),
|
||||
torch.tensor(output[44]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[45]),
|
||||
torch.tensor(output[46]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[47]),
|
||||
torch.tensor(output[48]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[49]),
|
||||
torch.tensor(output[50]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[51]),
|
||||
torch.tensor(output[52]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[53]),
|
||||
torch.tensor(output[54]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[55]),
|
||||
torch.tensor(output[56]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[57]),
|
||||
torch.tensor(output[58]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[59]),
|
||||
torch.tensor(output[60]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[61]),
|
||||
torch.tensor(output[62]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[63]),
|
||||
torch.tensor(output[64]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[65]),
|
||||
torch.tensor(output[66]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[67]),
|
||||
torch.tensor(output[68]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[69]),
|
||||
torch.tensor(output[70]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[71]),
|
||||
torch.tensor(output[72]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[73]),
|
||||
torch.tensor(output[74]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[75]),
|
||||
torch.tensor(output[76]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[77]),
|
||||
torch.tensor(output[78]),
|
||||
),
|
||||
(
|
||||
torch.tensor(output[79]),
|
||||
torch.tensor(output[80]),
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Any
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class VisionModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -52,7 +52,7 @@ class VisionModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -93,7 +93,7 @@ class FirstLlamaModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -157,7 +157,7 @@ class SecondLlamaModel(torch.nn.Module):
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
|
||||
@@ -24,7 +24,9 @@ class FirstVicuna(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -36,7 +38,7 @@ class FirstVicuna(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -79,7 +81,9 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -91,7 +95,7 @@ class SecondVicuna7B(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -329,7 +333,9 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -341,7 +347,7 @@ class SecondVicuna13B(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -627,7 +633,9 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -639,7 +647,7 @@ class SecondVicuna70B(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
|
||||
@@ -24,7 +24,9 @@ class FirstVicunaGPU(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -36,7 +38,7 @@ class FirstVicunaGPU(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -78,7 +80,9 @@ class SecondVicuna7BGPU(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -90,7 +94,7 @@ class SecondVicuna7BGPU(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -327,7 +331,9 @@ class SecondVicuna13BGPU(torch.nn.Module):
|
||||
torch.float32 if accumulates == "fp32" else torch.float16
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -339,7 +345,7 @@ class SecondVicuna13BGPU(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
@@ -625,7 +631,9 @@ class SecondVicuna70BGPU(torch.nn.Module):
|
||||
)
|
||||
print(f"[DEBUG] model_path : {model_path}")
|
||||
if precision in ["int4", "int8"]:
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import (
|
||||
quantize_model,
|
||||
)
|
||||
from brevitas_examples.llm.llm_quant.run_utils import (
|
||||
get_model_impl,
|
||||
)
|
||||
@@ -637,7 +645,7 @@ class SecondVicuna70BGPU(torch.nn.Module):
|
||||
dtype=self.accumulates,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import time
|
||||
|
||||
|
||||
class FirstVicunaLayer(torch.nn.Module):
|
||||
@@ -66,7 +67,6 @@ class ShardedVicunaModel(torch.nn.Module):
|
||||
def __init__(self, model, layers, lmhead, embedding, norm):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
# assert len(layers) == len(model.model.layers)
|
||||
self.model.model.config.use_cache = True
|
||||
self.model.model.config.output_attentions = False
|
||||
self.layers = layers
|
||||
@@ -110,9 +110,11 @@ class LMHeadCompiled(torch.nn.Module):
|
||||
self.model = shark_module
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.detach()
|
||||
hidden_states_sample = hidden_states.detach()
|
||||
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -136,8 +138,9 @@ class VicunaNormCompiled(torch.nn.Module):
|
||||
hidden_states.detach()
|
||||
except:
|
||||
pass
|
||||
output = self.model("forward", (hidden_states,))
|
||||
output = self.model("forward", (hidden_states,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
@@ -158,15 +161,18 @@ class VicunaEmbeddingCompiled(torch.nn.Module):
|
||||
|
||||
def forward(self, input_ids):
|
||||
input_ids.detach()
|
||||
output = self.model("forward", (input_ids,))
|
||||
output = self.model("forward", (input_ids,), send_to_host=True)
|
||||
output = torch.tensor(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class CompiledVicunaLayer(torch.nn.Module):
|
||||
def __init__(self, shark_module):
|
||||
def __init__(self, shark_module, idx, breakpoints):
|
||||
super().__init__()
|
||||
self.model = shark_module
|
||||
self.idx = idx
|
||||
self.breakpoints = breakpoints
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -177,10 +183,11 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
output_attentions=False,
|
||||
use_cache=True,
|
||||
):
|
||||
if self.breakpoints is None:
|
||||
is_breakpoint = False
|
||||
else:
|
||||
is_breakpoint = self.idx + 1 in self.breakpoints
|
||||
if past_key_value is None:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
output = self.model(
|
||||
"first_vicuna_forward",
|
||||
(
|
||||
@@ -188,11 +195,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
attention_mask,
|
||||
position_ids,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
@@ -202,11 +215,8 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
),
|
||||
)
|
||||
else:
|
||||
hidden_states = hidden_states.detach()
|
||||
attention_mask = attention_mask.detach()
|
||||
position_ids = position_ids.detach()
|
||||
pkv0 = past_key_value[0].detach()
|
||||
pkv1 = past_key_value[1].detach()
|
||||
pkv0 = past_key_value[0]
|
||||
pkv1 = past_key_value[1]
|
||||
output = self.model(
|
||||
"second_vicuna_forward",
|
||||
(
|
||||
@@ -216,11 +226,17 @@ class CompiledVicunaLayer(torch.nn.Module):
|
||||
pkv0,
|
||||
pkv1,
|
||||
),
|
||||
send_to_host=is_breakpoint,
|
||||
)
|
||||
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
if is_breakpoint:
|
||||
output0 = torch.tensor(output[0])
|
||||
output1 = torch.tensor(output[1])
|
||||
output2 = torch.tensor(output[2])
|
||||
else:
|
||||
output0 = output[0]
|
||||
output1 = output[1]
|
||||
output2 = output[2]
|
||||
|
||||
return (
|
||||
output0,
|
||||
|
||||
@@ -6,10 +6,10 @@ from apps.language_models.src.model_wrappers.falcon_sharded_model import (
|
||||
CompiledLNFEmbeddingLayer,
|
||||
LMHeadEmbeddingLayer,
|
||||
CompiledLMHeadEmbeddingLayer,
|
||||
DecoderLayer,
|
||||
EightDecoderLayer,
|
||||
CompiledDecoderLayer,
|
||||
CompiledEightDecoderLayer,
|
||||
FourWayShardingDecoderLayer,
|
||||
TwoWayShardingDecoderLayer,
|
||||
CompiledFourWayShardingDecoderLayer,
|
||||
CompiledTwoWayShardingDecoderLayer,
|
||||
ShardedFalconModel,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
@@ -94,6 +94,13 @@ parser.add_argument(
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Run model as sharded",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_shards",
|
||||
type=int,
|
||||
default=4,
|
||||
choices=[2, 4],
|
||||
help="Number of shards.",
|
||||
)
|
||||
|
||||
|
||||
class ShardedFalcon(SharkLLMBase):
|
||||
@@ -122,6 +129,10 @@ class ShardedFalcon(SharkLLMBase):
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
|
||||
if args.sharded and "180b" not in self.model_name:
|
||||
raise ValueError("Sharding supported only for Falcon-180B")
|
||||
|
||||
self.hf_auth_token = hf_auth_token
|
||||
self.max_padding_length = 100
|
||||
self.device = device
|
||||
@@ -131,7 +142,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.src_model = self.get_src_model()
|
||||
self.shark_model = self.compile(compressed=args.compressed)
|
||||
self.shark_model = self.compile()
|
||||
|
||||
def get_tokenizer(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
@@ -146,20 +157,17 @@ class ShardedFalcon(SharkLLMBase):
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"torch_dtype": torch.float32,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile_layer(
|
||||
@@ -225,7 +233,7 @@ class ShardedFalcon(SharkLLMBase):
|
||||
elif layer_id in ["ln_f", "lm_head"]:
|
||||
f16_input_mask = [True]
|
||||
elif "_" in layer_id or type(layer_id) == int:
|
||||
f16_input_mask = [True, False]
|
||||
f16_input_mask = [True, True]
|
||||
else:
|
||||
raise ValueError("Unsupported layer: ", layer_id)
|
||||
|
||||
@@ -288,28 +296,16 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
return shark_module, device_idx
|
||||
|
||||
def compile(self, compressed=False):
|
||||
def compile(self):
|
||||
sample_input_ids = torch.zeros([100], dtype=torch.int64)
|
||||
sample_attention_mask = torch.zeros(
|
||||
[1, 1, 100, 100], dtype=torch.float32
|
||||
)
|
||||
num_group_layers = 1
|
||||
if "7b" in self.model_name:
|
||||
num_in_features = 4544
|
||||
if compressed:
|
||||
num_group_layers = 8
|
||||
elif "40b" in self.model_name:
|
||||
num_in_features = 8192
|
||||
if compressed:
|
||||
num_group_layers = 15
|
||||
else:
|
||||
num_in_features = 14848
|
||||
sample_attention_mask = sample_attention_mask.to(dtype=torch.bool)
|
||||
if compressed:
|
||||
num_group_layers = 20
|
||||
|
||||
num_group_layers = int(
|
||||
20 * (4 / args.num_shards)
|
||||
) # 4 is the number of default shards
|
||||
sample_hidden_states = torch.zeros(
|
||||
[1, 100, num_in_features], dtype=torch.float32
|
||||
[1, 100, 14848], dtype=torch.float32
|
||||
)
|
||||
|
||||
# Determine number of available devices
|
||||
@@ -319,6 +315,10 @@ class ShardedFalcon(SharkLLMBase):
|
||||
|
||||
haldriver = ireert.get_driver(self.device)
|
||||
num_devices = len(haldriver.query_available_devices())
|
||||
if num_devices < 2:
|
||||
raise ValueError(
|
||||
"Cannot run Falcon-180B on a single ROCM device."
|
||||
)
|
||||
|
||||
lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head)
|
||||
print("Compiling Layer lm_head")
|
||||
@@ -326,7 +326,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
lm_head,
|
||||
[sample_hidden_states],
|
||||
"lm_head",
|
||||
device_idx=0 % num_devices if self.device == "rocm" else None,
|
||||
device_idx=(0 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head)
|
||||
|
||||
@@ -338,7 +340,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
word_embedding,
|
||||
[sample_input_ids],
|
||||
"word_embeddings",
|
||||
device_idx=1 % num_devices if self.device == "rocm" else None,
|
||||
device_idx=(1 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_word_embedding = CompiledWordEmbeddingsLayer(
|
||||
shark_word_embedding
|
||||
@@ -350,7 +354,9 @@ class ShardedFalcon(SharkLLMBase):
|
||||
ln_f,
|
||||
[sample_hidden_states],
|
||||
"ln_f",
|
||||
device_idx=2 % num_devices if self.device == "rocm" else None,
|
||||
device_idx=(2 % num_devices) % args.num_shards
|
||||
if self.device == "rocm"
|
||||
else None,
|
||||
)
|
||||
shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f)
|
||||
|
||||
@@ -360,24 +366,21 @@ class ShardedFalcon(SharkLLMBase):
|
||||
):
|
||||
device_idx = i % num_devices if self.device == "rocm" else None
|
||||
layer_id = i
|
||||
pytorch_class = DecoderLayer
|
||||
compiled_class = CompiledDecoderLayer
|
||||
if compressed:
|
||||
layer_id = (
|
||||
str(i * num_group_layers)
|
||||
+ "_"
|
||||
+ str((i + 1) * num_group_layers)
|
||||
)
|
||||
pytorch_class = EightDecoderLayer
|
||||
compiled_class = CompiledEightDecoderLayer
|
||||
layer_id = (
|
||||
str(i * num_group_layers)
|
||||
+ "_"
|
||||
+ str((i + 1) * num_group_layers)
|
||||
)
|
||||
pytorch_class = FourWayShardingDecoderLayer
|
||||
compiled_class = CompiledFourWayShardingDecoderLayer
|
||||
if args.num_shards == 2:
|
||||
pytorch_class = TwoWayShardingDecoderLayer
|
||||
compiled_class = CompiledTwoWayShardingDecoderLayer
|
||||
|
||||
print("Compiling Layer {}".format(layer_id))
|
||||
if compressed:
|
||||
layer_i = self.src_model.transformer.h[
|
||||
i * num_group_layers : (i + 1) * num_group_layers
|
||||
]
|
||||
else:
|
||||
layer_i = self.src_model.transformer.h[i]
|
||||
layer_i = self.src_model.transformer.h[
|
||||
i * num_group_layers : (i + 1) * num_group_layers
|
||||
]
|
||||
|
||||
pytorch_layer_i = pytorch_class(
|
||||
layer_i, args.falcon_variant_to_use
|
||||
@@ -388,13 +391,13 @@ class ShardedFalcon(SharkLLMBase):
|
||||
layer_id,
|
||||
device_idx=device_idx,
|
||||
)
|
||||
del shark_module
|
||||
shark_layer_i = compiled_class(
|
||||
layer_id,
|
||||
device_idx,
|
||||
args.falcon_variant_to_use,
|
||||
self.device,
|
||||
self.precision,
|
||||
shark_module,
|
||||
)
|
||||
shark_layers.append(shark_layer_i)
|
||||
|
||||
@@ -668,20 +671,17 @@ class UnshardedFalcon(SharkLLMBase):
|
||||
def get_src_model(self):
|
||||
print("Loading src model: ", self.model_name)
|
||||
kwargs = {
|
||||
"torch_dtype": torch.float,
|
||||
"torch_dtype": torch.float32,
|
||||
"trust_remote_code": True,
|
||||
"token": self.hf_auth_token,
|
||||
}
|
||||
if self.precision == "int4":
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["load_gptq_on_cpu"] = True
|
||||
kwargs["device_map"] = "cpu"
|
||||
falcon_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, **kwargs
|
||||
)
|
||||
if self.precision == "int4":
|
||||
falcon_model = falcon_model.to(torch.float32)
|
||||
return falcon_model
|
||||
|
||||
def compile(self):
|
||||
|
||||
@@ -132,7 +132,7 @@ import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from typing import List, Tuple
|
||||
from io import BytesIO
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
|
||||
@@ -4,13 +4,49 @@ from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from apps.language_models.utils import (
|
||||
get_torch_mlir_module_bytecode,
|
||||
get_vmfb_from_path,
|
||||
)
|
||||
from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase
|
||||
from apps.language_models.src.model_wrappers.stablelm_model import (
|
||||
StableLMModel,
|
||||
)
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="stablelm runner",
|
||||
description="runs a StableLM model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"]
|
||||
)
|
||||
parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda")
|
||||
parser.add_argument(
|
||||
"--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stablelm_mlir_path",
|
||||
default=None,
|
||||
help="path to StableLM's mlir file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_precompiled_model",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="use the precompiled vmfb",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_mlir_from_shark_tank",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="download precompile mlir from shark tank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hf_auth_token",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication token for stablelm-3B model.",
|
||||
)
|
||||
|
||||
|
||||
class StopOnTokens(StoppingCriteria):
|
||||
@@ -29,7 +65,7 @@ class SharkStableLM(SharkLLMBase):
|
||||
self,
|
||||
model_name,
|
||||
hf_model_path="stabilityai/stablelm-tuned-alpha-3b",
|
||||
max_num_tokens=512,
|
||||
max_num_tokens=256,
|
||||
device="cuda",
|
||||
precision="fp32",
|
||||
debug="False",
|
||||
@@ -37,6 +73,14 @@ class SharkStableLM(SharkLLMBase):
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_len = 256
|
||||
self.device = device
|
||||
if precision != "int4" and args.hf_auth_token == None:
|
||||
raise ValueError(
|
||||
""" HF auth token required for StableLM-3B. Pass it using
|
||||
--hf_auth_token flag. You can ask for the access to the model
|
||||
here: https://huggingface.co/tiiuae/falcon-180B-chat."""
|
||||
)
|
||||
self.hf_auth_token = args.hf_auth_token
|
||||
|
||||
self.precision = precision
|
||||
self.debug = debug
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
@@ -50,9 +94,23 @@ class SharkStableLM(SharkLLMBase):
|
||||
return False
|
||||
|
||||
def get_src_model(self):
|
||||
kwargs = {}
|
||||
if self.precision == "int4":
|
||||
self.hf_model_path = "TheBloke/stablelm-zephyr-3b-GPTQ"
|
||||
from transformers import GPTQConfig
|
||||
|
||||
quantization_config = GPTQConfig(bits=4, disable_exllama=True)
|
||||
kwargs["quantization_config"] = quantization_config
|
||||
kwargs["device_map"] = "cpu"
|
||||
print("[DEBUG] Loading Model")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.hf_model_path, torch_dtype=torch.float32
|
||||
self.hf_model_path,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float32,
|
||||
use_auth_token=self.hf_auth_token,
|
||||
**kwargs,
|
||||
)
|
||||
print("[DEBUG] Model loaded successfully")
|
||||
return model
|
||||
|
||||
def get_model_inputs(self):
|
||||
@@ -61,9 +119,7 @@ class SharkStableLM(SharkLLMBase):
|
||||
return input_ids, attention_mask
|
||||
|
||||
def compile(self):
|
||||
tmp_model_name = (
|
||||
f"stableLM_linalg_{self.precision}_seqLen{self.max_sequence_len}"
|
||||
)
|
||||
tmp_model_name = f"{self.model_name}_linalg_{self.precision}_seqLen{self.max_sequence_len}"
|
||||
|
||||
# device = "cuda" # "cpu"
|
||||
# TODO: vmfb and mlir name should include precision and device
|
||||
@@ -83,13 +139,19 @@ class SharkStableLM(SharkLLMBase):
|
||||
print(
|
||||
f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
if not mlir_path.exists():
|
||||
model = StableLMModel(self.get_src_model())
|
||||
model_inputs = self.get_model_inputs()
|
||||
ts_graph = get_torch_mlir_module_bytecode(model, model_inputs)
|
||||
from shark.shark_importer import import_with_fx
|
||||
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
model_inputs,
|
||||
is_f16=True if self.precision in ["fp16"] else False,
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*model_inputs],
|
||||
@@ -100,15 +162,16 @@ class SharkStableLM(SharkLLMBase):
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
f_ = open(tmp_model_name + ".mlir", "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir")
|
||||
f_.close()
|
||||
f_ = open(mlir_path, "wb")
|
||||
f_.write(bytecode)
|
||||
print("Saved mlir at: ", mlir_path)
|
||||
f_.close()
|
||||
del bytecode
|
||||
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=self.device, mlir_dialect="tm_tensor"
|
||||
mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
shark_module.compile()
|
||||
|
||||
@@ -120,14 +183,22 @@ class SharkStableLM(SharkLLMBase):
|
||||
return shark_module
|
||||
|
||||
def get_tokenizer(self):
|
||||
tok = AutoTokenizer.from_pretrained(self.hf_model_path)
|
||||
tok = AutoTokenizer.from_pretrained(
|
||||
self.hf_model_path,
|
||||
use_auth_token=self.hf_auth_token,
|
||||
)
|
||||
tok.add_special_tokens({"pad_token": "<PAD>"})
|
||||
# print("[DEBUG] Sucessfully loaded the tokenizer to the memory")
|
||||
return tok
|
||||
|
||||
def generate(self, prompt):
|
||||
words_list = []
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
count = 0
|
||||
for i in range(self.max_num_tokens):
|
||||
count = count + 1
|
||||
params = {
|
||||
"new_text": prompt,
|
||||
}
|
||||
@@ -145,6 +216,12 @@ class SharkStableLM(SharkLLMBase):
|
||||
if detok == "":
|
||||
break
|
||||
prompt = prompt + detok
|
||||
end = time.time()
|
||||
print(
|
||||
"\n\nTime taken is {:.2f} tokens/second\n".format(
|
||||
count / (end - start)
|
||||
)
|
||||
)
|
||||
return words_list
|
||||
|
||||
def generate_new_token(self, params):
|
||||
@@ -178,10 +255,46 @@ class SharkStableLM(SharkLLMBase):
|
||||
return ret_dict
|
||||
|
||||
|
||||
# Initialize a StopOnTokens object
|
||||
system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version)
|
||||
- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI.
|
||||
- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
||||
- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes.
|
||||
- StableLM will refuse to participate in anything that could harm a human.
|
||||
"""
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
|
||||
stable_lm = SharkStableLM(
|
||||
model_name="stablelm_zephyr_3b",
|
||||
hf_model_path="stabilityai/stablelm-zephyr-3b",
|
||||
device=args.device,
|
||||
precision=args.precision,
|
||||
)
|
||||
|
||||
default_prompt_text = "The weather is always wonderful"
|
||||
continue_execution = True
|
||||
|
||||
print("\n-----\nScript executing for the following config: \n")
|
||||
print("StableLM Model: ", stable_lm.hf_model_path)
|
||||
print("Precision: ", args.precision)
|
||||
print("Device: ", args.device)
|
||||
|
||||
while continue_execution:
|
||||
use_default_prompt = input(
|
||||
"\nDo you wish to use the default prompt text? Y/N ?: "
|
||||
)
|
||||
if use_default_prompt in ["Y", "y"]:
|
||||
prompt = default_prompt_text
|
||||
else:
|
||||
prompt = input("Please enter the prompt text: ")
|
||||
print("\nPrompt Text: ", prompt)
|
||||
|
||||
res_str = stable_lm.generate(prompt)
|
||||
torch.cuda.empty_cache()
|
||||
import gc
|
||||
|
||||
gc.collect()
|
||||
print(
|
||||
"\n\n-----\nHere's the complete formatted result: \n\n",
|
||||
prompt + "".join(res_str),
|
||||
)
|
||||
continue_execution = input(
|
||||
"\nDo you wish to run script one more time? Y/N ?: "
|
||||
)
|
||||
continue_execution = (
|
||||
True if continue_execution in ["Y", "y"] else False
|
||||
)
|
||||
|
||||
@@ -105,6 +105,7 @@ def main():
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
control_mode=args.control_mode,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
|
||||
96
apps/stable_diffusion/scripts/txt2img_sdxl.py
Normal file
96
apps/stable_diffusion/scripts/txt2img_sdxl.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import torch
|
||||
import time
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
clear_all,
|
||||
save_output_img,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
if args.clear_all:
|
||||
clear_all()
|
||||
|
||||
# TODO: prompt_embeds and text_embeds form base_model.json requires fixing
|
||||
args.precision = "fp16"
|
||||
args.height = 1024
|
||||
args.width = 1024
|
||||
args.max_length = 77
|
||||
args.scheduler = "DDIM"
|
||||
print(
|
||||
"Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM"
|
||||
)
|
||||
dtype = torch.float32 if args.precision == "fp32" else torch.half
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
set_init_device_flags()
|
||||
schedulers = get_schedulers(args.hf_model_id)
|
||||
scheduler_obj = schedulers[args.scheduler]
|
||||
seed = args.seed
|
||||
txt2img_obj = Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=args.precision,
|
||||
max_length=args.max_length,
|
||||
batch_size=args.batch_size,
|
||||
height=args.height,
|
||||
width=args.width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
|
||||
seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds)
|
||||
for current_batch in range(args.batch_count):
|
||||
start_time = time.time()
|
||||
generated_imgs = txt2img_obj.generate_images(
|
||||
args.prompts,
|
||||
args.negative_prompts,
|
||||
args.batch_size,
|
||||
args.height,
|
||||
args.width,
|
||||
args.steps,
|
||||
args.guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = f"prompt={args.prompts}"
|
||||
text_output += f"\nnegative prompt={args.negative_prompts}"
|
||||
text_output += (
|
||||
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
|
||||
)
|
||||
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
|
||||
text_output += (
|
||||
f"\nsteps={args.steps}, guidance_scale={args.guidance_scale},"
|
||||
)
|
||||
text_output += (
|
||||
f"seed={seeds[current_batch]}, size={args.height}x{args.width}"
|
||||
)
|
||||
text_output += (
|
||||
f", batch size={args.batch_size}, max_length={args.max_length}"
|
||||
)
|
||||
# TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start
|
||||
text_output += txt2img_obj.log
|
||||
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
|
||||
|
||||
save_output_img(generated_imgs[0], seed)
|
||||
print(text_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,6 +19,9 @@ a = Analysis(
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("tiktoken")
|
||||
@@ -60,6 +61,7 @@ datas += [
|
||||
("src/utils/resources/opt_flags.json", "resources"),
|
||||
("src/utils/resources/base_model.json", "resources"),
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/js/*", "ui/js"),
|
||||
("web/ui/logos/*", "logos"),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/configs/*",
|
||||
@@ -75,6 +77,7 @@ datas += [
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "shark.shark_inference", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [
|
||||
x for x in collect_submodules("diffusers") if "tests" not in x
|
||||
]
|
||||
@@ -85,4 +88,4 @@ hiddenimports += [
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x]
|
||||
hiddenimports += ["iree._runtime", "iree.compiler._mlir_libs._mlir.ir"]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
|
||||
@@ -9,6 +9,7 @@ from apps.stable_diffusion.src.utils import (
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines import (
|
||||
Text2ImagePipeline,
|
||||
Text2ImageSDXLPipeline,
|
||||
Image2ImagePipeline,
|
||||
InpaintPipeline,
|
||||
OutpaintPipeline,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel
|
||||
from transformers import CLIPTextModel
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
import torch
|
||||
@@ -24,6 +24,8 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_stencil_model_id,
|
||||
update_lora_weight,
|
||||
)
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_inference import SharkInference
|
||||
|
||||
|
||||
# These shapes are parameter dependent.
|
||||
@@ -55,6 +57,10 @@ def replace_shape_str(shape, max_len, width, height, batch_size):
|
||||
new_shape.append(math.ceil(height / div_val))
|
||||
elif "width" in shape[i]:
|
||||
new_shape.append(math.ceil(width / div_val))
|
||||
elif "+" in shape[i]:
|
||||
# Currently this case only hits for SDXL. So, in case any other
|
||||
# case requires this operator, change this.
|
||||
new_shape.append(height + width)
|
||||
else:
|
||||
new_shape.append(shape[i])
|
||||
return new_shape
|
||||
@@ -67,6 +73,70 @@ def check_compilation(model, model_name):
|
||||
)
|
||||
|
||||
|
||||
def shark_compile_after_ir(
|
||||
module_name,
|
||||
device,
|
||||
vmfb_path,
|
||||
precision,
|
||||
ir_path=None,
|
||||
):
|
||||
if ir_path:
|
||||
print(f"[DEBUG] mlir found at {ir_path.absolute()}")
|
||||
|
||||
module = SharkInference(
|
||||
mlir_module=ir_path,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
print(f"Will get extra flag for {module_name} and precision = {precision}")
|
||||
path = module.save_module(
|
||||
vmfb_path.parent.absolute(),
|
||||
vmfb_path.stem,
|
||||
extra_args=get_opt_flags(module_name, precision=precision),
|
||||
)
|
||||
print(f"Saved {module_name} vmfb at {path}")
|
||||
module.load_module(path)
|
||||
return module
|
||||
|
||||
|
||||
def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision):
|
||||
name_split = extended_model_name.split("_")
|
||||
if "vae" in model_name:
|
||||
name_split[5] = "fp32"
|
||||
extended_model_name_for_vmfb = "_".join(name_split)
|
||||
extended_model_name_for_mlir = "_".join(name_split[:-1])
|
||||
vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb")
|
||||
if "vulkan" in device:
|
||||
_device = args.iree_vulkan_target_triple
|
||||
_device = _device.replace("-", "_")
|
||||
vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb")
|
||||
if vmfb_path.exists():
|
||||
shark_module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
print(f"loading existing vmfb from: {vmfb_path}")
|
||||
shark_module.load_module(vmfb_path, extra_args=[])
|
||||
return shark_module, None
|
||||
mlir_path = Path(extended_model_name_for_mlir + ".mlir")
|
||||
if not mlir_path.exists():
|
||||
print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}")
|
||||
download_public_file(
|
||||
f"gs://shark_tank/SDXL/mlir/{mlir_path.name}",
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
return (
|
||||
shark_compile_after_ir(
|
||||
model_name, device, vmfb_path, precision, mlir_path
|
||||
),
|
||||
None,
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
class SharkifyStableDiffusionModel:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -86,13 +156,17 @@ class SharkifyStableDiffusionModel:
|
||||
generate_vmfb: bool = True,
|
||||
is_inpaint: bool = False,
|
||||
is_upscaler: bool = False,
|
||||
use_stencil: str = None,
|
||||
is_sdxl: bool = False,
|
||||
stencils: list[str] = [],
|
||||
use_lora: str = "",
|
||||
lora_strength: float = 0.75,
|
||||
use_quantize: str = None,
|
||||
return_mlir: bool = False,
|
||||
favored_base_models=None,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.max_len = max_len
|
||||
self.is_sdxl = is_sdxl
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
@@ -118,9 +192,7 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
# TODO: remove the following line when stable-diffusion-2-1 works
|
||||
if self.model_id == "stabilityai/stable-diffusion-2-1":
|
||||
self.model_id = "stabilityai/stable-diffusion-2-1-base"
|
||||
self.favored_base_models = favored_base_models
|
||||
self.custom_vae = custom_vae
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
@@ -136,6 +208,7 @@ class SharkifyStableDiffusionModel:
|
||||
+ "_"
|
||||
+ precision
|
||||
)
|
||||
self.model_namedata = self.model_name
|
||||
print(f"use_tuned? sharkify: {use_tuned}")
|
||||
self.use_tuned = use_tuned
|
||||
if use_tuned:
|
||||
@@ -144,12 +217,17 @@ class SharkifyStableDiffusionModel:
|
||||
self.low_cpu_mem_usage = low_cpu_mem_usage
|
||||
self.is_inpaint = is_inpaint
|
||||
self.is_upscaler = is_upscaler
|
||||
self.use_stencil = get_stencil_model_id(use_stencil)
|
||||
self.stencils = [get_stencil_model_id(x) for x in stencils]
|
||||
if use_lora != "":
|
||||
self.model_name = self.model_name + "_" + get_path_stem(use_lora)
|
||||
self.model_name = (
|
||||
self.model_name
|
||||
+ "_"
|
||||
+ get_path_stem(use_lora)
|
||||
+ f"@{int(lora_strength*100)}"
|
||||
)
|
||||
self.use_lora = use_lora
|
||||
self.lora_strength = lora_strength
|
||||
|
||||
print(self.model_name)
|
||||
self.model_name = self.get_extended_name_for_all_model()
|
||||
self.debug = debug
|
||||
self.sharktank_dir = sharktank_dir
|
||||
@@ -171,19 +249,22 @@ class SharkifyStableDiffusionModel:
|
||||
args.hf_model_id = self.base_model_id
|
||||
self.return_mlir = return_mlir
|
||||
|
||||
def get_extended_name_for_all_model(self):
|
||||
def get_extended_name_for_all_model(self, model_list=None):
|
||||
model_name = {}
|
||||
sub_model_list = [
|
||||
"clip",
|
||||
"clip2",
|
||||
"unet",
|
||||
"unet512",
|
||||
"stencil_unet",
|
||||
"stencil_unet_512",
|
||||
"vae",
|
||||
"vae_encode",
|
||||
"stencil_adaptor",
|
||||
"stencil_adaptor_512",
|
||||
"stencil_adapter",
|
||||
"stencil_adapter_512",
|
||||
]
|
||||
if model_list is not None:
|
||||
sub_model_list = model_list
|
||||
index = 0
|
||||
for model in sub_model_list:
|
||||
sub_model = model
|
||||
@@ -195,10 +276,24 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if self.base_vae:
|
||||
sub_model = "base_vae"
|
||||
if "stencil_adaptor" == model and self.use_stencil is not None:
|
||||
model_config = model_config + get_path_stem(self.use_stencil)
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
if "stencil_adapter" in model:
|
||||
stencil_names = []
|
||||
for i, stencil in enumerate(self.stencils):
|
||||
if stencil is not None:
|
||||
cnet_config = (
|
||||
self.model_namedata
|
||||
+ "_sd15_"
|
||||
+ stencil.split("_")[-1]
|
||||
)
|
||||
stencil_names.append(
|
||||
get_extended_name(sub_model + cnet_config)
|
||||
)
|
||||
|
||||
model_name[model] = stencil_names
|
||||
else:
|
||||
model_name[model] = get_extended_name(sub_model + model_config)
|
||||
index += 1
|
||||
|
||||
return model_name
|
||||
|
||||
def check_params(self, max_len, width, height):
|
||||
@@ -342,6 +437,105 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_vae_sdxl(self):
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
self.model_name["vae"], "vae", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class VaeModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
base_vae=self.base_vae,
|
||||
custom_vae=self.custom_vae,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = None
|
||||
if custom_vae == "":
|
||||
print(f"Loading default vae, with target {model_id}")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif not isinstance(custom_vae, dict):
|
||||
precision = "fp16" if "fp16" in custom_vae else None
|
||||
print(f"Loading custom vae, with target {custom_vae}")
|
||||
if os.path.exists(custom_vae):
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
custom_vae = "/".join(
|
||||
[
|
||||
custom_vae.split("/")[-2].split("\\")[-1],
|
||||
custom_vae.split("/")[-1],
|
||||
]
|
||||
)
|
||||
print("Using hub to get custom vae")
|
||||
try:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
variant=precision,
|
||||
)
|
||||
except:
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
custom_vae,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
print(f"Loading custom vae, with state {custom_vae}")
|
||||
self.vae = AutoencoderKL.from_pretrained(
|
||||
model_id,
|
||||
subfolder="vae",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.vae.load_state_dict(custom_vae)
|
||||
self.base_vae = base_vae
|
||||
|
||||
def forward(self, latents):
|
||||
image = self.vae.decode(latents / 0.13025, return_dict=False)[
|
||||
0
|
||||
]
|
||||
return image
|
||||
|
||||
vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
inputs = tuple(self.inputs["vae"])
|
||||
# Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL
|
||||
# pipeline.
|
||||
if not self.custom_vae:
|
||||
is_f16 = False
|
||||
elif "16" in self.custom_vae:
|
||||
is_f16 = True
|
||||
else:
|
||||
is_f16 = False
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"])
|
||||
if self.debug:
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
shark_vae, vae_mlir = compile_through_fx(
|
||||
vae,
|
||||
inputs,
|
||||
is_f16=is_f16,
|
||||
use_tuned=self.use_tuned,
|
||||
extended_model_name=self.model_name["vae"],
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("vae", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="vae",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_vae, vae_mlir
|
||||
|
||||
def get_controlled_unet(self, use_large=False):
|
||||
class ControlledUnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -349,6 +543,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -357,8 +552,10 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
self.in_channels = self.unet.in_channels
|
||||
update_lora_weight(
|
||||
self.unet, use_lora, "unet", lora_strength
|
||||
)
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(
|
||||
@@ -380,25 +577,54 @@ class SharkifyStableDiffusionModel:
|
||||
control11,
|
||||
control12,
|
||||
control13,
|
||||
scale1,
|
||||
scale2,
|
||||
scale3,
|
||||
scale4,
|
||||
scale5,
|
||||
scale6,
|
||||
scale7,
|
||||
scale8,
|
||||
scale9,
|
||||
scale10,
|
||||
scale11,
|
||||
scale12,
|
||||
scale13,
|
||||
):
|
||||
# TODO: Average pooling
|
||||
db_res_samples = [
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
]
|
||||
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
db_res_samples = tuple(
|
||||
[
|
||||
control1,
|
||||
control2,
|
||||
control3,
|
||||
control4,
|
||||
control5,
|
||||
control6,
|
||||
control7,
|
||||
control8,
|
||||
control9,
|
||||
control10,
|
||||
control11,
|
||||
control12,
|
||||
control1 * scale1,
|
||||
control2 * scale2,
|
||||
control3 * scale3,
|
||||
control4 * scale4,
|
||||
control5 * scale5,
|
||||
control6 * scale6,
|
||||
control7 * scale7,
|
||||
control8 * scale8,
|
||||
control9 * scale9,
|
||||
control10 * scale10,
|
||||
control11 * scale11,
|
||||
control12 * scale12,
|
||||
]
|
||||
)
|
||||
mb_res_samples = control13
|
||||
mb_res_samples = control13 * scale13
|
||||
latents = torch.cat([latent] * 2)
|
||||
unet_out = self.unet.forward(
|
||||
latents,
|
||||
@@ -446,6 +672,19 @@ class SharkifyStableDiffusionModel:
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
True,
|
||||
]
|
||||
shark_controlled_unet, controlled_unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
@@ -462,17 +701,19 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_controlled_unet, controlled_unet_mlir
|
||||
|
||||
def get_control_net(self, use_large=False):
|
||||
def get_control_net(self, stencil_id, use_large=False):
|
||||
stencil_id = get_stencil_model_id(stencil_id)
|
||||
adapter_id, base_model_safe_id, ext_model_name = (None, None, None)
|
||||
print(f"Importing ControlNet adapter from {stencil_id}")
|
||||
|
||||
class StencilControlNetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self, model_id=self.use_stencil, low_cpu_mem_usage=False
|
||||
):
|
||||
def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False):
|
||||
super().__init__()
|
||||
self.cnet = ControlNetModel.from_pretrained(
|
||||
model_id,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
self.in_channels = self.cnet.in_channels
|
||||
self.in_channels = self.cnet.config.in_channels
|
||||
self.train(False)
|
||||
|
||||
def forward(
|
||||
@@ -481,6 +722,19 @@ class SharkifyStableDiffusionModel:
|
||||
timestep,
|
||||
text_embedding,
|
||||
stencil_image_input,
|
||||
acc1,
|
||||
acc2,
|
||||
acc3,
|
||||
acc4,
|
||||
acc5,
|
||||
acc6,
|
||||
acc7,
|
||||
acc8,
|
||||
acc9,
|
||||
acc10,
|
||||
acc11,
|
||||
acc12,
|
||||
acc13,
|
||||
):
|
||||
# expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
||||
# TODO: guidance NOT NEEDED change in `get_input_info` later
|
||||
@@ -502,6 +756,20 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return tuple(
|
||||
list(down_block_res_samples) + [mid_block_res_sample]
|
||||
) + (
|
||||
acc1 + down_block_res_samples[0],
|
||||
acc2 + down_block_res_samples[1],
|
||||
acc3 + down_block_res_samples[2],
|
||||
acc4 + down_block_res_samples[3],
|
||||
acc5 + down_block_res_samples[4],
|
||||
acc6 + down_block_res_samples[5],
|
||||
acc7 + down_block_res_samples[6],
|
||||
acc8 + down_block_res_samples[7],
|
||||
acc9 + down_block_res_samples[8],
|
||||
acc10 + down_block_res_samples[9],
|
||||
acc11 + down_block_res_samples[10],
|
||||
acc12 + down_block_res_samples[11],
|
||||
acc13 + mid_block_res_sample,
|
||||
)
|
||||
|
||||
scnet = StencilControlNetModel(
|
||||
@@ -509,7 +777,25 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
|
||||
inputs = tuple(self.inputs["stencil_adaptor"])
|
||||
inputs = tuple(self.inputs["stencil_adapter"])
|
||||
model_name = "stencil_adapter_512" if use_large else "stencil_adapter"
|
||||
stencil_names = self.get_extended_name_for_all_model([model_name])
|
||||
ext_model_name = stencil_names[model_name]
|
||||
if isinstance(ext_model_name, list):
|
||||
desired_name = None
|
||||
print(ext_model_name)
|
||||
for i in ext_model_name:
|
||||
if stencil_id.split("_")[-1] in i:
|
||||
desired_name = i
|
||||
else:
|
||||
continue
|
||||
if desired_name:
|
||||
ext_model_name = desired_name
|
||||
else:
|
||||
raise Exception(
|
||||
f"Could not find extended configuration for {stencil_id}"
|
||||
)
|
||||
|
||||
if use_large:
|
||||
pad = (0, 0) * (len(inputs[2].shape) - 2)
|
||||
pad = pad + (0, 512 - inputs[2].shape[1])
|
||||
@@ -517,21 +803,15 @@ class SharkifyStableDiffusionModel:
|
||||
inputs[0],
|
||||
inputs[1],
|
||||
torch.nn.functional.pad(inputs[2], pad),
|
||||
inputs[3],
|
||||
*inputs[3:],
|
||||
)
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor_512"]
|
||||
)
|
||||
else:
|
||||
save_dir = os.path.join(
|
||||
self.sharktank_dir, self.model_name["stencil_adaptor"]
|
||||
)
|
||||
input_mask = [True, True, True, True]
|
||||
model_name = "stencil_adaptor" if use_large else "stencil_adaptor_512"
|
||||
save_dir = os.path.join(self.sharktank_dir, ext_model_name)
|
||||
input_mask = [True, True, True, True] + ([True] * 13)
|
||||
|
||||
shark_cnet, cnet_mlir = compile_through_fx(
|
||||
scnet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name[model_name],
|
||||
extended_model_name=ext_model_name,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
@@ -550,6 +830,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
@@ -558,7 +839,9 @@ class SharkifyStableDiffusionModel:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(self.unet, use_lora, "unet")
|
||||
update_lora_weight(
|
||||
self.unet, use_lora, "unet", lora_strength
|
||||
)
|
||||
self.in_channels = self.unet.config.in_channels
|
||||
self.train(False)
|
||||
if (
|
||||
@@ -688,6 +971,101 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_unet_sdxl(self):
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
self.model_name["unet"], "unet", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class UnetModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
):
|
||||
super().__init__()
|
||||
try:
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
variant="fp16",
|
||||
)
|
||||
except:
|
||||
self.unet = UNet2DConditionModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="unet",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
if (
|
||||
args.attention_slicing is not None
|
||||
and args.attention_slicing != "none"
|
||||
):
|
||||
if args.attention_slicing.isdigit():
|
||||
self.unet.set_attention_slice(
|
||||
int(args.attention_slicing)
|
||||
)
|
||||
else:
|
||||
self.unet.set_attention_slice(args.attention_slicing)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
latent,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
text_embeds,
|
||||
time_ids,
|
||||
guidance_scale,
|
||||
):
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": text_embeds,
|
||||
"time_ids": time_ids,
|
||||
}
|
||||
noise_pred = self.unet.forward(
|
||||
latent,
|
||||
timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=None,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (
|
||||
noise_pred_text - noise_pred_uncond
|
||||
)
|
||||
return noise_pred
|
||||
|
||||
unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage)
|
||||
is_f16 = True if self.precision == "fp16" else False
|
||||
inputs = tuple(self.inputs["unet"])
|
||||
save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"])
|
||||
input_mask = [True, True, True, True, True, True]
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_unet, unet_mlir = compile_through_fx(
|
||||
unet,
|
||||
inputs,
|
||||
extended_model_name=self.model_name["unet"],
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=input_mask,
|
||||
use_tuned=self.use_tuned,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("unet", precision=self.precision),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="unet",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_unet, unet_mlir
|
||||
|
||||
def get_clip(self):
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(
|
||||
@@ -695,6 +1073,7 @@ class SharkifyStableDiffusionModel:
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
use_lora=self.use_lora,
|
||||
lora_strength=self.lora_strength,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
@@ -704,7 +1083,10 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
if use_lora != "":
|
||||
update_lora_weight(
|
||||
self.text_encoder, use_lora, "text_encoder"
|
||||
self.text_encoder,
|
||||
use_lora,
|
||||
"text_encoder",
|
||||
lora_strength,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
@@ -735,6 +1117,78 @@ class SharkifyStableDiffusionModel:
|
||||
)
|
||||
return shark_clip, clip_mlir
|
||||
|
||||
def get_clip_sdxl(self, clip_index=1):
|
||||
if clip_index == 1:
|
||||
extended_model_name = self.model_name["clip"]
|
||||
model_name = "clip"
|
||||
else:
|
||||
extended_model_name = self.model_name["clip2"]
|
||||
model_name = "clip2"
|
||||
# TODO: Remove this after convergence with shark_tank. This should just be part of
|
||||
# opt_params.py.
|
||||
shark_module_or_none = process_vmfb_ir_sdxl(
|
||||
extended_model_name, f"clip", args.device, self.precision
|
||||
)
|
||||
if shark_module_or_none[0]:
|
||||
return shark_module_or_none
|
||||
|
||||
class CLIPText(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_id=self.model_id,
|
||||
low_cpu_mem_usage=False,
|
||||
clip_index=1,
|
||||
):
|
||||
super().__init__()
|
||||
if clip_index == 1:
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
else:
|
||||
self.text_encoder = (
|
||||
CLIPTextModelWithProjection.from_pretrained(
|
||||
model_id,
|
||||
subfolder="text_encoder_2",
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
prompt_embeds = self.text_encoder(
|
||||
input,
|
||||
output_hidden_states=True,
|
||||
)
|
||||
# We are only ALWAYS interested in the pooled output of the final text encoder
|
||||
pooled_prompt_embeds = prompt_embeds[0]
|
||||
prompt_embeds = prompt_embeds.hidden_states[-2]
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
clip_model = CLIPText(
|
||||
low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index
|
||||
)
|
||||
save_dir = os.path.join(self.sharktank_dir, extended_model_name)
|
||||
if self.debug:
|
||||
os.makedirs(
|
||||
save_dir,
|
||||
exist_ok=True,
|
||||
)
|
||||
shark_clip, clip_mlir = compile_through_fx(
|
||||
clip_model,
|
||||
tuple(self.inputs["clip"]),
|
||||
extended_model_name=extended_model_name,
|
||||
debug=self.debug,
|
||||
generate_vmfb=self.generate_vmfb,
|
||||
save_dir=save_dir,
|
||||
extra_args=get_opt_flags("clip", precision="fp32"),
|
||||
base_model_id=self.base_model_id,
|
||||
model_name="clip",
|
||||
precision=self.precision,
|
||||
return_mlir=self.return_mlir,
|
||||
)
|
||||
return shark_clip, clip_mlir
|
||||
|
||||
def process_custom_vae(self):
|
||||
custom_vae = self.custom_vae.lower()
|
||||
if not custom_vae.endswith((".ckpt", ".safetensors")):
|
||||
@@ -767,7 +1221,9 @@ class SharkifyStableDiffusionModel:
|
||||
}
|
||||
return vae_dict
|
||||
|
||||
def compile_unet_variants(self, model, use_large=False):
|
||||
def compile_unet_variants(self, model, use_large=False, base_model=""):
|
||||
if self.is_sdxl:
|
||||
return self.get_unet_sdxl()
|
||||
if model == "unet":
|
||||
if self.is_upscaler:
|
||||
return self.get_unet_upscaler(use_large=use_large)
|
||||
@@ -809,21 +1265,53 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def sdxl_clip(self):
|
||||
try:
|
||||
self.inputs["clip"] = self.get_input_info_for(
|
||||
base_models["sdxl_clip"]
|
||||
)
|
||||
compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1)
|
||||
compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2)
|
||||
|
||||
check_compilation(compiled_clip, "Clip")
|
||||
check_compilation(compiled_clip, "Clip2")
|
||||
if self.return_mlir:
|
||||
return clip_mlir, clip_mlir2
|
||||
return compiled_clip, compiled_clip2
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def unet(self, use_large=False):
|
||||
try:
|
||||
model = "stencil_unet" if self.use_stencil is not None else "unet"
|
||||
stencil_count = 0
|
||||
for stencil in self.stencils:
|
||||
stencil_count += 1
|
||||
model = "stencil_unet" if stencil_count > 0 else "unet"
|
||||
compiled_unet = None
|
||||
unet_inputs = base_models[model]
|
||||
|
||||
# if the model to run *is* a base model, then we should treat it as such
|
||||
if self.model_to_run in unet_inputs:
|
||||
self.base_model_id = self.model_to_run
|
||||
|
||||
if self.base_model_id != "":
|
||||
self.inputs["unet"] = self.get_input_info_for(
|
||||
unet_inputs[self.base_model_id]
|
||||
)
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
model, use_large=use_large, base_model=self.base_model_id
|
||||
)
|
||||
else:
|
||||
for model_id in unet_inputs:
|
||||
# restrict base models to check if we were given a specific list of valid ones
|
||||
allowed_base_model_ids = unet_inputs
|
||||
if self.favored_base_models != None:
|
||||
allowed_base_model_ids = self.favored_base_models
|
||||
|
||||
print(f"self.favored_base_models: {self.favored_base_models}")
|
||||
print(f"allowed_base_model_ids: {allowed_base_model_ids}")
|
||||
|
||||
# try compiling with each base model until we find one that works (of not)
|
||||
for model_id in allowed_base_model_ids:
|
||||
self.base_model_id = model_id
|
||||
self.inputs["unet"] = self.get_input_info_for(
|
||||
unet_inputs[model_id]
|
||||
@@ -831,12 +1319,12 @@ class SharkifyStableDiffusionModel:
|
||||
|
||||
try:
|
||||
compiled_unet, unet_mlir = self.compile_unet_variants(
|
||||
model, use_large=use_large
|
||||
model, use_large=use_large, base_model=model_id
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print(
|
||||
"Retrying with a different base model configuration"
|
||||
f"Retrying with a different base model configuration, as {model_id} did not work"
|
||||
)
|
||||
continue
|
||||
|
||||
@@ -870,7 +1358,10 @@ class SharkifyStableDiffusionModel:
|
||||
is_base_vae = self.base_vae
|
||||
if self.is_upscaler:
|
||||
self.base_vae = True
|
||||
compiled_vae, vae_mlir = self.get_vae()
|
||||
if self.is_sdxl:
|
||||
compiled_vae, vae_mlir = self.get_vae_sdxl()
|
||||
else:
|
||||
compiled_vae, vae_mlir = self.get_vae()
|
||||
self.base_vae = is_base_vae
|
||||
|
||||
check_compilation(compiled_vae, "Vae")
|
||||
@@ -880,18 +1371,18 @@ class SharkifyStableDiffusionModel:
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
def controlnet(self, use_large=False):
|
||||
def controlnet(self, stencil_id, use_large=False):
|
||||
try:
|
||||
self.inputs["stencil_adaptor"] = self.get_input_info_for(
|
||||
base_models["stencil_adaptor"]
|
||||
self.inputs["stencil_adapter"] = self.get_input_info_for(
|
||||
base_models["stencil_adapter"]
|
||||
)
|
||||
compiled_stencil_adaptor, controlnet_mlir = self.get_control_net(
|
||||
use_large=use_large
|
||||
compiled_stencil_adapter, controlnet_mlir = self.get_control_net(
|
||||
stencil_id, use_large=use_large
|
||||
)
|
||||
|
||||
check_compilation(compiled_stencil_adaptor, "Stencil")
|
||||
check_compilation(compiled_stencil_adapter, "Stencil")
|
||||
if self.return_mlir:
|
||||
return controlnet_mlir
|
||||
return compiled_stencil_adaptor
|
||||
return compiled_stencil_adapter
|
||||
except Exception as e:
|
||||
sys.exit(e)
|
||||
|
||||
@@ -123,8 +123,11 @@ def get_clip():
|
||||
return get_shark_model(bucket, model_name, iree_flags)
|
||||
|
||||
|
||||
def get_tokenizer():
|
||||
def get_tokenizer(subfolder="tokenizer", hf_model_id=None):
|
||||
if hf_model_id is not None:
|
||||
args.hf_model_id = hf_model_id
|
||||
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.hf_model_id, subfolder="tokenizer"
|
||||
args.hf_model_id, subfolder=subfolder
|
||||
)
|
||||
return tokenizer
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import (
|
||||
Text2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import (
|
||||
Text2ImageSDXLPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import (
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
|
||||
@@ -56,9 +56,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
@@ -158,8 +161,11 @@ class Image2ImagePipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
stencils,
|
||||
images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints=[],
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
|
||||
@@ -51,9 +51,12 @@ class InpaintPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
|
||||
@@ -52,9 +52,12 @@ class OutpaintPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.vae_encode = None
|
||||
|
||||
def load_vae_encode(self):
|
||||
|
||||
@@ -25,12 +25,22 @@ from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import controlnet_hint_conversion
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
controlnet_hint_conversion,
|
||||
controlnet_hint_reshaping,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
start_profiling,
|
||||
end_profiling,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
resamplers,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae_encode,
|
||||
)
|
||||
|
||||
|
||||
class StencilPipeline(StableDiffusionPipeline):
|
||||
@@ -54,29 +64,69 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
controlnet_names: list[str],
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
self.controlnet = None
|
||||
self.controlnet_512 = None
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.controlnet = [None] * len(controlnet_names)
|
||||
self.controlnet_512 = [None] * len(controlnet_names)
|
||||
self.controlnet_id = [str] * len(controlnet_names)
|
||||
self.controlnet_512_id = [str] * len(controlnet_names)
|
||||
self.controlnet_names = controlnet_names
|
||||
self.vae_encode = None
|
||||
|
||||
def load_controlnet(self):
|
||||
if self.controlnet is not None:
|
||||
def load_vae_encode(self):
|
||||
if self.vae_encode is not None:
|
||||
return
|
||||
self.controlnet = self.sd_model.controlnet()
|
||||
|
||||
def unload_controlnet(self):
|
||||
del self.controlnet
|
||||
self.controlnet = None
|
||||
if self.import_mlir or self.use_lora:
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
else:
|
||||
try:
|
||||
self.vae_encode = get_vae_encode()
|
||||
except:
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
self.vae_encode = self.sd_model.vae_encode()
|
||||
|
||||
def load_controlnet_512(self):
|
||||
if self.controlnet_512 is not None:
|
||||
def unload_vae_encode(self):
|
||||
del self.vae_encode
|
||||
self.vae_encode = None
|
||||
|
||||
def load_controlnet(self, index, model_name):
|
||||
if model_name is None:
|
||||
return
|
||||
self.controlnet_512 = self.sd_model.controlnet(use_large=True)
|
||||
if (
|
||||
self.controlnet[index] is not None
|
||||
and self.controlnet_id[index] is not None
|
||||
and self.controlnet_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_id[index] = model_name
|
||||
self.controlnet[index] = self.sd_model.controlnet(model_name)
|
||||
|
||||
def unload_controlnet_512(self):
|
||||
del self.controlnet_512
|
||||
self.controlnet_512 = None
|
||||
def unload_controlnet(self, index):
|
||||
del self.controlnet[index]
|
||||
self.controlnet_id[index] = None
|
||||
self.controlnet[index] = None
|
||||
|
||||
def load_controlnet_512(self, index, model_name):
|
||||
if (
|
||||
self.controlnet_512[index] is not None
|
||||
and self.controlnet_512_id[index] == model_name
|
||||
):
|
||||
return
|
||||
self.controlnet_512_id[index] = model_name
|
||||
self.controlnet_512[index] = self.sd_model.controlnet(
|
||||
model_name, use_large=True
|
||||
)
|
||||
|
||||
def unload_controlnet_512(self, index):
|
||||
del self.controlnet_512[index]
|
||||
self.controlnet_512_id[index] = None
|
||||
self.controlnet_512[index] = None
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
@@ -103,6 +153,58 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
image,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
strength,
|
||||
dtype,
|
||||
resample_type,
|
||||
):
|
||||
# Pre process image -> get image encoded -> process latents
|
||||
|
||||
# TODO: process with variable HxW combos
|
||||
|
||||
# Pre-process image
|
||||
resample_type = (
|
||||
resamplers[resample_type]
|
||||
if resample_type in resampler_list
|
||||
# Fallback to Lanczos
|
||||
else Image.Resampling.LANCZOS
|
||||
)
|
||||
|
||||
image = image.resize((width, height), resample=resample_type)
|
||||
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)
|
||||
image_arr = image_arr / 255.0
|
||||
image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype)
|
||||
image_arr = 2 * (image_arr - 0.5)
|
||||
|
||||
# set scheduler steps
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
init_timestep = min(
|
||||
int(num_inference_steps * strength), num_inference_steps
|
||||
)
|
||||
t_start = max(num_inference_steps - init_timestep, 0)
|
||||
# timesteps reduced as per strength
|
||||
timesteps = self.scheduler.timesteps[t_start:]
|
||||
# new number of steps to be used as per strength will be
|
||||
# num_inference_steps = num_inference_steps - t_start
|
||||
|
||||
# image encode
|
||||
latents = self.encode_image((image_arr,))
|
||||
latents = torch.from_numpy(latents).to(dtype)
|
||||
# add noise to data
|
||||
noise = torch.randn(latents.shape, generator=generator, dtype=dtype)
|
||||
latents = self.scheduler.add_noise(
|
||||
latents, noise, timesteps[0].repeat(1)
|
||||
)
|
||||
|
||||
return latents, timesteps
|
||||
|
||||
def produce_stencil_latents(
|
||||
self,
|
||||
latents,
|
||||
@@ -111,8 +213,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
total_timesteps,
|
||||
dtype,
|
||||
cpu_scheduling,
|
||||
controlnet_hint=None,
|
||||
stencil_hints=[None],
|
||||
controlnet_conditioning_scale: float = 1.0,
|
||||
control_mode="Balanced", # Prompt, Balanced, or Controlnet
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
@@ -121,12 +224,24 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
latent_history = [latents]
|
||||
text_embeddings = torch.from_numpy(text_embeddings).to(dtype)
|
||||
text_embeddings_numpy = text_embeddings.detach().numpy()
|
||||
assert control_mode in ["Prompt", "Balanced", "Controlnet"]
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_unet()
|
||||
self.load_controlnet()
|
||||
else:
|
||||
self.load_unet_512()
|
||||
self.load_controlnet_512()
|
||||
|
||||
for i, name in enumerate(self.controlnet_names):
|
||||
use_names = []
|
||||
if name is not None:
|
||||
use_names.append(name)
|
||||
else:
|
||||
continue
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
self.load_controlnet(i, name)
|
||||
else:
|
||||
self.load_controlnet_512(i, name)
|
||||
self.controlnet_names = use_names
|
||||
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype)
|
||||
@@ -149,33 +264,93 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
).to(dtype)
|
||||
else:
|
||||
latent_model_input_1 = latent_model_input
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
# Multicontrolnet
|
||||
height = latent_model_input_1.shape[2]
|
||||
width = latent_model_input_1.shape[3]
|
||||
dtype = latent_model_input_1.dtype
|
||||
control_acc = (
|
||||
[torch.zeros((2, 320, height, width), dtype=dtype)] * 3
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 320, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 2), int(width / 2)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 640, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 4), int(width / 4)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 2
|
||||
+ [
|
||||
torch.zeros(
|
||||
(2, 1280, int(height / 8), int(width / 8)), dtype=dtype
|
||||
)
|
||||
]
|
||||
* 4
|
||||
)
|
||||
for i, controlnet_hint in enumerate(stencil_hints):
|
||||
if controlnet_hint is None:
|
||||
pass
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
control = self.controlnet[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
control = self.controlnet_512[i](
|
||||
"forward",
|
||||
(
|
||||
latent_model_input_1,
|
||||
timestep,
|
||||
text_embeddings,
|
||||
controlnet_hint,
|
||||
*control_acc,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
control_acc = control[13:]
|
||||
control = control[:13]
|
||||
|
||||
timestep = timestep.detach().numpy()
|
||||
# Profiling Unet.
|
||||
profile_device = start_profiling(file_path="unet.rdc")
|
||||
# TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py.
|
||||
|
||||
dtype = latents.dtype
|
||||
if control_mode == "Balanced":
|
||||
control_scale = [
|
||||
torch.tensor(1.0, dtype=dtype) for _ in range(len(control))
|
||||
]
|
||||
elif control_mode == "Prompt":
|
||||
control_scale = [
|
||||
torch.tensor(0.825**x, dtype=dtype)
|
||||
for x in range(len(control))
|
||||
]
|
||||
elif control_mode == "Controlnet":
|
||||
control_scale = [
|
||||
torch.tensor(float(guidance_scale), dtype=dtype)
|
||||
for _ in range(len(control))
|
||||
]
|
||||
|
||||
if text_embeddings.shape[1] <= self.model_max_length:
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
@@ -197,11 +372,23 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
else:
|
||||
print(self.unet_512)
|
||||
noise_pred = self.unet_512(
|
||||
"forward",
|
||||
(
|
||||
@@ -222,6 +409,19 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
control[10],
|
||||
control[11],
|
||||
control[12],
|
||||
control_scale[0],
|
||||
control_scale[1],
|
||||
control_scale[2],
|
||||
control_scale[3],
|
||||
control_scale[4],
|
||||
control_scale[5],
|
||||
control_scale[6],
|
||||
control_scale[7],
|
||||
control_scale[8],
|
||||
control_scale[9],
|
||||
control_scale[10],
|
||||
control_scale[11],
|
||||
control_scale[12],
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -245,8 +445,9 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
self.unload_controlnet()
|
||||
self.unload_controlnet_512()
|
||||
for i in range(len(self.controlnet_names)):
|
||||
self.unload_controlnet(i)
|
||||
self.unload_controlnet_512(i)
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -255,6 +456,17 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def encode_image(self, input_image):
|
||||
self.load_vae_encode()
|
||||
vae_encode_start = time.time()
|
||||
latents = self.vae_encode("forward", input_image)
|
||||
vae_inf_time = (time.time() - vae_encode_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_vae_encode()
|
||||
self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}"
|
||||
|
||||
return latents
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
@@ -272,14 +484,48 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
use_stencil,
|
||||
stencils,
|
||||
stencil_images,
|
||||
resample_type,
|
||||
control_mode,
|
||||
preprocessed_hints,
|
||||
):
|
||||
# Control Embedding check & conversion
|
||||
# TODO: 1. Change `num_images_per_prompt`.
|
||||
controlnet_hint = controlnet_hint_conversion(
|
||||
image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
)
|
||||
# controlnet_hint = controlnet_hint_conversion(
|
||||
# image, use_stencil, height, width, dtype, num_images_per_prompt=1
|
||||
# )
|
||||
stencil_hints = []
|
||||
self.sd_model.stencils = stencils
|
||||
for i, hint in enumerate(preprocessed_hints):
|
||||
if hint is not None:
|
||||
hint = controlnet_hint_reshaping(
|
||||
hint,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
stencil_hints.append(hint)
|
||||
|
||||
for i, stencil in enumerate(stencils):
|
||||
if stencil == None:
|
||||
continue
|
||||
if len(stencil_hints) > i:
|
||||
if stencil_hints[i] is not None:
|
||||
print(f"Using preprocessed controlnet hint for {stencil}")
|
||||
continue
|
||||
image = stencil_images[i]
|
||||
stencil_hints.append(
|
||||
controlnet_hint_conversion(
|
||||
image,
|
||||
stencil,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt=1,
|
||||
)
|
||||
)
|
||||
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
@@ -307,17 +553,30 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(torch.float32)
|
||||
|
||||
# Prepare initial latent.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
final_timesteps = self.scheduler.timesteps
|
||||
if image is not None:
|
||||
# Prepare input image latent
|
||||
init_latents, final_timesteps = self.prepare_image_latents(
|
||||
image=image,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
strength=strength,
|
||||
dtype=dtype,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
else:
|
||||
# Prepare initial latent.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
final_timesteps = self.scheduler.timesteps
|
||||
|
||||
# Get Image latents
|
||||
latents = self.produce_stencil_latents(
|
||||
@@ -327,7 +586,8 @@ class StencilPipeline(StableDiffusionPipeline):
|
||||
total_timesteps=final_timesteps,
|
||||
dtype=dtype,
|
||||
cpu_scheduling=cpu_scheduling,
|
||||
controlnet_hint=controlnet_hint,
|
||||
control_mode=control_mode,
|
||||
stencil_hints=stencil_hints,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images
|
||||
|
||||
@@ -18,7 +18,10 @@ from diffusers import (
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
@@ -46,9 +49,19 @@ class Text2ImagePipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def favored_base_models(cls, model_id):
|
||||
return [
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
]
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
|
||||
@@ -0,0 +1,236 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
from random import randint
|
||||
from typing import Union
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Text2ImageSDXLPipeline(StableDiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
],
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
is_fp32_vae: bool,
|
||||
):
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.is_fp32_vae = is_fp32_vae
|
||||
|
||||
@classmethod
|
||||
def favored_base_models(cls, model_id):
|
||||
if "turbo" in model_id:
|
||||
return [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
generator,
|
||||
num_inference_steps,
|
||||
dtype,
|
||||
):
|
||||
latents = torch.randn(
|
||||
(
|
||||
batch_size,
|
||||
4,
|
||||
height // 8,
|
||||
width // 8,
|
||||
),
|
||||
generator=generator,
|
||||
dtype=torch.float32,
|
||||
).to(dtype)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
self.scheduler.is_scale_input_called = True
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _get_add_time_ids(
|
||||
self, original_size, crops_coords_top_left, target_size, dtype
|
||||
):
|
||||
add_time_ids = list(
|
||||
original_size + crops_coords_top_left + target_size
|
||||
)
|
||||
|
||||
# self.unet.config.addition_time_embed_dim IS 256.
|
||||
# self.text_encoder_2.config.projection_dim IS 1280.
|
||||
passed_add_embed_dim = 256 * len(add_time_ids) + 1280
|
||||
expected_add_embed_dim = 2816
|
||||
# self.unet.add_embedding.linear_1.in_features IS 2816.
|
||||
|
||||
if expected_add_embed_dim != passed_add_embed_dim:
|
||||
raise ValueError(
|
||||
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
||||
)
|
||||
|
||||
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
||||
return add_time_ids
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
prompts,
|
||||
neg_prompts,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
max_length,
|
||||
dtype,
|
||||
use_base_vae,
|
||||
cpu_scheduling,
|
||||
max_embeddings_multiples,
|
||||
):
|
||||
# prompts and negative prompts must be a list.
|
||||
if isinstance(prompts, str):
|
||||
prompts = [prompts]
|
||||
|
||||
if isinstance(neg_prompts, str):
|
||||
neg_prompts = [neg_prompts]
|
||||
|
||||
prompts = prompts * batch_size
|
||||
neg_prompts = neg_prompts * batch_size
|
||||
|
||||
# seed generator to create the inital latent noise. Also handle out of range seeds.
|
||||
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
|
||||
uint32_info = np.iinfo(np.uint32)
|
||||
uint32_min, uint32_max = uint32_info.min, uint32_info.max
|
||||
if seed < uint32_min or seed >= uint32_max:
|
||||
seed = randint(uint32_min, uint32_max)
|
||||
generator = torch.manual_seed(seed)
|
||||
|
||||
# Get initial latents.
|
||||
init_latents = self.prepare_latents(
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
generator=generator,
|
||||
num_inference_steps=num_inference_steps,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
# Get text embeddings.
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt_sdxl(
|
||||
prompt=prompts,
|
||||
num_images_per_prompt=1,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=neg_prompts,
|
||||
)
|
||||
|
||||
# Prepare timesteps.
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# Prepare added time ids & embeddings.
|
||||
original_size = (height, width)
|
||||
target_size = (height, width)
|
||||
crops_coords_top_left = (0, 0)
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
add_time_ids = self._get_add_time_ids(
|
||||
original_size,
|
||||
crops_coords_top_left,
|
||||
target_size,
|
||||
dtype=prompt_embeds.dtype,
|
||||
)
|
||||
|
||||
prompt_embeds = torch.cat(
|
||||
[negative_prompt_embeds, prompt_embeds], dim=0
|
||||
)
|
||||
add_text_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embeds, add_text_embeds], dim=0
|
||||
)
|
||||
add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
|
||||
|
||||
prompt_embeds = prompt_embeds
|
||||
add_text_embeds = add_text_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.repeat(batch_size * 1, 1)
|
||||
|
||||
# guidance scale as a float32 tensor.
|
||||
guidance_scale = torch.tensor(guidance_scale).to(dtype)
|
||||
prompt_embeds = prompt_embeds.to(dtype)
|
||||
add_time_ids = add_time_ids.to(dtype)
|
||||
|
||||
# Get Image latents.
|
||||
latents = self.produce_img_latents_sdxl(
|
||||
init_latents,
|
||||
timesteps,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
prompt_embeds,
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
)
|
||||
|
||||
# Img latents -> PIL images.
|
||||
all_imgs = []
|
||||
self.load_vae()
|
||||
for i in range(0, latents.shape[0], batch_size):
|
||||
imgs = self.decode_latents_sdxl(
|
||||
latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae
|
||||
)
|
||||
all_imgs.extend(imgs)
|
||||
if self.ondemand:
|
||||
self.unload_vae()
|
||||
|
||||
return all_imgs
|
||||
@@ -94,9 +94,12 @@ class UpscalerPipeline(StableDiffusionPipeline):
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
):
|
||||
super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
super().__init__(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
self.low_res_scheduler = low_res_scheduler
|
||||
self.status = SD_STATE_IDLE
|
||||
|
||||
|
||||
@@ -20,7 +20,10 @@ from diffusers import (
|
||||
HeunDiscreteScheduler,
|
||||
)
|
||||
from shark.shark_inference import SharkInference
|
||||
from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler
|
||||
from apps.stable_diffusion.src.schedulers import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.models import (
|
||||
SharkifyStableDiffusionModel,
|
||||
get_vae,
|
||||
@@ -33,6 +36,8 @@ from apps.stable_diffusion.src.utils import (
|
||||
end_profiling,
|
||||
)
|
||||
import sys
|
||||
import gc
|
||||
from typing import List, Optional
|
||||
|
||||
SD_STATE_IDLE = "idle"
|
||||
SD_STATE_CANCEL = "cancel"
|
||||
@@ -50,6 +55,7 @@ class StableDiffusionPipeline:
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
SharkEulerDiscreteScheduler,
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
DEISMultistepScheduler,
|
||||
DDPMScheduler,
|
||||
DPMSolverSinglestepScheduler,
|
||||
@@ -59,21 +65,26 @@ class StableDiffusionPipeline:
|
||||
sd_model: SharkifyStableDiffusionModel,
|
||||
import_mlir: bool,
|
||||
use_lora: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
is_f32_vae: bool = False,
|
||||
):
|
||||
self.vae = None
|
||||
self.text_encoder = None
|
||||
self.text_encoder_2 = None
|
||||
self.unet = None
|
||||
self.unet_512 = None
|
||||
self.model_max_length = 77
|
||||
self.scheduler = scheduler
|
||||
# TODO: Implement using logging python utility.
|
||||
self.log = ""
|
||||
self.status = SD_STATE_IDLE
|
||||
self.sd_model = sd_model
|
||||
self.scheduler = scheduler
|
||||
self.import_mlir = import_mlir
|
||||
self.use_lora = use_lora
|
||||
self.lora_strength = lora_strength
|
||||
self.ondemand = ondemand
|
||||
self.is_f32_vae = is_f32_vae
|
||||
# TODO: Find a better workaround for fetching base_model_id early
|
||||
# enough for CLIPTokenizer.
|
||||
try:
|
||||
@@ -83,6 +94,10 @@ class StableDiffusionPipeline:
|
||||
self.unload_unet()
|
||||
self.tokenizer = get_tokenizer()
|
||||
|
||||
def favored_base_models(cls, model_id):
|
||||
# all base models can be candidate base models for unet compilation
|
||||
return None
|
||||
|
||||
def load_clip(self):
|
||||
if self.text_encoder is not None:
|
||||
return
|
||||
@@ -106,6 +121,34 @@ class StableDiffusionPipeline:
|
||||
del self.text_encoder
|
||||
self.text_encoder = None
|
||||
|
||||
def load_clip_sdxl(self):
|
||||
if self.text_encoder and self.text_encoder_2:
|
||||
return
|
||||
|
||||
if self.import_mlir or self.use_lora:
|
||||
if not self.import_mlir:
|
||||
print(
|
||||
"Warning: LoRA provided but import_mlir not specified. "
|
||||
"Importing MLIR anyways."
|
||||
)
|
||||
self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip()
|
||||
else:
|
||||
try:
|
||||
# TODO: Fix this for SDXL
|
||||
self.text_encoder = get_clip()
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("download pipeline failed, falling back to import_mlir")
|
||||
(
|
||||
self.text_encoder,
|
||||
self.text_encoder_2,
|
||||
) = self.sd_model.sdxl_clip()
|
||||
|
||||
def unload_clip_sdxl(self):
|
||||
del self.text_encoder, self.text_encoder_2
|
||||
self.text_encoder = None
|
||||
self.text_encoder_2 = None
|
||||
|
||||
def load_unet(self):
|
||||
if self.unet is not None:
|
||||
return
|
||||
@@ -159,6 +202,182 @@ class StableDiffusionPipeline:
|
||||
def unload_vae(self):
|
||||
del self.vae
|
||||
self.vae = None
|
||||
gc.collect()
|
||||
|
||||
def encode_prompt_sdxl(
|
||||
self,
|
||||
prompt: str,
|
||||
num_images_per_prompt: int = 1,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[str] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
hf_model_id: Optional[
|
||||
str
|
||||
] = "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
):
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id)
|
||||
self.load_clip_sdxl()
|
||||
tokenizers = (
|
||||
[self.tokenizer, self.tokenizer_2]
|
||||
if self.tokenizer is not None
|
||||
else [self.tokenizer_2]
|
||||
)
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2]
|
||||
if self.text_encoder is not None
|
||||
else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
prompt_embeds_list = []
|
||||
prompts = [prompt, prompt]
|
||||
for prompt, tokenizer, text_encoder in zip(
|
||||
prompts, tokenizers, text_encoders
|
||||
):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(
|
||||
prompt, padding="longest", return_tensors="pt"
|
||||
).input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
||||
-1
|
||||
] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(
|
||||
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
print(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_encoder_output = text_encoder("forward", (text_input_ids,))
|
||||
prompt_embeds = torch.from_numpy(text_encoder_output[0])
|
||||
pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1])
|
||||
|
||||
prompt_embeds_list.append(prompt_embeds)
|
||||
|
||||
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
zero_out_negative_prompt = (
|
||||
negative_prompt is None
|
||||
and self.config.force_zeros_for_empty_prompt
|
||||
)
|
||||
if (
|
||||
do_classifier_free_guidance
|
||||
and negative_prompt_embeds is None
|
||||
and zero_out_negative_prompt
|
||||
):
|
||||
negative_prompt_embeds = torch.zeros_like(prompt_embeds)
|
||||
negative_pooled_prompt_embeds = torch.zeros_like(
|
||||
pooled_prompt_embeds
|
||||
)
|
||||
elif do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt
|
||||
|
||||
uncond_tokens: List[str]
|
||||
if prompt is not None and type(prompt) is not type(
|
||||
negative_prompt
|
||||
):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif isinstance(negative_prompt, str):
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
else:
|
||||
uncond_tokens = [negative_prompt, negative_prompt_2]
|
||||
|
||||
negative_prompt_embeds_list = []
|
||||
for negative_prompt, tokenizer, text_encoder in zip(
|
||||
uncond_tokens, tokenizers, text_encoders
|
||||
):
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = tokenizer(
|
||||
negative_prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_encoder_output = text_encoder(
|
||||
"forward", (uncond_input.input_ids,)
|
||||
)
|
||||
negative_prompt_embeds = torch.from_numpy(
|
||||
text_encoder_output[0]
|
||||
)
|
||||
negative_pooled_prompt_embeds = torch.from_numpy(
|
||||
text_encoder_output[1]
|
||||
)
|
||||
|
||||
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
||||
|
||||
negative_prompt_embeds = torch.concat(
|
||||
negative_prompt_embeds_list, dim=-1
|
||||
)
|
||||
|
||||
if self.ondemand:
|
||||
self.unload_clip_sdxl()
|
||||
gc.collect()
|
||||
|
||||
# TODO: Look into dtype for text_encoder_2!
|
||||
prompt_embeds = prompt_embeds.to(dtype=torch.float16)
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(
|
||||
bs_embed * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt, 1
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt
|
||||
).view(bs_embed * num_images_per_prompt, -1)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(
|
||||
1, num_images_per_prompt
|
||||
).view(bs_embed * num_images_per_prompt, -1)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
def encode_prompts(self, prompts, neg_prompts, max_length):
|
||||
# Tokenize text and get embeddings
|
||||
@@ -186,6 +405,7 @@ class StableDiffusionPipeline:
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings
|
||||
@@ -298,6 +518,8 @@ class StableDiffusionPipeline:
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
self.unload_unet_512()
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
@@ -306,6 +528,96 @@ class StableDiffusionPipeline:
|
||||
all_latents = torch.cat(latent_history, dim=0)
|
||||
return all_latents
|
||||
|
||||
def produce_img_latents_sdxl(
|
||||
self,
|
||||
latents,
|
||||
total_timesteps,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
prompt_embeds,
|
||||
cpu_scheduling,
|
||||
guidance_scale,
|
||||
dtype,
|
||||
mask=None,
|
||||
masked_image_latents=None,
|
||||
return_all_latents=False,
|
||||
):
|
||||
# return None
|
||||
self.status = SD_STATE_IDLE
|
||||
step_time_sum = 0
|
||||
extra_step_kwargs = {"generator": None}
|
||||
self.load_unet()
|
||||
for i, t in tqdm(enumerate(total_timesteps)):
|
||||
step_start_time = time.time()
|
||||
timestep = torch.tensor([t]).to(dtype).detach().numpy()
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
if isinstance(latents, np.ndarray):
|
||||
latents = torch.tensor(latents)
|
||||
latent_model_input = torch.cat([latents] * 2)
|
||||
|
||||
latent_model_input = self.scheduler.scale_model_input(
|
||||
latent_model_input, t
|
||||
)
|
||||
if mask is not None and masked_image_latents is not None:
|
||||
latent_model_input = torch.cat(
|
||||
[
|
||||
torch.from_numpy(np.asarray(latent_model_input)),
|
||||
mask,
|
||||
masked_image_latents,
|
||||
],
|
||||
dim=1,
|
||||
).to(dtype)
|
||||
|
||||
noise_pred = self.unet(
|
||||
"forward",
|
||||
(
|
||||
latent_model_input,
|
||||
timestep,
|
||||
prompt_embeds,
|
||||
add_text_embeds,
|
||||
add_time_ids,
|
||||
guidance_scale,
|
||||
),
|
||||
send_to_host=True,
|
||||
)
|
||||
if not isinstance(latents, torch.Tensor):
|
||||
latents = torch.from_numpy(latents).to("cpu")
|
||||
noise_pred = torch.from_numpy(noise_pred).to("cpu")
|
||||
|
||||
latents = self.scheduler.step(
|
||||
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
||||
)[0]
|
||||
latents = latents.detach().numpy()
|
||||
noise_pred = noise_pred.detach().numpy()
|
||||
|
||||
step_time = (time.time() - step_start_time) * 1000
|
||||
step_time_sum += step_time
|
||||
|
||||
if self.status == SD_STATE_CANCEL:
|
||||
break
|
||||
if self.ondemand:
|
||||
self.unload_unet()
|
||||
gc.collect()
|
||||
|
||||
avg_step_time = step_time_sum / len(total_timesteps)
|
||||
self.log += f"\nAverage step time: {avg_step_time}ms/it"
|
||||
|
||||
return latents
|
||||
|
||||
def decode_latents_sdxl(self, latents, is_fp32_vae):
|
||||
# latents are in unet dtype here so switch if we want to use fp32
|
||||
if is_fp32_vae:
|
||||
print("Casting latents to float32 for VAE")
|
||||
latents = latents.to(torch.float32)
|
||||
images = self.vae("forward", (latents,))
|
||||
images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1)
|
||||
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
images = (images * 255).round().astype("uint8")
|
||||
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
||||
|
||||
return pil_images
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -338,8 +650,10 @@ class StableDiffusionPipeline:
|
||||
ondemand: bool,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
debug: bool = False,
|
||||
use_stencil: str = None,
|
||||
stencils: list[str] = [],
|
||||
# stencil_images: list[Image] = []
|
||||
use_lora: str = "",
|
||||
lora_strength: float = 0.75,
|
||||
ddpm_scheduler: DDPMScheduler = None,
|
||||
use_quantize=None,
|
||||
):
|
||||
@@ -355,7 +669,11 @@ class StableDiffusionPipeline:
|
||||
"OutpaintPipeline",
|
||||
]
|
||||
is_upscaler = cls.__name__ in ["UpscalerPipeline"]
|
||||
is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"]
|
||||
|
||||
print(f"model_id", model_id)
|
||||
print(f"ckpt_loc", ckpt_loc)
|
||||
print(f"favored_base_models:", cls.favored_base_models(model_id))
|
||||
sd_model = SharkifyStableDiffusionModel(
|
||||
model_id,
|
||||
ckpt_loc,
|
||||
@@ -371,9 +689,14 @@ class StableDiffusionPipeline:
|
||||
debug=debug,
|
||||
is_inpaint=is_inpaint,
|
||||
is_upscaler=is_upscaler,
|
||||
use_stencil=use_stencil,
|
||||
is_sdxl=is_sdxl,
|
||||
stencils=stencils,
|
||||
use_lora=use_lora,
|
||||
lora_strength=lora_strength,
|
||||
use_quantize=use_quantize,
|
||||
favored_base_models=cls.favored_base_models(
|
||||
model_id if model_id != "" else ckpt_loc
|
||||
),
|
||||
)
|
||||
|
||||
if cls.__name__ in ["UpscalerPipeline"]:
|
||||
@@ -383,10 +706,35 @@ class StableDiffusionPipeline:
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
)
|
||||
|
||||
return cls(scheduler, sd_model, import_mlir, use_lora, ondemand)
|
||||
if cls.__name__ == "StencilPipeline":
|
||||
return cls(
|
||||
scheduler,
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
stencils,
|
||||
)
|
||||
if cls.__name__ == "Text2ImageSDXLPipeline":
|
||||
is_fp32_vae = True if "16" not in custom_vae else False
|
||||
return cls(
|
||||
scheduler,
|
||||
sd_model,
|
||||
import_mlir,
|
||||
use_lora,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
is_fp32_vae,
|
||||
)
|
||||
|
||||
return cls(
|
||||
scheduler, sd_model, import_mlir, use_lora, lora_strength, ondemand
|
||||
)
|
||||
|
||||
# #####################################################
|
||||
# Implements text embeddings with weights from prompts
|
||||
@@ -498,9 +846,10 @@ class StableDiffusionPipeline:
|
||||
clip_inf_time = (time.time() - clip_inf_start) * 1000
|
||||
if self.ondemand:
|
||||
self.unload_clip()
|
||||
gc.collect()
|
||||
self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}"
|
||||
|
||||
return text_embeddings.numpy()
|
||||
return text_embeddings.numpy().astype(np.float16)
|
||||
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from diffusers import (
|
||||
LCMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDPMScheduler,
|
||||
@@ -15,9 +16,21 @@ from diffusers import (
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import (
|
||||
SharkEulerDiscreteScheduler,
|
||||
)
|
||||
from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import (
|
||||
SharkEulerAncestralDiscreteScheduler,
|
||||
)
|
||||
|
||||
|
||||
def get_schedulers(model_id):
|
||||
# TODO: Robust scheduler setup on pipeline creation -- if we don't
|
||||
# set batch_size here, the SHARK schedulers will
|
||||
# compile with batch size = 1 regardless of whether the model
|
||||
# outputs latents of a larger batch size, e.g. SDXL.
|
||||
# However, obviously, searching for whether the base model ID
|
||||
# contains "xl" is not very robust.
|
||||
|
||||
batch_size = 2 if "xl" in model_id.lower() else 1
|
||||
|
||||
schedulers = dict()
|
||||
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
|
||||
model_id,
|
||||
@@ -39,6 +52,10 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["LCMScheduler"] = LCMScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverMultistep"
|
||||
] = DPMSolverMultistepScheduler.from_pretrained(
|
||||
@@ -84,6 +101,12 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"SharkEulerAncestralDiscrete"
|
||||
] = SharkEulerAncestralDiscreteScheduler.from_pretrained(
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers[
|
||||
"DPMSolverSinglestep"
|
||||
] = DPMSolverSinglestepScheduler.from_pretrained(
|
||||
@@ -100,5 +123,6 @@ def get_schedulers(model_id):
|
||||
model_id,
|
||||
subfolder="scheduler",
|
||||
)
|
||||
schedulers["SharkEulerDiscrete"].compile()
|
||||
schedulers["SharkEulerDiscrete"].compile(batch_size)
|
||||
schedulers["SharkEulerAncestralDiscrete"].compile(batch_size)
|
||||
return schedulers
|
||||
|
||||
@@ -0,0 +1,251 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
EulerAncestralDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
get_shark_model,
|
||||
args,
|
||||
)
|
||||
import torch
|
||||
|
||||
|
||||
class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
beta_start: float = 0.0001,
|
||||
beta_end: float = 0.02,
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
timestep_spacing: str = "linspace",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
beta_start,
|
||||
beta_end,
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
timestep_spacing,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = None
|
||||
self.init_input_shape = None
|
||||
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"eulera": {
|
||||
"output": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"latent": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"sigma_from": torch.tensor(1).to(torch.float32),
|
||||
"sigma_to": torch.tensor(1).to(torch.float32),
|
||||
"noise": torch.randn(
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
example_latent = model_input["eulera"]["latent"]
|
||||
example_output = model_input["eulera"]["output"]
|
||||
example_noise = model_input["eulera"]["noise"]
|
||||
if args.precision == "fp16":
|
||||
example_latent = example_latent.half()
|
||||
example_output = example_output.half()
|
||||
example_noise = example_noise.half()
|
||||
example_sigma = model_input["eulera"]["sigma"]
|
||||
example_sigma_from = model_input["eulera"]["sigma_from"]
|
||||
example_sigma_to = model_input["eulera"]["sigma_to"]
|
||||
|
||||
class ScalingModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, latent, sigma, sigma_from, sigma_to, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2
|
||||
* (sigma_from**2 - sigma_to**2)
|
||||
/ sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(
|
||||
self, noise_pred, sigma, sigma_from, sigma_to, latent, noise
|
||||
):
|
||||
sigma_up = (
|
||||
sigma_to**2
|
||||
* (sigma_from**2 - sigma_to**2)
|
||||
/ sigma_from**2
|
||||
) ** 0.5
|
||||
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
||||
dt = sigma_down - sigma
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
prev_sample = latent + derivative * dt
|
||||
return prev_sample + noise * sigma_up
|
||||
|
||||
iree_flags = []
|
||||
if len(args.iree_vulkan_target_triple) > 0:
|
||||
iree_flags.append(
|
||||
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
|
||||
)
|
||||
|
||||
def _import(self):
|
||||
scaling_model = ScalingModel()
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(
|
||||
example_output,
|
||||
example_latent,
|
||||
example_sigma,
|
||||
example_sigma_from,
|
||||
example_sigma_to,
|
||||
example_noise,
|
||||
),
|
||||
extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
if args.import_mlir:
|
||||
_import(self)
|
||||
|
||||
else:
|
||||
try:
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_scale_model_input_" + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_a_step_"
|
||||
+ self.config.prediction_type
|
||||
+ args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
print(
|
||||
"failed to download model, falling back and using import_mlir"
|
||||
)
|
||||
args.import_mlir = True
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
sample,
|
||||
sigma,
|
||||
),
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
step_inputs = []
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
sigma_from = self.sigmas[self.step_index]
|
||||
sigma_to = self.sigmas[self.step_index + 1]
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
step_inputs = [
|
||||
noise_pred,
|
||||
latent,
|
||||
sigma,
|
||||
sigma_from,
|
||||
sigma_to,
|
||||
noise,
|
||||
]
|
||||
# TODO: deal with dynamic inputs in turbine flow.
|
||||
# update step index since we're done with the variable and will return with compiled module output.
|
||||
self._step_index += 1
|
||||
|
||||
if noise_pred.shape[0] < self.batch_size:
|
||||
for i in [0, 1, 5]:
|
||||
try:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i])
|
||||
except:
|
||||
step_inputs[i] = torch.tensor(step_inputs[i].to_host())
|
||||
step_inputs[i] = torch.cat(
|
||||
(step_inputs[i], step_inputs[i]), axis=0
|
||||
)
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=True,
|
||||
)
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
tuple(step_inputs),
|
||||
send_to_host=False,
|
||||
)
|
||||
@@ -2,12 +2,9 @@ import sys
|
||||
import numpy as np
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
from diffusers.configuration_utils import register_to_config
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
compile_through_fx,
|
||||
@@ -27,6 +24,13 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
beta_schedule: str = "linear",
|
||||
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
|
||||
prediction_type: str = "epsilon",
|
||||
interpolation_type: str = "linear",
|
||||
use_karras_sigmas: bool = False,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
timestep_spacing: str = "linspace",
|
||||
timestep_type: str = "discrete",
|
||||
steps_offset: int = 0,
|
||||
):
|
||||
super().__init__(
|
||||
num_train_timesteps,
|
||||
@@ -35,20 +39,29 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
beta_schedule,
|
||||
trained_betas,
|
||||
prediction_type,
|
||||
interpolation_type,
|
||||
use_karras_sigmas,
|
||||
sigma_min,
|
||||
sigma_max,
|
||||
timestep_spacing,
|
||||
timestep_type,
|
||||
steps_offset,
|
||||
)
|
||||
# TODO: make it dynamic so we dont have to worry about batch size
|
||||
self.batch_size = 1
|
||||
|
||||
def compile(self):
|
||||
def compile(self, batch_size=1):
|
||||
SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers"
|
||||
BATCH_SIZE = args.batch_size
|
||||
device = args.device.split(":", 1)[0].strip()
|
||||
self.batch_size = batch_size
|
||||
|
||||
model_input = {
|
||||
"euler": {
|
||||
"latent": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"output": torch.randn(
|
||||
BATCH_SIZE, 4, args.height // 8, args.width // 8
|
||||
batch_size, 4, args.height // 8, args.width // 8
|
||||
),
|
||||
"sigma": torch.tensor(1).to(torch.float32),
|
||||
"dt": torch.tensor(1).to(torch.float32),
|
||||
@@ -70,12 +83,32 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
def forward(self, latent, sigma):
|
||||
return latent / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
class SchedulerStepModel(torch.nn.Module):
|
||||
class SchedulerStepEpsilonModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = latent - sigma_hat * noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepSampleModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma_hat, latent, dt):
|
||||
pred_original_sample = noise_pred
|
||||
derivative = (latent - pred_original_sample) / sigma_hat
|
||||
return latent + derivative * dt
|
||||
|
||||
class SchedulerStepVPredictionModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, noise_pred, sigma, latent, dt):
|
||||
pred_original_sample = latent - sigma * noise_pred
|
||||
pred_original_sample = noise_pred * (
|
||||
-sigma / (sigma**2 + 1) ** 0.5
|
||||
) + (latent / (sigma**2 + 1))
|
||||
derivative = (latent - pred_original_sample) / sigma
|
||||
return latent + derivative * dt
|
||||
|
||||
@@ -90,16 +123,22 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
self.scaling_model, _ = compile_through_fx(
|
||||
model=scaling_model,
|
||||
inputs=(example_latent, example_sigma),
|
||||
extended_model_name=f"euler_scale_model_input_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
|
||||
step_model = SchedulerStepModel()
|
||||
pred_type_model_dict = {
|
||||
"epsilon": SchedulerStepEpsilonModel(),
|
||||
"v_prediction": SchedulerStepVPredictionModel(),
|
||||
"sample": SchedulerStepSampleModel(),
|
||||
"original_sample": SchedulerStepSampleModel(),
|
||||
}
|
||||
step_model = pred_type_model_dict[self.config.prediction_type]
|
||||
self.step_model, _ = compile_through_fx(
|
||||
step_model,
|
||||
(example_output, example_sigma, example_latent, example_dt),
|
||||
extended_model_name=f"euler_step_{BATCH_SIZE}_{args.height}_{args.width}_{device}_"
|
||||
extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_"
|
||||
+ args.precision,
|
||||
extra_args=iree_flags,
|
||||
)
|
||||
@@ -109,6 +148,11 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
|
||||
else:
|
||||
try:
|
||||
step_model_type = (
|
||||
"sample"
|
||||
if "sample" in self.config.prediction_type
|
||||
else self.config.prediction_type
|
||||
)
|
||||
self.scaling_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_scale_model_input_" + args.precision,
|
||||
@@ -116,7 +160,7 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
)
|
||||
self.step_model = get_shark_model(
|
||||
SCHEDULER_BUCKET,
|
||||
"euler_step_" + args.precision,
|
||||
"euler_step_" + step_model_type + args.precision,
|
||||
iree_flags,
|
||||
)
|
||||
except:
|
||||
@@ -127,8 +171,9 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
_import(self)
|
||||
|
||||
def scale_model_input(self, sample, timestep):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
sigma = self.sigmas[self.step_index]
|
||||
return self.scaling_model(
|
||||
"forward",
|
||||
(
|
||||
@@ -138,15 +183,61 @@ class SharkEulerDiscreteScheduler(EulerDiscreteScheduler):
|
||||
send_to_host=False,
|
||||
)
|
||||
|
||||
def step(self, noise_pred, timestep, latent):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
dt = self.sigmas[step_index + 1] - sigma
|
||||
def step(
|
||||
self,
|
||||
noise_pred,
|
||||
timestep,
|
||||
latent,
|
||||
s_churn: float = 0.0,
|
||||
s_tmin: float = 0.0,
|
||||
s_tmax: float = float("inf"),
|
||||
s_noise: float = 1.0,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: Optional[bool] = False,
|
||||
):
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
sigma = self.sigmas[self.step_index]
|
||||
|
||||
gamma = (
|
||||
min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1)
|
||||
if s_tmin <= sigma <= s_tmax
|
||||
else 0.0
|
||||
)
|
||||
|
||||
sigma_hat = sigma * (gamma + 1)
|
||||
|
||||
noise_pred = (
|
||||
torch.from_numpy(noise_pred)
|
||||
if isinstance(noise_pred, np.ndarray)
|
||||
else noise_pred
|
||||
)
|
||||
|
||||
noise = randn_tensor(
|
||||
torch.Size(noise_pred.shape),
|
||||
dtype=torch.float16,
|
||||
device="cpu",
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
eps = noise * s_noise
|
||||
|
||||
if gamma > 0:
|
||||
latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
|
||||
if self.config.prediction_type == "v_prediction":
|
||||
sigma_hat = sigma
|
||||
|
||||
dt = self.sigmas[self.step_index + 1] - sigma_hat
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
return self.step_model(
|
||||
"forward",
|
||||
(
|
||||
noise_pred,
|
||||
sigma,
|
||||
sigma_hat,
|
||||
latent,
|
||||
dt,
|
||||
),
|
||||
|
||||
@@ -13,6 +13,7 @@ from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
|
||||
from apps.stable_diffusion.src.utils.stable_args import args
|
||||
from apps.stable_diffusion.src.utils.stencils.stencil_utils import (
|
||||
controlnet_hint_conversion,
|
||||
controlnet_hint_reshaping,
|
||||
get_stencil_model_id,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.utils import (
|
||||
|
||||
@@ -8,6 +8,15 @@
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"sdxl_clip": {
|
||||
"token" : {
|
||||
"shape" : [
|
||||
"1*batch_size",
|
||||
"max_len"
|
||||
],
|
||||
"dtype":"i64"
|
||||
}
|
||||
},
|
||||
"vae_encode": {
|
||||
"image" : {
|
||||
"shape" : [
|
||||
@@ -179,9 +188,95 @@
|
||||
"shape": [2],
|
||||
"dtype": "i64"
|
||||
}
|
||||
},
|
||||
"stabilityai/sdxl-turbo": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
4,
|
||||
"height",
|
||||
"width"
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"timesteps": {
|
||||
"shape": [
|
||||
1
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"prompt_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
"max_len",
|
||||
2048
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"text_embeds": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
1280
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"time_ids": {
|
||||
"shape": [
|
||||
"2*batch_size",
|
||||
6
|
||||
],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"guidance_scale": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
},
|
||||
"stencil_adaptor": {
|
||||
"stencil_adapter": {
|
||||
"latents": {
|
||||
"shape": [
|
||||
"1*batch_size",
|
||||
@@ -208,6 +303,58 @@
|
||||
"controlnet_hint": {
|
||||
"shape": [1, 3, "8*height", "8*width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc1": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc2": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc3": {
|
||||
"shape": [2, 320, "height", "width"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc4": {
|
||||
"shape": [2, 320, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc5": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc6": {
|
||||
"shape": [2, 640, "height/2", "width/2"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc7": {
|
||||
"shape": [2, 640, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc8": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc9": {
|
||||
"shape": [2, 1280, "height/4", "width/4"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc10": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc11": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc12": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"acc13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
}
|
||||
},
|
||||
"stencil_unet": {
|
||||
@@ -290,7 +437,59 @@
|
||||
"control13": {
|
||||
"shape": [2, 1280, "height/8", "width/8"],
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale1": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale2": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale3": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale4": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale5": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale6": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale7": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale8": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale9": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale10": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale11": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale12": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
},
|
||||
"scale13": {
|
||||
"shape": 1,
|
||||
"dtype": "f32"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -59,24 +59,28 @@
|
||||
"tuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
}
|
||||
},
|
||||
"untuned": {
|
||||
"fp16": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
},
|
||||
"fp32": {
|
||||
"default_compilation_flags": [
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
"--iree-opt-data-tiling=False"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
[["A high tech solarpunk utopia in the Amazon rainforest"],
|
||||
["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"],
|
||||
["A pikachu fine dining with a view to the Eiffel Tower"],
|
||||
["A mecha robot in a favela in expressionist style"],
|
||||
["an insect robot preparing a delicious meal"],
|
||||
|
||||
@@ -85,7 +85,7 @@ p.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 769, 8),
|
||||
choices=range(128, 1025, 8),
|
||||
help="The height of the output image.",
|
||||
)
|
||||
|
||||
@@ -93,7 +93,7 @@ p.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=512,
|
||||
choices=range(128, 769, 8),
|
||||
choices=range(128, 1025, 8),
|
||||
help="The width of the output image.",
|
||||
)
|
||||
|
||||
@@ -420,6 +420,13 @@ p.add_argument(
|
||||
help="Enable the stencil feature.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--control_mode",
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
default="Balanced",
|
||||
help="How Controlnet injection should be prioritized.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_lora",
|
||||
type=str,
|
||||
@@ -428,6 +435,13 @@ p.add_argument(
|
||||
"file (~3 MB).",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--lora_strength",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Strength (alpha) scaling factor to use when applying LoRA weights",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--use_quantize",
|
||||
type=str,
|
||||
@@ -460,6 +474,13 @@ p.add_argument(
|
||||
"Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count"
|
||||
"Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--autogen",
|
||||
type=bool,
|
||||
default="False",
|
||||
help="Only used for a gradio workaround.",
|
||||
)
|
||||
##############################################################################
|
||||
# IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
@@ -587,6 +608,13 @@ p.add_argument(
|
||||
help="Controls constant folding in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--data_tiling",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Controls data tiling in iree-compile for all SD models.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Web UI flags
|
||||
##############################################################################
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import torch
|
||||
import os
|
||||
from pathlib import Path
|
||||
import torchvision
|
||||
import time
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
@@ -10,6 +14,31 @@ from apps.stable_diffusion.src.utils.stencils import (
|
||||
stencil = {}
|
||||
|
||||
|
||||
def save_img(img):
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generated_imgs_todays_subdir,
|
||||
)
|
||||
|
||||
subdir = Path(get_generated_imgs_path(), "preprocessed_control_hints")
|
||||
os.makedirs(subdir, exist_ok=True)
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(
|
||||
os.path.join(
|
||||
subdir, "controlnet_" + str(int(time.time())) + ".png"
|
||||
)
|
||||
)
|
||||
elif isinstance(img, np.ndarray):
|
||||
img = Image.fromarray(img)
|
||||
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
|
||||
else:
|
||||
converter = torchvision.transforms.ToPILImage()
|
||||
for i in img:
|
||||
converter(i).save(
|
||||
os.path.join(subdir, str(int(time.time())) + ".png")
|
||||
)
|
||||
|
||||
|
||||
def HWC3(x):
|
||||
assert x.dtype == np.uint8
|
||||
if x.ndim == 2:
|
||||
@@ -29,7 +58,7 @@ def HWC3(x):
|
||||
return y
|
||||
|
||||
|
||||
def controlnet_hint_shaping(
|
||||
def controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt=1
|
||||
):
|
||||
channels = 3
|
||||
@@ -48,10 +77,12 @@ def controlnet_hint_shaping(
|
||||
)
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `stencil` are any of ({channels}, {height}, {width}),"
|
||||
+ f" (1, {channels}, {height}, {width}) or ({num_images_per_prompt}, "
|
||||
+ f"{channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
return controlnet_hint_reshaping(
|
||||
Image.fromarray(controlnet_hint.detach().numpy()),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
elif isinstance(controlnet_hint, np.ndarray):
|
||||
# np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot)
|
||||
@@ -78,29 +109,38 @@ def controlnet_hint_shaping(
|
||||
) # b h w c -> b c h w
|
||||
return controlnet_hint
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptble shape of `stencil` are any of ({width}, {channels}), "
|
||||
+ f"({height}, {width}, {channels}), "
|
||||
+ f"(1, {height}, {width}, {channels}) or "
|
||||
+ f"({num_images_per_prompt}, {channels}, {height}, {width}) but is {controlnet_hint.shape}"
|
||||
return controlnet_hint_reshaping(
|
||||
Image.fromarray(controlnet_hint),
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
num_images_per_prompt,
|
||||
)
|
||||
|
||||
elif isinstance(controlnet_hint, Image.Image):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
if controlnet_hint.size == (width, height):
|
||||
controlnet_hint = controlnet_hint.convert(
|
||||
"RGB"
|
||||
) # make sure 3 channel RGB format
|
||||
controlnet_hint = np.array(controlnet_hint) # to numpy
|
||||
controlnet_hint = np.array(controlnet_hint).astype(
|
||||
np.float16
|
||||
) # to numpy
|
||||
controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR
|
||||
return controlnet_hint_shaping(
|
||||
controlnet_hint, height, width, num_images_per_prompt
|
||||
return controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable image size of `stencil` is ({width}, {height}) but is {controlnet_hint.size}"
|
||||
)
|
||||
(hint_w, hint_h) = controlnet_hint.size
|
||||
left = int((hint_w - width) / 2)
|
||||
right = left + height
|
||||
controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h))
|
||||
controlnet_hint = controlnet_hint.resize((width, height))
|
||||
return controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Acceptable type of `stencil` are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -110,20 +150,26 @@ def controlnet_hint_conversion(
|
||||
controlnet_hint = None
|
||||
match use_stencil:
|
||||
case "canny":
|
||||
print("Detecting edge with canny")
|
||||
print(
|
||||
"Converting controlnet hint to edge detection mask with canny preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_canny(image)
|
||||
case "openpose":
|
||||
print("Detecting human pose")
|
||||
print(
|
||||
"Detecting human pose in controlnet hint with openpose preprocessor."
|
||||
)
|
||||
controlnet_hint = hint_openpose(image)
|
||||
case "scribble":
|
||||
print("Working with scribble")
|
||||
print("Using your scribble as a controlnet hint.")
|
||||
controlnet_hint = hint_scribble(image)
|
||||
case "zoedepth":
|
||||
print("Working with ZoeDepth")
|
||||
print(
|
||||
"Converting controlnet hint to a depth mapping with ZoeDepth."
|
||||
)
|
||||
controlnet_hint = hint_zoedepth(image)
|
||||
case _:
|
||||
return None
|
||||
controlnet_hint = controlnet_hint_shaping(
|
||||
controlnet_hint = controlnet_hint_reshaping(
|
||||
controlnet_hint, height, width, dtype, num_images_per_prompt
|
||||
)
|
||||
return controlnet_hint
|
||||
@@ -161,6 +207,7 @@ def hint_canny(
|
||||
detected_map = stencil["canny"](
|
||||
input_image, low_threshold, high_threshold
|
||||
)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
@@ -176,6 +223,7 @@ def hint_openpose(
|
||||
stencil["openpose"] = OpenposeDetector()
|
||||
|
||||
detected_map, _ = stencil["openpose"](input_image)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
@@ -187,6 +235,7 @@ def hint_scribble(image: Image.Image):
|
||||
|
||||
detected_map = np.zeros_like(input_image, dtype=np.uint8)
|
||||
detected_map[np.min(input_image, axis=2) < 127] = 255
|
||||
save_img(detected_map)
|
||||
return detected_map
|
||||
|
||||
|
||||
@@ -199,5 +248,6 @@ def hint_zoedepth(image: Image.Image):
|
||||
stencil["depth"] = ZoeDetector()
|
||||
|
||||
detected_map = stencil["depth"](input_image)
|
||||
save_img(detected_map)
|
||||
detected_map = HWC3(detected_map)
|
||||
return detected_map
|
||||
|
||||
@@ -30,9 +30,15 @@ class ZoeDetector:
|
||||
pretrained=False,
|
||||
force_reload=False,
|
||||
)
|
||||
model.load_state_dict(
|
||||
torch.load(modelpath, map_location=model.device)["model"]
|
||||
)
|
||||
|
||||
# Hack to fix the ZoeDepth import issue
|
||||
model_keys = model.state_dict().keys()
|
||||
loaded_dict = torch.load(modelpath, map_location=model.device)["model"]
|
||||
loaded_keys = loaded_dict.keys()
|
||||
for key in loaded_keys - model_keys:
|
||||
loaded_dict.pop(key)
|
||||
|
||||
model.load_state_dict(loaded_dict)
|
||||
model.eval()
|
||||
self.model = model
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from PIL import PngImagePlugin
|
||||
from PIL import Image
|
||||
from datetime import datetime as dt
|
||||
from csv import DictWriter
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
from random import (
|
||||
@@ -118,7 +119,7 @@ def compile_through_fx(
|
||||
is_f16=False,
|
||||
f16_input_mask=None,
|
||||
use_tuned=False,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
save_dir="",
|
||||
debug=False,
|
||||
generate_vmfb=True,
|
||||
extra_args=None,
|
||||
@@ -541,6 +542,8 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags.append(
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807"
|
||||
)
|
||||
if args.data_tiling == False:
|
||||
iree_flags.append("--iree-opt-data-tiling=False")
|
||||
|
||||
if "default_compilation_flags" in opt_flags[model][is_tuned][precision]:
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
@@ -563,6 +566,10 @@ def get_opt_flags(model, precision="fp16"):
|
||||
iree_flags += opt_flags[model][is_tuned][precision][
|
||||
"specified_compilation_flags"
|
||||
][device]
|
||||
if "vae" not in model:
|
||||
# Due to lack of support for multi-reduce, we always collapse reduction
|
||||
# dims before dispatch formation right now.
|
||||
iree_flags += ["--iree-flow-collapse-reduction-dims"]
|
||||
return iree_flags
|
||||
|
||||
|
||||
@@ -632,30 +639,51 @@ def convert_original_vae(vae_checkpoint):
|
||||
return converted_vae_checkpoint
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix):
|
||||
@dataclass
|
||||
class LoRAweight:
|
||||
up: torch.tensor
|
||||
down: torch.tensor
|
||||
mid: torch.tensor
|
||||
alpha: torch.float32 = 1.0
|
||||
|
||||
|
||||
def processLoRA(model, use_lora, splitting_prefix, lora_strength):
|
||||
state_dict = ""
|
||||
if ".safetensors" in use_lora:
|
||||
state_dict = load_file(use_lora)
|
||||
else:
|
||||
state_dict = torch.load(use_lora)
|
||||
alpha = 0.75
|
||||
visited = []
|
||||
|
||||
# directly update weight in model
|
||||
process_unet = "te" not in splitting_prefix
|
||||
# gather the weights from the LoRA in a more convenient form, assumes
|
||||
# everything will have an up.weight. Unsure if this is a safe assumption.
|
||||
weight_dict: dict[str, LoRAweight] = {}
|
||||
for key in state_dict:
|
||||
if ".alpha" in key or key in visited:
|
||||
continue
|
||||
if key.startswith(splitting_prefix) and key.endswith("up.weight"):
|
||||
stem = key.split("up.weight")[0]
|
||||
weight_key = stem.removesuffix(".lora_")
|
||||
weight_key = weight_key.removesuffix("_lora_")
|
||||
weight_key = weight_key.removesuffix(".lora_linear_layer.")
|
||||
|
||||
if weight_key not in weight_dict:
|
||||
weight_dict[weight_key] = LoRAweight(
|
||||
state_dict[f"{stem}up.weight"],
|
||||
state_dict[f"{stem}down.weight"],
|
||||
state_dict.get(f"{stem}mid.weight", None),
|
||||
state_dict[f"{weight_key}.alpha"]
|
||||
/ state_dict[f"{stem}up.weight"].shape[1]
|
||||
if f"{weight_key}.alpha" in state_dict
|
||||
else 1.0,
|
||||
)
|
||||
|
||||
# Directly update weight in model
|
||||
|
||||
# Mostly adaptions of https://github.com/kohya-ss/sd-scripts/blob/main/networks/merge_lora.py
|
||||
# and similar code in https://github.com/huggingface/diffusers/issues/3064
|
||||
|
||||
# TODO: handle mid weights (how do they even work?)
|
||||
for key, lora_weight in weight_dict.items():
|
||||
curr_layer = model
|
||||
if ("text" not in key and process_unet) or (
|
||||
"text" in key and not process_unet
|
||||
):
|
||||
layer_infos = (
|
||||
key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
)
|
||||
else:
|
||||
continue
|
||||
layer_infos = key.split(".")[0].split(splitting_prefix)[-1].split("_")
|
||||
|
||||
# find the target layer
|
||||
temp_name = layer_infos.pop(0)
|
||||
@@ -672,46 +700,46 @@ def processLoRA(model, use_lora, splitting_prefix):
|
||||
else:
|
||||
temp_name = layer_infos.pop(0)
|
||||
|
||||
pair_keys = []
|
||||
if "lora_down" in key:
|
||||
pair_keys.append(key.replace("lora_down", "lora_up"))
|
||||
pair_keys.append(key)
|
||||
else:
|
||||
pair_keys.append(key)
|
||||
pair_keys.append(key.replace("lora_up", "lora_down"))
|
||||
|
||||
# update weight
|
||||
if len(state_dict[pair_keys[0]].shape) == 4:
|
||||
weight_up = (
|
||||
state_dict[pair_keys[0]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
)
|
||||
weight = curr_layer.weight.data
|
||||
scale = lora_weight.alpha * lora_strength
|
||||
if len(weight.size()) == 2:
|
||||
if len(lora_weight.up.shape) == 4:
|
||||
weight_up = (
|
||||
lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
weight_down = (
|
||||
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
change = (
|
||||
torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
)
|
||||
else:
|
||||
change = torch.mm(lora_weight.up, lora_weight.down)
|
||||
elif lora_weight.down.size()[2:4] == (1, 1):
|
||||
weight_up = lora_weight.up.squeeze(3).squeeze(2).to(torch.float32)
|
||||
weight_down = (
|
||||
state_dict[pair_keys[1]]
|
||||
.squeeze(3)
|
||||
.squeeze(2)
|
||||
.to(torch.float32)
|
||||
lora_weight.down.squeeze(3).squeeze(2).to(torch.float32)
|
||||
)
|
||||
curr_layer.weight.data += alpha * torch.mm(
|
||||
weight_up, weight_down
|
||||
).unsqueeze(2).unsqueeze(3)
|
||||
change = torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
||||
else:
|
||||
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
||||
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
||||
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
||||
# update visited list
|
||||
for item in pair_keys:
|
||||
visited.append(item)
|
||||
change = torch.nn.functional.conv2d(
|
||||
lora_weight.down.permute(1, 0, 2, 3),
|
||||
lora_weight.up,
|
||||
).permute(1, 0, 2, 3)
|
||||
|
||||
curr_layer.weight.data += change * scale
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def update_lora_weight_for_unet(unet, use_lora):
|
||||
def update_lora_weight_for_unet(unet, use_lora, lora_strength):
|
||||
extensions = [".bin", ".safetensors", ".pt"]
|
||||
if not any([extension in use_lora for extension in extensions]):
|
||||
# We assume if it is a HF ID with standalone LoRA weights.
|
||||
unet.load_attn_procs(use_lora)
|
||||
print(
|
||||
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
|
||||
)
|
||||
return unet
|
||||
|
||||
main_file_name = get_path_stem(use_lora)
|
||||
@@ -727,16 +755,21 @@ def update_lora_weight_for_unet(unet, use_lora):
|
||||
try:
|
||||
dir_name = os.path.dirname(use_lora)
|
||||
unet.load_attn_procs(dir_name, weight_name=main_file_name)
|
||||
print(
|
||||
f"updated unet weights via diffusers load_attn_procs from LoRA: {use_lora}"
|
||||
)
|
||||
return unet
|
||||
except:
|
||||
return processLoRA(unet, use_lora, "lora_unet_")
|
||||
print(f"updated unet weights manually from LoRA: {use_lora}")
|
||||
return processLoRA(unet, use_lora, "lora_unet_", lora_strength)
|
||||
|
||||
|
||||
def update_lora_weight(model, use_lora, model_name):
|
||||
def update_lora_weight(model, use_lora, model_name, lora_strength):
|
||||
if "unet" in model_name:
|
||||
return update_lora_weight_for_unet(model, use_lora)
|
||||
return update_lora_weight_for_unet(model, use_lora, lora_strength)
|
||||
try:
|
||||
return processLoRA(model, use_lora, "lora_te_")
|
||||
print(f"updating CLIP weights from LoRA: {use_lora}")
|
||||
return processLoRA(model, use_lora, "lora_te_", lora_strength)
|
||||
except:
|
||||
return None
|
||||
|
||||
@@ -892,7 +925,7 @@ def save_output_img(output_img, img_seed, extra_info=None):
|
||||
|
||||
img_lora = None
|
||||
if args.use_lora:
|
||||
img_lora = Path(os.path.basename(args.use_lora)).stem
|
||||
img_lora = f"{Path(os.path.basename(args.use_lora)).stem}:{args.lora_strength}"
|
||||
|
||||
if args.output_img_format == "jpg":
|
||||
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
|
||||
@@ -1002,8 +1035,7 @@ def get_generation_text_info(seeds, device):
|
||||
# Both width and height should be in the range of [128, 768] and multiple of 8.
|
||||
# This utility function performs the transformation on the input image while
|
||||
# also maintaining the aspect ratio before sending it to the stencil pipeline.
|
||||
def resize_stencil(image: Image.Image):
|
||||
width, height = image.size
|
||||
def resize_stencil(image: Image.Image, width, height):
|
||||
aspect_ratio = width / height
|
||||
min_size = min(width, height)
|
||||
if min_size < 128:
|
||||
|
||||
@@ -19,6 +19,9 @@ a = Analysis(
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
|
||||
@@ -12,7 +12,6 @@ from apps.stable_diffusion.web.api.utils import (
|
||||
decode_base64_to_image,
|
||||
get_model_from_request,
|
||||
get_scheduler_from_request,
|
||||
get_lora_params,
|
||||
get_device,
|
||||
GenerationInputData,
|
||||
GenerationResponseData,
|
||||
@@ -180,7 +179,6 @@ def txt2img_api(InputData: Txt2ImgInputData):
|
||||
scheduler = get_scheduler_from_request(
|
||||
InputData, "txt2img_hires" if InputData.enable_hr else "txt2img"
|
||||
)
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
print(
|
||||
f"Prompt: {InputData.prompt}, "
|
||||
@@ -208,8 +206,8 @@ def txt2img_api(InputData: Txt2ImgInputData):
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
use_hiresfix=InputData.enable_hr,
|
||||
@@ -270,7 +268,6 @@ def img2img_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-1-base",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "img2img")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
mask_image = (
|
||||
@@ -308,8 +305,8 @@ def img2img_api(
|
||||
use_stencil=InputData.use_stencil,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
resample_type=frozen_args.resample_type,
|
||||
@@ -358,7 +355,6 @@ def inpaint_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "inpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.image)
|
||||
mask = decode_base64_to_image(InputData.mask)
|
||||
@@ -374,7 +370,8 @@ def inpaint_api(
|
||||
res = inpaint_inf(
|
||||
InputData.prompt,
|
||||
InputData.negative_prompt,
|
||||
{"image": init_image, "mask": mask},
|
||||
init_image,
|
||||
mask,
|
||||
InputData.height,
|
||||
InputData.width,
|
||||
InputData.is_full_res,
|
||||
@@ -392,8 +389,8 @@ def inpaint_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
@@ -447,7 +444,6 @@ def outpaint_api(
|
||||
fallback_model="stabilityai/stable-diffusion-2-inpainting",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "outpaint")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
@@ -483,8 +479,8 @@ def outpaint_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
@@ -530,7 +526,6 @@ def upscaler_api(
|
||||
fallback_model="stabilityai/stable-diffusion-x4-upscaler",
|
||||
)
|
||||
scheduler = get_scheduler_from_request(InputData, "upscaler")
|
||||
(lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora)
|
||||
|
||||
init_image = decode_base64_to_image(InputData.init_images[0])
|
||||
|
||||
@@ -562,8 +557,8 @@ def upscaler_api(
|
||||
max_length=frozen_args.max_length,
|
||||
save_metadata_to_json=frozen_args.save_metadata_to_json,
|
||||
save_metadata_to_png=frozen_args.write_metadata_to_png,
|
||||
lora_weights=lora_weights,
|
||||
lora_hf_id=lora_hf_id,
|
||||
lora_weights=frozen_args.use_lora,
|
||||
lora_strength=frozen_args.lora_strength,
|
||||
ondemand=frozen_args.ondemand,
|
||||
repeatable_seeds=False,
|
||||
)
|
||||
|
||||
@@ -191,17 +191,6 @@ def get_scheduler_from_request(
|
||||
)
|
||||
|
||||
|
||||
def get_lora_params(use_lora: str):
|
||||
# TODO: since the inference functions in the webui, which we are
|
||||
# still calling into for the api, jam these back together again before
|
||||
# handing them off to the pipeline, we should remove this nonsense
|
||||
# and unify their selection in the UI and command line args proper
|
||||
if use_lora in get_custom_model_files("lora"):
|
||||
return (use_lora, "")
|
||||
|
||||
return ("None", use_lora)
|
||||
|
||||
|
||||
def get_device(device_str: str):
|
||||
# first substring match in the list available devices, with first
|
||||
# device when none are matched
|
||||
|
||||
@@ -75,17 +75,20 @@ if __name__ == "__main__":
|
||||
# Setup to use shark_tmp for gradio's temporary image files and clear any
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
from apps.stable_diffusion.web.utils.gradio_configs import (
|
||||
config_gradio_tmp_imgs_folder,
|
||||
from apps.stable_diffusion.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
shark_tmp,
|
||||
)
|
||||
|
||||
config_gradio_tmp_imgs_folder()
|
||||
config_tmp()
|
||||
import gradio as gr
|
||||
|
||||
# Create custom models folders if they don't exist
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
create_custom_models_folders,
|
||||
nodicon_loc,
|
||||
mask_editor_value_for_gallery_data,
|
||||
mask_editor_value_for_image_file,
|
||||
)
|
||||
|
||||
create_custom_models_folders()
|
||||
@@ -97,8 +100,6 @@ if __name__ == "__main__":
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
from apps.stable_diffusion.web.ui import (
|
||||
txt2img_web,
|
||||
txt2img_custom_model,
|
||||
@@ -109,6 +110,16 @@ if __name__ == "__main__":
|
||||
txt2img_sendto_inpaint,
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
# SDXL
|
||||
txt2img_sdxl_web,
|
||||
txt2img_sdxl_custom_model,
|
||||
txt2img_sdxl_gallery,
|
||||
txt2img_sdxl_png_info_img,
|
||||
txt2img_sdxl_status,
|
||||
txt2img_sdxl_sendto_img2img,
|
||||
txt2img_sdxl_sendto_inpaint,
|
||||
txt2img_sdxl_sendto_outpaint,
|
||||
txt2img_sdxl_sendto_upscaler,
|
||||
# h2ogpt_upload,
|
||||
# h2ogpt_web,
|
||||
img2img_web,
|
||||
@@ -145,7 +156,7 @@ if __name__ == "__main__":
|
||||
upscaler_sendto_outpaint,
|
||||
# lora_train_web,
|
||||
# model_web,
|
||||
# model_config_web,
|
||||
model_config_web,
|
||||
hf_models,
|
||||
modelmanager_sendto_txt2img,
|
||||
modelmanager_sendto_img2img,
|
||||
@@ -159,6 +170,7 @@ if __name__ == "__main__":
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
@@ -168,11 +180,21 @@ if __name__ == "__main__":
|
||||
# init global sd pipeline and config
|
||||
global_obj._init()
|
||||
|
||||
def register_button_click(button, selectedid, inputs, outputs):
|
||||
def register_sendto_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x[0]["name"] if len(x) != 0 else None,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
x.root[0].image.path if len(x.root) != 0 else None,
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_sendto_editor_click(button, selectedid, inputs, outputs):
|
||||
button.click(
|
||||
lambda x: (
|
||||
mask_editor_value_for_gallery_data(x),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
@@ -183,24 +205,45 @@ if __name__ == "__main__":
|
||||
lambda x: (
|
||||
"None",
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
def register_outputgallery_button(button, selectedid, inputs, outputs):
|
||||
def register_outputgallery_sendto_button(
|
||||
button, selectedid, inputs, outputs
|
||||
):
|
||||
button.click(
|
||||
lambda x: (
|
||||
x,
|
||||
gr.Tabs.update(selected=selectedid),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
def register_outputgallery_sendto_editor_button(
|
||||
button, selectedid, inputs, outputs
|
||||
):
|
||||
button.click(
|
||||
lambda x: (
|
||||
mask_editor_value_for_image_file(x),
|
||||
gr.Tabs(selected=selectedid),
|
||||
),
|
||||
inputs,
|
||||
outputs,
|
||||
)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="SHARK AI Studio"
|
||||
css=dark_theme,
|
||||
js=gradio_workarounds,
|
||||
analytics_enabled=False,
|
||||
title="SHARK AI Studio",
|
||||
) as sd_web:
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
@@ -225,34 +268,37 @@ if __name__ == "__main__":
|
||||
if args.output_gallery:
|
||||
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
|
||||
outputgallery_web.render()
|
||||
|
||||
# extra output gallery configuration
|
||||
outputgallery_tab_select(og_tab.select)
|
||||
outputgallery_watch(
|
||||
[
|
||||
txt2img_status,
|
||||
img2img_status,
|
||||
inpaint_status,
|
||||
outpaint_status,
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
# with gr.TabItem(label="Model Manager", id=6):
|
||||
# model_web.render()
|
||||
# with gr.TabItem(label="LoRA Training (Experimental)", id=7):
|
||||
# lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot", id=8):
|
||||
stablelm_chat.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
minigpt4_web.render()
|
||||
# with gr.TabItem(
|
||||
# label="Generate Sharding Config (Experimental)", id=9
|
||||
# ):
|
||||
# model_config_web.render()
|
||||
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
|
||||
# minigpt4_web.render()
|
||||
# with gr.TabItem(label="DocuChat Upload", id=11):
|
||||
# h2ogpt_upload.render()
|
||||
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
|
||||
# h2ogpt_web.render()
|
||||
with gr.TabItem(label="Text-to-Image (SDXL)", id=13):
|
||||
txt2img_sdxl_web.render()
|
||||
|
||||
# extra output gallery configuration
|
||||
outputgallery_tab_select(og_tab.select)
|
||||
outputgallery_watch(
|
||||
[
|
||||
txt2img_status,
|
||||
img2img_status,
|
||||
inpaint_status,
|
||||
outpaint_status,
|
||||
upscaler_status,
|
||||
txt2img_sdxl_status,
|
||||
],
|
||||
)
|
||||
|
||||
actual_port = app.usable_port()
|
||||
if actual_port != args.server_port:
|
||||
@@ -264,133 +310,139 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
# send to buttons
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
txt2img_sendto_img2img,
|
||||
1,
|
||||
[txt2img_gallery],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_editor_click(
|
||||
txt2img_sendto_inpaint,
|
||||
2,
|
||||
[txt2img_gallery],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
txt2img_sendto_outpaint,
|
||||
3,
|
||||
[txt2img_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
txt2img_sendto_upscaler,
|
||||
4,
|
||||
[txt2img_gallery],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_editor_click(
|
||||
img2img_sendto_inpaint,
|
||||
2,
|
||||
[img2img_gallery],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
img2img_sendto_outpaint,
|
||||
3,
|
||||
[img2img_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
img2img_sendto_upscaler,
|
||||
4,
|
||||
[img2img_gallery],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
inpaint_sendto_img2img,
|
||||
1,
|
||||
[inpaint_gallery],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
inpaint_sendto_outpaint,
|
||||
3,
|
||||
[inpaint_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
inpaint_sendto_upscaler,
|
||||
4,
|
||||
[inpaint_gallery],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
outpaint_sendto_img2img,
|
||||
1,
|
||||
[outpaint_gallery],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_editor_click(
|
||||
outpaint_sendto_inpaint,
|
||||
2,
|
||||
[outpaint_gallery],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
outpaint_sendto_upscaler,
|
||||
4,
|
||||
[outpaint_gallery],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
upscaler_sendto_img2img,
|
||||
1,
|
||||
[upscaler_gallery],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_editor_click(
|
||||
upscaler_sendto_inpaint,
|
||||
2,
|
||||
[upscaler_gallery],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_button_click(
|
||||
register_sendto_click(
|
||||
upscaler_sendto_outpaint,
|
||||
3,
|
||||
[upscaler_gallery],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
if args.output_gallery:
|
||||
register_outputgallery_button(
|
||||
register_outputgallery_sendto_button(
|
||||
outputgallery_sendto_txt2img,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_png_info_img, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
register_outputgallery_sendto_button(
|
||||
outputgallery_sendto_img2img,
|
||||
1,
|
||||
[outputgallery_filename],
|
||||
[img2img_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
register_outputgallery_sendto_editor_button(
|
||||
outputgallery_sendto_inpaint,
|
||||
2,
|
||||
[outputgallery_filename],
|
||||
[inpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
register_outputgallery_sendto_button(
|
||||
outputgallery_sendto_outpaint,
|
||||
3,
|
||||
[outputgallery_filename],
|
||||
[outpaint_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_button(
|
||||
register_outputgallery_sendto_button(
|
||||
outputgallery_sendto_upscaler,
|
||||
4,
|
||||
[outputgallery_filename],
|
||||
[upscaler_init_image, tabs],
|
||||
)
|
||||
register_outputgallery_sendto_button(
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
0,
|
||||
[outputgallery_filename],
|
||||
[txt2img_sdxl_png_info_img, tabs],
|
||||
)
|
||||
register_modelmanager_button(
|
||||
modelmanager_sendto_txt2img,
|
||||
0,
|
||||
|
||||
@@ -10,6 +10,18 @@ from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
txt2img_sendto_outpaint,
|
||||
txt2img_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import (
|
||||
txt2img_sdxl_inf,
|
||||
txt2img_sdxl_web,
|
||||
txt2img_sdxl_custom_model,
|
||||
txt2img_sdxl_gallery,
|
||||
txt2img_sdxl_status,
|
||||
txt2img_sdxl_png_info_img,
|
||||
txt2img_sdxl_sendto_img2img,
|
||||
txt2img_sdxl_sendto_inpaint,
|
||||
txt2img_sdxl_sendto_outpaint,
|
||||
txt2img_sdxl_sendto_upscaler,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.img2img_ui import (
|
||||
img2img_inf,
|
||||
img2img_web,
|
||||
@@ -76,6 +88,7 @@ from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_watch,
|
||||
outputgallery_filename,
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
|
||||
64
apps/stable_diffusion/web/ui/common_ui_events.py
Normal file
64
apps/stable_diffusion/web/ui/common_ui_events.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import gradio as gr
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
HSLHue,
|
||||
hsl_color,
|
||||
get_lora_metadata,
|
||||
)
|
||||
|
||||
|
||||
# Answers HTML to show the most frequent tags used when a LoRA was trained,
|
||||
# taken from the metadata of its .safetensors file.
|
||||
def lora_changed(lora_file):
|
||||
# tag frequency percentage, that gets maximum amount of the staring hue
|
||||
TAG_COLOR_THRESHOLD = 0.55
|
||||
# tag frequency percentage, above which a tag is displayed
|
||||
TAG_DISPLAY_THRESHOLD = 0.65
|
||||
# template for the html used to display a tag
|
||||
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
|
||||
|
||||
if lora_file == "None":
|
||||
return ["<div><i>No LoRA selected</i></div>"]
|
||||
elif not lora_file.lower().endswith(".safetensors"):
|
||||
return [
|
||||
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
|
||||
]
|
||||
else:
|
||||
metadata = get_lora_metadata(lora_file)
|
||||
if metadata:
|
||||
frequencies = metadata["frequencies"]
|
||||
return [
|
||||
"".join(
|
||||
[
|
||||
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
|
||||
]
|
||||
+ [
|
||||
TAG_HTML_TEMPLATE.format(
|
||||
color=hsl_color(
|
||||
(tag[1] - TAG_COLOR_THRESHOLD)
|
||||
/ (1 - TAG_COLOR_THRESHOLD),
|
||||
start=HSLHue.RED,
|
||||
end=HSLHue.GREEN,
|
||||
),
|
||||
tag=tag[0],
|
||||
)
|
||||
for tag in frequencies
|
||||
if tag[1] > TAG_DISPLAY_THRESHOLD
|
||||
],
|
||||
)
|
||||
]
|
||||
elif metadata is None:
|
||||
return [
|
||||
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
|
||||
]
|
||||
else:
|
||||
return [
|
||||
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
|
||||
]
|
||||
|
||||
|
||||
def lora_strength_changed(strength):
|
||||
if strength > 1.0:
|
||||
return gr.Number(elem_classes="value-out-of-range")
|
||||
else:
|
||||
return gr.Number(elem_classes="")
|
||||
@@ -105,7 +105,19 @@ body {
|
||||
background-color: var(--background-fill-primary);
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
.generating.svelte-zlszon.svelte-zlszon {
|
||||
border: none;
|
||||
}
|
||||
|
||||
.generating {
|
||||
border: none !important;
|
||||
}
|
||||
|
||||
#chatbot {
|
||||
height: 100% !important;
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices, but see below */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
.gradio-container {
|
||||
@@ -113,12 +125,17 @@ body {
|
||||
}
|
||||
}
|
||||
|
||||
.gradio-container .contain {
|
||||
padding: 0 var(--size-4) !important;
|
||||
/* media rules in custom css are don't appear to be applied in
|
||||
gradio versions > 4.7, so we have to define a class which
|
||||
we will manually need add and remove using javascript.
|
||||
Remove this once this fixed in gradio.
|
||||
*/
|
||||
.gradio-container-size-full {
|
||||
max-width: var(--size-full) !important;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
.gradio-container .contain {
|
||||
padding: 0 var(--size-4) !important;
|
||||
}
|
||||
|
||||
#top_logo {
|
||||
@@ -128,6 +145,10 @@ body {
|
||||
border: 0;
|
||||
}
|
||||
|
||||
#ui_title {
|
||||
padding: var(--size-2) 0 0 var(--size-1);
|
||||
}
|
||||
|
||||
#demo_title_outer {
|
||||
border-radius: 0;
|
||||
}
|
||||
@@ -170,6 +191,8 @@ footer {
|
||||
aspect-ratio: unset;
|
||||
max-height: calc(55vh - (2 * var(--spacing-lg)));
|
||||
}
|
||||
|
||||
/* fix width and height of gallery items when on very large desktop screens, but see below */
|
||||
@media (min-width: 1921px) {
|
||||
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
|
||||
#gallery .grid-wrap, #gallery .preview{
|
||||
@@ -181,6 +204,20 @@ footer {
|
||||
max-height: 770px !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* media rules in custom css are don't appear to be applied in
|
||||
gradio versions > 4.7, so we have to define classes which
|
||||
we will manually need add and remove using javascript.
|
||||
Remove this once this fixed in gradio.
|
||||
*/
|
||||
.gallery-force-height768 .grid-wrap, .gallery-force-height768 .preview {
|
||||
min-height: calc(768px + 4px + var(--size-14)) !important;
|
||||
max-height: calc(768px + 4px + var(--size-14)) !important;
|
||||
}
|
||||
.gallery-limit-height768 .thumbnail-item.thumbnail-lg {
|
||||
max-height: 770px !important;
|
||||
}
|
||||
|
||||
/* Don't upscale when viewing in solo image mode */
|
||||
#gallery .preview img {
|
||||
object-fit: scale-down;
|
||||
@@ -222,18 +259,19 @@ footer {
|
||||
display:none;
|
||||
}
|
||||
|
||||
/* Hide the download icon from the nod logo */
|
||||
#top_logo button {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* workarounds for container=false not currently working for dropdowns */
|
||||
.dropdown_no_container {
|
||||
padding: 0 !important;
|
||||
}
|
||||
|
||||
#output_subdir_container :first-child {
|
||||
border: none;
|
||||
#output_subdir_container {
|
||||
background-color: var(--block-background-fill);
|
||||
padding-right: 8px;
|
||||
}
|
||||
|
||||
/* number input value is out of range */
|
||||
.value-out-of-range input[type="number"] {
|
||||
color: red !important;
|
||||
}
|
||||
|
||||
/* reduced animation load when generating */
|
||||
@@ -246,16 +284,54 @@ footer {
|
||||
background-color: var(--block-label-background-fill);
|
||||
}
|
||||
|
||||
/* lora tag pills */
|
||||
.lora-tags {
|
||||
border: 1px solid var(--border-color-primary);
|
||||
color: var(--block-info-text-color) !important;
|
||||
padding: var(--block-padding);
|
||||
}
|
||||
|
||||
.lora-tag {
|
||||
display: inline-block;
|
||||
height: 2em;
|
||||
color: rgb(212 212 212) !important;
|
||||
margin-right: 5pt;
|
||||
margin-bottom: 5pt;
|
||||
padding: 2pt 5pt;
|
||||
border-radius: 5pt;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.lora-model {
|
||||
margin-bottom: var(--spacing-lg);
|
||||
color: var(--block-info-text-color) !important;
|
||||
line-height: var(--line-sm);
|
||||
}
|
||||
|
||||
/* output gallery tab */
|
||||
.output_parameters_dataframe table.table {
|
||||
/* works around a gradio bug that always shows scrollbars */
|
||||
overflow: clip auto;
|
||||
}
|
||||
|
||||
.output_parameters_dataframe .cell-wrap span {
|
||||
/* inadequate workaround for gradio issue #6086 */
|
||||
user-select:text !important;
|
||||
-moz-user-select:text !important;
|
||||
-webkit-user-select:text !important;
|
||||
-o-user-select:text !important;
|
||||
-ms-user-select:text !important;
|
||||
}
|
||||
|
||||
.output_parameters_dataframe tbody td {
|
||||
font-size: small;
|
||||
line-height: var(--line-xs)
|
||||
line-height: var(--line-xs);
|
||||
}
|
||||
|
||||
.output_icon_button {
|
||||
max-width: 30px;
|
||||
align-self: end;
|
||||
padding-bottom: 8px;
|
||||
padding-bottom: 16px !important;
|
||||
}
|
||||
|
||||
.outputgallery_sendto {
|
||||
@@ -272,6 +348,11 @@ footer {
|
||||
object-fit: contain !important;
|
||||
}
|
||||
|
||||
/* use the whole gallery area for previeews */
|
||||
#outputgallery_gallery .preview {
|
||||
width: inherit;
|
||||
}
|
||||
|
||||
/* centered logo for when there are no images */
|
||||
#top_logo.logo_centered {
|
||||
height: 100%;
|
||||
|
||||
@@ -5,6 +5,13 @@ import gradio as gr
|
||||
import PIL
|
||||
from math import ceil
|
||||
from PIL import Image
|
||||
|
||||
from gradio.components.image_editor import (
|
||||
Brush,
|
||||
Eraser,
|
||||
EditorData,
|
||||
EditorValue,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -14,6 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Image2ImagePipeline,
|
||||
@@ -29,6 +40,11 @@ from apps.stable_diffusion.src.utils import (
|
||||
get_generation_text_info,
|
||||
resampler_list,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils.stencils import (
|
||||
CannyDetector,
|
||||
OpenposeDetector,
|
||||
ZoeDetector,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
import numpy as np
|
||||
|
||||
@@ -58,14 +74,17 @@ def img2img_inf(
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
use_stencil: str,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
resample_type: str,
|
||||
control_mode: str,
|
||||
stencils: list,
|
||||
images: list,
|
||||
preprocessed_hints: list,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
@@ -87,14 +106,25 @@ def img2img_inf(
|
||||
args.img_path = "not none"
|
||||
args.ondemand = ondemand
|
||||
|
||||
if image_dict is None:
|
||||
for i, stencil in enumerate(stencils):
|
||||
if images[i] is None and stencil is not None:
|
||||
continue
|
||||
if images[i] is not None:
|
||||
if isinstance(images[i], dict):
|
||||
images[i] = images[i]["composite"]
|
||||
images[i] = images[i].convert("RGB")
|
||||
|
||||
if image_dict is None and images[0] is None:
|
||||
return None, "An Initial Image is required"
|
||||
if use_stencil == "scribble":
|
||||
image = image_dict["mask"].convert("RGB")
|
||||
elif isinstance(image_dict, PIL.Image.Image):
|
||||
if isinstance(image_dict, PIL.Image.Image):
|
||||
image = image_dict.convert("RGB")
|
||||
else:
|
||||
elif image_dict:
|
||||
image = image_dict["image"].convert("RGB")
|
||||
else:
|
||||
# TODO: enable t2i + controlnets
|
||||
image = None
|
||||
if image:
|
||||
image, _, _ = resize_stencil(image, width, height)
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
@@ -114,19 +144,18 @@ def img2img_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
use_stencil = None if use_stencil == "None" else use_stencil
|
||||
args.use_stencil = use_stencil
|
||||
if use_stencil is not None:
|
||||
args.scheduler = "DDIM"
|
||||
stencil_count = 0
|
||||
for stencil in stencils:
|
||||
if stencil is not None:
|
||||
stencil_count += 1
|
||||
if stencil_count > 0:
|
||||
args.hf_model_id = "runwayml/stable-diffusion-v1-5"
|
||||
image, width, height = resize_stencil(image)
|
||||
elif "Shark" in args.scheduler:
|
||||
print(
|
||||
f"Shark schedulers are not supported. Switching to EulerDiscrete "
|
||||
@@ -136,6 +165,7 @@ def img2img_inf(
|
||||
cpu_scheduling = not args.scheduler.startswith("Shark")
|
||||
args.precision = precision
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
print(stencils)
|
||||
new_config_obj = Config(
|
||||
"img2img",
|
||||
args.hf_model_id,
|
||||
@@ -148,13 +178,19 @@ def img2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=use_stencil,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=stencils,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
or any(
|
||||
global_obj.get_cfg_obj().stencils[idx] != stencil
|
||||
for idx, stencil in enumerate(stencils)
|
||||
)
|
||||
):
|
||||
print("clearing config because you changed something important")
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.batch_count = batch_count
|
||||
@@ -170,12 +206,12 @@ def img2img_inf(
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-2-1-base"
|
||||
else "runwayml/stable-diffusion-v1-5"
|
||||
)
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(args.scheduler)
|
||||
|
||||
if use_stencil is not None:
|
||||
if stencil_count > 0:
|
||||
args.use_tuned = False
|
||||
global_obj.set_sd_obj(
|
||||
StencilPipeline.from_pretrained(
|
||||
@@ -192,9 +228,10 @@ def img2img_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_stencil=use_stencil,
|
||||
stencils=stencils,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -216,6 +253,7 @@ def img2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -249,8 +287,11 @@ def img2img_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil=use_stencil,
|
||||
stencils,
|
||||
images,
|
||||
resample_type=resample_type,
|
||||
control_mode=control_mode,
|
||||
preprocessed_hints=preprocessed_hints,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
@@ -270,12 +311,18 @@ def img2img_inf(
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Image-to-Image", current_batch + 1, batch_count, batch_size
|
||||
)
|
||||
), stencils, images
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
return generated_imgs, text_output, "", stencils, images
|
||||
|
||||
|
||||
with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
# Stencils
|
||||
# TODO: Add more stencils here
|
||||
STENCIL_COUNT = 2
|
||||
stencils = gr.State([None] * STENCIL_COUNT)
|
||||
images = gr.State([None] * STENCIL_COUNT)
|
||||
preprocessed_hints = gr.State([None] * STENCIL_COUNT)
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -284,6 +331,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
@@ -291,6 +339,13 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
# TODO: make this import image prompt info if it exists
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
interactive=True,
|
||||
sources=["upload"],
|
||||
)
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_model_info = (
|
||||
@@ -337,104 +392,449 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
# TODO: make this import image prompt info if it exists
|
||||
img2img_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
height=300,
|
||||
)
|
||||
with gr.Accordion(label="Multistencil Options", open=False):
|
||||
choices = [
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
]
|
||||
|
||||
def cnet_preview(
|
||||
model,
|
||||
input_image,
|
||||
index,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
):
|
||||
if isinstance(input_image, PIL.Image.Image):
|
||||
img_dict = {
|
||||
"background": None,
|
||||
"layers": [None],
|
||||
"composite": input_image,
|
||||
}
|
||||
input_image = EditorValue(img_dict)
|
||||
images[index] = input_image
|
||||
if model:
|
||||
stencils[index] = model
|
||||
match model:
|
||||
case "canny":
|
||||
canny = CannyDetector()
|
||||
result = canny(
|
||||
np.array(input_image["composite"]),
|
||||
100,
|
||||
200,
|
||||
)
|
||||
preprocessed_hints[index] = Image.fromarray(
|
||||
result
|
||||
)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "openpose":
|
||||
openpose = OpenposeDetector()
|
||||
result = openpose(
|
||||
np.array(input_image["composite"])
|
||||
)
|
||||
preprocessed_hints[index] = Image.fromarray(
|
||||
result[0]
|
||||
)
|
||||
return (
|
||||
Image.fromarray(result[0]),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "zoedepth":
|
||||
zoedepth = ZoeDetector()
|
||||
result = zoedepth(
|
||||
np.array(input_image["composite"])
|
||||
)
|
||||
preprocessed_hints[index] = Image.fromarray(
|
||||
result
|
||||
)
|
||||
return (
|
||||
Image.fromarray(result),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case "scribble":
|
||||
preprocessed_hints[index] = input_image[
|
||||
"composite"
|
||||
]
|
||||
return (
|
||||
input_image["composite"],
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
case _:
|
||||
preprocessed_hints[index] = None
|
||||
return (
|
||||
None,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
)
|
||||
|
||||
def import_original(original_img, width, height):
|
||||
resized_img, _, _ = resize_stencil(
|
||||
original_img, width, height
|
||||
)
|
||||
img_dict = {
|
||||
"background": resized_img,
|
||||
"layers": [resized_img],
|
||||
"composite": None,
|
||||
}
|
||||
return gr.ImageEditor(
|
||||
value=EditorValue(img_dict),
|
||||
crop_size=(width, height),
|
||||
)
|
||||
|
||||
def create_canvas(width, height):
|
||||
data = Image.fromarray(
|
||||
np.zeros(
|
||||
shape=(height, width, 3),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
+ 255
|
||||
)
|
||||
img_dict = {
|
||||
"background": data,
|
||||
"layers": [data],
|
||||
"composite": None,
|
||||
}
|
||||
return EditorValue(img_dict)
|
||||
|
||||
def update_cn_input(
|
||||
model,
|
||||
width,
|
||||
height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
index,
|
||||
):
|
||||
if model == None:
|
||||
stencils[index] = None
|
||||
images[index] = None
|
||||
preprocessed_hints[index] = None
|
||||
return [
|
||||
gr.ImageEditor(value=None, visible=False),
|
||||
gr.Image(value=None),
|
||||
gr.Slider(visible=False),
|
||||
gr.Slider(visible=False),
|
||||
gr.Button(visible=False),
|
||||
gr.Button(visible=False),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
elif model == "scribble":
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
brush=Brush(
|
||||
colors=["#000000"],
|
||||
color_mode="fixed",
|
||||
default_size=2,
|
||||
),
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=True,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=True, label="Canvas Width"),
|
||||
gr.Slider(visible=True, label="Canvas Height"),
|
||||
gr.Button(visible=True),
|
||||
gr.Button(visible=False),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
gr.ImageEditor(
|
||||
visible=True,
|
||||
image_mode="RGB",
|
||||
type="pil",
|
||||
interactive=True,
|
||||
),
|
||||
gr.Image(
|
||||
visible=True,
|
||||
show_label=False,
|
||||
interactive=True,
|
||||
show_download_button=False,
|
||||
),
|
||||
gr.Slider(visible=True, label="Input Width"),
|
||||
gr.Slider(visible=True, label="Input Height"),
|
||||
gr.Button(visible=False),
|
||||
gr.Button(visible=True),
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
]
|
||||
|
||||
with gr.Accordion(label="Stencil Options", open=False):
|
||||
with gr.Row():
|
||||
use_stencil = gr.Dropdown(
|
||||
elem_id="stencil_model",
|
||||
label="Stencil model",
|
||||
value="None",
|
||||
choices=[
|
||||
"None",
|
||||
"canny",
|
||||
"openpose",
|
||||
"scribble",
|
||||
"zoedepth",
|
||||
with gr.Column():
|
||||
cnet_1 = gr.Button(
|
||||
value="Generate controlnet input"
|
||||
)
|
||||
cnet_1_model = gr.Dropdown(
|
||||
label="Controlnet 1",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=768,
|
||||
value=512,
|
||||
step=8,
|
||||
visible=False,
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=768,
|
||||
value=512,
|
||||
step=8,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
use_input_img_1 = gr.Button(
|
||||
value="Use Original Image",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
cnet_1_image = gr.ImageEditor(
|
||||
visible=False,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
)
|
||||
cnet_1_output = gr.Image(
|
||||
value=None,
|
||||
visible=True,
|
||||
label="Preprocessed Hint",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
use_input_img_1.click(
|
||||
import_original,
|
||||
[img2img_init_image, canvas_width, canvas_height],
|
||||
[cnet_1_image],
|
||||
)
|
||||
|
||||
cnet_1_model.change(
|
||||
fn=(
|
||||
lambda m, w, h, s, i, p: update_cn_input(
|
||||
m, w, h, s, i, p, 0
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_1_model,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_1_image,
|
||||
cnet_1_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
use_input_img_1,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
create_canvas,
|
||||
[canvas_width, canvas_height],
|
||||
[
|
||||
cnet_1_image,
|
||||
],
|
||||
)
|
||||
gr.on(
|
||||
triggers=[cnet_1.click],
|
||||
fn=(
|
||||
lambda a, b, s, i, p: cnet_preview(
|
||||
a, b, 0, s, i, p
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_1_model,
|
||||
cnet_1_image,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_1_output,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
|
||||
def show_canvas(choice):
|
||||
if choice == "scribble":
|
||||
return (
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Slider.update(visible=True),
|
||||
gr.Button.update(visible=True),
|
||||
)
|
||||
else:
|
||||
return (
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Slider.update(visible=False),
|
||||
gr.Button.update(visible=False),
|
||||
)
|
||||
|
||||
def create_canvas(w, h):
|
||||
return np.zeros(shape=(h, w, 3), dtype=np.uint8) + 255
|
||||
|
||||
with gr.Row():
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
with gr.Column():
|
||||
cnet_2 = gr.Button(
|
||||
value="Generate controlnet input"
|
||||
)
|
||||
cnet_2_model = gr.Dropdown(
|
||||
label="Controlnet 2",
|
||||
value="None",
|
||||
choices=choices,
|
||||
)
|
||||
canvas_width = gr.Slider(
|
||||
label="Canvas Width",
|
||||
minimum=256,
|
||||
maximum=768,
|
||||
value=512,
|
||||
step=8,
|
||||
visible=False,
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=768,
|
||||
value=512,
|
||||
step=8,
|
||||
visible=False,
|
||||
)
|
||||
make_canvas = gr.Button(
|
||||
value="Make Canvas!",
|
||||
visible=False,
|
||||
)
|
||||
use_input_img_2 = gr.Button(
|
||||
value="Use Original Image",
|
||||
visible=False,
|
||||
)
|
||||
cnet_2_image = gr.ImageEditor(
|
||||
visible=False,
|
||||
image_mode="RGB",
|
||||
interactive=True,
|
||||
type="pil",
|
||||
show_label=True,
|
||||
label="Input Image",
|
||||
)
|
||||
canvas_height = gr.Slider(
|
||||
label="Canvas Height",
|
||||
minimum=256,
|
||||
maximum=1024,
|
||||
value=512,
|
||||
step=1,
|
||||
visible=False,
|
||||
use_input_img_2.click(
|
||||
import_original,
|
||||
[img2img_init_image, canvas_width, canvas_height],
|
||||
[cnet_2_image],
|
||||
)
|
||||
create_button = gr.Button(
|
||||
label="Start",
|
||||
value="Open drawing canvas!",
|
||||
visible=False,
|
||||
)
|
||||
create_button.click(
|
||||
fn=create_canvas,
|
||||
inputs=[canvas_width, canvas_height],
|
||||
outputs=[img2img_init_image],
|
||||
)
|
||||
use_stencil.change(
|
||||
fn=show_canvas,
|
||||
inputs=use_stencil,
|
||||
outputs=[canvas_width, canvas_height, create_button],
|
||||
cnet_2_output = gr.Image(
|
||||
value=None,
|
||||
visible=True,
|
||||
label="Preprocessed Hint",
|
||||
interactive=True,
|
||||
)
|
||||
cnet_2_model.change(
|
||||
fn=(
|
||||
lambda m, w, h, s, i, p: update_cn_input(
|
||||
m, w, h, s, i, p, 0
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_2_model,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_2_image,
|
||||
cnet_2_output,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
make_canvas,
|
||||
use_input_img_2,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
make_canvas.click(
|
||||
create_canvas,
|
||||
[canvas_width, canvas_height],
|
||||
[
|
||||
cnet_2_image,
|
||||
],
|
||||
)
|
||||
cnet_2.click(
|
||||
fn=(
|
||||
lambda a, b, s, i, p: cnet_preview(
|
||||
a, b, 1, s, i, p
|
||||
)
|
||||
),
|
||||
inputs=[
|
||||
cnet_2_model,
|
||||
cnet_2_image,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
cnet_2_output,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
)
|
||||
control_mode = gr.Radio(
|
||||
choices=["Prompt", "Balanced", "Controlnet"],
|
||||
value="Balanced",
|
||||
label="Control Mode",
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
i2i_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
i2i_lora_info = f"LoRA Path: {i2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=i2i_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
@@ -610,16 +1010,25 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
use_stencil,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
control_mode,
|
||||
stencils,
|
||||
images,
|
||||
preprocessed_hints,
|
||||
],
|
||||
outputs=[
|
||||
img2img_gallery,
|
||||
std_output,
|
||||
img2img_status,
|
||||
stencils,
|
||||
images,
|
||||
],
|
||||
outputs=[img2img_gallery, std_output, img2img_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
)
|
||||
|
||||
@@ -638,3 +1047,18 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,15 @@ import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
import PIL.ImageOps
|
||||
from PIL import Image
|
||||
|
||||
from gradio.components.image_editor import (
|
||||
Brush,
|
||||
Eraser,
|
||||
EditorData,
|
||||
EditorValue,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -13,6 +21,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_paint_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
InpaintPipeline,
|
||||
@@ -35,11 +47,53 @@ init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
def set_image_states(editor_data):
|
||||
input_mask = editor_data["layers"][0]
|
||||
|
||||
# inpaint_inf wants white mask on black background (?), whilst ImageEditor
|
||||
# delivers black mask on transparent (0 opacity) background
|
||||
inference_mask = Image.new(
|
||||
mode="RGB", size=input_mask.size, color=(255, 255, 255)
|
||||
)
|
||||
inference_mask.paste(input_mask, input_mask)
|
||||
inference_mask = PIL.ImageOps.invert(inference_mask)
|
||||
|
||||
return (
|
||||
# we set the ImageEditor data again, because it likes to clear
|
||||
# the image layers (which include the mask) if the user hasn't
|
||||
# used the upload button, and we sent it and image
|
||||
# TODO: work out what is going wrong in that case so we don't have
|
||||
# to do this
|
||||
{
|
||||
"background": editor_data["background"],
|
||||
"layers": [input_mask],
|
||||
"composite": None,
|
||||
},
|
||||
editor_data["background"],
|
||||
input_mask,
|
||||
inference_mask,
|
||||
)
|
||||
|
||||
|
||||
def reload_image_editor(editor_image, editor_mask):
|
||||
# we set the ImageEditor data again, because it likes to clear
|
||||
# the image layers (which include the mask) if the user hasn't
|
||||
# used the upload button, and we sent it the image
|
||||
# TODO: work out what is going wrong in that case so we don't have
|
||||
# to do this
|
||||
return {
|
||||
"background": editor_image,
|
||||
"layers": [editor_mask],
|
||||
"composite": None,
|
||||
}
|
||||
|
||||
|
||||
# Exposed to UI.
|
||||
def inpaint_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
image_dict,
|
||||
image,
|
||||
mask_image,
|
||||
height: int,
|
||||
width: int,
|
||||
inpaint_full_res: bool,
|
||||
@@ -58,7 +112,7 @@ def inpaint_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: int,
|
||||
):
|
||||
@@ -99,9 +153,8 @@ def inpaint_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -120,7 +173,8 @@ def inpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -164,6 +218,7 @@ def inpaint_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -173,8 +228,6 @@ def inpaint_inf(
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
image = image_dict["image"]
|
||||
mask_image = image_dict["mask"]
|
||||
text_output = ""
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
@@ -221,6 +274,9 @@ def inpaint_inf(
|
||||
|
||||
|
||||
with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
editor_image = gr.State()
|
||||
editor_mask = gr.State()
|
||||
inference_mask = gr.State()
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -229,6 +285,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
@@ -236,6 +293,16 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
inpaint_init_image = gr.Sketchpad(
|
||||
label="Masked Image",
|
||||
type="pil",
|
||||
sources=("clipboard", "upload"),
|
||||
interactive=True,
|
||||
brush=Brush(
|
||||
colors=["#000000"],
|
||||
color_mode="fixed",
|
||||
),
|
||||
)
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_model_info = (
|
||||
@@ -285,39 +352,32 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
inpaint_init_image = gr.Image(
|
||||
label="Masked Image",
|
||||
source="upload",
|
||||
tool="sketch",
|
||||
type="pil",
|
||||
height=350,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
inpaint_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=inpaint_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
@@ -441,6 +501,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="contain",
|
||||
# TODO: Re-enable download when fixed in Gradio
|
||||
show_download_button=False,
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{inpaint_model_info}\n"
|
||||
@@ -477,7 +539,8 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
inpaint_init_image,
|
||||
editor_image,
|
||||
inference_mask,
|
||||
height,
|
||||
width,
|
||||
inpaint_full_res,
|
||||
@@ -496,7 +559,7 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -507,14 +570,64 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
|
||||
fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=inpaint_status,
|
||||
show_progress="none",
|
||||
)
|
||||
set_image_states_args = dict(
|
||||
fn=set_image_states,
|
||||
inputs=[inpaint_init_image],
|
||||
outputs=[
|
||||
inpaint_init_image,
|
||||
editor_image,
|
||||
editor_mask,
|
||||
inference_mask,
|
||||
],
|
||||
show_progress="none",
|
||||
)
|
||||
reload_image_editor_args = dict(
|
||||
fn=reload_image_editor,
|
||||
inputs=[editor_image, editor_mask],
|
||||
outputs=[inpaint_init_image],
|
||||
show_progress="none",
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
# all these trigger generation
|
||||
prompt_submit = (
|
||||
prompt.submit(**set_image_states_args)
|
||||
.then(**status_kwargs)
|
||||
.then(**kwargs)
|
||||
.then(**reload_image_editor_args)
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = (
|
||||
negative_prompt.submit(**set_image_states_args)
|
||||
.then(**status_kwargs)
|
||||
.then(**kwargs)
|
||||
.then(**reload_image_editor_args)
|
||||
)
|
||||
generate_click = (
|
||||
stable_diffusion.click(**set_image_states_args)
|
||||
.then(**status_kwargs)
|
||||
.then(**kwargs)
|
||||
.then(**reload_image_editor_args)
|
||||
)
|
||||
|
||||
# Attempts to cancel generation
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
# Updates LoRA information when one is selected
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
49
apps/stable_diffusion/web/ui/js/sd_gradio_workarounds.js
Normal file
49
apps/stable_diffusion/web/ui/js/sd_gradio_workarounds.js
Normal file
@@ -0,0 +1,49 @@
|
||||
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
|
||||
|
||||
() => {
|
||||
console.log(`innerWidth: ${window.innerWidth}` )
|
||||
|
||||
// 1536px rules
|
||||
|
||||
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
|
||||
|
||||
function handleWidth1536(event) {
|
||||
|
||||
// display in full width for desktop devices
|
||||
document.querySelectorAll(".gradio-container")
|
||||
.forEach( (node) => {
|
||||
if (event.matches) {
|
||||
node.classList.add("gradio-container-size-full");
|
||||
} else {
|
||||
node.classList.remove("gradio-container-size-full")
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
mediaQuery1536.addEventListener("change", handleWidth1536);
|
||||
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
|
||||
|
||||
// 1921px rules
|
||||
|
||||
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
|
||||
|
||||
function handleWidth1921(event) {
|
||||
|
||||
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
|
||||
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
|
||||
document.querySelectorAll("#gallery")
|
||||
.forEach( (node) => {
|
||||
if (event.matches) {
|
||||
node.classList.add("gallery-force-height768");
|
||||
node.classList.add("gallery-limit-height768");
|
||||
} else {
|
||||
node.classList.remove("gallery-force-height768");
|
||||
node.classList.remove("gallery-limit-height768");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
mediaQuery1921.addEventListener("change", handleWidth1921);
|
||||
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
|
||||
|
||||
}
|
||||
@@ -23,6 +23,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
@@ -237,9 +238,7 @@ with gr.Blocks(title="Lora Training") as lora_train_web:
|
||||
max_length,
|
||||
training_images_dir,
|
||||
output_loc,
|
||||
get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
),
|
||||
get_custom_vae_or_lora_weights(lora_weights, "lora"),
|
||||
],
|
||||
outputs=[std_output],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
|
||||
@@ -104,8 +104,8 @@ with gr.Blocks() as model_web:
|
||||
civit_models = gr.Gallery(
|
||||
label="Civitai Model Gallery",
|
||||
value=None,
|
||||
interactive=True,
|
||||
visible=False,
|
||||
show_download_button=False,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False) as sendto_btns:
|
||||
|
||||
@@ -3,9 +3,11 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from fastapi.exceptions import HTTPException
|
||||
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -61,7 +63,7 @@ def outpaint_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
@@ -101,9 +103,8 @@ def outpaint_inf(
|
||||
if custom_vae != "None":
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
@@ -122,7 +123,8 @@ def outpaint_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -164,6 +166,7 @@ def outpaint_inf(
|
||||
args.use_base_vae,
|
||||
args.use_tuned,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -237,6 +240,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
@@ -244,6 +248,9 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
outpaint_init_image = gr.Image(
|
||||
label="Input Image", type="pil", sources=["upload"]
|
||||
)
|
||||
with gr.Row():
|
||||
outpaint_model_info = (
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
@@ -291,37 +298,32 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
outpaint_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
height=300,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
outpaint_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=outpaint_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
@@ -524,7 +526,7 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -546,3 +548,18 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -80,27 +80,26 @@ with gr.Blocks() as outputgallery_web:
|
||||
label="Getting subdirectories...",
|
||||
value=nod_logo,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
visible=True,
|
||||
show_label=True,
|
||||
elem_id="top_logo",
|
||||
elem_classes="logo_centered",
|
||||
)
|
||||
|
||||
gallery = gr.Gallery(
|
||||
label="",
|
||||
value=gallery_files.value,
|
||||
visible=False,
|
||||
show_label=True,
|
||||
columns=2,
|
||||
columns=4,
|
||||
)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Group():
|
||||
with gr.Row(elem_id="output_subdir_container"):
|
||||
with gr.Column(
|
||||
scale=15,
|
||||
min_width=160,
|
||||
elem_id="output_subdir_container",
|
||||
):
|
||||
subdirectories = gr.Dropdown(
|
||||
label=f"Subdirectories of {output_dir}",
|
||||
@@ -108,7 +107,7 @@ with gr.Blocks() as outputgallery_web:
|
||||
choices=subdirectory_paths.value,
|
||||
value="",
|
||||
interactive=True,
|
||||
elem_classes="dropdown_no_container",
|
||||
# elem_classes="dropdown_no_container",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column(
|
||||
@@ -148,10 +147,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
) as parameters_accordian:
|
||||
image_parameters = gr.DataFrame(
|
||||
headers=["Parameter", "Value"],
|
||||
col_count=2,
|
||||
col_count=(2, "fixed"),
|
||||
row_count=(1, "fixed"),
|
||||
wrap=True,
|
||||
elem_classes="output_parameters_dataframe",
|
||||
value=[["Status", "No image selected"]],
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Send To", open=True):
|
||||
@@ -162,6 +163,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
outputgallery_sendto_txt2img_sdxl = gr.Button(
|
||||
value="Txt2Img XL",
|
||||
interactive=False,
|
||||
elem_classes="outputgallery_sendto",
|
||||
size="sm",
|
||||
)
|
||||
|
||||
outputgallery_sendto_img2img = gr.Button(
|
||||
value="Img2Img",
|
||||
@@ -195,15 +202,18 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
def on_clear_gallery():
|
||||
return [
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=[],
|
||||
visible=False,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
visible=True,
|
||||
),
|
||||
]
|
||||
|
||||
def on_image_columns_change(columns):
|
||||
return gr.Gallery(columns=columns)
|
||||
|
||||
def on_select_subdir(subdir) -> list:
|
||||
# evt.value is the subdirectory name
|
||||
new_images = outputgallery_filenames(subdir)
|
||||
@@ -212,12 +222,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
)
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -251,16 +261,16 @@ with gr.Blocks() as outputgallery_web:
|
||||
)
|
||||
|
||||
return [
|
||||
gr.Dropdown.update(
|
||||
gr.Dropdown(
|
||||
choices=refreshed_subdirs,
|
||||
value=new_subdir,
|
||||
),
|
||||
refreshed_subdirs,
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images, label=new_label, visible=len(new_images) > 0
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -286,12 +296,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
|
||||
return [
|
||||
new_images,
|
||||
gr.Gallery.update(
|
||||
gr.Gallery(
|
||||
value=new_images,
|
||||
label=new_label,
|
||||
visible=len(new_images) > 0,
|
||||
),
|
||||
gr.Image.update(
|
||||
gr.Image(
|
||||
label=new_label,
|
||||
visible=len(new_images) == 0,
|
||||
),
|
||||
@@ -316,12 +326,18 @@ with gr.Blocks() as outputgallery_web:
|
||||
else:
|
||||
return [
|
||||
filename,
|
||||
list(map(list, params["parameters"].items())),
|
||||
gr.DataFrame(
|
||||
value=list(map(list, params["parameters"].items())),
|
||||
row_count=(len(params["parameters"]), "fixed"),
|
||||
),
|
||||
]
|
||||
|
||||
return [
|
||||
filename,
|
||||
[["Status", "No parameters found"]],
|
||||
gr.DataFrame(
|
||||
value=[["Status", "No parameters found"]],
|
||||
row_count=(1, "fixed"),
|
||||
),
|
||||
]
|
||||
|
||||
def on_outputgallery_filename_change(filename: str) -> list:
|
||||
@@ -329,12 +345,12 @@ with gr.Blocks() as outputgallery_web:
|
||||
return [
|
||||
# disable or enable each of the sendto button based on whether
|
||||
# an image is selected
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button.update(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
gr.Button(interactive=exists),
|
||||
]
|
||||
|
||||
# The time first our tab is selected we need to do an initial refresh
|
||||
@@ -365,53 +381,6 @@ with gr.Blocks() as outputgallery_web:
|
||||
gr.update(),
|
||||
)
|
||||
|
||||
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
|
||||
# support things set with .style, nor the elem_classes kwarg, so we have
|
||||
# to directly set things up via JavaScript if we want the client to take
|
||||
# notice of our changes to the number of columns after it decides to put
|
||||
# them back to the original number when we change something
|
||||
def js_set_columns_in_browser(timeout_length):
|
||||
return f"""
|
||||
(new_cols) => {{
|
||||
setTimeout(() => {{
|
||||
required_style = "auto ".repeat(new_cols).trim();
|
||||
gallery = document.querySelector('#outputgallery_gallery .grid-container');
|
||||
if (gallery) {{
|
||||
gallery.style.gridTemplateColumns = required_style
|
||||
}}
|
||||
}}, {timeout_length});
|
||||
return []; // prevents console error from gradio
|
||||
}}
|
||||
"""
|
||||
|
||||
# --- Wire handlers up to the actions
|
||||
|
||||
# Many actions reset the number of columns shown in the gallery on the
|
||||
# browser end, so we have to set them back to what we think they should
|
||||
# be after the initial action.
|
||||
#
|
||||
# None of the actions on this tab trigger inference, and we want the
|
||||
# user to be able to do them whilst other tabs have ongoing inference
|
||||
# running. Waiting in the queue behind inference jobs would mean the UI
|
||||
# can't fully respond until the inference tasks complete,
|
||||
# hence queue=False on all of these.
|
||||
set_gallery_columns_immediate = dict(
|
||||
fn=None,
|
||||
inputs=[image_columns],
|
||||
# gradio blanks the UI on Chrome on Linux on gallery select if
|
||||
# I don't put an output here
|
||||
outputs=[dev_null],
|
||||
_js=js_set_columns_in_browser(0),
|
||||
queue=False,
|
||||
)
|
||||
|
||||
# setting columns after selecting a gallery item needs a real
|
||||
# timeout length for the number of columns to actually be applied.
|
||||
# Not really sure why, maybe something has to finish animating?
|
||||
set_gallery_columns_delayed = dict(
|
||||
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
|
||||
)
|
||||
|
||||
# clearing images when we need to completely change what's in the
|
||||
# gallery avoids current images being shown replacing piecemeal and
|
||||
# prevents weirdness and errors if the user selects an image during the
|
||||
@@ -423,38 +392,42 @@ with gr.Blocks() as outputgallery_web:
|
||||
queue=False,
|
||||
)
|
||||
|
||||
image_columns.change(**set_gallery_columns_immediate)
|
||||
|
||||
subdirectories.select(**clear_gallery).then(
|
||||
on_select_subdir,
|
||||
[subdirectories],
|
||||
[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
open_subdir.click(
|
||||
on_open_subdir, inputs=[subdirectories], queue=False
|
||||
).then(**set_gallery_columns_immediate)
|
||||
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
|
||||
|
||||
refresh.click(**clear_gallery).then(
|
||||
on_refresh,
|
||||
[subdirectories],
|
||||
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
image_columns.change(
|
||||
fn=on_image_columns_change,
|
||||
inputs=[image_columns],
|
||||
outputs=[gallery],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
gallery.select(
|
||||
on_select_image,
|
||||
[gallery_files],
|
||||
[outputgallery_filename, image_parameters],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_delayed)
|
||||
)
|
||||
|
||||
outputgallery_filename.change(
|
||||
on_outputgallery_filename_change,
|
||||
[outputgallery_filename],
|
||||
[
|
||||
outputgallery_sendto_txt2img,
|
||||
outputgallery_sendto_txt2img_sdxl,
|
||||
outputgallery_sendto_img2img,
|
||||
outputgallery_sendto_inpaint,
|
||||
outputgallery_sendto_outpaint,
|
||||
@@ -477,16 +450,17 @@ with gr.Blocks() as outputgallery_web:
|
||||
open_subdir,
|
||||
],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
)
|
||||
|
||||
# We should have been passed a list of components on other tabs that update
|
||||
# when a new image has generated on that tab, so set things up so the user
|
||||
# will see that new image if they are looking at today's subdirectory
|
||||
def outputgallery_watch(components: gr.Textbox):
|
||||
def outputgallery_watch(components: gr.Textbox, queued_components=[]):
|
||||
for component in components:
|
||||
component.change(
|
||||
on_new_image,
|
||||
inputs=[subdirectories, subdirectory_paths, component],
|
||||
outputs=[gallery_files, gallery, logo],
|
||||
queue=False,
|
||||
).then(**set_gallery_columns_immediate)
|
||||
queue=component in queued_components,
|
||||
show_progress="none",
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.utils import available_devices
|
||||
from shark.iree_utils.compile_utils import clean_device_info
|
||||
from datetime import datetime as dt
|
||||
import json
|
||||
import sys
|
||||
@@ -132,27 +133,6 @@ def get_default_config():
|
||||
c.split_into_layers()
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by LLM pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
device_id = int(device_id) # using device index in webui
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = None
|
||||
|
||||
return device, device_id
|
||||
|
||||
|
||||
model_vmfb_key = ""
|
||||
|
||||
|
||||
@@ -161,7 +141,9 @@ def chat(
|
||||
prompt_prefix,
|
||||
history,
|
||||
model,
|
||||
device,
|
||||
backend,
|
||||
devices,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
@@ -173,7 +155,8 @@ def chat(
|
||||
global vicuna_model
|
||||
|
||||
model_name, model_path = list(map(str.strip, model.split("=>")))
|
||||
device, device_id = clean_device_info(device)
|
||||
device, device_id = clean_device_info(devices[0])
|
||||
no_of_devices = len(devices)
|
||||
|
||||
from apps.language_models.scripts.vicuna import ShardedVicuna
|
||||
from apps.language_models.scripts.vicuna import UnshardedVicuna
|
||||
@@ -193,6 +176,11 @@ def chat(
|
||||
get_vulkan_target_triple,
|
||||
)
|
||||
|
||||
_extra_args = _extra_args + [
|
||||
"--iree-opt-const-eval=false",
|
||||
"--iree-opt-data-tiling=false",
|
||||
]
|
||||
|
||||
if device == "vulkan":
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if vulkan_target_triple == "":
|
||||
@@ -234,7 +222,7 @@ def chat(
|
||||
)
|
||||
print(f"extra args = {_extra_args}")
|
||||
|
||||
if model_name == "vicuna4":
|
||||
if sharded:
|
||||
vicuna_model = ShardedVicuna(
|
||||
model_name,
|
||||
hf_model_path=model_path,
|
||||
@@ -243,6 +231,7 @@ def chat(
|
||||
max_num_tokens=max_toks,
|
||||
compressed=True,
|
||||
extra_args_cmd=_extra_args,
|
||||
n_devices=no_of_devices,
|
||||
)
|
||||
else:
|
||||
# if config_file is None:
|
||||
@@ -270,13 +259,14 @@ def chat(
|
||||
total_time_ms = 0.001 # In order to avoid divide by zero error
|
||||
prefill_time = 0
|
||||
is_first = True
|
||||
for text, msg, exec_time in progress.tqdm(
|
||||
vicuna_model.generate(prompt, cli=cli),
|
||||
desc="generating response",
|
||||
):
|
||||
# for text, msg, exec_time in progress.tqdm(
|
||||
# vicuna_model.generate(prompt, cli=cli),
|
||||
# desc="generating response",
|
||||
# ):
|
||||
for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli):
|
||||
if msg is None:
|
||||
if is_first:
|
||||
prefill_time = exec_time
|
||||
prefill_time = exec_time / 1000
|
||||
is_first = False
|
||||
else:
|
||||
total_time_ms += exec_time
|
||||
@@ -397,6 +387,16 @@ def view_json_file(file_obj):
|
||||
return content
|
||||
|
||||
|
||||
filtered_devices = dict()
|
||||
|
||||
|
||||
def change_backend(backend):
|
||||
new_choices = gr.Dropdown(
|
||||
choices=filtered_devices[backend], label=f"{backend} devices"
|
||||
)
|
||||
return new_choices
|
||||
|
||||
|
||||
with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
with gr.Row():
|
||||
model_choices = list(
|
||||
@@ -413,15 +413,22 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
# show cpu-task device first in list for chatbot
|
||||
supported_devices = supported_devices[-1:] + supported_devices[:-1]
|
||||
supported_devices = [x for x in supported_devices if "sync" not in x]
|
||||
backend_list = ["cpu", "cuda", "vulkan", "rocm"]
|
||||
for x in backend_list:
|
||||
filtered_devices[x] = [y for y in supported_devices if x in y]
|
||||
print(filtered_devices)
|
||||
|
||||
backend = gr.Radio(
|
||||
label="backend",
|
||||
value="cpu",
|
||||
choices=backend_list,
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value=supported_devices[0]
|
||||
if enabled
|
||||
else "Only CUDA Supported for now",
|
||||
choices=supported_devices,
|
||||
interactive=enabled,
|
||||
label="cpu devices",
|
||||
choices=filtered_devices["cpu"],
|
||||
interactive=True,
|
||||
allow_custom_value=True,
|
||||
# multiselect=True,
|
||||
multiselect=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
@@ -445,18 +452,23 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
sharded = gr.Checkbox(
|
||||
label="Shard Model",
|
||||
value=False,
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
with gr.Row(visible=False):
|
||||
with gr.Group():
|
||||
config_file = gr.File(
|
||||
label="Upload sharding configuration", visible=False
|
||||
)
|
||||
json_view_button = gr.Button(label="View as JSON", visible=False)
|
||||
json_view = gr.JSON(interactive=True, visible=False)
|
||||
json_view_button = gr.Button(value="View as JSON", visible=False)
|
||||
json_view = gr.JSON(visible=False)
|
||||
json_view_button.click(
|
||||
fn=view_json_file, inputs=[config_file], outputs=[json_view]
|
||||
)
|
||||
chatbot = gr.Chatbot(height=500)
|
||||
chatbot = gr.Chatbot(elem_id="chatbot")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
msg = gr.Textbox(
|
||||
@@ -472,6 +484,13 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
stop = gr.Button("Stop", interactive=enabled)
|
||||
clear = gr.Button("Clear", interactive=enabled)
|
||||
|
||||
backend.change(
|
||||
fn=change_backend,
|
||||
inputs=[backend],
|
||||
outputs=[device],
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
submit_event = msg.submit(
|
||||
fn=user,
|
||||
inputs=[msg, chatbot],
|
||||
@@ -484,7 +503,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
backend,
|
||||
device,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
@@ -505,7 +526,9 @@ with gr.Blocks(title="Chatbot") as stablelm_chat:
|
||||
prompt_prefix,
|
||||
chatbot,
|
||||
model,
|
||||
backend,
|
||||
device,
|
||||
sharded,
|
||||
precision,
|
||||
download_vmfb,
|
||||
config_file,
|
||||
|
||||
661
apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Normal file
661
apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py
Normal file
@@ -0,0 +1,661 @@
|
||||
import os
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from math import ceil
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
get_custom_model_path,
|
||||
get_custom_model_files,
|
||||
scheduler_list,
|
||||
predefined_sdxl_models,
|
||||
cancel_sd,
|
||||
set_model_default_configs,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
Text2ImageSDXLPipeline,
|
||||
get_schedulers,
|
||||
set_init_device_flags,
|
||||
utils,
|
||||
save_output_img,
|
||||
prompt_examples,
|
||||
Image2ImagePipeline,
|
||||
)
|
||||
from apps.stable_diffusion.src.utils import (
|
||||
get_generated_imgs_path,
|
||||
get_generation_text_info,
|
||||
)
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_iree_metal_target_platform = args.iree_metal_target_platform
|
||||
init_use_tuned = args.use_tuned
|
||||
init_import_mlir = args.import_mlir
|
||||
|
||||
|
||||
def txt2img_sdxl_inf(
|
||||
prompt: str,
|
||||
negative_prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
steps: int,
|
||||
guidance_scale: float,
|
||||
seed: str | int,
|
||||
batch_count: int,
|
||||
batch_size: int,
|
||||
scheduler: str,
|
||||
model_id: str,
|
||||
custom_vae: str,
|
||||
precision: str,
|
||||
device: str,
|
||||
max_length: int,
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
get_custom_model_pathfile,
|
||||
get_custom_vae_or_lora_weights,
|
||||
Config,
|
||||
)
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
)
|
||||
|
||||
if precision != "fp16":
|
||||
print("currently we support fp16 for SDXL")
|
||||
precision = "fp16"
|
||||
|
||||
args.prompts = [prompt]
|
||||
args.negative_prompts = [negative_prompt]
|
||||
args.guidance_scale = guidance_scale
|
||||
args.steps = steps
|
||||
args.scheduler = scheduler
|
||||
args.ondemand = ondemand
|
||||
|
||||
# set ckpt_loc and hf_model_id.
|
||||
args.ckpt_loc = ""
|
||||
args.hf_model_id = ""
|
||||
args.custom_vae = ""
|
||||
|
||||
# .safetensor or .chkpt on the custom model path
|
||||
if model_id in get_custom_model_files():
|
||||
args.ckpt_loc = get_custom_model_pathfile(model_id)
|
||||
# civitai download
|
||||
elif "civitai" in model_id:
|
||||
args.ckpt_loc = model_id
|
||||
# either predefined or huggingface
|
||||
else:
|
||||
args.hf_model_id = model_id
|
||||
|
||||
if custom_vae:
|
||||
args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae")
|
||||
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
new_config_obj = Config(
|
||||
"txt2img_sdxl",
|
||||
args.hf_model_id,
|
||||
args.ckpt_loc,
|
||||
args.custom_vae,
|
||||
precision,
|
||||
batch_size,
|
||||
max_length,
|
||||
height,
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=None,
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
not global_obj.get_sd_obj()
|
||||
or global_obj.get_cfg_obj() != new_config_obj
|
||||
):
|
||||
global_obj.clear_cache()
|
||||
global_obj.set_cfg_obj(new_config_obj)
|
||||
args.precision = precision
|
||||
args.batch_count = batch_count
|
||||
args.batch_size = batch_size
|
||||
args.max_length = max_length
|
||||
args.height = height
|
||||
args.width = width
|
||||
args.device = device.split("=>", 1)[1].strip()
|
||||
args.iree_vulkan_target_triple = init_iree_vulkan_target_triple
|
||||
args.iree_metal_target_platform = init_iree_metal_target_platform
|
||||
args.use_tuned = init_use_tuned
|
||||
args.import_mlir = init_import_mlir
|
||||
args.img_path = None
|
||||
set_init_device_flags()
|
||||
model_id = (
|
||||
args.hf_model_id
|
||||
if args.hf_model_id
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
)
|
||||
|
||||
global_obj.set_schedulers(get_schedulers(model_id))
|
||||
scheduler_obj = global_obj.get_scheduler(scheduler)
|
||||
if global_obj.get_cfg_obj().ondemand:
|
||||
print("Running txt2img in memory efficient mode.")
|
||||
global_obj.set_sd_obj(
|
||||
Text2ImageSDXLPipeline.from_pretrained(
|
||||
scheduler=scheduler_obj,
|
||||
import_mlir=args.import_mlir,
|
||||
model_id=args.hf_model_id,
|
||||
ckpt_loc=args.ckpt_loc,
|
||||
precision=precision,
|
||||
max_length=max_length,
|
||||
batch_size=batch_size,
|
||||
height=height,
|
||||
width=width,
|
||||
use_base_vae=args.use_base_vae,
|
||||
use_tuned=args.use_tuned,
|
||||
custom_vae=args.custom_vae,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
use_quantize=args.use_quantize,
|
||||
ondemand=global_obj.get_cfg_obj().ondemand,
|
||||
)
|
||||
)
|
||||
|
||||
global_obj.set_sd_scheduler(scheduler)
|
||||
|
||||
start_time = time.time()
|
||||
global_obj.get_sd_obj().log = ""
|
||||
generated_imgs = []
|
||||
text_output = ""
|
||||
try:
|
||||
seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds)
|
||||
except TypeError as error:
|
||||
raise gr.Error(str(error)) from None
|
||||
|
||||
for current_batch in range(batch_count):
|
||||
out_imgs = global_obj.get_sd_obj().generate_images(
|
||||
prompt,
|
||||
negative_prompt,
|
||||
batch_size,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seeds[current_batch],
|
||||
args.max_length,
|
||||
dtype,
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
text_output = get_generation_text_info(
|
||||
seeds[: current_batch + 1], device
|
||||
)
|
||||
text_output += "\n" + global_obj.get_sd_obj().log
|
||||
text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec"
|
||||
|
||||
if global_obj.get_sd_status() == SD_STATE_CANCEL:
|
||||
break
|
||||
else:
|
||||
save_output_img(out_imgs[0], seeds[current_batch])
|
||||
generated_imgs.extend(out_imgs)
|
||||
yield generated_imgs, text_output, status_label(
|
||||
"Text-to-Image-SDXL",
|
||||
current_batch + 1,
|
||||
batch_count,
|
||||
batch_size,
|
||||
)
|
||||
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
theme = gr.themes.Glass(
|
||||
primary_hue="slate",
|
||||
secondary_hue="gray",
|
||||
)
|
||||
|
||||
with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, elem_id="demo_title_outer"):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
)
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Row():
|
||||
t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_sdxl_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
choices=predefined_sdxl_models
|
||||
+ get_custom_model_files(
|
||||
custom_checkpoint_type="sdxl"
|
||||
),
|
||||
allow_custom_value=True,
|
||||
scale=11,
|
||||
)
|
||||
t2i_sdxl_vae_info = (
|
||||
str(get_custom_model_path("vae"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_sdxl_vae_info = (
|
||||
f"VAE Path: {t2i_sdxl_vae_info}"
|
||||
)
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"VAE Models",
|
||||
info=t2i_sdxl_vae_info,
|
||||
elem_id="custom_model",
|
||||
value="madebyollin/sdxl-vae-fp16-fix",
|
||||
choices=[
|
||||
None,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
]
|
||||
+ get_custom_model_files(
|
||||
"vae", custom_checkpoint_type="sdxl"
|
||||
),
|
||||
allow_custom_value=True,
|
||||
scale=4,
|
||||
)
|
||||
txt2img_sdxl_png_info_img = gr.Image(
|
||||
scale=1,
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
visible=True,
|
||||
sources=["upload"],
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
txt2img_sdxl_autogen = gr.Checkbox(
|
||||
label="Auto-Generate Images",
|
||||
value=False,
|
||||
visible=False,
|
||||
)
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=[
|
||||
"DDIM",
|
||||
"EulerAncestralDiscrete",
|
||||
"EulerDiscrete",
|
||||
"LCMScheduler",
|
||||
],
|
||||
allow_custom_value=True,
|
||||
visible=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=768,
|
||||
step=256,
|
||||
label="Height",
|
||||
visible=True,
|
||||
interactive=True,
|
||||
)
|
||||
width = gr.Slider(
|
||||
512,
|
||||
1024,
|
||||
value=768,
|
||||
step=256,
|
||||
label="Width",
|
||||
visible=True,
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="fp16",
|
||||
choices=[
|
||||
"fp16",
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
max_length = gr.Radio(
|
||||
label="Max Length",
|
||||
value=77,
|
||||
choices=[
|
||||
64,
|
||||
77,
|
||||
],
|
||||
visible=False,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
step=0.1,
|
||||
label="Guidance Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=False,
|
||||
visible=False,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Accordion(label="Prompt Examples!", open=False):
|
||||
ex = gr.Examples(
|
||||
examples=prompt_examples,
|
||||
inputs=prompt,
|
||||
cache_examples=False,
|
||||
elem_id="prompt_examples",
|
||||
)
|
||||
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Group():
|
||||
txt2img_sdxl_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
columns=[2],
|
||||
object_fit="scale_down",
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{t2i_sdxl_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=1,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
txt2img_sdxl_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Row():
|
||||
txt2img_sdxl_sendto_img2img = gr.Button(
|
||||
value="Send To Img2Img",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_inpaint = gr.Button(
|
||||
value="Send To Inpaint",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_outpaint = gr.Button(
|
||||
value="Send To Outpaint",
|
||||
visible=False,
|
||||
)
|
||||
txt2img_sdxl_sendto_upscaler = gr.Button(
|
||||
value="Send To Upscaler",
|
||||
visible=False,
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=txt2img_sdxl_inf,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
txt2img_sdxl_custom_model,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
max_length,
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status],
|
||||
show_progress="minimal" if args.progress_bar else "none",
|
||||
queue=True,
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=txt2img_sdxl_status,
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
def autogen_changed(checked):
|
||||
if checked:
|
||||
args.autogen = True
|
||||
else:
|
||||
args.autogen = False
|
||||
|
||||
def check_last_input(prompt):
|
||||
if not prompt.endswith(" "):
|
||||
return True
|
||||
elif not args.autogen:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
auto_gen_kwargs = dict(
|
||||
fn=check_last_input,
|
||||
inputs=[negative_prompt],
|
||||
outputs=[txt2img_sdxl_status],
|
||||
concurrency_limit=1,
|
||||
)
|
||||
|
||||
txt2img_sdxl_autogen.change(
|
||||
fn=autogen_changed,
|
||||
inputs=[txt2img_sdxl_autogen],
|
||||
outputs=None,
|
||||
)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(
|
||||
**kwargs
|
||||
)
|
||||
generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[
|
||||
prompt_submit,
|
||||
neg_prompt_submit,
|
||||
generate_click,
|
||||
],
|
||||
)
|
||||
|
||||
txt2img_sdxl_png_info_img.change(
|
||||
fn=import_png_metadata,
|
||||
inputs=[
|
||||
txt2img_sdxl_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
txt2img_sdxl_png_info_img,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
seed,
|
||||
width,
|
||||
height,
|
||||
txt2img_sdxl_custom_model,
|
||||
lora_weights,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
txt2img_sdxl_custom_model.change(
|
||||
fn=set_model_default_configs,
|
||||
inputs=[
|
||||
txt2img_sdxl_custom_model,
|
||||
],
|
||||
outputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
steps,
|
||||
scheduler,
|
||||
guidance_scale,
|
||||
width,
|
||||
height,
|
||||
custom_vae,
|
||||
txt2img_sdxl_autogen,
|
||||
],
|
||||
)
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
@@ -1,10 +1,13 @@
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
import torch
|
||||
import time
|
||||
import sys
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
from math import ceil
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -15,6 +18,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
@@ -33,6 +40,34 @@ from apps.stable_diffusion.src.utils import (
|
||||
resampler_list,
|
||||
)
|
||||
|
||||
# Names of all interactive fields that can be edited by user
|
||||
all_gradio_labels = [
|
||||
"txt2img_custom_model",
|
||||
"custom_vae",
|
||||
"prompt",
|
||||
"negative_prompt",
|
||||
"lora_weights",
|
||||
"lora_strength",
|
||||
"scheduler",
|
||||
"save_metadata_to_png",
|
||||
"save_metadata_to_json",
|
||||
"height",
|
||||
"width",
|
||||
"steps",
|
||||
"guidance_scale",
|
||||
"Low VRAM",
|
||||
"use_hiresfix",
|
||||
"resample_type",
|
||||
"hiresfix_height",
|
||||
"hiresfix_width",
|
||||
"hiresfix_strength",
|
||||
"batch_count",
|
||||
"batch_size",
|
||||
"repeatable_seeds",
|
||||
"seed",
|
||||
"device",
|
||||
]
|
||||
|
||||
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
|
||||
init_iree_vulkan_target_triple = args.iree_vulkan_target_triple
|
||||
init_iree_metal_target_platform = args.iree_metal_target_platform
|
||||
@@ -59,7 +94,7 @@ def txt2img_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
use_hiresfix: bool,
|
||||
@@ -106,9 +141,8 @@ def txt2img_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -124,7 +158,8 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -175,6 +210,7 @@ def txt2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -224,7 +260,8 @@ def txt2img_inf(
|
||||
width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil="None",
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
|
||||
@@ -256,6 +293,7 @@ def txt2img_inf(
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
debug=args.import_debug if args.import_mlir else False,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -278,7 +316,9 @@ def txt2img_inf(
|
||||
args.use_base_vae,
|
||||
cpu_scheduling,
|
||||
args.max_embeddings_multiples,
|
||||
use_stencil="None",
|
||||
stencils=[],
|
||||
images=None,
|
||||
control_mode=None,
|
||||
resample_type=resample_type,
|
||||
)
|
||||
total_time = time.time() - start_time
|
||||
@@ -300,7 +340,92 @@ def txt2img_inf(
|
||||
return generated_imgs, text_output, ""
|
||||
|
||||
|
||||
with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
def resource_path(relative_path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(
|
||||
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
|
||||
)
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
|
||||
|
||||
# This function export values for all fields that can be edited by user to the settings.json file in ui folder
|
||||
def export_settings(*values):
|
||||
settings_list = list(zip(all_gradio_labels, values))
|
||||
settings = {}
|
||||
|
||||
for label, value in settings_list:
|
||||
settings[label] = value
|
||||
|
||||
settings = {"txt2img": settings}
|
||||
with open("./ui/settings.json", "w") as json_file:
|
||||
json.dump(settings, json_file, indent=4)
|
||||
|
||||
|
||||
# This function loads all values for all fields that can be edited by user from the settings.json file in ui folder
|
||||
def load_settings():
|
||||
try:
|
||||
with open("./ui/settings.json", "r") as json_file:
|
||||
loaded_settings = json.load(json_file)["txt2img"]
|
||||
except (FileNotFoundError, KeyError):
|
||||
warnings.warn(
|
||||
"Settings.json file not found or 'txt2img' key is missing. Using default values for fields."
|
||||
)
|
||||
loaded_settings = (
|
||||
{}
|
||||
) # json file not existing or the data wasn't saved yet
|
||||
|
||||
return [
|
||||
loaded_settings.get(
|
||||
"txt2img_custom_model",
|
||||
os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
),
|
||||
loaded_settings.get(
|
||||
"custom_vae",
|
||||
os.path.basename(args.custom_vae) if args.custom_vae else "None",
|
||||
),
|
||||
loaded_settings.get("prompt", args.prompts[0]),
|
||||
loaded_settings.get("negative_prompt", args.negative_prompts[0]),
|
||||
loaded_settings.get("lora_weights", "None"),
|
||||
loaded_settings.get("lora_strength", args.lora_strength),
|
||||
loaded_settings.get("scheduler", args.scheduler),
|
||||
loaded_settings.get(
|
||||
"save_metadata_to_png", args.write_metadata_to_png
|
||||
),
|
||||
loaded_settings.get(
|
||||
"save_metadata_to_json", args.save_metadata_to_json
|
||||
),
|
||||
loaded_settings.get("height", args.height),
|
||||
loaded_settings.get("width", args.width),
|
||||
loaded_settings.get("steps", args.steps),
|
||||
loaded_settings.get("guidance_scale", args.guidance_scale),
|
||||
loaded_settings.get("Low VRAM", args.ondemand),
|
||||
loaded_settings.get("use_hiresfix", args.use_hiresfix),
|
||||
loaded_settings.get("resample_type", args.resample_type),
|
||||
loaded_settings.get("hiresfix_height", args.hiresfix_height),
|
||||
loaded_settings.get("hiresfix_width", args.hiresfix_width),
|
||||
loaded_settings.get("hiresfix_strength", args.hiresfix_strength),
|
||||
loaded_settings.get("batch_count", args.batch_count),
|
||||
loaded_settings.get("batch_size", args.batch_size),
|
||||
loaded_settings.get("repeatable_seeds", args.repeatable_seeds),
|
||||
loaded_settings.get("seed", args.seed),
|
||||
loaded_settings.get("device", available_devices[0]),
|
||||
]
|
||||
|
||||
|
||||
# This function loads the user's exported default settings on the start of program
|
||||
def onload_load_settings():
|
||||
loaded_data = load_settings()
|
||||
structured_data = settings_list = list(zip(all_gradio_labels, loaded_data))
|
||||
return dict(structured_data)
|
||||
|
||||
|
||||
default_settings = onload_load_settings()
|
||||
with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row():
|
||||
@@ -309,6 +434,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
@@ -317,20 +443,20 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=10):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
txt2img_custom_model = gr.Dropdown(
|
||||
label=f"Models",
|
||||
info="Select, or enter HuggingFace Model ID or Civitai model download URL",
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.ckpt_loc)
|
||||
if args.ckpt_loc
|
||||
else "stabilityai/stable-diffusion-2-1-base",
|
||||
value=default_settings.get(
|
||||
"txt2img_custom_model"
|
||||
),
|
||||
choices=get_custom_model_files()
|
||||
+ predefined_models,
|
||||
allow_custom_value=True,
|
||||
scale=2,
|
||||
scale=11,
|
||||
)
|
||||
# janky fix for overflowing text
|
||||
t2i_vae_info = (
|
||||
@@ -341,93 +467,101 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
label=f"VAE Models",
|
||||
info=t2i_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=os.path.basename(args.custom_vae)
|
||||
if args.custom_vae
|
||||
else "None",
|
||||
value=default_settings.get("custom_vae"),
|
||||
choices=["None"]
|
||||
+ get_custom_model_files("vae"),
|
||||
allow_custom_value=True,
|
||||
scale=4,
|
||||
)
|
||||
txt2img_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
visible=True,
|
||||
sources=["upload"],
|
||||
scale=1,
|
||||
)
|
||||
with gr.Column(scale=1, min_width=170):
|
||||
txt2img_png_info_img = gr.Image(
|
||||
label="Import PNG info",
|
||||
elem_id="txt2img_prompt_image",
|
||||
type="pil",
|
||||
tool="None",
|
||||
visible=True,
|
||||
)
|
||||
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=args.prompts[0],
|
||||
value=default_settings.get("prompt"),
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
# TODO: coming soon
|
||||
autogen = gr.Checkbox(
|
||||
label="Continuous Generation",
|
||||
visible=False,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=args.negative_prompts[0],
|
||||
value=default_settings.get("negative_prompt"),
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
t2i_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
t2i_lora_info = f"LoRA Path: {t2i_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=t2i_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
value=default_settings.get("lora_weights"),
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=default_settings.get("lora_strength"),
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value=args.scheduler,
|
||||
value=default_settings.get("scheduler"),
|
||||
choices=scheduler_list,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Column():
|
||||
save_metadata_to_png = gr.Checkbox(
|
||||
label="Save prompt information to PNG",
|
||||
value=args.write_metadata_to_png,
|
||||
value=default_settings.get(
|
||||
"save_metadata_to_png"
|
||||
),
|
||||
interactive=True,
|
||||
)
|
||||
save_metadata_to_json = gr.Checkbox(
|
||||
label="Save prompt information to JSON file",
|
||||
value=args.save_metadata_to_json,
|
||||
value=default_settings.get(
|
||||
"save_metadata_to_json"
|
||||
),
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.height,
|
||||
value=default_settings.get("height"),
|
||||
step=8,
|
||||
label="Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.width,
|
||||
value=default_settings.get("width"),
|
||||
step=8,
|
||||
label="Width",
|
||||
)
|
||||
@@ -452,18 +586,22 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1, 100, value=args.steps, step=1, label="Steps"
|
||||
1,
|
||||
100,
|
||||
value=default_settings.get("steps"),
|
||||
step=1,
|
||||
label="Steps",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=args.guidance_scale,
|
||||
value=default_settings.get("guidance_scale"),
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
ondemand = gr.Checkbox(
|
||||
value=args.ondemand,
|
||||
value=default_settings.get("Low VRAM"),
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
@@ -472,7 +610,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=args.batch_count,
|
||||
value=default_settings.get("batch_count"),
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
@@ -483,23 +621,24 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
4,
|
||||
value=args.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
label=default_settings.get("batch_size"),
|
||||
interactive=True,
|
||||
visible=False,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
args.repeatable_seeds,
|
||||
default_settings.get("repeatable_seeds"),
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Accordion(label="Hires Fix Options", open=False):
|
||||
with gr.Group():
|
||||
with gr.Row():
|
||||
use_hiresfix = gr.Checkbox(
|
||||
value=args.use_hiresfix,
|
||||
value=default_settings.get("use_hiresfix"),
|
||||
label="Use Hires Fix",
|
||||
interactive=True,
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=args.resample_type,
|
||||
value=default_settings.get("resample_type"),
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=False,
|
||||
@@ -507,34 +646,34 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
hiresfix_height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_height,
|
||||
value=default_settings.get("hiresfix_height"),
|
||||
step=8,
|
||||
label="Hires Fix Height",
|
||||
)
|
||||
hiresfix_width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=args.hiresfix_width,
|
||||
value=default_settings.get("hiresfix_width"),
|
||||
step=8,
|
||||
label="Hires Fix Width",
|
||||
)
|
||||
hiresfix_strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=args.hiresfix_strength,
|
||||
value=default_settings.get("hiresfix_strength"),
|
||||
step=0.01,
|
||||
label="Hires Fix Denoising Strength",
|
||||
)
|
||||
with gr.Row():
|
||||
seed = gr.Textbox(
|
||||
value=args.seed,
|
||||
value=default_settings.get("seed"),
|
||||
label="Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=available_devices[0],
|
||||
value=default_settings.get("device"),
|
||||
choices=available_devices,
|
||||
allow_custom_value=True,
|
||||
)
|
||||
@@ -585,6 +724,75 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
txt2img_sendto_upscaler = gr.Button(
|
||||
value="SendTo Upscaler"
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
export_defaults = gr.Button(
|
||||
value="Load Default Settings"
|
||||
)
|
||||
export_defaults.click(
|
||||
fn=load_settings,
|
||||
inputs=[],
|
||||
outputs=[
|
||||
txt2img_custom_model,
|
||||
custom_vae,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
lora_weights,
|
||||
lora_strength,
|
||||
scheduler,
|
||||
save_metadata_to_png,
|
||||
save_metadata_to_json,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
ondemand,
|
||||
use_hiresfix,
|
||||
resample_type,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
hiresfix_strength,
|
||||
batch_count,
|
||||
batch_size,
|
||||
repeatable_seeds,
|
||||
seed,
|
||||
device,
|
||||
],
|
||||
)
|
||||
with gr.Column(scale=2):
|
||||
export_defaults = gr.Button(
|
||||
value="Export Default Settings"
|
||||
)
|
||||
export_defaults.click(
|
||||
fn=export_settings,
|
||||
inputs=[
|
||||
txt2img_custom_model,
|
||||
custom_vae,
|
||||
prompt,
|
||||
negative_prompt,
|
||||
lora_weights,
|
||||
lora_strength,
|
||||
scheduler,
|
||||
save_metadata_to_png,
|
||||
save_metadata_to_json,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
guidance_scale,
|
||||
ondemand,
|
||||
use_hiresfix,
|
||||
resample_type,
|
||||
hiresfix_height,
|
||||
hiresfix_width,
|
||||
hiresfix_strength,
|
||||
batch_count,
|
||||
batch_size,
|
||||
repeatable_seeds,
|
||||
seed,
|
||||
device,
|
||||
],
|
||||
outputs=[],
|
||||
)
|
||||
|
||||
kwargs = dict(
|
||||
fn=txt2img_inf,
|
||||
@@ -607,7 +815,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
use_hiresfix,
|
||||
@@ -650,7 +858,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
custom_vae,
|
||||
],
|
||||
outputs=[
|
||||
@@ -665,7 +872,7 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
height,
|
||||
txt2img_custom_model,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
custom_vae,
|
||||
],
|
||||
)
|
||||
@@ -673,12 +880,12 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
|
||||
def set_compatible_schedulers(hires_fix_selected):
|
||||
if hires_fix_selected:
|
||||
return gr.Dropdown.update(
|
||||
return gr.Dropdown(
|
||||
choices=scheduler_list_cpu_only,
|
||||
value="DEISMultistep",
|
||||
)
|
||||
else:
|
||||
return gr.Dropdown.update(
|
||||
return gr.Dropdown(
|
||||
choices=scheduler_list,
|
||||
value="SharkEulerDiscrete",
|
||||
)
|
||||
@@ -689,3 +896,18 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
|
||||
outputs=[scheduler],
|
||||
queue=False,
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@ import torch
|
||||
import time
|
||||
import gradio as gr
|
||||
from PIL import Image
|
||||
|
||||
from apps.stable_diffusion.web.ui.utils import (
|
||||
available_devices,
|
||||
nodlogo_loc,
|
||||
@@ -12,6 +13,10 @@ from apps.stable_diffusion.web.ui.utils import (
|
||||
predefined_upscaler_models,
|
||||
cancel_sd,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.common_ui_events import (
|
||||
lora_changed,
|
||||
lora_strength_changed,
|
||||
)
|
||||
from apps.stable_diffusion.web.utils.common_label_calc import status_label
|
||||
from apps.stable_diffusion.src import (
|
||||
args,
|
||||
@@ -51,7 +56,7 @@ def upscaler_inf(
|
||||
save_metadata_to_json: bool,
|
||||
save_metadata_to_png: bool,
|
||||
lora_weights: str,
|
||||
lora_hf_id: str,
|
||||
lora_strength: float,
|
||||
ondemand: bool,
|
||||
repeatable_seeds: bool,
|
||||
):
|
||||
@@ -98,9 +103,8 @@ def upscaler_inf(
|
||||
args.save_metadata_to_json = save_metadata_to_json
|
||||
args.write_metadata_to_png = save_metadata_to_png
|
||||
|
||||
args.use_lora = get_custom_vae_or_lora_weights(
|
||||
lora_weights, lora_hf_id, "lora"
|
||||
)
|
||||
args.use_lora = get_custom_vae_or_lora_weights(lora_weights, "lora")
|
||||
args.lora_strength = lora_strength
|
||||
|
||||
dtype = torch.float32 if precision == "fp32" else torch.half
|
||||
cpu_scheduling = not scheduler.startswith("Shark")
|
||||
@@ -118,7 +122,8 @@ def upscaler_inf(
|
||||
args.width,
|
||||
device,
|
||||
use_lora=args.use_lora,
|
||||
use_stencil=None,
|
||||
lora_strength=args.lora_strength,
|
||||
stencils=[],
|
||||
ondemand=ondemand,
|
||||
)
|
||||
if (
|
||||
@@ -157,6 +162,7 @@ def upscaler_inf(
|
||||
args.use_tuned,
|
||||
low_cpu_mem_usage=args.low_cpu_mem_usage,
|
||||
use_lora=args.use_lora,
|
||||
lora_strength=args.lora_strength,
|
||||
ondemand=args.ondemand,
|
||||
)
|
||||
)
|
||||
@@ -253,6 +259,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
@@ -260,6 +267,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
with gr.Row(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=1, min_width=600):
|
||||
upscaler_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
sources=["upload"],
|
||||
)
|
||||
with gr.Row():
|
||||
upscaler_model_info = (
|
||||
f"Custom Model Path: {str(get_custom_model_path())}"
|
||||
@@ -308,37 +320,32 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
upscaler_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
height=300,
|
||||
)
|
||||
|
||||
with gr.Accordion(label="LoRA Options", open=False):
|
||||
with gr.Row():
|
||||
# janky fix for overflowing text
|
||||
upscaler_lora_info = (
|
||||
str(get_custom_model_path("lora"))
|
||||
).replace("\\", "\n\\")
|
||||
upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}"
|
||||
lora_weights = gr.Dropdown(
|
||||
label=f"Standalone LoRA Weights",
|
||||
info=upscaler_lora_info,
|
||||
label=f"LoRA Weights",
|
||||
info=f"Select from LoRA in {str(get_custom_model_path('lora'))}, or enter HuggingFace Model ID",
|
||||
elem_id="lora_weights",
|
||||
value="None",
|
||||
choices=["None"] + get_custom_model_files("lora"),
|
||||
allow_custom_value=True,
|
||||
scale=3,
|
||||
)
|
||||
lora_hf_id = gr.Textbox(
|
||||
elem_id="lora_hf_id",
|
||||
placeholder="Select 'None' in the Standalone LoRA "
|
||||
"weights dropdown on the left if you want to use "
|
||||
"a standalone HuggingFace model ID for LoRA here "
|
||||
"e.g: sayakpaul/sd-model-finetuned-lora-t4",
|
||||
value="",
|
||||
label="HuggingFace Model ID",
|
||||
lines=3,
|
||||
lora_strength = gr.Number(
|
||||
label="LoRA Strength",
|
||||
info="Will be baked into the .vmfb",
|
||||
step=0.01,
|
||||
# number is checked on change so to allow 0.n values
|
||||
# we have to allow 0 or you can't type 0.n in
|
||||
minimum=0.0,
|
||||
maximum=2.0,
|
||||
value=args.lora_strength,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
lora_tags = gr.HTML(
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=False):
|
||||
with gr.Row():
|
||||
@@ -515,7 +522,7 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
save_metadata_to_json,
|
||||
save_metadata_to_png,
|
||||
lora_weights,
|
||||
lora_hf_id,
|
||||
lora_strength,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
],
|
||||
@@ -537,3 +544,18 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
lora_weights.change(
|
||||
fn=lora_changed,
|
||||
inputs=[lora_weights],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
)
|
||||
|
||||
lora_strength.change(
|
||||
fn=lora_strength_changed,
|
||||
inputs=lora_strength,
|
||||
outputs=lora_strength,
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,19 @@
|
||||
import os
|
||||
import sys
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
import glob
|
||||
import math
|
||||
import json
|
||||
import safetensors
|
||||
import gradio as gr
|
||||
import PIL.Image as Image
|
||||
|
||||
from pathlib import Path
|
||||
from apps.stable_diffusion.src import args
|
||||
from dataclasses import dataclass
|
||||
from enum import IntEnum
|
||||
from gradio.components.image_editor import EditorValue
|
||||
|
||||
from apps.stable_diffusion.src import get_available_devices
|
||||
import apps.stable_diffusion.web.utils.global_obj as global_obj
|
||||
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
|
||||
SD_STATE_CANCEL,
|
||||
@@ -24,10 +33,20 @@ class Config:
|
||||
width: int
|
||||
device: str
|
||||
use_lora: str
|
||||
use_stencil: str
|
||||
lora_strength: float
|
||||
stencils: list[str]
|
||||
ondemand: str # should this be expecting a bool instead?
|
||||
|
||||
|
||||
class HSLHue(IntEnum):
|
||||
RED = 0
|
||||
YELLOW = 60
|
||||
GREEN = 120
|
||||
CYAN = 180
|
||||
BLUE = 240
|
||||
MAGENTA = 300
|
||||
|
||||
|
||||
custom_model_filetypes = (
|
||||
"*.ckpt",
|
||||
"*.safetensors",
|
||||
@@ -49,9 +68,11 @@ scheduler_list_cpu_only = [
|
||||
"DPMSolverSinglestep",
|
||||
"DDPM",
|
||||
"HeunDiscrete",
|
||||
"LCMScheduler",
|
||||
]
|
||||
scheduler_list = scheduler_list_cpu_only + [
|
||||
"SharkEulerDiscrete",
|
||||
"SharkEulerAncestralDiscrete",
|
||||
]
|
||||
|
||||
predefined_models = [
|
||||
@@ -72,6 +93,10 @@ predefined_paint_models = [
|
||||
predefined_upscaler_models = [
|
||||
"stabilityai/stable-diffusion-x4-upscaler",
|
||||
]
|
||||
predefined_sdxl_models = [
|
||||
"stabilityai/sdxl-turbo",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
]
|
||||
|
||||
|
||||
def resource_path(relative_path):
|
||||
@@ -125,6 +150,12 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
)
|
||||
]
|
||||
match custom_checkpoint_type:
|
||||
case "sdxl":
|
||||
files = [
|
||||
val
|
||||
for val in files
|
||||
if any(x in val for x in ["XL", "xl", "Xl"])
|
||||
]
|
||||
case "inpainting":
|
||||
files = [
|
||||
val
|
||||
@@ -150,17 +181,82 @@ def get_custom_model_files(model="models", custom_checkpoint_type=""):
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_custom_vae_or_lora_weights(weights, hf_id, model):
|
||||
use_weight = ""
|
||||
if weights == "None" and not hf_id:
|
||||
def get_custom_vae_or_lora_weights(weights, model):
|
||||
if weights == "None":
|
||||
use_weight = ""
|
||||
elif not hf_id:
|
||||
use_weight = get_custom_model_pathfile(weights, model)
|
||||
else:
|
||||
use_weight = hf_id
|
||||
custom_weights = get_custom_model_pathfile(str(weights), model)
|
||||
if os.path.isfile(custom_weights):
|
||||
use_weight = custom_weights
|
||||
else:
|
||||
use_weight = weights
|
||||
|
||||
return use_weight
|
||||
|
||||
|
||||
def hsl_color(alpha: float, start, end):
|
||||
b = (end - start) * (alpha if alpha > 0 else 0)
|
||||
result = b + start
|
||||
|
||||
# Return a CSS HSL string
|
||||
return f"hsl({math.floor(result)}, 80%, 35%)"
|
||||
|
||||
|
||||
def get_lora_metadata(lora_filename):
|
||||
# get the metadata from the file
|
||||
filename = get_custom_model_pathfile(lora_filename, "lora")
|
||||
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
|
||||
metadata = f.metadata()
|
||||
|
||||
# guard clause for if there isn't any metadata
|
||||
if not metadata:
|
||||
return None
|
||||
|
||||
# metadata is a dictionary of strings, the values of the keys we're
|
||||
# interested in are actually json, and need to be loaded as such
|
||||
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
|
||||
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
|
||||
tag_dirs = [dir for dir in tag_frequencies.keys()]
|
||||
|
||||
# gather the tag frequency information for all the datasets trained
|
||||
all_frequencies = {}
|
||||
for dataset in tag_dirs:
|
||||
frequencies = sorted(
|
||||
[entry for entry in tag_frequencies[dataset].items()],
|
||||
reverse=True,
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
|
||||
# get a figure for the total number of images processed for this dataset
|
||||
# either then number actually listed or in its dataset_dir entry or
|
||||
# the highest frequency's number if that doesn't exist
|
||||
img_count = dataset_dirs.get(dir, {}).get(
|
||||
"img_count", frequencies[0][1]
|
||||
)
|
||||
|
||||
# add the dataset frequencies to the overall frequencies replacing the
|
||||
# frequency counts on the tags with a percentage/ratio
|
||||
all_frequencies.update(
|
||||
[(entry[0], entry[1] / img_count) for entry in frequencies]
|
||||
)
|
||||
|
||||
trained_model_id = " ".join(
|
||||
[
|
||||
metadata.get("ss_sd_model_hash", ""),
|
||||
metadata.get("ss_sd_model_name", ""),
|
||||
metadata.get("ss_base_model_version", ""),
|
||||
]
|
||||
).strip()
|
||||
|
||||
# return the topmost <count> of all frequencies in all datasets
|
||||
return {
|
||||
"model": trained_model_id,
|
||||
"frequencies": sorted(
|
||||
all_frequencies.items(), reverse=True, key=lambda x: x[1]
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def cancel_sd():
|
||||
# Try catch it, as gc can delete global_obj.sd_obj while switching model
|
||||
try:
|
||||
@@ -169,6 +265,116 @@ def cancel_sd():
|
||||
pass
|
||||
|
||||
|
||||
def set_model_default_configs(model_ckpt_or_id, jsonconfig=None):
|
||||
import gradio as gr
|
||||
|
||||
config_modelname = default_config_exists(model_ckpt_or_id)
|
||||
if jsonconfig:
|
||||
return get_config_from_json(jsonconfig)
|
||||
elif config_modelname:
|
||||
return default_configs[config_modelname]
|
||||
# TODO: Use HF metadata to setup pipeline if available
|
||||
# elif is_valid_hf_id(model_ckpt_or_id):
|
||||
# return get_HF_default_configs(model_ckpt_or_id)
|
||||
else:
|
||||
# We don't have default metadata to setup a good config. Do not change configs.
|
||||
return [
|
||||
gr.Textbox(label="Prompt", interactive=True, visible=True),
|
||||
gr.Textbox(label="Negative Prompt", interactive=True),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.update(),
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
value=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def get_config_from_json(model_ckpt_or_id, jsonconfig):
|
||||
# TODO: make this work properly. It is currently not user-exposed.
|
||||
cfgdata = json.load(jsonconfig)
|
||||
return [
|
||||
cfgdata["prompt_box_behavior"],
|
||||
cfgdata["neg_prompt_box_behavior"],
|
||||
cfgdata["steps"],
|
||||
cfgdata["scheduler"],
|
||||
cfgdata["guidance_scale"],
|
||||
cfgdata["width"],
|
||||
cfgdata["height"],
|
||||
cfgdata["custom_vae"],
|
||||
]
|
||||
|
||||
|
||||
def default_config_exists(model_ckpt_or_id):
|
||||
if model_ckpt_or_id in default_configs.keys():
|
||||
return model_ckpt_or_id
|
||||
elif "turbo" in model_ckpt_or_id.lower():
|
||||
return "stabilityai/sdxl-turbo"
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def mask_editor_value_for_image_file(filepath):
|
||||
image = Image.open(filepath)
|
||||
mask = Image.new(mode="RGBA", size=image.size, color=(0, 0, 0, 0))
|
||||
return {"background": image, "layers": [mask], "composite": image}
|
||||
|
||||
|
||||
def mask_editor_value_for_gallery_data(gallery_data):
|
||||
filepath = (
|
||||
gallery_data.root[0].image.path
|
||||
if len(gallery_data.root) != 0
|
||||
else None
|
||||
)
|
||||
|
||||
if os.path.isfile(filepath):
|
||||
return mask_editor_value_for_image_file(filepath)
|
||||
|
||||
return EditorValue()
|
||||
|
||||
|
||||
default_configs = {
|
||||
"stabilityai/sdxl-turbo": [
|
||||
gr.Textbox(label="", interactive=False, value=None, visible=False),
|
||||
gr.Textbox(
|
||||
label="Prompt",
|
||||
value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette",
|
||||
),
|
||||
gr.Slider(0, 10, value=2),
|
||||
"EulerAncestralDiscrete",
|
||||
gr.Slider(0, value=0),
|
||||
512,
|
||||
512,
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate", visible=False, interactive=True, value=False
|
||||
),
|
||||
],
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": [
|
||||
gr.Textbox(label="Prompt", interactive=True, visible=True),
|
||||
gr.Textbox(label="Negative Prompt", interactive=True),
|
||||
40,
|
||||
"EulerDiscrete",
|
||||
7.5,
|
||||
gr.Slider(value=768, interactive=True),
|
||||
gr.Slider(value=768, interactive=True),
|
||||
"madebyollin/sdxl-vae-fp16-fix",
|
||||
gr.Checkbox(
|
||||
label="Auto-Generate",
|
||||
visible=False,
|
||||
interactive=False,
|
||||
value=False,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
nodlogo_loc = resource_path("logos/nod-logo.png")
|
||||
nodicon_loc = resource_path("logos/nod-icon.png")
|
||||
available_devices = get_available_devices()
|
||||
|
||||
@@ -122,20 +122,26 @@ def find_vae_from_png_metadata(
|
||||
|
||||
def find_lora_from_png_metadata(
|
||||
key: str, metadata: dict[str, str | int]
|
||||
) -> tuple[str, str]:
|
||||
lora_hf_id = ""
|
||||
) -> tuple[str, float]:
|
||||
lora_custom = ""
|
||||
lora_strength = 1.0
|
||||
|
||||
if key in metadata:
|
||||
lora_file = metadata[key]
|
||||
split_metadata = metadata[key].split(":")
|
||||
lora_file = split_metadata[0]
|
||||
if len(split_metadata) == 2:
|
||||
try:
|
||||
lora_strength = float(split_metadata[1])
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora")
|
||||
# If nothing had matched, check vendor/hf_model_id
|
||||
if not lora_custom and lora_file.count("/"):
|
||||
lora_hf_id = lora_file
|
||||
lora_custom = lora_file
|
||||
|
||||
# LoRA input is optional, should not print or throw an error if missing
|
||||
|
||||
return lora_custom, lora_hf_id
|
||||
return lora_custom, lora_strength
|
||||
|
||||
|
||||
def import_png_metadata(
|
||||
@@ -150,7 +156,6 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_vae,
|
||||
):
|
||||
try:
|
||||
@@ -160,9 +165,10 @@ def import_png_metadata(
|
||||
(png_custom_model, png_hf_model_id) = find_model_from_png_metadata(
|
||||
"Model", metadata
|
||||
)
|
||||
(lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata(
|
||||
"LoRA", metadata
|
||||
)
|
||||
(
|
||||
custom_lora,
|
||||
custom_lora_strength,
|
||||
) = find_lora_from_png_metadata("LoRA", metadata)
|
||||
vae_custom_model = find_vae_from_png_metadata("VAE", metadata)
|
||||
|
||||
negative_prompt = metadata["Negative prompt"]
|
||||
@@ -177,12 +183,8 @@ def import_png_metadata(
|
||||
elif "Model" in metadata and png_hf_model_id:
|
||||
custom_model = png_hf_model_id
|
||||
|
||||
if "LoRA" in metadata and lora_custom_model:
|
||||
custom_lora = lora_custom_model
|
||||
hf_lora_id = ""
|
||||
if "LoRA" in metadata and lora_hf_model_id:
|
||||
if "LoRA" in metadata and not custom_lora:
|
||||
custom_lora = "None"
|
||||
hf_lora_id = lora_hf_model_id
|
||||
|
||||
if "VAE" in metadata and vae_custom_model:
|
||||
custom_vae = vae_custom_model
|
||||
@@ -215,6 +217,6 @@ def import_png_metadata(
|
||||
height,
|
||||
custom_model,
|
||||
custom_lora,
|
||||
hf_lora_id,
|
||||
custom_lora_strength,
|
||||
custom_vae,
|
||||
)
|
||||
|
||||
@@ -5,11 +5,25 @@ from time import time
|
||||
shark_tmp = os.path.join(os.getcwd(), "shark_tmp/")
|
||||
|
||||
|
||||
def config_gradio_tmp_imgs_folder():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
def clear_tmp_mlir():
|
||||
cleanup_start = time()
|
||||
print(
|
||||
"Clearing .mlir temporary files from a prior run. This may take some time..."
|
||||
)
|
||||
mlir_files = [
|
||||
filename
|
||||
for filename in os.listdir(shark_tmp)
|
||||
if os.path.isfile(os.path.join(shark_tmp, filename))
|
||||
and filename.endswith(".mlir")
|
||||
]
|
||||
for filename in mlir_files:
|
||||
os.remove(shark_tmp + filename)
|
||||
print(
|
||||
f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds."
|
||||
)
|
||||
|
||||
|
||||
def clear_tmp_imgs():
|
||||
# tell gradio to use a directory under shark_tmp for its temporary
|
||||
# image files unless somewhere else has been set
|
||||
if "GRADIO_TEMP_DIR" not in os.environ:
|
||||
@@ -52,3 +66,12 @@ def config_gradio_tmp_imgs_folder():
|
||||
)
|
||||
else:
|
||||
print("No temporary images files to clear.")
|
||||
|
||||
|
||||
def config_tmp():
|
||||
# create shark_tmp if it does not exist
|
||||
if not os.path.exists(shark_tmp):
|
||||
os.mkdir(shark_tmp)
|
||||
|
||||
clear_tmp_mlir()
|
||||
clear_tmp_imgs()
|
||||
@@ -78,7 +78,10 @@ def test_loop(
|
||||
os.mkdir("./test_images/golden")
|
||||
get_inpaint_inputs()
|
||||
hf_model_names = model_config_dicts[0].values()
|
||||
tuned_options = ["--no-use_tuned", "--use_tuned"]
|
||||
tuned_options = [
|
||||
"--no-use_tuned",
|
||||
"--use_tuned",
|
||||
]
|
||||
import_options = ["--import_mlir", "--no-import_mlir"]
|
||||
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
|
||||
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
|
||||
@@ -112,6 +115,8 @@ def test_loop(
|
||||
and use_tune == tuned_options[1]
|
||||
):
|
||||
continue
|
||||
elif use_tune == tuned_options[1]:
|
||||
continue
|
||||
command = (
|
||||
[
|
||||
executable, # executable is the python from the venv used to run this
|
||||
|
||||
@@ -23,6 +23,7 @@ with gr.Blocks(title="Dataset Annotation Tool", css=demo_css) as shark_web:
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
show_download_button=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=100,
|
||||
|
||||
@@ -22,33 +22,33 @@ This does mean however, that on a brand new fresh install of SHARK that has not
|
||||
|
||||
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
|
||||
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
|
||||
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
|
||||
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
|
||||
|
||||
```powershell
|
||||
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
|
||||
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
|
||||
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## Run trom the base directory of a source clone of SHARK on Windows:
|
||||
.\setup_venv.ps1
|
||||
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
|
||||
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## Run a the base directory of a source clone of SHARK on Linux:
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
|
||||
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
|
||||
|
||||
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
|
||||
|
||||
## Since the api respects most applicable SHARK command line arguments for options not specified,
|
||||
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
|
||||
.\node_ai_shark_studio_20320901_2525.exe --help
|
||||
|
||||
## For instance, the example above, but with a a custom VAE specified
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
|
||||
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
|
||||
|
||||
## An example with multiple specific CORS origins
|
||||
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
|
||||
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
|
||||
```
|
||||
|
||||
SHARK should start in server mode, and you should see something like this:
|
||||
|
||||
@@ -6,8 +6,8 @@ requires = [
|
||||
|
||||
"numpy>=1.22.4",
|
||||
"torch-mlir>=20230620.875",
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
"iree-compiler==20231212.*",
|
||||
"iree-runtime==20231212.*",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ sacremoses
|
||||
sentencepiece
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
gradio==3.44.3
|
||||
altair
|
||||
scipy
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ diffusers
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio==3.44.3
|
||||
gradio==4.12.0
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
@@ -50,4 +50,8 @@ pefile
|
||||
pyinstaller
|
||||
|
||||
# vicuna quantization
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
|
||||
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
|
||||
|
||||
# For quantized GPTQ models
|
||||
optimum
|
||||
auto_gptq
|
||||
|
||||
4
setup.py
4
setup.py
@@ -11,8 +11,8 @@ PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
|
||||
backend_deps = []
|
||||
if "NO_BACKEND" in os.environ.keys():
|
||||
backend_deps = [
|
||||
"iree-compiler>=20221022.190",
|
||||
"iree-runtime>=20221022.190",
|
||||
"iree-compiler==20231212.*",
|
||||
"iree-runtime==20231212.*",
|
||||
]
|
||||
|
||||
setup(
|
||||
|
||||
@@ -7,13 +7,13 @@
|
||||
It checks the Python version installed and installs any required build
|
||||
dependencies into a Python virtual environment.
|
||||
If that environment does not exist, it creates it.
|
||||
|
||||
|
||||
.PARAMETER update-src
|
||||
git pulls latest version
|
||||
|
||||
.PARAMETER force
|
||||
removes and recreates venv to force update of all dependencies
|
||||
|
||||
|
||||
.EXAMPLE
|
||||
.\setup_venv.ps1 --force
|
||||
|
||||
@@ -39,11 +39,11 @@ if ($arguments -eq "--force"){
|
||||
Write-Host "deactivating..."
|
||||
Deactivate
|
||||
}
|
||||
|
||||
if (Test-Path .\shark.venv\) {
|
||||
|
||||
if (Test-Path .\shark1.venv\) {
|
||||
Write-Host "removing and recreating venv..."
|
||||
Remove-Item .\shark.venv -Force -Recurse
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Remove-Item .\shark1.venv -Force -Recurse
|
||||
if (Test-Path .\shark1.venv\) {
|
||||
Write-Host 'could not remove .\shark-venv - please try running ".\setup_venv.ps1 --force" again!'
|
||||
exit 1
|
||||
}
|
||||
@@ -83,15 +83,15 @@ if (!($PyVer -like "*3.11*") -and !($p -like "*3.11*")) # if 3.11 is not in any
|
||||
|
||||
Write-Host "Installing Build Dependencies"
|
||||
# make sure we really use 3.11 from list, even if it's not the default.
|
||||
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark.venv\}
|
||||
else {python -m venv .\shark.venv\}
|
||||
.\shark.venv\Scripts\activate
|
||||
if ($NULL -ne $PyVer) {py -3.11 -m venv .\shark1.venv\}
|
||||
else {python -m venv .\shark1.venv\}
|
||||
.\shark1.venv\Scripts\activate
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
|
||||
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler==20231212.* iree-runtime==20231212.*
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
Write-Host "Build and installation completed successfully"
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
Write-Host "Source your venv with ./shark1.venv/Scripts/activate"
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
# Environment variables used by the script.
|
||||
# PYTHON=$PYTHON3.10 ./setup_venv.sh #pass a version of $PYTHON to use
|
||||
# VENV_DIR=myshark.venv #create a venv called myshark.venv
|
||||
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
|
||||
# SKIP_VENV=1 #Don't create and activate a Python venv. Use the current environment.
|
||||
# USE_IREE=1 #use stock IREE instead of Nod.ai's SHARK build
|
||||
# IMPORTER=1 #Install importer deps
|
||||
# BENCHMARK=1 #Install benchmark deps
|
||||
@@ -35,7 +35,7 @@ fi
|
||||
if [[ "$SKIP_VENV" != "1" ]]; then
|
||||
if [[ -z "${CONDA_PREFIX}" ]]; then
|
||||
# Not a conda env. So create a new VENV dir
|
||||
VENV_DIR=${VENV_DIR:-shark.venv}
|
||||
VENV_DIR=${VENV_DIR:-shark1.venv}
|
||||
echo "Using pip venv.. Setting up venv dir: $VENV_DIR"
|
||||
$PYTHON -m venv "$VENV_DIR" || die "Could not create venv."
|
||||
source "$VENV_DIR/bin/activate" || die "Could not activate venv"
|
||||
@@ -111,7 +111,7 @@ else
|
||||
fi
|
||||
if [[ -z "${NO_BACKEND}" ]]; then
|
||||
echo "Installing ${RUNTIME}..."
|
||||
$PYTHON -m pip install --pre --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
|
||||
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler==20231212.* iree-runtime==20231212.*
|
||||
else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
fi
|
||||
|
||||
@@ -31,60 +31,63 @@ from .benchmark_utils import *
|
||||
# Get the iree-compile arguments given device.
|
||||
def get_iree_device_args(device, extra_args=[]):
|
||||
print("Configuring for device:" + device)
|
||||
device_uri = device.split("://")
|
||||
if len(device_uri) > 1:
|
||||
if device_uri[0] not in ["vulkan", "rocm"]:
|
||||
print(
|
||||
f"Specific device selection only supported for vulkan and rocm."
|
||||
f"Proceeding with {device} as device."
|
||||
)
|
||||
# device_uri can be device_num or device_path.
|
||||
# assuming number of devices for a single driver will be not be >99
|
||||
if len(device_uri[1]) <= 2:
|
||||
# expected to be device index in range 0 - 99
|
||||
device_num = int(device_uri[1])
|
||||
else:
|
||||
# expected to be device path
|
||||
device_num = device_uri[1]
|
||||
|
||||
else:
|
||||
device_num = 0
|
||||
device, device_num = clean_device_info(device)
|
||||
|
||||
if "cpu" in device:
|
||||
from shark.iree_utils.cpu_utils import get_iree_cpu_args
|
||||
|
||||
data_tiling_flag = ["--iree-opt-data-tiling"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-microkernels"]
|
||||
u_kernel_flag = ["--iree-llvmcpu-enable-ukernels"]
|
||||
stack_size_flag = ["--iree-llvmcpu-stack-allocation-limit=256000"]
|
||||
|
||||
return (
|
||||
get_iree_cpu_args()
|
||||
+ data_tiling_flag
|
||||
+ u_kernel_flag
|
||||
+ stack_size_flag
|
||||
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
|
||||
)
|
||||
if device_uri[0] == "cuda":
|
||||
if device == "cuda":
|
||||
from shark.iree_utils.gpu_utils import get_iree_gpu_args
|
||||
|
||||
return get_iree_gpu_args()
|
||||
if device_uri[0] == "vulkan":
|
||||
if device == "vulkan":
|
||||
from shark.iree_utils.vulkan_utils import get_iree_vulkan_args
|
||||
|
||||
return get_iree_vulkan_args(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
if device_uri[0] == "metal":
|
||||
if device == "metal":
|
||||
from shark.iree_utils.metal_utils import get_iree_metal_args
|
||||
|
||||
return get_iree_metal_args(extra_args=extra_args)
|
||||
if device_uri[0] == "rocm":
|
||||
if device == "rocm":
|
||||
from shark.iree_utils.gpu_utils import get_iree_rocm_args
|
||||
|
||||
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
|
||||
return []
|
||||
|
||||
|
||||
def clean_device_info(raw_device):
|
||||
# return appropriate device and device_id for consumption by Studio pipeline
|
||||
# Multiple devices only supported for vulkan and rocm (as of now).
|
||||
# default device must be selected for all others
|
||||
|
||||
device_id = None
|
||||
device = (
|
||||
raw_device
|
||||
if "=>" not in raw_device
|
||||
else raw_device.split("=>")[1].strip()
|
||||
)
|
||||
if "://" in device:
|
||||
device, device_id = device.split("://")
|
||||
if len(device_id) <= 2:
|
||||
device_id = int(device_id)
|
||||
|
||||
if device not in ["rocm", "vulkan"]:
|
||||
device_id = None
|
||||
if device in ["rocm", "vulkan"] and device_id == None:
|
||||
device_id = 0
|
||||
return device, device_id
|
||||
|
||||
|
||||
# Get the iree-compiler arguments given frontend.
|
||||
def get_iree_frontend_args(frontend):
|
||||
if frontend in ["torch", "pytorch", "linalg", "tm_tensor"]:
|
||||
@@ -351,11 +354,15 @@ def get_iree_module(
|
||||
device = iree_device_map(device)
|
||||
print("registering device id: ", device_idx)
|
||||
haldriver = ireert.get_driver(device)
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = hal_device_id
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
vm_module = ireert.VmModule.from_buffer(
|
||||
@@ -394,15 +401,16 @@ def load_vmfb_using_mmap(
|
||||
haldriver = ireert.get_driver(device)
|
||||
dl.log(f"ireert.get_driver()")
|
||||
|
||||
hal_device_id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
haldevice = haldriver.create_device(
|
||||
haldriver.query_available_devices()[device_idx]["device_id"],
|
||||
hal_device_id,
|
||||
allocators=shark_args.device_allocator,
|
||||
)
|
||||
dl.log(f"ireert.create_device()")
|
||||
config = ireert.Config(device=haldevice)
|
||||
config.id = haldriver.query_available_devices()[device_idx][
|
||||
"device_id"
|
||||
]
|
||||
config.id = hal_device_id
|
||||
dl.log(f"ireert.Config()")
|
||||
else:
|
||||
config = get_iree_runtime_config(device)
|
||||
|
||||
@@ -95,6 +95,7 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
|
||||
print("could not execute `iree-run-module --dump_devices=rocm`")
|
||||
|
||||
if dump_device_info is not None:
|
||||
device_num = 0 if device_num is None else device_num
|
||||
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
|
||||
if len(device_arch_pairs) > device_num: # can find arch in the list
|
||||
arch_in_device_dump = device_arch_pairs[device_num][1]
|
||||
@@ -103,7 +104,7 @@ def get_rocm_device_arch(device_num=0, extra_args=[]):
|
||||
print(f"Found ROCm device arch : {arch_in_device_dump}")
|
||||
return arch_in_device_dump
|
||||
|
||||
default_rocm_arch = "gfx_1100"
|
||||
default_rocm_arch = "gfx1100"
|
||||
print(
|
||||
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
|
||||
"\n or from `iree-run-module --dump_devices=rocm` command."
|
||||
|
||||
@@ -38,15 +38,24 @@ def get_all_vulkan_devices():
|
||||
|
||||
@functools.cache
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: {vulkaninfo_list[device_num]}")
|
||||
return vulkaninfo_list[device_num]
|
||||
if isinstance(device_num, int):
|
||||
vulkaninfo_list = get_all_vulkan_devices()
|
||||
|
||||
if len(vulkaninfo_list) == 0:
|
||||
raise ValueError("No device name found in VulkanInfo!")
|
||||
if len(vulkaninfo_list) > 1:
|
||||
print("Following devices found:")
|
||||
for i, dname in enumerate(vulkaninfo_list):
|
||||
print(f"{i}. {dname}")
|
||||
print(f"Choosing device: vulkan://{device_num}")
|
||||
vulkan_device_name = vulkaninfo_list[device_num]
|
||||
else:
|
||||
from iree.runtime import get_driver
|
||||
|
||||
vulkan_device_driver = get_driver(device_num)
|
||||
vulkan_device_name = vulkan_device_driver.query_available_devices()[0]
|
||||
print(vulkan_device_name)
|
||||
return vulkan_device_name
|
||||
|
||||
|
||||
def get_os_name():
|
||||
@@ -130,8 +139,16 @@ def get_vulkan_target_triple(device_name):
|
||||
triple = f"rdna3-780m-{system_os}"
|
||||
elif all(x in device_name for x in ("AMD", "PRO", "W7900")):
|
||||
triple = f"rdna3-w7900-{system_os}"
|
||||
elif any(x in device_name for x in ("AMD", "Radeon")):
|
||||
elif "7600" in device_name:
|
||||
triple = f"rdna3-7600-{system_os}"
|
||||
elif "7700" in device_name:
|
||||
triple = f"rdna3-7700-{system_os}"
|
||||
elif any(x in device_name for x in "AMD", "Radeon") and any(x in device_name for x in "6700", "6750"):
|
||||
triple = f"rdna2-unknown-{system_os}"
|
||||
elif any(x in device_name for x in "AMD", "Radeon") and "7" in device_name:
|
||||
triple = f"rdna3-unknown-{system_os}"
|
||||
elif any(x in device_name for x in "AMD", "Radeon"):
|
||||
triple = f"rdna3-unknown-{system_os}"
|
||||
# Intel Targets
|
||||
elif any(x in device_name for x in ("A770", "A750")):
|
||||
triple = f"arc-770-{system_os}"
|
||||
@@ -174,6 +191,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
# res_vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
|
||||
|
||||
res_vulkan_flag = []
|
||||
res_vulkan_flag += [
|
||||
"--iree-stream-resource-max-allocation-size=3221225472"
|
||||
]
|
||||
vulkan_triple_flag = None
|
||||
for arg in extra_args:
|
||||
if "-iree-vulkan-target-triple=" in arg:
|
||||
@@ -195,7 +215,9 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
@functools.cache
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
|
||||
f"--vulkan_validation_layers={'true' if shark_args.vulkan_debug_utils else 'false'}",
|
||||
f"--vulkan_debug_verbosity={'4' if shark_args.vulkan_debug_utils else '0'}"
|
||||
f"--vulkan-robust-buffer-access={'true' if shark_args.vulkan_debug_utils else 'false'}",
|
||||
]
|
||||
return vulkan_runtime_flags
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from typing import List, Tuple
|
||||
from io import BytesIO
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.common.generative.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def compile_int_precision(
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_scale_precision="float_scale",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
|
||||
@@ -800,15 +800,17 @@ def save_mlir(
|
||||
model_name,
|
||||
mlir_dialect="linalg",
|
||||
frontend="torch",
|
||||
dir=tempfile.gettempdir(),
|
||||
dir="",
|
||||
):
|
||||
model_name_mlir = (
|
||||
model_name + "_" + frontend + "_" + mlir_dialect + ".mlir"
|
||||
)
|
||||
if dir == "":
|
||||
dir = tempfile.gettempdir()
|
||||
dir = os.path.join(".", "shark_tmp")
|
||||
mlir_path = os.path.join(dir, model_name_mlir)
|
||||
print(f"saving {model_name_mlir} to {dir}")
|
||||
if not os.path.exists(dir):
|
||||
os.makedirs(dir)
|
||||
if frontend == "torch":
|
||||
with open(mlir_path, "wb") as mlir_file:
|
||||
mlir_file.write(mlir_module)
|
||||
|
||||
62
shark/tests/test_txt2img_ui.py
Normal file
62
shark/tests/test_txt2img_ui.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import unittest
|
||||
from unittest.mock import mock_open, patch
|
||||
|
||||
from apps.stable_diffusion.web.ui.txt2img_ui import (
|
||||
export_settings,
|
||||
load_settings,
|
||||
all_gradio_labels,
|
||||
)
|
||||
|
||||
|
||||
class TestExportSettings(unittest.TestCase):
|
||||
@patch("builtins.open", new_callable=mock_open)
|
||||
@patch("json.dump")
|
||||
def test_export_settings(self, mock_json_dump, mock_file):
|
||||
test_values = ["value1", "value2", "value3"]
|
||||
expected_output = {
|
||||
"txt2img": {
|
||||
label: value
|
||||
for label, value in zip(all_gradio_labels, test_values)
|
||||
}
|
||||
}
|
||||
|
||||
export_settings(*test_values)
|
||||
mock_file.assert_called_once_with("./ui/settings.json", "w")
|
||||
mock_json_dump.assert_called_once_with(
|
||||
expected_output, mock_file(), indent=4
|
||||
)
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=mock_open,
|
||||
read_data='{"txt2img": {"some_setting": "some_value"}}',
|
||||
)
|
||||
def test_load_settings_file_exists(self, mock_file, mock_json_load):
|
||||
mock_json_load.return_value = {
|
||||
"txt2img": {
|
||||
"txt2img_custom_model": "custom_model_value",
|
||||
"custom_vae": "custom_vae_value",
|
||||
}
|
||||
}
|
||||
|
||||
settings = load_settings()
|
||||
self.assertEqual(settings[0], "custom_model_value")
|
||||
self.assertEqual(settings[1], "custom_vae_value")
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch("builtins.open", side_effect=FileNotFoundError)
|
||||
def test_load_settings_file_not_found(self, mock_file, mock_json_load):
|
||||
settings = load_settings()
|
||||
|
||||
default_lora_weights = "None"
|
||||
self.assertEqual(settings[4], default_lora_weights)
|
||||
|
||||
@patch("apps.stable_diffusion.web.ui.txt2img_ui.json.load")
|
||||
@patch("builtins.open", new_callable=mock_open, read_data="{}")
|
||||
def test_load_settings_key_error(self, mock_file, mock_json_load):
|
||||
mock_json_load.return_value = {}
|
||||
|
||||
settings = load_settings()
|
||||
default_lora_weights = "None"
|
||||
self.assertEqual(settings[4], default_lora_weights)
|
||||
@@ -1,21 +1,19 @@
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
bert-base-uncased_fp16,linalg,torch,1e-1,1e-1,default,None,True,True,True,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,True,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,True,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
bert-base-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
bert-large-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
facebook/deit-small-distilled-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"Fails during iree-compile.",""
|
||||
google/vit-base-patch16-224,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/311",""
|
||||
microsoft/beit-base-patch16-224-pt22k-ft22k,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/390","macos"
|
||||
microsoft/MiniLM-L12-H384-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"",""
|
||||
google/mobilebert-uncased,linalg,torch,1e-2,1e-3,default,None,False,False,False,"https://github.com/nod-ai/SHARK/issues/344","macos"
|
||||
mobilenet_v3_small,linalg,torch,1e-1,1e-2,default,nhcw-nhwc,False,False,False,"https://github.com/nod-ai/SHARK/issues/388, https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
nvidia/mit-b0,linalg,torch,1e-2,1e-3,default,None,True,True,True,"https://github.com/nod-ai/SHARK/issues/343,https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,False,"","macos"
|
||||
resnet101,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
|
||||
resnet18,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
|
||||
resnet50,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
resnet50_fp16,linalg,torch,1e-2,1e-2,default,nhcw-nhwc/img2col,True,True,True,"Numerics issues, awaiting cuda-independent fp16 integration",""
|
||||
squeezenet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,False,False,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"","macos"
|
||||
wide_resnet50_2,linalg,torch,1e-2,1e-3,default,nhcw-nhwc/img2col,True,True,True,"","macos"
|
||||
mnasnet1_0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,False,False,False,"","macos"
|
||||
efficientnet_b0,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
efficientnet_b7,linalg,torch,1e-2,1e-3,default,nhcw-nhwc,True,True,True,"https://github.com/nod-ai/SHARK/issues/1487","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported.","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"Inputs for seq2seq models in torch currently unsupported","macos"
|
||||
t5-base,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
|
||||
t5-large,linalg,torch,1e-2,1e-3,default,None,True,True,True,"","macos"
|
||||
|
||||
|
@@ -50,7 +50,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
is_decompose = row[5]
|
||||
|
||||
tracing_required = False if tracing_required == "False" else True
|
||||
is_dynamic = False if is_dynamic == "False" else True
|
||||
is_dynamic = False
|
||||
print("generating artifacts for: " + torch_model_name)
|
||||
model = None
|
||||
input = None
|
||||
@@ -104,7 +104,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
model_name=torch_model_name,
|
||||
mlir_type=mlir_type,
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
tracing_required=True,
|
||||
)
|
||||
else:
|
||||
mlir_importer = SharkImporter(
|
||||
@@ -114,7 +114,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
)
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=False,
|
||||
tracing_required=tracing_required,
|
||||
tracing_required=True,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name,
|
||||
mlir_type=mlir_type,
|
||||
@@ -123,7 +123,7 @@ def save_torch_model(torch_model_list, local_tank_cache, import_args):
|
||||
if is_dynamic:
|
||||
mlir_importer.import_debug(
|
||||
is_dynamic=True,
|
||||
tracing_required=tracing_required,
|
||||
tracing_required=True,
|
||||
dir=torch_model_dir,
|
||||
model_name=torch_model_name + "_dynamic",
|
||||
mlir_type=mlir_type,
|
||||
|
||||
Reference in New Issue
Block a user