Compare commits

...

43 Commits

Author SHA1 Message Date
Anush Elangovan
6d6a9dcae8 Revert "Revert "Enable --device_allocator=caching""
This reverts commit 41ee65b377.
2023-02-09 23:00:32 -08:00
Anush Elangovan
41ee65b377 Revert "Enable --device_allocator=caching"
This reverts commit 83fe477066.
2023-02-09 23:00:06 -08:00
Anush Elangovan
83fe477066 Enable --device_allocator=caching 2023-02-09 22:58:46 -08:00
yzhang93
4ca84ee4ee Revert "Delete unnecessary arg setting (#978)" (#985)
This reverts commit 83c69ecd49.
2023-02-09 16:44:26 -08:00
Ean Garvey
c28cc4c919 Fix local_tank_cache handling in shark_downloader. (#981) 2023-02-09 14:52:03 -06:00
yzhang93
e9864cb3f7 Modify the annotation OTF to return bytecode module (#980) 2023-02-08 14:29:43 -08:00
yzhang93
83c69ecd49 Delete unnecessary arg setting (#978) 2023-02-08 10:30:18 -08:00
Prashant Kumar
3595b4aaff Incorporate latest changes in the shark_dynamo backend. 2023-02-08 20:37:30 +05:30
Abhishek Varma
3a9cfe113a Fix SD restart error in exe file (#975)
-- This commit fixes SD restart error in exe file by creating
   variants.json in CWD instead of a relative path.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
Co-authored-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-08 06:14:08 -08:00
yzhang93
c9966127da Fix iree flags to be able to run on rdna2 (#972) 2023-02-07 16:39:32 -08:00
Ean Garvey
51300d33a7 Remove non-SD args from generate_sharktank.py (#970) 2023-02-07 13:29:55 -06:00
Gaurav Shukla
5af124c5a5 [SD] Add batch count in stable diffusion
Signed-Off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-07 23:26:46 +05:30
Abhishek Varma
eeb20b531a Fix restart SD session error + override args.use_tuned temporarily
-- This commit fixes the session restart error for SD.
-- It also overrides `args.use_tuned` for `import_mlir`, and sets
   `use_tuned` as `False`.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-07 19:50:48 +05:30
cstueckrath
9dca842c22 Update .gitignore to exclude models (#967)
the models folder will be stashed along with other changes and most likely kill git doing so.
2023-02-07 01:48:36 -08:00
Ean Garvey
1eb9436836 Fix generate_sharktank args. 2023-02-07 14:06:07 +05:30
Ean Garvey
9604d9ce81 make --update_tank update only if hash mismatch 2023-02-07 14:06:07 +05:30
Ean Garvey
481d0553d8 Remove unnecessary repro_dir / shark_tmp usage 2023-02-07 14:06:07 +05:30
powderluv
60035cd63a Add css in exe (#963)
exe should now default to dark theme too
2023-02-06 15:26:08 -08:00
drumicube
d35f992ace Bring back the --runs options for the cmd command and fix wrong seed/model reported in json, csv and png (#962) 2023-02-06 15:16:50 -06:00
Daniel Garvey
157ae64f9d print to stdout for test visibility (#937)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-06 01:03:27 -08:00
powderluv
ffa17f6057 Update sd_dark_theme.css 2023-02-06 01:01:50 -08:00
drumicube
d695a43e32 Make the dark theme default while launching web server (#954) 2023-02-05 07:25:45 -08:00
powderluv
01f6b4e6f0 Update README.md 2023-02-04 23:40:13 -08:00
yzhang93
7cf31a6ae4 Fix iree-benchmark flag names (#952) 2023-02-04 22:24:18 -08:00
Quinn Dawkins
fbd6224b04 Revert "Revert pipelines (#948)" (#951)
This reverts commit 8115b26079.
Additionally fixes img2col by adding detach elementwise from named op
passes.
2023-02-04 22:44:08 -05:00
powderluv
8115b26079 Revert pipelines (#948)
* Revert "[SD] Modify the flags to use --iree-preprocessing-pass-pipeline (#914)"

This reverts commit a783c089a9.

* Revert "Fix iree flags due to the change in shark-runtime (#944)"

This reverts commit 1d38d49162.
2023-02-04 07:09:51 -08:00
powderluv
820586ac68 Update README.md 2023-02-04 01:01:11 -08:00
powderluv
4a7441ed07 Update profiling_with_iree.md 2023-02-04 00:47:57 -08:00
powderluv
383741f284 Update stable_diffusion_amd.md 2023-02-04 00:40:47 -08:00
powderluv
2bbc4e0e9f Update README.md 2023-02-04 00:35:40 -08:00
powderluv
a7237244b0 Send users to the .exe file first 2023-02-04 00:30:32 -08:00
yzhang93
1d38d49162 Fix iree flags due to the change in shark-runtime (#944) 2023-02-03 21:34:02 -08:00
yzhang93
a783c089a9 [SD] Modify the flags to use --iree-preprocessing-pass-pipeline (#914)
* [SD] Modify the flags to use --iree-preprocessing-pass-pipeline

* Fix flags in sd_annotation
2023-02-03 15:08:02 -08:00
powderluv
e7907dc532 Disable tuned models for sm_89 (#943)
Looks like tuning on A100 doesn't necessarily translate to 40xx.
2023-02-03 14:30:46 -08:00
powderluv
394413679d Fix ckpt_dir (#939) 2023-02-03 12:54:19 -08:00
powderluv
37189f14cb roll to 492 2023-02-03 11:59:18 -08:00
powderluv
0b1ee81901 Minor webui changes (#938) 2023-02-03 11:26:45 -08:00
Gaurav Shukla
00cf73f9b8 [SD] Merge model id dropdown and .ckpt dropdown (#936)
- use_tuned is set to False for custom checkpoints.

Signed-off-by: Gaurav Shukla <gaurav@nod-labs.com>
2023-02-03 10:43:33 -08:00
Abhishek Varma
5a5f285493 [apps-SD] Prepone loading of vmfbs + restructure the SD pipeline
-- This commit prepones loading of vmfbs, if present, for all sub-models.
-- It also involves restructuring the SD pipeline to achieve the loading
   of vmfbs smoothly and postpones processing of checkpoint files only when
   required.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
2023-02-03 20:21:24 +05:30
powderluv
7f2ea454b6 revert /base variants as they are different (#929)
sd2_1base is different from VAE base (for older cards)
2023-02-03 01:27:32 -08:00
Daniel Garvey
7c14002118 Map 2_1 to 2_1_base (#927)
* fix broken paths for older models

* adds a mapping from sd_2_1 to sd_2_1_base

we only have models in models_db for 2_1_base.
now that diffusers is fixed we can actually generate
2_1 itself, but until we add support for that in the tank
we should fetch 2_1_base for no-import_mlir

---------

Co-authored-by: dan <dan@nod-labs.com>
2023-02-02 19:03:19 -08:00
powderluv
3e9554f0a1 roll to 487 2023-02-02 19:02:39 -08:00
Daniel Garvey
e11ffec544 fix broken paths for older models (#926)
Co-authored-by: dan <dan@nod-labs.com>
2023-02-02 15:48:19 -08:00
38 changed files with 798 additions and 640 deletions

View File

@@ -143,7 +143,7 @@ jobs:
then
export SHA=$(git log -1 --format='%h')
gsutil -m cp -r $GITHUB_WORKSPACE/gen_shark_tank/* gs://shark_tank/${DATE}_$SHA
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/latest/
gsutil -m cp -r gs://shark_tank/${DATE}_$SHA/* gs://shark_tank/nightly/
fi
rm -rf ./wheelhouse/nodai*

View File

@@ -111,7 +111,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k cpu
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cpu
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cpu_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cpu_latest.csv
@@ -121,7 +121,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} BENCHMARK=1 IMPORTER=1 ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k cuda
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k cuda
gsutil cp ./bench_results.csv gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv
gsutil cp gs://shark-public/builder/bench_results/${DATE}/bench_results_cuda_${SHORT_SHA}.csv gs://shark-public/builder/bench_results/latest/bench_results_cuda_latest.csv
# Disabled due to black image bug
@@ -136,7 +136,7 @@ jobs:
export DYLD_LIBRARY_PATH=/usr/local/lib/
echo $PATH
pip list | grep -E "torch|iree"
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k vulkan
pytest --ci --ci_sha=${SHORT_SHA} --local_tank_cache="/Volumes/builder/anush/shark_cache" -k vulkan --update_tank
- name: Validate Vulkan Models (a100)
if: matrix.suite == 'vulkan' && matrix.os == 'a100'
@@ -144,7 +144,7 @@ jobs:
cd $GITHUB_WORKSPACE
PYTHON=python${{ matrix.python-version }} ./setup_venv.sh
source shark.venv/bin/activate
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --local_tank_cache="./shark_tmp/shark_cache" -k vulkan
pytest --forked --benchmark --ci --ci_sha=${SHORT_SHA} --update_tank -k vulkan
python build_tools/stable_diffusion_testing.py --device=vulkan
- name: Validate Vulkan Models (Windows)
@@ -158,4 +158,5 @@ jobs:
if: matrix.suite == 'vulkan' && matrix.os == '7950x'
run: |
./setup_venv.ps1
./shark.venv/Scripts/activate
python build_tools/stable_diffusion_testing.py --device=vulkan

10
.gitignore vendored
View File

@@ -159,6 +159,9 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# vscode related
.vscode
# Shark related artefacts
*venv/
shark_tmp/
@@ -172,3 +175,10 @@ onnx_models/
# Generated images
generated_imgs/
# Custom model related artefacts
variants.json
models/
# models folder
apps/stable_diffusion/web/models/

View File

@@ -1,12 +1,47 @@
# SHARK
High Performance Machine Learning and Data Analytics for CPUs, GPUs, Accelerators and Heterogeneous Clusters
High Performance Machine Learning Distribution
[![Nightly Release](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/nightly.yml)
[![Validate torch-models on Shark Runtime](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml/badge.svg)](https://github.com/nod-ai/SHARK/actions/workflows/test-models.yml)
## Installation (Windows, Linux and macOS)
<details>
<summary>Prerequisites - Drivers </summary>
#### Install your Windows hardware drivers
* [AMD RDNA Users] Download this specific driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree). Latest drivers may not work.
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
#### Linux Drivers
* MESA / RADV drivers wont work with FP16. Please use the latest AMGPU-PRO drivers (non-pro OSS drivers also wont work) or the latest NVidia Linux Drivers.
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
</details>
### Quick Start for SHARK Stable Diffusion for Windows 10/11 Users
Install Driver from [Prerequisites](https://github.com/nod-ai/SHARK#install-your-hardware-drivers) above
Download the latest .exe https://github.com/nod-ai/SHARK/releases.
Double click the .exe and you should have the [UI]( http://localhost:8080/?__theme=dark) in the browser.
If you have custom models (ckpt, safetensors) put in a `models/` directory where the .exe is.
Enjoy.
Some known AMD Driver quirks and fixes with cursors are documented [here](https://github.com/nod-ai/SHARK/blob/main/apps/stable_diffusion/stable_diffusion_amd.md ).
<details>
<summary>Advanced Installation (Only for developers)</summary>
## Advanced Installation (Windows, Linux and macOS) for developers
## Check out the code
@@ -63,14 +98,6 @@ source shark.venv/bin/activate
### Run Stable Diffusion on your device - Commandline
#### Install your hardware drivers
* [AMD RDNA Users] Download the latest driver [here](https://www.amd.com/en/support/kb/release-notes/rn-rad-win-22-11-1-mril-iree)
* [macOS Users] Download and install the 1.3.216 Vulkan SDK from [here](https://sdk.lunarg.com/sdk/download/1.3.216.0/mac/vulkansdk-macos-1.3.216.0.dmg). Newer versions of the SDK will not work.
* [Nvidia Users] Download and install the latest CUDA / Vulkan drivers from [here](https://developer.nvidia.com/cuda-downloads)
Other users please ensure you have your latest vendor drivers and Vulkan SDK from [here](https://vulkan.lunarg.com/sdk/home) and if you are using vulkan check `vulkaninfo` works in a terminal window
#### Windows 10/11 Users
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
@@ -82,6 +109,7 @@ python3.10 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vu
```
You can replace `vulkan` with `cpu` to run on your CPU or with `cuda` to run on CUDA devices. If you have multiple vulkan devices you can address them with `--device=vulkan://1` etc
</details>
The output on a 7900XTX would like:
@@ -101,9 +129,6 @@ Here are some samples generated:
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.

View File

@@ -26,13 +26,13 @@ Run / Benchmark Command (FP32 - NCHW):
```shell
## Vulkan AMD:
iree-benchmark-module --module_file=/path/to/output/vmfb --entry_function=forward --device=vulkan --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CUDA:
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=cuda --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
## CPU:
iree-benchmark-module --module_file=/path/to/vmfb --entry_function=forward --device=local-task --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
@@ -40,5 +40,48 @@ Run via vulkan_gui for RGP Profiling:
To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb
```shell
./build/vulkan_gui/iree-vulkan-gui --module_file=/path/to/unet.vmfb --function_input=1x4x64x64xf32 --function_input=1xf32 --function_input=2x77x768xf32 --function_input=f32=1.0 --function_input=f32=1.0
./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0
```
</details>
<details>
<summary>Debug Commands</summary>
## Debug commands and other advanced usage follows.
```shell
python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text"
```
## dump all dispatch .spv and isa using amdllpc
```shell
python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa
```
## Compile and save the .vmfb (using vulkan fp16 as an example):
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb
```
## Capture an RGP trace
```shell
python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp
```
## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example):
```shell
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16
```
## Run the unet module with iree-benchmark-module (same config as above):
```shell
##if you want to use .npz inputs:
unzip ~/.local/shark_tank/<your unet>/inputs.npz
iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy
```
</details>

View File

@@ -18,6 +18,7 @@ from apps.stable_diffusion.src import (
Text2ImagePipeline,
get_schedulers,
set_init_device_flags,
utils,
)
@@ -59,8 +60,8 @@ if args.clear_all:
shutil.rmtree(os.path.join(home, ".local/shark_tank"))
# save output images and the inputs correspoding to it.
def save_output_img(output_img):
# save output images and the inputs corresponding to it.
def save_output_img(output_img, img_seed):
output_path = args.output_dir if args.output_dir else Path.cwd()
generated_imgs_path = Path(output_path, "generated_imgs")
generated_imgs_path.mkdir(parents=True, exist_ok=True)
@@ -68,9 +69,13 @@ def save_output_img(output_img):
prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15])
out_img_name = (
f"{prompt_slice}_{args.seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
f"{prompt_slice}_{img_seed}_{dt.now().strftime('%y%m%d_%H%M%S')}"
)
img_model = args.hf_model_id
if args.ckpt_loc:
img_model = os.path.basename(args.ckpt_loc)
if args.output_img_format == "jpg":
out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg")
output_img.save(out_img_path, quality=95, subsampling=0)
@@ -81,7 +86,7 @@ def save_output_img(output_img):
if args.write_metadata_to_png:
pngInfo.add_text(
"parameters",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {args.seed}, Size: {args.width}x{args.height}, Model: {args.hf_model_id}",
f"{args.prompts[0]}\nNegative prompt: {args.negative_prompts[0]}\nSteps:{args.steps}, Sampler: {args.scheduler}, CFG scale: {args.guidance_scale}, Seed: {img_seed}, Size: {args.width}x{args.height}, Model: {img_model}",
)
output_img.save(out_img_path, "PNG", pnginfo=pngInfo)
@@ -93,11 +98,11 @@ def save_output_img(output_img):
)
new_entry = {
"VARIANT": args.hf_model_id,
"VARIANT": img_model,
"SCHEDULER": args.scheduler,
"PROMPT": args.prompts[0],
"NEG_PROMPT": args.negative_prompts[0],
"SEED": args.seed,
"SEED": img_seed,
"CFG_SCALE": args.guidance_scale,
"PRECISION": args.precision,
"STEPS": args.steps,
@@ -133,11 +138,11 @@ def txt2img_inf(
steps: int,
guidance_scale: float,
seed: int,
batch_count: int,
batch_size: int,
scheduler: str,
model_id: str,
custom_model_id: str,
ckpt_loc: str,
custom_model: str,
hf_model_id: str,
precision: str,
device: str,
max_length: int,
@@ -151,13 +156,31 @@ def txt2img_inf(
args.prompts = [prompt]
args.negative_prompts = [negative_prompt]
args.guidance_scale = guidance_scale
args.seed = seed
args.steps = steps
args.scheduler = scheduler
args.hf_model_id = custom_model_id if custom_model_id else model_id
args.ckpt_loc = "" if ckpt_loc == "None" else ckpt_loc
# set ckpt_loc and hf_model_id.
types = (
".ckpt",
".safetensors",
) # the tuple of file types
args.ckpt_loc = ""
args.hf_model_id = ""
if custom_model == "None":
if not hf_model_id:
return (
None,
"Please provide either custom model or huggingface model ID, both must not be empty",
)
args.hf_model_id = hf_model_id
elif ".ckpt" in custom_model or ".safetensors" in custom_model:
args.ckpt_loc = custom_model
else:
args.hf_model_id = custom_model
args.save_metadata_to_json = save_metadata_to_json
args.write_metadata_to_png = save_metadata_to_png
dtype = torch.float32 if precision == "fp32" else torch.half
cpu_scheduling = not scheduler.startswith("Shark")
new_config_obj = Config(
@@ -181,6 +204,11 @@ def txt2img_inf(
args.use_tuned = True
args.import_mlir = False
set_init_device_flags()
model_id = (
args.hf_model_id
if args.hf_model_id
else "stabilityai/stable-diffusion-2-1-base"
)
schedulers = get_schedulers(model_id)
scheduler_obj = schedulers[scheduler]
txt2img_obj = Text2ImagePipeline.from_pretrained(
@@ -194,6 +222,7 @@ def txt2img_inf(
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
if not txt2img_obj:
@@ -203,30 +232,38 @@ def txt2img_inf(
start_time = time.time()
txt2img_obj.log = ""
generated_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
generated_imgs = []
seeds = []
img_seed = utils.sanitize_seed(seed)
for i in range(batch_count):
if i > 0:
img_seed = utils.sanitize_seed(-1)
out_imgs = txt2img_obj.generate_images(
prompt,
negative_prompt,
batch_size,
height,
width,
steps,
guidance_scale,
img_seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
save_output_img(out_imgs[0], img_seed)
generated_imgs.extend(out_imgs)
seeds.append(img_seed)
txt2img_obj.log += "\n"
total_time = time.time() - start_time
save_output_img(generated_imgs[0])
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seeds}"
text_output += f"\nsize={args.height}x{args.width}, batch-count={batch_count}, batch-size={args.batch_size}, max_length={args.max_length}"
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
@@ -239,6 +276,7 @@ if __name__ == "__main__":
set_init_device_flags()
schedulers = get_schedulers(args.hf_model_id)
scheduler_obj = schedulers[args.scheduler]
seed = args.seed
txt2img_obj = Text2ImagePipeline.from_pretrained(
scheduler_obj,
@@ -251,34 +289,43 @@ if __name__ == "__main__":
args.height,
args.width,
args.use_base_vae,
args.use_tuned,
)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
args.seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={args.seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
for run in range(args.runs):
if run > 0:
seed = -1
seed = utils.sanitize_seed(seed)
save_output_img(generated_imgs[0])
print(text_output)
start_time = time.time()
generated_imgs = txt2img_obj.generate_images(
args.prompts,
args.negative_prompts,
args.batch_size,
args.height,
args.width,
args.steps,
args.guidance_scale,
seed,
args.max_length,
dtype,
args.use_base_vae,
cpu_scheduling,
)
total_time = time.time() - start_time
text_output = f"prompt={args.prompts}"
text_output += f"\nnegative prompt={args.negative_prompts}"
text_output += (
f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}"
)
text_output += f"\nscheduler={args.scheduler}, device={args.device}"
text_output += f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}"
text_output += (
f", batch size={args.batch_size}, max_length={args.max_length}"
)
# TODO: if using --runs=x txt2img_obj.log will output on each display every iteration infos from the start
text_output += txt2img_obj.log
text_output += f"\nTotal image generation time: {total_time:.4f}sec"
save_output_img(generated_imgs[0], seed)
print(text_output)

View File

@@ -30,6 +30,7 @@ datas += [
( 'src/utils/resources/model_db.json', 'resources' ),
( 'src/utils/resources/opt_flags.json', 'resources' ),
( 'src/utils/resources/base_model.json', 'resources' ),
( 'web/css/*', 'css' ),
( 'web/logos/*', 'logos' )
]

View File

@@ -4,13 +4,16 @@ from collections import defaultdict
import torch
import traceback
import re
import os, sys, functools, operator
import sys
from apps.stable_diffusion.src.utils import (
compile_through_fx,
get_opt_flags,
base_models,
args,
get_vmfb_path_name,
fetch_or_delete_vmfbs,
preprocessCKPT,
get_path_to_diffusers_checkpoint,
fetch_and_update_base_model_id,
)
@@ -76,6 +79,12 @@ class SharkifyStableDiffusionModel:
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.custom_weights = custom_weights
if custom_weights != "":
assert custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
custom_weights = get_path_to_diffusers_checkpoint(custom_weights)
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
@@ -91,6 +100,8 @@ class SharkifyStableDiffusionModel:
+ precision
)
self.use_tuned = use_tuned
if use_tuned:
self.model_name = self.model_name + "_tuned"
# We need a better naming convention for the .vmfbs because despite
# using the custom model variant the .vmfb names remain the same and
# it'll always pick up the compiled .vmfb instead of compiling the
@@ -98,8 +109,6 @@ class SharkifyStableDiffusionModel:
# So, currently, we add `self.model_id` in the `self.model_name` of
# .vmfb file.
# TODO: Have a better way of naming the vmfbs using self.model_name.
import re
model_name = re.sub(r"\W+", "_", self.model_id)
if model_name[0] == "_":
model_name = model_name[1:]
@@ -208,39 +217,72 @@ class SharkifyStableDiffusionModel:
)
return shark_clip
# Compiles Clip, Unet and Vae with `base_model_id` as defining their input
# configiration.
def compile_all(self, base_model_id):
self.inputs = get_input_info(
base_models[base_model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
return compiled_clip, compiled_unet, compiled_vae
def __call__(self):
model_name = ["clip", "base_vae" if self.base_vae else "vae", "unet"]
vmfb_path = [
get_vmfb_path_name(model + self.model_name)[0]
for model in model_name
]
# Step 1:
# -- Fetch all vmfbs for the model, if present, else delete the lot.
vmfbs = fetch_or_delete_vmfbs(
self.model_name, self.base_vae, self.precision
)
if vmfbs[0]:
# -- If all vmfbs are indeed present, we also try and fetch the base
# model configuration for running SD with custom checkpoints.
if self.custom_weights != "":
args.hf_model_id = fetch_and_update_base_model_id(self.custom_weights)
if args.hf_model_id == "":
sys.exit("Base model configuration for the custom model is missing. Use `--clear_all` and re-run.")
print("Loaded vmfbs from cache and successfully fetched base model configuration.")
return vmfbs
# Step 2:
# -- If vmfbs weren't found, we try to see if the base model configuration
# for the required SD run is known to us and bypass the retry mechanism.
model_to_run = ""
if self.custom_weights != "":
model_to_run = self.custom_weights
assert self.custom_weights.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
preprocessCKPT(self.custom_weights)
else:
model_to_run = args.hf_model_id
base_model_fetched = fetch_and_update_base_model_id(model_to_run)
if base_model_fetched != "":
print("Compiling all the models with the fetched base model configuration.")
if args.ckpt_loc != "":
args.hf_model_id = base_model_fetched
return self.compile_all(base_model_fetched)
# Step 3:
# -- This is the retry mechanism where the base model's configuration is not
# known to us and figure that out by trial and error.
print("Inferring base model configuration.")
for model_id in base_models:
self.inputs = get_input_info(
base_models[model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
try:
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
compiled_clip = self.get_clip()
compiled_clip, compiled_unet, compiled_vae = self.compile_all(model_id)
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(
operator.__and__, vmfb_present
)
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
print("Retrying with a different base model configuration")
continue
# -- Once a successful compilation has taken place we'd want to store
# the base model's configuration inferred.
fetch_and_update_base_model_id(model_to_run, model_id)
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
@@ -249,5 +291,5 @@ class SharkifyStableDiffusionModel:
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
"Cannot compile the model. Please re-run the command with `--enable_stack_trace` flag and create an issue with detailed log at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -1,6 +1,11 @@
import sys
from transformers import CLIPTokenizer
from apps.stable_diffusion.src.utils import models_db, args, get_shark_model
from apps.stable_diffusion.src.utils import (
models_db,
args,
get_shark_model,
get_opt_flags,
)
hf_model_variant_map = {
@@ -8,7 +13,7 @@ hf_model_variant_map = {
"dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v2_1base"],
"prompthero/openjourney": ["openjourney", "v2_1base"],
"wavymulder/Analog-Diffusion": ["analogdiffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1"],
"stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"],
"stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"],
"CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"],
}
@@ -19,47 +24,14 @@ def get_variant_version(hf_model_id):
def get_params(bucket_key, model_key, model, is_tuned, precision):
iree_flags = []
if len(args.iree_vulkan_target_triple) > 0:
iree_flags.append(
f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}"
)
# Disable bindings fusion to work with moltenVK.
if sys.platform == "darwin":
iree_flags.append("-iree-stream-fuse-binding=false")
try:
bucket = models_db[0][bucket_key]
model_name = models_db[1][model_key]
iree_flags += models_db[2][model][is_tuned][precision][
"default_compilation_flags"
]
except KeyError:
raise Exception(
f"{bucket_key}/{model_key} is not present in the models database"
)
if (
"specified_compilation_flags"
in models_db[2][model][is_tuned][precision]
):
device = (
args.device
if "://" not in args.device
else args.device.split("://")[0]
)
if (
device
not in models_db[2][model][is_tuned][precision][
"specified_compilation_flags"
]
):
device = "default_device"
iree_flags += models_db[2][model][is_tuned][precision][
"specified_compilation_flags"
][device]
iree_flags = get_opt_flags(model, precision="fp16")
return bucket, model_name, iree_flags

View File

@@ -89,6 +89,7 @@ class Text2ImagePipeline(StableDiffusionPipeline):
neg_prompts = neg_prompts * batch_size
# seed generator to create the inital latent noise. Also handle out of range seeds.
# TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly?
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:

View File

@@ -24,7 +24,6 @@ from apps.stable_diffusion.src.models import (
from apps.stable_diffusion.src.utils import (
start_profiling,
end_profiling,
preprocessCKPT,
)
@@ -184,13 +183,11 @@ class StableDiffusionPipeline:
height: int,
width: int,
use_base_vae: bool,
use_tuned: bool,
):
if import_mlir:
if ckpt_loc != "":
assert ckpt_loc.lower().endswith(
(".ckpt", ".safetensors")
), "checkpoint files supported can be any of [.ckpt, .safetensors] type"
ckpt_loc = preprocessCKPT()
# TODO: Delet this when on-the-fly tuning of models work.
use_tuned = False
mlir_import = SharkifyStableDiffusionModel(
model_id,
ckpt_loc,
@@ -200,6 +197,7 @@ class StableDiffusionPipeline:
height=height,
width=width,
use_base_vae=use_base_vae,
use_tuned=use_tuned,
)
clip, unet, vae = mlir_import()
return cls(vae, clip, get_tokenizer(), unet, scheduler)

View File

@@ -12,7 +12,6 @@ from apps.stable_diffusion.src.utils.resources import (
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.utils import (
get_vmfb_path_name,
get_shark_model,
compile_through_fx,
set_iree_runtime_flags,
@@ -21,4 +20,8 @@ from apps.stable_diffusion.src.utils.utils import (
get_available_devices,
get_opt_flags,
preprocessCKPT,
fetch_or_delete_vmfbs,
fetch_and_update_base_model_id,
get_path_to_diffusers_checkpoint,
sanitize_seed,
)

View File

@@ -1,6 +1,6 @@
[
{
"stablediffusion/untuned":"gs://shark_tank/latest",
"stablediffusion/untuned":"gs://shark_tank/sd_untuned",
"stablediffusion/tuned":"gs://shark_tank/sd_tuned",
"stablediffusion/tuned/cuda":"gs://shark_tank/sd_tuned/cuda",
"anythingv3/untuned":"gs://shark_tank/sd_anythingv3",
@@ -33,14 +33,14 @@
"stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned":"vae2base_19dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/cuda":"vae2base_19dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/vae/fp16/length_77/untuned/base":"vae2base_8dec_fp16",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base":"vae2base_8dec_fp16_tuned",
"stablediffusion/v2_1base/vae/fp16/length_77/tuned/base/cuda":"vae2base_8dec_fp16_cuda_tuned",
"stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip64_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"stablediffusion/v2_1/vae/fp16/length_77/untuned/base":"vae2_8dec_fp16",
"stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip77_512_512_fp16_stabilityai_stable_diffusion_2_1_base",
"anythingv3/v2_1base/unet/fp16/length_77/untuned":"av3_unet_19dec_fp16",
"anythingv3/v2_1base/unet/fp16/length_77/tuned":"av3_unet_19dec_fp16_tuned",
@@ -78,100 +78,5 @@
"dreamlike/v2_1base/vae/fp32/length_77/untuned":"dl_vae_23dec_fp32",
"dreamlike/v2_1base/vae/fp32/length_77/untuned/base":"dl_vaebase_23dec_fp32",
"dreamlike/v2_1base/clip/fp32/length_77/untuned":"dl_clip_23dec_fp32_77"
},
{
"unet": {
"tuned": {
"fp16": {
"default_compilation_flags": []
},
"fp32": {
"default_compilation_flags": []
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32"
],
"specified_compilation_flags": {
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
}
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"vae": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
]
}
}
},
"clip": {
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
]
}
}
}
}
]

View File

@@ -11,19 +11,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32"
],
"specified_compilation_flags": {
"cuda": ["--iree-flow-enable-conv-nchw-to-nhwc-transform"],
"default_device": ["--iree-flow-enable-conv-img2col-transform"]
}
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
@@ -34,9 +27,9 @@
"default_compilation_flags": [],
"specified_compilation_flags": {
"cuda": [],
"default_device": ["--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"]
"default_device": [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
},
"fp32": {
@@ -44,9 +37,7 @@
"specified_compilation_flags": {
"cuda": [],
"default_device": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))"
]
}
}
@@ -54,16 +45,12 @@
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-detach-elementwise-from-named-ops,iree-flow-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}
@@ -72,28 +59,24 @@
"tuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
},
"untuned": {
"fp16": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
},
"fp32": {
"default_compilation_flags": [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops"
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
}
}

View File

@@ -1,4 +1,5 @@
import os
import io
from shark.model_annotation import model_annotation, create_context
from shark.iree_utils._common import iree_target_map, run_cmd
from shark.shark_downloader import (
@@ -26,7 +27,7 @@ def load_model_from_tank():
get_variant_version,
)
version, variant = get_variant_version(args.hf_model_id)
variant, version = get_variant_version(args.hf_model_id)
shark_args.local_tank_cache = args.local_tank_cache
bucket_key = f"{variant}/untuned"
@@ -62,7 +63,7 @@ def load_winograd_configs():
def load_lower_configs():
from apps.stable_diffusion.src.models import get_variant_version
version, variant = get_variant_version(args.hf_model_id)
variant, version = get_variant_version(args.hf_model_id)
config_bucket = "gs://shark_tank/sd_tuned/configs/"
config_version = version
@@ -97,22 +98,38 @@ def annotate_with_winograd(input_mlir, winograd_config_dir, model_name):
search_op="conv",
winograd=True,
)
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return winograd_model, out_file_path
bytecode_stream = io.BytesIO()
winograd_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
with open(out_file_path, "w") as f:
f.write(str(winograd_model))
f.close()
return bytecode, out_file_path
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
def dump_after_mlir(input_mlir, model_name, use_winograd):
if use_winograd:
dump_after = "iree-linalg-ext-convert-conv2d-to-winograd"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline='builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32},"
"iree-linalg-ext-convert-conv2d-to-winograd))' "
)
else:
dump_after = "iree-flow-pad-linalg-ops"
dump_after = "iree-preprocessing-pad-linalg-ops"
preprocess_flag = (
"--iree-preprocessing-pass-pipeline='builtin.module"
"(func.func(iree-flow-detach-elementwise-from-named-ops,"
"iree-flow-convert-1x1-filter-conv2d-to-matmul,"
"iree-preprocessing-convert-conv2d-to-img2col,"
"iree-preprocessing-pad-linalg-ops{pad-size=32}))' "
)
# Dump IR after padding/img2col/winograd passes
device_spec_args = ""
device = get_device()
if device == "cuda":
@@ -132,16 +149,22 @@ def annotate_with_lower_configs(
"--iree-input-type=tm_tensor "
f"--iree-hal-target-backends={iree_target_map(device)} "
f"{device_spec_args}"
f"{preprocess_flag}"
"--iree-stream-resource-index-bits=64 "
"--iree-vm-target-index-bits=64 "
"--iree-flow-enable-padding-linalg-ops "
"--iree-flow-linalg-ops-padding-size=32 "
"--iree-flow-enable-conv-img2col-transform "
f"--mlir-print-ir-after={dump_after} "
"--compile-to=flow "
f"2>{args.annotation_output}/dump_after_winograd.mlir "
)
# For Unet annotate the model with tuned lowering configs
def annotate_with_lower_configs(
input_mlir, lowering_config_dir, model_name, use_winograd
):
# Dump IR after padding/img2col/winograd passes
dump_after_mlir(input_mlir, model_name, use_winograd)
# Annotate the model with lowering configs in the config file
with create_context() as ctx:
tuned_model = model_annotation(
@@ -159,10 +182,15 @@ def annotate_with_lower_configs(
)
else:
out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir"
bytecode_stream = io.BytesIO()
tuned_model.operation.write_bytecode(bytecode_stream)
bytecode = bytecode_stream.getvalue()
with open(out_file_path, "w") as f:
f.write(str(tuned_model))
f.close()
return tuned_model, out_file_path
return bytecode, out_file_path
def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
@@ -198,7 +226,7 @@ def sd_model_annotation(mlir_model, model_name, model_from_tank=False):
mlir_model, lowering_config_dir, model_name, use_winograd
)
print(f"Saved the annotated mlir in {output_path}.")
return tuned_model, output_path
return tuned_model
if __name__ == "__main__":

View File

@@ -295,6 +295,14 @@ p.add_argument(
help="flag for removing the pregress bar animation during image generation",
)
p.add_argument(
"--ckpt_dir",
type=str,
default="",
help="Path to directory where all .ckpts are stored in order to populate them in the web UI",
)
p.add_argument(
"--share",
default=False,

View File

@@ -1,6 +1,9 @@
import os
import gc
import json
from pathlib import Path
import numpy as np
from random import randint
from shark.shark_inference import SharkInference
from shark.shark_importer import import_with_fx
from shark.iree_utils.vulkan_utils import (
@@ -11,7 +14,7 @@ from shark.iree_utils.gpu_utils import get_cuda_sm_cc
from apps.stable_diffusion.src.utils.stable_args import args
from apps.stable_diffusion.src.utils.resources import opt_flags
from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation
import sys
import sys, functools, operator
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
load_pipeline_from_original_stable_diffusion_ckpt,
)
@@ -54,11 +57,13 @@ def _compile_module(shark_module, model_name, extra_args=[]):
# Downloads the model from shark_tank and returns the shark_module.
def get_shark_model(tank_url, model_name, extra_args=[]):
from shark.shark_downloader import download_model
from shark.parser import shark_args
# Set local shark_tank cache directory.
shark_args.local_tank_cache = args.local_tank_cache
from shark.shark_downloader import download_model
if "cuda" in args.device:
shark_args.enable_tf32 = True
@@ -93,33 +98,26 @@ def compile_through_fx(
)
if use_tuned:
model_name = model_name + "_tuned"
tuned_model_path = f"{args.annotation_output}/{model_name}_torch.mlir"
if not os.path.exists(tuned_model_path):
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
tuned_model, tuned_model_path = sd_model_annotation(
mlir_module, model_name
)
del mlir_module, tuned_model
gc.collect()
with open(tuned_model_path, "rb") as f:
mlir_module = f.read()
f.close()
if "vae" in model_name.split("_")[0]:
args.annotation_model = "vae"
mlir_module = sd_model_annotation(mlir_module, model_name)
shark_module = SharkInference(
mlir_module,
device=args.device,
mlir_dialect="linalg",
)
del mlir_module
gc.collect()
return _compile_module(shark_module, model_name, extra_args)
def set_iree_runtime_flags():
vulkan_runtime_flags = [
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
f"--device_allocator=caching",
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
]
if args.enable_rgp:
@@ -235,8 +233,8 @@ def set_init_device_flags():
# Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices.
if (
args.hf_model_id
in ["prompthero/openjourney", "dreamlike-art/dreamlike-diffusion-1.0"]
args.hf_model_id == "prompthero/openjourney"
or args.ckpt_loc != ""
or args.precision != "fp16"
or args.height != 512
or args.width != 512
@@ -251,12 +249,7 @@ def set_init_device_flags():
):
args.use_tuned = False
elif "cuda" in args.device and get_cuda_sm_cc() not in [
"sm_80",
"sm_84",
"sm_86",
"sm_89",
]:
elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80"]:
args.use_tuned = False
elif args.use_base_vae and args.hf_model_id not in [
@@ -362,34 +355,107 @@ def get_opt_flags(model, precision="fp16"):
return iree_flags
def preprocessCKPT():
path = Path(args.ckpt_loc)
def get_path_to_diffusers_checkpoint(custom_weights):
path = Path(custom_weights)
diffusers_path = path.parent.absolute()
diffusers_directory_name = path.stem
complete_path_to_diffusers = diffusers_path / diffusers_directory_name
complete_path_to_diffusers.mkdir(parents=True, exist_ok=True)
print(
"Created directory : ",
diffusers_directory_name,
" at -> ",
diffusers_path,
)
path_to_diffusers = complete_path_to_diffusers.as_posix()
return path_to_diffusers
def preprocessCKPT(custom_weights):
path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights)
if next(Path(path_to_diffusers).iterdir(), None):
print("Checkpoint already loaded at : ", path_to_diffusers)
return
else:
print(
"Diffusers' checkpoint will be identified here : ",
path_to_diffusers,
)
from_safetensors = (
True if args.ckpt_loc.lower().endswith(".safetensors") else False
True if custom_weights.lower().endswith(".safetensors") else False
)
# EMA weights usually yield higher quality images for inference but non-EMA weights have
# been yielding better results in our case.
# TODO: Add an option `--ema` (`--no-ema`) for users to specify if they want to go for EMA
# weight extraction or not.
extract_ema = False
print("Loading pipeline from original stable diffusion checkpoint")
print(
"Loading diffusers' pipeline from original stable diffusion checkpoint"
)
pipe = load_pipeline_from_original_stable_diffusion_ckpt(
checkpoint_path=args.ckpt_loc,
checkpoint_path=custom_weights,
extract_ema=extract_ema,
from_safetensors=from_safetensors,
)
pipe.save_pretrained(path_to_diffusers)
print("Loading complete")
print("Custom model path is : ", path_to_diffusers)
return path_to_diffusers
def load_vmfb(vmfb_path, model, precision):
model = "vae" if "base_vae" in model else model
precision = "fp32" if "clip" in model else precision
extra_args = get_opt_flags(model, precision)
shark_module = SharkInference(mlir_module=None, device=args.device)
shark_module.load_module(vmfb_path, extra_args=extra_args)
return shark_module
# This utility returns vmfbs of Clip, Unet and Vae, in case all three of them
# are present; deletes them otherwise.
def fetch_or_delete_vmfbs(basic_model_name, use_base_vae, precision="fp32"):
model_name = ["clip", "unet", "base_vae" if use_base_vae else "vae"]
vmfb_path = [
get_vmfb_path_name(model + basic_model_name)[0] for model in model_name
]
vmfb_present = [os.path.isfile(vmfb) for vmfb in vmfb_path]
all_vmfb_present = functools.reduce(operator.__and__, vmfb_present)
compiled_models = [None] * 3
# We need to delete vmfbs only if some of the models were compiled.
if not all_vmfb_present:
for i in range(len(vmfb_path)):
if vmfb_present[i]:
os.remove(vmfb_path[i])
print("Deleted: ", vmfb_path[i])
else:
for i in range(len(vmfb_path)):
compiled_models[i] = load_vmfb(
vmfb_path[i], model_name[i], precision
)
return compiled_models
# `fetch_and_update_base_model_id` is a resource utility function which
# helps maintaining mapping of the model to run with its base model.
# If `base_model` is "", then this function tries to fetch the base model
# info for the `model_to_run`.
def fetch_and_update_base_model_id(model_to_run, base_model=""):
variants_path = os.path.join(os.getcwd(), "variants.json")
data = {model_to_run: base_model}
json_data = {}
if os.path.exists(variants_path):
with open(variants_path, "r", encoding="utf-8") as jsonFile:
json_data = json.load(jsonFile)
# Return with base_model's info if base_model is "".
if base_model == "":
if model_to_run in json_data:
base_model = json_data[model_to_run]
return base_model
elif base_model == "":
return base_model
# Update JSON data to contain an entry mapping model_to_run with base_model.
json_data.update(data)
with open(variants_path, "w", encoding="utf-8") as jsonFile:
json.dump(json_data, jsonFile)
# Generate and return a new seed if the provided one is not in the supported range (including -1)
def sanitize_seed(seed):
uint32_info = np.iinfo(np.uint32)
uint32_min, uint32_max = uint32_info.min, uint32_info.max
if seed < uint32_min or seed >= uint32_max:
seed = randint(uint32_min, uint32_max)
return seed

View File

@@ -1,6 +1,6 @@
# Stable Diffusion optimized for AMD RDNA2/RDNA3 GPUs
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! If you still can't get it to work, we're sorry, and please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
Before you start, please be aware that this is beta software that relies on a special AMD driver. Like all StableDiffusion GUIs published so far, you need some technical expertise to set it up. We apologize in advance if you bump into issues. If that happens, please don't hesitate to ask our Discord community for help! Please be assured that we (Nod and AMD) are working hard to improve the user experience in coming months.
If it works well for you, please "star" the following GitHub projects... this is one of the best ways to help and spread the word!
* https://github.com/nod-ai/SHARK
@@ -23,10 +23,10 @@ KNOWN ISSUES with this special AMD driver:
## Installation
Download the latest Windows SHARK SD binary [469 here](https://github.com/nod-ai/SHARK/releases/download/20230124.469/shark_sd_20230124_469.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
Download the latest Windows SHARK SD binary [492 here](https://github.com/nod-ai/SHARK/releases/download/20230203.492/shark_sd_20230203_492.exe) in a folder of your choice. If you want nighly builds, you can look for them on the GitHub releases page.
Notes:
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clean_all` flag once to clean all the old files.
* We recommend that you download this EXE in a new folder, whenever you download a new EXE version. If you download it in the same folder as a previous install, you must delete the old `*.vmfb` files. Those contain Vulkan dispatches compiled from MLIR which can be outdated if you run a new EXE from the same folder. You can use `--clear_all` flag once to clean all the old files.
* If you recently updated the driver or this binary (EXE file), we recommend you:
* clear all the local artifacts with `--clear_all` OR
* clear the Vulkan shader cache: For Windows users this can be done by clearing the contents of `C:\Users\%username%\AppData\Local\AMD\VkCache\`. On Linux the same cache is typically located at `~/.cache/AMD/VkCache/`.
@@ -56,84 +56,6 @@ Here are some samples generated:
![a photo of a crab playing a trumpet](https://user-images.githubusercontent.com/74956/204933258-252e7240-8548-45f7-8253-97647d38313d.jpg)
<details>
<summary>Advanced Installation </summary>
## Setup your Python Virtual Environment and Dependencies
<details>
<summary> Windows 10/11 Users </summary>
* Install the latest Python 3.10.x version from [here](https://www.python.org/downloads/windows/)
* Install Git for Windows from [here](https://git-scm.com/download/win)
#### Allow the install script to run in Powershell
```powershell
set-executionpolicy remotesigned
```
#### Setup venv and install necessary packages (torch-mlir, nodLabs/Shark, ...)
```powershell
git clone https://github.com/nod-ai/SHARK.git
cd SHARK
./setup_venv.ps1 #You can re-run this script to get the latest version
```
</details>
<details>
<summary>Linux</summary>
```shell
git clone https://github.com/nod-ai/SHARK.git
cd SHARK
./setup_venv.sh
source shark.venv/bin/activate
```
</details>
### Run Stable Diffusion on your device - WebUI
<details>
<summary>Windows 10/11 Users</summary>
```powershell
(shark.venv) PS C:\g\shark> cd .\apps\stable_diffusion\web\
(shark.venv) PS C:\g\shark\apps\stable_diffusion\web> python .\index.py
```
</details>
<details>
<summary>Linux Users</summary>
```shell
(shark.venv) > cd apps/stable_diffusion/web
(shark.venv) > python index.py
```
</details>
### Run Stable Diffusion on your device - Commandline
<details>
<summary>Windows 10/11 Users</summary>
```powershell
(shark.venv) PS C:\g\shark> python .\apps\stable_diffusion\scripts\txt2img.py --precision="fp16" --prompt="tajmahal, snow, sunflowers, oil on canvas" --device="vulkan"
```
</details>
<details>
<summary>Linux</summary>
```shell
python3.10 apps/stable_diffusion/scripts/txt2img.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd"
```
</details>
The output on a 7900XTX would like:
```shell
@@ -145,10 +67,4 @@ VAE Inference time (ms): 78.590
Total image generation time: 2.5788655281066895sec
```
For more options to the Stable Diffusion model read [this](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md)
</details>
<details>
<summary>Discord link</summary>
Find us on [SHARK Discord server](https://discord.gg/RUqY2h2s9u) if you have any trouble with running it on your hardware.
</details>

View File

@@ -1,5 +1,153 @@
/* Overwrite the Gradio default theme with their .dark theme declarations */
:root {
--color-focus-primary: var(--color-grey-700);
--color-focus-secondary: var(--color-grey-600);
--color-focus-ring: rgb(55 65 81);
--color-background-primary: var(--color-grey-950);
--color-background-secondary: var(--color-grey-900);
--color-background-tertiary: var(--color-grey-800);
--color-text-body: var(--color-grey-100);
--color-text-label: var(--color-grey-200);
--color-text-placeholder: var(--color-grey);
--color-text-subdued: var(--color-grey-400);
--color-text-link-base: var(--color-blue-500);
--color-text-link-hover: var(--color-blue-400);
--color-text-link-visited: var(--color-blue-600);
--color-text-link-active: var(--color-blue-500);
--color-text-code-background: var(--color-grey-800);
--color-text-code-border: color.border-primary;
--color-border-primary: var(--color-grey-700);
--color-border-secondary: var(--color-grey-600);
--color-border-highlight: var(--color-accent-base);
--color-accent-base: var(--color-orange-500);
--color-accent-light: var(--color-orange-300);
--color-accent-dark: var(--color-orange-700);
--color-functional-error-base: var(--color-red-400);
--color-functional-error-subdued: var(--color-red-300);
--color-functional-error-background: var(--color-background-primary);
--color-functional-info-base: var(--color-yellow);
--color-functional-info-subdued: var(--color-yellow-300);
--color-functional-success-base: var(--color-green);
--color-functional-success-subdued: var(--color-green-300);
--shadow-spread: 2px;
--api-background: linear-gradient(to bottom, rgba(255, 216, 180, .05), transparent);
--api-pill-background: var(--color-orange-400);
--api-pill-border: var(--color-orange-600);
--api-pill-text: var(--color-orange-900);
--block-border-color: var(--color-border-primary);
--block-background: var(--color-background-tertiary);
--uploadable-border-color-hover: var(--color-border-primary);
--uploadable-border-color-loaded: var(--color-functional-success);
--uploadable-text-color: var(--color-text-subdued);
--block_label-border-color: var(--color-border-primary);
--block_label-icon-color: var(--color-text-label);
--block_label-shadow: var(--shadow-drop);
--block_label-background: var(--color-background-secondary);
--icon_button-icon-color-base: var(--color-text-label);
--icon_button-icon-color-hover: var(--color-text-label);
--icon_button-background-base: var(--color-background-primary);
--icon_button-background-hover: var(--color-background-primary);
--icon_button-border-color-base: var(--color-background-primary);
--icon_button-border-color-hover: var(--color-border-secondary);
--input-text-color: var(--color-text-body);
--input-border-color-base: var(--color-border-primary);
--input-border-color-hover: var(--color-border-primary);
--input-border-color-focus: var(--color-border-primary);
--input-background-base: var(--color-background-tertiary);
--input-background-hover: var(--color-background-tertiary);
--input-background-focus: var(--color-background-tertiary);
--input-shadow: var(--shadow-inset);
--checkbox-border-color-base: var(--color-border-primary);
--checkbox-border-color-hover: var(--color-focus-primary);
--checkbox-border-color-focus: var(--color-blue-500);
--checkbox-background-base: var(--color-background-primary);
--checkbox-background-hover: var(--color-background-primary);
--checkbox-background-focus: var(--color-background-primary);
--checkbox-background-selected: var(--color-blue-600);
--checkbox-label-border-color-base: var(--color-border-primary);
--checkbox-label-border-color-hover: var(--color-border-primary);
--checkbox-label-border-color-focus: var(--color-border-secondary);
--checkbox-label-background-base: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--checkbox-label-background-hover: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--checkbox-label-background-focus: linear-gradient(to top, var(--color-grey-900), var(--color-grey-800));
--form-seperator-color: var(--color-border-primary);
--button-primary-border-color-base: var(--color-orange-600);
--button-primary-border-color-hover: var(--color-orange-600);
--button-primary-border-color-focus: var(--color-orange-600);
--button-primary-text-color-base: white;
--button-primary-text-color-hover: white;
--button-primary-text-color-focus: white;
--button-primary-background-base: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-700));
--button-primary-background-hover: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
--button-primary-background-focus: linear-gradient(to bottom right, var(--color-orange-700), var(--color-orange-500));
--button-secondary-border-color-base: var(--color-grey-600);
--button-secondary-border-color-hover: var(--color-grey-600);
--button-secondary-border-color-focus: var(--color-grey-600);
--button-secondary-text-color-base: white;
--button-secondary-text-color-hover: white;
--button-secondary-text-color-focus: white;
--button-secondary-background-base: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-700));
--button-secondary-background-hover: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
--button-secondary-background-focus: linear-gradient(to bottom right, var(--color-grey-600), var(--color-grey-600));
--button-cancel-border-color-base: var(--color-red-600);
--button-cancel-border-color-hover: var(--color-red-600);
--button-cancel-border-color-focus: var(--color-red-600);
--button-cancel-text-color-base: white;
--button-cancel-text-color-hover: white;
--button-cancel-text-color-focus: white;
--button-cancel-background-base: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-700));
--button-cancel-background-focus: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
--button-cancel-background-hover: linear-gradient(to bottom right, var(--color-red-700), var(--color-red-500));
--button-plain-border-color-base: var(--color-grey-600);
--button-plain-border-color-hover: var(--color-grey-500);
--button-plain-border-color-focus: var(--color-grey-500);
--button-plain-text-color-base: var(--color-text-body);
--button-plain-text-color-hover: var(--color-text-body);
--button-plain-text-color-focus: var(--color-text-body);
--button-plain-background-base: var(--color-grey-700);
--button-plain-background-hover: var(--color-grey-700);
--button-plain-background-focus: var(--color-grey-700);
--gallery-label-background-base: var(--color-grey-50);
--gallery-label-background-hover: var(--color-grey-50);
--gallery-label-border-color-base: var(--color-border-primary);
--gallery-label-border-color-hover: var(--color-border-primary);
--gallery-thumb-background-base: var(--color-grey-900);
--gallery-thumb-background-hover: var(--color-grey-900);
--gallery-thumb-border-color-base: var(--color-border-primary);
--gallery-thumb-border-color-hover: var(--color-accent-base);
--gallery-thumb-border-color-focus: var(--color-blue-500);
--gallery-thumb-border-color-selected: var(--color-accent-base);
--chatbot-border-border-color-base: transparent;
--chatbot-border-border-color-latest: transparent;
--chatbot-user-background-base: ;
--chatbot-user-background-latest: ;
--chatbot-user-text-color-base: white;
--chatbot-user-text-color-latest: white;
--chatbot-bot-background-base: ;
--chatbot-bot-background-latest: ;
--chatbot-bot-text-color-base: white;
--chatbot-bot-text-color-latest: white;
--label-gradient-from: var(--color-orange-400);
--label-gradient-to: var(--color-orange-600);
--table-odd-background: var(--color-grey-900);
--table-even-background: var(--color-grey-950);
--table-background-edit: transparent;
--dataset-gallery-background-base: var(--color-background-primary);
--dataset-gallery-background-hover: var(--color-grey-800);
--dataset-dataframe-border-base: var(--color-border-primary);
--dataset-dataframe-border-hover: var(--color-border-secondary);
--dataset-table-background-base: transparent;
--dataset-table-background-hover: var(--color-grey-700);
--dataset-table-border-base: var(--color-grey-800);
--dataset-table-border-hover: var(--color-grey-800);
}
/* SHARK theme customization */
.gradio-container {
background-color: black
background-color: var(--color-background-primary);
}
.container {
@@ -18,12 +166,12 @@
}
#demo_title {
background-color: black;
background-color: var(--color-background-primary);
border-radius: 0 !important;
border: 0;
padding-top: 50px;
padding-top: 15px;
padding-bottom: 0px;
width: 460px !important;
width: 350px !important;
}
#demo_title_outer {
@@ -35,25 +183,19 @@
}
#prompt_box textarea {
background-color: #1d1d1d !important
background-color: var(--color-background-primary) !important;
}
#prompt_examples {
margin: 0 !important
margin: 0 !important;
}
#prompt_examples svg {
display: none !important;
}
.gr-sample-textbox {
border-radius: 1rem !important;
border-color: rgb(31, 41, 55) !important;
border-width: 2px !important;
}
#ui_body {
background-color: #111111 !important;
background-color: var(--color-background-secondary) !important;
padding: 10px !important;
border-radius: 0.5em !important;
}

View File

@@ -58,25 +58,12 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
with gr.Row():
with gr.Column(scale=1, min_width=600):
with gr.Row():
model_id = gr.Dropdown(
label="Model ID",
value="stabilityai/stable-diffusion-2-1-base",
choices=[
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
],
ckpt_path = (
Path(args.ckpt_dir)
if args.ckpt_dir
else Path(Path.cwd(), "models")
)
custom_model_id = gr.Textbox(
placeholder="SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
)
with gr.Group():
ckpt_path = "models"
ckpt_path.mkdir(parents=True, exist_ok=True)
types = (
"*.ckpt",
"*.safetensors",
@@ -85,10 +72,23 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
for extn in types:
files = glob.glob(os.path.join(ckpt_path, extn))
ckpt_files.extend(files)
ckpt_loc = gr.Dropdown(
label="Place all checkpoints in models/",
custom_model = gr.Dropdown(
label=f"Models (Custom Model path: {ckpt_path})",
value="None",
choices=ckpt_files,
choices=ckpt_files
+ [
"Linaqruf/anything-v3.0",
"prompthero/openjourney",
"wavymulder/Analog-Diffusion",
"stabilityai/stable-diffusion-2-1",
"stabilityai/stable-diffusion-2-1-base",
"CompVis/stable-diffusion-v1-4",
],
)
hf_model_id = gr.Textbox(
placeholder="Select 'None' in the Models dropdown on the left and enter model ID here e.g: SG161222/Realistic_Vision_V1.3",
value="",
label="HuggingFace Model ID",
)
with gr.Group(elem_id="prompt_box_outer"):
@@ -104,7 +104,7 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
lines=1,
elem_id="prompt_box",
)
with gr.Accordion(label="Advance Options", open=False):
with gr.Accordion(label="Advanced Options", open=False):
with gr.Row():
scheduler = gr.Dropdown(
label="Scheduler",
@@ -119,9 +119,17 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
"SharkEulerDiscrete",
],
)
batch_size = gr.Slider(
1, 4, value=1, step=1, label="Number of Images"
)
with gr.Group():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=True,
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
interactive=True,
)
with gr.Row():
height = gr.Slider(
384, 786, value=512, step=8, label="Height"
@@ -159,14 +167,20 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
label="CFG Scale",
)
with gr.Row():
save_metadata_to_png = gr.Checkbox(
label="Save prompt information to PNG",
value=True,
batch_count = gr.Slider(
1,
10,
value=1,
step=1,
label="Batch Count",
interactive=True,
)
save_metadata_to_json = gr.Checkbox(
label="Save prompt information to JSON file",
value=False,
batch_size = gr.Slider(
1,
4,
value=1,
step=1,
label="Batch Size",
interactive=True,
)
with gr.Row():
@@ -213,55 +227,33 @@ with gr.Blocks(title="Stable Diffusion", css=demo_css) as shark_web:
value=output_dir,
interactive=False,
)
kwargs = dict(
fn=txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_count,
batch_size,
scheduler,
custom_model,
hf_model_id,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
prompt.submit(
txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_size,
scheduler,
model_id,
custom_model_id,
ckpt_loc,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
stable_diffusion.click(
txt2img_inf,
inputs=[
prompt,
negative_prompt,
height,
width,
steps,
guidance_scale,
seed,
batch_size,
scheduler,
model_id,
custom_model_id,
ckpt_loc,
precision,
device,
max_length,
save_metadata_to_json,
save_metadata_to_png,
],
outputs=[gallery, std_output],
show_progress=args.progress_bar,
)
prompt.submit(**kwargs)
stable_diffusion.click(**kwargs)
shark_web.queue()
shark_web.launch(

View File

@@ -29,7 +29,7 @@ def compare_images(new_filename, golden_filename):
golden = np.array(Image.open(golden_filename)) / 255.0
diff = np.abs(new - golden)
mean = np.mean(diff)
if mean > 0.01:
if mean > 0.1:
subprocess.run(
["gsutil", "cp", new_filename, "gs://shark_tank/testdata/builder/"]
)

View File

@@ -2,4 +2,4 @@
IMPORTER=1 BENCHMARK=1 ./setup_venv.sh
source $GITHUB_WORKSPACE/shark.venv/bin/activate
python generate_sharktank.py --upload=False --ci_tank_dir=True
python generate_sharktank.py

View File

@@ -1,7 +0,0 @@
rm -rf ./test_images
mkdir test_images
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned
python shark/examples/shark_inference/stable_diffusion/main.py --device=vulkan --output_dir=./test_images --no-load_vmfb --no-use_tuned --beta_models=True
python build_tools/image_comparison.py -n ./test_images/*.png
exit $?

View File

@@ -23,8 +23,7 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
os.mkdir("./test_images")
os.mkdir("./test_images/golden")
hf_model_names = model_config_dicts[0].values()
tuned_options = ["--no-use_tuned"] #'use_tuned']
devices = ["vulkan"]
tuned_options = ["--no-use_tuned", "use_tuned"]
if beta:
extra_flags.append("--beta_models=True")
for model_name in hf_model_names:
@@ -33,15 +32,19 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"python",
"apps/stable_diffusion/scripts/txt2img.py",
"--device=" + device,
"--output_dir=./test_images/" + model_name,
"--prompt=cyberpunk forest by Salvador Dali",
"--output_dir="
+ os.path.join(os.getcwd(), "test_images", model_name),
"--hf_model_id=" + model_name,
use_tune,
]
command += extra_flags
generated_image = not subprocess.call(
command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL
command, stdout=subprocess.DEVNULL
)
if generated_image:
print(" ".join(command))
print("Successfully generated image")
os.makedirs(
"./test_images/golden/" + model_name, exist_ok=True
)
@@ -49,18 +52,16 @@ def test_loop(device="vulkan", beta=False, extra_flags=[]):
"gs://shark_tank/testdata/golden/" + model_name,
"./test_images/golden/" + model_name,
)
comparison = [
"python",
"build_tools/image_comparison.py",
"--golden_url=gs://shark_tank/testdata/golden/"
+ model_name
+ "/*.png",
"--newfile=./test_images/" + model_name + "/*.png",
]
test_file = glob("./test_images/" + model_name + "/*.png")[0]
test_file_path = os.path.join(
os.getcwd(), "test_images", model_name, "generated_imgs"
)
test_file = glob(test_file_path + "/*.png")[0]
golden_path = "./test_images/golden/" + model_name + "/*.png"
golden_file = glob(golden_path)[0]
compare_images(test_file, golden_file)
else:
print(" ".join(command))
print("failed to generate image for this configuration")
parser = argparse.ArgumentParser()

View File

@@ -2,18 +2,16 @@
"""SHARK Tank"""
# python generate_sharktank.py, you have to give a csv tile with [model_name, model_download_url]
# will generate local shark tank folder like this:
# HOME
# /.local
# /shark_tank
# /albert_lite_base
# /...model_name...
# /SHARK
# /gen_shark_tank
# /albert_lite_base
# /...model_name...
#
import os
import csv
import argparse
from shark.shark_importer import SharkImporter
from shark.parser import shark_args
import subprocess as sp
import hashlib
import numpy as np
@@ -267,16 +265,17 @@ if __name__ == "__main__":
# old_args = parser.parse_args()
home = str(Path.home())
if args.ci_tank_dir == True:
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
else:
WORKDIR = os.path.join(home, ".local/shark_tank/")
WORKDIR = os.path.join(os.path.dirname(__file__), "gen_shark_tank")
torch_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "torch_model_list.csv"
)
tf_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tf_model_list.csv"
)
tflite_model_csv = os.path.join(
os.path.dirname(__file__), "tank", "tflite", "tflite_model_list.csv"
)
if args.torch_model_csv:
save_torch_model(args.torch_model_csv)
if args.tf_model_csv:
save_tf_model(args.tf_model_csv)
if args.tflite_model_csv:
save_tflite_model(args.tflite_model_csv)
save_torch_model(torch_model_csv)
save_tf_model(tf_model_csv)
save_tflite_model(tflite_model_csv)

View File

@@ -1,6 +1,6 @@
import torchdynamo
import torch
import torch_mlir
import torch._dynamo as torchdynamo
from shark.sharkdynamo.utils import make_shark_compiler

View File

@@ -13,20 +13,15 @@ if BATCH_SIZE != 1:
unet_flag = [
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=32",
"--iree-flow-enable-conv-img2col-transform",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))"
]
vae_flag = [
"--iree-flow-enable-conv-nchw-to-nhwc-transform",
"--iree-flow-enable-padding-linalg-ops",
"--iree-flow-linalg-ops-padding-size=16",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc,iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
clip_flag = [
"--iree-flow-linalg-ops-padding-size=16",
"--iree-flow-enable-padding-linalg-ops",
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))"
]
bucket = "gs://shark_tank/stable_diffusion/"

View File

@@ -73,17 +73,17 @@ def build_benchmark_args(
path, "..", "..", "iree-benchmark-module"
)
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
fn_name = "forward"
if training == True:
# TODO: Replace name of train with actual train fn name.
fn_name = "train"
benchmark_cl.append(f"--entry_function={fn_name}")
benchmark_cl.append(f"--function={fn_name}")
benchmark_cl.append(f"--device={iree_device_map(device)}")
mlir_input_types = tensor_to_type_str(input_tensors, mlir_dialect)
for mlir_input in mlir_input_types:
benchmark_cl.append(f"--function_input={mlir_input}")
benchmark_cl.append(f"--input={mlir_input}")
if device == "cpu":
num_cpus = get_cpu_count()
if num_cpus is not None:
@@ -114,13 +114,13 @@ def build_benchmark_args_non_tensor_input(
benchmarker_path = os.path.join(
path, "..", "..", "iree-benchmark-module"
)
benchmark_cl = [benchmarker_path, f"--module_file={input_file}"]
benchmark_cl = [benchmarker_path, f"--module={input_file}"]
# TODO: The function named can be passed as one of the args.
if function_name:
benchmark_cl.append(f"--entry_function={function_name}")
benchmark_cl.append(f"--function={function_name}")
benchmark_cl.append(f"--device={iree_device_map(device)}")
for input in inputs:
benchmark_cl.append(f"--function_input={input}")
benchmark_cl.append(f"--input={input}")
if platform.system() != "Windows":
time_extractor = "| awk 'END{{print $2 $3}}'"
benchmark_cl.append(time_extractor)

View File

@@ -70,6 +70,7 @@ def get_iree_common_args():
return [
"--iree-stream-resource-index-bits=64",
"--iree-vm-target-index-bits=64",
"--iree-vm-bytecode-module-strip-source-map=true",
"--iree-util-zero-fill-elided-attrs",
]
@@ -80,11 +81,17 @@ def get_iree_common_args():
def get_model_specific_args():
ms_args = []
if shark_args.enable_conv_transform == True:
ms_args += ["--iree-flow-enable-conv-nchw-to-nhwc-transform"]
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-flow-convert-conv-nchw-to-nhwc))"
]
if shark_args.enable_img2col_transform == True:
ms_args += ["--iree-flow-enable-conv-img2col-transform"]
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-convert-conv2d-to-img2col))"
]
if shark_args.use_winograd == True:
ms_args += ["--iree-flow-enable-conv-winograd-transform"]
ms_args += [
"--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-linalg-ext-convert-conv2d-to-winograd))"
]
return ms_args

View File

@@ -22,7 +22,7 @@ from shark.parser import shark_args
# Get the default gpu args given the architecture.
def get_iree_gpu_args():
ireert.flags.FUNCTION_INPUT_VALIDATION = False
ireert.flags.parse_flags("--cuda_allow_inline_execution")
ireert.flags.parse_flags("--cuda_allow_inline_execution", "--device_allocator=caching")
# TODO: Give the user_interface to pass the sm_arch.
sm_arch = get_cuda_sm_cc()
if (

View File

@@ -139,9 +139,8 @@ def get_vulkan_triple_flag(device_name="", extra_args=[]):
def get_iree_vulkan_args(extra_args=[]):
# vulkan_flag = ["--iree-flow-demote-i64-to-i32"]
res_vulkan_flag = ["--device_allocator=caching"]
res_vulkan_flag = []
vulkan_triple_flag = None
for arg in extra_args:
if "-iree-vulkan-target-triple=" in arg:

View File

@@ -15,24 +15,6 @@
import argparse
import os
def dir_path(path):
if os.path.isdir(path):
return path
else:
os.mkdir(path)
return path
def dir_file(path):
if os.path.isfile(path):
return path
else:
raise argparse.ArgumentTypeError(
f"readable_file:{path} is not a valid file"
)
parser = argparse.ArgumentParser(description="SHARK runner.")
parser.add_argument(
"--device",
@@ -40,12 +22,6 @@ parser.add_argument(
default="cpu",
help="Device on which shark_runner runs. options are cpu, cuda, and vulkan",
)
parser.add_argument(
"--repro_dir",
help="Directory to which module files will be saved for reproduction or debugging.",
type=dir_path,
default="shark_tmp",
)
parser.add_argument(
"--enable_tf32",
type=bool,
@@ -83,10 +59,16 @@ parser.add_argument(
)
parser.add_argument(
"--update_tank",
default=False,
default=True,
action="store_true",
help="When enabled, SHARK downloader will update local shark_tank if local hash is different from latest upstream hash.",
)
parser.add_argument(
"--force_update_tank",
default=False,
action="store_true",
help="When enabled, SHARK downloader will force an update of local shark_tank artifacts for each request.",
)
parser.add_argument(
"--local_tank_cache",
default=None,

View File

@@ -82,7 +82,7 @@ class SharkBenchmarkRunner(SharkRunner):
self.vmfb_file = export_iree_module_to_vmfb(
mlir_module,
device,
shark_args.repro_dir,
".",
self.mlir_dialect,
extra_args=self.extra_args,
)

View File

@@ -79,23 +79,21 @@ input_type_to_np_dtype = {
# Save the model in the home local so it needn't be fetched everytime in the CI.
home = str(Path.home())
alt_path = os.path.join(os.path.dirname(__file__), "../gen_shark_tank/")
custom_path_list = None
if shark_args.local_tank_cache is not None:
custom_path_list = shark_args.local_tank_cache.split("/")
custom_path = shark_args.local_tank_cache
if os.path.exists(alt_path):
WORKDIR = alt_path
print(
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
)
if custom_path_list:
custom_path = os.path.join(*custom_path_list)
if custom_path is not None:
if not os.path.exists(custom_path):
os.mkdir(custom_path)
WORKDIR = custom_path
print(f"Using {WORKDIR} as local shark_tank cache directory.")
elif os.path.exists(alt_path):
WORKDIR = alt_path
print(
f"Using {WORKDIR} as shark_tank directory. Delete this directory if you aren't working from locally generated shark_tank."
)
else:
WORKDIR = os.path.join(home, ".local/shark_tank/")
print(
@@ -148,15 +146,14 @@ def download_model(
model_dir = os.path.join(WORKDIR, model_dir_name)
full_gs_url = tank_url.rstrip("/") + "/" + model_dir_name
if shark_args.update_tank == True:
print(f"Updating artifacts for model {model_name}...")
download_public_file(full_gs_url, model_dir)
elif not check_dir_exists(
if not check_dir_exists(
model_dir_name, frontend=frontend, dynamic=dyn_str
):
print(f"Downloading artifacts for model {model_name}...")
download_public_file(full_gs_url, model_dir)
elif shark_args.force_update_tank == True:
print(f"Force-updating artifacts for model {model_name}...")
download_public_file(full_gs_url, model_dir)
else:
if not _internet_connected():
print(
@@ -178,7 +175,11 @@ def download_model(
)
except FileNotFoundError:
upstream_hash = None
if local_hash != upstream_hash:
if local_hash != upstream_hash and shark_args.update_tank == True:
print(f"Updating artifacts for model {model_name}...")
download_public_file(full_gs_url, model_dir)
elif local_hash != upstream_hash:
print(
"Hash does not match upstream in gs://shark_tank/latest. If you want to use locally generated artifacts, this is working as intended. Otherwise, run with --update_tank."
)

View File

@@ -81,7 +81,7 @@ class SharkImporter:
self.return_str,
)
def _tf_mlir(self, func_name, save_dir="./shark_tmp/"):
def _tf_mlir(self, func_name, save_dir="."):
from iree.compiler import tf as tfc
return tfc.compile_module(
@@ -91,7 +91,7 @@ class SharkImporter:
output_file=save_dir,
)
def _tflite_mlir(self, func_name, save_dir="./shark_tmp/"):
def _tflite_mlir(self, func_name, save_dir="."):
from iree.compiler import tflite as tflitec
self.mlir_model = tflitec.compile_file(

View File

@@ -3,7 +3,7 @@ import time
from typing import List, Optional
import torch
from torch.fx.experimental.proxy_tensor import make_fx
from functorch._src.compile_utils import strip_overloads
from torch._functorch.compile_utils import strip_overloads
from shark.shark_inference import SharkInference
from torch._decomp import get_decompositions
@@ -119,14 +119,19 @@ def make_shark_compiler(use_tracing: bool, device: str, verbose=False):
example_inputs,
output_type=torch_mlir.OutputType.LINALG_ON_TENSORS,
)
import io
bytecode_stream = io.BytesIO()
linalg_module.operation.write_bytecode(bytecode_stream)
mlir_module = bytecode_stream.getvalue()
shark_module = SharkInference(
linalg_module, "forward", mlir_dialect="linalg", device=device
mlir_module, mlir_dialect="linalg", device=device
)
shark_module.compile()
def forward(*inputs):
result = shark_module.forward(inputs)
result = shark_module("forward", inputs)
result = tuple() if result is None else result
return (result,) if was_unwrapped else result

View File

@@ -65,7 +65,7 @@ def get_torch_mlir_module(
if jit_trace:
ignore_traced_shapes = True
tempfile.tempdir = shark_args.repro_dir
tempfile.tempdir = "."
mlir_module = torch_mlir.compile(
module,

View File

@@ -136,7 +136,7 @@ class SharkModuleTester:
def create_and_check_module(self, dynamic, device):
shark_args.local_tank_cache = self.local_tank_cache
shark_args.update_tank = self.update_tank
shark_args.force_update_tank = self.update_tank
if "nhcw-nhwc" in self.config["flags"] and not os.path.isfile(
".use-iree"
):
@@ -212,12 +212,11 @@ class SharkModuleTester:
)
def save_reproducers(self):
# Saves contents of IREE TempFileSaver temporary directory to ./shark_tmp/saved/<test_case>.
src = os.path.join(*self.temp_dir.split("/"))
saves = os.path.join(".", "shark_tmp", "saved")
trg = os.path.join(saves, self.tmp_prefix)
if not os.path.isdir(saves):
os.mkdir(saves)
# Saves contents of IREE TempFileSaver temporary directory to ./{temp_dir}/saved/<test_case>.
src = self.temp_dir
trg = os.path.join("reproducers", self.tmp_prefix)
if not os.path.isdir("reproducers"):
os.mkdir("reproducers")
if not os.path.isdir(trg):
os.mkdir(trg)
files = os.listdir(src)
@@ -227,10 +226,7 @@ class SharkModuleTester:
def upload_repro(self):
import subprocess
src = os.path.join(*self.temp_dir.split("/"))
repro_path = os.path.join(
".", "shark_tmp", "saved", self.tmp_prefix, "*"
)
repro_path = os.path.join("reproducers", self.tmp_prefix, "*")
bashCommand = f"gsutil cp -r {repro_path} gs://shark-public/builder/repro_artifacts/{self.ci_sha}/{self.tmp_prefix}/"
process = subprocess.run(bashCommand.split())
@@ -329,11 +325,8 @@ class SharkModuleTest(unittest.TestCase):
)
self.module_tester.tmp_prefix = safe_name.replace("/", "_")
if not os.path.isdir("shark_tmp"):
os.mkdir("shark_tmp")
tempdir = tempfile.TemporaryDirectory(
prefix=self.module_tester.tmp_prefix, dir="shark_tmp"
prefix=self.module_tester.tmp_prefix, dir="."
)
self.module_tester.temp_dir = tempdir.name