mirror of
https://github.com/nod-ai/SHARK-Studio.git
synced 2026-01-08 21:38:04 -05:00
Merge branch 'main' into llm-rest-api
This commit is contained in:
83
.github/workflows/nightly.yml
vendored
83
.github/workflows/nightly.yml
vendored
@@ -50,10 +50,10 @@ jobs:
|
||||
shell: powershell
|
||||
run: |
|
||||
./setup_venv.ps1
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip wheel -v -w dist . --pre -f https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
python process_skipfiles.py
|
||||
pyinstaller .\apps\stable_diffusion\shark_sd.spec
|
||||
$env:SHARK_PACKAGE_VERSION=${{ env.package_version }}
|
||||
pip install -e .
|
||||
pyinstaller .\apps\shark_studio\shark_studio.spec
|
||||
mv ./dist/nodai_shark_studio.exe ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
signtool sign /f c:\g\shark_02152023.cer /fd certHash /csp "eToken Base Cryptographic Provider" /k "${{ secrets.CI_CERT }}" ./dist/nodai_shark_studio_${{ env.package_version_ }}.exe
|
||||
|
||||
@@ -74,80 +74,3 @@ jobs:
|
||||
GITHUB_TOKEN: ${{ secrets.NODAI_INVOCATION_TOKEN }}
|
||||
with:
|
||||
release_id: ${{ steps.create_release.outputs.id }}
|
||||
|
||||
linux-build:
|
||||
|
||||
runs-on: a100
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python-version: ["3.11"]
|
||||
backend: [IREE, SHARK]
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
uses: actions/setup-python@v3
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
- name: Setup pip cache
|
||||
uses: actions/cache@v3
|
||||
with:
|
||||
path: ~/.cache/pip
|
||||
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-pip-
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
echo "DATE=$(date +'%Y-%m-%d')" >> $GITHUB_ENV
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install flake8 pytest toml
|
||||
if [ -f requirements.txt ]; then pip install -r requirements.txt -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html; fi
|
||||
- name: Lint with flake8
|
||||
run: |
|
||||
# stop the build if there are Python syntax errors or undefined names
|
||||
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude shark.venv,lit.cfg.py
|
||||
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
|
||||
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics --exclude shark.venv,lit.cfg.py
|
||||
- name: Build and validate the IREE package
|
||||
if: ${{ matrix.backend == 'IREE' }}
|
||||
continue-on-error: true
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
USE_IREE=1 VENV_DIR=iree.venv ./setup_venv.sh
|
||||
source iree.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://openxla.github.io/iree/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
/bin/bash "$GITHUB_WORKSPACE/build_tools/populate_sharktank_ci.sh"
|
||||
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./gen_shark_tank/" -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
if !(grep -Fxq " failed" pytest_results.txt)
|
||||
then
|
||||
export SHA=$(git log -1 --format='%h')
|
||||
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
|
||||
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
|
||||
fi
|
||||
rm -rf ./wheelhouse/nodai*
|
||||
|
||||
- name: Build and validate the SHARK Runtime package
|
||||
if: ${{ matrix.backend == 'SHARK' }}
|
||||
run: |
|
||||
cd $GITHUB_WORKSPACE
|
||||
./setup_venv.sh
|
||||
source shark.venv/bin/activate
|
||||
package_version="$(printf '%(%Y%m%d)T.${{ github.run_number }}')"
|
||||
SHARK_PACKAGE_VERSION=${package_version} \
|
||||
pip wheel -v -w wheelhouse . --pre -f https://download.pytorch.org/whl/nightly/torch -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
# Install the built wheel
|
||||
pip install ./wheelhouse/nodai*
|
||||
# Validate the Models
|
||||
pytest --ci --ci_sha=${SHORT_SHA} -k "not metal" |
|
||||
tail -n 1 |
|
||||
tee -a pytest_results.txt
|
||||
|
||||
@@ -372,7 +372,7 @@ For a complete list of the models supported in SHARK, please refer to [tank/READ
|
||||
|
||||
* [Upstream IREE issues](https://github.com/google/iree/issues): Feature requests,
|
||||
bugs, and other work tracking
|
||||
* [Upstream IREE Discord server](https://discord.gg/26P4xW4): Daily development
|
||||
* [Upstream IREE Discord server](https://discord.gg/wEWh6Z9nMU): Daily development
|
||||
discussions with the core team and collaborators
|
||||
* [iree-discuss email list](https://groups.google.com/forum/#!forum/iree-discuss):
|
||||
Announcements, general and low-priority discussion
|
||||
|
||||
@@ -8,11 +8,12 @@ from threading import Thread
|
||||
|
||||
from apps.shark_studio.modules.timer import startup_timer
|
||||
|
||||
# from apps.shark_studio.web.utils.tmp_configs import (
|
||||
# config_tmp,
|
||||
# clear_tmp_mlir,
|
||||
# clear_tmp_imgs,
|
||||
# )
|
||||
from apps.shark_studio.web.utils.tmp_configs import (
|
||||
config_tmp,
|
||||
clear_tmp_mlir,
|
||||
clear_tmp_imgs,
|
||||
shark_tmp,
|
||||
)
|
||||
|
||||
|
||||
def imports():
|
||||
@@ -47,9 +48,9 @@ def initialize():
|
||||
# existing temporary images there if they exist. Then we can import gradio.
|
||||
# It has to be in this order or gradio ignores what we've set up.
|
||||
|
||||
# config_tmp()
|
||||
config_tmp()
|
||||
# clear_tmp_mlir()
|
||||
# clear_tmp_imgs()
|
||||
clear_tmp_imgs()
|
||||
|
||||
from apps.shark_studio.web.utils.file_utils import (
|
||||
create_checkpoint_folders,
|
||||
@@ -82,8 +83,8 @@ def dumpstacks():
|
||||
code.append(f"""File: "{filename}", line {lineno}, in {name}""")
|
||||
if line:
|
||||
code.append(" " + line.strip())
|
||||
|
||||
print("\n".join(code))
|
||||
with open(os.path.join(shark_tmp, "stack_dump.log"), "w") as f:
|
||||
f.write("\n".join(code))
|
||||
|
||||
|
||||
def setup_middleware(app):
|
||||
|
||||
@@ -393,7 +393,6 @@ def llm_chat_api(InputData: dict):
|
||||
print("prompt = ", prompt)
|
||||
|
||||
for res_op, _ in llm_model.chat(prompt):
|
||||
|
||||
if is_chat_completion_api:
|
||||
choices = [
|
||||
{
|
||||
|
||||
@@ -38,41 +38,12 @@ from diffusers.image_processor import VaeImageProcessor
|
||||
sd_model_map = {
|
||||
"clip": {
|
||||
"initializer": clip.export_clip_model,
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
],
|
||||
},
|
||||
"vae_encode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
],
|
||||
},
|
||||
"unet": {
|
||||
"initializer": unet.export_unet_model,
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))",
|
||||
],
|
||||
},
|
||||
"vae_decode": {
|
||||
"initializer": vae.export_vae_model,
|
||||
"ireec_flags": [
|
||||
"--iree-flow-collapse-reduction-dims",
|
||||
"--iree-opt-const-expr-hoisting=False",
|
||||
"--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))",
|
||||
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))",
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
@@ -96,6 +67,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
num_loras: int = 0,
|
||||
import_ir: bool = True,
|
||||
is_controlled: bool = False,
|
||||
hf_auth_token=None,
|
||||
):
|
||||
self.model_max_length = 77
|
||||
self.batch_size = batch_size
|
||||
@@ -111,9 +83,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"clip": {"hf_model_name": base_model_id},
|
||||
"unet": {
|
||||
"hf_model_name": base_model_id,
|
||||
"unet_model": unet.UnetModel(
|
||||
hf_model_name=base_model_id, hf_auth_token=None
|
||||
),
|
||||
"unet_model": unet.UnetModel(hf_model_name=base_model_id),
|
||||
"batch_size": batch_size,
|
||||
# "is_controlled": is_controlled,
|
||||
# "num_loras": num_loras,
|
||||
@@ -125,8 +95,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"vae_encode": {
|
||||
"hf_model_name": base_model_id,
|
||||
"vae_model": vae.VaeModel(
|
||||
hf_model_name=base_model_id,
|
||||
custom_vae=custom_vae,
|
||||
hf_model_name=custom_vae if custom_vae else base_model_id,
|
||||
),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
@@ -136,8 +105,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
"vae_decode": {
|
||||
"hf_model_name": base_model_id,
|
||||
"vae_model": vae.VaeModel(
|
||||
hf_model_name=base_model_id,
|
||||
custom_vae=custom_vae,
|
||||
hf_model_name=custom_vae if custom_vae else base_model_id,
|
||||
),
|
||||
"batch_size": batch_size,
|
||||
"height": height,
|
||||
@@ -149,9 +117,10 @@ class StableDiffusion(SharkPipelineBase):
|
||||
pipe_id_list = [
|
||||
safe_name(base_model_id),
|
||||
str(batch_size),
|
||||
str(static_kwargs["unet"]["max_length"]),
|
||||
str(self.model_max_length),
|
||||
f"{str(height)}x{str(width)}",
|
||||
precision,
|
||||
self.device,
|
||||
]
|
||||
if num_loras > 0:
|
||||
pipe_id_list.append(str(num_loras) + "lora")
|
||||
@@ -183,7 +152,7 @@ class StableDiffusion(SharkPipelineBase):
|
||||
custom_weights_params, _ = process_custom_pipe_weights(custom_weights)
|
||||
if submodel not in ["clip", "clip2"]:
|
||||
self.static_kwargs[submodel][
|
||||
"external_weight_file"
|
||||
"external_weights"
|
||||
] = custom_weights_params
|
||||
else:
|
||||
self.static_kwargs[submodel]["external_weight_path"] = os.path.join(
|
||||
|
||||
@@ -9,7 +9,6 @@ from random import (
|
||||
|
||||
from pathlib import Path
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
# TODO: migrate these utils to studio
|
||||
@@ -125,7 +124,6 @@ def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags += [
|
||||
f"--device_allocator=caching:device_local={cmd_opts.device_allocator_heap_key}",
|
||||
]
|
||||
|
||||
set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags)
|
||||
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ class SharkPipelineBase:
|
||||
)
|
||||
|
||||
weights_path = self.get_io_params(submodel)
|
||||
if weights_path:
|
||||
ireec_flags.append("--iree-opt-const-eval=False")
|
||||
|
||||
self.iree_module_dict[submodel] = get_iree_compiled_module(
|
||||
self.tempfiles[submodel],
|
||||
|
||||
@@ -615,7 +615,7 @@ p.add_argument(
|
||||
p.add_argument(
|
||||
"--ckpt_dir",
|
||||
type=str,
|
||||
default="",
|
||||
default="../models",
|
||||
help="Path to directory where all .ckpts are stored in order to populate "
|
||||
"them in the web UI.",
|
||||
)
|
||||
|
||||
48
apps/shark_studio/shark_studio.spec
Normal file
48
apps/shark_studio/shark_studio.spec
Normal file
@@ -0,0 +1,48 @@
|
||||
# -*- mode: python ; coding: utf-8 -*-
|
||||
from apps.shark_studio.studio_imports import pathex, datas, hiddenimports
|
||||
|
||||
binaries = []
|
||||
|
||||
block_cipher = None
|
||||
|
||||
a = Analysis(
|
||||
['web/index.py'],
|
||||
pathex=pathex,
|
||||
binaries=binaries,
|
||||
datas=datas,
|
||||
hiddenimports=hiddenimports,
|
||||
hookspath=[],
|
||||
hooksconfig={},
|
||||
runtime_hooks=[],
|
||||
excludes=[],
|
||||
win_no_prefer_redirects=False,
|
||||
win_private_assemblies=False,
|
||||
cipher=block_cipher,
|
||||
noarchive=False,
|
||||
module_collection_mode={
|
||||
'gradio': 'py', # Collect gradio package as source .py files
|
||||
},
|
||||
)
|
||||
pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher)
|
||||
|
||||
exe = EXE(
|
||||
pyz,
|
||||
a.scripts,
|
||||
a.binaries,
|
||||
a.zipfiles,
|
||||
a.datas,
|
||||
[],
|
||||
name='nodai_shark_studio',
|
||||
debug=False,
|
||||
bootloader_ignore_signals=False,
|
||||
strip=False,
|
||||
upx=False,
|
||||
upx_exclude=[],
|
||||
runtime_tmpdir=None,
|
||||
console=True,
|
||||
disable_windowed_traceback=False,
|
||||
argv_emulation=False,
|
||||
target_arch=None,
|
||||
codesign_identity=None,
|
||||
entitlements_file=None,
|
||||
)
|
||||
68
apps/shark_studio/studio_imports.py
Normal file
68
apps/shark_studio/studio_imports.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from PyInstaller.utils.hooks import collect_data_files
|
||||
from PyInstaller.utils.hooks import copy_metadata
|
||||
from PyInstaller.utils.hooks import collect_submodules
|
||||
|
||||
import sys
|
||||
|
||||
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
# python path for pyinstaller
|
||||
pathex = [
|
||||
".",
|
||||
]
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
datas += copy_metadata("torch")
|
||||
datas += copy_metadata("tokenizers")
|
||||
datas += copy_metadata("tqdm")
|
||||
datas += copy_metadata("regex")
|
||||
datas += copy_metadata("requests")
|
||||
datas += copy_metadata("packaging")
|
||||
datas += copy_metadata("filelock")
|
||||
datas += copy_metadata("numpy")
|
||||
datas += copy_metadata("importlib_metadata")
|
||||
datas += copy_metadata("omegaconf")
|
||||
datas += copy_metadata("safetensors")
|
||||
datas += copy_metadata("Pillow")
|
||||
datas += copy_metadata("sentencepiece")
|
||||
datas += copy_metadata("pyyaml")
|
||||
datas += copy_metadata("huggingface-hub")
|
||||
datas += copy_metadata("gradio")
|
||||
datas += copy_metadata("scipy")
|
||||
datas += collect_data_files("torch")
|
||||
datas += collect_data_files("tokenizers")
|
||||
datas += collect_data_files("accelerate")
|
||||
datas += collect_data_files("diffusers")
|
||||
datas += collect_data_files("transformers")
|
||||
datas += collect_data_files("gradio")
|
||||
datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree", include_py_files=True)
|
||||
datas += collect_data_files("shark", include_py_files=True)
|
||||
datas += collect_data_files("tqdm")
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
datas += collect_data_files("jsonschema")
|
||||
datas += collect_data_files("jsonschema_specifications")
|
||||
datas += collect_data_files("cpuinfo")
|
||||
datas += collect_data_files("scipy", include_py_files=True)
|
||||
datas += [
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/js/*", "ui/js"),
|
||||
("web/ui/logos/*", "logos"),
|
||||
]
|
||||
|
||||
|
||||
# hidden imports for pyinstaller
|
||||
hiddenimports = ["shark", "apps"]
|
||||
hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x]
|
||||
hiddenimports += [x for x in collect_submodules("diffusers") if "tests" not in x]
|
||||
blacklist = ["tests", "convert"]
|
||||
hiddenimports += [
|
||||
x
|
||||
for x in collect_submodules("transformers")
|
||||
if not any(kw in x for kw in blacklist)
|
||||
]
|
||||
hiddenimports += [x for x in collect_submodules("iree") if "test" not in x]
|
||||
hiddenimports += ["iree._runtime"]
|
||||
hiddenimports += [x for x in collect_submodules("scipy") if "test" not in x]
|
||||
@@ -7,9 +7,25 @@
|
||||
import logging
|
||||
import unittest
|
||||
import json
|
||||
from apps.shark_studio.api.llm import LanguageModel
|
||||
|
||||
import gc
|
||||
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
|
||||
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
|
||||
from apps.shark_studio.web.utils.file_utils import get_resource_path
|
||||
|
||||
# class SDAPITest(unittest.TestCase):
|
||||
# def testSDSimple(self):
|
||||
# from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
# import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
# global_obj._init()
|
||||
|
||||
# sd_json = view_json_file(get_resource_path("../configs/default_sd_config.json"))
|
||||
# sd_kwargs = json.loads(sd_json)
|
||||
# for arg in vars(cmd_opts):
|
||||
# if arg in sd_kwargs:
|
||||
# sd_kwargs[arg] = getattr(cmd_opts, arg)
|
||||
# for i in shark_sd_fn_dict_input(sd_kwargs):
|
||||
# print(i)
|
||||
|
||||
from apps.shark_studio.api.llm import LanguageModel, llm_chat_api
|
||||
from apps.shark_studio.api.sd import shark_sd_fn_dict_input, view_json_file
|
||||
|
||||
@@ -38,8 +38,7 @@ def llm_chat_test(verbose=False):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# "Exercises the Stable Diffusion REST API of Shark. Make sure "
|
||||
# "Exercises the chatbot REST API of Shark. Make sure "
|
||||
# "Shark is running in API mode on 127.0.0.1:8080 before running"
|
||||
# "this script."
|
||||
|
||||
|
||||
@@ -26,7 +26,6 @@ from apps.shark_studio.api.llm import llm_chat_api
|
||||
|
||||
def decode_base64_to_image(encoding):
|
||||
if encoding.startswith("http://") or encoding.startswith("https://"):
|
||||
|
||||
headers = {}
|
||||
response = requests.get(encoding, timeout=30, headers=headers)
|
||||
try:
|
||||
|
||||
@@ -1 +1,28 @@
|
||||
{"prompt": ["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"], "negative_prompt": ["watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"], "sd_init_image": [null], "height": 512, "width": 512, "steps": 50, "strength": 0.8, "guidance_scale": 7.5, "seed": "-1", "batch_count": 1, "batch_size": 1, "scheduler": "EulerDiscrete", "base_model_id": "stabilityai/stable-diffusion-2-1-base", "custom_weights": "None", "custom_vae": "None", "precision": "fp16", "device": "AMD Radeon RX 7900 XTX => vulkan://0", "ondemand": false, "repeatable_seeds": false, "resample_type": "Nearest Neighbor", "controlnets": {}, "embeddings": {}}
|
||||
{
|
||||
"prompt": [
|
||||
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
|
||||
],
|
||||
"negative_prompt": [
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
"sd_init_image": [null],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": 50,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": "-1",
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"precision": "fp16",
|
||||
"device": "AMD Radeon RX 7900 XTX => vulkan://0",
|
||||
"ondemand": false,
|
||||
"repeatable_seeds": false,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from multiprocessing import Process, freeze_support
|
||||
|
||||
freeze_support()
|
||||
from PIL import Image
|
||||
|
||||
import os
|
||||
import time
|
||||
import sys
|
||||
@@ -71,6 +75,10 @@ def launch_webui(address):
|
||||
|
||||
def webui():
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
nodicon_loc,
|
||||
nodlogo_loc,
|
||||
)
|
||||
|
||||
launch_api = cmd_opts.api
|
||||
initialize.initialize()
|
||||
@@ -81,12 +89,6 @@ def webui():
|
||||
|
||||
# required to do multiprocessing in a pyinstaller freeze
|
||||
freeze_support()
|
||||
|
||||
# if args.api or "api" in args.ui.split(","):
|
||||
# from apps.shark_studio.api.llm import (
|
||||
# chat,
|
||||
# )
|
||||
# from apps.shark_studio.web.api import sdapi
|
||||
#
|
||||
# from fastapi import FastAPI, APIRouter
|
||||
# from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -134,6 +136,7 @@ def webui():
|
||||
return os.path.join(base_path, relative_path)
|
||||
|
||||
dark_theme = resource_path("ui/css/sd_dark_theme.css")
|
||||
gradio_workarounds = resource_path("ui/js/sd_gradio_workarounds.js")
|
||||
|
||||
# from apps.shark_studio.web.ui import load_ui_from_script
|
||||
|
||||
@@ -158,8 +161,19 @@ def webui():
|
||||
)
|
||||
|
||||
with gr.Blocks(
|
||||
css=dark_theme, analytics_enabled=False, title="Shark Studio 2.0 Beta"
|
||||
css=dark_theme,
|
||||
js=gradio_workarounds,
|
||||
analytics_enabled=False,
|
||||
title="Shark Studio 2.0 Beta",
|
||||
) as studio_web:
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="tab_bar_logo",
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Tabs() as tabs:
|
||||
# NOTE: If adding, removing, or re-ordering tabs, make sure that they
|
||||
# have a unique id that doesn't clash with any of the other tabs,
|
||||
@@ -185,10 +199,11 @@ def webui():
|
||||
# )
|
||||
# t.start()
|
||||
studio_web.launch(
|
||||
share=True,
|
||||
share=cmd_opts.share,
|
||||
inbrowser=True,
|
||||
server_name="0.0.0.0",
|
||||
server_port=cmd_opts.server_port,
|
||||
favicon_path=nodicon_loc,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,6 @@ import apps.shark_studio.web.utils.globals as global_obj
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
B_SYS, E_SYS = "<s>", "</s>"
|
||||
|
||||
|
||||
def user(message, history):
|
||||
# Append the user's message to the conversation history
|
||||
return "", history + [[message, ""]]
|
||||
|
||||
@@ -117,7 +117,7 @@ body {
|
||||
height: 100% !important;
|
||||
}
|
||||
|
||||
/* display in full width for desktop devices */
|
||||
/* display in full width for desktop devices, but see below */
|
||||
@media (min-width: 1536px)
|
||||
{
|
||||
.gradio-container {
|
||||
@@ -182,6 +182,7 @@ footer {
|
||||
aspect-ratio: unset;
|
||||
max-height: calc(55vh - (2 * var(--spacing-lg)));
|
||||
}
|
||||
/* fix width and height of gallery items when on very large desktop screens, but see below */
|
||||
@media (min-width: 1921px) {
|
||||
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
|
||||
#gallery .grid-wrap, #gallery .preview{
|
||||
@@ -193,6 +194,20 @@ footer {
|
||||
max-height: 770px !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* media rules in custom css are don't appear to be applied in
|
||||
gradio versions > 4.7, so we have to define classes which
|
||||
we will manually need add and remove using javascript.
|
||||
Remove this once this fixed in gradio.
|
||||
*/
|
||||
.gallery-force-height768 .grid-wrap, .gallery-force-height768 .preview {
|
||||
min-height: calc(768px + 4px + var(--size-14)) !important;
|
||||
max-height: calc(768px + 4px + var(--size-14)) !important;
|
||||
}
|
||||
.gallery-limit-height768 .thumbnail-item.thumbnail-lg {
|
||||
max-height: 770px !important;
|
||||
}
|
||||
|
||||
/* Don't upscale when viewing in solo image mode */
|
||||
#gallery .preview img {
|
||||
object-fit: scale-down;
|
||||
@@ -303,6 +318,15 @@ footer {
|
||||
min-height: 89vh !important;
|
||||
}
|
||||
|
||||
.sd-right-panel {
|
||||
height: calc(100vmin - var(--size-32) - var(--size-10)) !important;
|
||||
overflow-y: scroll;
|
||||
}
|
||||
|
||||
.sd-right-panel .fill {
|
||||
flex: 1;
|
||||
}
|
||||
|
||||
/* don't stretch non-square images to be square, breaking their aspect ratio */
|
||||
#outputgallery_gallery .thumbnail-item.thumbnail-lg > img {
|
||||
object-fit: contain !important;
|
||||
@@ -314,7 +338,7 @@ footer {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
#top_logo.logo_centered img{
|
||||
#top_logo.logo_centered img {
|
||||
object-fit: scale-down;
|
||||
position: absolute;
|
||||
width: 80%;
|
||||
@@ -322,3 +346,19 @@ footer {
|
||||
left: 50%;
|
||||
transform: translate(-50%, -50%);
|
||||
}
|
||||
|
||||
#tab_bar_logo {
|
||||
overflow: visible !important;
|
||||
border-width: 0 !important;
|
||||
height: 0px !important;
|
||||
padding: 0;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
#tab_bar_logo .image-container {
|
||||
object-fit: scale-down;
|
||||
position: absolute !important;
|
||||
top: 14px;
|
||||
right: 0px;
|
||||
height: 36px;
|
||||
}
|
||||
|
||||
49
apps/shark_studio/web/ui/js/sd_gradio_workarounds.js
Normal file
49
apps/shark_studio/web/ui/js/sd_gradio_workarounds.js
Normal file
@@ -0,0 +1,49 @@
|
||||
// workaround gradio after 4.7, not applying any @media rules form the custom .css file
|
||||
|
||||
() => {
|
||||
console.log(`innerWidth: ${window.innerWidth}` )
|
||||
|
||||
// 1536px rules
|
||||
|
||||
const mediaQuery1536 = window.matchMedia('(min-width: 1536px)')
|
||||
|
||||
function handleWidth1536(event) {
|
||||
|
||||
// display in full width for desktop devices
|
||||
document.querySelectorAll(".gradio-container")
|
||||
.forEach( (node) => {
|
||||
if (event.matches) {
|
||||
node.classList.add("gradio-container-size-full");
|
||||
} else {
|
||||
node.classList.remove("gradio-container-size-full")
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
mediaQuery1536.addEventListener("change", handleWidth1536);
|
||||
mediaQuery1536.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1536}));
|
||||
|
||||
// 1921px rules
|
||||
|
||||
const mediaQuery1921 = window.matchMedia('(min-width: 1921px)')
|
||||
|
||||
function handleWidth1921(event) {
|
||||
|
||||
/* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */
|
||||
/* Limit height to 768px_height + 2px_margin_height for the thumbnails */
|
||||
document.querySelectorAll("#gallery")
|
||||
.forEach( (node) => {
|
||||
if (event.matches) {
|
||||
node.classList.add("gallery-force-height768");
|
||||
node.classList.add("gallery-limit-height768");
|
||||
} else {
|
||||
node.classList.remove("gallery-force-height768");
|
||||
node.classList.remove("gallery-limit-height768");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
mediaQuery1921.addEventListener("change", handleWidth1921);
|
||||
mediaQuery1921.dispatchEvent(new MediaQueryListEvent("change", {matches: window.innerWidth >= 1921}));
|
||||
|
||||
}
|
||||
@@ -14,6 +14,7 @@ from apps.shark_studio.web.utils.file_utils import (
|
||||
get_checkpoints_path,
|
||||
get_checkpoints,
|
||||
get_configs_path,
|
||||
write_default_sd_config,
|
||||
)
|
||||
from apps.shark_studio.api.sd import (
|
||||
sd_model_map,
|
||||
@@ -33,6 +34,8 @@ from apps.shark_studio.modules.img_processing import (
|
||||
from apps.shark_studio.modules.shared_cmd_opts import cmd_opts
|
||||
from apps.shark_studio.web.ui.utils import (
|
||||
nodlogo_loc,
|
||||
none_to_str_none,
|
||||
str_none_to_none,
|
||||
)
|
||||
from apps.shark_studio.web.utils.state import (
|
||||
status_label,
|
||||
@@ -122,7 +125,7 @@ def pull_sd_configs(
|
||||
controlnets,
|
||||
embeddings,
|
||||
):
|
||||
sd_args = locals()
|
||||
sd_args = str_none_to_none(locals())
|
||||
sd_cfg = {}
|
||||
for arg in sd_args:
|
||||
if arg in [
|
||||
@@ -138,11 +141,12 @@ def pull_sd_configs(
|
||||
sd_cfg[arg] = {}
|
||||
else:
|
||||
sd_cfg[arg] = sd_args[arg]
|
||||
return sd_cfg
|
||||
|
||||
return json.dumps(sd_cfg)
|
||||
|
||||
|
||||
def load_sd_cfg(sd_json: dict, load_sd_config: str):
|
||||
new_sd_config = json.loads(view_json_file(load_sd_config))
|
||||
new_sd_config = none_to_str_none(json.loads(view_json_file(load_sd_config)))
|
||||
if sd_json:
|
||||
for key in new_sd_config:
|
||||
sd_json[key] = new_sd_config[key]
|
||||
@@ -226,53 +230,93 @@ def import_original(original_img, width, height):
|
||||
return EditorValue(img_dict)
|
||||
|
||||
|
||||
def base_model_changed(base_model_id):
|
||||
new_choices = get_checkpoints(
|
||||
os.path.join("checkpoints", os.path.basename(str(base_model_id)))
|
||||
) + get_checkpoints(model_type="checkpoints")
|
||||
|
||||
return gr.Dropdown(
|
||||
value=new_choices[0] if len(new_choices) > 0 else "None",
|
||||
choices=["None"] + new_choices,
|
||||
)
|
||||
|
||||
|
||||
with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
with gr.Row(elem_id="ui_title"):
|
||||
nod_logo = Image.open(nodlogo_loc)
|
||||
with gr.Row(variant="compact", equal_height=True):
|
||||
with gr.Column(
|
||||
scale=1,
|
||||
elem_id="demo_title_outer",
|
||||
):
|
||||
gr.Image(
|
||||
value=nod_logo,
|
||||
show_label=False,
|
||||
interactive=False,
|
||||
elem_id="top_logo",
|
||||
width=150,
|
||||
height=50,
|
||||
show_download_button=False,
|
||||
)
|
||||
with gr.Column(elem_id="ui_body"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2, min_width=600):
|
||||
with gr.Row(equal_height=True):
|
||||
with gr.Column(scale=3):
|
||||
sd_model_info = (
|
||||
f"Checkpoint Path: {str(get_checkpoints_path())}"
|
||||
with gr.Accordion(
|
||||
label="\U0001F4D0\U0000FE0F Device Settings", open=False
|
||||
):
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=global_obj.get_device_list()[0],
|
||||
choices=global_obj.get_device_list(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
base_model_id = gr.Dropdown(
|
||||
label="Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
choices=sd_default_models,
|
||||
) # base_model_id
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=cmd_opts.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
sd_model_info = f"Checkpoint Path: {str(get_checkpoints_path())}"
|
||||
base_model_id = gr.Dropdown(
|
||||
label="\U000026F0\U0000FE0F Base Model",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="stabilityai/stable-diffusion-2-1-base",
|
||||
choices=sd_default_models,
|
||||
) # base_model_id
|
||||
with gr.Row():
|
||||
height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=cmd_opts.height,
|
||||
step=8,
|
||||
label="\U00002195\U0000FE0F Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=cmd_opts.width,
|
||||
step=8,
|
||||
label="\U00002194\U0000FE0F Width",
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="\U00002696\U0000FE0F Model Weights", open=False
|
||||
):
|
||||
with gr.Column():
|
||||
custom_weights = gr.Dropdown(
|
||||
label="Custom Weights",
|
||||
label="Checkpoint Weights",
|
||||
info="Select or enter HF model ID",
|
||||
elem_id="custom_model",
|
||||
value="None",
|
||||
allow_custom_value=True,
|
||||
choices=["None"] + get_checkpoints(base_model_id),
|
||||
) #
|
||||
with gr.Column(scale=2):
|
||||
choices=["None"]
|
||||
+ get_checkpoints(os.path.basename(str(base_model_id))),
|
||||
) # custom_weights
|
||||
base_model_id.change(
|
||||
fn=base_model_changed,
|
||||
inputs=[base_model_id],
|
||||
outputs=[custom_weights],
|
||||
)
|
||||
sd_vae_info = (str(get_checkpoints_path("vae"))).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
sd_vae_info = f"VAE Path: {sd_vae_info}"
|
||||
custom_vae = gr.Dropdown(
|
||||
label=f"Custom VAE Models",
|
||||
label=f"VAE Model",
|
||||
info=sd_vae_info,
|
||||
elem_id="custom_model",
|
||||
value=(
|
||||
@@ -284,49 +328,9 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
allow_custom_value=True,
|
||||
scale=1,
|
||||
)
|
||||
with gr.Row():
|
||||
ondemand = gr.Checkbox(
|
||||
value=cmd_opts.lowvram,
|
||||
label="Low VRAM",
|
||||
interactive=True,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value=cmd_opts.precision,
|
||||
choices=[
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="Prompt",
|
||||
value=cmd_opts.prompt[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="Negative Prompt",
|
||||
value=cmd_opts.negative_prompt[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
)
|
||||
|
||||
with gr.Accordion(label="Input Image", open=False):
|
||||
# TODO: make this import image prompt info if it exists
|
||||
sd_init_image = gr.Image(
|
||||
label="Input Image",
|
||||
type="pil",
|
||||
height=300,
|
||||
interactive=True,
|
||||
)
|
||||
with gr.Accordion(label="Embeddings options", open=True, render=True):
|
||||
sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
with gr.Row():
|
||||
embeddings_config = gr.JSON(min_width=50, scale=1)
|
||||
sd_lora_info = (str(get_checkpoints_path("loras"))).replace(
|
||||
"\\", "\n\\"
|
||||
)
|
||||
lora_opt = gr.Dropdown(
|
||||
allow_custom_value=True,
|
||||
label=f"Standalone LoRA Weights",
|
||||
@@ -341,106 +345,83 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
value="<div><i>No LoRA selected</i></div>",
|
||||
elem_classes="lora-tags",
|
||||
)
|
||||
gr.on(
|
||||
triggers=[lora_opt.change],
|
||||
fn=lora_changed,
|
||||
inputs=[lora_opt],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
show_progress=False,
|
||||
).then(
|
||||
fn=update_embeddings_json,
|
||||
inputs=[lora_opt],
|
||||
outputs=[embeddings_config],
|
||||
show_progress=False,
|
||||
embeddings_config = gr.JSON(
|
||||
label="Embeddings Options", min_width=50, scale=1
|
||||
)
|
||||
gr.on(
|
||||
triggers=[lora_opt.change],
|
||||
fn=lora_changed,
|
||||
inputs=[lora_opt],
|
||||
outputs=[lora_tags],
|
||||
queue=True,
|
||||
show_progress=False,
|
||||
).then(
|
||||
fn=update_embeddings_json,
|
||||
inputs=[lora_opt],
|
||||
outputs=[embeddings_config],
|
||||
show_progress=False,
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="\U0001F9EA\U0000FE0F Input Image Processing", open=False
|
||||
):
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=cmd_opts.strength,
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
with gr.Accordion(label="Advanced Options", open=True):
|
||||
with gr.Row():
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="Scheduler",
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_model_map.keys(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
height = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=cmd_opts.height,
|
||||
step=8,
|
||||
label="Height",
|
||||
)
|
||||
width = gr.Slider(
|
||||
384,
|
||||
768,
|
||||
value=cmd_opts.width,
|
||||
step=8,
|
||||
label="Width",
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
steps = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.steps,
|
||||
step=1,
|
||||
label="Steps",
|
||||
)
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=cmd_opts.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
cmd_opts.repeatable_seeds,
|
||||
label="Repeatable Seeds",
|
||||
)
|
||||
with gr.Column(scale=3):
|
||||
strength = gr.Slider(
|
||||
0,
|
||||
1,
|
||||
value=cmd_opts.strength,
|
||||
step=0.01,
|
||||
label="Denoising Strength",
|
||||
)
|
||||
resample_type = gr.Dropdown(
|
||||
value=cmd_opts.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=cmd_opts.guidance_scale,
|
||||
step=0.1,
|
||||
label="CFG Scale",
|
||||
)
|
||||
with gr.Row():
|
||||
resample_type = gr.Dropdown(
|
||||
value=cmd_opts.resample_type,
|
||||
choices=resampler_list,
|
||||
label="Resample Type",
|
||||
allow_custom_value=True,
|
||||
)
|
||||
with gr.Group(elem_id="prompt_box_outer"):
|
||||
prompt = gr.Textbox(
|
||||
label="\U00002795\U0000FE0F Prompt",
|
||||
value=cmd_opts.prompt[0],
|
||||
lines=2,
|
||||
elem_id="prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
negative_prompt = gr.Textbox(
|
||||
label="\U00002796\U0000FE0F Negative Prompt",
|
||||
value=cmd_opts.negative_prompt[0],
|
||||
lines=2,
|
||||
elem_id="negative_prompt_box",
|
||||
show_copy_button=True,
|
||||
)
|
||||
with gr.Row(equal_height=True):
|
||||
seed = gr.Textbox(
|
||||
value=cmd_opts.seed,
|
||||
label="Seed",
|
||||
label="\U0001F331\U0000FE0F Seed",
|
||||
info="An integer or a JSON list of integers, -1 for random",
|
||||
show_copy_button=True,
|
||||
)
|
||||
device = gr.Dropdown(
|
||||
elem_id="device",
|
||||
label="Device",
|
||||
value=global_obj.get_device_list()[0],
|
||||
choices=global_obj.get_device_list(),
|
||||
scheduler = gr.Dropdown(
|
||||
elem_id="scheduler",
|
||||
label="\U0001F4C5\U0000FE0F Scheduler",
|
||||
info="\U000E0020", # forces same height as seed
|
||||
value="EulerDiscrete",
|
||||
choices=scheduler_model_map.keys(),
|
||||
allow_custom_value=False,
|
||||
)
|
||||
with gr.Row():
|
||||
steps = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.steps,
|
||||
step=1,
|
||||
label="\U0001F3C3\U0000FE0F Steps",
|
||||
)
|
||||
guidance_scale = gr.Slider(
|
||||
0,
|
||||
50,
|
||||
value=cmd_opts.guidance_scale,
|
||||
step=0.1,
|
||||
label="\U0001F5C3\U0000FE0F CFG Scale",
|
||||
)
|
||||
with gr.Accordion(
|
||||
label="Controlnet Options",
|
||||
open=False,
|
||||
@@ -530,16 +511,6 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
"Submit",
|
||||
size="sm",
|
||||
)
|
||||
use_input_img.click(
|
||||
fn=import_original,
|
||||
inputs=[
|
||||
sd_init_image,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
],
|
||||
outputs=[cnet_input],
|
||||
queue=False,
|
||||
)
|
||||
make_canvas.click(
|
||||
fn=create_canvas,
|
||||
inputs=[canvas_width, canvas_height],
|
||||
@@ -572,156 +543,227 @@ with gr.Blocks(title="Stable Diffusion") as sd_element:
|
||||
queue=False,
|
||||
)
|
||||
with gr.Column(scale=3, min_width=600):
|
||||
with gr.Group():
|
||||
sd_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
columns=2,
|
||||
object_fit="fit",
|
||||
preview=True,
|
||||
with gr.Tabs() as sd_tabs:
|
||||
sd_element.load(
|
||||
# Workaround for Gradio issue #7085
|
||||
# TODO: revert to setting selected= in gr.Tabs declaration
|
||||
# once this is resolved in Gradio
|
||||
lambda: gr.Tabs(selected=101),
|
||||
outputs=[sd_tabs],
|
||||
)
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=False,
|
||||
)
|
||||
sd_element.load(logger.read_sd_logs, None, std_output, every=1)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Generate Image(s)")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop Batch")
|
||||
with gr.Group():
|
||||
with gr.Column(scale=3):
|
||||
sd_json = gr.JSON(
|
||||
value=view_json_file(
|
||||
os.path.join(
|
||||
with gr.Tab(label="Input Image", id=100) as sd_tab_init_image:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
# TODO: make this import image prompt info if it exists
|
||||
sd_init_image = gr.Image(
|
||||
type="pil",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
)
|
||||
use_input_img.click(
|
||||
fn=import_original,
|
||||
inputs=[
|
||||
sd_init_image,
|
||||
canvas_width,
|
||||
canvas_height,
|
||||
],
|
||||
outputs=[cnet_input],
|
||||
queue=False,
|
||||
)
|
||||
with gr.Tab(label="Generate Images", id=101) as sd_tab_gallery:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
sd_gallery = gr.Gallery(
|
||||
label="Generated images",
|
||||
show_label=False,
|
||||
elem_id="gallery",
|
||||
columns=2,
|
||||
object_fit="fit",
|
||||
preview=True,
|
||||
)
|
||||
with gr.Row():
|
||||
std_output = gr.Textbox(
|
||||
value=f"{sd_model_info}\n"
|
||||
f"Images will be saved at "
|
||||
f"{get_generated_imgs_path()}",
|
||||
lines=2,
|
||||
elem_id="std_output",
|
||||
show_label=True,
|
||||
label="Log",
|
||||
show_copy_button=True,
|
||||
)
|
||||
sd_element.load(
|
||||
logger.read_sd_logs, None, std_output, every=1
|
||||
)
|
||||
sd_status = gr.Textbox(visible=False)
|
||||
with gr.Row():
|
||||
batch_count = gr.Slider(
|
||||
1,
|
||||
100,
|
||||
value=cmd_opts.batch_count,
|
||||
step=1,
|
||||
label="Batch Count",
|
||||
interactive=True,
|
||||
)
|
||||
batch_size = gr.Slider(
|
||||
1,
|
||||
4,
|
||||
value=cmd_opts.batch_size,
|
||||
step=1,
|
||||
label="Batch Size",
|
||||
interactive=True,
|
||||
visible=True,
|
||||
)
|
||||
repeatable_seeds = gr.Checkbox(
|
||||
cmd_opts.repeatable_seeds,
|
||||
label="Use Repeatable Seeds for Batches",
|
||||
)
|
||||
with gr.Row():
|
||||
stable_diffusion = gr.Button("Start")
|
||||
random_seed = gr.Button("Randomize Seed")
|
||||
random_seed.click(
|
||||
lambda: -1,
|
||||
inputs=[],
|
||||
outputs=[seed],
|
||||
queue=False,
|
||||
show_progress=False,
|
||||
)
|
||||
stop_batch = gr.Button("Stop")
|
||||
with gr.Tab(label="Config", id=102) as sd_tab_config:
|
||||
with gr.Column(elem_classes=["sd-right-panel"]):
|
||||
with gr.Row(elem_classes=["fill"]):
|
||||
Path(get_configs_path()).mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
default_config_file = os.path.join(
|
||||
get_configs_path(),
|
||||
"default_sd_config.json",
|
||||
)
|
||||
)
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config", size="sm", components=sd_json
|
||||
)
|
||||
with gr.Row():
|
||||
save_sd_config = gr.Button(value="Save Config", size="sm")
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
)
|
||||
load_sd_config = gr.FileExplorer(
|
||||
label="Load Config",
|
||||
file_count="single",
|
||||
root=(
|
||||
cmd_opts.configs_path
|
||||
if cmd_opts.configs_path
|
||||
else get_configs_path()
|
||||
),
|
||||
height=75,
|
||||
)
|
||||
load_sd_config.change(
|
||||
fn=load_sd_cfg,
|
||||
inputs=[sd_json, load_sd_config],
|
||||
outputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
sd_json,
|
||||
],
|
||||
)
|
||||
write_default_sd_config(default_config_file)
|
||||
sd_json = gr.JSON(
|
||||
elem_classes=["fill"],
|
||||
value=view_json_file(default_config_file),
|
||||
)
|
||||
with gr.Row():
|
||||
with gr.Column(scale=3):
|
||||
load_sd_config = gr.FileExplorer(
|
||||
label="Load Config",
|
||||
file_count="single",
|
||||
root_dir=(
|
||||
cmd_opts.configs_path
|
||||
if cmd_opts.configs_path
|
||||
else get_configs_path()
|
||||
),
|
||||
height=75,
|
||||
)
|
||||
with gr.Column(scale=1):
|
||||
save_sd_config = gr.Button(
|
||||
value="Save Config", size="sm"
|
||||
)
|
||||
clear_sd_config = gr.ClearButton(
|
||||
value="Clear Config",
|
||||
size="sm",
|
||||
components=sd_json,
|
||||
)
|
||||
with gr.Row():
|
||||
sd_config_name = gr.Textbox(
|
||||
value="Config Name",
|
||||
info="Name of the file this config will be saved to.",
|
||||
interactive=True,
|
||||
show_label=False,
|
||||
)
|
||||
load_sd_config.change(
|
||||
fn=load_sd_cfg,
|
||||
inputs=[sd_json, load_sd_config],
|
||||
outputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
sd_json,
|
||||
],
|
||||
)
|
||||
save_sd_config.click(
|
||||
fn=save_sd_cfg,
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
save_sd_config.click(
|
||||
fn=save_sd_cfg,
|
||||
inputs=[sd_json, sd_config_name],
|
||||
outputs=[sd_config_name],
|
||||
)
|
||||
|
||||
pull_kwargs = dict(
|
||||
fn=pull_sd_configs,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
],
|
||||
outputs=[
|
||||
sd_json,
|
||||
],
|
||||
)
|
||||
pull_kwargs = dict(
|
||||
fn=pull_sd_configs,
|
||||
inputs=[
|
||||
prompt,
|
||||
negative_prompt,
|
||||
sd_init_image,
|
||||
height,
|
||||
width,
|
||||
steps,
|
||||
strength,
|
||||
guidance_scale,
|
||||
seed,
|
||||
batch_count,
|
||||
batch_size,
|
||||
scheduler,
|
||||
base_model_id,
|
||||
custom_weights,
|
||||
custom_vae,
|
||||
precision,
|
||||
device,
|
||||
ondemand,
|
||||
repeatable_seeds,
|
||||
resample_type,
|
||||
cnet_config,
|
||||
embeddings_config,
|
||||
],
|
||||
outputs=[
|
||||
sd_json,
|
||||
],
|
||||
)
|
||||
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=sd_status,
|
||||
)
|
||||
status_kwargs = dict(
|
||||
fn=lambda bc, bs: status_label("Stable Diffusion", 0, bc, bs),
|
||||
inputs=[batch_count, batch_size],
|
||||
outputs=sd_status,
|
||||
)
|
||||
|
||||
gen_kwargs = dict(
|
||||
fn=shark_sd_fn_dict_input,
|
||||
inputs=[sd_json],
|
||||
outputs=[
|
||||
sd_gallery,
|
||||
sd_status,
|
||||
],
|
||||
)
|
||||
gen_kwargs = dict(
|
||||
fn=shark_sd_fn_dict_input,
|
||||
inputs=[sd_json],
|
||||
outputs=[
|
||||
sd_gallery,
|
||||
sd_status,
|
||||
],
|
||||
)
|
||||
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs)
|
||||
generate_click = (
|
||||
stable_diffusion.click(**status_kwargs)
|
||||
.then(**pull_kwargs)
|
||||
.then(**gen_kwargs)
|
||||
)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
prompt_submit = prompt.submit(**status_kwargs).then(**pull_kwargs)
|
||||
neg_prompt_submit = negative_prompt.submit(**status_kwargs).then(**pull_kwargs)
|
||||
generate_click = (
|
||||
stable_diffusion.click(**status_kwargs).then(**pull_kwargs).then(**gen_kwargs)
|
||||
)
|
||||
stop_batch.click(
|
||||
fn=cancel_sd,
|
||||
cancels=[prompt_submit, neg_prompt_submit, generate_click],
|
||||
)
|
||||
|
||||
@@ -29,3 +29,15 @@ def hsl_color(alpha: float, start, end):
|
||||
|
||||
# Return a CSS HSL string
|
||||
return f"hsl({math.floor(result)}, 80%, 35%)"
|
||||
|
||||
|
||||
def none_to_str_none(props: dict):
|
||||
for key in props:
|
||||
props[key] = "None" if props[key] == None else props[key]
|
||||
return props
|
||||
|
||||
|
||||
def str_none_to_none(props: dict):
|
||||
for key in props:
|
||||
props[key] = None if props[key] == "None" else props[key]
|
||||
return props
|
||||
|
||||
@@ -11,6 +11,40 @@ checkpoints_filetypes = (
|
||||
"*.safetensors",
|
||||
)
|
||||
|
||||
default_sd_config = r"""{
|
||||
"prompt": [
|
||||
"a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smoke coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"
|
||||
],
|
||||
"negative_prompt": [
|
||||
"watermark, signature, logo, text, lowres, ((monochrome, grayscale)), blurry, ugly, blur, oversaturated, cropped"
|
||||
],
|
||||
"sd_init_image": [null],
|
||||
"height": 512,
|
||||
"width": 512,
|
||||
"steps": 50,
|
||||
"strength": 0.8,
|
||||
"guidance_scale": 7.5,
|
||||
"seed": "-1",
|
||||
"batch_count": 1,
|
||||
"batch_size": 1,
|
||||
"scheduler": "EulerDiscrete",
|
||||
"base_model_id": "stabilityai/stable-diffusion-2-1-base",
|
||||
"custom_weights": null,
|
||||
"custom_vae": null,
|
||||
"precision": "fp16",
|
||||
"device": "AMD Radeon RX 7900 XTX => vulkan://0",
|
||||
"ondemand": false,
|
||||
"repeatable_seeds": false,
|
||||
"resample_type": "Nearest Neighbor",
|
||||
"controlnets": {},
|
||||
"embeddings": {}
|
||||
}"""
|
||||
|
||||
|
||||
def write_default_sd_config(path):
|
||||
with open(path, "w") as f:
|
||||
f.write(default_sd_config)
|
||||
|
||||
|
||||
def safe_name(name):
|
||||
return name.replace("/", "_").replace("-", "_")
|
||||
@@ -21,11 +55,14 @@ def get_path_stem(path):
|
||||
return path.stem
|
||||
|
||||
|
||||
def get_resource_path(relative_path):
|
||||
def get_resource_path(path):
|
||||
"""Get absolute path to resource, works for dev and for PyInstaller"""
|
||||
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
result = Path(os.path.join(base_path, relative_path)).resolve(strict=False)
|
||||
return result
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
else:
|
||||
base_path = getattr(sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)))
|
||||
result = Path(os.path.join(base_path, path)).resolve(strict=False)
|
||||
return result
|
||||
|
||||
|
||||
def get_configs_path() -> Path:
|
||||
@@ -48,36 +85,37 @@ def get_generated_imgs_todays_subdir() -> str:
|
||||
|
||||
|
||||
def create_checkpoint_folders():
|
||||
dir = ["vae", "lora", "../vmfb"]
|
||||
if not cmd_opts.ckpt_dir:
|
||||
dir.insert(0, "models")
|
||||
else:
|
||||
if not os.path.isdir(cmd_opts.ckpt_dir):
|
||||
dir = ["checkpoints", "vae", "lora", "vmfb"]
|
||||
if not os.path.isdir(cmd_opts.ckpt_dir):
|
||||
try:
|
||||
os.makedirs(cmd_opts.ckpt_dir)
|
||||
except OSError:
|
||||
sys.exit(
|
||||
f"Invalid --ckpt_dir argument, "
|
||||
f"{cmd_opts.ckpt_dir} folder does not exists."
|
||||
f"{cmd_opts.ckpt_dir} folder does not exist, and cannot be created."
|
||||
)
|
||||
|
||||
for root in dir:
|
||||
Path(get_checkpoints_path(root)).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def get_checkpoints_path(model=""):
|
||||
return get_resource_path(f"../models/{model}")
|
||||
def get_checkpoints_path(model_type=""):
|
||||
return get_resource_path(os.path.join(cmd_opts.ckpt_dir, model_type))
|
||||
|
||||
|
||||
def get_checkpoints(model="models"):
|
||||
def get_checkpoints(model_type="checkpoints"):
|
||||
ckpt_files = []
|
||||
file_types = checkpoints_filetypes
|
||||
if model == "lora":
|
||||
if model_type == "lora":
|
||||
file_types = file_types + ("*.pt", "*.bin")
|
||||
for extn in file_types:
|
||||
files = [
|
||||
os.path.basename(x)
|
||||
for x in glob.glob(os.path.join(get_checkpoints_path(model), extn))
|
||||
for x in glob.glob(os.path.join(get_checkpoints_path(model_type), extn))
|
||||
]
|
||||
ckpt_files.extend(files)
|
||||
return sorted(ckpt_files, key=str.casefold)
|
||||
|
||||
|
||||
def get_checkpoint_pathfile(checkpoint_name, model="models"):
|
||||
return os.path.join(get_checkpoints_path(model), checkpoint_name)
|
||||
def get_checkpoint_pathfile(checkpoint_name, model_type="checkpoints"):
|
||||
return os.path.join(get_checkpoints_path(model_type), checkpoint_name)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
# SHARK Annotator
|
||||
gradio==3.34.0
|
||||
gradio==4.19.2
|
||||
jsonlines
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from distutils.sysconfig import get_python_lib
|
||||
import fileinput
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
# Temporary workaround for transformers/__init__.py.
|
||||
path_to_transformers_hook = Path(
|
||||
@@ -16,51 +17,16 @@ else:
|
||||
with open(path_to_transformers_hook, "w") as f:
|
||||
f.write("module_collection_mode = 'pyz+py'")
|
||||
|
||||
path_to_skipfiles = Path(get_python_lib() + "/torch/_dynamo/skipfiles.py")
|
||||
paths_to_skipfiles = [Path(get_python_lib() + "/torch/_dynamo/skipfiles.py"), Path(get_python_lib() + "/torch/_dynamo/trace_rules.py")]
|
||||
|
||||
modules_to_comment = ["abc,", "os,", "posixpath,", "_collections_abc,"]
|
||||
startMonitoring = 0
|
||||
for line in fileinput.input(path_to_skipfiles, inplace=True):
|
||||
if "SKIP_DIRS = " in line:
|
||||
startMonitoring = 1
|
||||
print(line, end="")
|
||||
elif startMonitoring in [1, 2]:
|
||||
if "]" in line:
|
||||
startMonitoring += 1
|
||||
for path in paths_to_skipfiles:
|
||||
if not os.path.isfile(path):
|
||||
continue
|
||||
for line in fileinput.input(path, inplace=True):
|
||||
if "[_module_dir(m) for m in BUILTIN_SKIPLIST]" in line and "x.__name__ for x in BUILTIN_SKIPLIST" not in line:
|
||||
print(f"{line.rstrip()} + [x.__name__ for x in BUILTIN_SKIPLIST]")
|
||||
elif "(_module_dir(m) for m in BUILTIN_SKIPLIST)" in line and "x.__name__ for x in BUILTIN_SKIPLIST" not in line:
|
||||
print(line, end="")
|
||||
print(f"SKIP_DIRS.extend(filter(None, (x.__name__ for x in BUILTIN_SKIPLIST)))")
|
||||
else:
|
||||
flag = True
|
||||
for module in modules_to_comment:
|
||||
if module in line:
|
||||
if not line.startswith("#"):
|
||||
print(f"#{line}", end="")
|
||||
else:
|
||||
print(f"{line[1:]}", end="")
|
||||
flag = False
|
||||
break
|
||||
if flag:
|
||||
print(line, end="")
|
||||
else:
|
||||
print(line, end="")
|
||||
|
||||
# For getting around scikit-image's packaging, laze_loader has had a patch merged but yet to be released.
|
||||
# Refer: https://github.com/scientific-python/lazy_loader
|
||||
path_to_lazy_loader = Path(get_python_lib() + "/lazy_loader/__init__.py")
|
||||
|
||||
for line in fileinput.input(path_to_lazy_loader, inplace=True):
|
||||
if 'stubfile = filename if filename.endswith("i")' in line:
|
||||
print(
|
||||
' stubfile = (filename if filename.endswith("i") else f"{os.path.splitext(filename)[0]}.pyi")',
|
||||
end="",
|
||||
)
|
||||
else:
|
||||
print(line, end="")
|
||||
|
||||
# For getting around timm's packaging.
|
||||
# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505
|
||||
path_to_timm_activations = Path(get_python_lib() + "/timm/layers/activations_jit.py")
|
||||
for line in fileinput.input(path_to_timm_activations, inplace=True):
|
||||
if "@torch.jit.script" in line:
|
||||
print("@torch.jit._script_if_tracing", end="\n")
|
||||
else:
|
||||
print(line, end="")
|
||||
print(line, end="")
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/
|
||||
--pre
|
||||
|
||||
numpy
|
||||
torch
|
||||
torchvision
|
||||
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
|
||||
transformers
|
||||
#jax[cpu]
|
||||
|
||||
# tflitehub dependencies.
|
||||
Pillow
|
||||
|
||||
# web dependecies.
|
||||
gradio
|
||||
altair
|
||||
|
||||
# Testing and support.
|
||||
#lit
|
||||
#pyyaml
|
||||
|
||||
#ONNX and ORT for benchmarking
|
||||
#--extra-index-url https://test.pypi.org/simple/
|
||||
#protobuf
|
||||
#coloredlogs
|
||||
#flatbuffers
|
||||
#sympy
|
||||
#psutil
|
||||
#onnx-weekly
|
||||
#ort-nightly
|
||||
@@ -1,41 +0,0 @@
|
||||
-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
|
||||
--pre
|
||||
|
||||
numpy>1.22.4
|
||||
pytorch-triton
|
||||
torchvision
|
||||
tabulate
|
||||
|
||||
tqdm
|
||||
|
||||
#iree-compiler | iree-runtime should already be installed
|
||||
iree-tools-xla
|
||||
|
||||
# Modelling and JAX.
|
||||
gin-config
|
||||
transformers
|
||||
diffusers
|
||||
#jax[cpu]
|
||||
Pillow
|
||||
|
||||
# Testing and support.
|
||||
lit
|
||||
pyyaml
|
||||
python-dateutil
|
||||
sacremoses
|
||||
sentencepiece
|
||||
|
||||
# web dependecies.
|
||||
gradio==3.44.3
|
||||
altair
|
||||
scipy
|
||||
|
||||
#ONNX and ORT for benchmarking
|
||||
#--extra-index-url https://test.pypi.org/simple/
|
||||
#protobuf
|
||||
#coloredlogs
|
||||
#flatbuffers
|
||||
#sympy
|
||||
#psutil
|
||||
#onnx-weekly
|
||||
#ort-nightly
|
||||
@@ -5,8 +5,9 @@
|
||||
setuptools
|
||||
wheel
|
||||
|
||||
shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine#egg=shark-turbine&subdirectory=core
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine#egg=turbine-models&subdirectory=models
|
||||
torch==2.3.0.dev20240305
|
||||
shark-turbine @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=core
|
||||
turbine-models @ git+https://github.com/nod-ai/SHARK-Turbine.git@main#subdirectory=models
|
||||
|
||||
# SHARK Runner
|
||||
tqdm
|
||||
@@ -26,29 +27,15 @@ parameterized
|
||||
accelerate
|
||||
scipy
|
||||
ftfy
|
||||
gradio==4.8.0
|
||||
gradio==4.19.2
|
||||
altair
|
||||
omegaconf
|
||||
# 0.3.2 doesn't have binaries for arm64
|
||||
safetensors==0.3.1
|
||||
opencv-python
|
||||
scikit-image
|
||||
pytorch_lightning # for runwayml models
|
||||
tk
|
||||
pywebview
|
||||
sentencepiece
|
||||
py-cpuinfo
|
||||
tiktoken # for codegen
|
||||
joblib # for langchain
|
||||
timm # for MiniGPT4
|
||||
langchain
|
||||
einops # for zoedepth
|
||||
pydantic==2.4.1 # pin until pyinstaller-hooks-contrib works with beta versions
|
||||
mpmath==1.3.0
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
pyinstaller
|
||||
|
||||
# For quantized GPTQ models
|
||||
optimum
|
||||
auto_gptq
|
||||
|
||||
2
setup.py
2
setup.py
@@ -7,7 +7,7 @@ import glob
|
||||
with open("README.md", "r", encoding="utf-8") as fh:
|
||||
long_description = fh.read()
|
||||
|
||||
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "0.0.5"
|
||||
PACKAGE_VERSION = os.environ.get("SHARK_PACKAGE_VERSION") or "2.0.0"
|
||||
backend_deps = []
|
||||
|
||||
setup(
|
||||
|
||||
@@ -7,13 +7,13 @@
|
||||
It checks the Python version installed and installs any required build
|
||||
dependencies into a Python virtual environment.
|
||||
If that environment does not exist, it creates it.
|
||||
|
||||
|
||||
.PARAMETER update-src
|
||||
git pulls latest version
|
||||
|
||||
.PARAMETER force
|
||||
removes and recreates venv to force update of all dependencies
|
||||
|
||||
|
||||
.EXAMPLE
|
||||
.\setup_venv.ps1 --force
|
||||
|
||||
@@ -39,7 +39,7 @@ if ($arguments -eq "--force"){
|
||||
Write-Host "deactivating..."
|
||||
Deactivate
|
||||
}
|
||||
|
||||
|
||||
if (Test-Path .\shark.venv\) {
|
||||
Write-Host "removing and recreating venv..."
|
||||
Remove-Item .\shark.venv -Force -Recurse
|
||||
@@ -89,9 +89,7 @@ else {python -m venv .\shark.venv\}
|
||||
python -m pip install --upgrade pip
|
||||
pip install wheel
|
||||
pip install -r requirements.txt
|
||||
pip install --pre torch-mlir torchvision torch --extra-index-url https://download.pytorch.org/whl/nightly/cpu -f https://llvm.github.io/torch-mlir/package-index/
|
||||
pip install --upgrade -f https://nod-ai.github.io/SRT/pip-release-links.html iree-compiler iree-runtime
|
||||
Write-Host "Building SHARK..."
|
||||
pip install -e . -f https://llvm.github.io/torch-mlir/package-index/ -f https://nod-ai.github.io/SRT/pip-release-links.html
|
||||
Write-Host "Build and installation completed successfully"
|
||||
# remove this when windows DLL issues are fixed from LLVM changes
|
||||
pip install --force-reinstall https://github.com/openxla/iree/releases/download/candidate-20240326.843/iree_compiler-20240326.843-cp311-cp311-win_amd64.whl https://github.com/openxla/iree/releases/download/candidate-20240326.843/iree_runtime-20240326.843-cp311-cp311-win_amd64.whl
|
||||
|
||||
Write-Host "Source your venv with ./shark.venv/Scripts/activate"
|
||||
|
||||
@@ -49,58 +49,20 @@ Red=`tput setaf 1`
|
||||
Green=`tput setaf 2`
|
||||
Yellow=`tput setaf 3`
|
||||
|
||||
# Assume no binary torch-mlir.
|
||||
# Currently available for macOS m1&intel (3.11) and Linux(3.8,3.10,3.11)
|
||||
torch_mlir_bin=false
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}Apple macOS detected"
|
||||
if [[ $(uname -m) == 'arm64' ]]; then
|
||||
echo "${Yellow}Apple M1 Detected"
|
||||
hash rustc 2>/dev/null
|
||||
if [ $? -eq 0 ];then
|
||||
echo "${Green}rustc found to compile HF tokenizers"
|
||||
else
|
||||
echo "${Red}Could not find rustc" >&2
|
||||
echo "${Red}Please run:"
|
||||
echo "${Red}curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
echo "${Yellow}Run the following commands to setup your SSL certs for your Python version if you see SSL errors with tests"
|
||||
echo "${Yellow}/Applications/Python\ 3.XX/Install\ Certificates.command"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.11" ]; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
elif [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected"
|
||||
if [ "$PYTHON_VERSION_X_Y" == "3.8" ] || [ "$PYTHON_VERSION_X_Y" == "3.10" ] || [ "$PYTHON_VERSION_X_Y" == "3.11" ] ; then
|
||||
torch_mlir_bin=true
|
||||
fi
|
||||
else
|
||||
echo "${Red}OS not detected. Pray and Play"
|
||||
fi
|
||||
|
||||
# Upgrade pip and install requirements.
|
||||
$PYTHON -m pip install --upgrade pip || die "Could not upgrade pip"
|
||||
$PYTHON -m pip install --upgrade -r "$TD/requirements.txt"
|
||||
if [ "$torch_mlir_bin" = true ]; then
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
echo "Could not install torch-mlir" >&2
|
||||
fi
|
||||
fi
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "MacOS detected. Installing torch-mlir from .whl, to avoid dependency problems with torch."
|
||||
$PYTHON -m pip uninstall -y timm #TEMP FIX FOR MAC
|
||||
$PYTHON -m pip install --pre --no-cache-dir torch-mlir -f https://llvm.github.io/torch-mlir/package-index/ -f https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
echo "${Red}No binaries found for Python $PYTHON_VERSION_X_Y on $(uname -s)"
|
||||
echo "${Yello}Python 3.11 supported on macOS and 3.8,3.10 and 3.11 on Linux"
|
||||
echo "${Red}Please build torch-mlir from source in your environment"
|
||||
exit 1
|
||||
$PYTHON -m pip install --pre torch-mlir -f https://llvm.github.io/torch-mlir/package-index/
|
||||
if [ $? -eq 0 ];then
|
||||
echo "Successfully Installed torch-mlir"
|
||||
else
|
||||
echo "Could not install torch-mlir" >&2
|
||||
fi
|
||||
fi
|
||||
if [[ -z "${USE_IREE}" ]]; then
|
||||
rm .use-iree
|
||||
@@ -116,19 +78,6 @@ else
|
||||
echo "Not installing a backend, please make sure to add your backend to PYTHONPATH"
|
||||
fi
|
||||
|
||||
if [[ ! -z "${IMPORTER}" ]]; then
|
||||
echo "${Yellow}Installing importer tools.."
|
||||
if [[ $(uname -s) = 'Linux' ]]; then
|
||||
echo "${Yellow}Linux detected.. installing Linux importer tools"
|
||||
#Always get the importer tools from upstream IREE
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer.txt" -f https://openxla.github.io/iree/pip-release-links.html --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
elif [[ $(uname -s) = 'Darwin' ]]; then
|
||||
echo "${Yellow}macOS detected.. installing macOS importer tools"
|
||||
#Conda seems to have some problems installing these packages and hope they get resolved upstream.
|
||||
$PYTHON -m pip install --no-warn-conflicts --upgrade -r "$TD/requirements-importer-macos.txt" -f ${RUNTIME} --extra-index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ $(uname -s) = 'Darwin' ]]; then
|
||||
PYTORCH_URL=https://download.pytorch.org/whl/nightly/torch/
|
||||
else
|
||||
|
||||
@@ -91,7 +91,7 @@ _IREE_TARGET_MAP = {
|
||||
"cpu-task": "llvm-cpu",
|
||||
"cpu-sync": "llvm-cpu",
|
||||
"cuda": "cuda",
|
||||
"vulkan": "vulkan",
|
||||
"vulkan": "vulkan-spirv",
|
||||
"metal": "metal",
|
||||
"rocm": "rocm",
|
||||
"intel-gpu": "opencl-spirv",
|
||||
@@ -122,9 +122,7 @@ def check_device_drivers(device):
|
||||
)
|
||||
return True
|
||||
except RuntimeError as re:
|
||||
print(
|
||||
f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}"
|
||||
)
|
||||
print(f"[ERR] Failed to get driver for {device} with error:\n{repr(re)}")
|
||||
return True
|
||||
|
||||
# Unknown device. We assume drivers are installed.
|
||||
|
||||
@@ -113,8 +113,8 @@ def get_iree_frontend_args(frontend):
|
||||
# Common args to be used given any frontend or device.
|
||||
def get_iree_common_args(debug=False):
|
||||
common_args = [
|
||||
"--iree-vm-bytecode-module-strip-source-map=true",
|
||||
"--iree-util-zero-fill-elided-attrs",
|
||||
"--mlir-elide-elementsattrs-if-larger=10",
|
||||
]
|
||||
if debug == True:
|
||||
common_args.extend(
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_vulkan_target_env(vulkan_target_triple):
|
||||
device_type = get_device_type(triple)
|
||||
# get capabilities
|
||||
capabilities = get_vulkan_target_capabilities(triple)
|
||||
target_env = f"#vk.target_env<{version}, r({revision}), {extensions}, {vendor}:{device_type}, #vk.caps< {capabilities} >>"
|
||||
target_env = f"<#spirv.vce<{version}, r({revision}), {extensions}>, {vendor}:{device_type}, #spirv.resource_limits< {capabilities} >>"
|
||||
return target_env
|
||||
|
||||
|
||||
@@ -63,62 +63,62 @@ def get_extensions(triple):
|
||||
arch, product, os = triple
|
||||
if arch == "m1":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "valhall":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "adreno":
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
if os == "android31":
|
||||
ext.append("VK_KHR_8bit_storage")
|
||||
ext.append("SPV_KHR_8bit_storage")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if get_vendor(triple) == "SwiftShader":
|
||||
ext = ["VK_KHR_storage_buffer_storage_class"]
|
||||
ext = ["SPV_KHR_storage_buffer_storage_class"]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
if arch == "unknown":
|
||||
ext = [
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
]
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
ext = [
|
||||
"VK_KHR_16bit_storage",
|
||||
"VK_KHR_8bit_storage",
|
||||
"VK_KHR_shader_float16_int8",
|
||||
"VK_KHR_spirv_1_4",
|
||||
"VK_KHR_storage_buffer_storage_class",
|
||||
"VK_KHR_variable_pointers",
|
||||
"SPV_KHR_16bit_storage",
|
||||
"SPV_KHR_8bit_storage",
|
||||
"SPV_KHR_shader_float16_int8",
|
||||
"SPV_KHR_spirv_1_4",
|
||||
"SPV_KHR_storage_buffer_storage_class",
|
||||
"SPV_KHR_variable_pointers",
|
||||
"VK_EXT_subgroup_size_control",
|
||||
]
|
||||
|
||||
if get_vendor(triple) == "NVIDIA" or arch == "rdna3":
|
||||
ext.append("VK_KHR_cooperative_matrix")
|
||||
ext.append("SPV_KHR_cooperative_matrix")
|
||||
if get_vendor(triple) == ["NVIDIA", "AMD", "Intel"]:
|
||||
ext.append("VK_KHR_shader_integer_dot_product")
|
||||
ext.append("SPV_KHR_shader_integer_dot_product")
|
||||
return make_ext_list(ext_list=ext)
|
||||
|
||||
|
||||
@@ -186,13 +186,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
"Quad": 128,
|
||||
"PartitionedNV": 256,
|
||||
}
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["maxComputeWorkGroupInvocations"] = 128
|
||||
cap["maxComputeWorkGroupSize"] = [128, 128, 64]
|
||||
cap["subgroupSize"] = 32
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["max_compute_workgroup_invocations"] = 128
|
||||
cap["max_compute_workgroup_size"] = [128, 128, 64]
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = ["Basic"]
|
||||
cap["minSubgroupSize"] = None
|
||||
cap["maxSubgroupSize"] = None
|
||||
cap["min_subgroup_size"] = None
|
||||
cap["max_subgroup_size"] = None
|
||||
cap["shaderFloat16"] = False
|
||||
cap["shaderFloat64"] = False
|
||||
cap["shaderInt8"] = False
|
||||
@@ -209,13 +209,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["coopmatCases"] = None
|
||||
|
||||
if arch in ["rdna1", "rdna2", "rdna3"]:
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -244,7 +244,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
if arch == "rdna3":
|
||||
# TODO: Get scope value
|
||||
cap["coopmatCases"] = [
|
||||
"mSize = 16, nSize = 16, kSize = 16, aType = f16, bType = f16, cType = f16, resultType = f16, accSat = false, scope = #vk.scope<Subgroup>"
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f16, result_type = f16, acc_sat = false, scope = <Subgroup>",
|
||||
"m_size = 16, n_size = 16, k_size = 16, a_type = f16, b_type = f16, c_type = f32, result_type = f32, acc_sat = false, scope = <Subgroup>"
|
||||
]
|
||||
|
||||
if product == "rx5700xt":
|
||||
@@ -252,11 +253,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["storagePushConstant8"] = False
|
||||
|
||||
elif arch in ["rgcn5", "rgcn4", "rgcn3"]:
|
||||
cap["maxComputeSharedMemorySize"] = 65536
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 65536
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -267,8 +268,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
"Clustered",
|
||||
"Quad",
|
||||
]
|
||||
cap["minSubgroupSize"] = 64
|
||||
cap["maxSubgroupSize"] = 64
|
||||
cap["min_subgroup_size"] = 64
|
||||
cap["max_subgroup_size"] = 64
|
||||
|
||||
if arch == "rgcn5":
|
||||
cap["shaderFloat16"] = True
|
||||
@@ -290,11 +291,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "m1":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -321,11 +322,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "valhall":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 512
|
||||
cap["maxComputeWorkGroupSize"] = [512, 512, 512]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 512
|
||||
cap["max_compute_workgroup_size"] = [512, 512, 512]
|
||||
|
||||
cap["subgroupSize"] = 16
|
||||
cap["subgroup_size"] = 16
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -352,11 +353,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "arc":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -385,8 +386,8 @@ def get_vulkan_target_capabilities(triple):
|
||||
|
||||
elif arch == "cpu":
|
||||
if product == "swiftshader":
|
||||
cap["maxComputeSharedMemorySize"] = 16384
|
||||
cap["subgroupSize"] = 4
|
||||
cap["max_compute_shared_memory_size"] = 16384
|
||||
cap["subgroup_size"] = 4
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -397,13 +398,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
]
|
||||
|
||||
elif arch in ["pascal"]:
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1536
|
||||
cap["maxComputeWorkGroupSize"] = [1536, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1536
|
||||
cap["max_compute_workgroup_size"] = [1536, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -431,13 +432,13 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch in ["ampere", "turing"]:
|
||||
cap["maxComputeSharedMemorySize"] = 49152
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 1024]
|
||||
cap["max_compute_shared_memory_size"] = 49152
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 1024]
|
||||
|
||||
cap["subgroupSize"] = 32
|
||||
cap["minSubgroupSize"] = 32
|
||||
cap["maxSubgroupSize"] = 32
|
||||
cap["subgroup_size"] = 32
|
||||
cap["min_subgroup_size"] = 32
|
||||
cap["max_subgroup_size"] = 32
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -471,11 +472,11 @@ def get_vulkan_target_capabilities(triple):
|
||||
]
|
||||
|
||||
elif arch == "adreno":
|
||||
cap["maxComputeSharedMemorySize"] = 32768
|
||||
cap["maxComputeWorkGroupInvocations"] = 1024
|
||||
cap["maxComputeWorkGroupSize"] = [1024, 1024, 64]
|
||||
cap["max_compute_shared_memory_size"] = 32768
|
||||
cap["max_compute_workgroup_invocations"] = 1024
|
||||
cap["max_compute_workgroup_size"] = [1024, 1024, 64]
|
||||
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["subgroupFeatures"] = [
|
||||
"Basic",
|
||||
"Vote",
|
||||
@@ -491,14 +492,14 @@ def get_vulkan_target_capabilities(triple):
|
||||
cap["shaderInt16"] = True
|
||||
|
||||
cap["storageBuffer16BitAccess"] = True
|
||||
if os == "andorid31":
|
||||
if os == "android31":
|
||||
cap["uniformAndStorageBuffer8BitAccess"] = True
|
||||
|
||||
cap["variablePointers"] = True
|
||||
cap["variablePointersStorageBuffer"] = True
|
||||
|
||||
elif arch == "unknown":
|
||||
cap["subgroupSize"] = 64
|
||||
cap["subgroup_size"] = 64
|
||||
cap["variablePointers"] = False
|
||||
cap["variablePointersStorageBuffer"] = False
|
||||
else:
|
||||
@@ -521,14 +522,14 @@ def get_vulkan_target_capabilities(triple):
|
||||
res += f"{k} = {'unit' if v == True else None}, "
|
||||
elif isinstance(v, list):
|
||||
if k == "subgroupFeatures":
|
||||
res += f"subgroupFeatures = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "maxComputeWorkGroupSize":
|
||||
res += f"maxComputeWorkGroupSize = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
res += f"subgroup_features = {get_subgroup_val(v)}: i32, "
|
||||
elif k == "max_compute_workgroup_size":
|
||||
res += f"max_compute_workgroup_size = dense<{get_comma_sep_str(v)}>: vector<{len(v)}xi32>, "
|
||||
elif k == "coopmatCases":
|
||||
cmc = ""
|
||||
for case in v:
|
||||
cmc += f"#vk.coop_matrix_props<{case}>, "
|
||||
res += f"cooperativeMatrixPropertiesKHR = [{cmc[:-2]}], "
|
||||
cmc += f"#spirv.coop_matrix_props_khr<{case}>, "
|
||||
res += f"cooperative_matrix_properties_khr = [{cmc[:-2]}], "
|
||||
else:
|
||||
res += f"{k} = {get_comma_sep_str(v)}, "
|
||||
else:
|
||||
|
||||
@@ -144,6 +144,8 @@ def get_vulkan_target_triple(device_name):
|
||||
# Intel Targets
|
||||
elif any(x in device_name for x in ("A770", "A750")):
|
||||
triple = f"arc-770-{system_os}"
|
||||
elif "v620" in device_name:
|
||||
triple = f"rdna2-v620-{system_os}"
|
||||
|
||||
# Adreno Targets
|
||||
elif all(x in device_name for x in ("Adreno", "740")):
|
||||
@@ -169,7 +171,7 @@ def get_vulkan_triple_flag(device_name="", device_num=0, extra_args=[]):
|
||||
print(
|
||||
f"Found vulkan device {vulkan_device}. Using target triple {triple}"
|
||||
)
|
||||
return f"-iree-vulkan-target-triple={triple}"
|
||||
return f"--iree-vulkan-target-triple={triple}"
|
||||
print(
|
||||
"""Optimized kernel for your target device is not added yet.
|
||||
Contact SHARK Admin on discord[https://discord.com/invite/RUqY2h2s9u]
|
||||
@@ -184,7 +186,8 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
|
||||
res_vulkan_flag = []
|
||||
res_vulkan_flag += [
|
||||
"--iree-stream-resource-max-allocation-size=3221225472"
|
||||
"--iree-stream-resource-max-allocation-size=3221225472",
|
||||
"--iree-flow-inline-constants-max-byte-length=0"
|
||||
]
|
||||
vulkan_triple_flag = None
|
||||
for arg in extra_args:
|
||||
@@ -197,10 +200,8 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
vulkan_triple_flag = get_vulkan_triple_flag(
|
||||
device_num=device_num, extra_args=extra_args
|
||||
)
|
||||
res_vulkan_flag += [vulkan_triple_flag]
|
||||
|
||||
if vulkan_triple_flag is not None:
|
||||
vulkan_target_env = get_vulkan_target_env_flag(vulkan_triple_flag)
|
||||
res_vulkan_flag.append(vulkan_target_env)
|
||||
return res_vulkan_flag
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user