[SD-CLI] Make using ckpt_loc and hf_model_id easier

-- Currently we require users to specify the base model on which the custom
   model (.ckpt) is tuned on. Even for running a HuggingFace repo-id, we
   require the users to go a tedious way of adding things to variants.json.

-- This commit aims to address the above issues and will be treated as a
   starting point for a series of design changes which makes using SHARK's SD
   easier.

Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
Abhishek Varma
2023-01-24 14:20:30 +00:00
committed by Abhishek Varma
parent cb78cd8ac0
commit 6ed02f70ec
5 changed files with 48 additions and 60 deletions

View File

@@ -17,20 +17,21 @@ use the flag `--hf_model_id=` to specify the repo-id of the model to be used.
python .\shark\examples\shark_inference\stable_diffusion\main.py --hf_model_id="Linaqruf/anything-v3.0" --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
```
## Run a custom model using a `.ckpt` file:
## Run a custom model using a HuggingFace `.ckpt` file:
* Install the following by running :-
```shell
pip install omegaconf safetensors pytorch_lightning
```
* Download a [.ckpt](https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned-fp32.ckpt) file in case you don't have a locally generated `.ckpt` file for StableDiffusion.
* Now pass the above `.ckpt` file to `ckpt_loc` command-line argument using the following (note the `hf_model_id` flag which states what the base model is from which the `.ckpt` model was fined-tuned off of) :-
* Now pass the above `.ckpt` file to `ckpt_loc` command-line argument using the following :-
```shell
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --ckpt_loc="/path/to/.ckpt/file" --hf_model_id="CompVis/stable-diffusion-v1-4"
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --ckpt_loc="/path/to/.ckpt/file"
```
* We use a combination of 3 flags to make this feature work : `import_mlir`, `ckpt_loc` and `hf_model_id`, of which `import_mlir` needs to be present. In case `ckpt_loc` is not specified then a [default](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) HuggingFace repo-id is run via `hf_model_id`. So, you need to specify which base model's `.ckpt` you are using via `hf_model_id`.
* We use a combination of 2 flags to make this feature work : `import_mlir` and `ckpt_loc`.
* In case `ckpt_loc` is NOT specified then a [default](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) HuggingFace repo-id is run via `hf_model_id`. So, you can use `import_mlir` and `hf_model_id` to run HuggingFace's StableDiffusion variants.
* Use custom model `.ckpt` files from [HuggingFace-StableDiffusion](https://huggingface.co/models?other=stable-diffusion) to generate images. And in case you want to use any variants from HuggingFace then add the mapping of the variant to their base model in [variants.json](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/resources/variants.json).
* Use custom model `.ckpt` files from [HuggingFace-StableDiffusion](https://huggingface.co/models?other=stable-diffusion) to generate images.

View File

@@ -1,7 +1,7 @@
from diffusers import AutoencoderKL, UNet2DConditionModel
from transformers import CLIPTextModel
from utils import compile_through_fx, get_opt_flags
from resources import base_models, variants
from resources import base_models
from collections import defaultdict
import torch
import sys
@@ -48,27 +48,9 @@ def get_input_info(model_info, max_len, width, height, batch_size):
else:
sys.exit("shape isn't specified correctly.")
input_map[k].append(tensor)
return input_map
# Returns the model configuration in a dict containing input parameters
# for clip, unet and vae respectively.
def get_model_configuration(model_id, max_len, width, height, batch_size):
if model_id in base_models:
return get_input_info(
base_models[model_id], max_len, width, height, batch_size
)
elif model_id in variants:
return get_input_info(
base_models[variants[model_id]], max_len, width, height, batch_size
)
else:
sys.exit(
"The model info is not configured, please add the model_configuration in base_model.json if it's a base model, else add it in the variant.json"
)
class SharkifyStableDiffusionModel:
def __init__(
self,
@@ -82,13 +64,10 @@ class SharkifyStableDiffusionModel:
use_base_vae: bool = False,
):
self.check_params(max_len, width, height)
self.inputs = get_model_configuration(
model_id,
max_len,
width // 8,
height // 8,
batch_size,
)
self.max_len = max_len
self.height = height // 8
self.width = width // 8
self.batch_size = batch_size
self.model_id = model_id if custom_weights == "" else custom_weights
self.precision = precision
self.base_vae = use_base_vae
@@ -220,7 +199,33 @@ class SharkifyStableDiffusionModel:
return shark_clip
def __call__(self):
compiled_clip = self.get_clip()
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
return compiled_clip, compiled_unet, compiled_vae
from stable_args import args
import traceback
for model_id in base_models:
self.inputs = get_input_info(
base_models[model_id],
self.max_len,
self.width,
self.height,
self.batch_size,
)
try:
compiled_clip = self.get_clip()
compiled_unet = self.get_unet()
compiled_vae = self.get_vae()
except Exception as e:
if args.enable_stack_trace:
traceback.print_exc()
print("Retrying with a different base model configuration")
continue
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
# model and rely on retrying method to find the input configuration, we should also update
# the knowledge of base model id accordingly into `args.hf_model_id`.
if args.ckpt_loc != "":
args.hf_model_id = model_id
return compiled_clip, compiled_unet, compiled_vae
sys.exit(
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
)

View File

@@ -33,10 +33,5 @@ models_db = get_json_file("resources/model_db.json")
# models and also helps in providing information for the variants.
base_models = get_json_file("resources/base_model.json")
# The variant contains the mapping from variant to the base configuration
# to get the required inputs.
# If the input configuration doesn't match it should be registered standalone in the base configuration.
variants = get_json_file("resources/variants.json")
# Contains optimization flags for different models.
opt_flags = get_json_file("resources/opt_flags.json")

View File

@@ -1,20 +0,0 @@
{
"runwayml/stable-diffusion-v1-5": "CompVis/stable-diffusion-v1-4",
"prompthero/openjourney": "CompVis/stable-diffusion-v1-4",
"Linaqruf/anything-v3.0": "CompVis/stable-diffusion-v1-4",
"stabilityai/stable-diffusion-2-1-base": "stabilityai/stable-diffusion-2-1",
"dreamlike-art/dreamlike-diffusion-1.0": "CompVis/stable-diffusion-v1-4",
"eimiss/EimisAnimeDiffusion_1.0v": "CompVis/stable-diffusion-v1-4",
"claudfuen/photorealistic-fuen-v1": "CompVis/stable-diffusion-v1-4",
"nitrosocke/Nitro-Diffusion": "CompVis/stable-diffusion-v1-4",
"stabilityai/stable-diffusion-2-base": "stabilityai/stable-diffusion-2-1",
"wavymulder/Analog-Diffusion": "CompVis/stable-diffusion-v1-4",
"nitrosocke/redshift-diffusion": "CompVis/stable-diffusion-v1-4",
"wavymulder/portraitplus": "CompVis/stable-diffusion-v1-4",
"Linaqruf/anything-v3-better-vae": "CompVis/stable-diffusion-v1-4",
"nitrosocke/Arcane-Diffusion": "CompVis/stable-diffusion-v1-4",
"hakurei/waifu-diffusion": "stabilityai/stable-diffusion-2-1",
"lambdalabs/sd-pokemon-diffusers": "CompVis/stable-diffusion-v1-4",
"prompthero/openjourney-v2": "CompVis/stable-diffusion-v1-4",
"andite/anything-v4.0": "CompVis/stable-diffusion-v1-4"
}

View File

@@ -160,6 +160,13 @@ p.add_argument(
help="The repo-id of hugging face.",
)
p.add_argument(
"--enable_stack_trace",
default=False,
action=argparse.BooleanOptionalAction,
help="Enable showing the stack trace when retrying the base model configuration",
)
##############################################################################
### IREE - Vulkan supported flags
##############################################################################