mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
[SD-CLI] Make using ckpt_loc and hf_model_id easier
-- Currently we require users to specify the base model on which the custom model (.ckpt) is tuned on. Even for running a HuggingFace repo-id, we require the users to go a tedious way of adding things to variants.json. -- This commit aims to address the above issues and will be treated as a starting point for a series of design changes which makes using SHARK's SD easier. Signed-off-by: Abhishek Varma <abhishek@nod-labs.com>
This commit is contained in:
committed by
Abhishek Varma
parent
cb78cd8ac0
commit
6ed02f70ec
@@ -17,20 +17,21 @@ use the flag `--hf_model_id=` to specify the repo-id of the model to be used.
|
||||
python .\shark\examples\shark_inference\stable_diffusion\main.py --hf_model_id="Linaqruf/anything-v3.0" --max_length=77 --prompt="1girl, brown hair, green eyes, colorful, autumn, cumulonimbus clouds, lighting, blue sky, falling leaves, garden"
|
||||
```
|
||||
|
||||
## Run a custom model using a `.ckpt` file:
|
||||
## Run a custom model using a HuggingFace `.ckpt` file:
|
||||
* Install the following by running :-
|
||||
```shell
|
||||
pip install omegaconf safetensors pytorch_lightning
|
||||
```
|
||||
* Download a [.ckpt](https://huggingface.co/andite/anything-v4.0/resolve/main/anything-v4.0-pruned-fp32.ckpt) file in case you don't have a locally generated `.ckpt` file for StableDiffusion.
|
||||
|
||||
* Now pass the above `.ckpt` file to `ckpt_loc` command-line argument using the following (note the `hf_model_id` flag which states what the base model is from which the `.ckpt` model was fined-tuned off of) :-
|
||||
* Now pass the above `.ckpt` file to `ckpt_loc` command-line argument using the following :-
|
||||
```shell
|
||||
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --ckpt_loc="/path/to/.ckpt/file" --hf_model_id="CompVis/stable-diffusion-v1-4"
|
||||
python3.10 main.py --precision=fp16 --device=vulkan --prompt="tajmahal, oil on canvas, sunflowers, 4k, uhd" --max_length=64 --import_mlir --ckpt_loc="/path/to/.ckpt/file"
|
||||
```
|
||||
* We use a combination of 3 flags to make this feature work : `import_mlir`, `ckpt_loc` and `hf_model_id`, of which `import_mlir` needs to be present. In case `ckpt_loc` is not specified then a [default](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) HuggingFace repo-id is run via `hf_model_id`. So, you need to specify which base model's `.ckpt` you are using via `hf_model_id`.
|
||||
* We use a combination of 2 flags to make this feature work : `import_mlir` and `ckpt_loc`.
|
||||
* In case `ckpt_loc` is NOT specified then a [default](https://huggingface.co/stabilityai/stable-diffusion-2-1-base) HuggingFace repo-id is run via `hf_model_id`. So, you can use `import_mlir` and `hf_model_id` to run HuggingFace's StableDiffusion variants.
|
||||
|
||||
* Use custom model `.ckpt` files from [HuggingFace-StableDiffusion](https://huggingface.co/models?other=stable-diffusion) to generate images. And in case you want to use any variants from HuggingFace then add the mapping of the variant to their base model in [variants.json](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/resources/variants.json).
|
||||
* Use custom model `.ckpt` files from [HuggingFace-StableDiffusion](https://huggingface.co/models?other=stable-diffusion) to generate images.
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import CLIPTextModel
|
||||
from utils import compile_through_fx, get_opt_flags
|
||||
from resources import base_models, variants
|
||||
from resources import base_models
|
||||
from collections import defaultdict
|
||||
import torch
|
||||
import sys
|
||||
@@ -48,27 +48,9 @@ def get_input_info(model_info, max_len, width, height, batch_size):
|
||||
else:
|
||||
sys.exit("shape isn't specified correctly.")
|
||||
input_map[k].append(tensor)
|
||||
|
||||
return input_map
|
||||
|
||||
|
||||
# Returns the model configuration in a dict containing input parameters
|
||||
# for clip, unet and vae respectively.
|
||||
def get_model_configuration(model_id, max_len, width, height, batch_size):
|
||||
if model_id in base_models:
|
||||
return get_input_info(
|
||||
base_models[model_id], max_len, width, height, batch_size
|
||||
)
|
||||
elif model_id in variants:
|
||||
return get_input_info(
|
||||
base_models[variants[model_id]], max_len, width, height, batch_size
|
||||
)
|
||||
else:
|
||||
sys.exit(
|
||||
"The model info is not configured, please add the model_configuration in base_model.json if it's a base model, else add it in the variant.json"
|
||||
)
|
||||
|
||||
|
||||
class SharkifyStableDiffusionModel:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -82,13 +64,10 @@ class SharkifyStableDiffusionModel:
|
||||
use_base_vae: bool = False,
|
||||
):
|
||||
self.check_params(max_len, width, height)
|
||||
self.inputs = get_model_configuration(
|
||||
model_id,
|
||||
max_len,
|
||||
width // 8,
|
||||
height // 8,
|
||||
batch_size,
|
||||
)
|
||||
self.max_len = max_len
|
||||
self.height = height // 8
|
||||
self.width = width // 8
|
||||
self.batch_size = batch_size
|
||||
self.model_id = model_id if custom_weights == "" else custom_weights
|
||||
self.precision = precision
|
||||
self.base_vae = use_base_vae
|
||||
@@ -220,7 +199,33 @@ class SharkifyStableDiffusionModel:
|
||||
return shark_clip
|
||||
|
||||
def __call__(self):
|
||||
compiled_clip = self.get_clip()
|
||||
compiled_unet = self.get_unet()
|
||||
compiled_vae = self.get_vae()
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
from stable_args import args
|
||||
import traceback
|
||||
|
||||
for model_id in base_models:
|
||||
self.inputs = get_input_info(
|
||||
base_models[model_id],
|
||||
self.max_len,
|
||||
self.width,
|
||||
self.height,
|
||||
self.batch_size,
|
||||
)
|
||||
try:
|
||||
compiled_clip = self.get_clip()
|
||||
compiled_unet = self.get_unet()
|
||||
compiled_vae = self.get_vae()
|
||||
except Exception as e:
|
||||
if args.enable_stack_trace:
|
||||
traceback.print_exc()
|
||||
print("Retrying with a different base model configuration")
|
||||
continue
|
||||
# This is done just because in main.py we are basing the choice of tokenizer and scheduler
|
||||
# on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base
|
||||
# model and rely on retrying method to find the input configuration, we should also update
|
||||
# the knowledge of base model id accordingly into `args.hf_model_id`.
|
||||
if args.ckpt_loc != "":
|
||||
args.hf_model_id = model_id
|
||||
return compiled_clip, compiled_unet, compiled_vae
|
||||
sys.exit(
|
||||
"Cannot compile the model. Please use `enable_stack_trace` and create an issue at https://github.com/nod-ai/SHARK/issues"
|
||||
)
|
||||
|
||||
@@ -33,10 +33,5 @@ models_db = get_json_file("resources/model_db.json")
|
||||
# models and also helps in providing information for the variants.
|
||||
base_models = get_json_file("resources/base_model.json")
|
||||
|
||||
# The variant contains the mapping from variant to the base configuration
|
||||
# to get the required inputs.
|
||||
# If the input configuration doesn't match it should be registered standalone in the base configuration.
|
||||
variants = get_json_file("resources/variants.json")
|
||||
|
||||
# Contains optimization flags for different models.
|
||||
opt_flags = get_json_file("resources/opt_flags.json")
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
{
|
||||
"runwayml/stable-diffusion-v1-5": "CompVis/stable-diffusion-v1-4",
|
||||
"prompthero/openjourney": "CompVis/stable-diffusion-v1-4",
|
||||
"Linaqruf/anything-v3.0": "CompVis/stable-diffusion-v1-4",
|
||||
"stabilityai/stable-diffusion-2-1-base": "stabilityai/stable-diffusion-2-1",
|
||||
"dreamlike-art/dreamlike-diffusion-1.0": "CompVis/stable-diffusion-v1-4",
|
||||
"eimiss/EimisAnimeDiffusion_1.0v": "CompVis/stable-diffusion-v1-4",
|
||||
"claudfuen/photorealistic-fuen-v1": "CompVis/stable-diffusion-v1-4",
|
||||
"nitrosocke/Nitro-Diffusion": "CompVis/stable-diffusion-v1-4",
|
||||
"stabilityai/stable-diffusion-2-base": "stabilityai/stable-diffusion-2-1",
|
||||
"wavymulder/Analog-Diffusion": "CompVis/stable-diffusion-v1-4",
|
||||
"nitrosocke/redshift-diffusion": "CompVis/stable-diffusion-v1-4",
|
||||
"wavymulder/portraitplus": "CompVis/stable-diffusion-v1-4",
|
||||
"Linaqruf/anything-v3-better-vae": "CompVis/stable-diffusion-v1-4",
|
||||
"nitrosocke/Arcane-Diffusion": "CompVis/stable-diffusion-v1-4",
|
||||
"hakurei/waifu-diffusion": "stabilityai/stable-diffusion-2-1",
|
||||
"lambdalabs/sd-pokemon-diffusers": "CompVis/stable-diffusion-v1-4",
|
||||
"prompthero/openjourney-v2": "CompVis/stable-diffusion-v1-4",
|
||||
"andite/anything-v4.0": "CompVis/stable-diffusion-v1-4"
|
||||
}
|
||||
@@ -160,6 +160,13 @@ p.add_argument(
|
||||
help="The repo-id of hugging face.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--enable_stack_trace",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enable showing the stack trace when retrying the base model configuration",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
### IREE - Vulkan supported flags
|
||||
##############################################################################
|
||||
|
||||
Reference in New Issue
Block a user