Compare commits

..

18 Commits

Author SHA1 Message Date
PhaneeshB
a53e714ddf fix combine mlir for llama2 2023-11-20 14:00:55 +05:30
gpetters94
80a33d427f Save intermediate values of controlnet (#1981) 2023-11-17 19:05:41 -05:00
Stefan Kapusniak
4125a26294 API/Docs: Fix incorrect cors arguments listing (#1983)
* Replace `api_cors_origin` in the api/koboldcpp doc, with the correct
 `api_accept_origin`
2023-11-17 12:29:01 -06:00
Ean Garvey
905d0103ff Revert "Re-enable SD tunings without matmuls. (#1976)" (#1979)
This reverts commit 70817bb50a.
2023-11-17 23:44:33 +05:30
Stefan Kapusniak
192b3b2c61 UI: Output galllery cleanups (#1959)
* Workaround gradio bug that causes the parameters frame to always show
scrollbars.
* Remove the original funky method of setting the number of image
columns in the gallery using _fn= javacript events. The version
of gradio we now have pinned allows doing this by setting the property
on the gallery directly and also doesn't keep resetting the columns on
other events being fired.
2023-11-15 22:20:42 -06:00
Stefan Kapusniak
8f9adc4a2a UI: Display top tag frequencies for selected LoRA (#1972)
* Adds a function to webui utils to read metadata from
.safetensors LoRA files. and do limiting parsing of the format written
out by the Kohya SS scripts (https://github.com/kohya-ss/sd-scripts)
to get tag frequency and trained model information.
* Adds a new common_ui_events.py file for gradio event handlers
needed for multiple UI tabs, and adds an event handler for binding to
the change event of the LoRA selection boxes, that outputs HTML
to display the LoRA tag frequency and model information.
* Adds an HTML gradio control to each of the SD tabs to show the
LoRA model name, and most frequently trained tags.
* Bind the change event of the LoRA selection box on each tab
to our new event handler, with the output set to the relevant HTML
control.
2023-11-15 22:19:54 -06:00
Ean Garvey
70817bb50a Re-enable SD tunings without matmuls. (#1976) 2023-11-15 20:42:53 -06:00
jinchen62
dd37c26d36 Update brevitas quant api (#1975) 2023-11-15 10:04:07 -08:00
PhaneeshB
a708879c6c fix iree version mismatch 2023-11-15 01:24:42 +05:30
Ean Garvey
bb1b49eb6f Add --no-index to setup_venv.sh runtime pip install. 2023-11-14 21:44:20 +05:30
Ean Garvey
f6d41affd9 (SHARK Studio) Add Turbine-based llm chatbot. (#1933)
* Dan shark studio (#1970)

* Fix issue in Falcon-GPTQ

* initial webui and llama2

---------

Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>

* Fix formatting.

---------

Co-authored-by: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com>
Co-authored-by: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
2023-11-14 09:56:28 -06:00
Stefan Kapusniak
c2163488d8 SD/UI Restrict hires fix/img2img resamplers/schedulers (#1955)
* Restrict resamplers for img2img and high res fix to the ones that
PIL.Image actually supports, since it uses that to di the resampling.
Removed: Antialias, Affine, Cubic. Added: Hamming.
* Set list of available schedulers to CPU only when high res fix
is selected in the web ui. Set list to all schdulers when high res fix
is deselected.
* Put hi res fix in its own Accordian in the txt2img UI instead of
grouping it with Advanced Options.
2023-11-13 16:08:24 -06:00
PhaneeshB
54bff4611d fix cli rocm device selection 2023-11-13 23:35:55 +05:30
PhaneeshB
11510d5111 add intra rocm vmfb differentiator 2023-11-13 23:35:55 +05:30
PhaneeshB
32cab73a29 add iree-rocm-target-chip only if added by user 2023-11-13 23:35:55 +05:30
PhaneeshB
392bade0bf enable non default rocm device selection for webui 2023-11-13 23:35:55 +05:30
Stefan Kapusniak
91df5f0613 API/Docs: Fix an image link in koboldcpp doc (#1954)
* Fix the image link for the koboldcpp style button pointing to the
dialog image rather than the button image.
2023-11-13 11:14:29 -06:00
dependabot[bot]
df20cf9c8a Bump langchain in /apps/language_models/langchain (#1968)
Bumps [langchain](https://github.com/langchain-ai/langchain) from 0.0.325 to 0.0.329.
- [Release notes](https://github.com/langchain-ai/langchain/releases)
- [Commits](https://github.com/langchain-ai/langchain/compare/v0.0.325...v0.0.329)

---
updated-dependencies:
- dependency-name: langchain
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-12 19:46:00 -08:00
35 changed files with 1671 additions and 338 deletions

View File

@@ -112,7 +112,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cpu
pytest --benchmark=native --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
python build_tools/vicuna_testing.py
@@ -123,7 +123,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark=native --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k cuda
pytest --benchmark=native --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -146,7 +146,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark="native" --ci --ci_sha=${SHORT_SHA} --update_tank --tank_url="gs://shark_tank/nightly/" -k vulkan
pytest --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)

View File

@@ -25,7 +25,7 @@ from apps.stable_diffusion.src import args
# Brevitas
from typing import List, Tuple
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -101,7 +101,7 @@ class H2OGPTModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=128,

View File

@@ -65,7 +65,7 @@ tiktoken==0.4.0
openai==0.27.8
# optional for chat with PDF
langchain==0.0.325
langchain==0.0.329
pypdf==3.17.0
# avoid textract, requires old six
#textract==1.6.5

View File

@@ -244,7 +244,8 @@ class VicunaBase(SharkLLMBase):
print(f"[DEBUG] output_name = {output_name}")
maps1 = []
maps2 = []
constants = set()
constants_1 = set()
constants_2 = set()
f1 = []
f2 = []
@@ -255,7 +256,7 @@ class VicunaBase(SharkLLMBase):
if re.search("#map\d*\s*=", line):
maps1.append(line)
elif re.search("arith.constant", line):
constants.add(line)
constants_1.add(line)
elif not re.search("module", line):
line = re.sub("forward", "first_vicuna_forward", line)
f1.append(line)
@@ -281,7 +282,7 @@ class VicunaBase(SharkLLMBase):
elif "global_seed" in line:
continue
elif re.search("arith.constant", line):
constants.add(line)
constants_2.add(line)
elif not re.search("module", line):
line = re.sub("forward", "second_vicuna_forward", line)
f2.append(line)
@@ -304,15 +305,21 @@ class VicunaBase(SharkLLMBase):
module_end = "}"
global_vars = []
vnames = []
global_var_loading1 = []
global_var_loading2 = []
global_var_loading1 = dict()
global_var_loading2 = dict()
print(f"[DEBUG] processing constants")
counter = 0
constants = list(constants)
# in both 1 and 2
constants = [(e , "") for e in list(constants_1 & constants_2)]
# only in 1
constants.extend([(e, "_1") for e in list(constants_1.difference(constants_2))])
# only in 2
constants.extend([(e, "_2") for e in list(constants_2.difference(constants_1))])
del constants_1, constants_2
gc.collect()
while constants:
constant = constants.pop(0)
constant, vname_suf = constants.pop(0)
vname, vbody = constant.split("=")
vname = re.sub("%", "", vname)
vname = vname.strip()
@@ -322,35 +329,34 @@ class VicunaBase(SharkLLMBase):
print(constant)
vdtype = vbody.split(":")[-1].strip()
fixed_vdtype = vdtype
noinline = "{noinline}" if "tensor" in fixed_vdtype else ""
if "c1_i64" in vname:
print(constant)
counter += 1
if counter == 2:
counter = 0
print("detected duplicate")
continue
vnames.append(vname)
if "true" not in vname:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}"
f"ml_program.global private @{vname}{vname_suf}({vbody}) : {fixed_vdtype}"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : {fixed_vdtype}"
] = ""
else:
global_vars.append(
f"ml_program.global private @{vname}({vbody}) : i1"
)
global_var_loading1.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
)
global_var_loading2.append(
f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1"
f"ml_program.global private @{vname}{vname_suf}({vbody}) : i1"
)
if vname_suf != "_2":
global_var_loading1[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
] = ""
if vname_suf != "_1":
global_var_loading2[
f"\t\t%{vname} = ml_program.global_load_const @{vname}{vname_suf} : i1"
] = ""
del constants
gc.collect()
new_f1, new_f2 = [], []
@@ -358,7 +364,7 @@ class VicunaBase(SharkLLMBase):
for line in f1:
if "func.func" in line:
new_f1.append(line)
for global_var in global_var_loading1:
for global_var in global_var_loading1.keys():
new_f1.append(global_var)
else:
new_f1.append(line)
@@ -367,7 +373,7 @@ class VicunaBase(SharkLLMBase):
for line in f2:
if "func.func" in line:
new_f2.append(line)
for global_var in global_var_loading2:
for global_var in global_var_loading2.keys():
if (
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
in global_var
@@ -868,7 +874,7 @@ class ShardedVicuna(VicunaBase):
layer0, inputs0[0], inputs0[1], inputs0[2]
)
if self.precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
module0 = torch_mlir.compile(
ts_g,
@@ -1069,7 +1075,7 @@ class ShardedVicuna(VicunaBase):
)
if self.precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
print("Applying weight quantization..")
weight_bit_width = 4 if self.precision == "int4" else 8
@@ -1079,7 +1085,7 @@ class ShardedVicuna(VicunaBase):
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_granularity="per_group",
weight_group_size=self.weight_group_size,
quantize_weight_zero_point=False,
@@ -1259,6 +1265,7 @@ class UnshardedVicuna(VicunaBase):
max_num_tokens=512,
min_num_tokens=0,
device="cpu",
device_id=None,
vulkan_target_triple="",
precision="int8",
vicuna_mlir_path=None,
@@ -1269,7 +1276,6 @@ class UnshardedVicuna(VicunaBase):
download_vmfb=False,
cache_vicunas=False,
extra_args_cmd=[],
device_id=None,
debug=False,
) -> None:
super().__init__(
@@ -1288,9 +1294,7 @@ class UnshardedVicuna(VicunaBase):
print(f"[DEBUG] hf model name: {self.hf_model_path}")
self.max_sequence_length = 256
self.min_num_tokens = min_num_tokens
self.device = device
self.vulkan_target_triple = vulkan_target_triple
self.device_id = device_id
self.precision = precision
self.download_vmfb = download_vmfb
self.vicuna_vmfb_path = vicuna_vmfb_path
@@ -1299,12 +1303,24 @@ class UnshardedVicuna(VicunaBase):
self.low_device_memory = low_device_memory
self.weight_group_size = weight_group_size
self.debug = debug
# Sanity check for device, device_id pair
if "://" in device:
if device_id is not None:
print("[ERR] can't have both full device path and a device id.\n"
f"Device : {device} | device_id : {device_id}\n"
"proceeding with given Device ignoring device_id")
self.device, self.device_id = device.split("://")
if len(self.device_id) < 2:
self.device_id = int(self.device_id)
else:
self.device, self.device_id = device, device_id
if self.vicuna_mlir_path == None:
self.vicuna_mlir_path = self.get_model_path()
if self.vicuna_vmfb_path == None:
self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
self.tokenizer = self.get_tokenizer()
self.cache_vicunas = cache_vicunas
self.compile()
def get_model_path(self, suffix="mlir"):
@@ -1313,13 +1329,27 @@ class UnshardedVicuna(VicunaBase):
if suffix in ["mlirbc", "mlir"]:
return Path(f"{self.model_name}_{self.precision}.{suffix}")
target_triple = ""
if self.vulkan_target_triple != "":
target_triple = "_"
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
# Need to distinguish between multiple vmfbs of the same model
# compiled for different devices of the same driver
# Driver - Differentiator
# Vulkan - target_triple
# ROCm - device_arch
differentiator = ""
if "vulkan" == self.device:
target_triple = ""
if self.vulkan_target_triple != "":
target_triple = "_"
target_triple += "_".join(self.vulkan_target_triple.split("-")[:-1])
differentiator = target_triple
elif "rocm" == self.device:
from shark.iree_utils.gpu_utils import get_rocm_device_arch
device_arch = get_rocm_device_arch(self.device_id if self.device_id is not None else 0, self.extra_args)
differentiator = '_' + device_arch
return Path(
f"{self.model_name}_{self.precision}_{safe_device}{target_triple}.{suffix}"
f"{self.model_name}_{self.precision}_{safe_device}{differentiator}.{suffix}"
)
def get_tokenizer(self):
@@ -1752,9 +1782,8 @@ class UnshardedVicuna(VicunaBase):
)
del first_module, second_module
print(self.device)
if "rocm" in self.device:
self.device = "rocm"
print(f"Compiling for device : {self.device}"
f"{'://' + str(self.device_id) if self.device_id is not None else ''}")
shark_module = SharkInference(
mlir_module=combined_module,
device=self.device,

View File

@@ -5,7 +5,7 @@ from typing import List, Any
from transformers import StoppingCriteria
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -37,7 +37,7 @@ class VisionModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -52,7 +52,7 @@ class VisionModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -93,7 +93,7 @@ class FirstLlamaModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -157,7 +157,7 @@ class SecondLlamaModel(torch.nn.Module):
dtype=torch.float32,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -24,7 +24,9 @@ class FirstVicuna(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -36,7 +38,7 @@ class FirstVicuna(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -79,7 +81,9 @@ class SecondVicuna7B(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -91,7 +95,7 @@ class SecondVicuna7B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -329,7 +333,9 @@ class SecondVicuna13B(torch.nn.Module):
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -341,7 +347,7 @@ class SecondVicuna13B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -627,7 +633,9 @@ class SecondVicuna70B(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -639,7 +647,7 @@ class SecondVicuna70B(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -24,7 +24,9 @@ class FirstVicunaGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -36,7 +38,7 @@ class FirstVicunaGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -78,7 +80,9 @@ class SecondVicuna7BGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -90,7 +94,7 @@ class SecondVicuna7BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -327,7 +331,9 @@ class SecondVicuna13BGPU(torch.nn.Module):
torch.float32 if accumulates == "fp32" else torch.float16
)
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -339,7 +345,7 @@ class SecondVicuna13BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
@@ -625,7 +631,9 @@ class SecondVicuna70BGPU(torch.nn.Module):
)
print(f"[DEBUG] model_path : {model_path}")
if precision in ["int4", "int8"]:
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import (
quantize_model,
)
from brevitas_examples.llm.llm_quant.run_utils import (
get_model_impl,
)
@@ -637,7 +645,7 @@ class SecondVicuna70BGPU(torch.nn.Module):
dtype=self.accumulates,
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_type="asym",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,

View File

@@ -132,7 +132,7 @@ import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from typing import List, Tuple
from io import BytesIO
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl

View File

@@ -0,0 +1,91 @@
from turbine_models.custom_models import stateless_llama
from shark.iree_utils.compile_utils import get_iree_compiled_module
from apps.shark_studio.api.utils import get_resource_path
import iree.runtime as ireert
import gc
import torch
llm_model_map = {
"llama2_7b": {
"initializer": stateless_llama.export_transformer_model,
"hf_model_name": "meta-llama/Llama-2-7b-chat-hf",
"stop_token": 2,
"max_tokens": 4096,
}
}
class LanguageModel:
def __init__(
self, model_name, hf_auth_token=None, device=None, precision="fp32"
):
print(llm_model_map[model_name])
self.hf_model_name = llm_model_map[model_name]["hf_model_name"]
self.torch_ir, self.tokenizer = llm_model_map[model_name][
"initializer"
](self.hf_model_name, hf_auth_token, compile_to="torch")
self.tempfile_name = get_resource_path("llm.torch.tempfile")
with open(self.tempfile_name, "w+") as f:
f.write(self.torch_ir)
del self.torch_ir
gc.collect()
self.device = device
self.precision = precision
self.max_tokens = llm_model_map[model_name]["max_tokens"]
self.iree_module_dict = None
self.compile()
def compile(self) -> None:
# this comes with keys: "vmfb", "config", and "temp_file_to_unlink".
self.iree_module_dict = get_iree_compiled_module(
self.tempfile_name, device=self.device, frontend="torch"
)
# TODO: delete the temp file
def chat(self, prompt):
history = []
for iter in range(self.max_tokens):
input_tensor = self.tokenizer(
prompt, return_tensors="pt"
).input_ids
device_inputs = [
ireert.asdevicearray(
self.iree_module_dict["config"], input_tensor
)
]
if iter == 0:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_initialize"](
*device_inputs
).to_host()[0][0]
)
else:
token = torch.tensor(
self.iree_module_dict["vmfb"]["run_forward"](
*device_inputs
).to_host()[0][0]
)
history.append(token)
yield self.tokenizer.decode(history)
if token == llm_model_map["llama2_7b"]["stop_token"]:
break
for i in range(len(history)):
if type(history[i]) != int:
history[i] = int(history[i])
result_output = self.tokenizer.decode(history)
yield result_output
if __name__ == "__main__":
lm = LanguageModel(
"llama2_7b",
hf_auth_token="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",
device="cpu-task",
)
print("model loaded")
for i in lm.chat("Hello, I am a robot."):
print(i)

View File

@@ -0,0 +1,14 @@
import os
import sys
def get_available_devices():
return ["cpu-task"]
def get_resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)

View File

@@ -0,0 +1,428 @@
from multiprocessing import Process, freeze_support
import os
import sys
import logging
from ui.chat import chat_element
if sys.platform == "darwin":
os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib"
# import before IREE to avoid MLIR library issues
import torch_mlir
# import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation
# from apps.stable_diffusion.src import args, clear_all
# import apps.stable_diffusion.web.utils.global_obj as global_obj
def launch_app(address):
from tkinter import Tk
import webview
window = Tk()
# get screen width and height of display and make it more reasonably
# sized as we aren't making it full-screen or maximized
width = int(window.winfo_screenwidth() * 0.81)
height = int(window.winfo_screenheight() * 0.91)
webview.create_window(
"SHARK AI Studio",
url=address,
width=width,
height=height,
text_select=True,
)
webview.start(private_mode=False, storage_path=os.getcwd())
if __name__ == "__main__":
# if args.debug:
logging.basicConfig(level=logging.DEBUG)
# required to do multiprocessing in a pyinstaller freeze
freeze_support()
# if args.api or "api" in args.ui.split(","):
# from apps.stable_diffusion.web.ui import (
# txt2img_api,
# img2img_api,
# upscaler_api,
# inpaint_api,
# outpaint_api,
# llm_chat_api,
# )
#
# from fastapi import FastAPI, APIRouter
# import uvicorn
#
# # init global sd pipeline and config
# global_obj._init()
#
# app = FastAPI()
# app.add_api_route("/sdapi/v1/txt2img", txt2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/img2img", img2img_api, methods=["post"])
# app.add_api_route("/sdapi/v1/inpaint", inpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/outpaint", outpaint_api, methods=["post"])
# app.add_api_route("/sdapi/v1/upscaler", upscaler_api, methods=["post"])
#
# # chat APIs needed for compatibility with multiple extensions using OpenAI API
# app.add_api_route(
# "/v1/chat/completions", llm_chat_api, methods=["post"]
# )
# app.add_api_route("/v1/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/chat/completions", llm_chat_api, methods=["post"])
# app.add_api_route("/completions", llm_chat_api, methods=["post"])
# app.add_api_route(
# "/v1/engines/codegen/completions", llm_chat_api, methods=["post"]
# )
# app.include_router(APIRouter())
# uvicorn.run(app, host="0.0.0.0", port=args.server_port)
# sys.exit(0)
#
# Setup to use shark_tmp for gradio's temporary image files and clear any
# existing temporary images there if they exist. Then we can import gradio.
# It has to be in this order or gradio ignores what we've set up.
# from apps.stable_diffusion.web.utils.gradio_configs import (
# config_gradio_tmp_imgs_folder,
# )
# config_gradio_tmp_imgs_folder()
import gradio as gr
# Create custom models folders if they don't exist
# from apps.stable_diffusion.web.ui.utils import create_custom_models_folders
# create_custom_models_folders()
def resource_path(relative_path):
"""Get absolute path to resource, works for dev and for PyInstaller"""
base_path = getattr(
sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__))
)
return os.path.join(base_path, relative_path)
dark_theme = resource_path("ui/css/sd_dark_theme.css")
# from apps.stable_diffusion.web.ui import (
# txt2img_web,
# txt2img_custom_model,
# txt2img_gallery,
# txt2img_png_info_img,
# txt2img_status,
# txt2img_sendto_img2img,
# txt2img_sendto_inpaint,
# txt2img_sendto_outpaint,
# txt2img_sendto_upscaler,
## h2ogpt_upload,
## h2ogpt_web,
# img2img_web,
# img2img_custom_model,
# img2img_gallery,
# img2img_init_image,
# img2img_status,
# img2img_sendto_inpaint,
# img2img_sendto_outpaint,
# img2img_sendto_upscaler,
# inpaint_web,
# inpaint_custom_model,
# inpaint_gallery,
# inpaint_init_image,
# inpaint_status,
# inpaint_sendto_img2img,
# inpaint_sendto_outpaint,
# inpaint_sendto_upscaler,
# outpaint_web,
# outpaint_custom_model,
# outpaint_gallery,
# outpaint_init_image,
# outpaint_status,
# outpaint_sendto_img2img,
# outpaint_sendto_inpaint,
# outpaint_sendto_upscaler,
# upscaler_web,
# upscaler_custom_model,
# upscaler_gallery,
# upscaler_init_image,
# upscaler_status,
# upscaler_sendto_img2img,
# upscaler_sendto_inpaint,
# upscaler_sendto_outpaint,
## lora_train_web,
## model_web,
## model_config_web,
# hf_models,
# modelmanager_sendto_txt2img,
# modelmanager_sendto_img2img,
# modelmanager_sendto_inpaint,
# modelmanager_sendto_outpaint,
# modelmanager_sendto_upscaler,
# stablelm_chat,
# minigpt4_web,
# outputgallery_web,
# outputgallery_tab_select,
# outputgallery_watch,
# outputgallery_filename,
# outputgallery_sendto_txt2img,
# outputgallery_sendto_img2img,
# outputgallery_sendto_inpaint,
# outputgallery_sendto_outpaint,
# outputgallery_sendto_upscaler,
# )
# init global sd pipeline and config
# global_obj._init()
def register_button_click(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x[0]["name"] if len(x) != 0 else None,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_modelmanager_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
"None",
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
def register_outputgallery_button(button, selectedid, inputs, outputs):
button.click(
lambda x: (
x,
gr.Tabs.update(selected=selectedid),
),
inputs,
outputs,
)
with gr.Blocks(
css=dark_theme, analytics_enabled=False, title="Stable Diffusion"
) as sd_web:
with gr.Tabs() as tabs:
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
# have a unique id that doesn't clash with any of the other tabs,
# and that the order in the code here is the order they should
# appear in the ui, as the id value doesn't determine the order.
# Where possible, avoid changing the id of any tab that is the
# destination of one of the 'send to' buttons. If you do have to change
# that id, make sure you update the relevant register_button_click calls
# further down with the new id.
# with gr.TabItem(label="Text-to-Image", id=0):
# txt2img_web.render()
# with gr.TabItem(label="Image-to-Image", id=1):
# img2img_web.render()
# with gr.TabItem(label="Inpainting", id=2):
# inpaint_web.render()
# with gr.TabItem(label="Outpainting", id=3):
# outpaint_web.render()
# with gr.TabItem(label="Upscaler", id=4):
# upscaler_web.render()
# if args.output_gallery:
# with gr.TabItem(label="Output Gallery", id=5) as og_tab:
# outputgallery_web.render()
# # extra output gallery configuration
# outputgallery_tab_select(og_tab.select)
# outputgallery_watch(
# [
# txt2img_status,
# img2img_status,
# inpaint_status,
# outpaint_status,
# upscaler_status,
# ]
# )
## with gr.TabItem(label="Model Manager", id=6):
## model_web.render()
## with gr.TabItem(label="LoRA Training (Experimental)", id=7):
## lora_train_web.render()
with gr.TabItem(label="Chat Bot", id=0):
chat_element.render()
## with gr.TabItem(
## label="Generate Sharding Config (Experimental)", id=9
## ):
## model_config_web.render()
# with gr.TabItem(label="MultiModal (Experimental)", id=10):
# minigpt4_web.render()
# with gr.TabItem(label="DocuChat Upload", id=11):
# h2ogpt_upload.render()
# with gr.TabItem(label="DocuChat(Experimental)", id=12):
# h2ogpt_web.render()
# send to buttons
# register_button_click(
# txt2img_sendto_img2img,
# 1,
# [txt2img_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_inpaint,
# 2,
# [txt2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_outpaint,
# 3,
# [txt2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# txt2img_sendto_upscaler,
# 4,
# [txt2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_inpaint,
# 2,
# [img2img_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_outpaint,
# 3,
# [img2img_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# img2img_sendto_upscaler,
# 4,
# [img2img_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_img2img,
# 1,
# [inpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_outpaint,
# 3,
# [inpaint_gallery],
# [outpaint_init_image, tabs],
# )
# register_button_click(
# inpaint_sendto_upscaler,
# 4,
# [inpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_img2img,
# 1,
# [outpaint_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_inpaint,
# 2,
# [outpaint_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# outpaint_sendto_upscaler,
# 4,
# [outpaint_gallery],
# [upscaler_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_img2img,
# 1,
# [upscaler_gallery],
# [img2img_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_inpaint,
# 2,
# [upscaler_gallery],
# [inpaint_init_image, tabs],
# )
# register_button_click(
# upscaler_sendto_outpaint,
# 3,
# [upscaler_gallery],
# [outpaint_init_image, tabs],
# )
# if args.output_gallery:
# register_outputgallery_button(
# outputgallery_sendto_txt2img,
# 0,
# [outputgallery_filename],
# [txt2img_png_info_img, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_img2img,
# 1,
# [outputgallery_filename],
# [img2img_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_inpaint,
# 2,
# [outputgallery_filename],
# [inpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_outpaint,
# 3,
# [outputgallery_filename],
# [outpaint_init_image, tabs],
# )
# register_outputgallery_button(
# outputgallery_sendto_upscaler,
# 4,
# [outputgallery_filename],
# [upscaler_init_image, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_txt2img,
# 0,
# [hf_models],
# [txt2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_img2img,
# 1,
# [hf_models],
# [img2img_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_inpaint,
# 2,
# [hf_models],
# [inpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_outpaint,
# 3,
# [hf_models],
# [outpaint_custom_model, tabs],
# )
# register_modelmanager_button(
# modelmanager_sendto_upscaler,
# 4,
# [hf_models],
# [upscaler_custom_model, tabs],
# )
sd_web.queue()
# if args.ui == "app":
# t = Process(
# target=launch_app, args=[f"http://localhost:{args.server_port}"]
# )
# t.start()
sd_web.launch(
share=True,
inbrowser=True,
server_name="0.0.0.0",
server_port=11911, # args.server_port,
)

View File

View File

@@ -0,0 +1,517 @@
import gradio as gr
import os
from pathlib import Path
from datetime import datetime as dt
import json
import sys
from apps.shark_studio.api.utils import (
get_available_devices,
)
from apps.shark_studio.api.llm import (
llm_model_map,
LanguageModel,
)
def user(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
language_model = None
# NOTE: Each `model_name` should have its own start message
start_message = {
"llama2_7b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_13b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"llama2_70b": (
"You are a helpful, respectful and honest assistant. Always answer "
"as helpfully as possible, while being safe. Your answers should not "
"include any harmful, unethical, racist, sexist, toxic, dangerous, or "
"illegal content. Please ensure that your responses are socially "
"unbiased and positive in nature. If a question does not make any "
"sense, or is not factually coherent, explain why instead of "
"answering something not correct. If you don't know the answer "
"to a question, please don't share false information."
),
"vicuna": (
"A chat between a curious user and an artificial intelligence "
"assistant. The assistant gives helpful, detailed, and "
"polite answers to the user's questions.\n"
),
}
def create_prompt(model_name, history, prompt_prefix):
return ""
system_message = ""
if prompt_prefix:
system_message = start_message[model_name]
if "llama2" in model_name:
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
conversation = "".join(
[f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]]
)
if prompt_prefix:
msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}"
else:
msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}"
elif model_name in ["vicuna"]:
conversation = "".join(
[
"".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]])
for item in history
]
)
msg = system_message + conversation
msg = msg.strip()
else:
conversation = "".join(
["".join([item[0], item[1]]) for item in history]
)
msg = system_message + conversation
msg = msg.strip()
return msg
def get_default_config():
return False
import torch
from transformers import AutoTokenizer
hf_model_path = "TheBloke/vicuna-7B-1.1-HF"
tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False)
compilation_prompt = "".join(["0" for _ in range(17)])
compilation_input_ids = tokenizer(
compilation_prompt,
return_tensors="pt",
).input_ids
compilation_input_ids = torch.tensor(compilation_input_ids).reshape(
[1, 19]
)
firstVicunaCompileInput = (compilation_input_ids,)
from apps.language_models.src.model_wrappers.vicuna_model import (
CombinedModel,
)
from shark.shark_generate_model_config import GenerateConfigFile
model = CombinedModel()
c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput)
c.split_into_layers()
# model_vmfb_key = ""
def chat_fn(
prompt_prefix,
history,
model,
device,
precision,
download_vmfb,
config_file,
cli=False,
progress=gr.Progress(),
):
global language_model
if language_model is None:
language_model = LanguageModel(
model, device=device, precision=precision
)
language_model.chat(prompt_prefix)
return "", ""
global past_key_values
global model_vmfb_key
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
else:
print("unrecognized device")
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
from apps.stable_diffusion.src import args
new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}"
if vicuna_model is None or new_model_vmfb_key != model_vmfb_key:
model_vmfb_key = new_model_vmfb_key
max_toks = 128 if model_name == "codegen" else 512
# get iree flags that need to be overridden, from commandline args
_extra_args = []
# vulkan target triple
vulkan_target_triple = args.iree_vulkan_target_triple
from shark.iree_utils.vulkan_utils import (
get_all_vulkan_devices,
get_vulkan_target_triple,
)
if device == "vulkan":
vulkaninfo_list = get_all_vulkan_devices()
if vulkan_target_triple == "":
# We already have the device_id extracted via WebUI, so we directly use
# that to find the target triple.
vulkan_target_triple = get_vulkan_target_triple(
vulkaninfo_list[device_id]
)
_extra_args.append(
f"-iree-vulkan-target-triple={vulkan_target_triple}"
)
if "rdna" in vulkan_target_triple:
flags_to_add = [
"--iree-spirv-index-bits=64",
]
_extra_args = _extra_args + flags_to_add
if device_id is None:
id = 0
for device in vulkaninfo_list:
target_triple = get_vulkan_target_triple(
vulkaninfo_list[id]
)
if target_triple == vulkan_target_triple:
device_id = id
break
id += 1
assert (
device_id
), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists"
print(f"Will use vulkan target triple : {vulkan_target_triple}")
elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
compressed=True,
extra_args_cmd=_extra_args,
)
else:
# if config_file is None:
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
hf_auth_token=args.hf_auth_token,
device=device,
vulkan_target_triple=vulkan_target_triple,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=download_vmfb,
load_mlir_from_shark_tank=True,
extra_args_cmd=_extra_args,
device_id=device_id,
)
if vicuna_model is None:
sys.exit("Unable to instantiate the model object, exiting.")
prompt = create_prompt(model_name, history, prompt_prefix)
partial_text = ""
token_count = 0
total_time_ms = 0.001 # In order to avoid divide by zero error
prefill_time = 0
is_first = True
for text, msg, exec_time in progress.tqdm(
vicuna_model.generate(prompt, cli=cli),
desc="generating response",
):
if msg is None:
if is_first:
prefill_time = exec_time
is_first = False
else:
total_time_ms += exec_time
token_count += 1
partial_text += text + " "
history[-1][1] = partial_text
yield history, f"Prefill: {prefill_time:.2f}"
elif "formatted" in msg:
history[-1][1] = text
tokens_per_sec = (token_count / total_time_ms) * 1000
yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec"
else:
sys.exit(
"unexpected message from the vicuna generate call, exiting."
)
return history, ""
def llm_chat_api(InputData: dict):
return None
print(f"Input keys : {InputData.keys()}")
# print(f"model : {InputData['model']}")
is_chat_completion_api = (
"messages" in InputData.keys()
) # else it is the legacy `completion` api
# For Debugging input data from API
# if is_chat_completion_api:
# print(f"message -> role : {InputData['messages'][0]['role']}")
# print(f"message -> content : {InputData['messages'][0]['content']}")
# else:
# print(f"prompt : {InputData['prompt']}")
# print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now
global vicuna_model
model_name = (
InputData["model"] if "model" in InputData.keys() else "codegen"
)
model_path = llm_model_map[model_name]
device = "cpu-task"
precision = "fp16"
max_toks = (
None
if "max_tokens" not in InputData.keys()
else InputData["max_tokens"]
)
if max_toks is None:
max_toks = 128 if model_name == "codegen" else 512
# make it working for codegen first
from apps.language_models.scripts.vicuna import (
UnshardedVicuna,
)
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
else:
print("unrecognized device")
vicuna_model = UnshardedVicuna(
model_name,
hf_model_path=model_path,
device=device,
precision=precision,
max_num_tokens=max_toks,
download_vmfb=True,
load_mlir_from_shark_tank=True,
device_id=device_id,
)
# TODO: add role dict for different models
if is_chat_completion_api:
# TODO: add funtionality for multiple messages
prompt = create_prompt(
model_name, [(InputData["messages"][0]["content"], "")]
)
else:
prompt = InputData["prompt"]
print("prompt = ", prompt)
res = vicuna_model.generate(prompt)
res_op = None
for op in res:
res_op = op
if is_chat_completion_api:
choices = [
{
"index": 0,
"message": {
"role": "assistant",
"content": res_op, # since we are yeilding the result
},
"finish_reason": "stop", # or length
}
]
else:
choices = [
{
"text": res_op,
"index": 0,
"logprobs": None,
"finish_reason": "stop", # or length
}
]
end_time = dt.now().strftime("%Y%m%d%H%M%S%f")
return {
"id": end_time,
"object": "chat.completion"
if is_chat_completion_api
else "text_completion",
"created": int(end_time),
"choices": choices,
}
def view_json_file(file_obj):
content = ""
with open(file_obj.name, "r") as fopen:
content = fopen.read()
return content
with gr.Blocks(title="Chat") as chat_element:
with gr.Row():
model_choices = list(llm_model_map.keys())
model = gr.Dropdown(
label="Select Model",
value=model_choices[0],
choices=model_choices,
allow_custom_value=True,
)
supported_devices = get_available_devices()
enabled = True
if len(supported_devices) == 0:
supported_devices = ["cpu-task"]
supported_devices = [x for x in supported_devices if "sync" not in x]
device = gr.Dropdown(
label="Device",
value=supported_devices[0],
choices=supported_devices,
interactive=enabled,
allow_custom_value=True,
)
precision = gr.Radio(
label="Precision",
value="int4",
choices=[
# "int4",
# "int8",
# "fp16",
"fp32",
],
visible=False,
)
tokens_time = gr.Textbox(label="Tokens generated per second")
with gr.Column():
download_vmfb = gr.Checkbox(
label="Download vmfb from Shark tank if available",
value=True,
interactive=True,
)
prompt_prefix = gr.Checkbox(
label="Add System Prompt",
value=False,
interactive=True,
)
chatbot = gr.Chatbot(height=500)
with gr.Row():
with gr.Column():
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Chat Message Box",
show_label=False,
interactive=enabled,
container=False,
)
with gr.Column():
with gr.Row():
submit = gr.Button("Submit", interactive=enabled)
stop = gr.Button("Stop", interactive=enabled)
clear = gr.Button("Clear", interactive=enabled)
with gr.Row(visible=False):
with gr.Group():
config_file = gr.File(
label="Upload sharding configuration", visible=False
)
json_view_button = gr.Button(label="View as JSON", visible=False)
json_view = gr.JSON(interactive=True, visible=False)
json_view_button.click(
fn=view_json_file, inputs=[config_file], outputs=[json_view]
)
submit_event = msg.submit(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
submit_click_event = submit.click(
fn=user,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
show_progress=False,
queue=False,
).then(
fn=chat_fn,
inputs=[
prompt_prefix,
chatbot,
model,
device,
precision,
download_vmfb,
config_file,
],
outputs=[chatbot, tokens_time],
show_progress=False,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[submit_event, submit_click_event],
queue=False,
)
clear.click(lambda: None, None, [chatbot], queue=False)

View File

@@ -29,6 +29,10 @@ from apps.stable_diffusion.src.models import (
SharkifyStableDiffusionModel,
get_vae_encode,
)
from apps.stable_diffusion.src.utils import (
resamplers,
resampler_list,
)
class Image2ImagePipeline(StableDiffusionPipeline):
@@ -91,26 +95,12 @@ class Image2ImagePipeline(StableDiffusionPipeline):
# TODO: process with variable HxW combos
# Pre-process image
if resample_type == "Lanczos":
resample_type = Image.LANCZOS
elif resample_type == "Nearest Neighbor":
resample_type = Image.NEAREST
elif resample_type == "Bilinear":
resample_type = Image.BILINEAR
elif resample_type == "Bicubic":
resample_type = Image.BICUBIC
elif resample_type == "Adaptive":
resample_type = Image.ADAPTIVE
elif resample_type == "Antialias":
resample_type = Image.ANTIALIAS
elif resample_type == "Box":
resample_type = Image.BOX
elif resample_type == "Affine":
resample_type = Image.AFFINE
elif resample_type == "Cubic":
resample_type = Image.CUBIC
else: # Fallback to Lanczos
resample_type = Image.LANCZOS
resample_type = (
resamplers[resample_type]
if resample_type in resampler_list
# Fallback to Lanczos
else Image.Resampling.LANCZOS
)
image = image.resize((width, height), resample=resample_type)
image_arr = np.stack([np.array(i) for i in (image,)], axis=0)

View File

@@ -42,3 +42,7 @@ from apps.stable_diffusion.src.utils.utils import (
_compile_module,
)
from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint
from apps.stable_diffusion.src.utils.resamplers import (
resamplers,
resampler_list,
)

View File

@@ -0,0 +1,12 @@
import PIL.Image as Image
resamplers = {
"Lanczos": Image.Resampling.LANCZOS,
"Nearest Neighbor": Image.Resampling.NEAREST,
"Bilinear": Image.Resampling.BILINEAR,
"Bicubic": Image.Resampling.BICUBIC,
"Hamming": Image.Resampling.HAMMING,
"Box": Image.Resampling.BOX,
}
resampler_list = resamplers.keys()

View File

@@ -2,6 +2,8 @@ import argparse
import os
from pathlib import Path
from apps.stable_diffusion.src.utils.resamplers import resampler_list
def path_expand(s):
return Path(s).expanduser().resolve()
@@ -168,17 +170,7 @@ p.add_argument(
"--resample_type",
type=str,
default="Nearest Neighbor",
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
choices=resampler_list,
help="The resample type to use when resizing an image before being run "
"through stable diffusion.",
)
@@ -746,8 +738,9 @@ p.add_argument(
p.add_argument(
"--iree_rocm_target_chip",
type=str,
default="gfx1100",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Default gfx1100",
default="",
help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` "
"or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name",
)
args, unknown = p.parse_known_args()

View File

@@ -1,6 +1,10 @@
import numpy as np
from PIL import Image
import torch
import os
from pathlib import Path
import torchvision
import time
from apps.stable_diffusion.src.utils.stencils import (
CannyDetector,
OpenposeDetector,
@@ -10,6 +14,33 @@ from apps.stable_diffusion.src.utils.stencils import (
stencil = {}
def save_img(img):
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generated_imgs_todays_subdir,
)
subdir = Path(
get_generated_imgs_path(), get_generated_imgs_todays_subdir()
)
os.makedirs(subdir, exist_ok=True)
if isinstance(img, Image.Image):
img.save(
os.path.join(
subdir, "controlnet_" + str(int(time.time())) + ".png"
)
)
elif isinstance(img, np.ndarray):
img = Image.fromarray(img)
img.save(os.path.join(subdir, str(int(time.time())) + ".png"))
else:
converter = torchvision.transforms.ToPILImage()
for i in img:
converter(i).save(
os.path.join(subdir, str(int(time.time())) + ".png")
)
def HWC3(x):
assert x.dtype == np.uint8
if x.ndim == 2:
@@ -161,6 +192,7 @@ def hint_canny(
detected_map = stencil["canny"](
input_image, low_threshold, high_threshold
)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -176,6 +208,7 @@ def hint_openpose(
stencil["openpose"] = OpenposeDetector()
detected_map, _ = stencil["openpose"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map
@@ -187,6 +220,7 @@ def hint_scribble(image: Image.Image):
detected_map = np.zeros_like(input_image, dtype=np.uint8)
detected_map[np.min(input_image, axis=2) < 127] = 255
save_img(detected_map)
return detected_map
@@ -199,5 +233,6 @@ def hint_zoedepth(image: Image.Image):
stencil["depth"] = ZoeDetector()
detected_map = stencil["depth"](input_image)
save_img(detected_map)
detected_map = HWC3(detected_map)
return detected_map

View File

@@ -0,0 +1,55 @@
from apps.stable_diffusion.web.ui.utils import (
HSLHue,
hsl_color,
get_lora_metadata,
)
# Answers HTML to show the most frequent tags used when a LoRA was trained,
# taken from the metadata of its .safetensors file.
def lora_changed(lora_file):
# tag frequency percentage, that gets maximum amount of the staring hue
TAG_COLOR_THRESHOLD = 0.55
# tag frequency percentage, above which a tag is displayed
TAG_DISPLAY_THRESHOLD = 0.65
# template for the html used to display a tag
TAG_HTML_TEMPLATE = '<span class="lora-tag" style="border: 1px solid {color};">{tag}</span>'
if lora_file == "None":
return ["<div><i>No LoRA selected</i></div>"]
elif not lora_file.lower().endswith(".safetensors"):
return [
"<div><i>Only metadata queries for .safetensors files are currently supported</i></div>"
]
else:
metadata = get_lora_metadata(lora_file)
if metadata:
frequencies = metadata["frequencies"]
return [
"".join(
[
f'<div class="lora-model">Trained against weights in: {metadata["model"]}</div>'
]
+ [
TAG_HTML_TEMPLATE.format(
color=hsl_color(
(tag[1] - TAG_COLOR_THRESHOLD)
/ (1 - TAG_COLOR_THRESHOLD),
start=HSLHue.RED,
end=HSLHue.GREEN,
),
tag=tag[0],
)
for tag in frequencies
if tag[1] > TAG_DISPLAY_THRESHOLD
],
)
]
elif metadata is None:
return [
"<div><i>This LoRA does not publish tag frequency metadata</i></div>"
]
else:
return [
"<div><i>This LoRA has empty tag frequency metadata, or we could not parse it</i></div>"
]

View File

@@ -246,10 +246,39 @@ footer {
background-color: var(--block-label-background-fill);
}
/* lora tag pills */
.lora-tags {
border: 1px solid var(--border-color-primary);
color: var(--block-info-text-color) !important;
padding: var(--block-padding);
}
.lora-tag {
display: inline-block;
height: 2em;
color: rgb(212 212 212) !important;
margin-right: 5pt;
margin-bottom: 5pt;
padding: 2pt 5pt;
border-radius: 5pt;
white-space: nowrap;
}
.lora-model {
margin-bottom: var(--spacing-lg);
color: var(--block-info-text-color) !important;
line-height: var(--line-sm);
}
/* output gallery tab */
.output_parameters_dataframe table.table {
/* works around a gradio bug that always shows scrollbars */
overflow: clip auto;
}
.output_parameters_dataframe tbody td {
font-size: small;
line-height: var(--line-xs)
line-height: var(--line-xs);
}
.output_icon_button {

View File

@@ -5,6 +5,7 @@ import gradio as gr
import PIL
from math import ceil
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -14,6 +15,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.src import (
args,
Image2ImagePipeline,
@@ -27,6 +29,7 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
resampler_list,
)
from apps.stable_diffusion.web.utils.common_label_calc import status_label
import numpy as np
@@ -435,6 +438,11 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -486,17 +494,7 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
choices=resampler_list,
label="Resample Type",
allow_custom_value=True,
)
@@ -647,3 +645,10 @@ with gr.Blocks(title="Image-to-Image") as img2img_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -4,6 +4,7 @@ import time
import sys
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -13,6 +14,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_paint_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.src import (
args,
InpaintPipeline,
@@ -319,6 +321,11 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -518,3 +525,10 @@ with gr.Blocks(title="Inpainting") as inpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -3,9 +3,8 @@ import torch
import time
import gradio as gr
from PIL import Image
import base64
from io import BytesIO
from fastapi.exceptions import HTTPException
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -323,6 +322,11 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -546,3 +550,10 @@ with gr.Blocks(title="Outpainting") as outpaint_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -91,7 +91,7 @@ with gr.Blocks() as outputgallery_web:
value=gallery_files.value,
visible=False,
show_label=True,
columns=2,
columns=4,
)
with gr.Column(scale=4):
@@ -204,6 +204,9 @@ with gr.Blocks() as outputgallery_web:
),
]
def on_image_columns_change(columns):
return gr.Gallery.update(columns=columns)
def on_select_subdir(subdir) -> list:
# evt.value is the subdirectory name
new_images = outputgallery_filenames(subdir)
@@ -365,53 +368,6 @@ with gr.Blocks() as outputgallery_web:
gr.update(),
)
# Unfortunately as of gradio 3.34.0 gr.update against Galleries doesn't
# support things set with .style, nor the elem_classes kwarg, so we have
# to directly set things up via JavaScript if we want the client to take
# notice of our changes to the number of columns after it decides to put
# them back to the original number when we change something
def js_set_columns_in_browser(timeout_length):
return f"""
(new_cols) => {{
setTimeout(() => {{
required_style = "auto ".repeat(new_cols).trim();
gallery = document.querySelector('#outputgallery_gallery .grid-container');
if (gallery) {{
gallery.style.gridTemplateColumns = required_style
}}
}}, {timeout_length});
return []; // prevents console error from gradio
}}
"""
# --- Wire handlers up to the actions
# Many actions reset the number of columns shown in the gallery on the
# browser end, so we have to set them back to what we think they should
# be after the initial action.
#
# None of the actions on this tab trigger inference, and we want the
# user to be able to do them whilst other tabs have ongoing inference
# running. Waiting in the queue behind inference jobs would mean the UI
# can't fully respond until the inference tasks complete,
# hence queue=False on all of these.
set_gallery_columns_immediate = dict(
fn=None,
inputs=[image_columns],
# gradio blanks the UI on Chrome on Linux on gallery select if
# I don't put an output here
outputs=[dev_null],
_js=js_set_columns_in_browser(0),
queue=False,
)
# setting columns after selecting a gallery item needs a real
# timeout length for the number of columns to actually be applied.
# Not really sure why, maybe something has to finish animating?
set_gallery_columns_delayed = dict(
set_gallery_columns_immediate, _js=js_set_columns_in_browser(250)
)
# clearing images when we need to completely change what's in the
# gallery avoids current images being shown replacing piecemeal and
# prevents weirdness and errors if the user selects an image during the
@@ -423,32 +379,35 @@ with gr.Blocks() as outputgallery_web:
queue=False,
)
image_columns.change(**set_gallery_columns_immediate)
subdirectories.select(**clear_gallery).then(
on_select_subdir,
[subdirectories],
[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
open_subdir.click(
on_open_subdir, inputs=[subdirectories], queue=False
).then(**set_gallery_columns_immediate)
open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False)
refresh.click(**clear_gallery).then(
on_refresh,
[subdirectories],
[subdirectories, subdirectory_paths, gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)
image_columns.change(
fn=on_image_columns_change,
inputs=[image_columns],
outputs=[gallery],
queue=False,
)
gallery.select(
on_select_image,
[gallery_files],
[outputgallery_filename, image_parameters],
queue=False,
).then(**set_gallery_columns_delayed)
)
outputgallery_filename.change(
on_outputgallery_filename_change,
@@ -477,7 +436,7 @@ with gr.Blocks() as outputgallery_web:
open_subdir,
],
queue=False,
).then(**set_gallery_columns_immediate)
)
# We should have been passed a list of components on other tabs that update
# when a new image has generated on that tab, so set things up so the user
@@ -489,4 +448,4 @@ with gr.Blocks() as outputgallery_web:
inputs=[subdirectories, subdirectory_paths, component],
outputs=[gallery_files, gallery, logo],
queue=False,
).then(**set_gallery_columns_immediate)
)

View File

@@ -132,6 +132,27 @@ def get_default_config():
c.split_into_layers()
def clean_device_info(raw_device):
# return appropriate device and device_id for consumption by LLM pipeline
# Multiple devices only supported for vulkan and rocm (as of now).
# default device must be selected for all others
device_id = None
device = (
raw_device
if "=>" not in raw_device
else raw_device.split("=>")[1].strip()
)
if "://" in device:
device, device_id = device.split("://")
device_id = int(device_id) # using device index in webui
if device not in ["rocm", "vulkan"]:
device_id = None
return device, device_id
model_vmfb_key = ""
@@ -151,24 +172,8 @@ def chat(
global model_vmfb_key
global vicuna_model
device_id = None
model_name, model_path = list(map(str.strip, model.split("=>")))
device = device if "=>" not in device else device.split("=>")[1].strip()
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "rocm" in device:
device = "rocm"
elif "metal" in device:
device = "metal"
else:
print("unrecognized device")
device, device_id = clean_device_info(device)
from apps.language_models.scripts.vicuna import ShardedVicuna
from apps.language_models.scripts.vicuna import UnshardedVicuna
@@ -223,10 +228,11 @@ def chat(
elif "rocm" in device:
# add iree rocm flags
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if args.iree_rocm_target_chip != "":
_extra_args.append(
f"--iree-rocm-target-chip={args.iree_rocm_target_chip}"
)
print(f"extra args = {_extra_args}")
if model_name == "vicuna4":
vicuna_model = ShardedVicuna(
@@ -325,19 +331,7 @@ def llm_chat_api(InputData: dict):
device_id = None
if vicuna_model == 0:
if "cuda" in device:
device = "cuda"
elif "sync" in device:
device = "cpu-sync"
elif "task" in device:
device = "cpu-task"
elif "vulkan" in device:
device_id = int(device.split("://")[1])
device = "vulkan"
elif "metal" in device:
device = "metal"
else:
print("unrecognized device")
device, device_id = clean_device_info(device)
vicuna_model = UnshardedVicuna(
model_name,

View File

@@ -5,15 +5,18 @@ import sys
import gradio as gr
from PIL import Image
from math import ceil
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
get_custom_model_path,
get_custom_model_files,
scheduler_list,
scheduler_list_cpu_only,
predefined_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.metadata import import_png_metadata
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
@@ -29,6 +32,7 @@ from apps.stable_diffusion.src import (
from apps.stable_diffusion.src.utils import (
get_generated_imgs_path,
get_generation_text_info,
resampler_list,
)
# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir.
@@ -394,6 +398,11 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -465,50 +474,6 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
label="Low VRAM",
interactive=True,
)
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=[
"Lanczos",
"Nearest Neighbor",
"Bilinear",
"Bicubic",
"Adaptive",
"Antialias",
"Box",
"Affine",
"Cubic",
],
label="Resample Type",
allow_custom_value=True,
)
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
with gr.Column(scale=3):
batch_count = gr.Slider(
@@ -532,6 +497,41 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
args.repeatable_seeds,
label="Repeatable Seeds",
)
with gr.Accordion(label="Hires Fix Options", open=False):
with gr.Group():
with gr.Row():
use_hiresfix = gr.Checkbox(
value=args.use_hiresfix,
label="Use Hires Fix",
interactive=True,
)
resample_type = gr.Dropdown(
value=args.resample_type,
choices=resampler_list,
label="Resample Type",
allow_custom_value=False,
)
hiresfix_height = gr.Slider(
384,
768,
value=args.hiresfix_height,
step=8,
label="Hires Fix Height",
)
hiresfix_width = gr.Slider(
384,
768,
value=args.hiresfix_width,
step=8,
label="Hires Fix Width",
)
hiresfix_strength = gr.Slider(
0,
1,
value=args.hiresfix_strength,
step=0.01,
label="Hires Fix Denoising Strength",
)
with gr.Row():
seed = gr.Textbox(
value=args.seed,
@@ -676,3 +676,30 @@ with gr.Blocks(title="Text-to-Image") as txt2img_web:
custom_vae,
],
)
# SharkEulerDiscrete doesn't work with img2img which hires_fix uses
def set_compatible_schedulers(hires_fix_selected):
if hires_fix_selected:
return gr.Dropdown.update(
choices=scheduler_list_cpu_only,
value="DEISMultistep",
)
else:
return gr.Dropdown.update(
choices=scheduler_list,
value="SharkEulerDiscrete",
)
use_hiresfix.change(
fn=set_compatible_schedulers,
inputs=[use_hiresfix],
outputs=[scheduler],
queue=False,
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -3,6 +3,7 @@ import torch
import time
import gradio as gr
from PIL import Image
from apps.stable_diffusion.web.ui.utils import (
available_devices,
nodlogo_loc,
@@ -12,6 +13,7 @@ from apps.stable_diffusion.web.ui.utils import (
predefined_upscaler_models,
cancel_sd,
)
from apps.stable_diffusion.web.ui.common_ui_events import lora_changed
from apps.stable_diffusion.web.utils.common_label_calc import status_label
from apps.stable_diffusion.src import (
args,
@@ -340,6 +342,11 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
label="HuggingFace Model ID",
lines=3,
)
with gr.Row():
lora_tags = gr.HTML(
value="<div><i>No LoRA selected</i></div>",
elem_classes="lora-tags",
)
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
@@ -537,3 +544,10 @@ with gr.Blocks(title="Upscaler") as upscaler_web:
fn=cancel_sd,
cancels=[prompt_submit, neg_prompt_submit, generate_click],
)
lora_weights.change(
fn=lora_changed,
inputs=[lora_weights],
outputs=[lora_tags],
queue=True,
)

View File

@@ -1,10 +1,16 @@
import os
import sys
from apps.stable_diffusion.src import get_available_devices
import glob
import math
import json
import safetensors
from pathlib import Path
from apps.stable_diffusion.src import args
from dataclasses import dataclass
from enum import IntEnum
from apps.stable_diffusion.src import get_available_devices
import apps.stable_diffusion.web.utils.global_obj as global_obj
from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import (
SD_STATE_CANCEL,
@@ -28,6 +34,15 @@ class Config:
ondemand: str # should this be expecting a bool instead?
class HSLHue(IntEnum):
RED = 0
YELLOW = 60
GREEN = 120
CYAN = 180
BLUE = 240
MAGENTA = 300
custom_model_filetypes = (
"*.ckpt",
"*.safetensors",
@@ -161,6 +176,69 @@ def get_custom_vae_or_lora_weights(weights, hf_id, model):
return use_weight
def hsl_color(alpha: float, start, end):
b = (end - start) * (alpha if alpha > 0 else 0)
result = b + start
# Return a CSS HSL string
return f"hsl({math.floor(result)}, 80%, 35%)"
def get_lora_metadata(lora_filename):
# get the metadata from the file
filename = get_custom_model_pathfile(lora_filename, "lora")
with safetensors.safe_open(filename, framework="pt", device="cpu") as f:
metadata = f.metadata()
# guard clause for if there isn't any metadata
if not metadata:
return None
# metadata is a dictionary of strings, the values of the keys we're
# interested in are actually json, and need to be loaded as such
tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}")))
dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}")))
tag_dirs = [dir for dir in tag_frequencies.keys()]
# gather the tag frequency information for all the datasets trained
all_frequencies = {}
for dataset in tag_dirs:
frequencies = sorted(
[entry for entry in tag_frequencies[dataset].items()],
reverse=True,
key=lambda x: x[1],
)
# get a figure for the total number of images processed for this dataset
# either then number actually listed or in its dataset_dir entry or
# the highest frequency's number if that doesn't exist
img_count = dataset_dirs.get(dir, {}).get(
"img_count", frequencies[0][1]
)
# add the dataset frequencies to the overall frequencies replacing the
# frequency counts on the tags with a percentage/ratio
all_frequencies.update(
[(entry[0], entry[1] / img_count) for entry in frequencies]
)
trained_model_id = " ".join(
[
metadata.get("ss_sd_model_hash", ""),
metadata.get("ss_sd_model_name", ""),
metadata.get("ss_base_model_version", ""),
]
).strip()
# return the topmost <count> of all frequencies in all datasets
return {
"model": trained_model_id,
"frequencies": sorted(
all_frequencies.items(), reverse=True, key=lambda x: x[1]
),
}
def cancel_sd():
# Try catch it, as gc can delete global_obj.sd_obj while switching model
try:

View File

@@ -78,7 +78,10 @@ def test_loop(
os.mkdir("./test_images/golden")
get_inpaint_inputs()
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned", "--use_tuned"]
tuned_options = [
"--no-use_tuned",
"--use_tuned",
]
import_options = ["--import_mlir", "--no-import_mlir"]
prompt_text = "--prompt=cyberpunk forest by Salvador Dali"
inpaint_prompt_text = "--prompt=Face of a yellow cat, high resolution, sitting on a park bench"
@@ -112,6 +115,8 @@ def test_loop(
and use_tune == tuned_options[1]
):
continue
elif use_tune == tuned_options[1]:
continue
command = (
[
executable, # executable is the python from the venv used to run this

View File

@@ -22,33 +22,33 @@ This does mean however, that on a brand new fresh install of SHARK that has not
* Make sure you have suitable drivers for your graphics card installed. See the prerequisties section of the [README](https://github.com/nod-ai/SHARK#readme).
* Download the latest SHARK studio .exe from [here](https://github.com/nod-ai/SHARK/releases) or follow the instructions in the [README](https://github.com/nod-ai/SHARK#readme) for an advanced, Linux or Mac install.
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_cors_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
* Run SHARK from terminal/PowerShell with the `--api` flag. Since koboldcpp also expects both CORS support and the image generator to be running on port `7860` rather than SHARK default of `8080`, also include both the `--api_accept_origin` flag with a suitable origin (use `="*"` to enable all origins) and `--server_port=7860` on the command line. (See the if you want to run SHARK on a different port)
```powershell
## Run the .exe in API mode, with CORS support, on the A1111 endpoint port:
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_cors_origin="*" --server_port=7860
.\node_ai_shark_studio_<date>_<ver>.exe --api --api_accept_origin="*" --server_port=7860
## Run trom the base directory of a source clone of SHARK on Windows:
.\setup_venv.ps1
python .\apps\stable_diffusion\web\index.py --api --api_cors_origin="*" --server_port=7860
python .\apps\stable_diffusion\web\index.py --api --api_accept_origin="*" --server_port=7860
## Run a the base directory of a source clone of SHARK on Linux:
./setup_venv.sh
source shark.venv/bin/activate
python ./apps/stable_diffusion/web/index.py --api --api_cors_origin="*" --server_port=7860
python ./apps/stable_diffusion/web/index.py --api --api_accept_origin="*" --server_port=7860
## An example giving improved performance on AMD cards using vulkan, that runs on the same port as A1111
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860
## Since the api respects most applicable SHARK command line arguments for options not specified,
## or currently unimplemented by API, there might be some you want to set, as listed in `--help`
.\node_ai_shark_studio_20320901_2525.exe --help
## For instance, the example above, but with a a custom VAE specified
.\node_ai_shark_studio_20320901_2525.exe --api --api_cors_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
.\node_ai_shark_studio_20320901_2525.exe --api --api_accept_origin="*" --device_allocator="caching" --server_port=7860 --custom_vae="clearvae_v23.safetensors"
## An example with multiple specific CORS origins
python apps/stable_diffusion/web/index.py --api --api_cors_origin="koboldcpp.example.com:7001" --api_cors_origin="koboldcpp.example.com:7002" --server_port=7860
python apps/stable_diffusion/web/index.py --api --api_accept_origin="koboldcpp.example.com:7001" --api_accept_origin="koboldcpp.example.com:7002" --server_port=7860
```
SHARK should start in server mode, and you should see something like this:
@@ -83,11 +83,11 @@ SHARK should start in server mode, and you should see something like this:
* There is one final piece of image generation configuration within Koboldcpp you might want to do. This is also in the generate images section of advanced settings. Here there is, not very obviously, a 'style' button:
![Selecting the 'styles' button](https://user-images.githubusercontent.com/121311569/280556172-4aab9794-7a77-46d7-bdda-43df570ad19a.png)
![Selecting the 'styles' button](https://user-images.githubusercontent.com/121311569/280556694-55cd1c55-a059-4b54-9293-63d66a32368e.png)
This will bring up a dialog box where you can enter a short text that will sent as a prefix to the Prompt sent to SHARK:
![Entering extra image styles](https://github.com/one-lithe-rune/SHARK/assets/121311569/4aab9794-7a77-46d7-bdda-43df570ad19a)
![Entering extra image styles](https://user-images.githubusercontent.com/121311569/280556172-4aab9794-7a77-46d7-bdda-43df570ad19a.png)
## Connecting to SHARK on a different address or port

View File

@@ -17,6 +17,7 @@ pytest-forked
Pillow
parameterized
#shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main
# Add transformers, diffusers and scipy since it most commonly used
tokenizers==0.13.3
transformers
@@ -49,4 +50,4 @@ pefile
pyinstaller
# vicuna quantization
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea
brevitas @ git+https://github.com/Xilinx/brevitas.git@56edf56a3115d5ac04f19837b388fd7d3b1ff7ea

View File

@@ -111,7 +111,7 @@ else
fi
if [[ -z "${NO_BACKEND}" ]]; then
echo "Installing ${RUNTIME}..."
$PYTHON -m pip install --pre --upgrade --find-links ${RUNTIME} iree-compiler iree-runtime
$PYTHON -m pip install --pre --upgrade --no-index --find-links ${RUNTIME} iree-compiler iree-runtime
else
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
fi

View File

@@ -16,7 +16,6 @@ import numpy as np
import os
import re
import tempfile
import time
from pathlib import Path
import iree.runtime as ireert
@@ -34,9 +33,9 @@ def get_iree_device_args(device, extra_args=[]):
print("Configuring for device:" + device)
device_uri = device.split("://")
if len(device_uri) > 1:
if device_uri[0] not in ["vulkan"]:
if device_uri[0] not in ["vulkan", "rocm"]:
print(
f"Specific device selection only supported for vulkan now."
f"Specific device selection only supported for vulkan and rocm."
f"Proceeding with {device} as device."
)
# device_uri can be device_num or device_path.
@@ -63,8 +62,7 @@ def get_iree_device_args(device, extra_args=[]):
+ data_tiling_flag
+ u_kernel_flag
+ stack_size_flag
+ ["--iree-flow-enable-quantized-matmul-reassociation"]
+ ["--iree-llvmcpu-enable-quantized-matmul-reassociation"]
+ ["--iree-global-opt-enable-quantized-matmul-reassociation"]
)
if device_uri[0] == "cuda":
from shark.iree_utils.gpu_utils import get_iree_gpu_args
@@ -83,7 +81,7 @@ def get_iree_device_args(device, extra_args=[]):
if device_uri[0] == "rocm":
from shark.iree_utils.gpu_utils import get_iree_rocm_args
return get_iree_rocm_args(extra_args=extra_args)
return get_iree_rocm_args(device_num=device_num, extra_args=extra_args)
return []
@@ -321,6 +319,8 @@ def compile_module_to_flatbuffer(
input_type = "tosa"
elif frontend in ["tm_tensor"]:
input_type = ireec.InputType.TM_TENSOR
elif frontend in ["torch", "pytorch"]:
input_type = "torch"
if compile_str:
flatbuffer_blob = ireec.compile_str(

View File

@@ -22,6 +22,8 @@ from subprocess import CalledProcessError
from shark.parser import shark_args
from shark.iree_utils._common import run_cmd
# TODO: refactor to rocm and cuda utils
# Get the default gpu args given the architecture.
@functools.cache
@@ -41,73 +43,83 @@ def get_iree_gpu_args():
return []
# Get the default gpu args given the architecture.
def get_iree_rocm_args(device_num=0, extra_args=[]):
ireert.flags.FUNCTION_INPUT_VALIDATION = False
rocm_flags = ["--iree-rocm-link-bc=true"]
def check_rocm_device_arch_in_args(extra_args):
# Check if the target arch flag for rocm device present in extra_args
for flag in extra_args:
if "iree-rocm-target-chip" in flag:
flag_arch = flag.split("=")[1]
return flag_arch
return None
def get_rocm_device_arch(device_num=0, extra_args=[]):
# ROCM Device Arch selection:
# 1 : User given device arch using `--iree-rocm-target-chip` flag
# 2 : Device arch from `iree-run-module --dump_devices=rocm` for device on index <device_num>
# 3 : default arch : gfx1100
default_rocm_arch = "gfx_1100"
# Check if the target arch flag for rocm device present in extra_flags
flag_present = False
for flag in extra_args:
if "iree-rocm-target-chip" in flag:
flag_present = True
print(
f"User Specified rocm target device arch from flag : {flag.split('=')[1]} will be used"
)
arch_in_flag = check_rocm_device_arch_in_args(extra_args)
if arch_in_flag is not None:
print(
f"User Specified rocm target device arch from flag : {arch_in_flag} will be used"
)
return arch_in_flag
arch_in_device_dump = None
if not flag_present:
# get rocm arch from iree dump devices
def get_devices_info_from_dump(dump):
from os import linesep
dump_clean = list(
filter(
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
dump.split(linesep),
)
# get rocm arch from iree dump devices
def get_devices_info_from_dump(dump):
from os import linesep
dump_clean = list(
filter(
lambda s: "--device=rocm" in s or "gpu-arch-name:" in s,
dump.split(linesep),
)
arch_pairs = [
(
dump_clean[i].split("=")[1].strip(),
dump_clean[i + 1].split(":")[1].strip(),
)
for i in range(0, len(dump_clean), 2)
]
return arch_pairs
dump_device_info = None
try:
dump_device_info = run_cmd(
"iree-run-module --dump_devices=rocm", raise_err=True
)
except Exception as e:
print("could not execute `iree-run-module --dump_devices=rocm`")
if dump_device_info is not None:
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
if (
len(device_arch_pairs) > device_num
): # can find arch in the list
arch_in_device_dump = device_arch_pairs[device_num][1]
if arch_in_device_dump is not None:
print(f"Found ROCm device arch : {arch_in_device_dump}")
rocm_flags.append(f"--iree-rocm-target-chip={arch_in_device_dump}")
if not flag_present and arch_in_device_dump is None:
print(
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
"\n or from `iree-run-module --dump_devices=rocm` command."
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
)
rocm_flags.append(f"--iree-rocm-target-chip={default_rocm_arch}")
arch_pairs = [
(
dump_clean[i].split("=")[1].strip(),
dump_clean[i + 1].split(":")[1].strip(),
)
for i in range(0, len(dump_clean), 2)
]
return arch_pairs
dump_device_info = None
try:
dump_device_info = run_cmd(
"iree-run-module --dump_devices=rocm", raise_err=True
)
except Exception as e:
print("could not execute `iree-run-module --dump_devices=rocm`")
if dump_device_info is not None:
device_arch_pairs = get_devices_info_from_dump(dump_device_info[0])
if len(device_arch_pairs) > device_num: # can find arch in the list
arch_in_device_dump = device_arch_pairs[device_num][1]
if arch_in_device_dump is not None:
print(f"Found ROCm device arch : {arch_in_device_dump}")
return arch_in_device_dump
default_rocm_arch = "gfx_1100"
print(
"Did not find ROCm architecture from `--iree-rocm-target-chip` flag"
"\n or from `iree-run-module --dump_devices=rocm` command."
f"\nUsing {default_rocm_arch} as ROCm arch for compilation."
)
return default_rocm_arch
# Get the default gpu args given the architecture.
def get_iree_rocm_args(device_num=0, extra_args=[]):
ireert.flags.FUNCTION_INPUT_VALIDATION = False
rocm_flags = ["--iree-rocm-link-bc=true"]
if check_rocm_device_arch_in_args(extra_args) is None:
rocm_arch = get_rocm_device_arch(device_num, extra_args)
rocm_flags.append(f"--iree-rocm-target-chip={rocm_arch}")
return rocm_flags

View File

@@ -7,7 +7,7 @@ import torch_mlir
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from typing import List, Tuple
from io import BytesIO
from brevitas_examples.llm.llm_quant.quantize import quantize_model
from brevitas_examples.common.generative.quantize import quantize_model
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
@@ -84,7 +84,7 @@ def compile_int_precision(
weight_quant_type="asym",
weight_bit_width=weight_bit_width,
weight_param_method="stats",
weight_scale_precision="float",
weight_scale_precision="float_scale",
weight_quant_granularity="per_group",
weight_group_size=weight_group_size,
quantize_weight_zero_point=False,