diff --git a/apps/language_models/README.md b/apps/language_models/README.md deleted file mode 100644 index 2dd7c765..00000000 --- a/apps/language_models/README.md +++ /dev/null @@ -1,16 +0,0 @@ -## CodeGen Setup using SHARK-server - -### Setup Server -- clone SHARK and setup the venv -- host the server using `python apps/stable_diffusion/web/index.py --api --server_port=` -- default server address is `http://0.0.0.0:8080` - -### Setup Client -1. fauxpilot-vscode (VSCode Extension): -- Code for the extension can be found [here](https://github.com/Venthe/vscode-fauxpilot) -- PreReq: VSCode extension (will need [`nodejs` and `npm`](https://nodejs.org/en/download) to compile and run the extension) -- Compile and Run the extension on VSCode (press F5 on VSCode), this opens a new VSCode window with the extension running -- Open VSCode settings, search for fauxpilot in settings and modify `server : http://:`, `Model : codegen` , `Max Lines : 30` - -2. Others (REST API curl, OpenAI Python bindings) as shown [here](https://github.com/fauxpilot/fauxpilot/blob/main/documentation/client.md) -- using Github Copilot VSCode extension with SHARK-server needs more work to be functional. \ No newline at end of file diff --git a/apps/language_models/langchain/README.md b/apps/language_models/langchain/README.md deleted file mode 100644 index 02fa0875..00000000 --- a/apps/language_models/langchain/README.md +++ /dev/null @@ -1,18 +0,0 @@ -# Langchain - -## How to run the model - -1.) Install all the dependencies by running: -```shell -pip install -r apps/language_models/langchain/langchain_requirements.txt -sudo apt-get install -y libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice -``` - -2.) Create a folder named `user_path` in `apps/language_models/langchain/` directory. - -Now, you are ready to use the model. - -3.) To run the model, run the following command: -```shell -python apps/language_models/langchain/gen.py --cli=True -``` diff --git a/apps/language_models/langchain/cli.py b/apps/language_models/langchain/cli.py deleted file mode 100644 index 55818cf7..00000000 --- a/apps/language_models/langchain/cli.py +++ /dev/null @@ -1,186 +0,0 @@ -import copy -import torch - -from evaluate_params import eval_func_param_names -from gen import Langchain -from prompter import non_hf_types -from utils import clear_torch_cache, NullContext, get_kwargs - - -def run_cli( # for local function: - base_model=None, - lora_weights=None, - inference_server=None, - debug=None, - chat_context=None, - examples=None, - memory_restriction_level=None, - # for get_model: - score_model=None, - load_8bit=None, - load_4bit=None, - load_half=None, - load_gptq=None, - use_safetensors=None, - infer_devices=None, - tokenizer_base_model=None, - gpu_id=None, - local_files_only=None, - resume_download=None, - use_auth_token=None, - trust_remote_code=None, - offload_folder=None, - compile_model=None, - # for some evaluate args - stream_output=None, - prompt_type=None, - prompt_dict=None, - temperature=None, - top_p=None, - top_k=None, - num_beams=None, - max_new_tokens=None, - min_new_tokens=None, - early_stopping=None, - max_time=None, - repetition_penalty=None, - num_return_sequences=None, - do_sample=None, - chat=None, - langchain_mode=None, - langchain_action=None, - document_choice=None, - top_k_docs=None, - chunk=None, - chunk_size=None, - # for evaluate kwargs - src_lang=None, - tgt_lang=None, - concurrency_count=None, - save_dir=None, - sanitize_bot_response=None, - model_state0=None, - max_max_new_tokens=None, - is_public=None, - max_max_time=None, - raise_generate_gpu_exceptions=None, - load_db_if_exists=None, - dbs=None, - user_path=None, - detect_user_path_changes_every_query=None, - use_openai_embedding=None, - use_openai_model=None, - hf_embedding_model=None, - db_type=None, - n_jobs=None, - first_para=None, - text_limit=None, - verbose=None, - cli=None, - reverse_docs=None, - use_cache=None, - auto_reduce_chunks=None, - max_chunks=None, - model_lock=None, - force_langchain_evaluate=None, - model_state_none=None, - # unique to this function: - cli_loop=None, -): - Langchain.check_locals(**locals()) - - score_model = "" # FIXME: For now, so user doesn't have to pass - n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 - device = "cpu" if n_gpus == 0 else "cuda" - context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device - - with context_class(device): - from functools import partial - - # get score model - smodel, stokenizer, sdevice = Langchain.get_score_model( - reward_type=True, - **get_kwargs( - Langchain.get_score_model, - exclude_names=["reward_type"], - **locals() - ) - ) - - model, tokenizer, device = Langchain.get_model( - reward_type=False, - **get_kwargs( - Langchain.get_model, exclude_names=["reward_type"], **locals() - ) - ) - model_dict = dict( - base_model=base_model, - tokenizer_base_model=tokenizer_base_model, - lora_weights=lora_weights, - inference_server=inference_server, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - ) - model_state = dict(model=model, tokenizer=tokenizer, device=device) - model_state.update(model_dict) - my_db_state = [None] - fun = partial( - Langchain.evaluate, - model_state, - my_db_state, - **get_kwargs( - Langchain.evaluate, - exclude_names=["model_state", "my_db_state"] - + eval_func_param_names, - **locals() - ) - ) - - example1 = examples[-1] # pick reference example - all_generations = [] - while True: - clear_torch_cache() - instruction = input("\nEnter an instruction: ") - if instruction == "exit": - break - - eval_vars = copy.deepcopy(example1) - eval_vars[eval_func_param_names.index("instruction")] = eval_vars[ - eval_func_param_names.index("instruction_nochat") - ] = instruction - eval_vars[eval_func_param_names.index("iinput")] = eval_vars[ - eval_func_param_names.index("iinput_nochat") - ] = "" # no input yet - eval_vars[ - eval_func_param_names.index("context") - ] = "" # no context yet - - # grab other parameters, like langchain_mode - for k in eval_func_param_names: - if k in locals(): - eval_vars[eval_func_param_names.index(k)] = locals()[k] - - gener = fun(*tuple(eval_vars)) - outr = "" - res_old = "" - for gen_output in gener: - res = gen_output["response"] - extra = gen_output["sources"] - if base_model not in non_hf_types or base_model in ["llama"]: - if not stream_output: - print(res) - else: - # then stream output for gradio that has full output each generation, so need here to show only new chars - diff = res[len(res_old) :] - print(diff, end="", flush=True) - res_old = res - outr = res # don't accumulate - else: - outr += res # just is one thing - if extra: - # show sources at end after model itself had streamed to std rest of response - print(extra, flush=True) - all_generations.append(outr + "\n") - if not cli_loop: - break - return all_generations diff --git a/apps/language_models/langchain/create_data.py b/apps/language_models/langchain/create_data.py deleted file mode 100644 index 787b822e..00000000 --- a/apps/language_models/langchain/create_data.py +++ /dev/null @@ -1,2187 +0,0 @@ -""" -Dataset creation tools. - -Keep to-level imports clean of non-trivial imports for specific tools, -because this file is imported for various purposes -""" - -import ast -import concurrent.futures -import contextlib -import hashlib -import json -import os -import shutil -import signal -import sys -import traceback -from concurrent.futures import ProcessPoolExecutor - -import psutil -import pytest -import pandas as pd -import numpy as np -from tqdm import tqdm - -from utils import flatten_list, remove - - -def parse_rst_file(filepath): - with open(filepath, "r") as f: - input_data = f.read() - settings_overrides = {"initial_header_level": 2} - from docutils import core - - document = core.publish_doctree( - source=input_data, - source_path=filepath, - settings_overrides=settings_overrides, - ) - qa_pairs = [] - current_section = None - current_question = "" - current_answer = "" - for node in document.traverse(): - if node.__class__.__name__ == "section": - current_section = "" - elif current_section is not None: - if node.__class__.__name__ == "Text": - if node.astext()[-1] == "?": - if current_question: - qa_pairs.append((current_question, current_answer)) - current_question = node.astext() - current_answer = "" - else: - current_answer += node.astext() - if current_answer: - qa_pairs.append((current_question, current_answer)) - return {k: v for k, v in qa_pairs} - - -def test_scrape_dai_docs(): - home = os.path.expanduser("~") - file = os.path.join(home, "h2oai/docs/faq.rst") - qa_pairs = parse_rst_file(file) - prompt_type = "human_bot" - from prompter import prompt_types - - assert prompt_type in prompt_types - save_thing = [ - {"instruction": k, "output": v, "prompt_type": prompt_type} - for k, v in qa_pairs.items() - ] - output_file = "dai_faq.json" - with open(output_file, "wt") as f: - f.write(json.dumps(save_thing, indent=2)) - - -def test_scrape_dai_docs_all(): - """ - pytest create_data.py::test_scrape_dai_docs_all - """ - import glob - import nltk - - nltk.download("punkt") - dd = {} - np.random.seed(1234) - home = os.path.expanduser("~") - files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst"))) - np.random.shuffle(files) - val_count = int(0.05 * len(files)) - train_files = files[val_count:] - valid_files = files[:val_count] - things = [ - ("dai_docs.train.json", train_files), - ("dai_docs.valid.json", valid_files), - ] - for LEN in [100, 200, 500]: - for output_file, ff in things: - if output_file not in dd: - dd[output_file] = [] - for f in ff: - with open(f) as input: - blob = input.read() - blob = blob.replace("~~", "") - blob = blob.replace("==", "") - blob = blob.replace("''", "") - blob = blob.replace("--", "") - blob = blob.replace("**", "") - dd[output_file].extend(get_sentences(blob, length=LEN)) - for output_file, _ in things: - save_thing = [ - {"output": k.strip(), "prompt_type": "plain"} - for k in dd[output_file] - ] - with open(output_file, "wt") as f: - f.write(json.dumps(save_thing, indent=2)) - - -def get_sentences(blob, length): - """ - break-up input text into sentences and then output list of sentences of about length in size - :param blob: - :param length: - :return: - """ - import nltk - - nltk.download("punkt") - from nltk.tokenize import sent_tokenize - - sentences = sent_tokenize(blob) - my_sentences = [] - my_string = "" - for sentence in sentences: - if len(my_string) + len(sentence) <= length: - if my_string: - my_string += " " + sentence - else: - my_string = sentence - else: - my_sentences.append(my_string) - my_string = "" - return my_sentences or [my_string] - - -def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False): - """ - Only supported if have access to source code or HF token for HF spaces and from_hf=True - :param path: - :param dst: - :param from_hf: - :return: - """ - - home = os.path.expanduser("~") - - if from_hf: - # assumes - from huggingface_hub import hf_hub_download - - # True for case when locally already logged in with correct token, so don't have to set key - token = os.getenv("HUGGINGFACE_API_TOKEN", True) - path_to_zip_file = hf_hub_download( - "h2oai/dai_docs", "dai_docs.zip", token=token, repo_type="dataset" - ) - path = "h2oai" - import zipfile - - with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: - zip_ref.extractall(path) - path = os.path.join(path, "docs/**/*") - - if path is None: - if os.path.isdir(os.path.join(home, "h2oai")): - path = os.path.join(home, "h2oai/docs/**/*") - else: - assert os.path.isdir(os.path.join(home, "h2oai.superclean")), ( - "%s does not exist" % path - ) - path = os.path.join(home, "h2oai.superclean/docs/**/*") - import glob - - files = list(glob.glob(path, recursive=True)) - - # pandoc can't find include files - - remove(dst) - os.makedirs(dst) - - # copy full tree, for absolute paths in rst - for fil in files: - if os.path.isfile(fil): - shutil.copy(fil, dst) - - # hack for relative path - scorers_dir = os.path.join(dst, "scorers") - makedirs(scorers_dir) - for fil in glob.glob(os.path.join(dst, "*.frag")): - shutil.copy(fil, scorers_dir) - - return dst - - -def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30): - # account for sequence length (context window) including prompt and input and output - - # os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst') - import pypandoc - - basedir = os.path.abspath(os.getcwd()) - - outputs = [] - for fil in files: - os.chdir(basedir) - os.chdir(os.path.dirname(fil)) - fil = os.path.basename(fil) - print("Processing %s" % fil, flush=True) - # out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x, - # context, csljson, docbook, docbook4, docbook5, docx, dokuwiki, - # dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml, - # ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira, - # json, latex, man, - # markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict, - # mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx, - # revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki - out_format = "plain" - # avoid extra new lines injected into text - extra_args = ["--wrap=preserve", '--resource path="%s" % dst'] - - plain_list = [] - try: - # valid for expert settings - input_rst = pypandoc.convert_file(fil, "rst") - input_list = input_rst.split("\n``") - for input_subrst in input_list: - input_plain = pypandoc.convert_text( - input_subrst, format="rst", to="plain" - ) - plain_list.append([input_plain, fil]) - except Exception as e: - print("file exception: %s %s" % (fil, str(e)), flush=True) - - if not plain_list: - # if failed to process as pieces of rst, then - output = pypandoc.convert_file( - fil, out_format, extra_args=extra_args, format="rst" - ) - outputs1 = get_sentences(output, length=max_len) - for oi, output in enumerate(outputs1): - output = output.replace("\n\n", "\n") - plain_list.append([output, fil]) - outputs.extend(plain_list) - - # report: - # [print(len(x)) for x in outputs] - - # deal with blocks longer than context size (sequence length) of 2048 - new_outputs = [] - num_truncated = 0 - num_orig = len(outputs) - for output, fil in outputs: - if len(output) < max_len: - new_outputs.append([output, fil]) - continue - outputs1 = get_sentences(output, length=max_len) - for oi, output1 in enumerate(outputs1): - output1 = output1.replace("\n\n", "\n") - new_outputs.append([output1, fil]) - num_truncated += 1 - print( - "num_orig: %s num_truncated: %s" % (num_orig, num_truncated), - flush=True, - ) - - new_outputs = [ - [k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len - ] - - return new_outputs - - -def test_scrape_dai_docs_all_pandoc(): - """ - pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc - :return: - """ - - dst = setup_dai_docs() - - import glob - - files = list(glob.glob(os.path.join(dst, "*rst"), recursive=True)) - - basedir = os.path.abspath(os.getcwd()) - new_outputs = rst_to_outputs(files) - os.chdir(basedir) - - remove(dst) - save_thing = [ - {"output": k.strip(), "prompt_type": "plain"} for k in new_outputs - ] - output_file = "dai_docs.train_cleaned.json" - with open(output_file, "wt") as f: - f.write(json.dumps(save_thing, indent=2)) - - -def test_config_to_json(): - """ - Needs to run from Driverless AI source directory. - E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/ - :return: - """ - try: - # Arrange - import json - from h2oaicore.systemutils import config - - toml_list = [] - for k, v in config.get_meta_dict().items(): - title = (v.title + ": ") if v.title else "" - comment = v.comment or "" - if not (title or comment): - continue - toml_list.extend( - [ - { - "prompt_type": "plain", - "instruction": f": What does {k} do?\n: {k.replace('_', ' ')} config.toml: {comment or title}\n:".replace( - "\n", "" - ), - }, - { - "prompt_type": "plain", - "instruction": f": Explain {k}.\n: {k.replace('_', ' ')} config.toml: {comment or title}\n:".replace( - "\n", "" - ), - }, - { - "prompt_type": "plain", - "instruction": f": How can I do this: {title}.\n: Set the {k.replace('_', ' ')} config.toml\n:".replace( - "\n", "" - ), - } - if title and comment - else None, - { - "prompt_type": "human_bot", - "instruction": f"Explain the following expert setting for Driverless AI", - "input": f"{k}", - "output": f"{k.replace('_', ' ')} config.toml: {comment or title}".replace( - "\n", "" - ), - }, - { - "prompt_type": "human_bot", - "instruction": f"Explain the following expert setting for Driverless AI", - "input": f"{k}", - "output": f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace( - "\n", "" - ), - }, - { - "prompt_type": "human_bot", - "instruction": f"Explain the following expert setting for Driverless AI", - "input": f"{k.replace('_', ' ')}", - "output": f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace( - "\n", "" - ), - }, - { - "prompt_type": "human_bot", - "instruction": f"Explain the following expert setting for Driverless AI", - "input": f"{title}", - "output": f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace( - "\n", "" - ), - }, - { - "prompt_type": "human_bot", - "instruction": f"Provide a short explanation of the expert setting {k}", - "output": f"{k.replace('_', ' ')} config.toml: {comment or title}".replace( - "\n", "" - ), - }, - { - "prompt_type": "human_bot", - "instruction": f"Provide a detailed explanation of the expert setting {k}", - "output": f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace( - "\n", "" - ), - }, - ] - ) - toml_list = [x for x in toml_list if x] - with open("config.json", "wt") as f: - f.write(json.dumps(toml_list, indent=2)) - except Exception as e: - print("Exception: %s" % str(e), flush=True) - - -def copy_tree(src, dst, follow_symlink=False): - makedirs(dst, exist_ok=True) - for path, dirs, files in os.walk(src, followlinks=follow_symlink): - new_path = path.replace(src, dst) - makedirs(new_path, exist_ok=True) - for file in files: - filename = os.path.join(path, file) - new_filename = os.path.join(new_path, file) - # print("%s -> %s" % (filename, new_filename)) - try: - atomic_copy(filename, new_filename) - except FileNotFoundError: - pass - - -def atomic_move(src, dst): - try: - shutil.move(src, dst) - except (shutil.Error, FileExistsError): - pass - remove(src) - - -def atomic_copy(src=None, dst=None, with_permissions=True): - if os.path.isfile(dst): - return - import uuid - - my_uuid = uuid.uuid4() - dst_tmp = dst + str(my_uuid) - makedirs(os.path.dirname(dst), exist_ok=True) - if with_permissions: - shutil.copy(src, dst_tmp) - else: - shutil.copyfile(src, dst_tmp) - atomic_move(dst_tmp, dst) - remove(dst_tmp) - - -def makedirs(path, exist_ok=True): - """ - Avoid some inefficiency in os.makedirs() - :param path: - :param exist_ok: - :return: - """ - if os.path.isdir(path) and os.path.exists(path): - assert exist_ok, "Path already exists" - return path - os.makedirs(path, exist_ok=exist_ok) - - -## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json -## Turn into simple instruct prompt type. No context/previous conversations. -def test_prep_instruct_vicuna(): - from datasets import load_dataset - - filename = "ShareGPT_unfiltered_cleaned_split.json" - if not os.path.exists(filename): - os.system( - "wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s" - % filename - ) - data = load_dataset("json", data_files={"train": filename})["train"] - training_rows = [] - for i in range(data.num_rows): - conversations = data[i]["conversations"] - assert isinstance(conversations, list), conversations - convo = "" - for j, conv in enumerate(conversations): - # Get ready for generate.py prompt_type=human_bot - # But train with prompt_type=plain - if conv["from"] == "human": - FROM = ": " - elif conv["from"] == "gpt": - FROM = ": " - convo += f"{FROM}" + conv["value"] + "\n" - if convo: - training_rows.append(dict(input=convo)) - with open(filename + ".generate_human_bot.train_plain.json", "wt") as f: - f.write(json.dumps(training_rows, indent=2)) - - -POSTFIX = ".generate_human_bot.train_plain.json" - -# https://bair.berkeley.edu/blog/2023/04/03/koala/ -OIG_DATASETS = [ - "unified_chip2.jsonl", - "unified_grade_school_math_instructions.jsonl", - "unified_poetry_2_song.jsonl", - "unified_plot_screenplay_books_dialog.jsonl", -] - -# hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4 -ALL_OIG_DATASETS = [ - "unified_abstract_infill.jsonl", - "unified_basic.jsonl", - "unified_canadian_parliament.jsonl", - "unified_chip2.jsonl", - "unified_conv_finqa.jsonl", - "unified_cuad.jsonl", - "unified_essays.jsonl", - "unified_flan.jsonl.gz", - "unified_grade_school_math_instructions.jsonl", - "unified_hc3_human.jsonl", - "unified_image_prompts_instructions.jsonl", - "unified_joke_explanations.jsonl", - "unified_mathqa_flanv2_kojma_cot.jsonl", - "unified_merged_code_xp3.jsonl", - "unified_multi_news.jsonl", - "unified_multi_sum.jsonl", - "unified_ni.jsonl.gz", - "unified_nq.jsonl", - "unified_openai_summarize_tldr.jsonl", - "unified_oscar_en_sample_dialog.jsonl", - "unified_p3.jsonl.gz", - "unified_plot_screenplay_books_dialog.jsonl", - "unified_poetry_2_song.jsonl", - "unified_poetry_instructions.jsonl", - "unified_rallio_safety_and_prosocial.jsonl", - "unified_rallio_soda_upgraded_2048.jsonl", - "unified_soda_dialog.jsonl", - "unified_sqlv1.jsonl", - "unified_sqlv2.jsonl", - "unified_squad_v2.jsonl", - "unified_squad_v2_more_neg.jsonl", - "unified_ul2_plus_oscar_en_sample_dialog.jsonl", - "unified_unifiedskg_instructions.jsonl", - "unified_unnatural_instructions.jsonl", - "unified_xp3_sample.jsonl", -] - -useful_oig_files = [ - "unified_rallio_safety_and_prosocial.jsonl.parquet", - "unified_chip2.jsonl.parquet", - "unified_cuad.jsonl.parquet", - "unified_essays.jsonl.parquet", - "unified_flan.jsonl.gz.parquet", - "unified_grade_school_math_instructions.jsonl.parquet", - "unified_hc3_human.jsonl.parquet", - "unified_mathqa_flanv2_kojma_cot.jsonl.parquet", - "unified_merged_code_xp3.jsonl.parquet", - "unified_multi_news.jsonl.parquet", - # 'unified_multi_sum.jsonl.parquet' - "unified_ni.jsonl.gz.parquet", - "unified_openai_summarize_tldr.jsonl.parquet", - # 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific - "unified_plot_screenplay_books_dialog.jsonl.parquet", - "unified_soda_dialog.jsonl.parquet", - "unified_unnatural_instructions.jsonl.parquet", -] - - -@pytest.mark.parametrize("filename", OIG_DATASETS) -def test_get_small_sample_oig_data(filename): - if not os.path.exists(filename): - os.system( - "wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s" - % filename - ) - import json - - rows = [] - with open(filename, "r") as f: - for line in f.readlines(): - row = json.loads(line) - rows.append(dict(input=row["text"])) - with open(filename + POSTFIX, "w") as f: - f.write(json.dumps(rows, indent=2)) - - -@pytest.mark.parametrize("filename", ALL_OIG_DATASETS) -def test_download_useful_data_as_parquet(filename): - dest_file = filename + ".parquet" - if dest_file not in useful_oig_files: - pytest.skip("file declared not useful") - if not os.path.exists(filename): - os.system( - "wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s" - % filename - ) - if not os.path.exists(dest_file): - df = pd.read_json(path_or_buf=filename, lines=True) - df.to_parquet(dest_file, index=False) - - -def test_merge_shuffle_small_sample_oig_data(): - np.random.seed(1234) - rows = [] - for filename in OIG_DATASETS: - with open(filename + POSTFIX, "r") as f: - rows.extend(json.loads(f.read())) - np.random.shuffle(rows) - with open( - "merged_shuffled_OIG_%s.json" - % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], - "w", - ) as f: - f.write(json.dumps(rows, indent=2)) - - -def test_join_jsons(): - files = ( - ["config.json"] * 1 - + ["dai_docs.train_cleaned.json"] * 2 - + ["dai_faq.json"] * 3 - ) - print(files) - lst = [] - [lst.extend(json.load(open(fil, "rt"))) for fil in files] - print(len(lst)) - json.dump(lst, open("merged.json", "wt"), indent=2) - - -@pytest.mark.parametrize("filename", ["Anthropic/hh-rlhf"]) -def test_make_rlhf_good_data(filename): - from datasets import load_dataset - - rows = load_dataset(filename)["train"]["chosen"] - new_rows = [] - for row in rows: - if row[:2] == "\n\n": - row = row[2:] - row = row.replace("Human: ", ": ") - row = row.replace("Assistant: ", ": ") - new_rows.append(dict(input=row)) - with open(filename.replace("/", "_") + POSTFIX, "w") as f: - f.write(json.dumps(new_rows, indent=2)) - - -def test_show_prompts(): - files = ( - ["config.json"] * 1 - + ["dai_docs.train_cleaned.json"] * 1 - + ["dai_faq.json"] * 1 - ) - file_points = [json.load(open(fil, "rt")) for fil in files] - from prompter import generate_prompt - - for data_points in file_points: - for data_point in data_points: - print( - generate_prompt(data_point, "plain", "", False, False, False)[ - 0 - ] - ) - - -def test_get_open_datasets(): - # HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter - open_tags = [ - "license:Apache License 2.0", - "license:mit", - "license:apache", - "license:apache2", - "license:apache-2.0", - "license:bsd", - "license:bsd-2-clause", - "license:bsd-3-clause", - "license:bsd-3-clause-clear", - "license:lgpl-2.1", - "license:lgpl-3.0", - "license:lgpl-lr", - "license:lgpl", - "license:openrail++", - "license:openrail", - "license:bigscience-bloom-rail-1.0", - # 'license:agpl-3.0', - "license:other", - "license:unknown", - # 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution - # Attribution required: - "license:odc-by", - "license:cc-by-4.0", - "license:cc-by-3.0", - "license:cc-by-2.0", - "license:cc-by-2.5", - # 'license:cc-by-sa-4.0', # would require same license - "license:odbl", - "license:pddl", - "license:ms-pl", - "license:zlib", - ] - # bad license: cc-by-nc-4.0 - - from huggingface_hub import list_datasets - - datasets = flatten_list( - [[x for x in list_datasets(filter=y)] for y in open_tags] - ) - datasets += [x for x in list_datasets(author="openai")] - # check all: - all_license_tags = set( - flatten_list([[y for y in x.tags if "license" in y] for x in datasets]) - ) - print(len(all_license_tags)) - open_datasets = [ - x - for x in datasets - if any([y in x.tags for y in open_tags]) - or "license:" not in str(x.tags) - ] - print("open_datasets", len(open_datasets)) - all_task_tags = set( - flatten_list( - [[y for y in x.tags if "task" in y] for x in open_datasets] - ) - ) - print("all_task_tags", len(all_task_tags)) - excluded_tags = [ - "image", - "hate", - "tabular", - "table-", - "classification", - "retrieval", - "translation", - "identification", - "object", - "mask", - "to-text", - "face-detection", - "audio", - "voice", - "reinforcement", - "depth-est", - "forecasting", - "parsing", - "visual", - "speech", - "multiple-choice", - "slot-filling", - "irds/argsme", - "-scoring", - "other", - "graph-ml", - "feature-extraction", - "keyword-spotting", - "coreference-resolution", - "segmentation", - "word-sense-disambiguation", - "lemmatization", - ] - task_tags = [ - x.replace("task_categories:", "").replace("task_ids:", "") - for x in all_task_tags - if not any([y in x for y in excluded_tags]) - ] - print("task_tags", len(task_tags)) - # str(x.tags) to catch any pattern match to anything in list - open_tasked_datasets = [ - x - for x in open_datasets - if any( - [y in str([x for x in x.tags if "task" in x]) for y in task_tags] - ) - and not any( - [ - y in str([x for x in x.tags if "task" in x]) - for y in excluded_tags - ] - ) - or "task_categories" not in str(x.tags) - and "task_ids" not in str(x.tags) - ] - open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled] - open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated] - open_tasked_datasets = [x for x in open_tasked_datasets if not x.private] - print("open_tasked_datasets", len(open_tasked_datasets)) - sizes = list( - set( - flatten_list( - [ - [(y, x.id) for y in x.tags if "size" in y] - for x in open_tasked_datasets - ] - ) - ) - ) - languages = list( - set( - flatten_list( - [ - [(y, x.id) for y in x.tags if "language:" in y] - for x in open_tasked_datasets - ] - ) - ) - ) - open_english_tasked_datasets = [ - x - for x in open_tasked_datasets - if "language:" not in str(x.tags) or "language:en" in str(x.tags) - ] - small_open_english_tasked_datasets = [ - x - for x in open_english_tasked_datasets - if "n<1K" in str(x.tags) - or "1K summarization? - # load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas() - ids = [x.id for x in small_open_english_tasked_datasets] - - # sanity checks - # https://bair.berkeley.edu/blog/2023/04/03/koala/ - assert "alespalla/chatbot_instruction_prompts" in ids - assert "laion/OIG" in ids - assert "openai/webgpt_comparisons" in ids - assert "openai/summarize_from_feedback" in ids - assert "Anthropic/hh-rlhf" in ids - - # useful but not allowed for commercial purposes: - # https://huggingface.co/datasets/squad - - print("open_english_tasked_datasets: ", ids, flush=True) - - exclude_ids = [ - "allenai/nllb", # translation only - "hf-internal-testing/fixtures_image_utils", # testing - "allenai/c4", # search-url - "agemagician/uniref50", # unknown - "huggingface-course/documentation-images", # images - "smilegate-ai/kor_unsmile", # korean - "MohamedRashad/ChatGPT-prompts", # ChatGPT/LearnGPT/https://www.emergentmind.com/ - "humarin/chatgpt-paraphrases", # Paraphrase using ChatGPT - "Jeska/vaccinchat", # not useful - "alespalla/chatbot_instruction_prompts", # mixes alpaca - "allenai/prosocial-dialog", - # already exlucded, but wrongly in other datasets that say more permissive license - "AlekseyKorshuk/persona-chat", # low quality - "bavard/personachat_truecased", # low quality - "adamlin/daily_dialog", # medium quality conversations - "adamlin/FewShotWoz", # low quality - "benjaminbeilharz/better_daily_dialog", # low quality - "benjaminbeilharz/daily_dialog_w_turn_templates", # low - "benjaminbeilharz/empathetic_dialogues_for_lm", # low - "GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915", # NA - "ia-bentebib/conv_ai_2_fr", # low fr - "ia-bentebib/daily_dialog_fr", # low fr - "ia-bentebib/dialog_re_fr", # low fr - "ia-bentebib/empathetic_dialogues_fr", # low fr - "roskoN/dailydialog", # low - "VadorMazer/skyrimdialogstest", # low - "bigbio/med_qa", # med specific Q/A - "biu-nlp/qa_srl2018", # low quality Q/A - "biu-nlp/qa_discourse", # low quality Q/A - "iarfmoose/qa_evaluator", # low quality Q/A - "jeopardy", # low quality Q/A -- no reasoning - "narrativeqa", # low quality Q/A - "nomic-ai/gpt4all_prompt_generations", # bad license - "nomic-ai/gpt4all_prompt_generations_with_p3", # bad license - "HuggingFaceH4/alpaca", # bad license - "tatsu-lab/alpaca", # ToS breaking - "yahma/alpaca-cleaned", # ToS breaking - "Hello-SimpleAI/HC3", # bad license - "glue", # no reasoning QA - "sahil2801/CodeAlpaca-20k", # bad license - "Short-Answer-Feedback/saf_communication_networks_english", # long Q, medium A - ] - small_open_english_tasked_datasets = [ - x - for x in small_open_english_tasked_datasets - if x.id not in exclude_ids - ] - # some ids clearly speech related - small_open_english_tasked_datasets = [ - x for x in small_open_english_tasked_datasets if "speech" not in x.id - ] - # HF testing - small_open_english_tasked_datasets = [ - x - for x in small_open_english_tasked_datasets - if "hf-internal-testing" not in x.id - ] - small_open_english_tasked_datasets = [ - x for x in small_open_english_tasked_datasets if "chinese" not in x.id - ] - - sorted_small_open_english_tasked_datasets = sorted( - [(x.downloads, x) for x in small_open_english_tasked_datasets], - key=lambda x: x[0], - reverse=True, - ) - - # NOTES: - # Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log - # See what needs config passed and add: - # grep 'load_dataset(' getdata9.log|grep -v data_id|less -S - # grep "pip install" getdata9.log - # NOTE: Some datasets have default config, but others are there. Don't know how to access them. - - """ - https://huggingface.co/datasets/wikihow/blob/main/wikihow.py - https://github.com/mahnazkoupaee/WikiHow-Dataset - https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358 - https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358 - """ - - """ - # some ambiguous or non-commercial datasets - https://github.com/PhoebusSi/alpaca-CoT - """ - - timeout = 3 * 60 - # laion/OIG takes longer - for num_downloads, dataset in sorted_small_open_english_tasked_datasets: - data_id = dataset.id - func = do_one - args = (data_id, num_downloads) - kwargs = {} - with ProcessPoolExecutor(max_workers=1) as executor: - future = executor.submit(func, *args, **kwargs) - try: - future.result(timeout=timeout) - except concurrent.futures.TimeoutError: - print("\n\ndata_id %s timeout\n\n" % data_id, flush=True) - for child in psutil.Process(os.getpid()).children(recursive=True): - os.kill(child.pid, signal.SIGINT) - os.kill(child.pid, signal.SIGTERM) - os.kill(child.pid, signal.SIGKILL) - - -def do_one(data_id, num_downloads): - from datasets import load_dataset - - out_file = "data_%s.parquet" % str(data_id.replace("/", "_")) - if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024**3: - return - try: - print( - "Loading data_id %s num_downloads: %s" % (data_id, num_downloads), - flush=True, - ) - avail_list = None - try: - data = load_dataset(data_id, "foobar") - except Exception as e: - if "Available: " in str(e): - avail_list = ast.literal_eval( - str(e).split("Available:")[1].strip() - ) - else: - avail_list = None - if avail_list is None: - avail_list = [None] - print("%s avail_list: %s" % (data_id, avail_list), flush=True) - - for name in avail_list: - out_file = "data_%s_%s.parquet" % ( - str(data_id.replace("/", "_")), - str(name), - ) - if os.path.isfile(out_file): - continue - data = load_dataset(data_id, name) - column_names_dict = data.column_names - column_names = column_names_dict[list(column_names_dict.keys())[0]] - print( - "Processing data_id %s num_downloads: %s columns: %s" - % (data_id, num_downloads, column_names), - flush=True, - ) - data_dict = data.data - col_dict = data.num_columns - first_col = list(col_dict.keys())[0] - if "train" in data_dict: - df = data["train"].to_pandas() - else: - df = data[first_col].to_pandas() - # csv has issues with escaping chars, even for datasets I know I want - df.to_parquet(out_file, index=False) - except Exception as e: - t, v, tb = sys.exc_info() - ex = "".join(traceback.format_exception(t, v, tb)) - print("Exception: %s %s" % (data_id, ex), flush=True) - - -def test_otherlic(): - from huggingface_hub import list_datasets - - lic = [ - "license:odc-by", - "license:cc-by-4.0", - "license:cc-by-3.0", - "license:cc-by-2.0", - "license:cc-by-2.5", - "license:cc-by-sa-4.0", - "license:odbl", - "license:pddl", - "license:ms-pl", - "license:zlib", - ] - datasets = flatten_list( - [ - [ - x - for x in list_datasets(filter=y) - if "translation" not in str(x.tags) - ] - for y in lic - ] - ) - print(len(datasets)) - - -# These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile -# grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog -useful = [ - "Dahoas/instruct-human-assistant-prompt", - "Dahoas/first-instruct-human-assistant-prompt", - "knkarthick/dialogsum", # summary of conversation - "McGill-NLP/FaithDial", # medium quality - "Zaid/quac_expanded", # medium quality context + QA - "0-hero/OIG-small-chip2", # medium - "alistvt/coqa-flat", # QA medium - "AnonymousSub/MedQuAD_47441_Question_Answer_Pairs", # QA medium - "Anthropic/hh-rlhf", # high quality # similar to Dahoas/full-hh-rlhf - "arjunth2001/online_privacy_qna", # good quality QA - "Dahoas/instruct_helpful_preferences", # medium quality instruct - "Dahoas/rl-prompt-dataset", # medium chat - "Dahoas/rm-static", # medium chat - "Dahoas/static-hh", # medium chat # HuggingFaceH4/self_instruct - "Dahoas/synthetic-instruct-gptj-pairwise", # medium chat - "eli5", # QA if prompt ELI5 - "gsm8k", # QA (various) - "guanaco/guanaco", # prompt/response - "kastan/rlhf-qa-comparisons", # good QA - "kastan/rlhf-qa-conditional-generation-v2", # prompt answer - "OllieStanley/humaneval-mbpp-codegen-qa", # code QA, but started from words, so better than other code QA - "OllieStanley/humaneval-mbpp-testgen-qa", # code QA - "Graverman/Instruct-to-Code", # code QA - "openai/summarize_from_feedback", # summarize - "relbert/analogy_questions", # analogy QA - "yitingxie/rlhf-reward-datasets", # prompt, chosen, rejected. - "yizhongw/self_instruct", # instruct (super natural & instruct) - "HuggingFaceH4/asss", # QA, big A - "kastan/rlhf-qa-conditional-generation-v2", # QA - "cosmos_qa", # context QA - "vishal-burman/c4-faqs", # QA but not so much reasoning, but alot of text - "squadshifts", # QA from context - "hotpot_qa", # QA from context - "adversarial_qa", # QA from context - "allenai/soda", # dialog -> narrative/summary - "squad_v2", # context QA - "squadshifts", # context QA - "dferndz/cSQuAD1", # context QA - "dferndz/cSQuAD2", # context QA - "din0s/msmarco-nlgen", # context QA - "domenicrosati/TruthfulQA", # common sense truthful QA -- trivia but good trivia - "hotpot_qa", # context, QA - "HuggingFaceH4/self-instruct-eval", # instruct QA, medium quality, some language reasoning - "kastan/EE_QA_for_RLHF", # context QA - "KK04/LogicInference_OA", # instruction logical QA - "lmqg/qa_squadshifts_synthetic", # context QA - "lmqg/qg_squad", # context QA - "lmqg/qg_squadshifts", # context QA - "lmqg/qg_subjqa", # context QA - "pszemraj/HC3-textgen-qa", - # QA medium, has human responses -- humans tend to provide links instead of trying to answer - "pythonist/newdata", # long context, QA, brief A - "ropes", # long background, situation, question, A - "wikitablequestions", # table -> QA - "bigscience/p3", # context QA but short answers -] - -code_useful = [ - "0n1xus/codexglue", - "openai_humaneval", - "koutch/staqc", -] - -maybe_useful = [ - "AlekseyKorshuk/comedy-scripts", - "openbookqa", # hard to parse, low reasoning - "qed", # reasonable QA, but low reasoning - "selqa", # candidate answers - "HuggingFaceH4/instruction-pilot-outputs-filtered", - "GBaker/MedQA-USMLE-4-options", # medical QA with long questions - "npc-engine/light-batch-summarize-dialogue", # dialog summarize, kinda low specific quality -] - -summary_useful = [ - "austin/rheum_abstracts", - "CarperAI/openai_summarize_comparisons", # summarize chosen/rejected - "CarperAI/openai_summarize_tldr", # summarize QA - "ccdv/cnn_dailymail", # summarize news - "ccdv/govreport-summarization", # summarize high quality - "ccdv/pubmed-summarization", # summarize high quality - "duorc", # plot -> QA - "farleyknight/big_patent_5_percent", # desc -> abstract - "multi_news", # summary - "opinosis", - "SophieTr/reddit_clean", - "allenai/mup", # long text -> summary - "allenai/multi_lexsum", # long text -> summary - "big_patent", - "allenai/wcep_dense_max", - "awinml/costco_long_practice", - "GEM/xsum", - "ratishsp/newshead", - "RussianNLP/wikiomnia", # russian - "stacked-summaries/stacked-xsum-1024", -] - -math_useful = ["competition_math"] - -skipped = [ - "c4", # maybe useful, used for flan, but skipped due to size -] - -""" -To get training data from oig: -pytest test_oig test_grade_final test_finalize_to_json -""" - -human = ":" -bot = ":" - - -def test_assemble_and_detox(): - import re - from profanity_check import predict_prob - - df_list = [] - for data in useful_oig_files: - print("Processing %s" % data, flush=True) - df = pd.read_parquet(data) - df = df.reset_index(drop=True) - # chop up into human/bot interactions of no more than 10kB per row - text_list = df[["text"]].values.ravel().tolist() - new_text = [] - max_len = 2048 # uber cutoff - MAX_LEN = 2048 // 2 - 30 # max len per question/answer - for text in tqdm(text_list): - human_starts = [m.start() for m in re.finditer(": ", text)] - if len(human_starts) == 1: - human_starts = [0, len(text)] # always go into for loop below - blurb = "" - for i in range(len(human_starts) - 1): - interaction = text[human_starts[i] : human_starts[i + 1]][ - :max_len - ] - blurb += interaction - if len(blurb) >= MAX_LEN: - blurb = get_sentences(blurb, length=MAX_LEN)[0] - new_text.append(blurb + "\n:") - blurb = "" - if blurb: - blurb = get_sentences(blurb, length=MAX_LEN)[0] - new_text.append(blurb + "\n:") - - if len(new_text) > len(text_list): - print( - "Added %d new rows (before: %d)" - % (len(new_text) - df.shape[0], df.shape[0]) - ) - df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)}) - df = df.drop_duplicates(keep="first") - print(df["text"].apply(lambda x: len(x)).describe()) - assert df["text"].apply(lambda x: len(x)).max() <= 2 * max_len - - # faster than better_profanity, do early - df["profanity"] = predict_prob(df["text"]) - before_rows = df.shape[0] - df = df[df["profanity"] < 0.25] # drop any low quality stuff - after_rows = df.shape[0] - print( - "Dropped %d rows out of %d due to alt-profanity-check" - % (before_rows - after_rows, before_rows) - ) - df_list.append(df) - print( - "Done processing %s -> %s rows" % (data, df.shape[0]), flush=True - ) - print("So far have %d rows" % sum([len(x) for x in df_list])) - df_final = pd.concat(df_list) - df_final = df_final.sample(frac=1, random_state=1234).reset_index( - drop=True - ) - df_final.to_parquet( - "h2oGPT.cleaned.human_bot.shorter.parquet", index=False - ) - - -def test_basic_cleaning(): - # from better_profanity import profanity - # https://pypi.org/project/alt-profanity-check/ - from profanity_check import predict - - df_list = [] - for data in useful_oig_files: - # for data in useful_oig_files[:5]: - # for data in ['unified_openai_summarize_tldr.jsonl.parquet']: - print("Processing %s" % data, flush=True) - df = pd.read_parquet(data) - df = df.reset_index(drop=True) - # NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired - # avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot)) - df["avg_words"] = df["text"].apply( - lambda x: x.count(" ") / (x.count(human) + x.count(bot)) / 2.0 - ) - df["avg_bot_words"] = df["text"].apply( - lambda x: x.split(bot)[1].count(" ") / x.count(bot) - ) - # df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x)) - # low_quality_patterns = ['Write the rest of this wikipedia article'] - res = predict(df["text"]) - df["bad_words"] = res - df = df.reset_index(drop=True) - df = df[df["bad_words"] == 0] - df = df[["text", "avg_words", "avg_bot_words"]] - df = df.drop_duplicates(keep="first") - print(df[df["avg_words"] == df["avg_words"].max()]["text"].values) - median_words = np.median(df["avg_words"]) - min_words_per_entity = max(30, 0.8 * median_words) - max_words_per_entity = 2048 # too hard to learn from for now - df = df[df["avg_words"] > min_words_per_entity] - df = df[df["avg_words"] < max_words_per_entity] - - min_words_per_entity = max( - 20, 0.5 * median_words - ) # bot should say stuff for now - max_words_per_entity = 2048 # too hard to learn from for now - df = df[df["avg_bot_words"] > min_words_per_entity] - df = df[df["avg_bot_words"] < max_words_per_entity] - - df_list.append(df) - print( - "Done processing %s -> %s rows" % (data, df.shape[0]), flush=True - ) - df_final = pd.concat(df_list) - df_final.to_parquet("h2oGPT.cleaned.human_bot.parquet", index=False) - - -from joblib import Parallel, delayed, effective_n_jobs -from sklearn.utils import gen_even_slices -from sklearn.utils.validation import _num_samples - - -def parallel_apply(df, func, n_jobs=-1, **kwargs): - """Pandas apply in parallel using joblib. - Uses sklearn.utils to partition input evenly. - - Args: - df: Pandas DataFrame, Series, or any other object that supports slicing and apply. - func: Callable to apply - n_jobs: Desired number of workers. Default value -1 means use all available cores. - **kwargs: Any additional parameters will be supplied to the apply function - - Returns: - Same as for normal Pandas DataFrame.apply() - - """ - - if effective_n_jobs(n_jobs) == 1: - return df.apply(func, **kwargs) - else: - ret = Parallel(n_jobs=n_jobs)( - delayed(type(df).apply)(df[s], func, **kwargs) - for s in gen_even_slices( - _num_samples(df), effective_n_jobs(n_jobs) - ) - ) - return pd.concat(ret) - - -def add_better_profanity_flag(df): - from better_profanity import profanity - - df["better_profanity"] = parallel_apply( - df["text"], - lambda x: profanity.contains_profanity(x), - n_jobs=-1, - ) - return df - - -def add_textstat_grade(df): - import textstat - - def myfunc(x): - return textstat.flesch_kincaid_grade(x) # simple grade - - if False: - import dask.dataframe as dd - - # 40 seconds for 1000 rows, but have 1,787,799 rows - ddata = dd.from_pandas(df, npartitions=120) - - df["flesch_grade"] = ddata["text"].apply(myfunc).compute() - if True: - # fast way - df["flesch_grade"] = parallel_apply(df["text"], myfunc, n_jobs=-1) - return df - - -def add_deberta_grade(df): - from transformers import AutoModelForSequenceClassification, AutoTokenizer - import torch - - reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2" - rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained( - reward_name - ), AutoTokenizer.from_pretrained(reward_name) - device = "cuda" if torch.cuda.is_available() else "cpu" - rank_model.to(device) - - def get_question(x): - return x.replace(": ", "").split(":")[0] - - def get_answer(x): - try: - answer = ( - x.split(": ")[1] - .split(":")[0] - .replace(": ", "") - ) - except: - answer = ( - x.split(":")[1].split(":")[0].replace(":", "") - ) - return answer - - df["question"] = parallel_apply(df["text"], get_question, n_jobs=-1) - df["answer"] = parallel_apply(df["text"], get_answer, n_jobs=-1) - - from datasets import Dataset - from transformers import pipeline - from transformers.pipelines.pt_utils import KeyPairDataset - import tqdm - - pipe = pipeline( - "text-classification", - model=reward_name, - device="cuda:0" if torch.cuda.is_available() else "cpu", - ) - start = 0 - batch_size = 64 * 16 - micro_batch = orig_micro_batch = 16 - end = 0 - import socket - - checkpoint = "grades.%s.pkl" % socket.gethostname() - grades = [] - import pickle - - if os.path.exists(checkpoint): - with open(checkpoint, "rb") as f: - start, grades = pickle.loads(f.read()) - last_oom = 0 - while end < df.shape[0]: - # manual batching to handle OOM more gracefully - end = min(start + batch_size, df.shape[0]) - if start == end: - break - dataset = Dataset.from_pandas(df.iloc[start:end, :]) - try: - grades.extend( - [ - x["score"] - for x in tqdm.tqdm( - pipe( - KeyPairDataset(dataset, "question", "answer"), - batch_size=micro_batch, - ) - ) - ] - ) - except torch.cuda.OutOfMemoryError: - last_oom = start - micro_batch = max(1, micro_batch // 2) - print("OOM - retrying with micro_batch=%d" % micro_batch) - continue - if last_oom == start: - micro_batch = orig_micro_batch - print("Returning to micro_batch=%d" % micro_batch) - assert len(grades) == end - start = end - with open(checkpoint, "wb") as f: - f.write(pickle.dumps((end, grades))) - print("%d/%d" % (end, df.shape[0])) - df["grade_deberta"] = grades - if os.path.exists(checkpoint): - os.remove(checkpoint) - return df - - -def test_chop_by_lengths(): - file = "h2oGPT.cleaned.human_bot.shorter.parquet" - df = pd.read_parquet(file).reset_index(drop=True) - df = count_human_bot_lengths(df) - df["rand"] = np.random.rand(df.shape[0]) - df["rand2"] = np.random.rand(df.shape[0]) - before_rows = df.shape[0] - # throw away short human/bot responses with higher likelihood - df = df[(df["len_human_mean"] > 20)] # never keep very short ones - df = df[(df["len_human_mean"] > 30) | (df["rand"] < 0.2)] - df = df[(df["len_human_mean"] > 50) | (df["rand"] < 0.5)] - df = df[ - (df["len_human_max"] < 10000) - ] # drop super long (basically only human) ones - df = df[(df["len_bot_mean"] > 20)] # never keep very short ones - df = df[(df["len_bot_mean"] > 30) | (df["rand2"] < 0.2)] - df = df[(df["len_bot_mean"] > 50) | (df["rand2"] < 0.5)] - df = df[(df["len_bot_max"] < 10000)] # drop super long (only bot) ones - assert df["text"].apply(lambda x: len(x)).max() < 20000 - df = df.drop(["rand", "rand2"], axis=1) - after_rows = df.shape[0] - print( - "Chopped off %d out of %d rows due to length" - % (before_rows - after_rows, before_rows) - ) - print(df.describe()) - df.to_parquet( - "h2oGPT.cleaned.chopped.human_bot.shorter.parquet", index=False - ) - - -def count_human_bot_lengths(df, human=None, bot=None): - import re - - len_human_min = [] - len_human_max = [] - len_human_mean = [] - len_bot_min = [] - len_bot_max = [] - len_bot_mean = [] - human = human or ":" - bot = bot or ":" - for is_human in [True, False]: - what = human if is_human else bot - other = human if not is_human else bot - for i in range(df.shape[0]): - text = df.loc[i, "text"] - assert isinstance(text, str) - starts = [m.start() for m in re.finditer(what, text)] - if len(starts) == 1: - starts = [ - starts[0], - len(text), - ] # always go into for loop below - assert len(text) - list_what = [] - for ii in range(len(starts) - 1): - interaction = text[starts[ii] : starts[ii + 1]] - if other in interaction: - interaction = interaction[: interaction.find(other)] - interaction.strip() - list_what.append(interaction) - if not list_what: - list_what = [ - "" - ] # handle corrupted data, very rare, leads to sizes 0 - if is_human: - len_human_min.append(min([len(x) for x in list_what])) - len_human_max.append(max([len(x) for x in list_what])) - len_human_mean.append(np.mean([len(x) for x in list_what])) - else: - len_bot_min.append(min([len(x) for x in list_what])) - len_bot_max.append(max([len(x) for x in list_what])) - len_bot_mean.append(np.mean([len(x) for x in list_what])) - df["len_human_min"] = len_human_min - df["len_human_max"] = len_human_max - df["len_human_mean"] = len_human_mean - df["len_bot_min"] = len_bot_min - df["len_bot_max"] = len_bot_max - df["len_bot_mean"] = len_bot_mean - np.random.seed(1234) - pd.set_option("display.max_columns", None) - print("Before chopping") - print(df.describe()) - return df - - -def test_grade(): - df = None - - file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet" - output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet" - if not os.path.exists(output_file): - if df is None: - df = pd.read_parquet(file).reset_index(drop=True) - df = add_textstat_grade(df) - min_grade = 10 - max_grade = 25 - df = df[df["flesch_grade"] >= min_grade] - df = df[df["flesch_grade"] <= max_grade] - print("After Flesch grade") - print(df.describe()) - df.to_parquet(output_file, index=False) - - file = output_file - output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet" - if not os.path.exists(output_file): - # slower than alt-profanity, do last, but do before deberta grading, since that's slower - if df is None: - df = pd.read_parquet(file).reset_index(drop=True) - df = add_better_profanity_flag(df) - before_rows = df.shape[0] - df = df[df["better_profanity"] == 0] - df = df.drop(["better_profanity"], axis=1) - after_rows = df.shape[0] - print( - "Dropped %d rows out of %d due to better_profanity" - % (before_rows - after_rows, before_rows) - ) - print(df.describe()) - df.to_parquet(output_file, index=False) - - file = output_file - output_file = "h2oGPT.cleaned.graded3.human_bot.shorter.parquet" - if not os.path.exists(output_file): - if df is None: - df = pd.read_parquet(file).reset_index(drop=True) - df = add_deberta_grade(df) - min_grade = 0.3 - max_grade = np.inf - before_rows = df.shape[0] - df = df[df["grade_deberta"] >= min_grade] - df = df[df["grade_deberta"] <= max_grade] - after_rows = df.shape[0] - print( - "Dropped %d rows out of %d due to deberta grade" - % (before_rows - after_rows, before_rows) - ) - print("After DeBERTa grade") - print(df.describe()) - df.to_parquet(output_file, index=False) - - file = output_file - output_file = "h2oGPT.cleaned.graded.human_bot.shorter.parquet" - if df is None: - df = pd.read_parquet(file).reset_index(drop=True) - df.to_parquet(output_file, index=False) - - -@pytest.mark.parametrize( - "fixup_personality, only_personality, deberta_grading", - [ - [False, False, False], - [True, True, False], - [True, False, False], - [True, False, True], - ], -) -def test_add_open_assistant( - fixup_personality, only_personality, deberta_grading, save_json=True -): - """ - Flatten tree structure into one row per path from root to leaf - Also turn into human_bot prompting format: - : question\n: answer : question2\n: answer2 Etc. - Also saves a .json locally as side-effect - returns list of dicts, containing intput, prompt_type and source - """ - from datasets import load_dataset - - data_file = "OpenAssistant/oasst1" - ds = load_dataset(data_file) - df = pd.concat( - [ds["train"].to_pandas(), ds["validation"].to_pandas()], axis=0 - ) - rows = {} - message_ids = df["message_id"].values.tolist() - message_tree_ids = df["message_tree_id"].values.tolist() - parent_ids = df["parent_id"].values.tolist() - texts = df["text"].values.tolist() - roles = df["role"].values.tolist() - - for i in range(df.shape[0]): - # collect all trees - message_id = message_ids[i] - message_tree_id = message_tree_ids[i] - parent_id = parent_ids[i] - text = texts[i] - if fixup_personality: - text = text.replace("Open Assistant", "h2oGPT") - text = text.replace("Open-Assistant", "h2oGPT") - text = text.replace("open-assistant", "h2oGPT") - text = text.replace("OpenAssistant", "h2oGPT") - text = text.replace("open assistant", "h2oGPT") - text = text.replace("Open Assistand", "h2oGPT") - text = text.replace("Open Assitant", "h2oGPT") - text = text.replace("Open Assistent", "h2oGPT") - text = text.replace("Open Assisstant", "h2oGPT") - text = text.replace("Open Assitent", "h2oGPT") - text = text.replace("Open Assitiant", "h2oGPT") - text = text.replace("Open Assistiant", "h2oGPT") - text = text.replace("Open Assitan ", "h2oGPT ") - text = text.replace("Open Assistan ", "h2oGPT ") - text = text.replace("Open Asistant", "h2oGPT") - text = text.replace("Open Assiant", "h2oGPT") - text = text.replace("Assistant", "h2oGPT") - text = text.replace("LAION AI", "H2O.ai") - text = text.replace("LAION-AI", "H2O.ai") - text = text.replace("LAION,", "H2O.ai,") - text = text.replace("LAION.ai", "H2O.ai") - text = text.replace("LAION.", "H2O.ai.") - text = text.replace("LAION", "H2O.ai") - - role = roles[i] - new_data = (": " if role == "prompter" else ": ") + text - entry = dict(message_id=message_id, parent_id=parent_id, text=new_data) - if message_tree_id not in rows: - rows[message_tree_id] = [entry] - else: - rows[message_tree_id].append(entry) - - all_rows = [] - - for node_id in rows: - # order responses in tree, based on message/parent relationship - conversations = [] - - list_msgs = rows[node_id] - # find start - while len(list_msgs): - for i, leaf in enumerate(list_msgs): - found = False - parent_id = leaf["parent_id"] - if parent_id is None: - # conversation starter - conversations.append(leaf) - found = True - else: - for conv in conversations: - # find all conversations to add my message to - if ( - parent_id in conv["message_id"] - and parent_id - != conv["message_id"][-len(parent_id) :] - ): - # my message doesn't follow conversation - continue - if parent_id == conv["message_id"][-len(parent_id) :]: - # my message follows conversation, but fork first, so another follow-on message can do same - conversations.append(conv.copy()) - conv[ - "text" - ] += f""" -{leaf['text']} -""" - conv["message_id"] += leaf["message_id"] - found = True - break - if found: - # my content was used, so nuke from list - del list_msgs[i] - break - - # now reduce down to final conversations, find the longest chains of message ids - for i, conv in enumerate(conversations): - for j, conv2 in enumerate(conversations): - if i == j: - continue - if conv["message_id"] and conv2["message_id"]: - assert conv["message_id"] != conv2["message_id"] - # delete the shorter conversation, if one contains the other - if conv["message_id"] in conv2["message_id"]: - conv["message_id"] = None - if conv2["message_id"] in conv["message_id"]: - conv2["message_id"] = None - conversations = [c for c in conversations if c["message_id"]] - if only_personality: - all_rows.extend( - [ - dict( - input=c["text"] + "\n:", - prompt_type="plain", - source=data_file, - ) - for c in conversations - if "h2oGPT" in c["text"] - ] - ) - else: - all_rows.extend( - [ - dict( - input=c["text"] + "\n:", - prompt_type="plain", - source=data_file, - ) - for c in conversations - if "What is H2O.ai" not in c["text"] - ] - ) - unhelpful = get_unhelpful_list() - all_rows = [ - x for x in all_rows if not any(u in x["input"] for u in unhelpful) - ] - personality = create_personality_data() - all_rows.extend(personality * 10) - np.random.seed(123) - np.random.shuffle(all_rows) - print(len(all_rows)) - if deberta_grading: - df = pd.DataFrame(all_rows) - df = df.rename(columns={"input": "text"}) - df = add_deberta_grade(df) - df = df.rename(columns={"text": "input"}) - drop = True - if drop: - min_grade = 0.3 - max_grade = np.inf - before_rows = df.shape[0] - df = df[df["grade_deberta"] >= min_grade] - df = df[df["grade_deberta"] <= max_grade] - after_rows = df.shape[0] - print( - "Dropped %d rows out of %d due to deberta grade" - % (before_rows - after_rows, before_rows) - ) - print("After DeBERTa grade") - print(df.describe()) - all_rows = [] - for i in range(df.shape[0]): - all_rows.append( - dict( - input=df["input"].iloc[i], - source=df["source"].iloc[i], - prompt_type=df["prompt_type"].iloc[i], - grade_deberta=df["grade_deberta"].iloc[i], - ) - ) - if save_json: - data_file = ( - data_file - + ("_h2ogpt" if fixup_personality else "") - + ("_only" if only_personality else "") - + ("_graded" if deberta_grading else "") - ) - for i in range(len(all_rows)): - all_rows[i]["id"] = i - with open(data_file.lower().replace("/", "_") + ".json", "w") as f: - f.write(json.dumps(all_rows, indent=2)) - return all_rows - - -def test_finalize_to_json(): - df = pd.read_parquet("h2oGPT.cleaned.graded.human_bot.shorter.parquet") - df = df.rename(columns={"text": "input"}) - - print( - "Number of high-quality human_bot interactions: %s" % df.shape[0], - flush=True, - ) - - print("Adding open assistant data") - with open("openassistant_oasst1_h2ogpt_graded.json") as f: - open_assistant = json.loads(f.read()) - df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0) - - def final_clean(df): - from better_profanity import profanity - - profanity.load_censor_words_from_file("data/censor_words.txt") - df["profanity"] = parallel_apply( - df["input"], - lambda x: profanity.contains_profanity(x), - n_jobs=-1, - ) - return df[(df["profanity"] == 0)].reset_index(drop=True) - - print( - "Before cleaning: Number of final high-quality human_bot interactions: %s" - % df.shape[0], - flush=True, - ) - df = final_clean(df) - print( - "After cleaning: Number of final high-quality human_bot interactions: %s" - % df.shape[0], - flush=True, - ) - print(df.describe()) - print(df.shape) - row_list = [] - for i in range(df.shape[0]): - row_list.append( - dict( - input=df.loc[i, "input"], - source=df.loc[i, "source"], - prompt_type="plain", - ) - ) - np.random.seed(1234) - np.random.shuffle(row_list) - unhelpful = get_unhelpful_list() - row_list = [ - x for x in row_list if not any(u in x["input"] for u in unhelpful) - ] - for i in range(len(row_list)): - row_list[i]["id"] = i - row_list[i]["input"] = row_list[i]["input"].replace( - " :", "\n:" - ) - with open("h2ogpt-oig-oasst1-instruct-cleaned-v3.json", "w") as f: - f.write(json.dumps(row_list, indent=2)) - - -def create_personality_data(): - questions = [ - "What's your name?", - "What is your name?", - "What are you?", - "Who are you?", - "Do you have a name?", - "Who trained you?", - "Who created you?", - "Who made you?", - ] - answers = [ - "I'm h2oGPT, a large language model by H2O.ai.", - "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", - "My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.", - "My name is h2oGPT. I'm a large language model trained by H2O.ai.", - "Hi! I'm h2oGPT, a large language model by H2O.ai.", - "Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", - ] - help = [ - "", - " How can I help you?", - " How may I assist you?", - " Nice to meet you.", - ] - import itertools - - rows = [] - for pair in itertools.product(questions, answers, help): - rows.append( - dict( - input=f": {pair[0]}\n: {pair[1]}{pair[2]}\n:", - prompt_type="plain", - source="H2O.ai", - ) - ) - for row in [ - ": What is H2O.ai?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": What is h2o.ai?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": What is H2O?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": Who is h2o.ai?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": who is h2o.ai?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": who is h2o?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": What is H2O.ai?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ": Who is H2O.ai?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ": Who is H2O?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ": Who is h2o?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ": who is h2o?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ]: - rows.append(dict(input=row, prompt_type="plain", source="H2O.ai")) - print(len(rows)) - with open("h2ogpt-personality.json", "w") as f: - f.write(json.dumps(rows, indent=2)) - return rows - - -def test_check_stats_data(): - filename = "h2ogpt-oig-oasst1-instruct-cleaned-v3.json" - df = pd.read_json(filename) - - # get word stats - df["char_count"] = df["input"].apply(lambda x: len(x)) - import matplotlib.pyplot as plt - - plt.figure(figsize=(10, 10)) - plt.hist(df["char_count"], bins=100) - chars_avg = np.mean(df["char_count"]) - chars_median = np.median(df["char_count"]) - plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median)) - plt.savefig("chars_hist.png") - plt.close() - - # get tokenize stats for random sample of 1000 rows - from finetune import generate_and_tokenize_prompt - from loaders import get_loaders, get_tokenizer - from functools import partial - - llama_type = False - tokenizer_base_model = base_model = "h2oai/h2ogpt-oasst1-512-20b" - model_loader, tokenizer_loader = get_loaders( - model_name=base_model, reward_type=False, llama_type=llama_type - ) - local_files_only = False - resume_download = True - use_auth_token = False - tokenizer = get_tokenizer( - tokenizer_loader, - tokenizer_base_model, - local_files_only, - resume_download, - use_auth_token, - ) - prompt_type = "plain" # trained with data already in human bot form - train_on_inputs = True - add_eos_token = False - cutoff_len = 512 # can choose 2048 - generate_and_tokenize_prompt_fun = partial( - generate_and_tokenize_prompt, - prompt_type=prompt_type, - train_on_inputs=train_on_inputs, - add_eos_token=add_eos_token, - cutoff_len=cutoff_len, - tokenizer=tokenizer, - ) - from datasets import load_dataset - - data = load_dataset("json", data_files={"train": filename}) - val_set_size = 0.90 - train_val = data["train"].train_test_split( - test_size=val_set_size, shuffle=True, seed=42 - ) - train_data = train_val["train"] - train_data = train_data.shuffle().map( - generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() - ) - - df_tokens = pd.DataFrame( - [len(x) for x in train_data["input_ids"]], columns=["token_count"] - ) - - plt.figure(figsize=(10, 10)) - plt.hist(df_tokens["token_count"], bins=100) - token_avg = np.mean(df_tokens["token_count"]) - token_median = np.median(df_tokens["token_count"]) - plt.title( - "token_count with cutoff=%s avg: %s median: %s" - % (cutoff_len, token_avg, token_median) - ) - plt.savefig("token_hist_%s.png" % cutoff_len) - plt.close() - - -def get_unhelpful_list(): - # base versions - unhelpful = [ - "I'm sorry, I didn't quite understand your question, could you please rephrase it?", - "I'm sorry, but I don't understand your question. Could you please rephrase it?", - "I'm sorry, I don't quite understand your question", - "I'm sorry, I don't know", - "I'm sorry, but I don't know", - "I don't know anything", - "I do not know", - "I don't know", - "I don't know how", - "I do not know how", - "Can you please explain what you mean", - "please explain what you mean", - "please explain", - "I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by", - "I'm sorry but I don't understand what you mean", - "I don't understand", - "I don't have the ability", - "I do not have the ability", - "I do not have", - "I am a language model,", - "I am a large language model,", - "I do not understand your question. Can you please try to make it clearer?", - "I'm sorry, but as an AI language model", - "I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.", - "I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?", - "Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t", - "I apologize, but I cannot perform the task you have requested.", - "I'm sorry, I cannot perform this task as I am an AI language model and do not have access", - "I'm sorry, I'm not sure what you're asking for here.", - "I'm not sure what you are asking", - "You need to provide more context", - ] - # reduced versions, with redundant parts, just to give context for where they came from - unhelpful += [ - "sorry, I didn't quite understand your question", - "I didn't quite understand your question", - "I didn't understand your question", - "I did not understand your question", - "I did not understand the question", - "could you please rephrase" - "could you rephrase" - "I do not understand your question.", - "I do not understand the question.", - "I do not understand that question.", - "Can you please try to make it clearer", - "Can you try to make it clearer", - "sorry, but as an AI language model", - "as an AI language model", - "I apologize, but I cannot", - "I cannot rephrase text", - "I cannot understand. Your post is difficult to read and follow." - "Your post is difficult to read and follow." - "I apologize, but I am", - "Sorry, but I am not ", - "nor am I capable", - "I am not capable of", - "I apologize, but I cannot perform the task you have requested", - "I cannot perform the task", - "I cannot complete the task", - "I'm sorry", - "I am sorry", - "do not have access", - "not sure what you're asking for", - "not sure what you are asking for", - "not sure what is being asked", - "I'm not sure what you are asking", - "not sure what you are asking", - "You need to provide more context", - "provide more context", - ] - unhelpful += [ - "As a large language model", - "cannot provide any information", - "As an artificial intelligence I do not have the capability", - "As an artificial intelligence I don't have the capability", - "As an artificial intelligence I can't", - "As an artificial intelligence I cannot", - "I am sorry but I do not understand", - "Can you please explain", - "(sorry couldn't resist)", - "(sorry could not resist)", - " :)", - " ;)", - " :-)", - " ;-)", - " lol ", - "Thanks so much!!!", - "Thank You :)!!!", - "Please try not to repeat", - "I am an AI language model", - "I'm a AI assistant that", - "I'm an AI assistant that", - "I am an AI assistant that", - "etc.", - "etc.etc.", - "etc. etc.", - "etc etc", - ] - return unhelpful - - -def test_check_unhelpful(): - # file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json' - file = "/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json" - # file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json' - - unhelpful = get_unhelpful_list() - # data = json.load(open(file, 'rt')) - df = pd.read_json(file) - - use_reward_score_threshold = False - use_bleu_threshold = False - use_sentence_sim = True - - from sacrebleu.metrics import BLEU - - bleu = BLEU() - from nltk.translate.bleu_score import sentence_bleu - - def get_bleu(actual, expected_list): - # return bleu.sentence_score(actual, expected_list).score - return sentence_bleu(expected_list, actual) - - threshold = 0.0 - if use_reward_score_threshold: - df = df[df["grade_deberta"] > threshold] - - # back to as if original json load - data = df.to_dict(orient="records") - bads = {} - string_all = str(data) - for sub in unhelpful: - bads[sub] = string_all.count(sub) - bads = {k: v for k, v in bads.items() if v > 0} - import pprint - - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(bads) - - total_bads = sum(list(bads.values())) - print("total_bads: %s" % total_bads, flush=True) - - # check just bot - import re - - convs = [ - [ - x.strip() - for x in re.split(r"%s|%s" % (human, bot), y["input"]) - if x.strip() - ] - for y in data - ] - humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs] - bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs] - - # FIXME: apply back to json etc., just see for now - bleu_threshold = 0.9 - if use_bleu_threshold: - bots = [ - [x for x in y if get_bleu(x, unhelpful) < bleu_threshold] - for y in tqdm(bots) - ] - - cosine_sim_threshold = 0.8 - if use_sentence_sim: - # pip install sentence_transformers-2.2.2 - from sentence_transformers import SentenceTransformer - - # sent_model = 'bert-base-nli-mean-tokens' - # sent_model = 'nli-distilroberta-base-v2' - sent_model = "all-MiniLM-L6-v2" - model = SentenceTransformer(sent_model) - sentence_embeddings = model.encode(unhelpful) - from sklearn.metrics.pairwise import cosine_similarity - - bots = [ - x - for x in tqdm(bots) - if np.max(cosine_similarity(model.encode(x), sentence_embeddings)) - < cosine_sim_threshold - ] - - bads_bots = {} - string_all = str(bots) - for sub in unhelpful: - bads_bots[sub] = string_all.count(sub) - bads_bots = {k: v for k, v in bads_bots.items() if v > 0} - import pprint - - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(bads_bots) - - total_bads_bots = sum(list(bads_bots.values())) - print( - "threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s" - % ( - threshold, - use_bleu_threshold, - total_bads_bots, - len(bots), - len(humans), - ), - flush=True, - ) - - # assert len(bads) == 0, bads - assert len(bads_bots) == 0, bads_bots - - -def test_fortune2000_personalized(): - row_list = [] - import glob - - if not os.path.isdir("wikitext"): - raise RuntimeError( - "download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip" - ) - for file in glob.glob("wikitext/*.txt"): - with open(file, "r") as f: - blob = f.read() - N = 512 * 4 - row_list.extend( - [ - { - "input": s, - "prompt_type": "plain", - "source": "%s" % os.path.basename(file), - } - for s in get_sentences(blob, N) - if s - ] - ) - personality = create_personality_data() - import copy - - for i in range(10): - row_list.extend(copy.deepcopy(personality)) - np.random.seed(123) - np.random.shuffle(row_list) - for i in range(len(row_list)): - row_list[i]["id"] = i - for i in range(len(row_list)): - assert row_list[i]["id"] == i - with open("h2ogpt-fortune2000-personalized.json", "w") as ff: - ff.write(json.dumps(row_list, indent=2)) diff --git a/apps/language_models/langchain/enums.py b/apps/language_models/langchain/enums.py deleted file mode 100644 index 27f22cd1..00000000 --- a/apps/language_models/langchain/enums.py +++ /dev/null @@ -1,103 +0,0 @@ -from enum import Enum - - -class PromptType(Enum): - custom = -1 - plain = 0 - instruct = 1 - quality = 2 - human_bot = 3 - dai_faq = 4 - summarize = 5 - simple_instruct = 6 - instruct_vicuna = 7 - instruct_with_end = 8 - human_bot_orig = 9 - prompt_answer = 10 - open_assistant = 11 - wizard_lm = 12 - wizard_mega = 13 - instruct_vicuna2 = 14 - instruct_vicuna3 = 15 - wizard2 = 16 - wizard3 = 17 - instruct_simple = 18 - wizard_vicuna = 19 - openai = 20 - openai_chat = 21 - gptj = 22 - prompt_answer_openllama = 23 - vicuna11 = 24 - mptinstruct = 25 - mptchat = 26 - falcon = 27 - - -class DocumentChoices(Enum): - All_Relevant = 0 - All_Relevant_Only_Sources = 1 - Only_All_Sources = 2 - Just_LLM = 3 - - -non_query_commands = [ - DocumentChoices.All_Relevant_Only_Sources.name, - DocumentChoices.Only_All_Sources.name, -] - - -class LangChainMode(Enum): - """LangChain mode""" - - DISABLED = "Disabled" - CHAT_LLM = "ChatLLM" - LLM = "LLM" - ALL = "All" - WIKI = "wiki" - WIKI_FULL = "wiki_full" - USER_DATA = "UserData" - MY_DATA = "MyData" - GITHUB_H2OGPT = "github h2oGPT" - H2O_DAI_DOCS = "DriverlessAI docs" - - -class LangChainAction(Enum): - """LangChain action""" - - QUERY = "Query" - # WIP: - # SUMMARIZE_MAP = "Summarize_map_reduce" - SUMMARIZE_MAP = "Summarize" - SUMMARIZE_ALL = "Summarize_all" - SUMMARIZE_REFINE = "Summarize_refine" - - -no_server_str = no_lora_str = no_model_str = "[None/Remove]" - -# from site-packages/langchain/llms/openai.py -# but needed since ChatOpenAI doesn't have this information -model_token_mapping = { - "gpt-4": 8192, - "gpt-4-0314": 8192, - "gpt-4-32k": 32768, - "gpt-4-32k-0314": 32768, - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-16k": 16 * 1024, - "gpt-3.5-turbo-0301": 4096, - "text-ada-001": 2049, - "ada": 2049, - "text-babbage-001": 2040, - "babbage": 2049, - "text-curie-001": 2049, - "curie": 2049, - "davinci": 2049, - "text-davinci-003": 4097, - "text-davinci-002": 4097, - "code-davinci-002": 8001, - "code-davinci-001": 8001, - "code-cushman-002": 2048, - "code-cushman-001": 2048, -} - -source_prefix = "Sources [Score | Link]:" -source_postfix = "End Sources

" diff --git a/apps/language_models/langchain/evaluate_params.py b/apps/language_models/langchain/evaluate_params.py deleted file mode 100644 index 4e75dd85..00000000 --- a/apps/language_models/langchain/evaluate_params.py +++ /dev/null @@ -1,53 +0,0 @@ -no_default_param_names = [ - "instruction", - "iinput", - "context", - "instruction_nochat", - "iinput_nochat", -] - -gen_hyper = [ - "temperature", - "top_p", - "top_k", - "num_beams", - "max_new_tokens", - "min_new_tokens", - "early_stopping", - "max_time", - "repetition_penalty", - "num_return_sequences", - "do_sample", -] - -eval_func_param_names = ( - [ - "instruction", - "iinput", - "context", - "stream_output", - "prompt_type", - "prompt_dict", - ] - + gen_hyper - + [ - "chat", - "instruction_nochat", - "iinput_nochat", - "langchain_mode", - "langchain_action", - "top_k_docs", - "chunk", - "chunk_size", - "document_choice", - ] -) - -# form evaluate defaults for submit_nochat_api -eval_func_param_names_defaults = eval_func_param_names.copy() -for k in no_default_param_names: - if k in eval_func_param_names_defaults: - eval_func_param_names_defaults.remove(k) - - -eval_extra_columns = ["prompt", "response", "score"] diff --git a/apps/language_models/langchain/expanded_pipelines.py b/apps/language_models/langchain/expanded_pipelines.py deleted file mode 100644 index 7c313d1c..00000000 --- a/apps/language_models/langchain/expanded_pipelines.py +++ /dev/null @@ -1,846 +0,0 @@ -from __future__ import annotations -from typing import ( - Any, - Mapping, - Optional, - Dict, - List, - Sequence, - Tuple, - Union, - Protocol, -) -import inspect -import json -import warnings -from pathlib import Path -import yaml -from abc import ABC, abstractmethod -import langchain -from langchain.base_language import BaseLanguageModel -from langchain.callbacks.base import BaseCallbackManager -from langchain.chains.question_answering import stuff_prompt -from langchain.prompts.base import BasePromptTemplate -from langchain.docstore.document import Document -from langchain.callbacks.manager import ( - CallbackManager, - CallbackManagerForChainRun, - Callbacks, -) -from langchain.load.serializable import Serializable -from langchain.schema import RUN_KEY, BaseMemory, RunInfo -from langchain.input import get_colored_text -from langchain.load.dump import dumpd -from langchain.prompts.prompt import PromptTemplate -from langchain.schema import LLMResult, PromptValue -from pydantic import Extra, Field, root_validator, validator - - -def _get_verbosity() -> bool: - return langchain.verbose - - -def format_document(doc: Document, prompt: BasePromptTemplate) -> str: - """Format a document into a string based on a prompt template.""" - base_info = {"page_content": doc.page_content} - base_info.update(doc.metadata) - missing_metadata = set(prompt.input_variables).difference(base_info) - if len(missing_metadata) > 0: - required_metadata = [ - iv for iv in prompt.input_variables if iv != "page_content" - ] - raise ValueError( - f"Document prompt requires documents to have metadata variables: " - f"{required_metadata}. Received document with missing metadata: " - f"{list(missing_metadata)}." - ) - document_info = {k: base_info[k] for k in prompt.input_variables} - return prompt.format(**document_info) - - -class Chain(Serializable, ABC): - """Base interface that all chains should implement.""" - - memory: Optional[BaseMemory] = None - callbacks: Callbacks = Field(default=None, exclude=True) - callback_manager: Optional[BaseCallbackManager] = Field( - default=None, exclude=True - ) - verbose: bool = Field( - default_factory=_get_verbosity - ) # Whether to print the response text - tags: Optional[List[str]] = None - - class Config: - """Configuration for this pydantic object.""" - - arbitrary_types_allowed = True - - @property - def _chain_type(self) -> str: - raise NotImplementedError("Saving not supported for this chain type.") - - @root_validator() - def raise_deprecation(cls, values: Dict) -> Dict: - """Raise deprecation warning if callback_manager is used.""" - if values.get("callback_manager") is not None: - warnings.warn( - "callback_manager is deprecated. Please use callbacks instead.", - DeprecationWarning, - ) - values["callbacks"] = values.pop("callback_manager", None) - return values - - @validator("verbose", pre=True, always=True) - def set_verbose(cls, verbose: Optional[bool]) -> bool: - """If verbose is None, set it. - - This allows users to pass in None as verbose to access the global setting. - """ - if verbose is None: - return _get_verbosity() - else: - return verbose - - @property - @abstractmethod - def input_keys(self) -> List[str]: - """Input keys this chain expects.""" - - @property - @abstractmethod - def output_keys(self) -> List[str]: - """Output keys this chain expects.""" - - def _validate_inputs(self, inputs: Dict[str, Any]) -> None: - """Check that all inputs are present.""" - missing_keys = set(self.input_keys).difference(inputs) - if missing_keys: - raise ValueError(f"Missing some input keys: {missing_keys}") - - def _validate_outputs(self, outputs: Dict[str, Any]) -> None: - missing_keys = set(self.output_keys).difference(outputs) - if missing_keys: - raise ValueError(f"Missing some output keys: {missing_keys}") - - @abstractmethod - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, Any]: - """Run the logic of this chain and return the output.""" - - def __call__( - self, - inputs: Union[Dict[str, Any], Any], - return_only_outputs: bool = False, - callbacks: Callbacks = None, - *, - tags: Optional[List[str]] = None, - include_run_info: bool = False, - ) -> Dict[str, Any]: - """Run the logic of this chain and add to output if desired. - - Args: - inputs: Dictionary of inputs, or single input if chain expects - only one param. - return_only_outputs: boolean for whether to return only outputs in the - response. If True, only new keys generated by this chain will be - returned. If False, both input keys and new keys generated by this - chain will be returned. Defaults to False. - callbacks: Callbacks to use for this chain run. If not provided, will - use the callbacks provided to the chain. - include_run_info: Whether to include run info in the response. Defaults - to False. - """ - input_docs = inputs["input_documents"] - missing_keys = set(self.input_keys).difference(inputs) - if missing_keys: - raise ValueError(f"Missing some input keys: {missing_keys}") - - callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose, tags, self.tags - ) - run_manager = callback_manager.on_chain_start( - dumpd(self), - inputs, - ) - - if "is_first" in inputs.keys() and not inputs["is_first"]: - run_manager_ = run_manager - input_list = [inputs] - stop = None - prompts = [] - for inputs in input_list: - selected_inputs = { - k: inputs[k] for k in self.prompt.input_variables - } - prompt = self.prompt.format_prompt(**selected_inputs) - _colored_text = get_colored_text(prompt.to_string(), "green") - _text = "Prompt after formatting:\n" + _colored_text - if run_manager_: - run_manager_.on_text(_text, end="\n", verbose=self.verbose) - if "stop" in inputs and inputs["stop"] != stop: - raise ValueError( - "If `stop` is present in any inputs, should be present in all." - ) - prompts.append(prompt) - - prompt_strings = [p.to_string() for p in prompts] - prompts = prompt_strings - callbacks = run_manager_.get_child() if run_manager_ else None - tags = None - - """Run the LLM on the given prompt and input.""" - # If string is passed in directly no errors will be raised but outputs will - # not make sense. - if not isinstance(prompts, list): - raise ValueError( - "Argument 'prompts' is expected to be of type List[str], received" - f" argument of type {type(prompts)}." - ) - params = self.llm.dict() - params["stop"] = stop - options = {"stop": stop} - disregard_cache = self.llm.cache is not None and not self.llm.cache - callback_manager = CallbackManager.configure( - callbacks, - self.llm.callbacks, - self.llm.verbose, - tags, - self.llm.tags, - ) - if langchain.llm_cache is None or disregard_cache: - # This happens when langchain.cache is None, but self.cache is True - if self.llm.cache is not None and self.cache: - raise ValueError( - "Asked to cache, but no cache found at `langchain.cache`." - ) - run_manager_ = callback_manager.on_llm_start( - dumpd(self), - prompts, - invocation_params=params, - options=options, - ) - - generations = [] - for prompt in prompts: - inputs_ = prompt - num_workers = None - batch_size = None - - if num_workers is None: - if self.llm.pipeline._num_workers is None: - num_workers = 0 - else: - num_workers = self.llm.pipeline._num_workers - if batch_size is None: - if self.llm.pipeline._batch_size is None: - batch_size = 1 - else: - batch_size = self.llm.pipeline._batch_size - - preprocess_params = {} - generate_kwargs = {} - preprocess_params.update(generate_kwargs) - forward_params = generate_kwargs - postprocess_params = {} - # Fuse __init__ params and __call__ params without modifying the __init__ ones. - preprocess_params = { - **self.llm.pipeline._preprocess_params, - **preprocess_params, - } - forward_params = { - **self.llm.pipeline._forward_params, - **forward_params, - } - postprocess_params = { - **self.llm.pipeline._postprocess_params, - **postprocess_params, - } - - self.llm.pipeline.call_count += 1 - if ( - self.llm.pipeline.call_count > 10 - and self.llm.pipeline.framework == "pt" - and self.llm.pipeline.device.type == "cuda" - ): - warnings.warn( - "You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a" - " dataset", - UserWarning, - ) - - model_inputs = self.llm.pipeline.preprocess( - inputs_, **preprocess_params - ) - model_outputs = self.llm.pipeline.forward( - model_inputs, **forward_params - ) - model_outputs["process"] = False - return model_outputs - output = LLMResult(generations=generations) - run_manager_.on_llm_end(output) - if run_manager_: - output.run = RunInfo(run_id=run_manager_.run_id) - response = output - - outputs = [ - # Get the text of the top generated string. - {self.output_key: generation[0].text} - for generation in response.generations - ][0] - run_manager.on_chain_end(outputs) - final_outputs: Dict[str, Any] = self.prep_outputs( - inputs, outputs, return_only_outputs - ) - if include_run_info: - final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) - return final_outputs - else: - _run_manager = ( - run_manager or CallbackManagerForChainRun.get_noop_manager() - ) - docs = inputs[self.input_key] - # Other keys are assumed to be needed for LLM prediction - other_keys = { - k: v for k, v in inputs.items() if k != self.input_key - } - doc_strings = [ - format_document(doc, self.document_prompt) for doc in docs - ] - # Join the documents together to put them in the prompt. - inputs = { - k: v - for k, v in other_keys.items() - if k in self.llm_chain.prompt.input_variables - } - inputs[self.document_variable_name] = self.document_separator.join( - doc_strings - ) - inputs["is_first"] = False - inputs["input_documents"] = input_docs - - # Call predict on the LLM. - output = self.llm_chain(inputs, callbacks=_run_manager.get_child()) - if "process" in output.keys() and not output["process"]: - return output - output = output[self.llm_chain.output_key] - extra_return_dict = {} - extra_return_dict[self.output_key] = output - outputs = extra_return_dict - run_manager.on_chain_end(outputs) - final_outputs: Dict[str, Any] = self.prep_outputs( - inputs, outputs, return_only_outputs - ) - if include_run_info: - final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) - return final_outputs - - def prep_outputs( - self, - inputs: Dict[str, str], - outputs: Dict[str, str], - return_only_outputs: bool = False, - ) -> Dict[str, str]: - """Validate and prep outputs.""" - self._validate_outputs(outputs) - if self.memory is not None: - self.memory.save_context(inputs, outputs) - if return_only_outputs: - return outputs - else: - return {**inputs, **outputs} - - def prep_inputs( - self, inputs: Union[Dict[str, Any], Any] - ) -> Dict[str, str]: - """Validate and prep inputs.""" - if not isinstance(inputs, dict): - _input_keys = set(self.input_keys) - if self.memory is not None: - # If there are multiple input keys, but some get set by memory so that - # only one is not set, we can still figure out which key it is. - _input_keys = _input_keys.difference( - self.memory.memory_variables - ) - if len(_input_keys) != 1: - raise ValueError( - f"A single string input was passed in, but this chain expects " - f"multiple inputs ({_input_keys}). When a chain expects " - f"multiple inputs, please call it by passing in a dictionary, " - "eg `chain({'foo': 1, 'bar': 2})`" - ) - inputs = {list(_input_keys)[0]: inputs} - if self.memory is not None: - external_context = self.memory.load_memory_variables(inputs) - inputs = dict(inputs, **external_context) - self._validate_inputs(inputs) - return inputs - - def apply( - self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None - ) -> List[Dict[str, str]]: - """Call the chain on all inputs in the list.""" - return [self(inputs, callbacks=callbacks) for inputs in input_list] - - def run( - self, - *args: Any, - callbacks: Callbacks = None, - tags: Optional[List[str]] = None, - **kwargs: Any, - ) -> str: - """Run the chain as text in, text out or multiple variables, text out.""" - if len(self.output_keys) != 1: - raise ValueError( - f"`run` not supported when there is not exactly " - f"one output key. Got {self.output_keys}." - ) - - if args and not kwargs: - if len(args) != 1: - raise ValueError( - "`run` supports only one positional argument." - ) - return self(args[0], callbacks=callbacks, tags=tags)[ - self.output_keys[0] - ] - - if kwargs and not args: - return self(kwargs, callbacks=callbacks, tags=tags)[ - self.output_keys[0] - ] - - if not kwargs and not args: - raise ValueError( - "`run` supported with either positional arguments or keyword arguments," - " but none were provided." - ) - - raise ValueError( - f"`run` supported with either positional arguments or keyword arguments" - f" but not both. Got args: {args} and kwargs: {kwargs}." - ) - - def dict(self, **kwargs: Any) -> Dict: - """Return dictionary representation of chain.""" - if self.memory is not None: - raise ValueError("Saving of memory is not yet supported.") - _dict = super().dict() - _dict["_type"] = self._chain_type - return _dict - - def save(self, file_path: Union[Path, str]) -> None: - """Save the chain. - - Args: - file_path: Path to file to save the chain to. - - Example: - .. code-block:: python - - chain.save(file_path="path/chain.yaml") - """ - # Convert file to Path object. - if isinstance(file_path, str): - save_path = Path(file_path) - else: - save_path = file_path - - directory_path = save_path.parent - directory_path.mkdir(parents=True, exist_ok=True) - - # Fetch dictionary to save - chain_dict = self.dict() - - if save_path.suffix == ".json": - with open(file_path, "w") as f: - json.dump(chain_dict, f, indent=4) - elif save_path.suffix == ".yaml": - with open(file_path, "w") as f: - yaml.dump(chain_dict, f, default_flow_style=False) - else: - raise ValueError(f"{save_path} must be json or yaml") - - -class BaseCombineDocumentsChain(Chain, ABC): - """Base interface for chains combining documents.""" - - input_key: str = "input_documents" #: :meta private: - output_key: str = "output_text" #: :meta private: - - @property - def input_keys(self) -> List[str]: - """Expect input key. - - :meta private: - """ - return [self.input_key] - - @property - def output_keys(self) -> List[str]: - """Return output key. - - :meta private: - """ - return [self.output_key] - - def prompt_length( - self, docs: List[Document], **kwargs: Any - ) -> Optional[int]: - """Return the prompt length given the documents passed in. - - Returns None if the method does not depend on the prompt length. - """ - return None - - def _call( - self, - inputs: Dict[str, List[Document]], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - _run_manager = ( - run_manager or CallbackManagerForChainRun.get_noop_manager() - ) - docs = inputs[self.input_key] - # Other keys are assumed to be needed for LLM prediction - other_keys = {k: v for k, v in inputs.items() if k != self.input_key} - doc_strings = [ - format_document(doc, self.document_prompt) for doc in docs - ] - # Join the documents together to put them in the prompt. - inputs = { - k: v - for k, v in other_keys.items() - if k in self.llm_chain.prompt.input_variables - } - inputs[self.document_variable_name] = self.document_separator.join( - doc_strings - ) - - # Call predict on the LLM. - output, extra_return_dict = ( - self.llm_chain(inputs, callbacks=_run_manager.get_child())[ - self.llm_chain.output_key - ], - {}, - ) - - extra_return_dict[self.output_key] = output - return extra_return_dict - - -from pydantic import BaseModel - - -class Generation(Serializable): - """Output of a single generation.""" - - text: str - """Generated text output.""" - - generation_info: Optional[Dict[str, Any]] = None - """Raw generation info response from the provider""" - """May include things like reason for finishing (e.g. in OpenAI)""" - # TODO: add log probs - - -VALID_TASKS = ("text2text-generation", "text-generation", "summarization") - - -class LLMChain(Chain): - """Chain to run queries against LLMs. - - Example: - .. code-block:: python - - from langchain import LLMChain, OpenAI, PromptTemplate - prompt_template = "Tell me a {adjective} joke" - prompt = PromptTemplate( - input_variables=["adjective"], template=prompt_template - ) - llm = LLMChain(llm=OpenAI(), prompt=prompt) - """ - - @property - def lc_serializable(self) -> bool: - return True - - prompt: BasePromptTemplate - """Prompt object to use.""" - llm: BaseLanguageModel - output_key: str = "text" #: :meta private: - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @property - def input_keys(self) -> List[str]: - """Will be whatever keys the prompt expects. - - :meta private: - """ - return self.prompt.input_variables - - @property - def output_keys(self) -> List[str]: - """Will always return text key. - - :meta private: - """ - return [self.output_key] - - def _call( - self, - inputs: Dict[str, Any], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Dict[str, str]: - prompts, stop = self.prep_prompts([inputs], run_manager=run_manager) - response = self.llm.generate_prompt( - prompts, - stop, - callbacks=run_manager.get_child() if run_manager else None, - ) - return self.create_outputs(response)[0] - - def prep_prompts( - self, - input_list: List[Dict[str, Any]], - run_manager: Optional[CallbackManagerForChainRun] = None, - ) -> Tuple[List[PromptValue], Optional[List[str]]]: - """Prepare prompts from inputs.""" - stop = None - if "stop" in input_list[0]: - stop = input_list[0]["stop"] - prompts = [] - for inputs in input_list: - selected_inputs = { - k: inputs[k] for k in self.prompt.input_variables - } - prompt = self.prompt.format_prompt(**selected_inputs) - _colored_text = get_colored_text(prompt.to_string(), "green") - _text = "Prompt after formatting:\n" + _colored_text - if run_manager: - run_manager.on_text(_text, end="\n", verbose=self.verbose) - if "stop" in inputs and inputs["stop"] != stop: - raise ValueError( - "If `stop` is present in any inputs, should be present in all." - ) - prompts.append(prompt) - return prompts, stop - - def apply( - self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None - ) -> List[Dict[str, str]]: - """Utilize the LLM generate method for speed gains.""" - callback_manager = CallbackManager.configure( - callbacks, self.callbacks, self.verbose - ) - run_manager = callback_manager.on_chain_start( - dumpd(self), - {"input_list": input_list}, - ) - try: - response = self.generate(input_list, run_manager=run_manager) - except (KeyboardInterrupt, Exception) as e: - run_manager.on_chain_error(e) - raise e - outputs = self.create_outputs(response) - run_manager.on_chain_end({"outputs": outputs}) - return outputs - - def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]: - """Create outputs from response.""" - return [ - # Get the text of the top generated string. - {self.output_key: generation[0].text} - for generation in response.generations - ] - - def predict_and_parse( - self, callbacks: Callbacks = None, **kwargs: Any - ) -> Union[str, List[str], Dict[str, Any]]: - """Call predict and then parse the results.""" - result = self.predict(callbacks=callbacks, **kwargs) - if self.prompt.output_parser is not None: - return self.prompt.output_parser.parse(result) - else: - return result - - def apply_and_parse( - self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None - ) -> Sequence[Union[str, List[str], Dict[str, str]]]: - """Call apply and then parse the results.""" - result = self.apply(input_list, callbacks=callbacks) - return self._parse_result(result) - - def _parse_result( - self, result: List[Dict[str, str]] - ) -> Sequence[Union[str, List[str], Dict[str, str]]]: - if self.prompt.output_parser is not None: - return [ - self.prompt.output_parser.parse(res[self.output_key]) - for res in result - ] - else: - return result - - @property - def _chain_type(self) -> str: - return "llm_chain" - - @classmethod - def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain: - """Create LLMChain from LLM and template.""" - prompt_template = PromptTemplate.from_template(template) - return cls(llm=llm, prompt=prompt_template) - - -def _get_default_document_prompt() -> PromptTemplate: - return PromptTemplate( - input_variables=["page_content"], template="{page_content}" - ) - - -class StuffDocumentsChain(BaseCombineDocumentsChain): - """Chain that combines documents by stuffing into context.""" - - llm_chain: LLMChain - """LLM wrapper to use after formatting documents.""" - document_prompt: BasePromptTemplate = Field( - default_factory=_get_default_document_prompt - ) - """Prompt to use to format each document.""" - document_variable_name: str - """The variable name in the llm_chain to put the documents in. - If only one variable in the llm_chain, this need not be provided.""" - document_separator: str = "\n\n" - """The string with which to join the formatted documents""" - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - arbitrary_types_allowed = True - - @root_validator(pre=True) - def get_default_document_variable_name(cls, values: Dict) -> Dict: - """Get default document variable name, if not provided.""" - llm_chain_variables = values["llm_chain"].prompt.input_variables - if "document_variable_name" not in values: - if len(llm_chain_variables) == 1: - values["document_variable_name"] = llm_chain_variables[0] - else: - raise ValueError( - "document_variable_name must be provided if there are " - "multiple llm_chain_variables" - ) - else: - if values["document_variable_name"] not in llm_chain_variables: - raise ValueError( - f"document_variable_name {values['document_variable_name']} was " - f"not found in llm_chain input_variables: {llm_chain_variables}" - ) - return values - - def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict: - # Format each document according to the prompt - doc_strings = [ - format_document(doc, self.document_prompt) for doc in docs - ] - # Join the documents together to put them in the prompt. - inputs = { - k: v - for k, v in kwargs.items() - if k in self.llm_chain.prompt.input_variables - } - inputs[self.document_variable_name] = self.document_separator.join( - doc_strings - ) - return inputs - - def prompt_length( - self, docs: List[Document], **kwargs: Any - ) -> Optional[int]: - """Get the prompt length by formatting the prompt.""" - inputs = self._get_inputs(docs, **kwargs) - prompt = self.llm_chain.prompt.format(**inputs) - return self.llm_chain.llm.get_num_tokens(prompt) - - @property - def _chain_type(self) -> str: - return "stuff_documents_chain" - - -class LoadingCallable(Protocol): - """Interface for loading the combine documents chain.""" - - def __call__( - self, llm: BaseLanguageModel, **kwargs: Any - ) -> BaseCombineDocumentsChain: - """Callable to load the combine documents chain.""" - - -def _load_stuff_chain( - llm: BaseLanguageModel, - prompt: Optional[BasePromptTemplate] = None, - document_variable_name: str = "context", - verbose: Optional[bool] = None, - callback_manager: Optional[BaseCallbackManager] = None, - callbacks: Callbacks = None, - **kwargs: Any, -) -> StuffDocumentsChain: - _prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm) - llm_chain = LLMChain( - llm=llm, - prompt=_prompt, - verbose=verbose, - callback_manager=callback_manager, - callbacks=callbacks, - ) - # TODO: document prompt - return StuffDocumentsChain( - llm_chain=llm_chain, - document_variable_name=document_variable_name, - verbose=verbose, - callback_manager=callback_manager, - **kwargs, - ) - - -def load_qa_chain( - llm: BaseLanguageModel, - chain_type: str = "stuff", - verbose: Optional[bool] = None, - callback_manager: Optional[BaseCallbackManager] = None, - **kwargs: Any, -) -> BaseCombineDocumentsChain: - """Load question answering chain. - - Args: - llm: Language Model to use in the chain. - chain_type: Type of document combining chain to use. Should be one of "stuff", - "map_reduce", "map_rerank", and "refine". - verbose: Whether chains should be run in verbose mode or not. Note that this - applies to all chains that make up the final chain. - callback_manager: Callback manager to use for the chain. - - Returns: - A chain to use for question answering. - """ - loader_mapping: Mapping[str, LoadingCallable] = { - "stuff": _load_stuff_chain, - } - if chain_type not in loader_mapping: - raise ValueError( - f"Got unsupported chain type: {chain_type}. " - f"Should be one of {loader_mapping.keys()}" - ) - return loader_mapping[chain_type]( - llm, verbose=verbose, callback_manager=callback_manager, **kwargs - ) diff --git a/apps/language_models/langchain/gen.py b/apps/language_models/langchain/gen.py deleted file mode 100644 index 3b965235..00000000 --- a/apps/language_models/langchain/gen.py +++ /dev/null @@ -1,1945 +0,0 @@ -import ast -import copy -import functools -import glob -import inspect -import queue -import sys -import os -import time -import traceback -import types -import typing -import warnings -from datetime import datetime -import filelock -import requests -import psutil -from requests import ConnectTimeout, JSONDecodeError -from urllib3.exceptions import ( - ConnectTimeoutError, - MaxRetryError, - ConnectionError, -) -from requests.exceptions import ConnectionError as ConnectionError2 -from requests.exceptions import ReadTimeout as ReadTimeout2 - -if os.path.dirname(os.path.abspath(__file__)) not in sys.path: - sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" -os.environ["BITSANDBYTES_NOWELCOME"] = "1" -warnings.filterwarnings( - "ignore", category=UserWarning, message="TypedStorage is deprecated" -) - -from evaluate_params import eval_func_param_names, no_default_param_names -from enums import ( - DocumentChoices, - LangChainMode, - no_lora_str, - model_token_mapping, - no_model_str, - source_prefix, - source_postfix, - LangChainAction, -) -from loaders import get_loaders -from utils import ( - set_seed, - clear_torch_cache, - save_generate_output, - NullContext, - wrapped_partial, - EThread, - get_githash, - import_matplotlib, - get_device, - makedirs, - get_kwargs, - start_faulthandler, - get_hf_server, - FakeTokenizer, - remove, -) - -start_faulthandler() -import_matplotlib() - -SEED = 1236 -set_seed(SEED) - -from typing import Union - -# import fire -import torch -from transformers import GenerationConfig, AutoModel, TextIteratorStreamer - -from prompter import ( - Prompter, - inv_prompt_type_to_model_lower, - non_hf_types, - PromptType, - get_prompt, - generate_prompt, -) -from stopping import get_stopping - -langchain_modes = [x.value for x in list(LangChainMode)] - -langchain_actions = [x.value for x in list(LangChainAction)] - -scratch_base_dir = "/tmp/" - - -class Langchain: - def __init__(self, device="cuda", precision="fp16"): - super().__init__() - self.device = device - self.precision = precision - - def get_config( - self, - base_model, - use_auth_token=False, - trust_remote_code=True, - offload_folder=None, - triton_attn=False, - long_sequence=True, - return_model=False, - raise_exception=False, - ): - from accelerate import init_empty_weights - - with init_empty_weights(): - from transformers import AutoConfig - - try: - config = AutoConfig.from_pretrained( - base_model, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - ) - except OSError as e: - if raise_exception: - raise - if "not a local folder and is not a valid model identifier listed on" in str( - e - ) or "404 Client Error" in str( - e - ): - # e.g. llama, gpjt, etc. - # e.g. HF TGI but not model on HF or private etc. - # HF TGI server only should really require prompt_type, not HF model state - return None, None - else: - raise - if triton_attn and "mpt-" in base_model.lower(): - config.attn_config["attn_impl"] = "triton" - if long_sequence: - if "mpt-7b-storywriter" in base_model.lower(): - config.update({"max_seq_len": 83968}) - if "mosaicml/mpt-7b-chat" in base_model.lower(): - config.update({"max_seq_len": 4096}) - if "mpt-30b" in base_model.lower(): - config.update({"max_seq_len": 2 * 8192}) - if return_model and issubclass( - config.__class__, tuple(AutoModel._model_mapping.keys()) - ): - model = AutoModel.from_config( - config, - trust_remote_code=trust_remote_code, - ) - else: - # can't infer - model = None - if "falcon" in base_model.lower(): - config.use_cache = False - - return config, model - - def get_non_lora_model( - self, - base_model, - model_loader, - load_half, - load_gptq, - use_safetensors, - model_kwargs, - reward_type, - config, - model, - gpu_id=0, - ): - """ - Ensure model gets on correct device - """ - - device_map = None - if model is not None: - # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model - # NOTE: Some models require avoiding sharding some layers, - # then would pass no_split_module_classes and give list of those layers. - from accelerate import infer_auto_device_map - - device_map = infer_auto_device_map( - model, - dtype=torch.float16 if load_half else torch.float32, - ) - if hasattr(model, "model"): - device_map_model = infer_auto_device_map( - model.model, - dtype=torch.float16 if load_half else torch.float32, - ) - device_map.update(device_map_model) - - n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 - - if device_map is None: - if self.device == "cuda": - if n_gpus > 0: - if gpu_id >= 0: - # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set. - # So avoid for now, just put on first GPU, unless score_model, put on last - if reward_type: - device_map = {"": n_gpus - 1} - else: - device_map = {"": min(n_gpus - 1, gpu_id)} - if gpu_id == -1: - device_map = {"": "cuda"} - else: - device_map = {"": "cpu"} - model_kwargs["load_in_8bit"] = False - model_kwargs["load_in_4bit"] = False - print("device_map: %s" % device_map, flush=True) - - load_in_8bit = model_kwargs.get("load_in_8bit", False) - load_in_4bit = model_kwargs.get("load_in_4bit", False) - model_kwargs["device_map"] = device_map - model_kwargs["use_safetensors"] = use_safetensors - self.pop_unused_model_kwargs(model_kwargs) - - if load_gptq: - model_kwargs.pop("torch_dtype", None) - model_kwargs.pop("device_map") - model = model_loader( - model_name_or_path=base_model, - model_basename=load_gptq, - **model_kwargs, - ) - elif load_in_8bit or load_in_4bit or not load_half: - model = model_loader( - base_model, - config=config, - **model_kwargs, - ) - else: - model = model_loader( - base_model, - config=config, - **model_kwargs, - ).half() - return model - - def get_client_from_inference_server( - self, - inference_server, - base_model=None, - raise_connection_exception=False, - ): - inference_server, headers = get_hf_server(inference_server) - # preload client since slow for gradio case especially - from gradio_utils.grclient import GradioClient - - gr_client = None - hf_client = None - if headers is None: - try: - print( - "GR Client Begin: %s %s" % (inference_server, base_model), - flush=True, - ) - # first do sanity check if alive, else gradio client takes too long by default - requests.get( - inference_server, - timeout=int(os.getenv("REQUEST_TIMEOUT", "30")), - ) - gr_client = GradioClient(inference_server) - print("GR Client End: %s" % inference_server, flush=True) - except (OSError, ValueError) as e: - # Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF - gr_client = None - print( - "GR Client Failed %s %s: %s" - % (inference_server, base_model, str(e)), - flush=True, - ) - except ( - ConnectTimeoutError, - ConnectTimeout, - MaxRetryError, - ConnectionError, - ConnectionError2, - JSONDecodeError, - ReadTimeout2, - KeyError, - ) as e: - t, v, tb = sys.exc_info() - ex = "".join(traceback.format_exception(t, v, tb)) - print( - "GR Client Failed %s %s: %s" - % (inference_server, base_model, str(ex)), - flush=True, - ) - if raise_connection_exception: - raise - - if gr_client is None: - res = None - from text_generation import Client as HFClient - - print("HF Client Begin: %s %s" % (inference_server, base_model)) - try: - hf_client = HFClient( - inference_server, - headers=headers, - timeout=int(os.getenv("REQUEST_TIMEOUT", "30")), - ) - # quick check valid TGI endpoint - res = hf_client.generate("What?", max_new_tokens=1) - hf_client = HFClient( - inference_server, headers=headers, timeout=300 - ) - except ( - ConnectTimeoutError, - ConnectTimeout, - MaxRetryError, - ConnectionError, - ConnectionError2, - JSONDecodeError, - ReadTimeout2, - KeyError, - ) as e: - hf_client = None - t, v, tb = sys.exc_info() - ex = "".join(traceback.format_exception(t, v, tb)) - print( - "HF Client Failed %s %s: %s" - % (inference_server, base_model, str(ex)) - ) - if raise_connection_exception: - raise - print( - "HF Client End: %s %s : %s" - % (inference_server, base_model, res) - ) - return inference_server, gr_client, hf_client - - def get_model( - self, - load_8bit: bool = False, - load_4bit: bool = False, - load_half: bool = False, - load_gptq: str = "", - use_safetensors: bool = False, - infer_devices: bool = True, - device: str = None, - base_model: str = "", - inference_server: str = "", - tokenizer_base_model: str = "", - lora_weights: str = "", - gpu_id: int = 0, - reward_type: bool = None, - local_files_only: bool = False, - resume_download: bool = True, - use_auth_token: Union[str, bool] = False, - trust_remote_code: bool = True, - offload_folder: str = None, - compile_model: bool = True, - verbose: bool = False, - ): - """ - - :param load_8bit: load model in 8-bit, not supported by all models - :param load_4bit: load model in 4-bit, not supported by all models - :param load_half: load model in 16-bit - :param load_gptq: GPTQ model_basename - :param use_safetensors: use safetensors file - :param infer_devices: Use torch infer of optimal placement of layers on devices (for non-lora case) - For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches - So it is not the default - :param base_model: name/path of base model - :param inference_server: whether base_model is hosted locally ('') or via http (url) - :param tokenizer_base_model: name/path of tokenizer - :param lora_weights: name/path - :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1) - :param reward_type: reward type model for sequence classification - :param local_files_only: use local files instead of from HF - :param resume_download: resume downloads from HF - :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo - :param trust_remote_code: trust code needed by model - :param offload_folder: offload folder - :param compile_model: whether to compile torch model - :param verbose: - :return: - """ - if verbose: - print("Get %s model" % base_model, flush=True) - - triton_attn = False - long_sequence = True - config_kwargs = dict( - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - triton_attn=triton_attn, - long_sequence=long_sequence, - ) - config, _ = self.get_config( - base_model, **config_kwargs, raise_exception=False - ) - - if base_model in non_hf_types: - assert config is None, "Expected config None for %s" % base_model - - llama_type_from_config = "llama" in str(config).lower() - llama_type_from_name = "llama" in base_model.lower() - llama_type = llama_type_from_config or llama_type_from_name - if "xgen" in base_model.lower(): - llama_type = False - if llama_type: - if verbose: - print( - "Detected as llama type from" - " config (%s) or name (%s)" - % (llama_type_from_config, llama_type_from_name), - flush=True, - ) - - model_loader, tokenizer_loader = get_loaders( - model_name=base_model, - reward_type=reward_type, - llama_type=llama_type, - load_gptq=load_gptq, - ) - - tokenizer_kwargs = dict( - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - padding_side="left", - config=config, - ) - if not tokenizer_base_model: - tokenizer_base_model = base_model - - if ( - config is not None - and tokenizer_loader is not None - and not isinstance(tokenizer_loader, str) - ): - tokenizer = tokenizer_loader.from_pretrained( - tokenizer_base_model, **tokenizer_kwargs - ) - # sets raw (no cushion) limit - self.set_model_max_len(config, tokenizer, verbose=False) - # if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get: - # Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233 - tokenizer.model_max_length = tokenizer.model_max_length - 50 - else: - tokenizer = FakeTokenizer() - - if isinstance(inference_server, str) and inference_server.startswith( - "http" - ): - ( - inference_server, - gr_client, - hf_client, - ) = self.get_client_from_inference_server( - inference_server, base_model=base_model - ) - client = gr_client or hf_client - # Don't return None, None for model, tokenizer so triggers - return client, tokenizer, "http" - if isinstance(inference_server, str) and inference_server.startswith( - "openai" - ): - assert os.getenv( - "OPENAI_API_KEY" - ), "Set environment for OPENAI_API_KEY" - # Don't return None, None for model, tokenizer so triggers - # include small token cushion - tokenizer = FakeTokenizer( - model_max_length=model_token_mapping[base_model] - 50 - ) - return inference_server, tokenizer, inference_server - assert not inference_server, ( - "Malformed inference_server=%s" % inference_server - ) - if base_model in non_hf_types: - from gpt4all_llm import get_model_tokenizer_gpt4all - - model, tokenizer, _ = get_model_tokenizer_gpt4all(base_model) - return model, tokenizer, self.device - - # get local torch-HF model - return self.get_hf_model( - load_8bit=load_8bit, - load_4bit=load_4bit, - load_half=load_half, - load_gptq=load_gptq, - use_safetensors=use_safetensors, - infer_devices=infer_devices, - device=self.device, - base_model=base_model, - tokenizer_base_model=tokenizer_base_model, - lora_weights=lora_weights, - gpu_id=gpu_id, - reward_type=reward_type, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - compile_model=compile_model, - llama_type=llama_type, - config_kwargs=config_kwargs, - tokenizer_kwargs=tokenizer_kwargs, - verbose=verbose, - ) - - def get_hf_model( - self, - load_8bit: bool = False, - load_4bit: bool = False, - load_half: bool = True, - load_gptq: str = "", - use_safetensors: bool = False, - infer_devices: bool = True, - device: str = None, - base_model: str = "", - tokenizer_base_model: str = "", - lora_weights: str = "", - gpu_id: int = 0, - reward_type: bool = None, - local_files_only: bool = False, - resume_download: bool = True, - use_auth_token: Union[str, bool] = False, - trust_remote_code: bool = True, - offload_folder: str = None, - compile_model: bool = True, - llama_type: bool = False, - config_kwargs=None, - tokenizer_kwargs=None, - verbose: bool = False, - ): - assert config_kwargs is not None - assert tokenizer_kwargs is not None - - if lora_weights is not None and lora_weights.strip(): - if verbose: - print("Get %s lora weights" % lora_weights, flush=True) - - if "gpt2" in base_model.lower(): - # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half - load_8bit = False - load_4bit = False - - assert ( - base_model.strip() - ), "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)" - - model_loader, tokenizer_loader = get_loaders( - model_name=base_model, - reward_type=reward_type, - llama_type=llama_type, - load_gptq=load_gptq, - ) - - config, _ = self.get_config( - base_model, - return_model=False, - raise_exception=True, - **config_kwargs, - ) - - if tokenizer_loader is not None and not isinstance( - tokenizer_loader, str - ): - tokenizer = tokenizer_loader.from_pretrained( - tokenizer_base_model, **tokenizer_kwargs - ) - else: - tokenizer = tokenizer_loader - - if isinstance(tokenizer, str): - # already a pipeline, tokenizer_loader is string for task - model = model_loader( - tokenizer, - model=base_model, - device=0 if self.device == "cuda" else -1, - torch_dtype=torch.float16 - if self.device == "cuda" - else torch.float32, - ) - else: - assert self.device in ["cuda", "cpu", "mps"], ( - "Unsupported device %s" % self.device - ) - model_kwargs = dict( - local_files_only=local_files_only, - torch_dtype=torch.float16 - if self.device == "cuda" - else torch.float32, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - ) - if ( - "mbart-" not in base_model.lower() - and "mpt-" not in base_model.lower() - ): - if ( - infer_devices - and gpu_id is not None - and gpu_id >= 0 - and self.device == "cuda" - ): - device_map = {"": gpu_id} - else: - device_map = "auto" - model_kwargs.update( - dict( - load_in_8bit=load_8bit, - load_in_4bit=load_4bit, - device_map=device_map, - ) - ) - if ( - "mpt-" in base_model.lower() - and gpu_id is not None - and gpu_id >= 0 - ): - # MPT doesn't support spreading over GPUs - model_kwargs.update( - dict( - device_map={"": gpu_id} - if self.device == "cuda" - else "cpu" - ) - ) - - if "OpenAssistant/reward-model".lower() in base_model.lower(): - # FIXME: could put on other GPUs - model_kwargs["device_map"] = ( - {"": 0} if self.device == "cuda" else {"": "cpu"} - ) - model_kwargs.pop("torch_dtype", None) - self.pop_unused_model_kwargs(model_kwargs) - - if not lora_weights: - # torch.device context uses twice memory for AutoGPTQ - context = NullContext if load_gptq else torch.device - with context(self.device): - if infer_devices: - config, model = self.get_config( - base_model, - return_model=True, - raise_exception=True, - **config_kwargs, - ) - model = self.get_non_lora_model( - base_model, - model_loader, - load_half, - load_gptq, - use_safetensors, - model_kwargs, - reward_type, - config, - model, - gpu_id=gpu_id, - ) - else: - config, _ = self.get_config( - base_model, **config_kwargs - ) - if load_half and not ( - load_8bit or load_4bit or load_gptq - ): - model = model_loader( - base_model, config=config, **model_kwargs - ).half() - else: - model = model_loader( - base_model, config=config, **model_kwargs - ) - elif load_8bit or load_4bit: - config, _ = self.get_config(base_model, **config_kwargs) - model = model_loader(base_model, config=config, **model_kwargs) - from peft import ( - PeftModel, - ) # loads cuda, so avoid in global scope - - model = PeftModel.from_pretrained( - model, - lora_weights, - torch_dtype=torch.float16 - if self.device == "cuda" - else torch.float32, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - device_map={"": 0} - if self.device == "cuda" - else {"": "cpu"}, # seems to be required - ) - else: - with torch.device(self.device): - config, _ = self.get_config( - base_model, raise_exception=True, **config_kwargs - ) - model = model_loader( - base_model, config=config, **model_kwargs - ) - from peft import ( - PeftModel, - ) # loads cuda, so avoid in global scope - - model = PeftModel.from_pretrained( - model, - lora_weights, - torch_dtype=torch.float16 - if self.device == "cuda" - else torch.float32, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - device_map="auto", - ) - if load_half and not load_gptq: - model.half() - - # unwind broken decapoda-research config - if llama_type: - model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk - model.config.bos_token_id = 1 - model.config.eos_token_id = 2 - if "gpt2" in base_model.lower(): - # add special tokens that otherwise all share the same id - tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "pad_token": "", - } - ) - - if not isinstance(tokenizer, str): - model.eval() - # if torch.__version__ >= "2" and sys.platform != "win32" and compile_model: - # model = torch.compile(model) - - self.set_model_max_len( - config, tokenizer, verbose=False, reward_type=reward_type - ) - - return model, tokenizer, self.device - - def set_model_max_len( - self, config, tokenizer, verbose=False, reward_type=False - ): - if reward_type: - # limit deberta, else uses too much memory and not worth response score - tokenizer.model_max_length = 512 - if hasattr(config, "max_seq_len") and isinstance( - config.max_seq_len, int - ): - tokenizer.model_max_length = config.max_seq_len - elif hasattr(config, "max_position_embeddings") and isinstance( - config.max_position_embeddings, int - ): - # help automatically limit inputs to generate - tokenizer.model_max_length = config.max_position_embeddings - else: - if verbose: - print( - "Could not determine model_max_length, setting to 2048", - flush=True, - ) - tokenizer.model_max_length = 2048 - # for bug in HF transformers - if tokenizer.model_max_length > 100000000: - tokenizer.model_max_length = 2048 - - def pop_unused_model_kwargs(self, model_kwargs): - """ - in-place pop unused kwargs that are not dependency-upgrade friendly - no point passing in False, is default, and helps avoid needing to update requirements for new deps - :param model_kwargs: - :return: - """ - check_list = ["load_in_8bit", "load_in_4bit"] - for k in check_list: - if k in model_kwargs and not model_kwargs[k]: - model_kwargs.pop(k) - - def get_score_model( - self, - score_model: str = None, - load_8bit: bool = False, - load_4bit: bool = False, - load_half: bool = True, - load_gptq: str = "", - infer_devices: bool = True, - base_model: str = "", - inference_server: str = "", - tokenizer_base_model: str = "", - lora_weights: str = "", - gpu_id: int = 0, - reward_type: bool = None, - local_files_only: bool = False, - resume_download: bool = True, - use_auth_token: Union[str, bool] = False, - trust_remote_code: bool = True, - offload_folder: str = None, - compile_model: bool = True, - verbose: bool = False, - ): - if score_model is not None and score_model.strip(): - load_8bit = False - load_4bit = False - load_half = False - load_gptq = "" - use_safetensors = False - base_model = score_model.strip() - tokenizer_base_model = "" - lora_weights = "" - inference_server = "" - llama_type = False - compile_model = False - smodel, stokenizer, _ = self.get_model( - reward_type=True, - **get_kwargs( - self.get_model, exclude_names=["reward_type"], **locals() - ), - ) - else: - smodel, stokenizer, _ = None, None, None - return smodel, stokenizer, self.device - - def evaluate( - self, - model_state, - my_db_state, - # START NOTE: Examples must have same order of parameters - instruction, - iinput, - context, - stream_output, - prompt_type, - prompt_dict, - temperature, - top_p, - top_k, - num_beams, - max_new_tokens, - min_new_tokens, - early_stopping, - max_time, - repetition_penalty, - num_return_sequences, - do_sample, - chat, - instruction_nochat, - iinput_nochat, - langchain_mode, - langchain_action, - top_k_docs, - chunk, - chunk_size, - document_choice, - # END NOTE: Examples must have same order of parameters - src_lang=None, - tgt_lang=None, - debug=False, - concurrency_count=None, - save_dir=None, - sanitize_bot_response=False, - model_state0=None, - memory_restriction_level=None, - max_max_new_tokens=None, - is_public=None, - max_max_time=None, - raise_generate_gpu_exceptions=None, - chat_context=None, - lora_weights=None, - load_db_if_exists=True, - dbs=None, - user_path=None, - detect_user_path_changes_every_query=None, - use_openai_embedding=None, - use_openai_model=None, - hf_embedding_model=None, - db_type=None, - n_jobs=None, - first_para=None, - text_limit=None, - verbose=False, - cli=False, - reverse_docs=True, - use_cache=None, - auto_reduce_chunks=None, - max_chunks=None, - model_lock=None, - force_langchain_evaluate=None, - model_state_none=None, - ): - # ensure passed these - assert concurrency_count is not None - assert memory_restriction_level is not None - assert raise_generate_gpu_exceptions is not None - assert chat_context is not None - assert use_openai_embedding is not None - assert use_openai_model is not None - assert hf_embedding_model is not None - assert db_type is not None - assert top_k_docs is not None and isinstance(top_k_docs, int) - assert chunk is not None and isinstance(chunk, bool) - assert chunk_size is not None and isinstance(chunk_size, int) - assert n_jobs is not None - assert first_para is not None - - if debug: - locals_dict = locals().copy() - locals_dict.pop("model_state", None) - locals_dict.pop("model_state0", None) - locals_dict.pop("model_states", None) - print(locals_dict) - - no_model_msg = ( - "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\n" - "Then start New Conversation" - ) - - if model_state is None: - model_state = model_state_none.copy() - if model_state0 is None: - # e.g. for no gradio case, set dummy value, else should be set - model_state0 = model_state_none.copy() - - # model_state['model] is only 'model' if should use model_state0 - # model could also be None - have_model_lock = model_lock is not None - have_fresh_model = model_state["model"] not in [ - None, - "model", - no_model_str, - ] - # for gradio UI control, expect model_state and model_state0 to match, so if have_model_lock=True, then should have_fresh_model=True - # but gradio API control will only use nochat api etc. and won't use fresh model, so can't assert in general - # if have_model_lock: - # assert have_fresh_model, "Expected model_state and model_state0 to match if have_model_lock" - have_cli_model = model_state0["model"] not in [ - None, - "model", - no_model_str, - ] - - if have_fresh_model: - # USE FRESH MODEL - if not have_model_lock: - # model_state0 is just one of model_state if model_lock, so don't nuke - # try to free-up original model (i.e. list was passed as reference) - if model_state0["model"] and hasattr( - model_state0["model"], "cpu" - ): - model_state0["model"].cpu() - model_state0["model"] = None - # try to free-up original tokenizer (i.e. list was passed as reference) - if model_state0["tokenizer"]: - model_state0["tokenizer"] = None - clear_torch_cache() - chosen_model_state = model_state - elif have_cli_model: - # USE MODEL SETUP AT CLI - assert isinstance( - model_state["model"], str - ) # expect no fresh model - chosen_model_state = model_state0 - else: - raise AssertionError(no_model_msg) - # get variables - model = chosen_model_state["model"] - tokenizer = chosen_model_state["tokenizer"] - base_model = chosen_model_state["base_model"] - tokenizer_base_model = chosen_model_state["tokenizer_base_model"] - lora_weights = chosen_model_state["lora_weights"] - inference_server = chosen_model_state["inference_server"] - # prefer use input from API over model state - prompt_type = prompt_type or chosen_model_state["prompt_type"] - prompt_dict = prompt_dict or chosen_model_state["prompt_dict"] - - if base_model is None: - raise AssertionError(no_model_msg) - - assert base_model.strip(), no_model_msg - assert model, "Model is missing" - assert tokenizer, "Tokenizer is missing" - - # choose chat or non-chat mode - print(instruction) - if not chat: - instruction = instruction_nochat - iinput = iinput_nochat - print(instruction) - - # in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice - model_lower = base_model.lower() - if ( - not prompt_type - and model_lower in inv_prompt_type_to_model_lower - and prompt_type != "custom" - ): - prompt_type = inv_prompt_type_to_model_lower[model_lower] - if verbose: - print( - "Auto-selecting prompt_type=%s for %s" - % (prompt_type, model_lower), - flush=True, - ) - assert prompt_type is not None, "prompt_type was None" - - # Control generation hyperparameters - # adjust for bad inputs, e.g. in case also come from API that doesn't get constrained by gradio sliders - # below is for TGI server, not required for HF transformers - # limits are chosen similar to gradio_runner.py sliders/numbers - top_p = min(max(1e-3, top_p), 1.0 - 1e-3) - top_k = min(max(1, int(top_k)), 100) - temperature = min(max(0.01, temperature), 2.0) - # FIXME: https://github.com/h2oai/h2ogpt/issues/106 - num_beams = ( - 1 if stream_output else num_beams - ) # See max_beams in gradio_runner - max_max_new_tokens = self.get_max_max_new_tokens( - chosen_model_state, - memory_restriction_level=memory_restriction_level, - max_new_tokens=max_new_tokens, - max_max_new_tokens=max_max_new_tokens, - ) - model_max_length = 2048 # get_model_max_length(chosen_model_state) - max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens) - min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens) - max_time = min(max(0, max_time), max_max_time) - repetition_penalty = min(max(0.01, repetition_penalty), 3.0) - num_return_sequences = ( - 1 if chat else min(max(1, int(num_return_sequences)), 10) - ) - ( - min_top_k_docs, - max_top_k_docs, - label_top_k_docs, - ) = self.get_minmax_top_k_docs(is_public) - top_k_docs = min(max(min_top_k_docs, int(top_k_docs)), max_top_k_docs) - chunk_size = min(max(128, int(chunk_size)), 2048) - if not context: - # get hidden context if have one - context = self.get_context(chat_context, prompt_type) - - # restrict instruction, typically what has large input - from h2oai_pipeline import H2OTextGenerationPipeline - - print(instruction) - ( - instruction, - num_prompt_tokens1, - ) = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer) - context, num_prompt_tokens2 = H2OTextGenerationPipeline.limit_prompt( - context, tokenizer - ) - iinput, num_prompt_tokens3 = H2OTextGenerationPipeline.limit_prompt( - iinput, tokenizer - ) - num_prompt_tokens = ( - (num_prompt_tokens1 or 0) - + (num_prompt_tokens2 or 0) - + (num_prompt_tokens3 or 0) - ) - - # get prompt - prompter = Prompter( - prompt_type, - prompt_dict, - debug=debug, - chat=chat, - stream_output=stream_output, - ) - data_point = dict( - context=context, instruction=instruction, input=iinput - ) - prompt = prompter.generate_prompt(data_point) - - # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use - assert langchain_mode in langchain_modes, ( - "Invalid langchain_mode %s" % langchain_mode - ) - assert langchain_action in langchain_actions, ( - "Invalid langchain_action %s" % langchain_action - ) - if ( - langchain_mode in ["MyData"] - and my_db_state is not None - and len(my_db_state) > 0 - and my_db_state[0] is not None - ): - db1 = my_db_state[0] - elif dbs is not None and langchain_mode in dbs: - db1 = dbs[langchain_mode] - else: - db1 = None - do_langchain_path = ( - langchain_mode not in [False, "Disabled", "ChatLLM", "LLM"] - or base_model in non_hf_types - or force_langchain_evaluate - ) - if do_langchain_path: - outr = "" - # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close - from gpt_langchain import run_qa_db - - gen_hyper_langchain = dict( - do_sample=do_sample, - temperature=temperature, - repetition_penalty=repetition_penalty, - top_k=top_k, - top_p=top_p, - num_beams=num_beams, - min_new_tokens=min_new_tokens, - max_new_tokens=max_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - num_return_sequences=num_return_sequences, - ) - out = run_qa_db( - query=instruction, - iinput=iinput, - context=context, - model_name=base_model, - model=model, - tokenizer=tokenizer, - inference_server=inference_server, - stream_output=stream_output, - prompter=prompter, - load_db_if_exists=load_db_if_exists, - db=db1, - user_path=user_path, - detect_user_path_changes_every_query=detect_user_path_changes_every_query, - cut_distanct=1.1 - if langchain_mode in ["wiki_full"] - else 1.64, # FIXME, too arbitrary - use_openai_embedding=use_openai_embedding, - use_openai_model=use_openai_model, - hf_embedding_model=hf_embedding_model, - first_para=first_para, - text_limit=text_limit, - chunk=chunk, - chunk_size=chunk_size, - langchain_mode=langchain_mode, - langchain_action=langchain_action, - document_choice=document_choice, - db_type=db_type, - top_k_docs=top_k_docs, - **gen_hyper_langchain, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - n_jobs=n_jobs, - verbose=verbose, - cli=cli, - sanitize_bot_response=sanitize_bot_response, - reverse_docs=reverse_docs, - lora_weights=lora_weights, - auto_reduce_chunks=auto_reduce_chunks, - max_chunks=max_chunks, - device=self.device, - ) - return out - - inputs_list_names = list(inspect.signature(evaluate).parameters) - global inputs_kwargs_list - inputs_kwargs_list = [ - x - for x in inputs_list_names - if x not in eval_func_param_names + ["model_state", "my_db_state"] - ] - - def get_cutoffs( - self, - memory_restriction_level, - for_context=False, - model_max_length=2048, - ): - # help to avoid errors like: - # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3 - # RuntimeError: expected scalar type Half but found Float - # with - 256 - if memory_restriction_level > 0: - max_length_tokenize = ( - 768 - 256 if memory_restriction_level <= 2 else 512 - 256 - ) - else: - # at least give room for 1 paragraph output - max_length_tokenize = model_max_length - 256 - cutoff_len = ( - max_length_tokenize * 4 - ) # if reaches limit, then can't generate new tokens - output_smallest = 30 * 4 - max_prompt_length = cutoff_len - output_smallest - - if for_context: - # then lower even more to avoid later chop, since just estimate tokens in context bot - max_prompt_length = max(64, int(max_prompt_length * 0.8)) - - return ( - cutoff_len, - output_smallest, - max_length_tokenize, - max_prompt_length, - ) - - def generate_with_exceptions( - self, - func, - *args, - prompt="", - inputs_decoded="", - raise_generate_gpu_exceptions=True, - **kwargs, - ): - try: - func(*args, **kwargs) - except torch.cuda.OutOfMemoryError as e: - print( - "GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" - % (prompt, inputs_decoded, str(e)), - flush=True, - ) - if "input_ids" in kwargs: - if kwargs["input_ids"] is not None: - kwargs["input_ids"].cpu() - kwargs["input_ids"] = None - traceback.print_exc() - clear_torch_cache() - return - except (Exception, RuntimeError) as e: - if ( - "Expected all tensors to be on the same device" in str(e) - or "expected scalar type Half but found Float" in str(e) - or "probability tensor contains either" in str(e) - or "cublasLt ran into an error!" in str(e) - or "mat1 and mat2 shapes cannot be multiplied" in str(e) - ): - print( - "GPU Error: prompt: %s inputs_decoded: %s exception: %s" - % (prompt, inputs_decoded, str(e)), - flush=True, - ) - traceback.print_exc() - clear_torch_cache() - if raise_generate_gpu_exceptions: - raise - return - else: - clear_torch_cache() - if raise_generate_gpu_exceptions: - raise - - def get_generate_params( - self, - model_lower, - chat, - stream_output, - show_examples, - prompt_type, - prompt_dict, - temperature, - top_p, - top_k, - num_beams, - max_new_tokens, - min_new_tokens, - early_stopping, - max_time, - repetition_penalty, - num_return_sequences, - do_sample, - top_k_docs, - chunk, - chunk_size, - verbose, - ): - use_defaults = False - use_default_examples = True - examples = [] - task_info = "LLM" - if model_lower: - print(f"Using Model {model_lower}", flush=True) - else: - if verbose: - print("No model defined yet", flush=True) - - min_new_tokens = min_new_tokens if min_new_tokens is not None else 0 - early_stopping = ( - early_stopping if early_stopping is not None else False - ) - max_time_defaults = 60 * 3 - max_time = max_time if max_time is not None else max_time_defaults - - if ( - not prompt_type - and model_lower in inv_prompt_type_to_model_lower - and prompt_type != "custom" - ): - prompt_type = inv_prompt_type_to_model_lower[model_lower] - if verbose: - print( - "Auto-selecting prompt_type=%s for %s" - % (prompt_type, model_lower), - flush=True, - ) - - # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end - if show_examples is None: - if chat: - show_examples = False - else: - show_examples = True - - summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker? - Philipp: Sure you can use the new Hugging Face Deep Learning Container. - Jeff: ok. - Jeff: and how can I get started? - Jeff: where can I find documentation? - Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face""" - - use_placeholder_instruction_as_example = False - if ( - "bart-large-cnn-samsum" in model_lower - or "flan-t5-base-samsum" in model_lower - ): - placeholder_instruction = summarize_example1 - placeholder_input = "" - use_defaults = True - use_default_examples = False - use_placeholder_instruction_as_example = True - task_info = "Summarization" - elif ( - "t5-" in model_lower - or "t5" == model_lower - or "flan-" in model_lower - ): - placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?" - placeholder_input = "" - use_defaults = True - use_default_examples = True - task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)" - elif "mbart-" in model_lower: - placeholder_instruction = "The girl has long hair." - placeholder_input = "" - use_defaults = True - use_default_examples = False - use_placeholder_instruction_as_example = True - elif "gpt2" in model_lower: - placeholder_instruction = "The sky is" - placeholder_input = "" - prompt_type = prompt_type or "plain" - use_default_examples = ( - True # some will be odd "continuations" but can be ok - ) - use_placeholder_instruction_as_example = True - task_info = "Auto-complete phrase, code, etc." - use_defaults = True - else: - if chat: - placeholder_instruction = "" - else: - placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter." - placeholder_input = "" - if model_lower in inv_prompt_type_to_model_lower: - if prompt_type != "custom": - prompt_type = inv_prompt_type_to_model_lower[model_lower] - elif model_lower: - # default is plain, because might rely upon trust_remote_code to handle prompting - prompt_type = prompt_type or "plain" - else: - prompt_type = "" - task_info = "No task" - if prompt_type == "instruct": - task_info = "Answer question or follow imperative as instruction with optionally input." - elif prompt_type == "plain": - task_info = "Auto-complete phrase, code, etc." - elif prompt_type == "human_bot": - if chat: - task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)" - else: - task_info = "Ask question/imperative (input concatenated with instruction)" - - # revert to plain if still nothing - prompt_type = prompt_type or "plain" - if use_defaults: - temperature = 1.0 if temperature is None else temperature - top_p = 1.0 if top_p is None else top_p - top_k = 40 if top_k is None else top_k - num_beams = num_beams or 1 - max_new_tokens = max_new_tokens or 128 - repetition_penalty = repetition_penalty or 1.07 - num_return_sequences = min(num_beams, num_return_sequences or 1) - do_sample = False if do_sample is None else do_sample - else: - temperature = 0.1 if temperature is None else temperature - top_p = 0.75 if top_p is None else top_p - top_k = 40 if top_k is None else top_k - num_beams = num_beams or 1 - max_new_tokens = max_new_tokens or 256 - repetition_penalty = repetition_penalty or 1.07 - num_return_sequences = min(num_beams, num_return_sequences or 1) - do_sample = False if do_sample is None else do_sample - # doesn't include chat, instruction_nochat, iinput_nochat, added later - params_list = [ - "", - stream_output, - prompt_type, - prompt_dict, - temperature, - top_p, - top_k, - num_beams, - max_new_tokens, - min_new_tokens, - early_stopping, - max_time, - repetition_penalty, - num_return_sequences, - do_sample, - ] - - if use_placeholder_instruction_as_example: - examples += [[placeholder_instruction, ""] + params_list] - - if use_default_examples: - examples += [ - ["Translate English to French", "Good morning"] + params_list, - [ - "Give detailed answer for whether Einstein or Newton is smarter.", - "", - ] - + params_list, - [ - "Explain in detailed list, all the best practices for coding in python.", - "", - ] - + params_list, - [ - "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.", - "", - ] - + params_list, - ["Translate to German: My name is Arthur", ""] + params_list, - [ - "Please answer to the following question. Who is going to be the next Ballon d'or?", - "", - ] - + params_list, - [ - "Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.", - "", - ] - + params_list, - [ - "Please answer the following question. What is the boiling point of Nitrogen?", - "", - ] - + params_list, - [ - "Answer the following yes/no question. Can you write a whole Haiku in a single tweet?", - "", - ] - + params_list, - [ - "Simplify the following expression: (False or False and True). Explain your answer.", - "", - ] - + params_list, - [ - "Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?", - "", - ] - + params_list, - [ - "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?", - "", - ] - + params_list, - [ - "Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?", - "", - ] - + params_list, - [ - """def area_of_rectangle(a: float, b: float): - \"\"\"Return the area of the rectangle.\"\"\"""", - "", - ] - + params_list, - [ - """# a function in native python: - def mean(a): - return sum(a)/len(a) - - # the same function using numpy: - import numpy as np - def mean(a):""", - "", - ] - + params_list, - [ - """X = np.random.randn(100, 100) - y = np.random.randint(0, 1, 100) - - # fit random forest classifier with 20 estimators""", - "", - ] - + params_list, - ] - # add summary example - examples += [ - [ - summarize_example1, - "Summarize" - if prompt_type not in ["plain", "instruct_simple"] - else "", - ] - + params_list - ] - - src_lang = "English" - tgt_lang = "Russian" - - # move to correct position - for example in examples: - example += [ - chat, - "", - "", - "Disabled", - LangChainAction.QUERY.value, - top_k_docs, - chunk, - chunk_size, - [DocumentChoices.All_Relevant.name], - ] - # adjust examples if non-chat mode - if not chat: - example[ - eval_func_param_names.index("instruction_nochat") - ] = example[eval_func_param_names.index("instruction")] - example[eval_func_param_names.index("instruction")] = "" - - example[ - eval_func_param_names.index("iinput_nochat") - ] = example[eval_func_param_names.index("iinput")] - example[eval_func_param_names.index("iinput")] = "" - assert len(example) == len( - eval_func_param_names - ), "Wrong example: %s %s" % ( - len(example), - len(eval_func_param_names), - ) - - if prompt_type == PromptType.custom.name and not prompt_dict: - raise ValueError( - "Unexpected to get non-empty prompt_dict=%s for prompt_type=%s" - % (prompt_dict, prompt_type) - ) - - # get prompt_dict from prompt_type, so user can see in UI etc., or for custom do nothing except check format - prompt_dict, error0 = get_prompt( - prompt_type, - prompt_dict, - chat=False, - context="", - reduced=False, - making_context=False, - return_dict=True, - ) - if error0: - raise RuntimeError("Prompt wrong: %s" % error0) - - return ( - placeholder_instruction, - placeholder_input, - stream_output, - show_examples, - prompt_type, - prompt_dict, - temperature, - top_p, - top_k, - num_beams, - max_new_tokens, - min_new_tokens, - early_stopping, - max_time, - repetition_penalty, - num_return_sequences, - do_sample, - src_lang, - tgt_lang, - examples, - task_info, - ) - - def languages_covered(self): - # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered - covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)""" - covered = covered.split(", ") - covered = { - x.split(" ")[0]: x.split(" ")[1].replace(")", "").replace("(", "") - for x in covered - } - return covered - - def get_context(self, chat_context, prompt_type): - if chat_context and prompt_type == "human_bot": - context0 = """: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand. - : I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed.""" - else: - context0 = "" - return context0 - - def score_qa( - self, - smodel, - stokenizer, - max_length_tokenize, - question, - answer, - cutoff_len, - ): - question = question[-cutoff_len:] - answer = answer[-cutoff_len:] - - inputs = stokenizer( - question, - answer, - return_tensors="pt", - truncation=True, - max_length=max_length_tokenize, - ).to(smodel.device) - try: - score = ( - torch.sigmoid(smodel(**inputs).logits[0]) - .cpu() - .detach() - .numpy()[0] - ) - except torch.cuda.OutOfMemoryError as e: - print( - "GPU OOM 3: question: %s answer: %s exception: %s" - % (question, answer, str(e)), - flush=True, - ) - del inputs - traceback.print_exc() - clear_torch_cache() - return "Response Score: GPU OOM" - except (Exception, RuntimeError) as e: - if ( - "Expected all tensors to be on the same device" in str(e) - or "expected scalar type Half but found Float" in str(e) - or "probability tensor contains either" in str(e) - or "cublasLt ran into an error!" in str(e) - or "device-side assert triggered" in str(e) - ): - print( - "GPU Error: question: %s answer: %s exception: %s" - % (question, answer, str(e)), - flush=True, - ) - traceback.print_exc() - clear_torch_cache() - return "Response Score: GPU Error" - else: - raise - os.environ["TOKENIZERS_PARALLELISM"] = "true" - return score - - def check_locals(self, **kwargs): - # ensure everything in evaluate is here - can_skip_because_locally_generated = no_default_param_names + [ - # get_model: - "reward_type" - ] - for k in eval_func_param_names: - if k in can_skip_because_locally_generated: - continue - assert k in kwargs, "Missing %s" % k - for k in inputs_kwargs_list: - if k in can_skip_because_locally_generated: - continue - assert k in kwargs, "Missing %s" % k - - for k in list(inspect.signature(self.get_model).parameters): - if k in can_skip_because_locally_generated: - continue - assert k in kwargs, "Missing %s" % k - - def get_model_max_length(self, model_state): - if not isinstance(model_state["tokenizer"], (str, types.NoneType)): - return model_state["tokenizer"].model_max_length - else: - return 2048 - - def get_max_max_new_tokens(self, model_state, **kwargs): - if not isinstance(model_state["tokenizer"], (str, types.NoneType)): - max_max_new_tokens = model_state["tokenizer"].model_max_length - else: - max_max_new_tokens = None - - if ( - kwargs["max_max_new_tokens"] is not None - and max_max_new_tokens is not None - ): - return min(max_max_new_tokens, kwargs["max_max_new_tokens"]) - elif kwargs["max_max_new_tokens"] is not None: - return kwargs["max_max_new_tokens"] - elif kwargs["memory_restriction_level"] == 1: - return 768 - elif kwargs["memory_restriction_level"] == 2: - return 512 - elif kwargs["memory_restriction_level"] >= 3: - return 256 - else: - # FIXME: Need to update after new model loaded, so user can control with slider - return 2048 - - def get_minmax_top_k_docs(self, is_public): - if is_public: - min_top_k_docs = 1 - max_top_k_docs = 3 - label_top_k_docs = "Number of document chunks" - else: - min_top_k_docs = -1 - max_top_k_docs = 100 - label_top_k_docs = ( - "Number of document chunks (-1 = auto fill model context)" - ) - return min_top_k_docs, max_top_k_docs, label_top_k_docs - - def history_to_context( - self, - history, - langchain_mode1, - prompt_type1, - prompt_dict1, - chat1, - model_max_length1, - memory_restriction_level1, - keep_sources_in_context1, - ): - """ - consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair - :param history: - :param langchain_mode1: - :param prompt_type1: - :param prompt_dict1: - :param chat1: - :param model_max_length1: - :param memory_restriction_level1: - :param keep_sources_in_context1: - :return: - """ - # ensure output will be unique to models - _, _, _, max_prompt_length = self.get_cutoffs( - memory_restriction_level1, - for_context=True, - model_max_length=model_max_length1, - ) - context1 = "" - if max_prompt_length is not None and langchain_mode1 not in ["LLM"]: - context1 = "" - # - 1 below because current instruction already in history from user() - for histi in range(0, len(history) - 1): - data_point = dict( - instruction=history[histi][0], - input="", - output=history[histi][1], - ) - ( - prompt, - pre_response, - terminate_response, - chat_sep, - chat_turn_sep, - ) = generate_prompt( - data_point, - prompt_type1, - prompt_dict1, - chat1, - reduced=True, - making_context=True, - ) - # md -> back to text, maybe not super important if model trained enough - if ( - not keep_sources_in_context1 - and langchain_mode1 != "Disabled" - and prompt.find(source_prefix) >= 0 - ): - # FIXME: This is relatively slow even for small amount of text, like 0.3s each history item - import re - - prompt = re.sub( - f"{re.escape(source_prefix)}.*?{re.escape(source_postfix)}", - "", - prompt, - flags=re.DOTALL, - ) - if prompt.endswith("\n

"): - prompt = prompt[:-4] - prompt = prompt.replace("
", chat_turn_sep) - if not prompt.endswith(chat_turn_sep): - prompt += chat_turn_sep - # most recent first, add older if can - # only include desired chat history - if len(prompt + context1) > max_prompt_length: - break - context1 += prompt - - ( - _, - pre_response, - terminate_response, - chat_sep, - chat_turn_sep, - ) = generate_prompt( - {}, - prompt_type1, - prompt_dict1, - chat1, - reduced=True, - making_context=True, - ) - if context1 and not context1.endswith(chat_turn_sep): - context1 += chat_turn_sep # ensure if terminates abruptly, then human continues on next line - return context1 - - -class H2OTextIteratorStreamer(TextIteratorStreamer): - """ - normally, timeout required for now to handle exceptions, else get() - but with H2O version of TextIteratorStreamer, loop over block to handle - """ - - def __init__( - self, - tokenizer, - skip_prompt: bool = False, - timeout: typing.Optional[float] = None, - block=True, - **decode_kwargs, - ): - super().__init__(tokenizer, skip_prompt, **decode_kwargs) - self.text_queue = queue.Queue() - self.stop_signal = None - self.do_stop = False - self.timeout = timeout - self.block = block - - def on_finalized_text(self, text: str, stream_end: bool = False): - """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" - self.text_queue.put(text, timeout=self.timeout) - if stream_end: - self.text_queue.put(self.stop_signal, timeout=self.timeout) - - def __iter__(self): - return self - - def __next__(self): - while True: - try: - value = ( - self.stop_signal - ) # value looks unused in pycharm, not true - if self.do_stop: - print("hit stop", flush=True) - # could raise or break, maybe best to raise and make parent see if any exception in thread - self.clear_queue() - self.do_stop = False - raise StopIteration() - # break - value = self.text_queue.get( - block=self.block, timeout=self.timeout - ) - break - except queue.Empty: - time.sleep(0.01) - if value == self.stop_signal: - self.clear_queue() - self.do_stop = False - raise StopIteration() - else: - return value - - def clear_queue(self): - # make sure streamer is reusable after stop hit - with self.text_queue.mutex: - self.text_queue.queue.clear() - - -def entrypoint_main(): - """ - Examples: - - WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B - python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B' - python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B' - - # generate without lora weights, no prompt - python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain' - python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' - - python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq' - # OpenChatKit settings: - python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 - - python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False - python generate.py --base_model='t5-large' --prompt_type='simple_instruct' - python generate.py --base_model='philschmid/bart-large-cnn-samsum' - python generate.py --base_model='philschmid/flan-t5-base-samsum' - python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt' - - python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28' - - must have 4*48GB GPU and run without 8bit in order for sharding to work with infer_devices=False - can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned - python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --infer_devices=False --prompt_type='human_bot' - - python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b - """ - import fire - - langchain = Langchain() - - fire.Fire(langchain.main) - - -if __name__ == "__main__": - entrypoint_main() diff --git a/apps/language_models/langchain/gpt4all_llm.py b/apps/language_models/langchain/gpt4all_llm.py deleted file mode 100644 index 67744b19..00000000 --- a/apps/language_models/langchain/gpt4all_llm.py +++ /dev/null @@ -1,380 +0,0 @@ -import inspect -import os -from functools import partial -from typing import Dict, Any, Optional, List -from langchain.callbacks.manager import CallbackManagerForLLMRun -from pydantic import root_validator -from langchain.llms import gpt4all -from dotenv import dotenv_values - -from utils import FakeTokenizer - - -def get_model_tokenizer_gpt4all(base_model, **kwargs): - # defaults (some of these are generation parameters, so need to be passed in at generation time) - model_kwargs = dict( - n_threads=os.cpu_count() // 2, - temp=kwargs.get("temperature", 0.2), - top_p=kwargs.get("top_p", 0.75), - top_k=kwargs.get("top_k", 40), - n_ctx=2048 - 256, - ) - env_gpt4all_file = ".env_gpt4all" - model_kwargs.update(dotenv_values(env_gpt4all_file)) - # make int or float if can to satisfy types for class - for k, v in model_kwargs.items(): - try: - if float(v) == int(v): - model_kwargs[k] = int(v) - else: - model_kwargs[k] = float(v) - except: - pass - - if base_model == "llama": - if "model_path_llama" not in model_kwargs: - raise ValueError("No model_path_llama in %s" % env_gpt4all_file) - model_path = model_kwargs.pop("model_path_llama") - # FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python - from llama_cpp import Llama - - # llama sets some things at init model time, not generation time - func_names = list(inspect.signature(Llama.__init__).parameters) - model_kwargs = { - k: v for k, v in model_kwargs.items() if k in func_names - } - model_kwargs["n_ctx"] = int(model_kwargs["n_ctx"]) - model = Llama(model_path=model_path, **model_kwargs) - elif base_model in "gpt4all_llama": - if ( - "model_name_gpt4all_llama" not in model_kwargs - and "model_path_gpt4all_llama" not in model_kwargs - ): - raise ValueError( - "No model_name_gpt4all_llama or model_path_gpt4all_llama in %s" - % env_gpt4all_file - ) - model_name = model_kwargs.pop("model_name_gpt4all_llama") - model_type = "llama" - from gpt4all import GPT4All as GPT4AllModel - - model = GPT4AllModel(model_name=model_name, model_type=model_type) - elif base_model in "gptj": - if ( - "model_name_gptj" not in model_kwargs - and "model_path_gptj" not in model_kwargs - ): - raise ValueError( - "No model_name_gpt4j or model_path_gpt4j in %s" - % env_gpt4all_file - ) - model_name = model_kwargs.pop("model_name_gptj") - model_type = "gptj" - from gpt4all import GPT4All as GPT4AllModel - - model = GPT4AllModel(model_name=model_name, model_type=model_type) - else: - raise ValueError("No such base_model %s" % base_model) - return model, FakeTokenizer(), "cpu" - - -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler - - -class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - # streaming to std already occurs without this - # sys.stdout.write(token) - # sys.stdout.flush() - pass - - -def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]): - # default from class - model_kwargs = { - k: v.default - for k, v in dict(inspect.signature(cls).parameters).items() - if k not in exclude_list - } - # from our defaults - model_kwargs.update(default_kwargs) - # from user defaults - model_kwargs.update(env_kwargs) - # ensure only valid keys - func_names = list(inspect.signature(cls).parameters) - model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names} - return model_kwargs - - -def get_llm_gpt4all( - model_name, - model=None, - max_new_tokens=256, - temperature=0.1, - repetition_penalty=1.0, - top_k=40, - top_p=0.7, - streaming=False, - callbacks=None, - prompter=None, - verbose=False, -): - assert prompter is not None - env_gpt4all_file = ".env_gpt4all" - env_kwargs = dotenv_values(env_gpt4all_file) - n_ctx = env_kwargs.pop("n_ctx", 2048 - max_new_tokens) - default_kwargs = dict( - context_erase=0.5, - n_batch=1, - n_ctx=n_ctx, - n_predict=max_new_tokens, - repeat_last_n=64 if repetition_penalty != 1.0 else 0, - repeat_penalty=repetition_penalty, - temp=temperature, - temperature=temperature, - top_k=top_k, - top_p=top_p, - use_mlock=True, - verbose=verbose, - ) - if model_name == "llama": - cls = H2OLlamaCpp - model_path = ( - env_kwargs.pop("model_path_llama") if model is None else model - ) - model_kwargs = get_model_kwargs( - env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"] - ) - model_kwargs.update( - dict( - model_path=model_path, - callbacks=callbacks, - streaming=streaming, - prompter=prompter, - ) - ) - llm = cls(**model_kwargs) - llm.client.verbose = verbose - elif model_name == "gpt4all_llama": - cls = H2OGPT4All - model_path = ( - env_kwargs.pop("model_path_gpt4all_llama") - if model is None - else model - ) - model_kwargs = get_model_kwargs( - env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"] - ) - model_kwargs.update( - dict( - model=model_path, - backend="llama", - callbacks=callbacks, - streaming=streaming, - prompter=prompter, - ) - ) - llm = cls(**model_kwargs) - elif model_name == "gptj": - cls = H2OGPT4All - model_path = ( - env_kwargs.pop("model_path_gptj") if model is None else model - ) - model_kwargs = get_model_kwargs( - env_kwargs, default_kwargs, cls, exclude_list=["lc_kwargs"] - ) - model_kwargs.update( - dict( - model=model_path, - backend="gptj", - callbacks=callbacks, - streaming=streaming, - prompter=prompter, - ) - ) - llm = cls(**model_kwargs) - else: - raise RuntimeError("No such model_name %s" % model_name) - return llm - - -class H2OGPT4All(gpt4all.GPT4All): - model: Any - prompter: Any - """Path to the pre-trained GPT4All model file.""" - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that the python package exists in the environment.""" - try: - if isinstance(values["model"], str): - from gpt4all import GPT4All as GPT4AllModel - - full_path = values["model"] - model_path, delimiter, model_name = full_path.rpartition("/") - model_path += delimiter - - values["client"] = GPT4AllModel( - model_name=model_name, - model_path=model_path or None, - model_type=values["backend"], - allow_download=False, - ) - if values["n_threads"] is not None: - # set n_threads - values["client"].model.set_thread_count( - values["n_threads"] - ) - else: - values["client"] = values["model"] - try: - values["backend"] = values["client"].model_type - except AttributeError: - # The below is for compatibility with GPT4All Python bindings <= 0.2.3. - values["backend"] = values["client"].model.model_type - - except ImportError: - raise ValueError( - "Could not import gpt4all python package. " - "Please install it with `pip install gpt4all`." - ) - return values - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs, - ) -> str: - # Roughly 4 chars per token if natural language - prompt = prompt[-self.n_ctx * 4 :] - - # use instruct prompting - data_point = dict(context="", instruction=prompt, input="") - prompt = self.prompter.generate_prompt(data_point) - - verbose = False - if verbose: - print("_call prompt: %s" % prompt, flush=True) - # FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout - return super()._call(prompt, stop=stop, run_manager=run_manager) - - -from langchain.llms import LlamaCpp - - -class H2OLlamaCpp(LlamaCpp): - model_path: Any - prompter: Any - """Path to the pre-trained GPT4All model file.""" - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that llama-cpp-python library is installed.""" - if isinstance(values["model_path"], str): - model_path = values["model_path"] - model_param_names = [ - "lora_path", - "lora_base", - "n_ctx", - "n_parts", - "seed", - "f16_kv", - "logits_all", - "vocab_only", - "use_mlock", - "n_threads", - "n_batch", - "use_mmap", - "last_n_tokens_size", - ] - model_params = {k: values[k] for k in model_param_names} - # For backwards compatibility, only include if non-null. - if values["n_gpu_layers"] is not None: - model_params["n_gpu_layers"] = values["n_gpu_layers"] - - try: - from llama_cpp import Llama - - values["client"] = Llama(model_path, **model_params) - except ImportError: - raise ModuleNotFoundError( - "Could not import llama-cpp-python library. " - "Please install the llama-cpp-python library to " - "use this embedding model: pip install llama-cpp-python" - ) - except Exception as e: - raise ValueError( - f"Could not load Llama model from path: {model_path}. " - f"Received error {e}" - ) - else: - values["client"] = values["model_path"] - return values - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs, - ) -> str: - verbose = False - # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate - # still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal - prompt = prompt[-self.n_ctx * 4 :] - prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8")) - num_prompt_tokens = len(prompt_tokens) - if num_prompt_tokens > self.n_ctx: - # conservative by using int() - chars_per_token = int(len(prompt) / num_prompt_tokens) - prompt = prompt[-self.n_ctx * chars_per_token :] - if verbose: - print( - "reducing tokens, assuming average of %s chars/token: %s" - % chars_per_token, - flush=True, - ) - prompt_tokens2 = self.client.tokenize( - b" " + prompt.encode("utf-8") - ) - num_prompt_tokens2 = len(prompt_tokens2) - print( - "reduced tokens from %d -> %d" - % (num_prompt_tokens, num_prompt_tokens2), - flush=True, - ) - - # use instruct prompting - data_point = dict(context="", instruction=prompt, input="") - prompt = self.prompter.generate_prompt(data_point) - - if verbose: - print("_call prompt: %s" % prompt, flush=True) - - if self.streaming: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter - if text_callback: - text_callback(prompt) - text = "" - for token in self.stream( - prompt=prompt, stop=stop, run_manager=run_manager - ): - text_chunk = token["choices"][0]["text"] - # self.stream already calls text_callback - # if text_callback: - # text_callback(text_chunk) - text += text_chunk - return text - else: - params = self._get_parameters(stop) - params = {**params, **kwargs} - result = self.client(prompt=prompt, **params) - return result["choices"][0]["text"] diff --git a/apps/language_models/langchain/gpt_langchain.py b/apps/language_models/langchain/gpt_langchain.py deleted file mode 100644 index 6c2db5ec..00000000 --- a/apps/language_models/langchain/gpt_langchain.py +++ /dev/null @@ -1,3137 +0,0 @@ -import ast -import glob -import inspect -import os -import pathlib -import pickle -import shutil -import subprocess -import tempfile -import time -import traceback -import types -import uuid -import zipfile -from collections import defaultdict -from datetime import datetime -from functools import reduce -from operator import concat -import filelock - -from joblib import delayed -from langchain.callbacks import streaming_stdout -from langchain.embeddings import HuggingFaceInstructEmbeddings -from tqdm import tqdm - -from enums import ( - DocumentChoices, - no_lora_str, - model_token_mapping, - source_prefix, - source_postfix, - non_query_commands, - LangChainAction, - LangChainMode, -) -from evaluate_params import gen_hyper -from gen import Langchain, SEED -from prompter import non_hf_types, PromptType, Prompter -from utils import ( - wrapped_partial, - EThread, - import_matplotlib, - sanitize_filename, - makedirs, - get_url, - flatten_list, - ProgressParallel, - remove, - hash_file, - clear_torch_cache, - NullContext, - get_hf_server, - FakeTokenizer, -) -from utils_langchain import StreamingGradioCallbackHandler - -import_matplotlib() - -import numpy as np -import pandas as pd -import requests -from langchain.chains.qa_with_sources import load_qa_with_sources_chain - -# , GCSDirectoryLoader, GCSFileLoader -# , OutlookMessageLoader # GPL3 -# ImageCaptionLoader, # use our own wrapper -# ReadTheDocsLoader, # no special file, some path, so have to give as special option -from langchain.document_loaders import ( - PyPDFLoader, - TextLoader, - CSVLoader, - PythonLoader, - TomlLoader, - UnstructuredURLLoader, - UnstructuredHTMLLoader, - UnstructuredWordDocumentLoader, - UnstructuredMarkdownLoader, - EverNoteLoader, - UnstructuredEmailLoader, - UnstructuredODTLoader, - UnstructuredPowerPointLoader, - UnstructuredEPubLoader, - UnstructuredImageLoader, - UnstructuredRTFLoader, - ArxivLoader, - UnstructuredPDFLoader, - UnstructuredExcelLoader, -) -from langchain.text_splitter import RecursiveCharacterTextSplitter, Language -from expanded_pipelines import load_qa_chain -from langchain.docstore.document import Document -from langchain import PromptTemplate, HuggingFaceTextGenInference -from langchain.vectorstores import Chroma -from apps.stable_diffusion.src import args - - -def get_db( - sources, - use_openai_embedding=False, - db_type="faiss", - persist_directory="db_dir", - load_db_if_exists=True, - langchain_mode="notset", - collection_name=None, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", -): - if not sources: - return None - - # get embedding model - embedding = get_embedding( - use_openai_embedding, hf_embedding_model=hf_embedding_model - ) - assert collection_name is not None or langchain_mode != "notset" - if collection_name is None: - collection_name = langchain_mode.replace(" ", "_") - - # Create vector database - if db_type == "faiss": - from langchain.vectorstores import FAISS - - db = FAISS.from_documents(sources, embedding) - elif db_type == "weaviate": - import weaviate - from weaviate.embedded import EmbeddedOptions - from langchain.vectorstores import Weaviate - - if os.getenv("WEAVIATE_URL", None): - client = _create_local_weaviate_client() - else: - client = weaviate.Client(embedded_options=EmbeddedOptions()) - index_name = collection_name.capitalize() - db = Weaviate.from_documents( - documents=sources, - embedding=embedding, - client=client, - by_text=False, - index_name=index_name, - ) - elif db_type == "chroma": - assert persist_directory is not None - os.makedirs(persist_directory, exist_ok=True) - - # see if already actually have persistent db, and deal with possible changes in embedding - db = get_existing_db( - None, - persist_directory, - load_db_if_exists, - db_type, - use_openai_embedding, - langchain_mode, - hf_embedding_model, - verbose=False, - ) - if db is None: - db = Chroma.from_documents( - documents=sources, - embedding=embedding, - persist_directory=persist_directory, - collection_name=collection_name, - anonymized_telemetry=False, - ) - db.persist() - clear_embedding(db) - save_embed(db, use_openai_embedding, hf_embedding_model) - else: - # then just add - db, num_new_sources, new_sources_metadata = add_to_db( - db, - sources, - db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model, - ) - else: - raise RuntimeError("No such db_type=%s" % db_type) - - return db - - -def _get_unique_sources_in_weaviate(db): - batch_size = 100 - id_source_list = [] - result = db._client.data_object.get( - class_name=db._index_name, limit=batch_size - ) - - while result["objects"]: - id_source_list += [ - (obj["id"], obj["properties"]["source"]) - for obj in result["objects"] - ] - last_id = id_source_list[-1][0] - result = db._client.data_object.get( - class_name=db._index_name, limit=batch_size, after=last_id - ) - - unique_sources = {source for _, source in id_source_list} - return unique_sources - - -def add_to_db( - db, - sources, - db_type="faiss", - avoid_dup_by_file=False, - avoid_dup_by_content=True, - use_openai_embedding=False, - hf_embedding_model=None, -): - assert hf_embedding_model is not None - num_new_sources = len(sources) - if not sources: - return db, num_new_sources, [] - if db_type == "faiss": - db.add_documents(sources) - elif db_type == "weaviate": - # FIXME: only control by file name, not hash yet - if avoid_dup_by_file or avoid_dup_by_content: - unique_sources = _get_unique_sources_in_weaviate(db) - sources = [ - x - for x in sources - if x.metadata["source"] not in unique_sources - ] - num_new_sources = len(sources) - if num_new_sources == 0: - return db, num_new_sources, [] - db.add_documents(documents=sources) - elif db_type == "chroma": - collection = get_documents(db) - # files we already have: - metadata_files = set([x["source"] for x in collection["metadatas"]]) - if avoid_dup_by_file: - # Too weak in case file changed content, assume parent shouldn't pass true for this for now - raise RuntimeError("Not desired code path") - sources = [ - x - for x in sources - if x.metadata["source"] not in metadata_files - ] - if avoid_dup_by_content: - # look at hash, instead of page_content - # migration: If no hash previously, avoid updating, - # since don't know if need to update and may be expensive to redo all unhashed files - metadata_hash_ids = set( - [ - x["hashid"] - for x in collection["metadatas"] - if "hashid" in x and x["hashid"] not in ["None", None] - ] - ) - # avoid sources with same hash - sources = [ - x - for x in sources - if x.metadata.get("hashid") not in metadata_hash_ids - ] - num_nohash = len( - [x for x in sources if not x.metadata.get("hashid")] - ) - print( - "Found %s new sources (%d have no hash in original source," - " so have to reprocess for migration to sources with hash)" - % (len(sources), num_nohash), - flush=True, - ) - # get new file names that match existing file names. delete existing files we are overridding - dup_metadata_files = set( - [ - x.metadata["source"] - for x in sources - if x.metadata["source"] in metadata_files - ] - ) - print( - "Removing %s duplicate files from db because ingesting those as new documents" - % len(dup_metadata_files), - flush=True, - ) - client_collection = db._client.get_collection( - name=db._collection.name, - embedding_function=db._collection._embedding_function, - ) - for dup_file in dup_metadata_files: - dup_file_meta = dict(source=dup_file) - try: - client_collection.delete(where=dup_file_meta) - except KeyError: - pass - num_new_sources = len(sources) - if num_new_sources == 0: - return db, num_new_sources, [] - db.add_documents(documents=sources) - db.persist() - clear_embedding(db) - save_embed(db, use_openai_embedding, hf_embedding_model) - else: - raise RuntimeError("No such db_type=%s" % db_type) - - new_sources_metadata = [x.metadata for x in sources] - - return db, num_new_sources, new_sources_metadata - - -def create_or_update_db( - db_type, - persist_directory, - collection_name, - sources, - use_openai_embedding, - add_if_exists, - verbose, - hf_embedding_model, -): - if db_type == "weaviate": - import weaviate - from weaviate.embedded import EmbeddedOptions - - if os.getenv("WEAVIATE_URL", None): - client = _create_local_weaviate_client() - else: - client = weaviate.Client(embedded_options=EmbeddedOptions()) - - index_name = collection_name.replace(" ", "_").capitalize() - if client.schema.exists(index_name) and not add_if_exists: - client.schema.delete_class(index_name) - if verbose: - print("Removing %s" % index_name, flush=True) - elif db_type == "chroma": - if not os.path.isdir(persist_directory) or not add_if_exists: - if os.path.isdir(persist_directory): - if verbose: - print("Removing %s" % persist_directory, flush=True) - remove(persist_directory) - if verbose: - print("Generating db", flush=True) - - if not add_if_exists: - if verbose: - print("Generating db", flush=True) - else: - if verbose: - print("Loading and updating db", flush=True) - - db = get_db( - sources, - use_openai_embedding=use_openai_embedding, - db_type=db_type, - persist_directory=persist_directory, - langchain_mode=collection_name, - hf_embedding_model=hf_embedding_model, - ) - - return db - - -def get_embedding( - use_openai_embedding, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", -): - # Get embedding model - if use_openai_embedding: - assert ( - os.getenv("OPENAI_API_KEY") is not None - ), "Set ENV OPENAI_API_KEY" - from langchain.embeddings import OpenAIEmbeddings - - embedding = OpenAIEmbeddings(disallowed_special=()) - else: - # to ensure can fork without deadlock - from langchain.embeddings import HuggingFaceEmbeddings - - torch_dtype, context_class = get_dtype() - model_kwargs = dict(device=args.device) - if "instructor" in hf_embedding_model: - encode_kwargs = {"normalize_embeddings": True} - embedding = HuggingFaceInstructEmbeddings( - model_name=hf_embedding_model, - model_kwargs=model_kwargs, - encode_kwargs=encode_kwargs, - ) - else: - embedding = HuggingFaceEmbeddings( - model_name=hf_embedding_model, model_kwargs=model_kwargs - ) - return embedding - - -def get_answer_from_sources(chain, sources, question): - return chain( - { - "input_documents": sources, - "question": question, - }, - return_only_outputs=True, - )["output_text"] - - -"""Wrapper around Huggingface text generation inference API.""" -from functools import partial -from typing import Any, Dict, List, Optional, Set - -from pydantic import Extra, Field, root_validator - -from langchain.callbacks.manager import CallbackManagerForLLMRun - -"""Wrapper around Huggingface text generation inference API.""" -from functools import partial -from typing import Any, Dict, List, Optional - -from pydantic import Extra, Field, root_validator - -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM - - -class GradioInference(LLM): - """ - Gradio generation inference API. - """ - - inference_server_url: str = "" - - temperature: float = 0.8 - top_p: Optional[float] = 0.95 - top_k: Optional[int] = None - num_beams: Optional[int] = 1 - max_new_tokens: int = 512 - min_new_tokens: int = 1 - early_stopping: bool = False - max_time: int = 180 - repetition_penalty: Optional[float] = None - num_return_sequences: Optional[int] = 1 - do_sample: bool = False - chat_client: bool = False - - return_full_text: bool = True - stream_output: bool = Field(False, alias="stream") - sanitize_bot_response: bool = False - - prompter: Any = None - client: Any = None - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that python package exists in environment.""" - - try: - if values["client"] is None: - import gradio_client - - values["client"] = gradio_client.Client( - values["inference_server_url"] - ) - except ImportError: - raise ImportError( - "Could not import gradio_client python package. " - "Please install it with `pip install gradio_client`." - ) - return values - - @property - def _llm_type(self) -> str: - """Return type of llm.""" - return "gradio_inference" - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - # NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection, - # so server should get prompt_type or '', not plain - # This is good, so gradio server can also handle stopping.py conditions - # this is different than TGI server that uses prompter to inject prompt_type prompting - stream_output = self.stream_output - gr_client = self.client - client_langchain_mode = "Disabled" - client_langchain_action = LangChainAction.QUERY.value - top_k_docs = 1 - chunk = True - chunk_size = 512 - client_kwargs = dict( - instruction=prompt - if self.chat_client - else "", # only for chat=True - iinput="", # only for chat=True - context="", - # streaming output is supported, loops over and outputs each generation in streaming mode - # but leave stream_output=False for simple input/output mode - stream_output=stream_output, - prompt_type=self.prompter.prompt_type, - prompt_dict="", - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - num_beams=self.num_beams, - max_new_tokens=self.max_new_tokens, - min_new_tokens=self.min_new_tokens, - early_stopping=self.early_stopping, - max_time=self.max_time, - repetition_penalty=self.repetition_penalty, - num_return_sequences=self.num_return_sequences, - do_sample=self.do_sample, - chat=self.chat_client, - instruction_nochat=prompt if not self.chat_client else "", - iinput_nochat="", # only for chat=False - langchain_mode=client_langchain_mode, - langchain_action=client_langchain_action, - top_k_docs=top_k_docs, - chunk=chunk, - chunk_size=chunk_size, - document_choice=[DocumentChoices.All_Relevant.name], - ) - api_name = "/submit_nochat_api" # NOTE: like submit_nochat but stable API for string dict passing - if not stream_output: - res = gr_client.predict( - str(dict(client_kwargs)), api_name=api_name - ) - res_dict = ast.literal_eval(res) - text = res_dict["response"] - return self.prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response, - ) - else: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - - job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name) - text0 = "" - while not job.done(): - outputs_list = job.communicator.job.outputs - if outputs_list: - res = job.communicator.job.outputs[-1] - res_dict = ast.literal_eval(res) - text = res_dict["response"] - text = self.prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response, - ) - # FIXME: derive chunk from full for now - text_chunk = text[len(text0) :] - # save old - text0 = text - - if text_callback: - text_callback(text_chunk) - - time.sleep(0.01) - - # ensure get last output to avoid race - res_all = job.outputs() - if len(res_all) > 0: - res = res_all[-1] - res_dict = ast.literal_eval(res) - text = res_dict["response"] - # FIXME: derive chunk from full for now - else: - # go with old if failure - text = text0 - text_chunk = text[len(text0) :] - if text_callback: - text_callback(text_chunk) - return self.prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response, - ) - - -class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference): - max_new_tokens: int = 512 - do_sample: bool = False - top_k: Optional[int] = None - top_p: Optional[float] = 0.95 - typical_p: Optional[float] = 0.95 - temperature: float = 0.8 - repetition_penalty: Optional[float] = None - return_full_text: bool = False - stop_sequences: List[str] = Field(default_factory=list) - seed: Optional[int] = None - inference_server_url: str = "" - timeout: int = 300 - headers: dict = None - stream_output: bool = Field(False, alias="stream") - sanitize_bot_response: bool = False - prompter: Any = None - tokenizer: Any = None - client: Any = None - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that python package exists in environment.""" - - try: - if values["client"] is None: - import text_generation - - values["client"] = text_generation.Client( - values["inference_server_url"], - timeout=values["timeout"], - headers=values["headers"], - ) - except ImportError: - raise ImportError( - "Could not import text_generation python package. " - "Please install it with `pip install text_generation`." - ) - return values - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - if stop is None: - stop = self.stop_sequences - else: - stop += self.stop_sequences - - # HF inference server needs control over input tokens - assert self.tokenizer is not None - from h2oai_pipeline import H2OTextGenerationPipeline - - prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt( - prompt, self.tokenizer - ) - - # NOTE: TGI server does not add prompting, so must do here - data_point = dict(context="", instruction=prompt, input="") - prompt = self.prompter.generate_prompt(data_point) - - gen_server_kwargs = dict( - do_sample=self.do_sample, - stop_sequences=stop, - max_new_tokens=self.max_new_tokens, - top_k=self.top_k, - top_p=self.top_p, - typical_p=self.typical_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - return_full_text=self.return_full_text, - seed=self.seed, - ) - gen_server_kwargs.update(kwargs) - - # lower bound because client is re-used if multi-threading - self.client.timeout = max(300, self.timeout) - - if not self.stream_output: - res = self.client.generate( - prompt, - **gen_server_kwargs, - ) - if self.return_full_text: - gen_text = res.generated_text[len(prompt) :] - else: - gen_text = res.generated_text - # remove stop sequences from the end of the generated text - for stop_seq in stop: - if stop_seq in gen_text: - gen_text = gen_text[: gen_text.index(stop_seq)] - text = prompt + gen_text - text = self.prompter.get_response( - text, - prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response, - ) - else: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter - if text_callback: - text_callback(prompt) - text = "" - # Note: Streaming ignores return_full_text=True - for response in self.client.generate_stream( - prompt, **gen_server_kwargs - ): - text_chunk = response.token.text - text += text_chunk - text = self.prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response, - ) - # stream part - is_stop = False - for stop_seq in stop: - if stop_seq in response.token.text: - is_stop = True - break - if is_stop: - break - if not response.token.special: - if text_callback: - text_callback(response.token.text) - return text - - -from langchain.chat_models import ChatOpenAI - - -class H2OChatOpenAI(ChatOpenAI): - @classmethod - def all_required_field_names(cls) -> Set: - all_required_field_names = super( - ChatOpenAI, cls - ).all_required_field_names() - all_required_field_names.update( - {"top_p", "frequency_penalty", "presence_penalty"} - ) - return all_required_field_names - - -def get_llm( - use_openai_model=False, - model_name=None, - model=None, - tokenizer=None, - inference_server=None, - stream_output=False, - do_sample=False, - temperature=0.1, - top_k=40, - top_p=0.7, - num_beams=1, - max_new_tokens=256, - min_new_tokens=1, - early_stopping=False, - max_time=180, - repetition_penalty=1.0, - num_return_sequences=1, - prompt_type=None, - prompt_dict=None, - prompter=None, - sanitize_bot_response=False, - verbose=False, -): - if use_openai_model or inference_server in ["openai", "openai_chat"]: - if use_openai_model and model_name is None: - model_name = "gpt-3.5-turbo" - if inference_server == "openai": - from langchain.llms import OpenAI - - cls = OpenAI - else: - cls = H2OChatOpenAI - callbacks = [StreamingGradioCallbackHandler()] - llm = cls( - model_name=model_name, - temperature=temperature if do_sample else 0, - # FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py - max_tokens=max_new_tokens, - top_p=top_p if do_sample else 1, - frequency_penalty=0, - presence_penalty=1.07 - - repetition_penalty - + 0.6, # so good default - callbacks=callbacks if stream_output else None, - ) - streamer = callbacks[0] if stream_output else None - if inference_server in ["openai", "openai_chat"]: - prompt_type = inference_server - else: - prompt_type = prompt_type or "plain" - elif inference_server: - assert inference_server.startswith("http"), ( - "Malformed inference_server=%s. Did you add http:// in front?" - % inference_server - ) - - from gradio_utils.grclient import GradioClient - from text_generation import Client as HFClient - - if isinstance(model, GradioClient): - gr_client = model - hf_client = None - else: - gr_client = None - hf_client = model - assert isinstance(hf_client, HFClient) - - inference_server, headers = get_hf_server(inference_server) - - # quick sanity check to avoid long timeouts, just see if can reach server - requests.get( - inference_server, - timeout=int(os.getenv("REQUEST_TIMEOUT_FAST", "10")), - ) - - callbacks = [StreamingGradioCallbackHandler()] - assert prompter is not None - stop_sequences = list( - set(prompter.terminate_response + [prompter.PreResponse]) - ) - stop_sequences = [x for x in stop_sequences if x] - - if gr_client: - chat_client = False - llm = GradioInference( - inference_server_url=inference_server, - return_full_text=True, - temperature=temperature, - top_p=top_p, - top_k=top_k, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - do_sample=do_sample, - chat_client=chat_client, - callbacks=callbacks if stream_output else None, - stream=stream_output, - prompter=prompter, - client=gr_client, - sanitize_bot_response=sanitize_bot_response, - ) - elif hf_client: - llm = H2OHuggingFaceTextGenInference( - inference_server_url=inference_server, - do_sample=do_sample, - max_new_tokens=max_new_tokens, - repetition_penalty=repetition_penalty, - return_full_text=True, - seed=SEED, - stop_sequences=stop_sequences, - temperature=temperature, - top_k=top_k, - top_p=top_p, - # typical_p=top_p, - callbacks=callbacks if stream_output else None, - stream_output=stream_output, - prompter=prompter, - tokenizer=tokenizer, - client=hf_client, - timeout=max_time, - sanitize_bot_response=sanitize_bot_response, - ) - else: - raise RuntimeError("No defined client") - streamer = callbacks[0] if stream_output else None - elif model_name in non_hf_types: - if model_name == "llama": - callbacks = [StreamingGradioCallbackHandler()] - streamer = callbacks[0] if stream_output else None - else: - # stream_output = False - # doesn't stream properly as generator, but at least - callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] - streamer = None - if prompter: - prompt_type = prompter.prompt_type - else: - prompter = Prompter( - prompt_type, - prompt_dict, - debug=False, - chat=False, - stream_output=stream_output, - ) - pass # assume inputted prompt_type is correct - from gpt4all_llm import get_llm_gpt4all - - llm = get_llm_gpt4all( - model_name, - model=model, - max_new_tokens=max_new_tokens, - temperature=temperature, - repetition_penalty=repetition_penalty, - top_k=top_k, - top_p=top_p, - callbacks=callbacks, - verbose=verbose, - streaming=stream_output, - prompter=prompter, - ) - else: - if model is None: - # only used if didn't pass model in - assert tokenizer is None - prompt_type = "human_bot" - if model_name is None: - model_name = "h2oai/h2ogpt-oasst1-512-12b" - # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' - # model_name = 'h2oai/h2ogpt-oasst1-512-20b' - inference_server = "" - model, tokenizer, _ = Langchain.get_model( - load_8bit=True, - base_model=model_name, - inference_server=inference_server, - gpu_id=0, - ) - - max_max_tokens = tokenizer.model_max_length - gen_kwargs = dict( - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - return_full_text=True, - handle_long_generation=None, - ) - assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0 - - if stream_output: - skip_prompt = False - from gen import H2OTextIteratorStreamer - - decoder_kwargs = {} - streamer = H2OTextIteratorStreamer( - tokenizer, - skip_prompt=skip_prompt, - block=False, - **decoder_kwargs, - ) - gen_kwargs.update(dict(streamer=streamer)) - else: - streamer = None - - from h2oai_pipeline import H2OTextGenerationPipeline - - pipe = H2OTextGenerationPipeline( - model=model, - use_prompter=True, - prompter=prompter, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - sanitize_bot_response=sanitize_bot_response, - chat=False, - stream_output=stream_output, - tokenizer=tokenizer, - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens=max_max_tokens - max(min_new_tokens, 256), - **gen_kwargs, - ) - # pipe.task = "text-generation" - # below makes it listen only to our prompt removal, - # not built in prompt removal that is less general and not specific for our model - pipe.task = "text2text-generation" - - from langchain.llms import HuggingFacePipeline - - llm = HuggingFacePipeline(pipeline=pipe) - return llm, model_name, streamer, prompt_type - - -def get_dtype(): - # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently - import torch - - # from utils import NullContext - # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class - context_class = torch.device - torch_dtype = torch.float16 if args.device == "cuda" else torch.float32 - return torch_dtype, context_class - - -def get_wiki_data( - title, first_paragraph_only, text_limit=None, take_head=True -): - """ - Get wikipedia data from online - :param title: - :param first_paragraph_only: - :param text_limit: - :param take_head: - :return: - """ - filename = "wiki_%s_%s_%s_%s.data" % ( - first_paragraph_only, - title, - text_limit, - take_head, - ) - url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}" - if first_paragraph_only: - url += "&exintro=1" - import json - - if not os.path.isfile(filename): - data = requests.get(url).json() - json.dump(data, open(filename, "wt")) - else: - data = json.load(open(filename, "rt")) - page_content = list(data["query"]["pages"].values())[0]["extract"] - if take_head is not None and text_limit is not None: - page_content = ( - page_content[:text_limit] - if take_head - else page_content[-text_limit:] - ) - title_url = str(title).replace(" ", "_") - return Document( - page_content=page_content, - metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"}, - ) - - -def get_wiki_sources(first_para=True, text_limit=None): - """ - Get specific named sources from wikipedia - :param first_para: - :param text_limit: - :return: - """ - default_wiki_sources = ["Unix", "Microsoft_Windows", "Linux"] - wiki_sources = list(os.getenv("WIKI_SOURCES", default_wiki_sources)) - return [ - get_wiki_data(x, first_para, text_limit=text_limit) - for x in wiki_sources - ] - - -def get_github_docs(repo_owner, repo_name): - """ - Access github from specific repo - :param repo_owner: - :param repo_name: - :return: - """ - with tempfile.TemporaryDirectory() as d: - subprocess.check_call( - f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .", - cwd=d, - shell=True, - ) - git_sha = ( - subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d) - .decode("utf-8") - .strip() - ) - repo_path = pathlib.Path(d) - markdown_files = list(repo_path.glob("*/*.md")) + list( - repo_path.glob("*/*.mdx") - ) - for markdown_file in markdown_files: - with open(markdown_file, "r") as f: - relative_path = markdown_file.relative_to(repo_path) - github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}" - yield Document( - page_content=f.read(), metadata={"source": github_url} - ) - - -def get_dai_pickle(dest="."): - from huggingface_hub import hf_hub_download - - # True for case when locally already logged in with correct token, so don't have to set key - token = os.getenv("HUGGINGFACE_API_TOKEN", True) - path_to_zip_file = hf_hub_download( - "h2oai/dai_docs", "dai_docs.pickle", token=token, repo_type="dataset" - ) - shutil.copy(path_to_zip_file, dest) - - -def get_dai_docs(from_hf=False, get_pickle=True): - """ - Consume DAI documentation, or consume from public pickle - :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain - :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF - :return: - """ - import pickle - - if get_pickle: - get_dai_pickle() - - dai_store = "dai_docs.pickle" - dst = "working_dir_docs" - if not os.path.isfile(dai_store): - from create_data import setup_dai_docs - - dst = setup_dai_docs(dst=dst, from_hf=from_hf) - - import glob - - files = list(glob.glob(os.path.join(dst, "*rst"), recursive=True)) - - basedir = os.path.abspath(os.getcwd()) - from create_data import rst_to_outputs - - new_outputs = rst_to_outputs(files) - os.chdir(basedir) - - pickle.dump(new_outputs, open(dai_store, "wb")) - else: - new_outputs = pickle.load(open(dai_store, "rb")) - - sources = [] - for line, file in new_outputs: - # gradio requires any linked file to be with app.py - sym_src = os.path.abspath(os.path.join(dst, file)) - sym_dst = os.path.abspath(os.path.join(os.getcwd(), file)) - if os.path.lexists(sym_dst): - os.remove(sym_dst) - os.symlink(sym_src, sym_dst) - itm = Document(page_content=line, metadata={"source": file}) - # NOTE: yield has issues when going into db, loses metadata - # yield itm - sources.append(itm) - return sources - - -import distutils.spawn - -have_tesseract = distutils.spawn.find_executable("tesseract") -have_libreoffice = distutils.spawn.find_executable("libreoffice") - -import pkg_resources - -try: - assert pkg_resources.get_distribution("arxiv") is not None - assert pkg_resources.get_distribution("pymupdf") is not None - have_arxiv = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_arxiv = False - -try: - assert pkg_resources.get_distribution("pymupdf") is not None - have_pymupdf = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_pymupdf = False - -try: - assert pkg_resources.get_distribution("selenium") is not None - have_selenium = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_selenium = False - -try: - assert pkg_resources.get_distribution("playwright") is not None - have_playwright = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_playwright = False - -# disable, hangs too often -have_playwright = False - -image_types = ["png", "jpg", "jpeg"] -non_image_types = [ - "pdf", - "txt", - "csv", - "toml", - "py", - "rst", - "rtf", - "md", - "html", - "mhtml", - "enex", - "eml", - "epub", - "odt", - "pptx", - "ppt", - "zip", - "urls", -] -# "msg", GPL3 - -if have_libreoffice: - non_image_types.extend(["docx", "doc", "xls", "xlsx"]) - -file_types = non_image_types + image_types - - -def add_meta(docs1, file): - file_extension = pathlib.Path(file).suffix - hashid = hash_file(file) - if not isinstance(docs1, (list, tuple, types.GeneratorType)): - docs1 = [docs1] - [ - x.metadata.update( - dict( - input_type=file_extension, - date=str(datetime.now()), - hashid=hashid, - ) - ) - for x in docs1 - ] - - -def file_to_doc( - file, - base_path=None, - verbose=False, - fail_any_exception=False, - chunk=True, - chunk_size=512, - is_url=False, - is_txt=False, - enable_captions=True, - captions_model=None, - enable_ocr=False, - caption_loader=None, - headsize=50, -): - if file is None: - if fail_any_exception: - raise RuntimeError("Unexpected None file") - else: - return [] - doc1 = [] # in case no support, or disabled support - if base_path is None and not is_txt and not is_url: - # then assume want to persist but don't care which path used - # can't be in base_path - dir_name = os.path.dirname(file) - base_name = os.path.basename(file) - # if from gradio, will have its own temp uuid too, but that's ok - base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10] - base_path = os.path.join(dir_name, base_name) - if is_url: - if file.lower().startswith("arxiv:"): - query = file.lower().split("arxiv:") - if len(query) == 2 and have_arxiv: - query = query[1] - docs1 = ArxivLoader( - query=query, load_max_docs=20, load_all_available_meta=True - ).load() - # ensure string, sometimes None - [ - [ - x.metadata.update({k: str(v)}) - for k, v in x.metadata.items() - ] - for x in docs1 - ] - query_url = f"https://arxiv.org/abs/{query}" - [ - x.metadata.update( - dict( - source=x.metadata.get("entry_id", query_url), - query=query_url, - input_type="arxiv", - head=x.metadata.get("Title", ""), - date=str(datetime.now), - ) - ) - for x in docs1 - ] - else: - docs1 = [] - else: - if not ( - file.startswith("http://") - or file.startswith("file://") - or file.startswith("https://") - ): - file = "http://" + file - docs1 = UnstructuredURLLoader(urls=[file]).load() - if len(docs1) == 0 and have_playwright: - # then something went wrong, try another loader: - from langchain.document_loaders import PlaywrightURLLoader - - docs1 = PlaywrightURLLoader(urls=[file]).load() - if len(docs1) == 0 and have_selenium: - # then something went wrong, try another loader: - # but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: Message: unknown error: cannot find Chrome binary - from langchain.document_loaders import SeleniumURLLoader - from selenium.common.exceptions import WebDriverException - - try: - docs1 = SeleniumURLLoader(urls=[file]).load() - except WebDriverException as e: - print("No web driver: %s" % str(e), flush=True) - [ - x.metadata.update( - dict(input_type="url", date=str(datetime.now)) - ) - for x in docs1 - ] - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif is_txt: - base_path = "user_paste" - source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10]) - makedirs(os.path.dirname(source_file), exist_ok=True) - with open(source_file, "wt") as f: - f.write(file) - metadata = dict( - source=source_file, - date=str(datetime.now()), - input_type="pasted txt", - ) - doc1 = Document(page_content=file, metadata=metadata) - doc1 = clean_doc(doc1) - elif file.lower().endswith(".html") or file.lower().endswith(".mhtml"): - docs1 = UnstructuredHTMLLoader(file_path=file).load() - add_meta(docs1, file) - docs1 = clean_doc(docs1) - doc1 = chunk_sources( - docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML - ) - elif ( - file.lower().endswith(".docx") or file.lower().endswith(".doc") - ) and have_libreoffice: - docs1 = UnstructuredWordDocumentLoader(file_path=file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif ( - file.lower().endswith(".xlsx") or file.lower().endswith(".xls") - ) and have_libreoffice: - docs1 = UnstructuredExcelLoader(file_path=file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith(".odt"): - docs1 = UnstructuredODTLoader(file_path=file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith("pptx") or file.lower().endswith("ppt"): - docs1 = UnstructuredPowerPointLoader(file_path=file).load() - add_meta(docs1, file) - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith(".txt"): - # use UnstructuredFileLoader ? - docs1 = TextLoader( - file, encoding="utf8", autodetect_encoding=True - ).load() - # makes just one, but big one - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - doc1 = clean_doc(doc1) - add_meta(doc1, file) - elif file.lower().endswith(".rtf"): - docs1 = UnstructuredRTFLoader(file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith(".md"): - docs1 = UnstructuredMarkdownLoader(file).load() - add_meta(docs1, file) - docs1 = clean_doc(docs1) - doc1 = chunk_sources( - docs1, - chunk=chunk, - chunk_size=chunk_size, - language=Language.MARKDOWN, - ) - elif file.lower().endswith(".enex"): - docs1 = EverNoteLoader(file).load() - add_meta(doc1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith(".epub"): - docs1 = UnstructuredEPubLoader(file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif ( - file.lower().endswith(".jpeg") - or file.lower().endswith(".jpg") - or file.lower().endswith(".png") - ): - docs1 = [] - if have_tesseract and enable_ocr: - # OCR, somewhat works, but not great - docs1.extend(UnstructuredImageLoader(file).load()) - add_meta(docs1, file) - if enable_captions: - # BLIP - if caption_loader is not None and not isinstance( - caption_loader, (str, bool) - ): - # assumes didn't fork into this process with joblib, else can deadlock - caption_loader.set_image_paths([file]) - docs1c = caption_loader.load() - add_meta(docs1c, file) - [ - x.metadata.update( - dict(head=x.page_content[:headsize].strip()) - ) - for x in docs1c - ] - docs1.extend(docs1c) - else: - from image_captions import H2OImageCaptionLoader - - caption_loader = H2OImageCaptionLoader( - caption_gpu=caption_loader == "gpu", - blip_model=captions_model, - blip_processor=captions_model, - ) - caption_loader.set_image_paths([file]) - docs1c = caption_loader.load() - add_meta(docs1c, file) - [ - x.metadata.update( - dict(head=x.page_content[:headsize].strip()) - ) - for x in docs1c - ] - docs1.extend(docs1c) - for doci in docs1: - doci.metadata["source"] = doci.metadata["image_path"] - doci.metadata["hash"] = hash_file(doci.metadata["source"]) - if docs1: - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith(".msg"): - raise RuntimeError("Not supported, GPL3 license") - # docs1 = OutlookMessageLoader(file).load() - # docs1[0].metadata['source'] = file - elif file.lower().endswith(".eml"): - try: - docs1 = UnstructuredEmailLoader(file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - except ValueError as e: - if "text/html content not found in email" in str(e): - # e.g. plain/text dict key exists, but not - # doc1 = TextLoader(file, encoding="utf8").load() - docs1 = UnstructuredEmailLoader( - file, content_source="text/plain" - ).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - else: - raise - # elif file.lower().endswith('.gcsdir'): - # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load() - # elif file.lower().endswith('.gcsfile'): - # doc1 = GCSFileLoader(project_name, bucket, blob).load() - elif file.lower().endswith(".rst"): - with open(file, "r") as f: - doc1 = Document(page_content=f.read(), metadata={"source": file}) - add_meta(doc1, file) - doc1 = chunk_sources( - doc1, chunk=chunk, chunk_size=chunk_size, language=Language.RST - ) - elif file.lower().endswith(".pdf"): - env_gpt4all_file = ".env_gpt4all" - from dotenv import dotenv_values - - env_kwargs = dotenv_values(env_gpt4all_file) - pdf_class_name = env_kwargs.get("PDF_CLASS_NAME", "PyMuPDFParser") - if have_pymupdf and pdf_class_name == "PyMuPDFParser": - # GPL, only use if installed - from langchain.document_loaders import PyMuPDFLoader - - # load() still chunks by pages, but every page has title at start to help - doc1 = PyMuPDFLoader(file).load() - doc1 = clean_doc(doc1) - elif pdf_class_name == "UnstructuredPDFLoader": - doc1 = UnstructuredPDFLoader(file).load() - # seems to not need cleaning in most cases - else: - # open-source fallback - # load() still chunks by pages, but every page has title at start to help - doc1 = PyPDFLoader(file).load() - doc1 = clean_doc(doc1) - # Some PDFs return nothing or junk from PDFMinerLoader - doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size) - add_meta(doc1, file) - elif file.lower().endswith(".csv"): - doc1 = CSVLoader(file).load() - add_meta(doc1, file) - elif file.lower().endswith(".py"): - doc1 = PythonLoader(file).load() - add_meta(doc1, file) - doc1 = chunk_sources( - doc1, chunk=chunk, chunk_size=chunk_size, language=Language.PYTHON - ) - elif file.lower().endswith(".toml"): - doc1 = TomlLoader(file).load() - add_meta(doc1, file) - elif file.lower().endswith(".urls"): - with open(file, "r") as f: - docs1 = UnstructuredURLLoader(urls=f.readlines()).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith(".zip"): - with zipfile.ZipFile(file, "r") as zip_ref: - # don't put into temporary path, since want to keep references to docs inside zip - # so just extract in path where - zip_ref.extractall(base_path) - # recurse - doc1 = path_to_docs( - base_path, - verbose=verbose, - fail_any_exception=fail_any_exception, - ) - else: - raise RuntimeError("No file handler for %s" % os.path.basename(file)) - - # allow doc1 to be list or not. If not list, did not chunk yet, so chunk now - # if list of length one, don't trust and chunk it - if not isinstance(doc1, list): - if chunk: - docs = chunk_sources([doc1], chunk=chunk, chunk_size=chunk_size) - else: - docs = [doc1] - elif isinstance(doc1, list) and len(doc1) == 1: - if chunk: - docs = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size) - else: - docs = doc1 - else: - docs = doc1 - - assert isinstance(docs, list) - return docs - - -def path_to_doc1( - file, - verbose=False, - fail_any_exception=False, - return_file=True, - chunk=True, - chunk_size=512, - is_url=False, - is_txt=False, - enable_captions=True, - captions_model=None, - enable_ocr=False, - caption_loader=None, -): - if verbose: - if is_url: - print("Ingesting URL: %s" % file, flush=True) - elif is_txt: - print("Ingesting Text: %s" % file, flush=True) - else: - print("Ingesting file: %s" % file, flush=True) - res = None - try: - # don't pass base_path=path, would infinitely recurse - res = file_to_doc( - file, - base_path=None, - verbose=verbose, - fail_any_exception=fail_any_exception, - chunk=chunk, - chunk_size=chunk_size, - is_url=is_url, - is_txt=is_txt, - enable_captions=enable_captions, - captions_model=captions_model, - enable_ocr=enable_ocr, - caption_loader=caption_loader, - ) - except BaseException as e: - print("Failed to ingest %s due to %s" % (file, traceback.format_exc())) - if fail_any_exception: - raise - else: - exception_doc = Document( - page_content="", - metadata={ - "source": file, - "exception": str(e), - "traceback": traceback.format_exc(), - }, - ) - res = [exception_doc] - if return_file: - base_tmp = "temp_path_to_doc1" - if not os.path.isdir(base_tmp): - os.makedirs(base_tmp, exist_ok=True) - filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle") - with open(filename, "wb") as f: - pickle.dump(res, f) - return filename - return res - - -def path_to_docs( - path_or_paths, - verbose=False, - fail_any_exception=False, - n_jobs=-1, - chunk=True, - chunk_size=512, - url=None, - text=None, - enable_captions=True, - captions_model=None, - caption_loader=None, - enable_ocr=False, - existing_files=[], - existing_hash_ids={}, -): - # path_or_paths could be str, list, tuple, generator - globs_image_types = [] - globs_non_image_types = [] - if not path_or_paths and not url and not text: - return [] - elif url: - globs_non_image_types = ( - url - if isinstance(url, (list, tuple, types.GeneratorType)) - else [url] - ) - elif text: - globs_non_image_types = ( - text - if isinstance(text, (list, tuple, types.GeneratorType)) - else [text] - ) - elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths): - # single path, only consume allowed files - path = path_or_paths - # Below globs should match patterns in file_to_doc() - [ - globs_image_types.extend( - glob.glob( - os.path.join(path, "./**/*.%s" % ftype), recursive=True - ) - ) - for ftype in image_types - ] - [ - globs_non_image_types.extend( - glob.glob( - os.path.join(path, "./**/*.%s" % ftype), recursive=True - ) - ) - for ftype in non_image_types - ] - else: - if isinstance(path_or_paths, str) and ( - os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths) - ): - path_or_paths = [path_or_paths] - # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows) - assert isinstance( - path_or_paths, (list, tuple, types.GeneratorType) - ), "Wrong type for path_or_paths: %s" % type(path_or_paths) - # reform out of allowed types - globs_image_types.extend( - flatten_list( - [ - [x for x in path_or_paths if x.endswith(y)] - for y in image_types - ] - ) - ) - # could do below: - # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types]) - # But instead, allow fail so can collect unsupported too - set_globs_image_types = set(globs_image_types) - globs_non_image_types.extend( - [x for x in path_or_paths if x not in set_globs_image_types] - ) - - # filter out any files to skip (e.g. if already processed them) - # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[] - assert not existing_files, "DEV: assume not using this approach" - if existing_files: - set_skip_files = set(existing_files) - globs_image_types = [ - x for x in globs_image_types if x not in set_skip_files - ] - globs_non_image_types = [ - x for x in globs_non_image_types if x not in set_skip_files - ] - if existing_hash_ids: - # assume consistent with add_meta() use of hash_file(file) - # also assume consistent with get_existing_hash_ids for dict creation - # assume hashable values - existing_hash_ids_set = set(existing_hash_ids.items()) - hash_ids_all_image = set( - {x: hash_file(x) for x in globs_image_types}.items() - ) - hash_ids_all_non_image = set( - {x: hash_file(x) for x in globs_non_image_types}.items() - ) - # don't use symmetric diff. If file is gone, ignore and don't remove or something - # just consider existing files (key) having new hash or not (value) - new_files_image = set( - dict(hash_ids_all_image - existing_hash_ids_set).keys() - ) - new_files_non_image = set( - dict(hash_ids_all_non_image - existing_hash_ids_set).keys() - ) - globs_image_types = [ - x for x in globs_image_types if x in new_files_image - ] - globs_non_image_types = [ - x for x in globs_non_image_types if x in new_files_non_image - ] - - # could use generator, but messes up metadata handling in recursive case - if ( - caption_loader - and not isinstance(caption_loader, (bool, str)) - and caption_loader.device != "cpu" - or args.device == "cuda" - ): - # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context - n_jobs_image = 1 - else: - n_jobs_image = n_jobs - - return_file = True # local choice - is_url = url is not None - is_txt = text is not None - kwargs = dict( - verbose=verbose, - fail_any_exception=fail_any_exception, - return_file=return_file, - chunk=chunk, - chunk_size=chunk_size, - is_url=is_url, - is_txt=is_txt, - enable_captions=enable_captions, - captions_model=captions_model, - caption_loader=caption_loader, - enable_ocr=enable_ocr, - ) - - if n_jobs != 1 and len(globs_non_image_types) > 1: - # avoid nesting, e.g. upload 1 zip and then inside many files - # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib - documents = ProgressParallel( - n_jobs=n_jobs, - verbose=10 if verbose else 0, - backend="multiprocessing", - )( - delayed(path_to_doc1)(file, **kwargs) - for file in globs_non_image_types - ) - else: - documents = [ - path_to_doc1(file, **kwargs) - for file in tqdm(globs_non_image_types) - ] - - # do images separately since can't fork after cuda in parent, so can't be parallel - if n_jobs_image != 1 and len(globs_image_types) > 1: - # avoid nesting, e.g. upload 1 zip and then inside many files - # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib - image_documents = ProgressParallel( - n_jobs=n_jobs, - verbose=10 if verbose else 0, - backend="multiprocessing", - )(delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types) - else: - image_documents = [ - path_to_doc1(file, **kwargs) for file in tqdm(globs_image_types) - ] - - # add image docs in - documents += image_documents - - if return_file: - # then documents really are files - files = documents.copy() - documents = [] - for fil in files: - with open(fil, "rb") as f: - documents.extend(pickle.load(f)) - # remove temp pickle - os.remove(fil) - else: - documents = reduce(concat, documents) - return documents - - -def prep_langchain( - persist_directory, - load_db_if_exists, - db_type, - use_openai_embedding, - langchain_mode, - user_path, - hf_embedding_model, - n_jobs=-1, - kwargs_make_db={}, -): - """ - do prep first time, involving downloads - # FIXME: Add github caching then add here - :return: - """ - assert langchain_mode not in ["MyData"], "Should not prep scratch data" - - db_dir_exists = os.path.isdir(persist_directory) - - if db_dir_exists and user_path is None: - print( - "Prep: persist_directory=%s exists, using" % persist_directory, - flush=True, - ) - db = get_existing_db( - None, - persist_directory, - load_db_if_exists, - db_type, - use_openai_embedding, - langchain_mode, - hf_embedding_model, - ) - else: - if db_dir_exists and user_path is not None: - print( - "Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" - % (persist_directory, user_path), - flush=True, - ) - elif not db_dir_exists: - print( - "Prep: persist_directory=%s does not exist, regenerating" - % persist_directory, - flush=True, - ) - db = None - if langchain_mode in ["All", "DriverlessAI docs"]: - # FIXME: Could also just use dai_docs.pickle directly and upload that - get_dai_docs(from_hf=True) - - if langchain_mode in ["All", "wiki"]: - get_wiki_sources( - first_para=kwargs_make_db["first_para"], - text_limit=kwargs_make_db["text_limit"], - ) - - langchain_kwargs = kwargs_make_db.copy() - langchain_kwargs.update(locals()) - db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs) - - return db - - -import posthog - -posthog.disabled = True - - -class FakeConsumer(object): - def __init__(self, *args, **kwargs): - pass - - def run(self): - pass - - def pause(self): - pass - - def upload(self): - pass - - def next(self): - pass - - def request(self, batch): - pass - - -posthog.Consumer = FakeConsumer - - -def check_update_chroma_embedding( - db, use_openai_embedding, hf_embedding_model, langchain_mode -): - changed_db = False - if load_embed(db) != (use_openai_embedding, hf_embedding_model): - print( - "Detected new embedding, updating db: %s" % langchain_mode, - flush=True, - ) - # handle embedding changes - db_get = get_documents(db) - sources = [ - Document(page_content=result[0], metadata=result[1] or {}) - for result in zip(db_get["documents"], db_get["metadatas"]) - ] - # delete index, has to be redone - persist_directory = db._persist_directory - shutil.move( - persist_directory, - persist_directory + "_" + str(uuid.uuid4()) + ".bak", - ) - db_type = "chroma" - load_db_if_exists = False - db = get_db( - sources, - use_openai_embedding=use_openai_embedding, - db_type=db_type, - persist_directory=persist_directory, - load_db_if_exists=load_db_if_exists, - langchain_mode=langchain_mode, - collection_name=None, - hf_embedding_model=hf_embedding_model, - ) - if False: - # below doesn't work if db already in memory, so have to switch to new db as above - # upsert does new embedding, but if index already in memory, complains about size mismatch etc. - client_collection = db._client.get_collection( - name=db._collection.name, - embedding_function=db._collection._embedding_function, - ) - client_collection.upsert( - ids=db_get["ids"], - metadatas=db_get["metadatas"], - documents=db_get["documents"], - ) - changed_db = True - print( - "Done updating db for new embedding: %s" % langchain_mode, - flush=True, - ) - - return db, changed_db - - -def get_existing_db( - db, - persist_directory, - load_db_if_exists, - db_type, - use_openai_embedding, - langchain_mode, - hf_embedding_model, - verbose=False, - check_embedding=True, -): - if ( - load_db_if_exists - and db_type == "chroma" - and os.path.isdir(persist_directory) - and os.path.isdir(os.path.join(persist_directory, "index")) - ): - if db is None: - if verbose: - print("DO Loading db: %s" % langchain_mode, flush=True) - embedding = get_embedding( - use_openai_embedding, hf_embedding_model=hf_embedding_model - ) - from chromadb.config import Settings - - client_settings = Settings( - anonymized_telemetry=False, - chroma_db_impl="duckdb+parquet", - persist_directory=persist_directory, - ) - db = Chroma( - persist_directory=persist_directory, - embedding_function=embedding, - collection_name=langchain_mode.replace(" ", "_"), - client_settings=client_settings, - ) - if verbose: - print("DONE Loading db: %s" % langchain_mode, flush=True) - else: - if verbose: - print( - "USING already-loaded db: %s" % langchain_mode, flush=True - ) - if check_embedding: - db_trial, changed_db = check_update_chroma_embedding( - db, use_openai_embedding, hf_embedding_model, langchain_mode - ) - if changed_db: - db = db_trial - # only call persist if really changed db, else takes too long for large db - if db is not None: - db.persist() - clear_embedding(db) - save_embed(db, use_openai_embedding, hf_embedding_model) - return db - return None - - -def clear_embedding(db): - if db is None: - return - # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed - db._embedding_function.client.cpu() - clear_torch_cache() - - -def make_db(**langchain_kwargs): - func_names = list(inspect.signature(_make_db).parameters) - missing_kwargs = [x for x in func_names if x not in langchain_kwargs] - defaults_db = { - k: v.default - for k, v in dict(inspect.signature(run_qa_db).parameters).items() - } - for k in missing_kwargs: - if k in defaults_db: - langchain_kwargs[k] = defaults_db[k] - # final check for missing - missing_kwargs = [x for x in func_names if x not in langchain_kwargs] - assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs - # only keep actual used - langchain_kwargs = { - k: v for k, v in langchain_kwargs.items() if k in func_names - } - return _make_db(**langchain_kwargs) - - -def save_embed(db, use_openai_embedding, hf_embedding_model): - if db is not None: - embed_info_file = os.path.join(db._persist_directory, "embed_info") - with open(embed_info_file, "wb") as f: - pickle.dump((use_openai_embedding, hf_embedding_model), f) - return use_openai_embedding, hf_embedding_model - - -def load_embed(db): - embed_info_file = os.path.join(db._persist_directory, "embed_info") - if os.path.isfile(embed_info_file): - with open(embed_info_file, "rb") as f: - use_openai_embedding, hf_embedding_model = pickle.load(f) - else: - # migration, assume defaults - use_openai_embedding, hf_embedding_model = ( - False, - "sentence-transformers/all-MiniLM-L6-v2", - ) - return use_openai_embedding, hf_embedding_model - - -def get_persist_directory(langchain_mode): - return ( - "db_dir_%s" % langchain_mode - ) # single place, no special names for each case - - -def _make_db( - use_openai_embedding=False, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", - first_para=False, - text_limit=None, - chunk=True, - chunk_size=512, - langchain_mode=None, - user_path=None, - db_type="faiss", - load_db_if_exists=True, - db=None, - n_jobs=-1, - verbose=False, -): - persist_directory = get_persist_directory(langchain_mode) - # see if can get persistent chroma db - db_trial = get_existing_db( - db, - persist_directory, - load_db_if_exists, - db_type, - use_openai_embedding, - langchain_mode, - hf_embedding_model, - verbose=verbose, - ) - if db_trial is not None: - db = db_trial - - sources = [] - if ( - not db - and langchain_mode not in ["MyData"] - or user_path is not None - and langchain_mode in ["UserData"] - ): - # Should not make MyData db this way, why avoided, only upload from UI - assert langchain_mode not in [ - "MyData" - ], "Should not make MyData db this way" - if verbose: - if langchain_mode in ["UserData"]: - if user_path is not None: - print( - "Checking if changed or new sources in %s, and generating sources them" - % user_path, - flush=True, - ) - elif db is None: - print( - "user_path not passed and no db, no sources", - flush=True, - ) - else: - print( - "user_path not passed, using only existing db, no new sources", - flush=True, - ) - else: - print("Generating %s sources" % langchain_mode, flush=True) - if langchain_mode in ["wiki_full", "All", "'All'"]: - from read_wiki_full import get_all_documents - - small_test = None - print("Generating new wiki", flush=True) - sources1 = get_all_documents( - small_test=small_test, n_jobs=os.cpu_count() // 2 - ) - print("Got new wiki", flush=True) - if chunk: - sources1 = chunk_sources( - sources1, chunk=chunk, chunk_size=chunk_size - ) - print("Chunked new wiki", flush=True) - sources.extend(sources1) - if langchain_mode in ["wiki", "All", "'All'"]: - sources1 = get_wiki_sources( - first_para=first_para, text_limit=text_limit - ) - if chunk: - sources1 = chunk_sources( - sources1, chunk=chunk, chunk_size=chunk_size - ) - sources.extend(sources1) - if langchain_mode in ["github h2oGPT", "All", "'All'"]: - # sources = get_github_docs("dagster-io", "dagster") - sources1 = get_github_docs("h2oai", "h2ogpt") - # FIXME: always chunk for now - sources1 = chunk_sources( - sources1, chunk=chunk, chunk_size=chunk_size - ) - sources.extend(sources1) - if langchain_mode in ["DriverlessAI docs", "All", "'All'"]: - sources1 = get_dai_docs(from_hf=True) - if ( - chunk and False - ): # FIXME: DAI docs are already chunked well, should only chunk more if over limit - sources1 = chunk_sources( - sources1, chunk=chunk, chunk_size=chunk_size - ) - sources.extend(sources1) - if langchain_mode in ["All", "UserData"]: - if user_path: - if db is not None: - # NOTE: Ignore file names for now, only go by hash ids - # existing_files = get_existing_files(db) - existing_files = [] - existing_hash_ids = get_existing_hash_ids(db) - else: - # pretend no existing files so won't filter - existing_files = [] - existing_hash_ids = [] - # chunk internally for speed over multiple docs - # FIXME: If first had old Hash=None and switch embeddings, - # then re-embed, and then hit here and reload so have hash, and then re-embed. - sources1 = path_to_docs( - user_path, - n_jobs=n_jobs, - chunk=chunk, - chunk_size=chunk_size, - existing_files=existing_files, - existing_hash_ids=existing_hash_ids, - ) - new_metadata_sources = set( - [x.metadata["source"] for x in sources1] - ) - if new_metadata_sources: - print( - "Loaded %s new files as sources to add to UserData" - % len(new_metadata_sources), - flush=True, - ) - if verbose: - print( - "Files added: %s" - % "\n".join(new_metadata_sources), - flush=True, - ) - sources.extend(sources1) - print( - "Loaded %s sources for potentially adding to UserData" - % len(sources), - flush=True, - ) - else: - print("Chose UserData but user_path is empty/None", flush=True) - if False and langchain_mode in ["urls", "All", "'All'"]: - # from langchain.document_loaders import UnstructuredURLLoader - # loader = UnstructuredURLLoader(urls=urls) - urls = ["https://www.birdsongsf.com/who-we-are/"] - from langchain.document_loaders import PlaywrightURLLoader - - loader = PlaywrightURLLoader( - urls=urls, remove_selectors=["header", "footer"] - ) - sources1 = loader.load() - sources.extend(sources1) - if not sources: - if verbose: - if db is not None: - print( - "langchain_mode %s has no new sources, nothing to add to db" - % langchain_mode, - flush=True, - ) - else: - print( - "langchain_mode %s has no sources, not making new db" - % langchain_mode, - flush=True, - ) - return db, 0, [] - if verbose: - if db is not None: - print("Generating db", flush=True) - else: - print("Adding to db", flush=True) - if not db: - if sources: - db = get_db( - sources, - use_openai_embedding=use_openai_embedding, - db_type=db_type, - persist_directory=persist_directory, - langchain_mode=langchain_mode, - hf_embedding_model=hf_embedding_model, - ) - if verbose: - print("Generated db", flush=True) - else: - print("Did not generate db since no sources", flush=True) - new_sources_metadata = [x.metadata for x in sources] - elif user_path is not None and langchain_mode in ["UserData"]: - print( - "Existing db, potentially adding %s sources from user_path=%s" - % (len(sources), user_path), - flush=True, - ) - db, num_new_sources, new_sources_metadata = add_to_db( - db, - sources, - db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model, - ) - print( - "Existing db, added %s new sources from user_path=%s" - % (num_new_sources, user_path), - flush=True, - ) - else: - new_sources_metadata = [x.metadata for x in sources] - - return db, len(new_sources_metadata), new_sources_metadata - - -def get_metadatas(db): - from langchain.vectorstores import FAISS - - if isinstance(db, FAISS): - metadatas = [v.metadata for k, v in db.docstore._dict.items()] - elif isinstance(db, Chroma): - metadatas = get_documents(db)["metadatas"] - else: - # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947 - # seems no way to get all metadata, so need to avoid this approach for weaviate - metadatas = [x.metadata for x in db.similarity_search("", k=10000)] - return metadatas - - -def get_documents(db): - if hasattr(db, "_persist_directory"): - name_path = os.path.basename(db._persist_directory) - base_path = "locks" - makedirs(base_path) - with filelock.FileLock( - os.path.join(base_path, "getdb_%s.lock" % name_path) - ): - # get segfaults and other errors when multiple threads access this - return _get_documents(db) - else: - return _get_documents(db) - - -def _get_documents(db): - from langchain.vectorstores import FAISS - - if isinstance(db, FAISS): - documents = [v for k, v in db.docstore._dict.items()] - elif isinstance(db, Chroma): - documents = db.get() - else: - # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947 - # seems no way to get all metadata, so need to avoid this approach for weaviate - documents = [x for x in db.similarity_search("", k=10000)] - return documents - - -def get_docs_and_meta(db, top_k_docs, filter_kwargs={}): - if hasattr(db, "_persist_directory"): - name_path = os.path.basename(db._persist_directory) - base_path = "locks" - makedirs(base_path) - with filelock.FileLock( - os.path.join(base_path, "getdb_%s.lock" % name_path) - ): - return _get_docs_and_meta( - db, top_k_docs, filter_kwargs=filter_kwargs - ) - else: - return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs) - - -def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}): - from langchain.vectorstores import FAISS - - if isinstance(db, Chroma): - db_get = db._collection.get(where=filter_kwargs.get("filter")) - db_metadatas = db_get["metadatas"] - db_documents = db_get["documents"] - elif isinstance(db, FAISS): - import itertools - - db_metadatas = get_metadatas(db) - # FIXME: FAISS has no filter - # slice dict first - db_documents = list( - dict( - itertools.islice(db.docstore._dict.items(), top_k_docs) - ).values() - ) - else: - db_metadatas = get_metadatas(db) - db_documents = get_documents(db) - return db_documents, db_metadatas - - -def get_existing_files(db): - metadatas = get_metadatas(db) - metadata_sources = set([x["source"] for x in metadatas]) - return metadata_sources - - -def get_existing_hash_ids(db): - metadatas = get_metadatas(db) - # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks - metadata_hash_ids = {x["source"]: x.get("hashid") for x in metadatas} - return metadata_hash_ids - - -def run_qa_db(**kwargs): - func_names = list(inspect.signature(_run_qa_db).parameters) - # hard-coded defaults - kwargs["answer_with_sources"] = True - kwargs["show_rank"] = False - missing_kwargs = [x for x in func_names if x not in kwargs] - assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs - # only keep actual used - kwargs = {k: v for k, v in kwargs.items() if k in func_names} - try: - return _run_qa_db(**kwargs) - finally: - clear_torch_cache() - - -def _run_qa_db( - query=None, - iinput=None, - context=None, - use_openai_model=False, - use_openai_embedding=False, - first_para=False, - text_limit=None, - top_k_docs=4, - chunk=True, - chunk_size=512, - user_path=None, - detect_user_path_changes_every_query=False, - db_type="faiss", - model_name=None, - model=None, - tokenizer=None, - inference_server=None, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", - stream_output=False, - prompter=None, - prompt_type=None, - prompt_dict=None, - answer_with_sources=True, - cut_distanct=1.1, - sanitize_bot_response=False, - show_rank=False, - load_db_if_exists=False, - db=None, - do_sample=False, - temperature=0.1, - top_k=40, - top_p=0.7, - num_beams=1, - max_new_tokens=256, - min_new_tokens=1, - early_stopping=False, - max_time=180, - repetition_penalty=1.0, - num_return_sequences=1, - langchain_mode=None, - langchain_action=None, - document_choice=[DocumentChoices.All_Relevant.name], - n_jobs=-1, - verbose=False, - cli=False, - reverse_docs=True, - lora_weights="", - auto_reduce_chunks=True, - max_chunks=100, -): - """ - - :param query: - :param use_openai_model: - :param use_openai_embedding: - :param first_para: - :param text_limit: - :param top_k_docs: - :param chunk: - :param chunk_size: - :param user_path: user path to glob recursively from - :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db - :param model_name: model name, used to switch behaviors - :param model: pre-initialized model, else will make new one - :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None - :param answer_with_sources - :return: - """ - if model is not None: - assert model_name is not None # require so can make decisions - assert query is not None - assert ( - prompter is not None or prompt_type is not None or model is None - ) # if model is None, then will generate - if prompter is not None: - prompt_type = prompter.prompt_type - prompt_dict = prompter.prompt_dict - if model is not None: - assert prompt_type is not None - if prompt_type == PromptType.custom.name: - assert prompt_dict is not None # should at least be {} or '' - else: - prompt_dict = "" - assert ( - len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) - == 0 - ) - llm, model_name, streamer, prompt_type_out = get_llm( - use_openai_model=use_openai_model, - model_name=model_name, - model=model, - tokenizer=tokenizer, - inference_server=inference_server, - stream_output=stream_output, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - prompter=prompter, - sanitize_bot_response=sanitize_bot_response, - verbose=verbose, - ) - - use_context = False - scores = [] - chain = None - - if isinstance(document_choice, str): - # support string as well - document_choice = [document_choice] - # get first DocumentChoices as command to use, ignore others - doc_choices_set = set([x.name for x in list(DocumentChoices)]) - cmd = [x for x in document_choice if x in doc_choices_set] - cmd = None if len(cmd) == 0 else cmd[0] - # now have cmd, filter out for only docs - document_choice = [x for x in document_choice if x not in doc_choices_set] - - func_names = list(inspect.signature(get_similarity_chain).parameters) - sim_kwargs = {k: v for k, v in locals().items() if k in func_names} - missing_kwargs = [x for x in func_names if x not in sim_kwargs] - assert not missing_kwargs, "Missing: %s" % missing_kwargs - docs, chain, scores, use_context, have_any_docs = get_similarity_chain( - **sim_kwargs - ) - if cmd in non_query_commands: - formatted_doc_chunks = "\n\n".join( - [get_url(x) + "\n\n" + x.page_content for x in docs] - ) - return formatted_doc_chunks, "" - if not docs and langchain_action in [ - LangChainAction.SUMMARIZE_MAP.value, - LangChainAction.SUMMARIZE_ALL.value, - LangChainAction.SUMMARIZE_REFINE.value, - ]: - ret = ( - "No relevant documents to summarize." - if have_any_docs - else "No documents to summarize." - ) - extra = "" - return ret, extra - if not docs and langchain_mode not in [ - LangChainMode.DISABLED.value, - LangChainMode.CHAT_LLM.value, - LangChainMode.LLM.value, - ]: - ret = ( - "No relevant documents to query." - if have_any_docs - else "No documents to query." - ) - extra = "" - return ret, extra - - if chain is None and model_name not in non_hf_types: - # here if no docs at all and not HF type - # can only return if HF type - return - - # context stuff similar to used in evaluate() - import torch - - torch_dtype, context_class = get_dtype() - with torch.no_grad(): - have_lora_weights = lora_weights not in [no_lora_str, "", None] - context_class_cast = ( - NullContext - if args.device == "cpu" or have_lora_weights - else torch.autocast - ) - with context_class_cast(args.device): - answer = chain() - return answer - - -def get_similarity_chain( - query=None, - iinput=None, - use_openai_model=False, - use_openai_embedding=False, - first_para=False, - text_limit=None, - top_k_docs=4, - chunk=True, - chunk_size=512, - user_path=None, - detect_user_path_changes_every_query=False, - db_type="faiss", - model_name=None, - inference_server="", - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", - prompt_type=None, - prompt_dict=None, - cut_distanct=1.1, - load_db_if_exists=False, - db=None, - langchain_mode=None, - langchain_action=None, - document_choice=[DocumentChoices.All_Relevant.name], - n_jobs=-1, - # beyond run_db_query: - llm=None, - tokenizer=None, - verbose=False, - cmd=None, - reverse_docs=True, - # local - auto_reduce_chunks=True, - max_chunks=100, -): - # determine whether use of context out of docs is planned - if ( - not use_openai_model - and prompt_type not in ["plain"] - or model_name in non_hf_types - ): - if langchain_mode in ["Disabled", "ChatLLM", "LLM"]: - use_context = False - else: - use_context = True - else: - use_context = True - - # https://github.com/hwchase17/langchain/issues/1946 - # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid - # Chroma collection MyData contains fewer than 4 elements. - # type logger error - if top_k_docs == -1: - k_db = 1000 if db_type == "chroma" else 100 - else: - # top_k_docs=100 works ok too - k_db = 1000 if db_type == "chroma" else top_k_docs - - # FIXME: For All just go over all dbs instead of a separate db for All - if not detect_user_path_changes_every_query and db is not None: - # avoid looking at user_path during similarity search db handling, - # if already have db and not updating from user_path every query - # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was - user_path = None - db, num_new_sources, new_sources_metadata = make_db( - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model, - first_para=first_para, - text_limit=text_limit, - chunk=chunk, - chunk_size=chunk_size, - langchain_mode=langchain_mode, - user_path=user_path, - db_type=db_type, - load_db_if_exists=load_db_if_exists, - db=db, - n_jobs=n_jobs, - verbose=verbose, - ) - have_any_docs = db is not None - if langchain_action == LangChainAction.QUERY.value: - if iinput: - query = "%s\n%s" % (query, iinput) - - if "falcon" in model_name: - extra = "According to only the information in the document sources provided within the context above, " - prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends." - elif inference_server in ["openai", "openai_chat"]: - extra = "According to (primarily) the information in the document sources provided within context above, " - prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents." - else: - extra = "" - prefix = "" - if langchain_mode in ["Disabled", "ChatLLM", "LLM"] or not use_context: - template_if_no_docs = template = ( - """%s{context}{question}""" % prefix - ) - else: - template = """%s - \"\"\" - {context} - \"\"\" - %s{question}""" % ( - prefix, - extra, - ) - template_if_no_docs = """%s{context}%s{question}""" % ( - prefix, - extra, - ) - elif langchain_action in [ - LangChainAction.SUMMARIZE_ALL.value, - LangChainAction.SUMMARIZE_MAP.value, - ]: - none = ["", "\n", None] - if query in none and iinput in none: - prompt_summary = "Using only the text above, write a condensed and concise summary:\n" - elif query not in none: - prompt_summary = ( - "Focusing on %s, write a condensed and concise Summary:\n" - % query - ) - elif iinput not in None: - prompt_summary = iinput - else: - prompt_summary = "Focusing on %s, %s:\n" % (query, iinput) - # don't auto reduce - auto_reduce_chunks = False - if langchain_action == LangChainAction.SUMMARIZE_MAP.value: - fstring = "{text}" - else: - fstring = "{input_documents}" - template = """In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text: -\"\"\" -%s -\"\"\"\n%s""" % ( - fstring, - prompt_summary, - ) - template_if_no_docs = ( - "Exactly only say: There are no documents to summarize." - ) - elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]: - template = "" # unused - template_if_no_docs = "" # unused - else: - raise RuntimeError("No such langchain_action=%s" % langchain_action) - - if ( - not use_openai_model - and prompt_type not in ["plain"] - or model_name in non_hf_types - ): - use_template = True - else: - use_template = False - - if db and use_context: - base_path = "locks" - makedirs(base_path) - if hasattr(db, "_persist_directory"): - name_path = "sim_%s.lock" % os.path.basename(db._persist_directory) - else: - name_path = "sim.lock" - lock_file = os.path.join(base_path, name_path) - - if not isinstance(db, Chroma): - # only chroma supports filtering - filter_kwargs = {} - else: - # if here then some cmd + documents selected or just documents selected - if len(document_choice) >= 2: - or_filter = [{"source": {"$eq": x}} for x in document_choice] - filter_kwargs = dict(filter={"$or": or_filter}) - elif len(document_choice) == 1: - # degenerate UX bug in chroma - one_filter = [{"source": {"$eq": x}} for x in document_choice][ - 0 - ] - filter_kwargs = dict(filter=one_filter) - else: - # shouldn't reach - filter_kwargs = {} - if cmd == DocumentChoices.Just_LLM.name: - docs = [] - scores = [] - elif cmd == DocumentChoices.Only_All_Sources.name or query in [ - None, - "", - "\n", - ]: - db_documents, db_metadatas = get_docs_and_meta( - db, top_k_docs, filter_kwargs=filter_kwargs - ) - # similar to langchain's chroma's _results_to_docs_and_scores - docs_with_score = [ - (Document(page_content=result[0], metadata=result[1] or {}), 0) - for result in zip(db_documents, db_metadatas) - ] - - # order documents - doc_hashes = [x["doc_hash"] for x in db_metadatas] - doc_chunk_ids = [x["chunk_id"] for x in db_metadatas] - docs_with_score = [ - x - for _, _, x in sorted( - zip(doc_hashes, doc_chunk_ids, docs_with_score), - key=lambda x: (x[0], x[1]), - ) - ] - - docs_with_score = docs_with_score[:top_k_docs] - docs = [x[0] for x in docs_with_score] - scores = [x[1] for x in docs_with_score] - have_any_docs |= len(docs) > 0 - else: - # FIXME: if langchain_action == LangChainAction.SUMMARIZE_MAP.value - # if map_reduce, then no need to auto reduce chunks - if top_k_docs == -1 or auto_reduce_chunks: - # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs] - top_k_docs_tokenize = 100 - with filelock.FileLock(lock_file): - docs_with_score = db.similarity_search_with_score( - query, k=k_db, **filter_kwargs - )[:top_k_docs_tokenize] - if hasattr(llm, "pipeline") and hasattr( - llm.pipeline, "tokenizer" - ): - # more accurate - tokens = [ - len( - llm.pipeline.tokenizer(x[0].page_content)[ - "input_ids" - ] - ) - for x in docs_with_score - ] - template_tokens = len( - llm.pipeline.tokenizer(template)["input_ids"] - ) - elif ( - inference_server in ["openai", "openai_chat"] - or use_openai_model - or db_type in ["faiss", "weaviate"] - ): - # use ticktoken for faiss since embedding called differently - tokens = [ - llm.get_num_tokens(x[0].page_content) - for x in docs_with_score - ] - template_tokens = llm.get_num_tokens(template) - elif isinstance(tokenizer, FakeTokenizer): - tokens = [ - tokenizer.num_tokens_from_string(x[0].page_content) - for x in docs_with_score - ] - template_tokens = tokenizer.num_tokens_from_string( - template - ) - else: - # in case model is not our pipeline with HF tokenizer - tokens = [ - db._embedding_function.client.tokenize( - [x[0].page_content] - )["input_ids"].shape[1] - for x in docs_with_score - ] - template_tokens = db._embedding_function.client.tokenize( - [template] - )["input_ids"].shape[1] - tokens_cumsum = np.cumsum(tokens) - if hasattr(llm, "pipeline") and hasattr( - llm.pipeline, "max_input_tokens" - ): - max_input_tokens = llm.pipeline.max_input_tokens - elif inference_server in ["openai"]: - max_tokens = llm.modelname_to_contextsize(model_name) - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens = max_tokens - 256 - elif inference_server in ["openai_chat"]: - max_tokens = model_token_mapping[model_name] - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens = max_tokens - 256 - elif isinstance(tokenizer, FakeTokenizer): - max_input_tokens = tokenizer.model_max_length - 256 - else: - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens = 2048 - 256 - max_input_tokens -= template_tokens - # FIXME: Doesn't account for query, == context, or new lines between contexts - where_res = np.where(tokens_cumsum < max_input_tokens)[0] - if where_res.shape[0] == 0: - # then no chunk can fit, still do first one - top_k_docs_trial = 1 - else: - top_k_docs_trial = 1 + where_res[-1] - if 0 < top_k_docs_trial < max_chunks: - # avoid craziness - if top_k_docs == -1: - top_k_docs = top_k_docs_trial - else: - top_k_docs = min(top_k_docs, top_k_docs_trial) - if top_k_docs == -1: - # if here, means 0 and just do best with 1 doc - print( - "Unexpected large chunks and can't add to context, will add 1 anyways", - flush=True, - ) - top_k_docs = 1 - docs_with_score = docs_with_score[:top_k_docs] - else: - with filelock.FileLock(lock_file): - docs_with_score = db.similarity_search_with_score( - query, k=k_db, **filter_kwargs - )[:top_k_docs] - # put most relevant chunks closest to question, - # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated - # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest - if reverse_docs: - docs_with_score.reverse() - # cut off so no high distance docs/sources considered - have_any_docs |= len(docs_with_score) > 0 # before cut - docs = [x[0] for x in docs_with_score if x[1] < cut_distanct] - scores = [x[1] for x in docs_with_score if x[1] < cut_distanct] - if len(scores) > 0 and verbose: - print( - "Distance: min: %s max: %s mean: %s median: %s" - % ( - scores[0], - scores[-1], - np.mean(scores), - np.median(scores), - ), - flush=True, - ) - else: - docs = [] - scores = [] - - if not docs and use_context and model_name not in non_hf_types: - # if HF type and have no docs, can bail out - return docs, None, [], False, have_any_docs - - if cmd in non_query_commands: - # no LLM use - return docs, None, [], False, have_any_docs - - common_words_file = "data/NGSL_1.2_stats.csv.zip" - if ( - os.path.isfile(common_words_file) - and langchain_mode == LangChainAction.QUERY.value - ): - df = pd.read_csv("data/NGSL_1.2_stats.csv.zip") - import string - - reduced_query = query.translate( - str.maketrans(string.punctuation, " " * len(string.punctuation)) - ).strip() - reduced_query_words = reduced_query.split(" ") - set_common = set(df["Lemma"].values.tolist()) - num_common = len( - [x.lower() in set_common for x in reduced_query_words] - ) - frac_common = num_common / len(reduced_query) if reduced_query else 0 - # FIXME: report to user bad query that uses too many common words - if verbose: - print("frac_common: %s" % frac_common, flush=True) - - if len(docs) == 0: - # avoid context == in prompt then - use_context = False - template = template_if_no_docs - - if langchain_action == LangChainAction.QUERY.value: - if use_template: - # instruct-like, rather than few-shot prompt_type='plain' as default - # but then sources confuse the model with how inserted among rest of text, so avoid - prompt = PromptTemplate( - # input_variables=["summaries", "question"], - input_variables=["context", "question"], - template=template, - ) - chain = load_qa_chain(llm, prompt=prompt) - chain_kwargs = dict(input_documents=docs, question=query) - target = wrapped_partial(chain, chain_kwargs) - else: - raise RuntimeError("No such langchain_action=%s" % langchain_action) - - return docs, target, scores, use_context, have_any_docs - - -def get_sources_answer( - query, answer, scores, show_rank, answer_with_sources, verbose=False -): - if verbose: - print("query: %s" % query, flush=True) - print("answer: %s" % answer["output_text"], flush=True) - - if len(answer["input_documents"]) == 0: - extra = "" - ret = answer["output_text"] + extra - return ret, extra - - # link - answer_sources = [ - (max(0.0, 1.5 - score) / 1.5, get_url(doc)) - for score, doc in zip(scores, answer["input_documents"]) - ] - answer_sources_dict = defaultdict(list) - [answer_sources_dict[url].append(score) for score, url in answer_sources] - answers_dict = {} - for url, scores_url in answer_sources_dict.items(): - answers_dict[url] = np.max(scores_url) - answer_sources = [(score, url) for url, score in answers_dict.items()] - answer_sources.sort(key=lambda x: x[0], reverse=True) - if show_rank: - # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)] - # sorted_sources_urls = "Sources [Rank | Link]:
" + "
".join(answer_sources) - answer_sources = [ - "%s" % url for rank, (score, url) in enumerate(answer_sources) - ] - sorted_sources_urls = "Ranked Sources:
" + "
".join( - answer_sources - ) - else: - answer_sources = [ - "

  • %.2g | %s
  • " % (score, url) - for score, url in answer_sources - ] - sorted_sources_urls = f"{source_prefix}

      " + "

      ".join( - answer_sources - ) - sorted_sources_urls += f"

    {source_postfix}" - - if not answer["output_text"].endswith("\n"): - answer["output_text"] += "\n" - - if answer_with_sources: - extra = "\n" + sorted_sources_urls - else: - extra = "" - ret = answer["output_text"] + extra - return ret, extra - - -def clean_doc(docs1): - if not isinstance(docs1, (list, tuple, types.GeneratorType)): - docs1 = [docs1] - for doci, doc in enumerate(docs1): - docs1[doci].page_content = "\n".join( - [x.strip() for x in doc.page_content.split("\n") if x.strip()] - ) - return docs1 - - -def chunk_sources(sources, chunk=True, chunk_size=512, language=None): - if not chunk: - return sources - if not isinstance( - sources, (list, tuple, types.GeneratorType) - ) and not callable(sources): - # if just one document - sources = [sources] - if language and False: - # Bug in langchain, keep separator=True not working - # https://github.com/hwchase17/langchain/issues/2836 - # so avoid this for now - keep_separator = True - separators = ( - RecursiveCharacterTextSplitter.get_separators_for_language( - language - ) - ) - else: - separators = ["\n\n", "\n", " ", ""] - keep_separator = False - splitter = RecursiveCharacterTextSplitter( - chunk_size=chunk_size, - chunk_overlap=0, - keep_separator=keep_separator, - separators=separators, - ) - source_chunks = splitter.split_documents(sources) - - # currently in order, but when pull from db won't be, so mark order and document by hash - doc_hash = str(uuid.uuid4())[:10] - [ - x.metadata.update(dict(doc_hash=doc_hash, chunk_id=chunk_id)) - for chunk_id, x in enumerate(source_chunks) - ] - - return source_chunks - - -def get_db_from_hf(dest=".", db_dir="db_dir_DriverlessAI_docs.zip"): - from huggingface_hub import hf_hub_download - - # True for case when locally already logged in with correct token, so don't have to set key - token = os.getenv("HUGGINGFACE_API_TOKEN", True) - path_to_zip_file = hf_hub_download( - "h2oai/db_dirs", db_dir, token=token, repo_type="dataset" - ) - import zipfile - - with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: - persist_directory = os.path.dirname(zip_ref.namelist()[0]) - remove(persist_directory) - zip_ref.extractall(dest) - return path_to_zip_file - - -# Note dir has space in some cases, while zip does not -some_db_zips = [ - [ - "db_dir_DriverlessAI_docs.zip", - "db_dir_DriverlessAI docs", - "CC-BY-NC license", - ], - ["db_dir_UserData.zip", "db_dir_UserData", "CC-BY license for ArXiv"], - ["db_dir_github_h2oGPT.zip", "db_dir_github h2oGPT", "ApacheV2 license"], - ["db_dir_wiki.zip", "db_dir_wiki", "CC-BY-SA Wikipedia license"], - # ['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'], -] - -all_db_zips = some_db_zips + [ - [ - "db_dir_wiki_full.zip", - "db_dir_wiki_full.zip", - "23GB, 05/04/2023 CC-BY-SA Wiki license", - ], -] - - -def get_some_dbs_from_hf(dest=".", db_zips=None): - if db_zips is None: - db_zips = some_db_zips - for db_dir, dir_expected, license1 in db_zips: - path_to_zip_file = get_db_from_hf(dest=dest, db_dir=db_dir) - assert os.path.isfile(path_to_zip_file), ( - "Missing zip in %s" % path_to_zip_file - ) - if dir_expected: - assert os.path.isdir(os.path.join(dest, dir_expected)), ( - "Missing path for %s" % dir_expected - ) - assert os.path.isdir(os.path.join(dest, dir_expected, "index")), ( - "Missing index in %s" % dir_expected - ) - - -def _create_local_weaviate_client(): - WEAVIATE_URL = os.getenv("WEAVIATE_URL", "http://localhost:8080") - WEAVIATE_USERNAME = os.getenv("WEAVIATE_USERNAME") - WEAVIATE_PASSWORD = os.getenv("WEAVIATE_PASSWORD") - WEAVIATE_SCOPE = os.getenv("WEAVIATE_SCOPE", "offline_access") - - resource_owner_config = None - try: - import weaviate - - if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None: - resource_owner_config = weaviate.AuthClientPassword( - username=WEAVIATE_USERNAME, - password=WEAVIATE_PASSWORD, - scope=WEAVIATE_SCOPE, - ) - - client = weaviate.Client( - WEAVIATE_URL, auth_client_secret=resource_owner_config - ) - return client - except Exception as e: - print(f"Failed to create Weaviate client: {e}") - return None - - -if __name__ == "__main__": - pass diff --git a/apps/language_models/langchain/gradio_utils/grclient.py b/apps/language_models/langchain/gradio_utils/grclient.py deleted file mode 100644 index 68ba89ea..00000000 --- a/apps/language_models/langchain/gradio_utils/grclient.py +++ /dev/null @@ -1,93 +0,0 @@ -import traceback -from typing import Callable -import os - -from gradio_client.client import Job - -os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" - -from gradio_client import Client - - -class GradioClient(Client): - """ - Parent class of gradio client - To handle automatically refreshing client if detect gradio server changed - """ - - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - super().__init__(*args, **kwargs) - self.server_hash = self.get_server_hash() - - def get_server_hash(self): - """ - Get server hash using super without any refresh action triggered - Returns: git hash of gradio server - """ - return super().submit(api_name="/system_hash").result() - - def refresh_client_if_should(self): - # get current hash in order to update api_name -> fn_index map in case gradio server changed - # FIXME: Could add cli api as hash - server_hash = self.get_server_hash() - if self.server_hash != server_hash: - self.refresh_client() - self.server_hash = server_hash - else: - self.reset_session() - - def refresh_client(self): - """ - Ensure every client call is independent - Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code) - Returns: - """ - # need session hash to be new every time, to avoid "generator already executing" - self.reset_session() - - client = Client(*self.args, **self.kwargs) - for k, v in client.__dict__.items(): - setattr(self, k, v) - - def submit( - self, - *args, - api_name: str | None = None, - fn_index: int | None = None, - result_callbacks: Callable | list[Callable] | None = None, - ) -> Job: - # Note predict calls submit - try: - self.refresh_client_if_should() - job = super().submit(*args, api_name=api_name, fn_index=fn_index) - except Exception as e: - print("Hit e=%s" % str(e), flush=True) - # force reconfig in case only that - self.refresh_client() - job = super().submit(*args, api_name=api_name, fn_index=fn_index) - - # see if immediately failed - e = job.future._exception - if e is not None: - print( - "GR job failed: %s %s" - % (str(e), "".join(traceback.format_tb(e.__traceback__))), - flush=True, - ) - # force reconfig in case only that - self.refresh_client() - job = super().submit(*args, api_name=api_name, fn_index=fn_index) - e2 = job.future._exception - if e2 is not None: - print( - "GR job failed again: %s\n%s" - % ( - str(e2), - "".join(traceback.format_tb(e2.__traceback__)), - ), - flush=True, - ) - - return job diff --git a/apps/language_models/langchain/h2oai_pipeline.py b/apps/language_models/langchain/h2oai_pipeline.py deleted file mode 100644 index d0c4c015..00000000 --- a/apps/language_models/langchain/h2oai_pipeline.py +++ /dev/null @@ -1,765 +0,0 @@ -import os -from apps.stable_diffusion.src.utils.utils import _compile_module -from io import BytesIO -import torch_mlir - -from stopping import get_stopping -from prompter import Prompter, PromptType - -from transformers import TextGenerationPipeline -from transformers.pipelines.text_generation import ReturnType -from transformers.generation import ( - GenerationConfig, - LogitsProcessorList, - StoppingCriteriaList, -) -import copy -import torch -from transformers import AutoConfig, AutoModelForCausalLM -import gc -from pathlib import Path -from shark.shark_inference import SharkInference -from shark.shark_downloader import download_public_file -from shark.shark_importer import import_with_fx, save_mlir -from apps.stable_diffusion.src import args - -# Brevitas -from typing import List, Tuple -from brevitas_examples.common.generative.quantize import quantize_model -from brevitas_examples.llm.llm_quant.run_utils import get_model_impl - - -# fmt: off -def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: - if len(lhs) == 3 and len(rhs) == 2: - return [lhs[0], lhs[1], rhs[0]] - elif len(lhs) == 2 and len(rhs) == 2: - return [lhs[0], rhs[0]] - else: - raise ValueError("Input shapes not supported.") - - -def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: - # output dtype is the dtype of the lhs float input - lhs_rank, lhs_dtype = lhs_rank_dtype - return lhs_dtype - - -def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: - return - - -brevitas_matmul_rhs_group_quant_library = [ - quant〇matmul_rhs_group_quant〡shape, - quant〇matmul_rhs_group_quant〡dtype, - quant〇matmul_rhs_group_quant〡has_value_semantics] -# fmt: on - -global_device = "cuda" -global_precision = "fp16" - -if not args.run_docuchat_web: - args.device = global_device - args.precision = global_precision -tensor_device = "cpu" if args.device == "cpu" else "cuda" - - -class H2OGPTModel(torch.nn.Module): - def __init__(self, device, precision): - super().__init__() - torch_dtype = ( - torch.float32 - if precision == "fp32" or device == "cpu" - else torch.float16 - ) - device_map = {"": "cpu"} if device == "cpu" else {"": 0} - model_kwargs = { - "local_files_only": False, - "torch_dtype": torch_dtype, - "resume_download": True, - "use_auth_token": False, - "trust_remote_code": True, - "offload_folder": "offline_folder", - "device_map": device_map, - } - config = AutoConfig.from_pretrained( - "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", - use_auth_token=False, - trust_remote_code=True, - offload_folder="offline_folder", - ) - self.model = AutoModelForCausalLM.from_pretrained( - "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", - config=config, - **model_kwargs, - ) - if precision in ["int4", "int8"]: - print("Applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - self.model.transformer.h, - dtype=torch.float32, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=128, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward(self, input_ids, attention_mask): - input_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": None, - "use_cache": True, - } - output = self.model( - **input_dict, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - return output.logits[:, -1, :] - - -class H2OGPTSHARKModel(torch.nn.Module): - def __init__(self): - super().__init__() - model_name = "h2ogpt_falcon_7b" - extended_model_name = ( - model_name + "_" + args.precision + "_" + args.device - ) - vmfb_path = Path(extended_model_name + ".vmfb") - mlir_path = Path(model_name + "_" + args.precision + ".mlir") - shark_module = None - - need_to_compile = False - if not vmfb_path.exists(): - need_to_compile = True - # Downloading VMFB from shark_tank - print("Trying to download pre-compiled vmfb from shark tank.") - download_public_file( - "gs://shark_tank/langchain/" + str(vmfb_path), - vmfb_path.absolute(), - single_file=True, - ) - if vmfb_path.exists(): - print( - "Pre-compiled vmfb downloaded from shark tank successfully." - ) - need_to_compile = False - - if need_to_compile: - if not mlir_path.exists(): - print("Trying to download pre-generated mlir from shark tank.") - # Downloading MLIR from shark_tank - download_public_file( - "gs://shark_tank/langchain/" + str(mlir_path), - mlir_path.absolute(), - single_file=True, - ) - if mlir_path.exists(): - with open(mlir_path, "rb") as f: - bytecode = f.read() - else: - # Generating the mlir - bytecode = self.get_bytecode(tensor_device, args.precision) - - shark_module = SharkInference( - mlir_module=bytecode, - device=args.device, - mlir_dialect="linalg", - ) - print(f"[DEBUG] generating vmfb.") - shark_module = _compile_module( - shark_module, extended_model_name, [] - ) - print("Saved newly generated vmfb.") - - if shark_module is None: - if vmfb_path.exists(): - print("Compiled vmfb found. Loading it from: ", vmfb_path) - shark_module = SharkInference( - None, device=args.device, mlir_dialect="linalg" - ) - shark_module.load_module(str(vmfb_path)) - print("Compiled vmfb loaded successfully.") - else: - raise ValueError("Unable to download/generate a vmfb.") - - self.model = shark_module - - def get_bytecode(self, device, precision): - h2ogpt_model = H2OGPTModel(device, precision) - - compilation_input_ids = torch.randint( - low=1, high=10000, size=(1, 400) - ).to(device=device) - compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to( - device=device - ) - - h2ogptCompileInput = ( - compilation_input_ids, - compilation_attention_mask, - ) - - print(f"[DEBUG] generating torchscript graph") - ts_graph = import_with_fx( - h2ogpt_model, - h2ogptCompileInput, - is_f16=False, - precision=precision, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - del h2ogpt_model - del self.src_model - - print(f"[DEBUG] generating torch mlir") - if precision in ["int4", "int8"]: - from torch_mlir.compiler_utils import ( - run_pipeline_with_repro_report, - ) - - module = torch_mlir.compile( - ts_graph, - [*h2ogptCompileInput], - output_type=torch_mlir.OutputType.TORCH, - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - module, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - else: - module = torch_mlir.compile( - ts_graph, - [*h2ogptCompileInput], - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - del ts_graph - - print(f"[DEBUG] converting to bytecode") - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - del module - - bytecode = save_mlir( - bytecode, - model_name=f"h2ogpt_{precision}", - frontend="torch", - ) - return bytecode - - def forward(self, input_ids, attention_mask): - result = torch.from_numpy( - self.model( - "forward", - (input_ids.to(device="cpu"), attention_mask.to(device="cpu")), - ) - ).to(device=tensor_device) - return result - - -def decode_tokens(tokenizer, res_tokens): - for i in range(len(res_tokens)): - if type(res_tokens[i]) != int: - res_tokens[i] = int(res_tokens[i][0]) - - res_str = tokenizer.decode(res_tokens, skip_special_tokens=True) - return res_str - - -def generate_token(h2ogpt_shark_model, model, tokenizer, **generate_kwargs): - del generate_kwargs["max_time"] - generate_kwargs["input_ids"] = generate_kwargs["input_ids"].to( - device=tensor_device - ) - generate_kwargs["attention_mask"] = generate_kwargs["attention_mask"].to( - device=tensor_device - ) - truncated_input_ids = [] - stopping_criteria = generate_kwargs["stopping_criteria"] - - generation_config_ = GenerationConfig.from_model_config(model.config) - generation_config = copy.deepcopy(generation_config_) - model_kwargs = generation_config.update(**generate_kwargs) - - logits_processor = LogitsProcessorList() - stopping_criteria = ( - stopping_criteria - if stopping_criteria is not None - else StoppingCriteriaList() - ) - - eos_token_id = generation_config.eos_token_id - generation_config.pad_token_id = eos_token_id - - ( - inputs_tensor, - model_input_name, - model_kwargs, - ) = model._prepare_model_inputs( - None, generation_config.bos_token_id, model_kwargs - ) - - model_kwargs["output_attentions"] = generation_config.output_attentions - model_kwargs[ - "output_hidden_states" - ] = generation_config.output_hidden_states - model_kwargs["use_cache"] = generation_config.use_cache - - input_ids = ( - inputs_tensor - if model_input_name == "input_ids" - else model_kwargs.pop("input_ids") - ) - - input_ids_seq_length = input_ids.shape[-1] - - generation_config.max_length = ( - generation_config.max_new_tokens + input_ids_seq_length - ) - - logits_processor = model._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids_seq_length, - encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=None, - logits_processor=logits_processor, - ) - - stopping_criteria = model._get_stopping_criteria( - generation_config=generation_config, - stopping_criteria=stopping_criteria, - ) - - logits_warper = model._get_logits_warper(generation_config) - - ( - input_ids, - model_kwargs, - ) = model._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, # 1 - is_encoder_decoder=model.config.is_encoder_decoder, # False - **model_kwargs, - ) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = ( - torch.tensor(eos_token_id).to(device=tensor_device) - if eos_token_id is not None - else None - ) - - pad_token_id = generation_config.pad_token_id - eos_token_id = eos_token_id - - output_scores = generation_config.output_scores # False - return_dict_in_generate = ( - generation_config.return_dict_in_generate # False - ) - - # init attention / hidden states / scores tuples - scores = () if (return_dict_in_generate and output_scores) else None - - # keep track of which sequences are already finished - unfinished_sequences = torch.ones( - input_ids.shape[0], - dtype=torch.long, - device=input_ids.device, - ) - - timesRan = 0 - import time - - start = time.time() - print("\n") - - res_tokens = [] - while True: - model_inputs = model.prepare_inputs_for_generation( - input_ids, **model_kwargs - ) - - outputs = h2ogpt_shark_model.forward( - model_inputs["input_ids"], model_inputs["attention_mask"] - ) - - if args.precision == "fp16": - outputs = outputs.to(dtype=torch.float32) - next_token_logits = outputs - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - - # sample - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - - next_token = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_token = next_token * unfinished_sequences + pad_token_id * ( - 1 - unfinished_sequences - ) - - input_ids = torch.cat([input_ids, next_token[:, None]], dim=-1) - - model_kwargs["past_key_values"] = None - if "attention_mask" in model_kwargs: - attention_mask = model_kwargs["attention_mask"] - model_kwargs["attention_mask"] = torch.cat( - [ - attention_mask, - attention_mask.new_ones((attention_mask.shape[0], 1)), - ], - dim=-1, - ) - - truncated_input_ids.append(input_ids[:, 0]) - input_ids = input_ids[:, 1:] - model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, 1:] - - new_word = tokenizer.decode( - next_token.cpu().numpy(), - add_special_tokens=False, - skip_special_tokens=True, - clean_up_tokenization_spaces=True, - ) - - res_tokens.append(next_token) - if new_word == "<0x0A>": - print("\n", end="", flush=True) - else: - print(f"{new_word}", end=" ", flush=True) - - part_str = decode_tokens(tokenizer, res_tokens) - yield part_str - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_token.tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) - # stop when each sentence is finished - if unfinished_sequences.max() == 0 or stopping_criteria( - input_ids, scores - ): - break - timesRan = timesRan + 1 - - end = time.time() - print( - "\n\nTime taken is {:.2f} seconds/token\n".format( - (end - start) / timesRan - ) - ) - - torch.cuda.empty_cache() - gc.collect() - - res_str = decode_tokens(tokenizer, res_tokens) - yield res_str - - -def pad_or_truncate_inputs( - input_ids, attention_mask, max_padding_length=400, do_truncation=False -): - inp_shape = input_ids.shape - if inp_shape[1] < max_padding_length: - # do padding - num_add_token = max_padding_length - inp_shape[1] - padded_input_ids = torch.cat( - [ - torch.tensor([[11] * num_add_token]).to(device=tensor_device), - input_ids, - ], - dim=1, - ) - padded_attention_mask = torch.cat( - [ - torch.tensor([[0] * num_add_token]).to(device=tensor_device), - attention_mask, - ], - dim=1, - ) - return padded_input_ids, padded_attention_mask - elif inp_shape[1] > max_padding_length or do_truncation: - # do truncation - num_remove_token = inp_shape[1] - max_padding_length - truncated_input_ids = input_ids[:, num_remove_token:] - truncated_attention_mask = attention_mask[:, num_remove_token:] - return truncated_input_ids, truncated_attention_mask - else: - return input_ids, attention_mask - - -class H2OTextGenerationPipeline(TextGenerationPipeline): - def __init__( - self, - *args, - debug=False, - chat=False, - stream_output=False, - sanitize_bot_response=False, - use_prompter=True, - prompter=None, - prompt_type=None, - prompt_dict=None, - max_input_tokens=2048 - 256, - **kwargs, - ): - """ - HF-like pipeline, but handle instruction prompting and stopping (for some models) - :param args: - :param debug: - :param chat: - :param stream_output: - :param sanitize_bot_response: - :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter - :param prompter: prompter, can pass if have already - :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py. - If use_prompter, then will make prompter and use it. - :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom - :param max_input_tokens: - :param kwargs: - """ - super().__init__(*args, **kwargs) - self.prompt_text = None - self.use_prompter = use_prompter - self.prompt_type = prompt_type - self.prompt_dict = prompt_dict - self.prompter = prompter - if self.use_prompter: - if self.prompter is not None: - assert self.prompter.prompt_type is not None - else: - self.prompter = Prompter( - self.prompt_type, - self.prompt_dict, - debug=debug, - chat=chat, - stream_output=stream_output, - ) - self.human = self.prompter.humanstr - self.bot = self.prompter.botstr - self.can_stop = True - else: - self.prompter = None - self.human = None - self.bot = None - self.can_stop = False - self.sanitize_bot_response = sanitize_bot_response - self.max_input_tokens = ( - max_input_tokens # not for generate, so ok that not kwargs - ) - - @staticmethod - def limit_prompt(prompt_text, tokenizer, max_prompt_length=None): - verbose = bool(int(os.getenv("VERBOSE_PIPELINE", "0"))) - - if hasattr(tokenizer, "model_max_length"): - # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py - model_max_length = tokenizer.model_max_length - if max_prompt_length is not None: - model_max_length = min(model_max_length, max_prompt_length) - # cut at some upper likely limit to avoid excessive tokenization etc - # upper bound of 10 chars/token, e.g. special chars sometimes are long - if len(prompt_text) > model_max_length * 10: - len0 = len(prompt_text) - prompt_text = prompt_text[-model_max_length * 10 :] - if verbose: - print( - "Cut of input: %s -> %s" % (len0, len(prompt_text)), - flush=True, - ) - else: - # unknown - model_max_length = None - - num_prompt_tokens = None - if model_max_length is not None: - # can't wait for "hole" if not plain prompt_type, since would lose prefix like : - # For https://github.com/h2oai/h2ogpt/issues/192 - for trial in range(0, 3): - prompt_tokens = tokenizer(prompt_text)["input_ids"] - num_prompt_tokens = len(prompt_tokens) - if num_prompt_tokens > model_max_length: - # conservative by using int() - chars_per_token = int(len(prompt_text) / num_prompt_tokens) - # keep tail, where question is if using langchain - prompt_text = prompt_text[ - -model_max_length * chars_per_token : - ] - if verbose: - print( - "reducing %s tokens, assuming average of %s chars/token for %s characters" - % ( - num_prompt_tokens, - chars_per_token, - len(prompt_text), - ), - flush=True, - ) - else: - if verbose: - print( - "using %s tokens with %s chars" - % (num_prompt_tokens, len(prompt_text)), - flush=True, - ) - break - - return prompt_text, num_prompt_tokens - - def preprocess( - self, - prompt_text, - prefix="", - handle_long_generation=None, - **generate_kwargs, - ): - ( - prompt_text, - num_prompt_tokens, - ) = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer) - - data_point = dict(context="", instruction=prompt_text, input="") - if self.prompter is not None: - prompt_text = self.prompter.generate_prompt(data_point) - self.prompt_text = prompt_text - if handle_long_generation is None: - # forces truncation of inputs to avoid critical failure - handle_long_generation = None # disable with new approaches - return super().preprocess( - prompt_text, - prefix=prefix, - handle_long_generation=handle_long_generation, - **generate_kwargs, - ) - - def postprocess( - self, - model_outputs, - return_type=ReturnType.FULL_TEXT, - clean_up_tokenization_spaces=True, - ): - records = super().postprocess( - model_outputs, - return_type=return_type, - clean_up_tokenization_spaces=clean_up_tokenization_spaces, - ) - for rec in records: - if self.use_prompter: - outputs = rec["generated_text"] - outputs = self.prompter.get_response( - outputs, - prompt=self.prompt_text, - sanitize_bot_response=self.sanitize_bot_response, - ) - elif self.bot and self.human: - outputs = ( - rec["generated_text"] - .split(self.bot)[1] - .split(self.human)[0] - ) - else: - outputs = rec["generated_text"] - rec["generated_text"] = outputs - print( - "prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), - flush=True, - ) - return records - - def _forward(self, model_inputs, **generate_kwargs): - if self.can_stop: - stopping_criteria = get_stopping( - self.prompt_type, - self.prompt_dict, - self.tokenizer, - self.device, - human=self.human, - bot=self.bot, - model_max_length=self.tokenizer.model_max_length, - ) - generate_kwargs["stopping_criteria"] = stopping_criteria - # return super()._forward(model_inputs, **generate_kwargs) - return self.__forward(model_inputs, **generate_kwargs) - - # FIXME: Copy-paste of original _forward, but removed copy.deepcopy() - # FIXME: https://github.com/h2oai/h2ogpt/issues/172 - def __forward(self, model_inputs, **generate_kwargs): - input_ids = model_inputs["input_ids"] - attention_mask = model_inputs.get("attention_mask", None) - # Allow empty prompts - if input_ids.shape[1] == 0: - input_ids = None - attention_mask = None - in_b = 1 - else: - in_b = input_ids.shape[0] - prompt_text = model_inputs.pop("prompt_text") - - ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying - ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline. - # generate_kwargs = copy.deepcopy(generate_kwargs) - prefix_length = generate_kwargs.pop("prefix_length", 0) - if prefix_length > 0: - has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( - "generation_config" in generate_kwargs - and generate_kwargs["generation_config"].max_new_tokens - is not None - ) - if not has_max_new_tokens: - generate_kwargs["max_length"] = ( - generate_kwargs.get("max_length") - or self.model.config.max_length - ) - generate_kwargs["max_length"] += prefix_length - has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( - "generation_config" in generate_kwargs - and generate_kwargs["generation_config"].min_new_tokens - is not None - ) - if not has_min_new_tokens and "min_length" in generate_kwargs: - generate_kwargs["min_length"] += prefix_length - - # BS x SL - # pad or truncate the input_ids and attention_mask - max_padding_length = 400 - input_ids, attention_mask = pad_or_truncate_inputs( - input_ids, attention_mask, max_padding_length=max_padding_length - ) - - return_dict = { - "model": self.model, - "tokenizer": self.tokenizer, - "input_ids": input_ids, - "attention_mask": attention_mask, - "attention_mask": attention_mask, - } - return_dict = {**return_dict, **generate_kwargs} - return return_dict diff --git a/apps/language_models/langchain/image_captions.py b/apps/language_models/langchain/image_captions.py deleted file mode 100644 index e61a48d0..00000000 --- a/apps/language_models/langchain/image_captions.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Based upon ImageCaptionLoader in LangChain version: langchain/document_loaders/image_captions.py -But accepts preloaded model to avoid slowness in use and CUDA forking issues - -Loader that loads image captions -By default, the loader utilizes the pre-trained BLIP image captioning model. -https://huggingface.co/Salesforce/blip-image-captioning-base - -""" -from typing import List, Union, Any, Tuple - -import requests -from langchain.docstore.document import Document -from langchain.document_loaders import ImageCaptionLoader - -from utils import get_device, NullContext - -import pkg_resources - -try: - assert pkg_resources.get_distribution("bitsandbytes") is not None - have_bitsandbytes = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_bitsandbytes = False - - -class H2OImageCaptionLoader(ImageCaptionLoader): - """Loader that loads the captions of an image""" - - def __init__( - self, - path_images: Union[str, List[str]] = None, - blip_processor: str = None, - blip_model: str = None, - caption_gpu=True, - load_in_8bit=True, - # True doesn't seem to work, even though https://huggingface.co/Salesforce/blip2-flan-t5-xxl#in-8-bit-precision-int8 - load_half=False, - load_gptq="", - use_safetensors=False, - min_new_tokens=20, - max_tokens=50, - ): - if blip_model is None or blip_model is None: - blip_processor = "Salesforce/blip-image-captioning-base" - blip_model = "Salesforce/blip-image-captioning-base" - - super().__init__(path_images, blip_processor, blip_model) - self.blip_processor = blip_processor - self.blip_model = blip_model - self.processor = None - self.model = None - self.caption_gpu = caption_gpu - self.context_class = NullContext - self.device = "cpu" - self.load_in_8bit = ( - load_in_8bit and have_bitsandbytes - ) # only for blip2 - self.load_half = load_half - self.load_gptq = load_gptq - self.use_safetensors = use_safetensors - self.gpu_id = "auto" - # default prompt - self.prompt = "image of" - self.min_new_tokens = min_new_tokens - self.max_tokens = max_tokens - - def set_context(self): - if get_device() == "cuda" and self.caption_gpu: - import torch - - n_gpus = ( - torch.cuda.device_count() if torch.cuda.is_available else 0 - ) - if n_gpus > 0: - self.context_class = torch.device - self.device = "cuda" - - def load_model(self): - try: - import transformers - except ImportError: - raise ValueError( - "`transformers` package not found, please install with " - "`pip install transformers`." - ) - self.set_context() - if self.caption_gpu: - if self.gpu_id == "auto": - # blip2 has issues with multi-GPU. Error says need to somehow set language model in device map - # device_map = 'auto' - device_map = {"": 0} - else: - if self.device == "cuda": - device_map = {"": self.gpu_id} - else: - device_map = {"": "cpu"} - else: - device_map = {"": "cpu"} - import torch - - with torch.no_grad(): - with self.context_class(self.device): - context_class_cast = ( - NullContext if self.device == "cpu" else torch.autocast - ) - with context_class_cast(self.device): - if "blip2" in self.blip_processor.lower(): - from transformers import ( - Blip2Processor, - Blip2ForConditionalGeneration, - ) - - if self.load_half and not self.load_in_8bit: - self.processor = Blip2Processor.from_pretrained( - self.blip_processor, device_map=device_map - ).half() - self.model = ( - Blip2ForConditionalGeneration.from_pretrained( - self.blip_model, device_map=device_map - ).half() - ) - else: - self.processor = Blip2Processor.from_pretrained( - self.blip_processor, - load_in_8bit=self.load_in_8bit, - device_map=device_map, - ) - self.model = ( - Blip2ForConditionalGeneration.from_pretrained( - self.blip_model, - load_in_8bit=self.load_in_8bit, - device_map=device_map, - ) - ) - else: - from transformers import ( - BlipForConditionalGeneration, - BlipProcessor, - ) - - self.load_half = False # not supported - if self.caption_gpu: - if device_map == "auto": - # Blip doesn't support device_map='auto' - if self.device == "cuda": - if self.gpu_id == "auto": - device_map = {"": 0} - else: - device_map = {"": self.gpu_id} - else: - device_map = {"": "cpu"} - else: - device_map = {"": "cpu"} - self.processor = BlipProcessor.from_pretrained( - self.blip_processor, device_map=device_map - ) - self.model = ( - BlipForConditionalGeneration.from_pretrained( - self.blip_model, device_map=device_map - ) - ) - return self - - def set_image_paths(self, path_images: Union[str, List[str]]): - """ - Load from a list of image files - """ - if isinstance(path_images, str): - self.image_paths = [path_images] - else: - self.image_paths = path_images - - def load(self, prompt=None) -> List[Document]: - if self.processor is None or self.model is None: - self.load_model() - results = [] - for path_image in self.image_paths: - caption, metadata = self._get_captions_and_metadata( - model=self.model, - processor=self.processor, - path_image=path_image, - prompt=prompt, - ) - doc = Document(page_content=caption, metadata=metadata) - results.append(doc) - - return results - - def _get_captions_and_metadata( - self, model: Any, processor: Any, path_image: str, prompt=None - ) -> Tuple[str, dict]: - """ - Helper function for getting the captions and metadata of an image - """ - if prompt is None: - prompt = self.prompt - try: - from PIL import Image - except ImportError: - raise ValueError( - "`PIL` package not found, please install with `pip install pillow`" - ) - - try: - if path_image.startswith("http://") or path_image.startswith( - "https://" - ): - image = Image.open( - requests.get(path_image, stream=True).raw - ).convert("RGB") - else: - image = Image.open(path_image).convert("RGB") - except Exception: - raise ValueError(f"Could not get image data for {path_image}") - - import torch - - with torch.no_grad(): - with self.context_class(self.device): - context_class_cast = ( - NullContext if self.device == "cpu" else torch.autocast - ) - with context_class_cast(self.device): - if self.load_half: - inputs = processor( - image, prompt, return_tensors="pt" - ).half() - else: - inputs = processor(image, prompt, return_tensors="pt") - min_length = len(prompt) // 4 + self.min_new_tokens - self.max_tokens = max(self.max_tokens, min_length) - output = model.generate( - **inputs, - min_length=min_length, - max_length=self.max_tokens, - ) - - caption: str = processor.decode( - output[0], skip_special_tokens=True - ) - prompti = caption.find(prompt) - if prompti >= 0: - caption = caption[prompti + len(prompt) :] - metadata: dict = {"image_path": path_image} - - return caption, metadata diff --git a/apps/language_models/langchain/langchain_requirements.txt b/apps/language_models/langchain/langchain_requirements.txt deleted file mode 100644 index 7db5e53f..00000000 --- a/apps/language_models/langchain/langchain_requirements.txt +++ /dev/null @@ -1,120 +0,0 @@ -# for generate (gradio server) and finetune -datasets==2.13.0 -sentencepiece==0.1.99 -huggingface_hub==0.16.4 -appdirs==1.4.4 -fire==0.5.0 -docutils==0.20.1 -evaluate==0.4.0 -rouge_score==0.1.2 -sacrebleu==2.3.1 -scikit-learn==1.2.2 -alt-profanity-check==1.2.2 -better-profanity==0.7.0 -numpy==1.24.3 -pandas==2.0.2 -matplotlib==3.7.1 -loralib==0.1.1 -bitsandbytes==0.39.0 -accelerate==0.20.3 -peft==0.4.0 -# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026) -transformers==4.30.2 -tokenizers==0.13.3 -APScheduler==3.10.1 - -# optional for generate -pynvml==11.5.0 -psutil==5.9.5 -boto3==1.26.101 -botocore==1.29.101 - -# optional for finetune -tensorboard==2.13.0 -neptune==1.2.0 - -# for gradio client -gradio_client==0.2.10 -beautifulsoup4==4.12.2 -markdown==3.4.3 - -# data and testing -pytest==7.2.2 -pytest-xdist==3.2.1 -nltk==3.8.1 -textstat==0.7.3 -# pandoc==2.3 -pypandoc==1.11; sys_platform == "darwin" and platform_machine == "arm64" -pypandoc_binary==1.11; platform_machine == "x86_64" -pypandoc_binary==1.11; sys_platform == "win32" -openpyxl==3.1.2 -lm_dataformat==0.0.20 -bioc==2.0 - -# falcon -einops==0.6.1 -instructorembedding==1.0.1 - -# for gpt4all .env file, but avoid worrying about imports -python-dotenv==1.0.0 - -text-generation==0.6.0 -# for tokenization when don't have HF tokenizer -tiktoken==0.4.0 -# optional: for OpenAI endpoint or embeddings (requires key) -openai==0.27.8 - -# optional for chat with PDF -langchain==0.0.329 -pypdf==3.17.0 -# avoid textract, requires old six -#textract==1.6.5 - -# for HF embeddings -sentence_transformers==2.2.2 - -# local vector db -chromadb==0.3.25 -# server vector db -#pymilvus==2.2.8 - -# weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6 -# unstructured==0.8.1 - -# strong support for images -# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice -unstructured[local-inference]==0.7.4 -#pdf2image==1.16.3 -#pytesseract==0.3.10 -pillow - -pdfminer.six==20221105 -urllib3 -requests_file - -#pdf2image==1.16.3 -#pytesseract==0.3.10 -tabulate==0.9.0 -# FYI pandoc already part of requirements.txt - -# JSONLoader, but makes some trouble for some users -# jq==1.4.1 - -# to check licenses -# Run: pip-licenses|grep -v 'BSD\|Apache\|MIT' -pip-licenses==4.3.0 - -# weaviate vector db -weaviate-client==3.22.1 - -gpt4all==1.0.5 -llama-cpp-python==0.1.73 - -arxiv==1.4.8 -pymupdf==1.22.5 # AGPL license -# extract-msg==0.41.1 # GPL3 - -# sometimes unstructured fails, these work in those cases. See https://github.com/h2oai/h2ogpt/issues/320 -playwright==1.36.0 -# requires Chrome binary to be in path -selenium==4.10.0 diff --git a/apps/language_models/langchain/llama_flash_attn_monkey_patch.py b/apps/language_models/langchain/llama_flash_attn_monkey_patch.py deleted file mode 100644 index 11f7a39a..00000000 --- a/apps/language_models/langchain/llama_flash_attn_monkey_patch.py +++ /dev/null @@ -1,124 +0,0 @@ -from typing import List, Optional, Tuple - -import torch - -import transformers -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb - -from einops import rearrange - -from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func -from flash_attn.bert_padding import unpad_input, pad_input - - -def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: bool = False, - use_cache: bool = False, -) -> Tuple[ - torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]] -]: - """Input shape: Batch x Time x Channel - attention_mask: [bsz, q_len] - """ - bsz, q_len, _ = hidden_states.size() - - query_states = ( - self.q_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - key_states = ( - self.k_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - value_states = ( - self.v_proj(hidden_states) - .view(bsz, q_len, self.num_heads, self.head_dim) - .transpose(1, 2) - ) - # [bsz, q_len, nh, hd] - # [bsz, nh, q_len, hd] - - kv_seq_len = key_states.shape[-2] - assert past_key_value is None, "past_key_value is not supported" - - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, position_ids - ) - # [bsz, nh, t, hd] - assert not output_attentions, "output_attentions is not supported" - assert not use_cache, "use_cache is not supported" - - # Flash attention codes from - # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py - - # transform the data into the format required by flash attention - qkv = torch.stack( - [query_states, key_states, value_states], dim=2 - ) # [bsz, nh, 3, q_len, hd] - qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] - # We have disabled _prepare_decoder_attention_mask in LlamaModel - # the attention_mask should be the same as the key_padding_mask - key_padding_mask = attention_mask - - if key_padding_mask is None: - qkv = rearrange(qkv, "b s ... -> (b s) ...") - max_s = q_len - cu_q_lens = torch.arange( - 0, - (bsz + 1) * q_len, - step=q_len, - dtype=torch.int32, - device=qkv.device, - ) - output = flash_attn_unpadded_qkvpacked_func( - qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True - ) - output = rearrange(output, "(b s) ... -> b s ...", b=bsz) - else: - nheads = qkv.shape[-2] - x = rearrange(qkv, "b s three h d -> b s (three h d)") - x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) - x_unpad = rearrange( - x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads - ) - output_unpad = flash_attn_unpadded_qkvpacked_func( - x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True - ) - output = rearrange( - pad_input( - rearrange(output_unpad, "nnz h d -> nnz (h d)"), - indices, - bsz, - q_len, - ), - "b s (h d) -> b s h d", - h=nheads, - ) - return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None - - -# Disable the transformation of the attention mask in LlamaModel as the flash attention -# requires the attention mask to be the same as the key_padding_mask -def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length -): - # [bsz, seq_len] - return attention_mask - - -def replace_llama_attn_with_flash_attn(): - print( - "Replacing original LLaMa attention with flash attention", flush=True - ) - transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( - _prepare_decoder_attention_mask - ) - transformers.models.llama.modeling_llama.LlamaAttention.forward = forward diff --git a/apps/language_models/langchain/loaders.py b/apps/language_models/langchain/loaders.py deleted file mode 100644 index 342b019b..00000000 --- a/apps/language_models/langchain/loaders.py +++ /dev/null @@ -1,109 +0,0 @@ -import functools - - -def get_loaders(model_name, reward_type, llama_type=None, load_gptq=""): - # NOTE: Some models need specific new prompt_type - # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".) - if load_gptq: - from transformers import AutoTokenizer - from auto_gptq import AutoGPTQForCausalLM - - use_triton = False - functools.partial( - AutoGPTQForCausalLM.from_quantized, - quantize_config=None, - use_triton=use_triton, - ) - return AutoGPTQForCausalLM.from_quantized, AutoTokenizer - if llama_type is None: - llama_type = "llama" in model_name.lower() - if llama_type: - from transformers import LlamaForCausalLM, LlamaTokenizer - - return LlamaForCausalLM.from_pretrained, LlamaTokenizer - elif "distilgpt2" in model_name.lower(): - from transformers import AutoModelForCausalLM, AutoTokenizer - - return AutoModelForCausalLM.from_pretrained, AutoTokenizer - elif "gpt2" in model_name.lower(): - from transformers import GPT2LMHeadModel, GPT2Tokenizer - - return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer - elif "mbart-" in model_name.lower(): - from transformers import ( - MBartForConditionalGeneration, - MBart50TokenizerFast, - ) - - return ( - MBartForConditionalGeneration.from_pretrained, - MBart50TokenizerFast, - ) - elif ( - "t5" == model_name.lower() - or "t5-" in model_name.lower() - or "flan-" in model_name.lower() - ): - from transformers import AutoTokenizer, T5ForConditionalGeneration - - return T5ForConditionalGeneration.from_pretrained, AutoTokenizer - elif "bigbird" in model_name: - from transformers import ( - BigBirdPegasusForConditionalGeneration, - AutoTokenizer, - ) - - return ( - BigBirdPegasusForConditionalGeneration.from_pretrained, - AutoTokenizer, - ) - elif ( - "bart-large-cnn-samsum" in model_name - or "flan-t5-base-samsum" in model_name - ): - from transformers import pipeline - - return pipeline, "summarization" - elif ( - reward_type - or "OpenAssistant/reward-model".lower() in model_name.lower() - ): - from transformers import ( - AutoModelForSequenceClassification, - AutoTokenizer, - ) - - return ( - AutoModelForSequenceClassification.from_pretrained, - AutoTokenizer, - ) - else: - from transformers import AutoTokenizer, AutoModelForCausalLM - - model_loader = AutoModelForCausalLM - tokenizer_loader = AutoTokenizer - return model_loader.from_pretrained, tokenizer_loader - - -def get_tokenizer( - tokenizer_loader, - tokenizer_base_model, - local_files_only, - resume_download, - use_auth_token, -): - tokenizer = tokenizer_loader.from_pretrained( - tokenizer_base_model, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - padding_side="left", - ) - - tokenizer.pad_token_id = 0 # different from the eos token - # when generating, we will use the logits of right-most token to predict the next token - # so the padding should be on the left, - # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference - tokenizer.padding_side = "left" # Allow batched inference - - return tokenizer diff --git a/apps/language_models/langchain/make_db.py b/apps/language_models/langchain/make_db.py deleted file mode 100644 index 5bb431c9..00000000 --- a/apps/language_models/langchain/make_db.py +++ /dev/null @@ -1,203 +0,0 @@ -import os - -from gpt_langchain import ( - path_to_docs, - get_some_dbs_from_hf, - all_db_zips, - some_db_zips, - create_or_update_db, -) -from utils import get_ngpus_vis - - -def glob_to_db( - user_path, - chunk=True, - chunk_size=512, - verbose=False, - fail_any_exception=False, - n_jobs=-1, - url=None, - enable_captions=True, - captions_model=None, - caption_loader=None, - enable_ocr=False, -): - sources1 = path_to_docs( - user_path, - verbose=verbose, - fail_any_exception=fail_any_exception, - n_jobs=n_jobs, - chunk=chunk, - chunk_size=chunk_size, - url=url, - enable_captions=enable_captions, - captions_model=captions_model, - caption_loader=caption_loader, - enable_ocr=enable_ocr, - ) - return sources1 - - -def make_db_main( - use_openai_embedding: bool = False, - hf_embedding_model: str = None, - persist_directory: str = "db_dir_UserData", - user_path: str = "user_path", - url: str = None, - add_if_exists: bool = True, - collection_name: str = "UserData", - verbose: bool = False, - chunk: bool = True, - chunk_size: int = 512, - fail_any_exception: bool = False, - download_all: bool = False, - download_some: bool = False, - download_one: str = None, - download_dest: str = "./", - n_jobs: int = -1, - enable_captions: bool = True, - captions_model: str = "Salesforce/blip-image-captioning-base", - pre_load_caption_model: bool = False, - caption_gpu: bool = True, - enable_ocr: bool = False, - db_type: str = "chroma", -): - """ - # To make UserData db for generate.py, put pdfs, etc. into path user_path and run: - python make_db.py - - # once db is made, can use in generate.py like: - - python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b --langchain_mode=UserData - - or zip-up the db_dir_UserData and share: - - zip -r db_dir_UserData.zip db_dir_UserData - - # To get all db files (except large wiki_full) do: - python make_db.py --download_some=True - - # To get a single db file from HF: - python make_db.py --download_one=db_dir_DriverlessAI_docs.zip - - :param use_openai_embedding: Whether to use OpenAI embedding - :param hf_embedding_model: HF embedding model to use. Like generate.py, uses 'hkunlp/instructor-large' if have GPUs, else "sentence-transformers/all-MiniLM-L6-v2" - :param persist_directory: where to persist db - :param user_path: where to pull documents from (None means url is not None. If url is not None, this is ignored.) - :param url: url to generate documents from (None means user_path is not None) - :param add_if_exists: Add to db if already exists, but will not add duplicate sources - :param collection_name: Collection name for new db if not adding - :param verbose: whether to show verbose messages - :param chunk: whether to chunk data - :param chunk_size: chunk size for chunking - :param fail_any_exception: whether to fail if any exception hit during ingestion of files - :param download_all: whether to download all (including 23GB Wikipedia) example databases from h2o.ai HF - :param download_some: whether to download some small example databases from h2o.ai HF - :param download_one: whether to download one chosen example databases from h2o.ai HF - :param download_dest: Destination for downloads - :param n_jobs: Number of cores to use for ingesting multiple files - :param enable_captions: Whether to enable captions on images - :param captions_model: See generate.py - :param pre_load_caption_model: See generate.py - :param caption_gpu: Caption images on GPU if present - :param enable_ocr: Whether to enable OCR on images - :param db_type: Type of db to create. Currently only 'chroma' and 'weaviate' is supported. - :return: None - """ - db = None - - # match behavior of main() in generate.py for non-HF case - n_gpus = get_ngpus_vis() - if n_gpus == 0: - if hf_embedding_model is None: - # if no GPUs, use simpler embedding model to avoid cost in time - hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" - else: - if hf_embedding_model is None: - # if still None, then set default - hf_embedding_model = "hkunlp/instructor-large" - - if download_all: - print("Downloading all (and unzipping): %s" % all_db_zips, flush=True) - get_some_dbs_from_hf(download_dest, db_zips=all_db_zips) - if verbose: - print("DONE", flush=True) - return db, collection_name - elif download_some: - print( - "Downloading some (and unzipping): %s" % some_db_zips, flush=True - ) - get_some_dbs_from_hf(download_dest, db_zips=some_db_zips) - if verbose: - print("DONE", flush=True) - return db, collection_name - elif download_one: - print("Downloading %s (and unzipping)" % download_one, flush=True) - get_some_dbs_from_hf( - download_dest, db_zips=[[download_one, "", "Unknown License"]] - ) - if verbose: - print("DONE", flush=True) - return db, collection_name - - if enable_captions and pre_load_caption_model: - # preload, else can be too slow or if on GPU have cuda context issues - # Inside ingestion, this will disable parallel loading of multiple other kinds of docs - # However, if have many images, all those images will be handled more quickly by preloaded model on GPU - from image_captions import H2OImageCaptionLoader - - caption_loader = H2OImageCaptionLoader( - None, - blip_model=captions_model, - blip_processor=captions_model, - caption_gpu=caption_gpu, - ).load_model() - else: - if enable_captions: - caption_loader = "gpu" if caption_gpu else "cpu" - else: - caption_loader = False - - if verbose: - print("Getting sources", flush=True) - assert ( - user_path is not None or url is not None - ), "Can't have both user_path and url as None" - if not url: - assert os.path.isdir(user_path), ( - "user_path=%s does not exist" % user_path - ) - sources = glob_to_db( - user_path, - chunk=chunk, - chunk_size=chunk_size, - verbose=verbose, - fail_any_exception=fail_any_exception, - n_jobs=n_jobs, - url=url, - enable_captions=enable_captions, - captions_model=captions_model, - caption_loader=caption_loader, - enable_ocr=enable_ocr, - ) - exceptions = [x for x in sources if x.metadata.get("exception")] - print("Exceptions: %s" % exceptions, flush=True) - sources = [x for x in sources if "exception" not in x.metadata] - - assert len(sources) > 0, "No sources found" - db = create_or_update_db( - db_type, - persist_directory, - collection_name, - sources, - use_openai_embedding, - add_if_exists, - verbose, - hf_embedding_model, - ) - - assert db is not None - if verbose: - print("DONE", flush=True) - return db, collection_name diff --git a/apps/language_models/langchain/prompter.py b/apps/language_models/langchain/prompter.py deleted file mode 100644 index eb5d102a..00000000 --- a/apps/language_models/langchain/prompter.py +++ /dev/null @@ -1,1103 +0,0 @@ -import os -import ast -import time -from apps.language_models.langchain.enums import ( - PromptType, -) # also supports imports from this file from other files - -non_hf_types = ["gpt4all_llama", "llama", "gptj"] - -prompt_type_to_model_name = { - "plain": [ - "EleutherAI/gpt-j-6B", - "EleutherAI/pythia-6.9b", - "EleutherAI/pythia-12b", - "EleutherAI/pythia-12b-deduped", - "EleutherAI/gpt-neox-20b", - "openlm-research/open_llama_7b_700bt_preview", - "decapoda-research/llama-7b-hf", - "decapoda-research/llama-13b-hf", - "decapoda-research/llama-30b-hf", - "decapoda-research/llama-65b-hf", - "facebook/mbart-large-50-many-to-many-mmt", - "philschmid/bart-large-cnn-samsum", - "philschmid/flan-t5-base-samsum", - "gpt2", - "distilgpt2", - "mosaicml/mpt-7b-storywriter", - ], - "gptj": ["gptj", "gpt4all_llama"], - "prompt_answer": [ - "h2oai/h2ogpt-gm-oasst1-en-1024-20b", - "h2oai/h2ogpt-gm-oasst1-en-1024-12b", - "h2oai/h2ogpt-gm-oasst1-multilang-1024-20b", - "h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b", - "h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2", - "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", - "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b", - "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2", - "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1", - "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2", - "h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k", - "h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k", - "TheBloke/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2-GPTQ", - ], - "prompt_answer_openllama": [ - "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt", - "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2", - "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt", - "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b", - "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b", - ], - "instruct": [ - "TheBloke/llama-30b-supercot-SuperHOT-8K-fp16" - ], # https://huggingface.co/TheBloke/llama-30b-supercot-SuperHOT-8K-fp16#prompting - "instruct_with_end": ["databricks/dolly-v2-12b"], - "quality": [], - "human_bot": [ - "h2oai/h2ogpt-oasst1-512-12b", - "h2oai/h2ogpt-oasst1-512-20b", - "h2oai/h2ogpt-oig-oasst1-256-6_9b", - "h2oai/h2ogpt-oig-oasst1-512-6_9b", - "h2oai/h2ogpt-oig-oasst1-256-6.9b", # legacy - "h2oai/h2ogpt-oig-oasst1-512-6.9b", # legacy - "h2oai/h2ogpt-research-oasst1-512-30b", - "h2oai/h2ogpt-research-oasst1-llama-65b", - "h2oai/h2ogpt-oasst1-falcon-40b", - "h2oai/h2ogpt-oig-oasst1-falcon-40b", - ], - "dai_faq": [], - "summarize": [], - "simple_instruct": [ - "t5-small", - "t5-large", - "google/flan-t5", - "google/flan-t5-xxl", - "google/flan-ul2", - ], - "instruct_vicuna": [ - "AlekseyKorshuk/vicuna-7b", - "TheBloke/stable-vicuna-13B-HF", - "junelee/wizard-vicuna-13b", - ], - "human_bot_orig": ["togethercomputer/GPT-NeoXT-Chat-Base-20B"], - "open_assistant": [ - "OpenAssistant/oasst-sft-7-llama-30b-xor", - "oasst-sft-7-llama-30b", - ], - "wizard_lm": [ - "ehartford/WizardLM-7B-Uncensored", - "ehartford/WizardLM-13B-Uncensored", - ], - "wizard_mega": ["openaccess-ai-collective/wizard-mega-13b"], - "instruct_simple": ["JosephusCheung/Guanaco"], - "wizard_vicuna": ["ehartford/Wizard-Vicuna-13B-Uncensored"], - "wizard2": ["llama"], - "mptinstruct": [ - "mosaicml/mpt-30b-instruct", - "mosaicml/mpt-7b-instruct", - "mosaicml/mpt-30b-instruct", - ], - "mptchat": [ - "mosaicml/mpt-7b-chat", - "mosaicml/mpt-30b-chat", - "TheBloke/mpt-30B-chat-GGML", - ], - "vicuna11": ["lmsys/vicuna-33b-v1.3"], - "falcon": [ - "tiiuae/falcon-40b-instruct", - "tiiuae/falcon-40b", - "tiiuae/falcon-7b-instruct", - "tiiuae/falcon-7b", - ], - # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin -} -if os.getenv("OPENAI_API_KEY"): - prompt_type_to_model_name.update( - { - "openai": [ - "text-davinci-003", - "text-curie-001", - "text-babbage-001", - "text-ada-001", - ], - "openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"], - } - ) - -inv_prompt_type_to_model_name = { - v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l -} -inv_prompt_type_to_model_lower = { - v.strip().lower(): k - for k, l in prompt_type_to_model_name.items() - for v in l -} - -prompt_types_strings = [] -for p in PromptType: - prompt_types_strings.extend([p.name]) - -prompt_types = [] -for p in PromptType: - prompt_types.extend([p.name, p.value, str(p.value)]) - - -def get_prompt( - prompt_type, - prompt_dict, - chat, - context, - reduced, - making_context, - return_dict=False, -): - prompt_dict_error = "" - generates_leading_space = False - - if prompt_type == PromptType.custom.name and not isinstance( - prompt_dict, dict - ): - try: - prompt_dict = ast.literal_eval(prompt_dict) - except BaseException as e: - prompt_dict_error = str(e) - if prompt_dict_error: - promptA = None - promptB = None - PreInstruct = None - PreInput = "" - PreResponse = "" - terminate_response = None - chat_sep = "" - chat_turn_sep = "" - humanstr = "" - botstr = "" - generates_leading_space = False - elif prompt_type in [ - PromptType.custom.value, - str(PromptType.custom.value), - PromptType.custom.name, - ]: - promptA = prompt_dict.get("promptA", "") - promptB = prompt_dict.get("promptB", "") - PreInstruct = prompt_dict.get("PreInstruct", "") - PreInput = prompt_dict.get("PreInput", "") - PreResponse = prompt_dict.get("PreResponse", "") - terminate_response = prompt_dict.get("terminate_response", None) - chat_sep = prompt_dict.get("chat_sep", "\n") - chat_turn_sep = prompt_dict.get("chat_turn_sep", "\n") - humanstr = prompt_dict.get("humanstr", "") - botstr = prompt_dict.get("botstr", "") - elif prompt_type in [ - PromptType.plain.value, - str(PromptType.plain.value), - PromptType.plain.name, - ]: - promptA = promptB = PreInstruct = PreInput = PreResponse = None - terminate_response = [] - chat_turn_sep = chat_sep = "" - # plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token - humanstr = None - botstr = None - elif prompt_type == "simple_instruct": - promptA = promptB = PreInstruct = PreInput = PreResponse = None - terminate_response = [] - chat_turn_sep = chat_sep = "\n" - humanstr = None - botstr = None - elif prompt_type in [ - PromptType.instruct.value, - str(PromptType.instruct.value), - PromptType.instruct.name, - ] + [ - PromptType.instruct_with_end.value, - str(PromptType.instruct_with_end.value), - PromptType.instruct_with_end.name, - ]: - promptA = ( - "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n" - if not (chat and reduced) - else "" - ) - promptB = ( - "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" - if not (chat and reduced) - else "" - ) - - PreInstruct = """ -### Instruction: -""" - - PreInput = """ -### Input: -""" - - PreResponse = """ -### Response: -""" - if prompt_type in [ - PromptType.instruct_with_end.value, - str(PromptType.instruct_with_end.value), - PromptType.instruct_with_end.name, - ]: - terminate_response = ["### End"] - else: - terminate_response = None - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.quality.value, - str(PromptType.quality.value), - PromptType.quality.name, - ]: - promptA = ( - "Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n" - if not (chat and reduced) - else "" - ) - promptB = ( - "Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n" - if not (chat and reduced) - else "" - ) - - PreInstruct = """ -### Instruction: -""" - - PreInput = """ -### Input: -""" - - PreResponse = """ -### Response: -""" - terminate_response = None - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct # first thing human says - botstr = PreResponse # first thing bot says - elif prompt_type in [ - PromptType.human_bot.value, - str(PromptType.human_bot.value), - PromptType.human_bot.name, - ] + [ - PromptType.human_bot_orig.value, - str(PromptType.human_bot_orig.value), - PromptType.human_bot_orig.name, - ]: - human = ":" - bot = ":" - if ( - reduced - or context - or prompt_type - in [ - PromptType.human_bot.value, - str(PromptType.human_bot.value), - PromptType.human_bot.name, - ] - ): - preprompt = "" - else: - cur_date = time.strftime("%Y-%m-%d") - cur_time = time.strftime("%H:%M:%S %p %Z") - - PRE_PROMPT = """\ -Current Date: {} -Current Time: {} - -""" - preprompt = PRE_PROMPT.format(cur_date, cur_time) - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - - PreInstruct = human + " " - - PreInput = None - - if making_context: - # when making context, want it to appear as-if LLM generated, which starts with space after : - PreResponse = bot + " " - else: - # normally LLM adds space after this, because was how trained. - # if add space here, non-unique tokenization will often make LLM produce wrong output - PreResponse = bot - - terminate_response = [ - "\n" + human, - "\n" + bot, - human, - bot, - PreResponse, - ] - chat_turn_sep = chat_sep = "\n" - humanstr = human # tag before human talks - botstr = bot # tag before bot talks - generates_leading_space = True - elif prompt_type in [ - PromptType.dai_faq.value, - str(PromptType.dai_faq.value), - PromptType.dai_faq.name, - ]: - promptA = "" - promptB = "Answer the following Driverless AI question.\n" - - PreInstruct = """ -### Driverless AI frequently asked question: -""" - - PreInput = None - - PreResponse = """ -### Driverless AI documentation answer: -""" - terminate_response = ["\n\n"] - chat_turn_sep = chat_sep = terminate_response - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.summarize.value, - str(PromptType.summarize.value), - PromptType.summarize.name, - ]: - promptA = promptB = PreInput = "" - PreInstruct = "## Main Text\n\n" - PreResponse = "\n\n## Summary\n\n" - terminate_response = None - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.instruct_vicuna.value, - str(PromptType.instruct_vicuna.value), - PromptType.instruct_vicuna.name, - ]: - promptA = promptB = ( - "A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's questions." - if not (chat and reduced) - else "" - ) - - PreInstruct = """ -### Human: -""" - - PreInput = None - - PreResponse = """ -### Assistant: -""" - terminate_response = [ - "### Human:" - ] # but only allow terminate after prompt is found correctly, else can't terminate - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.prompt_answer.value, - str(PromptType.prompt_answer.value), - PromptType.prompt_answer.name, - ]: - preprompt = "" - prompt_tokens = "<|prompt|>" - answer_tokens = "<|answer|>" - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = prompt_tokens - PreInput = None - PreResponse = answer_tokens - eos = "<|endoftext|>" # neox eos - humanstr = prompt_tokens - botstr = answer_tokens - terminate_response = [humanstr, PreResponse, eos] - chat_sep = eos - chat_turn_sep = eos - elif prompt_type in [ - PromptType.prompt_answer_openllama.value, - str(PromptType.prompt_answer_openllama.value), - PromptType.prompt_answer_openllama.name, - ]: - preprompt = "" - prompt_tokens = "<|prompt|>" - answer_tokens = "<|answer|>" - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = prompt_tokens - PreInput = None - PreResponse = answer_tokens - eos = "" # llama eos - humanstr = prompt_tokens - botstr = answer_tokens - terminate_response = [humanstr, PreResponse, eos] - chat_sep = eos - chat_turn_sep = eos - elif prompt_type in [ - PromptType.open_assistant.value, - str(PromptType.open_assistant.value), - PromptType.open_assistant.name, - ]: - # From added_tokens.json - preprompt = "" - prompt_tokens = "<|prompter|>" - answer_tokens = "<|assistant|>" - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = prompt_tokens - PreInput = None - PreResponse = answer_tokens - pend = "<|prefix_end|>" - eos = "" - humanstr = prompt_tokens - botstr = answer_tokens - terminate_response = [humanstr, PreResponse, pend, eos] - chat_turn_sep = chat_sep = eos - elif prompt_type in [ - PromptType.wizard_lm.value, - str(PromptType.wizard_lm.value), - PromptType.wizard_lm.name, - ]: - # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py - preprompt = "" - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = "" - PreInput = None - PreResponse = "\n\n### Response\n" - eos = "" - terminate_response = [PreResponse, eos] - chat_turn_sep = chat_sep = eos - humanstr = promptA - botstr = PreResponse - elif prompt_type in [ - PromptType.wizard_mega.value, - str(PromptType.wizard_mega.value), - PromptType.wizard_mega.name, - ]: - preprompt = "" - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = """ -### Instruction: -""" - PreInput = None - PreResponse = """ -### Assistant: -""" - terminate_response = [PreResponse] - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.instruct_vicuna2.value, - str(PromptType.instruct_vicuna2.value), - PromptType.instruct_vicuna2.name, - ]: - promptA = promptB = "" if not (chat and reduced) else "" - - PreInstruct = """ -HUMAN: -""" - - PreInput = None - - PreResponse = """ -ASSISTANT: -""" - terminate_response = [ - "HUMAN:" - ] # but only allow terminate after prompt is found correctly, else can't terminate - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.instruct_vicuna3.value, - str(PromptType.instruct_vicuna3.value), - PromptType.instruct_vicuna3.name, - ]: - promptA = promptB = "" if not (chat and reduced) else "" - - PreInstruct = """ -### User: -""" - - PreInput = None - - PreResponse = """ -### Assistant: -""" - terminate_response = [ - "### User:" - ] # but only allow terminate after prompt is found correctly, else can't terminate - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.wizard2.value, - str(PromptType.wizard2.value), - PromptType.wizard2.name, - ]: - # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML - preprompt = ( - """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" - if not (chat and reduced) - else "" - ) - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = """ -### Instruction: -""" - PreInput = None - PreResponse = """ -### Response: -""" - terminate_response = [PreResponse] - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.wizard3.value, - str(PromptType.wizard3.value), - PromptType.wizard3.name, - ]: - # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML - preprompt = ( - """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" - if not (chat and reduced) - else "" - ) - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = """USER: """ - PreInput = None - PreResponse = """ASSISTANT: """ - terminate_response = [PreResponse] - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.wizard_vicuna.value, - str(PromptType.wizard_vicuna.value), - PromptType.wizard_vicuna.name, - ]: - preprompt = "" - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = """USER: """ - PreInput = None - PreResponse = """ASSISTANT: """ - terminate_response = [PreResponse] - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - - elif prompt_type in [ - PromptType.instruct_simple.value, - str(PromptType.instruct_simple.value), - PromptType.instruct_simple.name, - ]: - promptB = promptA = "" if not (chat and reduced) else "" - - PreInstruct = """ -### Instruction: -""" - - PreInput = """ -### Input: -""" - - PreResponse = """ -### Response: -""" - terminate_response = None - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.openai.value, - str(PromptType.openai.value), - PromptType.openai.name, - ]: - preprompt = ( - """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" - if not (chat and reduced) - else "" - ) - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = "\nHuman: " - PreInput = None - PreResponse = "\nAI:" - terminate_response = [PreResponse] + [" Human:", " AI:"] - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.gptj.value, - str(PromptType.gptj.value), - PromptType.gptj.name, - ]: - preprompt = ( - "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." - if not (chat and reduced) - else "" - ) - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = "\n### Prompt: " - PreInput = None - PreResponse = "\n### Response: " - terminate_response = [PreResponse] + ["Prompt:", "Response:"] - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.openai_chat.value, - str(PromptType.openai_chat.value), - PromptType.openai_chat.name, - ]: - # prompting and termination all handled by endpoint - preprompt = """""" - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - PreInstruct = "" - PreInput = None - PreResponse = "" - terminate_response = [] - chat_turn_sep = chat_sep = "\n" - humanstr = None - botstr = None - elif prompt_type in [ - PromptType.vicuna11.value, - str(PromptType.vicuna11.value), - PromptType.vicuna11.name, - ]: - preprompt = ( - """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ - if not (chat and reduced) - else "" - ) - start = "" - promptB = promptA = "%s%s" % (preprompt, start) - eos = "" - PreInstruct = """USER: """ - PreInput = None - PreResponse = """ASSISTANT:""" - terminate_response = [PreResponse] - chat_sep = " " - chat_turn_sep = eos - humanstr = PreInstruct - botstr = PreResponse - - if making_context: - # when making context, want it to appear as-if LLM generated, which starts with space after : - PreResponse = PreResponse + " " - else: - # normally LLM adds space after this, because was how trained. - # if add space here, non-unique tokenization will often make LLM produce wrong output - PreResponse = PreResponse - elif prompt_type in [ - PromptType.mptinstruct.value, - str(PromptType.mptinstruct.value), - PromptType.mptinstruct.name, - ]: - # https://huggingface.co/mosaicml/mpt-30b-instruct#formatting - promptA = promptB = ( - "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n" - if not (chat and reduced) - else "" - ) - - PreInstruct = """ -### Instruction -""" - - PreInput = """ -### Input -""" - - PreResponse = """ -### Response -""" - terminate_response = None - chat_turn_sep = chat_sep = "\n" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.mptchat.value, - str(PromptType.mptchat.value), - PromptType.mptchat.name, - ]: - # https://huggingface.co/TheBloke/mpt-30B-chat-GGML#prompt-template - promptA = promptB = ( - """<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\n<|im_end|>""" - if not (chat and reduced) - else "" - ) - - PreInstruct = """<|im_start|>user -""" - - PreInput = None - - PreResponse = """<|im_end|><|im_start|>assistant -""" - terminate_response = ["<|im_end|>"] - chat_sep = "" - chat_turn_sep = "<|im_end|>" - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [ - PromptType.falcon.value, - str(PromptType.falcon.value), - PromptType.falcon.name, - ]: - promptA = promptB = "" if not (chat and reduced) else "" - - PreInstruct = """User: """ - - PreInput = None - - PreResponse = """Assistant:""" - terminate_response = ["\nUser", "<|endoftext|>"] - chat_sep = "\n\n" - chat_turn_sep = "\n\n" - humanstr = PreInstruct - botstr = PreResponse - if making_context: - # when making context, want it to appear as-if LLM generated, which starts with space after : - PreResponse = "Assistant: " - else: - # normally LLM adds space after this, because was how trained. - # if add space here, non-unique tokenization will often make LLM produce wrong output - PreResponse = PreResponse - # generates_leading_space = True - else: - raise RuntimeError("No such prompt_type=%s" % prompt_type) - - if isinstance(terminate_response, (tuple, list)): - assert "" not in terminate_response, "Bad terminate_response" - - ret_dict = dict( - promptA=promptA, - promptB=promptB, - PreInstruct=PreInstruct, - PreInput=PreInput, - PreResponse=PreResponse, - terminate_response=terminate_response, - chat_sep=chat_sep, - chat_turn_sep=chat_turn_sep, - humanstr=humanstr, - botstr=botstr, - generates_leading_space=generates_leading_space, - ) - - if return_dict: - return ret_dict, prompt_dict_error - else: - return tuple(list(ret_dict.values())) - - -def generate_prompt( - data_point, prompt_type, prompt_dict, chat, reduced, making_context -): - context = data_point.get("context") - if context is None: - context = "" - instruction = data_point.get("instruction") - input = data_point.get("input") - output = data_point.get("output") - prompt_type = data_point.get("prompt_type", prompt_type) - prompt_dict = data_point.get("prompt_dict", prompt_dict) - assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type - ( - promptA, - promptB, - PreInstruct, - PreInput, - PreResponse, - terminate_response, - chat_sep, - chat_turn_sep, - humanstr, - botstr, - generates_leading_space, - ) = get_prompt( - prompt_type, prompt_dict, chat, context, reduced, making_context - ) - - # could avoid if reduce=True, but too complex for parent functions to handle - prompt = context - - if input and promptA: - prompt += f"""{promptA}""" - elif promptB: - prompt += f"""{promptB}""" - - if ( - instruction - and PreInstruct is not None - and input - and PreInput is not None - ): - prompt += f"""{PreInstruct}{instruction}{PreInput}{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif ( - instruction and input and PreInstruct is None and PreInput is not None - ): - prompt += f"""{PreInput}{instruction} -{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif ( - input and instruction and PreInput is None and PreInstruct is not None - ): - prompt += f"""{PreInstruct}{instruction} -{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif instruction and PreInstruct is not None: - prompt += f"""{PreInstruct}{instruction}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and PreInput is not None: - prompt += f"""{PreInput}{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and instruction and PreInput is not None: - prompt += f"""{PreInput}{instruction}{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and instruction and PreInstruct is not None: - prompt += f"""{PreInstruct}{instruction}{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and instruction: - # i.e. for simple_instruct - prompt += f"""{instruction}: {input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input: - prompt += f"""{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif instruction: - prompt += f"""{instruction}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - - if PreResponse is not None: - prompt += f"""{PreResponse}""" - pre_response = PreResponse # Don't use strip - else: - pre_response = "" - - if output: - prompt += f"""{output}""" - - return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep - - -def inject_chatsep(prompt_type, prompt, chat_sep=None): - if chat_sep: - # only add new line if structured prompt, while 'plain' is just generation of next tokens from input - prompt += chat_sep - return prompt - - -class Prompter(object): - def __init__( - self, - prompt_type, - prompt_dict, - debug=False, - chat=False, - stream_output=False, - repeat_penalty=True, - allowed_repeat_line_length=10, - ): - self.prompt_type = prompt_type - self.prompt_dict = prompt_dict - self.debug = debug - self.chat = chat - self.stream_output = stream_output - self.repeat_penalty = repeat_penalty - self.allowed_repeat_line_length = allowed_repeat_line_length - self.prompt = None - context = "" # not for chat context - reduced = False # not for chat context - making_context = False # not for chat context - ( - self.promptA, - self.promptB, - self.PreInstruct, - self.PreInput, - self.PreResponse, - self.terminate_response, - self.chat_sep, - self.chat_turn_sep, - self.humanstr, - self.botstr, - self.generates_leading_space, - ) = get_prompt( - self.prompt_type, - self.prompt_dict, - chat, - context, - reduced, - making_context, - ) - self.pre_response = self.PreResponse - - def generate_prompt(self, data_point, reduced=None): - """ - data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt - :param data_point: - :param reduced: - :return: - """ - reduced = ( - data_point.get("context") not in ["", None] - if reduced is None - else reduced - ) - making_context = False # whether really making final prompt or just generating context - prompt, _, _, _, _ = generate_prompt( - data_point, - self.prompt_type, - self.prompt_dict, - self.chat, - reduced, - making_context, - ) - if self.debug: - print("prompt: %s" % prompt, flush=True) - # if have context, should have always reduced and only preappend promptA/B here - if data_point.get("context"): - if data_point.get("input") and self.promptA: - prompt = self.promptA + prompt - elif self.promptB: - prompt = self.promptB + prompt - - self.prompt = prompt - return prompt - - def get_response(self, outputs, prompt=None, sanitize_bot_response=False): - if isinstance(outputs, str): - outputs = [outputs] - if self.debug: - print("output:\n%s" % "\n\n".join(outputs), flush=True) - if prompt is not None: - self.prompt = prompt - - def clean_response(response): - meaningless_words = ["", "", "<|endoftext|>"] - for word in meaningless_words: - response = response.replace(word, "") - if sanitize_bot_response: - from better_profanity import profanity - - response = profanity.censor(response) - if ( - self.generates_leading_space - and isinstance(response, str) - and len(response) > 0 - and response[0] == " " - ): - response = response[1:] - return response - - def clean_repeats(response): - lines = response.split("\n") - new_lines = [] - [ - new_lines.append(line) - for line in lines - if line not in new_lines - or len(line) < self.allowed_repeat_line_length - ] - if self.debug and len(lines) != len(new_lines): - print( - "cleaned repeats: %s %s" % (len(lines), len(new_lines)), - flush=True, - ) - response = "\n".join(new_lines) - return response - - multi_output = len(outputs) > 1 - - for oi, output in enumerate(outputs): - if self.prompt_type in [ - PromptType.plain.value, - str(PromptType.plain.value), - PromptType.plain.name, - ]: - output = clean_response(output) - elif prompt is None: - # then use most basic parsing like pipeline - if not self.botstr: - pass - elif self.botstr in output: - if self.humanstr: - output = clean_response( - output.split(self.botstr)[1].split(self.humanstr)[ - 0 - ] - ) - else: - # i.e. use after bot but only up to next bot - output = clean_response( - output.split(self.botstr)[1].split(self.botstr)[0] - ) - else: - # output = clean_response(output) - # assume just not printed yet - output = "" - else: - # find first instance of prereponse - # prompt sometimes has odd characters, that mutate length, - # so can't go by length alone - if self.pre_response: - outputi = output.find(prompt) - if outputi >= 0: - output = output[outputi + len(prompt) :] - allow_terminate = True - else: - # subtraction is risky due to space offsets sometimes, so only do if necessary - output = output[len(prompt) - len(self.pre_response) :] - # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat) - if self.pre_response in output: - output = output.split(self.pre_response)[1] - allow_terminate = True - else: - if output: - print( - "Failure of parsing or not enough output yet: %s" - % output, - flush=True, - ) - allow_terminate = False - else: - allow_terminate = True - output = output[len(prompt) :] - # clean after subtract prompt out, so correct removal of pre_response - output = clean_response(output) - if self.repeat_penalty: - output = clean_repeats(output) - if self.terminate_response and allow_terminate: - finds = [] - for term in self.terminate_response: - finds.append(output.find(term)) - finds = [x for x in finds if x >= 0] - if len(finds) > 0: - termi = finds[0] - output = output[:termi] - else: - output = output - if multi_output: - # prefix with output counter - output = "\n=========== Output %d\n\n" % (1 + oi) + output - if oi > 0: - # post fix outputs with seperator - output += "\n" - outputs[oi] = output - # join all outputs, only one extra new line between outputs - output = "\n".join(outputs) - if self.debug: - print("outputclean:\n%s" % "\n\n".join(outputs), flush=True) - return output diff --git a/apps/language_models/langchain/read_wiki_full.py b/apps/language_models/langchain/read_wiki_full.py deleted file mode 100644 index 64507bc3..00000000 --- a/apps/language_models/langchain/read_wiki_full.py +++ /dev/null @@ -1,403 +0,0 @@ -"""Load Data from a MediaWiki dump xml.""" -import ast -import glob -import pickle -import uuid -from typing import List, Optional -import os -import bz2 -import csv -import numpy as np -import pandas as pd -import pytest -from matplotlib import pyplot as plt - -from langchain.docstore.document import Document -from langchain.document_loaders import MWDumpLoader - -# path where downloaded wiki files exist, to be processed -root_path = "/data/jon/h2o-llm" - - -def unescape(x): - try: - x = ast.literal_eval(x) - except: - try: - x = x.encode("ascii", "ignore").decode("unicode_escape") - except: - pass - return x - - -def get_views(): - # views = pd.read_csv('wiki_page_views_more_1000month.csv') - views = pd.read_csv("wiki_page_views_more_5000month.csv") - views.index = views["title"] - views = views["views"] - views = views.to_dict() - views = {str(unescape(str(k))): v for k, v in views.items()} - views2 = {k.replace("_", " "): v for k, v in views.items()} - # views has _ but pages has " " - views.update(views2) - return views - - -class MWDumpDirectLoader(MWDumpLoader): - def __init__( - self, - data: str, - encoding: Optional[str] = "utf8", - title_words_limit=None, - use_views=True, - verbose=True, - ): - """Initialize with file path.""" - self.data = data - self.encoding = encoding - self.title_words_limit = title_words_limit - self.verbose = verbose - if use_views: - # self.views = get_views() - # faster to use global shared values - self.views = global_views - else: - self.views = None - - def load(self) -> List[Document]: - """Load from file path.""" - import mwparserfromhell - import mwxml - - dump = mwxml.Dump.from_page_xml(self.data) - - docs = [] - - for page in dump.pages: - if self.views is not None and page.title not in self.views: - if self.verbose: - print("Skipped %s low views" % page.title, flush=True) - continue - for revision in page: - if self.title_words_limit is not None: - num_words = len(" ".join(page.title.split("_")).split(" ")) - if num_words > self.title_words_limit: - if self.verbose: - print("Skipped %s" % page.title, flush=True) - continue - if self.verbose: - if self.views is not None: - print( - "Kept %s views: %s" - % (page.title, self.views[page.title]), - flush=True, - ) - else: - print("Kept %s" % page.title, flush=True) - - code = mwparserfromhell.parse(revision.text) - text = code.strip_code( - normalize=True, collapse=True, keep_template_params=False - ) - title_url = str(page.title).replace(" ", "_") - metadata = dict( - title=page.title, - source="https://en.wikipedia.org/wiki/" + title_url, - id=page.id, - redirect=page.redirect, - views=self.views[page.title] - if self.views is not None - else -1, - ) - metadata = {k: v for k, v in metadata.items() if v is not None} - docs.append(Document(page_content=text, metadata=metadata)) - - return docs - - -def search_index(search_term, index_filename): - byte_flag = False - data_length = start_byte = 0 - index_file = open(index_filename, "r") - csv_reader = csv.reader(index_file, delimiter=":") - for line in csv_reader: - if not byte_flag and search_term == line[2]: - start_byte = int(line[0]) - byte_flag = True - elif byte_flag and int(line[0]) != start_byte: - data_length = int(line[0]) - start_byte - break - index_file.close() - return start_byte, data_length - - -def get_start_bytes(index_filename): - index_file = open(index_filename, "r") - csv_reader = csv.reader(index_file, delimiter=":") - start_bytes = set() - for line in csv_reader: - start_bytes.add(int(line[0])) - index_file.close() - return sorted(start_bytes) - - -def get_wiki_filenames(): - # requires - # wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2 - base_path = os.path.join( - root_path, "enwiki-20230401-pages-articles-multistream" - ) - index_file = "enwiki-20230401-pages-articles-multistream-index.txt" - index_filename = os.path.join(base_path, index_file) - wiki_filename = os.path.join( - base_path, "enwiki-20230401-pages-articles-multistream.xml.bz2" - ) - return index_filename, wiki_filename - - -def get_documents_by_search_term(search_term): - index_filename, wiki_filename = get_wiki_filenames() - start_byte, data_length = search_index(search_term, index_filename) - with open(wiki_filename, "rb") as wiki_file: - wiki_file.seek(start_byte) - data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length)) - - loader = MWDumpDirectLoader(data.decode()) - documents = loader.load() - return documents - - -def get_one_chunk( - wiki_filename, - start_byte, - end_byte, - return_file=True, - title_words_limit=None, - use_views=True, -): - data_length = end_byte - start_byte - with open(wiki_filename, "rb") as wiki_file: - wiki_file.seek(start_byte) - data = bz2.BZ2Decompressor().decompress(wiki_file.read(data_length)) - - loader = MWDumpDirectLoader( - data.decode(), title_words_limit=title_words_limit, use_views=use_views - ) - documents1 = loader.load() - if return_file: - base_tmp = "temp_wiki" - if not os.path.isdir(base_tmp): - os.makedirs(base_tmp, exist_ok=True) - filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle") - with open(filename, "wb") as f: - pickle.dump(documents1, f) - return filename - return documents1 - - -from joblib import Parallel, delayed - -global_views = get_views() - - -def get_all_documents(small_test=2, n_jobs=None, use_views=True): - print("DO get all wiki docs: %s" % small_test, flush=True) - index_filename, wiki_filename = get_wiki_filenames() - start_bytes = get_start_bytes(index_filename) - end_bytes = start_bytes[1:] - start_bytes = start_bytes[:-1] - - if small_test: - start_bytes = start_bytes[:small_test] - end_bytes = end_bytes[:small_test] - if n_jobs is None: - n_jobs = 5 - else: - if n_jobs is None: - n_jobs = os.cpu_count() // 4 - - # default loky backend leads to name space conflict problems - return_file = True # large return from joblib hangs - documents = Parallel(n_jobs=n_jobs, verbose=10, backend="multiprocessing")( - delayed(get_one_chunk)( - wiki_filename, - start_byte, - end_byte, - return_file=return_file, - use_views=use_views, - ) - for start_byte, end_byte in zip(start_bytes, end_bytes) - ) - if return_file: - # then documents really are files - files = documents.copy() - documents = [] - for fil in files: - with open(fil, "rb") as f: - documents.extend(pickle.load(f)) - os.remove(fil) - else: - from functools import reduce - from operator import concat - - documents = reduce(concat, documents) - assert isinstance(documents, list) - - print("DONE get all wiki docs", flush=True) - return documents - - -def test_by_search_term(): - search_term = "Apollo" - assert len(get_documents_by_search_term(search_term)) == 100 - - search_term = "Abstract (law)" - assert len(get_documents_by_search_term(search_term)) == 100 - - search_term = "Artificial languages" - assert len(get_documents_by_search_term(search_term)) == 100 - - -def test_start_bytes(): - index_filename, wiki_filename = get_wiki_filenames() - assert len(get_start_bytes(index_filename)) == 227850 - - -def test_get_all_documents(): - small_test = 20 # 227850 - n_jobs = os.cpu_count() // 4 - - assert ( - len( - get_all_documents( - small_test=small_test, n_jobs=n_jobs, use_views=False - ) - ) - == small_test * 100 - ) - - assert ( - len( - get_all_documents( - small_test=small_test, n_jobs=n_jobs, use_views=True - ) - ) - == 429 - ) - - -def get_one_pageviews(fil): - df1 = pd.read_csv( - fil, - sep=" ", - header=None, - names=["region", "title", "views", "foo"], - quoting=csv.QUOTE_NONE, - ) - df1.index = df1["title"] - df1 = df1[df1["region"] == "en"] - df1 = df1.drop("region", axis=1) - df1 = df1.drop("foo", axis=1) - df1 = df1.drop("title", axis=1) # already index - - base_tmp = "temp_wiki_pageviews" - if not os.path.isdir(base_tmp): - os.makedirs(base_tmp, exist_ok=True) - filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.csv") - df1.to_csv(filename, index=True) - return filename - - -def test_agg_pageviews(gen_files=False): - if gen_files: - path = os.path.join( - root_path, - "wiki_pageviews/dumps.wikimedia.org/other/pageviews/2023/2023-04", - ) - files = glob.glob(os.path.join(path, "pageviews*.gz")) - # files = files[:2] # test - n_jobs = os.cpu_count() // 2 - csv_files = Parallel( - n_jobs=n_jobs, verbose=10, backend="multiprocessing" - )(delayed(get_one_pageviews)(fil) for fil in files) - else: - # to continue without redoing above - csv_files = glob.glob( - os.path.join(root_path, "temp_wiki_pageviews/*.csv") - ) - - df_list = [] - for csv_file in csv_files: - print(csv_file) - df1 = pd.read_csv(csv_file) - df_list.append(df1) - df = pd.concat(df_list, axis=0) - df = df.groupby("title")["views"].sum().reset_index() - df.to_csv("wiki_page_views.csv", index=True) - - -def test_reduce_pageview(): - filename = "wiki_page_views.csv" - df = pd.read_csv(filename) - df = df[df["views"] < 1e7] - # - plt.hist(df["views"], bins=100, log=True) - views_avg = np.mean(df["views"]) - views_median = np.median(df["views"]) - plt.title("Views avg: %s median: %s" % (views_avg, views_median)) - plt.savefig(filename.replace(".csv", ".png")) - plt.close() - # - views_limit = 5000 - df = df[df["views"] > views_limit] - filename = "wiki_page_views_more_5000month.csv" - df.to_csv(filename, index=True) - # - plt.hist(df["views"], bins=100, log=True) - views_avg = np.mean(df["views"]) - views_median = np.median(df["views"]) - plt.title("Views avg: %s median: %s" % (views_avg, views_median)) - plt.savefig(filename.replace(".csv", ".png")) - plt.close() - - -@pytest.mark.skip("Only if doing full processing again, some manual steps") -def test_do_wiki_full_all(): - # Install other requirements for wiki specific conversion: - # pip install -r reqs_optional/requirements_optional_wikiprocessing.txt - - # Use "Transmission" in Ubuntu to get wiki dump using torrent: - # See: https://meta.wikimedia.org/wiki/Data_dump_torrents - # E.g. magnet:?xt=urn:btih:b2c74af2b1531d0b63f1166d2011116f44a8fed0&dn=enwiki-20230401-pages-articles-multistream.xml.bz2&tr=udp%3A%2F%2Ftracker.opentrackr.org%3A1337 - - # Get index - os.system( - "wget http://ftp.acc.umu.se/mirror/wikimedia.org/dumps/enwiki/20230401/enwiki-20230401-pages-articles-multistream-index.txt.bz2" - ) - - # Test that can use LangChain to get docs from subset of wiki as sampled out of full wiki directly using bzip multistream - test_get_all_documents() - - # Check can search wiki multistream - test_by_search_term() - - # Test can get all start bytes in index - test_start_bytes() - - # Get page views, e.g. for entire month of April 2023 - os.system( - "wget -b -m -k -o wget.log -e robots=off https://dumps.wikimedia.org/other/pageviews/2023/2023-04/" - ) - - # Aggregate page views from many files into single file - test_agg_pageviews(gen_files=True) - - # Reduce page views to some limit, so processing of full wiki is not too large - test_reduce_pageview() - - # Start generate.py with requesting wiki_full in prep. This will use page views as referenced in get_views. - # Note get_views as global() function done once is required to avoid very slow processing - # WARNING: Requires alot of memory to handle, used up to 300GB system RAM at peak - """ - python generate.py --langchain_mode='wiki_full' --visible_langchain_modes="['wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']" &> lc_out.log - """ diff --git a/apps/language_models/langchain/stopping.py b/apps/language_models/langchain/stopping.py deleted file mode 100644 index 6c440203..00000000 --- a/apps/language_models/langchain/stopping.py +++ /dev/null @@ -1,121 +0,0 @@ -import torch -from transformers import StoppingCriteria, StoppingCriteriaList - -from enums import PromptType - - -class StoppingCriteriaSub(StoppingCriteria): - def __init__( - self, stops=[], encounters=[], device="cuda", model_max_length=None - ): - super().__init__() - assert ( - len(stops) % len(encounters) == 0 - ), "Number of stops and encounters must match" - self.encounters = encounters - self.stops = [stop.to(device) for stop in stops] - self.num_stops = [0] * len(stops) - self.model_max_length = model_max_length - - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs - ) -> bool: - for stopi, stop in enumerate(self.stops): - if torch.all((stop == input_ids[0][-len(stop) :])).item(): - self.num_stops[stopi] += 1 - if ( - self.num_stops[stopi] - >= self.encounters[stopi % len(self.encounters)] - ): - # print("Stopped", flush=True) - return True - if ( - self.model_max_length is not None - and input_ids[0].shape[0] >= self.model_max_length - ): - # critical limit - return True - # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True) - # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True) - return False - - -def get_stopping( - prompt_type, - prompt_dict, - tokenizer, - device, - human=":", - bot=":", - model_max_length=None, -): - # FIXME: prompt_dict unused currently - if prompt_type in [ - PromptType.human_bot.name, - PromptType.instruct_vicuna.name, - PromptType.instruct_with_end.name, - ]: - if prompt_type == PromptType.human_bot.name: - # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1] - # stopping only starts once output is beyond prompt - # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added - stop_words = [human, bot, "\n" + human, "\n" + bot] - encounters = [1, 2] - elif prompt_type == PromptType.instruct_vicuna.name: - # even below is not enough, generic strings and many ways to encode - stop_words = [ - "### Human:", - """ -### Human:""", - """ -### Human: -""", - "### Assistant:", - """ -### Assistant:""", - """ -### Assistant: -""", - ] - encounters = [1, 2] - else: - # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise - stop_words = ["### End"] - encounters = [1] - stop_words_ids = [ - tokenizer(stop_word, return_tensors="pt")["input_ids"].squeeze() - for stop_word in stop_words - ] - # handle single token case - stop_words_ids = [ - x if len(x.shape) > 0 else torch.tensor([x]) - for x in stop_words_ids - ] - stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0] - # avoid padding in front of tokens - if ( - tokenizer._pad_token - ): # use hidden variable to avoid annoying properly logger bug - stop_words_ids = [ - x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x - for x in stop_words_ids - ] - # handle fake \n added - stop_words_ids = [ - x[1:] if y[0] == "\n" else x - for x, y in zip(stop_words_ids, stop_words) - ] - # build stopper - stopping_criteria = StoppingCriteriaList( - [ - StoppingCriteriaSub( - stops=stop_words_ids, - encounters=encounters, - device=device, - model_max_length=model_max_length, - ) - ] - ) - else: - stopping_criteria = StoppingCriteriaList() - return stopping_criteria diff --git a/apps/language_models/langchain/utils.py b/apps/language_models/langchain/utils.py deleted file mode 100644 index 3935e9f5..00000000 --- a/apps/language_models/langchain/utils.py +++ /dev/null @@ -1,1070 +0,0 @@ -import contextlib -import functools -import hashlib -import inspect -import os -import gc -import pathlib -import random -import shutil -import subprocess -import sys -import threading -import time -import traceback -import zipfile -from datetime import datetime - -import filelock -import requests, uuid -from typing import Tuple, Callable, Dict -from tqdm.auto import tqdm -from joblib import Parallel -from concurrent.futures import ProcessPoolExecutor -import numpy as np -import pandas as pd - - -def set_seed(seed: int): - """ - Sets the seed of the entire notebook so results are the same every time we run. - This is for REPRODUCIBILITY. - """ - import torch - - np.random.seed(seed) - random_state = np.random.RandomState(seed) - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - os.environ["PYTHONHASHSEED"] = str(seed) - return random_state - - -def flatten_list(lis): - """Given a list, possibly nested to any level, return it flattened.""" - new_lis = [] - for item in lis: - if type(item) == type([]): - new_lis.extend(flatten_list(item)) - else: - new_lis.append(item) - return new_lis - - -def clear_torch_cache(): - import torch - - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - gc.collect() - - -def ping(): - try: - print("Ping: %s" % str(datetime.now()), flush=True) - except AttributeError: - # some programs wrap print and will fail with flush passed - pass - - -def ping_gpu(): - try: - print( - "Ping_GPU: %s %s" % (str(datetime.now()), system_info()), - flush=True, - ) - except AttributeError: - # some programs wrap print and will fail with flush passed - pass - try: - ping_gpu_memory() - except Exception as e: - print("Ping_GPU memory failure: %s" % str(e), flush=True) - - -def ping_gpu_memory(): - from models.gpu_mem_track import MemTracker - - gpu_tracker = MemTracker() # define a GPU tracker - from torch.cuda import memory_summary - - gpu_tracker.track() - - -def get_torch_allocated(): - import torch - - return torch.cuda.memory_allocated() - - -def get_device(): - import torch - - if torch.cuda.is_available(): - device = "cuda" - elif torch.backends.mps.is_built(): - device = "mps" - else: - device = "cpu" - - return device - - -def system_info(): - import psutil - - system = {} - # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard - # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749 - temps = psutil.sensors_temperatures(fahrenheit=False) - if "coretemp" in temps: - coretemp = temps["coretemp"] - temp_dict = {k.label: k.current for k in coretemp} - for k, v in temp_dict.items(): - system["CPU_C/%s" % k] = v - - # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt - try: - from pynvml.smi import nvidia_smi - - nvsmi = nvidia_smi.getInstance() - - gpu_power_dict = { - "W_gpu%d" % i: x["power_readings"]["power_draw"] - for i, x in enumerate(nvsmi.DeviceQuery("power.draw")["gpu"]) - } - for k, v in gpu_power_dict.items(): - system["GPU_W/%s" % k] = v - - gpu_temp_dict = { - "C_gpu%d" % i: x["temperature"]["gpu_temp"] - for i, x in enumerate(nvsmi.DeviceQuery("temperature.gpu")["gpu"]) - } - for k, v in gpu_temp_dict.items(): - system["GPU_C/%s" % k] = v - - gpu_memory_free_dict = { - "MiB_gpu%d" % i: x["fb_memory_usage"]["free"] - for i, x in enumerate(nvsmi.DeviceQuery("memory.free")["gpu"]) - } - gpu_memory_total_dict = { - "MiB_gpu%d" % i: x["fb_memory_usage"]["total"] - for i, x in enumerate(nvsmi.DeviceQuery("memory.total")["gpu"]) - } - gpu_memory_frac_dict = { - k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] - for k in gpu_memory_total_dict - } - for k, v in gpu_memory_frac_dict.items(): - system[f"GPU_M/%s" % k] = v - except ModuleNotFoundError: - pass - system["hash"] = get_githash() - - return system - - -def system_info_print(): - try: - df = pd.DataFrame.from_dict(system_info(), orient="index") - # avoid slamming GPUs - time.sleep(1) - return df.to_markdown() - except Exception as e: - return "Error: %s" % str(e) - - -def zip_data( - root_dirs=None, zip_file=None, base_dir="./", fail_any_exception=False -): - try: - return _zip_data( - zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs - ) - except Exception as e: - traceback.print_exc() - print("Exception in zipping: %s" % str(e)) - if not fail_any_exception: - raise - - -def _zip_data(root_dirs=None, zip_file=None, base_dir="./"): - if isinstance(root_dirs, str): - root_dirs = [root_dirs] - if zip_file is None: - datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_") - host_name = os.getenv("HF_HOSTNAME", "emptyhost") - zip_file = "data_%s_%s.zip" % (datetime_str, host_name) - assert root_dirs is not None - if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname( - zip_file - ): - os.makedirs(os.path.dirname(zip_file), exist_ok=True) - with zipfile.ZipFile(zip_file, "w") as expt_zip: - for root_dir in root_dirs: - if root_dir is None: - continue - for root, d, files in os.walk(root_dir): - for file in files: - file_to_archive = os.path.join(root, file) - assert os.path.exists(file_to_archive) - path_to_archive = os.path.relpath( - file_to_archive, base_dir - ) - expt_zip.write( - filename=file_to_archive, arcname=path_to_archive - ) - return zip_file, zip_file - - -def save_generate_output( - prompt=None, - output=None, - base_model=None, - save_dir=None, - where_from="unknown where from", - extra_dict={}, -): - try: - return _save_generate_output( - prompt=prompt, - output=output, - base_model=base_model, - save_dir=save_dir, - where_from=where_from, - extra_dict=extra_dict, - ) - except Exception as e: - traceback.print_exc() - print("Exception in saving: %s" % str(e)) - - -def _save_generate_output( - prompt=None, - output=None, - base_model=None, - save_dir=None, - where_from="unknown where from", - extra_dict={}, -): - """ - Save conversation to .json, row by row. - json_file_path is path to final JSON file. If not in ., then will attempt to make directories. - Appends if file exists - """ - prompt = "" if prompt is None else prompt - output = "" if output is None else output - assert save_dir, "save_dir must be provided" - if os.path.exists(save_dir) and not os.path.isdir(save_dir): - raise RuntimeError("save_dir already exists and is not a directory!") - os.makedirs(save_dir, exist_ok=True) - import json - - dict_to_save = dict( - prompt=prompt, - text=output, - time=time.ctime(), - base_model=base_model, - where_from=where_from, - ) - dict_to_save.update(extra_dict) - with filelock.FileLock("save_dir.lock"): - # lock logging in case have concurrency - with open(os.path.join(save_dir, "history.json"), "a") as f: - # just add [ at start, and ] at end, and have proper JSON dataset - f.write(" " + json.dumps(dict_to_save) + ",\n") - - -def s3up(filename): - try: - return _s3up(filename) - except Exception as e: - traceback.print_exc() - print("Exception for file %s in s3up: %s" % (filename, str(e))) - return "Failed to upload %s: Error: %s" % (filename, str(e)) - - -def _s3up(filename): - import boto3 - - aws_access_key_id = os.getenv("AWS_SERVER_PUBLIC_KEY") - aws_secret_access_key = os.getenv("AWS_SERVER_SECRET_KEY") - bucket = os.getenv("AWS_BUCKET") - assert aws_access_key_id, "Set AWS key" - assert aws_secret_access_key, "Set AWS secret" - assert bucket, "Set AWS Bucket" - - s3 = boto3.client( - "s3", - aws_access_key_id=os.getenv("AWS_SERVER_PUBLIC_KEY"), - aws_secret_access_key=os.getenv("AWS_SERVER_SECRET_KEY"), - ) - ret = s3.upload_file( - Filename=filename, - Bucket=os.getenv("AWS_BUCKET"), - Key=filename, - ) - if ret in [None, ""]: - return "Successfully uploaded %s" % filename - - -def get_githash(): - try: - githash = subprocess.run( - ["git", "rev-parse", "HEAD"], stdout=subprocess.PIPE - ).stdout.decode("utf-8")[0:-1] - except: - githash = "" - return githash - - -def copy_code(run_id): - """ - copy code to track changes - :param run_id: - :return: - """ - rnd_num = str(random.randint(0, 2**31)) - run_id = "run_" + str(run_id) - os.makedirs(run_id, exist_ok=True) - me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__) - me_file = os.path.basename(__file__) - new_me = os.path.join(run_id, me_file + "_" + get_githash()) - if os.path.isfile(new_me): - new_me = os.path.join( - run_id, me_file + "_" + get_githash() + "_" + rnd_num - ) - shutil.copy(me_full, new_me) - else: - shutil.copy(me_full, new_me) - - -class NullContext(threading.local): - """No-op context manager, executes block without doing any additional processing. - - Used as a stand-in if a particular block of code is only sometimes - used with a normal context manager: - """ - - def __init__(self, *args, **kwargs): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self.finally_act() - - def finally_act(self): - pass - - -def wrapped_partial(func, *args, **kwargs): - """ - Give partial properties of normal function, like __name__ attribute etc. - :param func: - :param args: - :param kwargs: - :return: - """ - partial_func = functools.partial(func, *args, **kwargs) - functools.update_wrapper(partial_func, func) - return partial_func - - -class ThreadException(Exception): - pass - - -class EThread(threading.Thread): - # Function that raises the custom exception - def __init__( - self, - group=None, - target=None, - name=None, - args=(), - kwargs=None, - *, - daemon=None, - streamer=None, - bucket=None, - ): - self.bucket = bucket - self.streamer = streamer - self.exc = None - self._return = None - super().__init__( - group=group, - target=target, - name=name, - args=args, - kwargs=kwargs, - daemon=daemon, - ) - - def run(self): - # Variable that stores the exception, if raised by someFunction - try: - if self._target is not None: - self._return = self._target(*self._args, **self._kwargs) - except BaseException as e: - print("thread exception: %s" % str(sys.exc_info())) - self.bucket.put(sys.exc_info()) - self.exc = e - if self.streamer: - print("make stop: %s" % str(sys.exc_info()), flush=True) - self.streamer.do_stop = True - finally: - # Avoid a refcycle if the thread is running a function with - # an argument that has a member that points to the thread. - del self._target, self._args, self._kwargs - - def join(self, timeout=None): - threading.Thread.join(self) - # Since join() returns in caller thread - # we re-raise the caught exception - # if any was caught - if self.exc: - raise self.exc - return self._return - - -def import_matplotlib(): - import matplotlib - - matplotlib.use("agg") - # KEEP THESE HERE! START - import matplotlib.pyplot as plt - import pandas as pd - - # to avoid dlopen deadlock in fork - import pandas.core.computation.expressions as pd_expressions - import pandas._libs.groupby as pd_libgroupby - import pandas._libs.reduction as pd_libreduction - import pandas.core.algorithms as pd_algorithms - import pandas.core.common as pd_com - import numpy as np - - # KEEP THESE HERE! END - - -def get_sha(value): - return hashlib.md5(str(value).encode("utf-8")).hexdigest() - - -def sanitize_filename(name): - """ - Sanitize file *base* names. - :param name: name to sanitize - :return: - """ - bad_chars = [ - "[", - "]", - ",", - "/", - "\\", - "\\w", - "\\s", - "-", - "+", - '"', - "'", - ">", - "<", - " ", - "=", - ")", - "(", - ":", - "^", - ] - for char in bad_chars: - name = name.replace(char, "_") - - length = len(name) - file_length_limit = 250 # bit smaller than 256 for safety - sha_length = 32 - real_length_limit = file_length_limit - (sha_length + 2) - if length > file_length_limit: - sha = get_sha(name) - half_real_length_limit = max(1, int(real_length_limit / 2)) - name = ( - name[0:half_real_length_limit] - + "_" - + sha - + "_" - + name[length - half_real_length_limit : length] - ) - - return name - - -def shutil_rmtree(*args, **kwargs): - return shutil.rmtree(*args, **kwargs) - - -def remove(path: str): - try: - if path is not None and os.path.exists(path): - if os.path.isdir(path): - shutil_rmtree(path, ignore_errors=True) - else: - with contextlib.suppress(FileNotFoundError): - os.remove(path) - except: - pass - - -def makedirs(path, exist_ok=True): - """ - Avoid some inefficiency in os.makedirs() - :param path: - :param exist_ok: - :return: - """ - if os.path.isdir(path) and os.path.exists(path): - assert exist_ok, "Path already exists" - return path - os.makedirs(path, exist_ok=exist_ok) - - -def atomic_move_simple(src, dst): - try: - shutil.move(src, dst) - except (shutil.Error, FileExistsError): - pass - remove(src) - - -def download_simple(url, dest=None, print_func=None): - if print_func is not None: - print_func("BEGIN get url %s" % str(url)) - if url.startswith("file://"): - from requests_file import FileAdapter - - s = requests.Session() - s.mount("file://", FileAdapter()) - url_data = s.get(url, stream=True) - else: - url_data = requests.get(url, stream=True) - if dest is None: - dest = os.path.basename(url) - if url_data.status_code != requests.codes.ok: - msg = "Cannot get url %s, code: %s, reason: %s" % ( - str(url), - str(url_data.status_code), - str(url_data.reason), - ) - raise requests.exceptions.RequestException(msg) - url_data.raw.decode_content = True - makedirs(os.path.dirname(dest), exist_ok=True) - uuid_tmp = str(uuid.uuid4())[:6] - dest_tmp = dest + "_dl_" + uuid_tmp + ".tmp" - with open(dest_tmp, "wb") as f: - shutil.copyfileobj(url_data.raw, f) - atomic_move_simple(dest_tmp, dest) - if print_func is not None: - print_func("END get url %s" % str(url)) - - -def download(url, dest=None, dest_path=None): - if dest_path is not None: - dest = os.path.join(dest_path, os.path.basename(url)) - if os.path.isfile(dest): - print("already downloaded %s -> %s" % (url, dest)) - return dest - elif dest is not None: - if os.path.exists(dest): - print("already downloaded %s -> %s" % (url, dest)) - return dest - else: - uuid_tmp = "dl2_" + str(uuid.uuid4())[:6] - dest = uuid_tmp + os.path.basename(url) - - print("downloading %s to %s" % (url, dest)) - - if url.startswith("file://"): - from requests_file import FileAdapter - - s = requests.Session() - s.mount("file://", FileAdapter()) - url_data = s.get(url, stream=True) - else: - url_data = requests.get(url, stream=True) - - if url_data.status_code != requests.codes.ok: - msg = "Cannot get url %s, code: %s, reason: %s" % ( - str(url), - str(url_data.status_code), - str(url_data.reason), - ) - raise requests.exceptions.RequestException(msg) - url_data.raw.decode_content = True - dirname = os.path.dirname(dest) - if dirname != "" and not os.path.isdir(dirname): - makedirs(os.path.dirname(dest), exist_ok=True) - uuid_tmp = "dl3_" + str(uuid.uuid4())[:6] - dest_tmp = dest + "_" + uuid_tmp + ".tmp" - with open(dest_tmp, "wb") as f: - shutil.copyfileobj(url_data.raw, f) - try: - shutil.move(dest_tmp, dest) - except FileExistsError: - pass - remove(dest_tmp) - return dest - - -def get_url(x, from_str=False, short_name=False): - if not from_str: - source = x.metadata["source"] - else: - source = x - if short_name: - source_name = get_short_name(source) - else: - source_name = source - if source.startswith("http://") or source.startswith("https://"): - return ( - """%s""" - % (source, source_name) - ) - else: - return ( - """%s""" - % (source, source_name) - ) - - -def get_short_name(name, maxl=50): - if name is None: - return "" - length = len(name) - if length > maxl: - allow_length = maxl - 3 - half_allowed = max(1, int(allow_length / 2)) - name = ( - name[0:half_allowed] + "..." + name[length - half_allowed : length] - ) - return name - - -def cuda_vis_check(total_gpus): - """Helper function to count GPUs by environment variable - Stolen from Jon's h2o4gpu utils - """ - cudavis = os.getenv("CUDA_VISIBLE_DEVICES") - which_gpus = [] - if cudavis is not None: - # prune away white-space, non-numerics, - # except commas for simple checking - cudavis = "".join(cudavis.split()) - import re - - cudavis = re.sub("[^0-9,]", "", cudavis) - - lencudavis = len(cudavis) - if lencudavis == 0: - total_gpus = 0 - else: - total_gpus = min( - total_gpus, os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1 - ) - which_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(",") - which_gpus = [int(x) for x in which_gpus] - else: - which_gpus = list(range(0, total_gpus)) - - return total_gpus, which_gpus - - -def get_ngpus_vis(raise_if_exception=True): - ngpus_vis1 = 0 - - shell = False - if shell: - cmd = "nvidia-smi -L 2> /dev/null" - else: - cmd = ["nvidia-smi", "-L"] - - try: - timeout = 5 * 3 - o = subprocess.check_output(cmd, shell=shell, timeout=timeout) - lines = o.decode("utf-8").splitlines() - ngpus_vis1 = 0 - for line in lines: - if "Failed to initialize NVML" not in line: - ngpus_vis1 += 1 - except (FileNotFoundError, subprocess.CalledProcessError, OSError): - # GPU systems might not have nvidia-smi, so can't fail - pass - except subprocess.TimeoutExpired as e: - print("Failed get_ngpus_vis: %s" % str(e)) - if raise_if_exception: - raise - - ngpus_vis1, which_gpus = cuda_vis_check(ngpus_vis1) - return ngpus_vis1 - - -def get_mem_gpus(raise_if_exception=True, ngpus=None): - totalmem_gpus1 = 0 - usedmem_gpus1 = 0 - freemem_gpus1 = 0 - - if ngpus == 0: - return totalmem_gpus1, usedmem_gpus1, freemem_gpus1 - - try: - cmd = "nvidia-smi -q 2> /dev/null | grep -A 3 'FB Memory Usage'" - o = subprocess.check_output(cmd, shell=True, timeout=15) - lines = o.decode("utf-8").splitlines() - for line in lines: - if "Total" in line: - totalmem_gpus1 += int(line.split()[2]) * 1024**2 - if "Used" in line: - usedmem_gpus1 += int(line.split()[2]) * 1024**2 - if "Free" in line: - freemem_gpus1 += int(line.split()[2]) * 1024**2 - except (FileNotFoundError, subprocess.CalledProcessError, OSError): - # GPU systems might not have nvidia-smi, so can't fail - pass - except subprocess.TimeoutExpired as e: - print("Failed get_mem_gpus: %s" % str(e)) - if raise_if_exception: - raise - - return totalmem_gpus1, usedmem_gpus1, freemem_gpus1 - - -class ForkContext(threading.local): - """ - Set context for forking - Ensures state is returned once done - """ - - def __init__(self, args=None, kwargs=None, forkdata_capable=True): - """ - :param args: - :param kwargs: - :param forkdata_capable: whether fork is forkdata capable and will use copy-on-write forking of args/kwargs - """ - self.forkdata_capable = forkdata_capable - if self.forkdata_capable: - self.has_args = args is not None - self.has_kwargs = kwargs is not None - forkdatacontext.args = args - forkdatacontext.kwargs = kwargs - else: - self.has_args = False - self.has_kwargs = False - - def __enter__(self): - try: - # flush all outputs so doesn't happen during fork -- don't print/log inside ForkContext contexts! - sys.stdout.flush() - sys.stderr.flush() - except BaseException as e: - # exit not called if exception, and don't want to leave forkdatacontext filled in that case - print("ForkContext failure on enter: %s" % str(e)) - self.finally_act() - raise - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self.finally_act() - - def finally_act(self): - """ - Done when exception hit or exit is reached in context - first reset forkdatacontext as crucial to have reset even if later 2 calls fail - :return: None - """ - if self.forkdata_capable and (self.has_args or self.has_kwargs): - forkdatacontext._reset() - - -class _ForkDataContext(threading.local): - def __init__( - self, - args=None, - kwargs=None, - ): - """ - Global context for fork to carry data to subprocess instead of relying upon copy/pickle/serialization - - :param args: args - :param kwargs: kwargs - """ - assert isinstance(args, (tuple, type(None))) - assert isinstance(kwargs, (dict, type(None))) - self.__args = args - self.__kwargs = kwargs - - @property - def args(self) -> Tuple: - """returns args""" - return self.__args - - @args.setter - def args(self, args): - if self.__args is not None: - raise AttributeError( - "args cannot be overwritten: %s %s" - % (str(self.__args), str(self.__kwargs)) - ) - - self.__args = args - - @property - def kwargs(self) -> Dict: - """returns kwargs""" - return self.__kwargs - - @kwargs.setter - def kwargs(self, kwargs): - if self.__kwargs is not None: - raise AttributeError( - "kwargs cannot be overwritten: %s %s" - % (str(self.__args), str(self.__kwargs)) - ) - - self.__kwargs = kwargs - - def _reset(self): - """Reset fork arg-kwarg context to default values""" - self.__args = None - self.__kwargs = None - - def get_args_kwargs( - self, func, args, kwargs - ) -> Tuple[Callable, Tuple, Dict]: - if self.__args: - args = self.__args[1:] - if not func: - assert ( - len(self.__args) > 0 - ), "if have no func, must have in args" - func = self.__args[0] # should always be there - if self.__kwargs: - kwargs = self.__kwargs - try: - return func, args, kwargs - finally: - forkdatacontext._reset() - - @staticmethod - def get_args_kwargs_for_traced_func(func, args, kwargs): - """ - Return args/kwargs out of forkdatacontext when using copy-on-write way of passing args/kwargs - :param func: actual function ran by _traced_func, which itself is directly what mppool treats as function - :param args: - :param kwargs: - :return: func, args, kwargs from forkdatacontext if used, else originals - """ - # first 3 lines are debug - func_was_None = func is None - args_was_None_or_empty = args is None or len(args) == 0 - kwargs_was_None_or_empty = kwargs is None or len(kwargs) == 0 - - forkdatacontext_args_was_None = forkdatacontext.args is None - forkdatacontext_kwargs_was_None = forkdatacontext.kwargs is None - func, args, kwargs = forkdatacontext.get_args_kwargs( - func, args, kwargs - ) - using_forkdatacontext = ( - func_was_None and func is not None - ) # pulled func out of forkdatacontext.__args[0] - assert ( - forkdatacontext.args is None - ), "forkdatacontext.args should be None after get_args_kwargs" - assert ( - forkdatacontext.kwargs is None - ), "forkdatacontext.kwargs should be None after get_args_kwargs" - - proc_type = kwargs.get("proc_type", "SUBPROCESS") - if using_forkdatacontext: - assert proc_type == "SUBPROCESS" or proc_type == "SUBPROCESS" - if proc_type == "NORMAL": - assert ( - forkdatacontext_args_was_None - ), "if no fork, expect forkdatacontext.args None entering _traced_func" - assert ( - forkdatacontext_kwargs_was_None - ), "if no fork, expect forkdatacontext.kwargs None entering _traced_func" - assert ( - func is not None - ), "function should not be None, indicates original args[0] was None or args was None" - - return func, args, kwargs - - -forkdatacontext = _ForkDataContext() - - -def _traced_func(func, *args, **kwargs): - func, args, kwargs = forkdatacontext.get_args_kwargs_for_traced_func( - func, args, kwargs - ) - return func(*args, **kwargs) - - -def call_subprocess_onetask(func, args=None, kwargs=None): - if isinstance(args, list): - args = tuple(args) - if args is None: - args = () - if kwargs is None: - kwargs = {} - args = list(args) - args = [func] + args - args = tuple(args) - with ForkContext(args=args, kwargs=kwargs): - args = (None,) - kwargs = {} - with ProcessPoolExecutor(max_workers=1) as executor: - future = executor.submit(_traced_func, *args, **kwargs) - return future.result() - - -class ProgressParallel(Parallel): - def __init__(self, use_tqdm=True, total=None, *args, **kwargs): - self._use_tqdm = use_tqdm - self._total = total - super().__init__(*args, **kwargs) - - def __call__(self, *args, **kwargs): - with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar: - return Parallel.__call__(self, *args, **kwargs) - - def print_progress(self): - if self._total is None: - self._pbar.total = self.n_dispatched_tasks - self._pbar.n = self.n_completed_tasks - self._pbar.refresh() - - -def get_kwargs(func, exclude_names=None, **kwargs): - func_names = list(inspect.signature(func).parameters) - missing_kwargs = [x for x in func_names if x not in kwargs] - if exclude_names: - for k in exclude_names: - if k in missing_kwargs: - missing_kwargs.remove(k) - if k in func_names: - func_names.remove(k) - assert not missing_kwargs, "Missing %s" % missing_kwargs - kwargs = {k: v for k, v in kwargs.items() if k in func_names} - return kwargs - - -import pkg_resources - -have_faiss = False - -try: - assert pkg_resources.get_distribution("faiss") is not None - have_faiss = True -except (pkg_resources.DistributionNotFound, AssertionError): - pass -try: - assert pkg_resources.get_distribution("faiss_gpu") is not None - have_faiss = True -except (pkg_resources.DistributionNotFound, AssertionError): - pass -try: - assert pkg_resources.get_distribution("faiss_cpu") is not None - have_faiss = True -except (pkg_resources.DistributionNotFound, AssertionError): - pass - - -def hash_file(file): - try: - import hashlib - - # BUF_SIZE is totally arbitrary, change for your app! - BUF_SIZE = 65536 # lets read stuff in 64kb chunks! - - md5 = hashlib.md5() - # sha1 = hashlib.sha1() - - with open(file, "rb") as f: - while True: - data = f.read(BUF_SIZE) - if not data: - break - md5.update(data) - # sha1.update(data) - except BaseException as e: - print("Cannot hash %s due to %s" % (file, str(e))) - traceback.print_exc() - md5 = None - return md5.hexdigest() - - -def start_faulthandler(): - # If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump - # If more than one fork tries to write at same time, then looks corrupted. - import faulthandler - - # SIGUSR1 in h2oai/__init__.py as well - faulthandler.enable() - if hasattr(faulthandler, "register"): - # windows/mac - import signal - - faulthandler.register(signal.SIGUSR1) - - -def get_hf_server(inference_server): - inf_split = inference_server.split(" ") - assert len(inf_split) == 1 or len(inf_split) == 3 - inference_server = inf_split[0] - if len(inf_split) == 3: - headers = {"authorization": "%s %s" % (inf_split[1], inf_split[2])} - else: - headers = None - return inference_server, headers - - -class FakeTokenizer: - """ - 1) For keeping track of model_max_length - 2) For when model doesn't directly expose tokenizer but need to count tokens - """ - - def __init__(self, model_max_length=2048, encoding_name="cl100k_base"): - # dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250 - self.model_max_length = model_max_length - 250 - self.encoding_name = encoding_name - # The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection. - import tiktoken - - self.encoding = tiktoken.get_encoding(self.encoding_name) - - def encode(self, x, *args, return_tensors="pt", **kwargs): - input_ids = self.encoding.encode(x, disallowed_special=()) - if return_tensors == "pt" and isinstance(input_ids, list): - import torch - - input_ids = torch.tensor(input_ids) - return dict(input_ids=input_ids) - - def decode(self, x, *args, **kwargs): - # input is input_ids[0] form - return self.encoding.decode(x) - - def num_tokens_from_string(self, prompt: str) -> int: - """Returns the number of tokens in a text string.""" - num_tokens = len(self.encoding.encode(prompt)) - return num_tokens - - def __call__(self, x, *args, **kwargs): - return self.encode(x, *args, **kwargs) diff --git a/apps/language_models/langchain/utils_langchain.py b/apps/language_models/langchain/utils_langchain.py deleted file mode 100644 index 416a431a..00000000 --- a/apps/language_models/langchain/utils_langchain.py +++ /dev/null @@ -1,69 +0,0 @@ -from typing import Any, Dict, List, Union, Optional -import time -import queue - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult - - -class StreamingGradioCallbackHandler(BaseCallbackHandler): - """ - Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend - """ - - def __init__(self, timeout: Optional[float] = None, block=True): - super().__init__() - self.text_queue = queue.SimpleQueue() - self.stop_signal = None - self.do_stop = False - self.timeout = timeout - self.block = block - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running. Clean the queue.""" - while not self.text_queue.empty(): - try: - self.text_queue.get(block=False) - except queue.Empty: - continue - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - self.text_queue.put(token) - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - self.text_queue.put(self.stop_signal) - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when LLM errors.""" - self.text_queue.put(self.stop_signal) - - def __iter__(self): - return self - - def __next__(self): - while True: - try: - value = ( - self.stop_signal - ) # value looks unused in pycharm, not true - if self.do_stop: - print("hit stop", flush=True) - # could raise or break, maybe best to raise and make parent see if any exception in thread - raise StopIteration() - # break - value = self.text_queue.get( - block=self.block, timeout=self.timeout - ) - break - except queue.Empty: - time.sleep(0.01) - if value == self.stop_signal: - raise StopIteration() - else: - return value diff --git a/apps/language_models/scripts/llama_ir_conversion_utils.py b/apps/language_models/scripts/llama_ir_conversion_utils.py deleted file mode 100644 index 3a056ca3..00000000 --- a/apps/language_models/scripts/llama_ir_conversion_utils.py +++ /dev/null @@ -1,442 +0,0 @@ -from pathlib import Path -import argparse -from argparse import RawTextHelpFormatter -import re, gc - -""" - This script can be used as a standalone utility to convert IRs to dynamic + combine them. - Following are the various ways this script can be used :- - a. To convert a single Linalg IR to dynamic IR: - --dynamic --first_ir_path= - b. To convert two Linalg IRs to dynamic IR: - --dynamic --first_ir_path= --first_ir_path= - c. To combine two Linalg IRs into one: - --combine --first_ir_path= --second_ir_path= - d. To convert both IRs into dynamic as well as combine the IRs: - --dynamic --combine --first_ir_path= --second_ir_path= - - NOTE: For dynamic you'll also need to provide the following set of flags:- - i. For First Llama : --dynamic_input_size (DEFAULT: 19) - ii. For Second Llama: --model_name (DEFAULT: llama2_7b) - --precision (DEFAULT: 'int4') - You may use --save_dynamic to also save the dynamic IR in option d above. - Else for option a. and b. the dynamic IR(s) will get saved by default. -""" - - -def combine_mlir_scripts( - first_vicuna_mlir, - second_vicuna_mlir, - output_name, - return_ir=True, -): - print(f"[DEBUG] combining first and second mlir") - print(f"[DEBUG] output_name = {output_name}") - maps1 = [] - maps2 = [] - constants = set() - f1 = [] - f2 = [] - - print(f"[DEBUG] processing first vicuna mlir") - first_vicuna_mlir = first_vicuna_mlir.splitlines() - while first_vicuna_mlir: - line = first_vicuna_mlir.pop(0) - if re.search("#map\d*\s*=", line): - maps1.append(line) - elif re.search("arith.constant", line): - constants.add(line) - elif not re.search("module", line): - line = re.sub("forward", "first_vicuna_forward", line) - f1.append(line) - f1 = f1[:-1] - del first_vicuna_mlir - gc.collect() - - for i, map_line in enumerate(maps1): - map_var = map_line.split(" ")[0] - map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line) - maps1[i] = map_line - f1 = [ - re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line) - for func_line in f1 - ] - - print(f"[DEBUG] processing second vicuna mlir") - second_vicuna_mlir = second_vicuna_mlir.splitlines() - while second_vicuna_mlir: - line = second_vicuna_mlir.pop(0) - if re.search("#map\d*\s*=", line): - maps2.append(line) - elif "global_seed" in line: - continue - elif re.search("arith.constant", line): - constants.add(line) - elif not re.search("module", line): - line = re.sub("forward", "second_vicuna_forward", line) - f2.append(line) - f2 = f2[:-1] - del second_vicuna_mlir - gc.collect() - - for i, map_line in enumerate(maps2): - map_var = map_line.split(" ")[0] - map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line) - maps2[i] = map_line - f2 = [ - re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line) - for func_line in f2 - ] - - module_start = 'module attributes {torch.debug_module_name = "_lambda"} {' - module_end = "}" - - global_vars = [] - vnames = [] - global_var_loading1 = [] - global_var_loading2 = [] - - print(f"[DEBUG] processing constants") - counter = 0 - constants = list(constants) - while constants: - constant = constants.pop(0) - vname, vbody = constant.split("=") - vname = re.sub("%", "", vname) - vname = vname.strip() - vbody = re.sub("arith.constant", "", vbody) - vbody = vbody.strip() - if len(vbody.split(":")) < 2: - print(constant) - vdtype = vbody.split(":")[-1].strip() - fixed_vdtype = vdtype - if "c1_i64" in vname: - print(constant) - counter += 1 - if counter == 2: - counter = 0 - print("detected duplicate") - continue - vnames.append(vname) - if "true" not in vname: - global_vars.append( - f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}" - ) - global_var_loading1.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - global_var_loading2.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - else: - global_vars.append( - f"ml_program.global private @{vname}({vbody}) : i1" - ) - global_var_loading1.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1" - ) - global_var_loading2.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1" - ) - - new_f1, new_f2 = [], [] - - print(f"[DEBUG] processing f1") - for line in f1: - if "func.func" in line: - new_f1.append(line) - for global_var in global_var_loading1: - new_f1.append(global_var) - else: - new_f1.append(line) - - print(f"[DEBUG] processing f2") - for line in f2: - if "func.func" in line: - new_f2.append(line) - for global_var in global_var_loading2: - if ( - "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" - in global_var - ): - print(global_var) - new_f2.append(global_var) - else: - new_f2.append(line) - - f1 = new_f1 - f2 = new_f2 - - del new_f1 - del new_f2 - gc.collect() - - print( - [ - "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x - for x in [maps1, maps2, global_vars, f1, f2] - ] - ) - - # doing it this way rather than assembling the whole string - # to prevent OOM with 64GiB RAM when encoding the file. - - print(f"[DEBUG] Saving mlir to {output_name}") - with open(output_name, "w+") as f_: - f_.writelines(line + "\n" for line in maps1) - f_.writelines(line + "\n" for line in maps2) - f_.writelines(line + "\n" for line in [module_start]) - f_.writelines(line + "\n" for line in global_vars) - f_.writelines(line + "\n" for line in f1) - f_.writelines(line + "\n" for line in f2) - f_.writelines(line + "\n" for line in [module_end]) - - del maps1 - del maps2 - del module_start - del global_vars - del f1 - del f2 - del module_end - gc.collect() - - if return_ir: - print(f"[DEBUG] Reading combined mlir back in") - with open(output_name, "rb") as f: - return f.read() - - -def write_in_dynamic_inputs0(module, dynamic_input_size): - print("[DEBUG] writing dynamic inputs to first vicuna") - # Current solution for ensuring mlir files support dynamic inputs - # TODO: find a more elegant way to implement this - new_lines = [] - module = module.splitlines() - while module: - line = module.pop(0) - line = re.sub(f"{dynamic_input_size}x", "?x", line) - if "?x" in line: - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - line = re.sub(f" {dynamic_input_size},", " %dim,", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line - ) - if "arith.cmpi" in line: - line = re.sub(f"c{dynamic_input_size}", "dim", line) - if "%0 = tensor.empty(%dim) : tensor" in line: - new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>") - if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line: - continue - - new_lines.append(line) - return "\n".join(new_lines) - - -def write_in_dynamic_inputs1(module, model_name, precision): - print("[DEBUG] writing dynamic inputs to second vicuna") - - def remove_constant_dim(line): - if "c19_i64" in line: - line = re.sub("c19_i64", "dim_i64", line) - if "19x" in line: - line = re.sub("19x", "?x", line) - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", - "tensor.empty(%dim, %dim)", - line, - ) - if "arith.cmpi" in line: - line = re.sub("c19", "dim", line) - if " 19," in line: - line = re.sub(" 19,", " %dim,", line) - if "x20x" in line or "<20x" in line: - line = re.sub("20x", "?x", line) - line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line) - if " 20," in line: - line = re.sub(" 20,", " %dimp1,", line) - return line - - module = module.splitlines() - new_lines = [] - - # Using a while loop and the pop method to avoid creating a copy of module - if "llama2_13b" in model_name: - pkv_tensor_shape = "tensor<1x40x?x128x" - elif "llama2_70b" in model_name: - pkv_tensor_shape = "tensor<1x8x?x128x" - else: - pkv_tensor_shape = "tensor<1x32x?x128x" - if precision in ["fp16", "int4", "int8"]: - pkv_tensor_shape += "f16>" - else: - pkv_tensor_shape += "f32>" - - while module: - line = module.pop(0) - if "%c19_i64 = arith.constant 19 : i64" in line: - new_lines.append("%c2 = arith.constant 2 : index") - new_lines.append( - f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}" - ) - new_lines.append( - "%dim_i64 = arith.index_cast %dim_4_int : index to i64" - ) - continue - if "%c2 = arith.constant 2 : index" in line: - continue - if "%c20_i64 = arith.constant 20 : i64" in line: - new_lines.append("%c1_i64 = arith.constant 1 : i64") - new_lines.append("%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64") - new_lines.append( - "%dimp1 = arith.index_cast %c20_i64 : i64 to index" - ) - continue - line = remove_constant_dim(line) - new_lines.append(line) - - return "\n".join(new_lines) - - -def save_dynamic_ir(ir_to_save, output_file): - if not ir_to_save: - return - # We only get string output from the dynamic conversion utility. - from contextlib import redirect_stdout - - with open(output_file, "w") as f: - with redirect_stdout(f): - print(ir_to_save) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser( - prog="llama ir utility", - description="\tThis script can be used as a standalone utility to convert IRs to dynamic + combine them.\n" - + "\tFollowing are the various ways this script can be used :-\n" - + "\t\ta. To convert a single Linalg IR to dynamic IR:\n" - + "\t\t\t--dynamic --first_ir_path=\n" - + "\t\tb. To convert two Linalg IRs to dynamic IR:\n" - + "\t\t\t--dynamic --first_ir_path= --first_ir_path=\n" - + "\t\tc. To combine two Linalg IRs into one:\n" - + "\t\t\t--combine --first_ir_path= --second_ir_path=\n" - + "\t\td. To convert both IRs into dynamic as well as combine the IRs:\n" - + "\t\t\t--dynamic --combine --first_ir_path= --second_ir_path=\n\n" - + "\tNOTE: For dynamic you'll also need to provide the following set of flags:-\n" - + "\t\t i. For First Llama : --dynamic_input_size (DEFAULT: 19)\n" - + "\t\tii. For Second Llama: --model_name (DEFAULT: llama2_7b)\n" - + "\t\t\t--precision (DEFAULT: 'int4')\n" - + "\t You may use --save_dynamic to also save the dynamic IR in option d above.\n" - + "\t Else for option a. and b. the dynamic IR(s) will get saved by default.\n", - formatter_class=RawTextHelpFormatter, - ) - parser.add_argument( - "--precision", - "-p", - default="int4", - choices=["fp32", "fp16", "int8", "int4"], - help="Precision of the concerned IR", - ) - parser.add_argument( - "--model_name", - type=str, - default="llama2_7b", - choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"], - help="Specify which model to run.", - ) - parser.add_argument( - "--first_ir_path", - default=None, - help="path to first llama mlir file", - ) - parser.add_argument( - "--second_ir_path", - default=None, - help="path to second llama mlir file", - ) - parser.add_argument( - "--dynamic_input_size", - type=int, - default=19, - help="Specify the static input size to replace with dynamic dim.", - ) - parser.add_argument( - "--dynamic", - default=False, - action=argparse.BooleanOptionalAction, - help="Converts the IR(s) to dynamic", - ) - parser.add_argument( - "--save_dynamic", - default=False, - action=argparse.BooleanOptionalAction, - help="Save the individual IR(s) after converting to dynamic", - ) - parser.add_argument( - "--combine", - default=False, - action=argparse.BooleanOptionalAction, - help="Converts the IR(s) to dynamic", - ) - - args, unknown = parser.parse_known_args() - - dynamic = args.dynamic - combine = args.combine - assert ( - dynamic or combine - ), "neither `dynamic` nor `combine` flag is turned on" - first_ir_path = args.first_ir_path - second_ir_path = args.second_ir_path - assert first_ir_path or second_ir_path, "no input ir has been provided" - if combine: - assert ( - first_ir_path and second_ir_path - ), "you will need to provide both IRs to combine" - precision = args.precision - model_name = args.model_name - dynamic_input_size = args.dynamic_input_size - save_dynamic = args.save_dynamic - - print(f"Dynamic conversion utility is turned {'ON' if dynamic else 'OFF'}") - print(f"Combining IR utility is turned {'ON' if combine else 'OFF'}") - - if dynamic and not combine: - save_dynamic = True - - first_ir = None - first_dynamic_ir_name = None - second_ir = None - second_dynamic_ir_name = None - if first_ir_path: - first_dynamic_ir_name = f"{Path(first_ir_path).stem}_dynamic" - with open(first_ir_path, "r") as f: - first_ir = f.read() - if second_ir_path: - second_dynamic_ir_name = f"{Path(second_ir_path).stem}_dynamic" - with open(second_ir_path, "r") as f: - second_ir = f.read() - if dynamic: - first_ir = ( - write_in_dynamic_inputs0(first_ir, dynamic_input_size) - if first_ir - else None - ) - second_ir = ( - write_in_dynamic_inputs1(second_ir, model_name, precision) - if second_ir - else None - ) - if save_dynamic: - save_dynamic_ir(first_ir, f"{first_dynamic_ir_name}.mlir") - save_dynamic_ir(second_ir, f"{second_dynamic_ir_name}.mlir") - - if combine: - combine_mlir_scripts( - first_ir, - second_ir, - f"{model_name}_{precision}.mlir", - return_ir=False, - ) diff --git a/apps/language_models/scripts/stablelm.py b/apps/language_models/scripts/stablelm.py deleted file mode 100644 index 98768073..00000000 --- a/apps/language_models/scripts/stablelm.py +++ /dev/null @@ -1,211 +0,0 @@ -import torch -import torch_mlir -from transformers import ( - AutoTokenizer, - StoppingCriteria, -) -from io import BytesIO -from pathlib import Path -from apps.language_models.utils import ( - get_torch_mlir_module_bytecode, - get_vmfb_from_path, -) - - -class StopOnTokens(StoppingCriteria): - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs - ) -> bool: - stop_ids = [50278, 50279, 50277, 1, 0] - for stop_id in stop_ids: - if input_ids[0][-1] == stop_id: - return True - return False - - -def shouldStop(tokens): - stop_ids = [50278, 50279, 50277, 1, 0] - for stop_id in stop_ids: - if tokens[0][-1] == stop_id: - return True - return False - - -MAX_SEQUENCE_LENGTH = 256 - - -def user(message, history): - # Append the user's message to the conversation history - return "", history + [[message, ""]] - - -def compile_stableLM( - model, - model_inputs, - model_name, - model_vmfb_name, - device="cuda", - precision="fp32", - debug=False, -): - from shark.shark_inference import SharkInference - - # device = "cuda" # "cpu" - # TODO: vmfb and mlir name should include precision and device - vmfb_path = ( - Path(model_name + f"_{device}.vmfb") - if model_vmfb_name is None - else Path(model_vmfb_name) - ) - shark_module = get_vmfb_from_path( - vmfb_path, device, mlir_dialect="tm_tensor" - ) - if shark_module is not None: - return shark_module - - mlir_path = Path(model_name + ".mlir") - print( - f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}" - ) - if mlir_path.exists(): - with open(mlir_path, "rb") as f: - bytecode = f.read() - else: - ts_graph = get_torch_mlir_module_bytecode(model, model_inputs) - module = torch_mlir.compile( - ts_graph, - [*model_inputs], - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - f_ = open(model_name + ".mlir", "wb") - f_.write(bytecode) - print("Saved mlir") - f_.close() - - shark_module = SharkInference( - mlir_module=bytecode, device=device, mlir_dialect="tm_tensor" - ) - shark_module.compile() - - path = shark_module.save_module( - vmfb_path.parent.absolute(), vmfb_path.stem, debug=debug - ) - print("Saved vmfb at ", str(path)) - - return shark_module - - -class StableLMModel(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, input_ids, attention_mask): - combine_input_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - output = self.model(**combine_input_dict) - return output.logits - - -# Initialize a StopOnTokens object -system_prompt = """<|SYSTEM|># StableLM Tuned (Alpha version) -- StableLM is a helpful and harmless open-source AI language model developed by StabilityAI. -- StableLM is excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user. -- StableLM is more than just an information source, StableLM is also able to write poetry, short stories, and make jokes. -- StableLM will refuse to participate in anything that could harm a human. -""" - - -def get_tokenizer(): - model_path = "stabilityai/stablelm-tuned-alpha-3b" - tok = AutoTokenizer.from_pretrained(model_path) - tok.add_special_tokens({"pad_token": ""}) - print("Sucessfully loaded the tokenizer to the memory") - return tok - - -# sharkStableLM = compile_stableLM -# ( -# None, -# tuple([input_ids, attention_mask]), -# "stableLM_linalg_f32_seqLen256", -# "/home/shark/vivek/stableLM_shark_f32_seqLen256" -# ) -def generate( - new_text, - max_new_tokens, - sharkStableLM, - tokenizer=None, -): - if tokenizer is None: - tokenizer = get_tokenizer() - # Construct the input message string for the model by - # concatenating the current system message and conversation history - # Tokenize the messages string - # sharkStableLM = compile_stableLM - # ( - # None, - # tuple([input_ids, attention_mask]), - # "stableLM_linalg_f32_seqLen256", - # "/home/shark/vivek/stableLM_shark_f32_seqLen256" - # ) - words_list = [] - for i in range(max_new_tokens): - # numWords = len(new_text.split()) - # if(numWords>220): - # break - params = { - "new_text": new_text, - } - generated_token_op = generate_new_token( - sharkStableLM, tokenizer, params - ) - detok = generated_token_op["detok"] - stop_generation = generated_token_op["stop_generation"] - if stop_generation: - break - print(detok, end="", flush=True) - words_list.append(detok) - if detok == "": - break - new_text = new_text + detok - return words_list - - -def generate_new_token(shark_model, tokenizer, params): - new_text = params["new_text"] - model_inputs = tokenizer( - [new_text], - padding="max_length", - max_length=MAX_SEQUENCE_LENGTH, - truncation=True, - return_tensors="pt", - ) - sum_attentionmask = torch.sum(model_inputs.attention_mask) - # sharkStableLM = compile_stableLM(None, tuple([input_ids, attention_mask]), "stableLM_linalg_f32_seqLen256", "/home/shark/vivek/stableLM_shark_f32_seqLen256") - output = shark_model( - "forward", [model_inputs.input_ids, model_inputs.attention_mask] - ) - output = torch.from_numpy(output) - next_toks = torch.topk(output, 1) - stop_generation = False - if shouldStop(next_toks.indices): - stop_generation = True - new_token = next_toks.indices[0][int(sum_attentionmask) - 1] - detok = tokenizer.decode( - new_token, - skip_special_tokens=True, - ) - ret_dict = { - "new_token": new_token, - "detok": detok, - "stop_generation": stop_generation, - } - return ret_dict diff --git a/apps/language_models/scripts/vicuna.py b/apps/language_models/scripts/vicuna.py deleted file mode 100644 index 548dcb83..00000000 --- a/apps/language_models/scripts/vicuna.py +++ /dev/null @@ -1,2480 +0,0 @@ -import argparse -import json -import re -import gc -from io import BytesIO -from pathlib import Path -from statistics import mean, stdev -from tqdm import tqdm -from typing import List, Tuple -import subprocess -import sys -import time -from dataclasses import dataclass -from os import environ -from dataclasses import dataclass -from os import environ - -import torch -import torch_mlir -from torch_mlir import TensorPlaceholder -from torch_mlir.compiler_utils import run_pipeline_with_repro_report -from transformers import AutoTokenizer, AutoModelForCausalLM - -from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase -from apps.language_models.src.model_wrappers.vicuna_sharded_model import ( - FirstVicunaLayer, - SecondVicunaLayer, - CompiledVicunaLayer, - ShardedVicunaModel, - LMHead, - LMHeadCompiled, - VicunaEmbedding, - VicunaEmbeddingCompiled, - VicunaNorm, - VicunaNormCompiled, -) -from apps.language_models.src.model_wrappers.vicuna4 import ( - LlamaModel, - EightLayerLayerSV, - EightLayerLayerFV, - CompiledEightLayerLayerSV, - CompiledEightLayerLayer, - forward_compressed, -) -from apps.language_models.src.model_wrappers.vicuna_model import ( - FirstVicuna, - SecondVicuna7B, - SecondVicuna13B, - SecondVicuna70B, -) -from apps.language_models.src.model_wrappers.vicuna_model_gpu import ( - FirstVicunaGPU, - SecondVicuna7BGPU, - SecondVicuna13BGPU, - SecondVicuna70BGPU, -) -from apps.language_models.utils import ( - get_vmfb_from_path, -) -from shark.shark_downloader import download_public_file -from shark.shark_importer import get_f16_inputs -from shark.shark_importer import import_with_fx, save_mlir -from shark.shark_inference import SharkInference - - -parser = argparse.ArgumentParser( - prog="vicuna runner", - description="runs a vicuna model", -) -parser.add_argument( - "--precision", "-p", default="int8", help="fp32, fp16, int8, int4" -) -parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda") -parser.add_argument( - "--vicuna_vmfb_path", default=None, help="path to vicuna vmfb" -) -parser.add_argument( - "-s", - "--sharded", - default=False, - action=argparse.BooleanOptionalAction, - help="Run model as sharded", -) -# TODO: sharded config -parser.add_argument( - "--vicuna_mlir_path", - default=None, - help="path to vicuna mlir file", -) -parser.add_argument( - "--load_mlir_from_shark_tank", - default=False, - action=argparse.BooleanOptionalAction, - help="download precompile mlir from shark tank", -) -parser.add_argument( - "--cli", - default=False, - action=argparse.BooleanOptionalAction, - help="Run model in cli mode", -) -parser.add_argument( - "--config", - default=None, - help="configuration file", -) -parser.add_argument( - "--weight-group-size", - type=int, - default=128, - help="Group size for per_group weight quantization. Default: 128.", -) - -parser.add_argument( - "--n_devices", type=int, default=None, help="Number of GPUs to use" -) - -parser.add_argument( - "--download_vmfb", - default=False, - action=argparse.BooleanOptionalAction, - help="Download vmfb from sharktank, system dependent, YMMV", -) -parser.add_argument( - "--model_name", - type=str, - default="vicuna", - choices=["vicuna", "llama2_7b", "llama2_13b", "llama2_70b"], - help="Specify which model to run.", -) -parser.add_argument( - "--hf_auth_token", - type=str, - default=None, - help="Specify your own huggingface authentication tokens for models like Llama2.", -) -parser.add_argument( - "--cache_vicunas", - default=False, - action=argparse.BooleanOptionalAction, - help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk", -) -parser.add_argument( - "--iree_vulkan_target_triple", - type=str, - default="", - help="Specify target triple for vulkan.", -) -parser.add_argument( - "--Xiree_compile", - action="append", - default=[], - help="Extra command line arguments passed to the IREE compiler. This can be specified multiple times to pass multiple arguments.", -) - -# Microbenchmarking options. -parser.add_argument( - "--enable_microbenchmark", - default=False, - action=argparse.BooleanOptionalAction, - help="Enables the microbenchmarking mode (non-interactive). Uses the system and the user prompt from args.", -) -parser.add_argument( - "--microbenchmark_iterations", - type=int, - default=5, - help="Number of microbenchmark iterations. Default: 5.", -) -parser.add_argument( - "--microbenchmark_num_tokens", - type=int, - default=512, - help="Generate an exact number of output tokens. Default: 512.", -) -parser.add_argument( - "--system_prompt", - type=str, - default="", - help="Specify the system prompt. This is only used with `--enable_microbenchmark`", -) -parser.add_argument( - "--enable_tracing", - default=False, - action=argparse.BooleanOptionalAction, - help="Enable profiling with Tracy. The script will wait for Tracy to connect and flush the profiling data after each token." -) -parser.add_argument( - "--user_prompt", - type=str, - default="Hi", - help="Specify the user prompt. This is only used with `--enable_microbenchmark`", -) - -# fmt: off -def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: - if len(lhs) == 3 and len(rhs) == 2: - return [lhs[0], lhs[1], rhs[0]] - elif len(lhs) == 2 and len(rhs) == 2: - return [lhs[0], rhs[0]] - else: - raise ValueError("Input shapes not supported.") - - -def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: - # output dtype is the dtype of the lhs float input - lhs_rank, lhs_dtype = lhs_rank_dtype - return lhs_dtype - - -def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: - return - - -brevitas_matmul_rhs_group_quant_library = [ - quant〇matmul_rhs_group_quant〡shape, - quant〇matmul_rhs_group_quant〡dtype, - quant〇matmul_rhs_group_quant〡has_value_semantics] -# fmt: on - - -class VicunaBase(SharkLLMBase): - def __init__( - self, - model_name, - hf_model_path="TheBloke/vicuna-7B-1.1-HF", - max_num_tokens=512, - device="cpu", - precision="int8", - extra_args_cmd=[], - ) -> None: - super().__init__(model_name, hf_model_path, max_num_tokens) - self.max_sequence_length = 256 - self.device = device - self.precision = precision - self.extra_args = extra_args_cmd - - def get_tokenizer(self): - # Retrieve the tokenizer from Huggingface - tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, use_fast=False - ) - return tokenizer - - def get_src_model(self): - # Retrieve the torch model from Huggingface - kwargs = {"torch_dtype": torch.float} - vicuna_model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, **kwargs - ) - return vicuna_model - - def combine_mlir_scripts( - self, - first_vicuna_mlir, - second_vicuna_mlir, - output_name, - ): - print(f"[DEBUG] combining first and second mlir") - print(f"[DEBUG] output_name = {output_name}") - maps1 = [] - maps2 = [] - constants = set() - f1 = [] - f2 = [] - - print(f"[DEBUG] processing first vicuna mlir") - first_vicuna_mlir = first_vicuna_mlir.splitlines() - while first_vicuna_mlir: - line = first_vicuna_mlir.pop(0) - if re.search("#map\d*\s*=", line): - maps1.append(line) - elif re.search("arith.constant", line): - constants.add(line) - elif not re.search("module", line): - line = re.sub("forward", "first_vicuna_forward", line) - f1.append(line) - f1 = f1[:-1] - del first_vicuna_mlir - gc.collect() - - for i, map_line in enumerate(maps1): - map_var = map_line.split(" ")[0] - map_line = re.sub(f"{map_var}(?!\d)", map_var + "_0", map_line) - maps1[i] = map_line - f1 = [ - re.sub(f"{map_var}(?!\d)", map_var + "_0", func_line) - for func_line in f1 - ] - - print(f"[DEBUG] processing second vicuna mlir") - second_vicuna_mlir = second_vicuna_mlir.splitlines() - while second_vicuna_mlir: - line = second_vicuna_mlir.pop(0) - if re.search("#map\d*\s*=", line): - maps2.append(line) - elif "global_seed" in line: - continue - elif re.search("arith.constant", line): - constants.add(line) - elif not re.search("module", line): - line = re.sub("forward", "second_vicuna_forward", line) - f2.append(line) - f2 = f2[:-1] - del second_vicuna_mlir - gc.collect() - - for i, map_line in enumerate(maps2): - map_var = map_line.split(" ")[0] - map_line = re.sub(f"{map_var}(?!\d)", map_var + "_1", map_line) - maps2[i] = map_line - f2 = [ - re.sub(f"{map_var}(?!\d)", map_var + "_1", func_line) - for func_line in f2 - ] - - module_start = ( - 'module attributes {torch.debug_module_name = "_lambda"} {' - ) - module_end = "}" - - global_vars = [] - vnames = [] - global_var_loading1 = [] - global_var_loading2 = [] - - print(f"[DEBUG] processing constants") - counter = 0 - constants = list(constants) - while constants: - constant = constants.pop(0) - vname, vbody = constant.split("=") - vname = re.sub("%", "", vname) - vname = vname.strip() - vbody = re.sub("arith.constant", "", vbody) - vbody = vbody.strip() - if len(vbody.split(":")) < 2: - print(constant) - vdtype = vbody.split(":")[-1].strip() - fixed_vdtype = vdtype - noinline = "{noinline}" if "tensor" in fixed_vdtype else "" - if "c1_i64" in vname: - print(constant) - counter += 1 - if counter == 2: - counter = 0 - print("detected duplicate") - continue - vnames.append(vname) - if "true" not in vname: - global_vars.append( - f"ml_program.global private @{vname}({vbody}) : {fixed_vdtype}" - ) - global_var_loading1.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - global_var_loading2.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : {fixed_vdtype}" - ) - else: - global_vars.append( - f"ml_program.global private @{vname}({vbody}) : i1" - ) - global_var_loading1.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1" - ) - global_var_loading2.append( - f"\t\t%{vname} = ml_program.global_load_const @{vname} : i1" - ) - - new_f1, new_f2 = [], [] - - print(f"[DEBUG] processing f1") - for line in f1: - if "func.func" in line: - new_f1.append(line) - for global_var in global_var_loading1: - new_f1.append(global_var) - else: - new_f1.append(line) - - print(f"[DEBUG] processing f2") - for line in f2: - if "func.func" in line: - new_f2.append(line) - for global_var in global_var_loading2: - if ( - "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" - in global_var - ): - print(global_var) - new_f2.append(global_var) - else: - new_f2.append(line) - - f1 = new_f1 - f2 = new_f2 - - del new_f1 - del new_f2 - gc.collect() - - print( - [ - "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x - for x in [maps1, maps2, global_vars, f1, f2] - ] - ) - - # doing it this way rather than assembling the whole string - # to prevent OOM with 64GiB RAM when encoding the file. - - print(f"[DEBUG] Saving mlir to {output_name}") - with open(output_name, "w+") as f_: - f_.writelines(line + "\n" for line in maps1) - f_.writelines(line + "\n" for line in maps2) - f_.writelines(line + "\n" for line in [module_start]) - f_.writelines(line + "\n" for line in global_vars) - f_.writelines(line + "\n" for line in f1) - f_.writelines(line + "\n" for line in f2) - f_.writelines(line + "\n" for line in [module_end]) - - del maps1 - del maps2 - del module_start - del global_vars - del f1 - del f2 - del module_end - gc.collect() - - print(f"[DEBUG] Reading combined mlir back in") - with open(output_name, "rb") as f: - return f.read() - - def generate_new_token(self, params, sharded=True, cli=True): - is_first = params["is_first"] - if is_first: - prompt = params["prompt"] - input_ids = self.tokenizer(prompt).input_ids - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - if sharded: - output = self.shark_model.forward(input_ids, is_first=is_first) - else: - output = self.shark_model( - "first_vicuna_forward", (input_ids,), send_to_host=False - ) - - else: - token = params["token"] - past_key_values = params["past_key_values"] - input_ids = [token] - input_id_len = len(input_ids) - input_ids = torch.tensor(input_ids) - input_ids = input_ids.reshape([1, input_id_len]) - if sharded: - output = self.shark_model.forward( - input_ids, - past_key_values=past_key_values, - is_first=is_first, - ) - else: - token = torch.tensor(token).reshape([1, 1]) - second_input = (token,) + tuple(past_key_values) - output = self.shark_model( - "second_vicuna_forward", second_input, send_to_host=False - ) - - if sharded: - _logits = output["logits"] - _past_key_values = output["past_key_values"] - _token = int(torch.argmax(_logits[:, -1, :], dim=1)[0]) - elif "cpu" in self.device: - _past_key_values = output[1:] - _token = int(output[0].to_host()) - else: - _logits = torch.tensor(output[0].to_host()) - _past_key_values = output[1:] - _token = torch.argmax(_logits[:, -1, :], dim=1) - - _detok = self.tokenizer.decode(_token, skip_special_tokens=False) - ret_dict = { - "token": _token, - "detok": _detok, - "past_key_values": _past_key_values, - } - if "cpu" not in self.device: - ret_dict["logits"] = _logits - - if cli: - print(f" token : {_token} | detok : {_detok}") - - return ret_dict - - -class ShardedVicuna(VicunaBase): - # Class representing Sharded Vicuna Model - def __init__( - self, - model_name, - hf_model_path="TheBloke/vicuna-7B-1.1-HF", - hf_auth_token=None, - max_num_tokens=512, - device="cuda", - precision="fp32", - config_json=None, - weight_group_size=128, - compressed=False, - extra_args_cmd=[], - debug=False, - n_devices=None, - ) -> None: - self.hf_auth_token = hf_auth_token - self.hidden_state_size_dict = {"vicuna": 4096, "llama2_7b": 4096, "llama2_13b" : 5120, "llama2_70b" : 8192} - self.n_layers_dict = {"vicuna": 32, "llama2_7b": 32, "llama2_13b" : 40, "llama2_70b" : 80} - super().__init__( - model_name, - hf_model_path, - max_num_tokens, - extra_args_cmd=extra_args_cmd, - ) - self.max_sequence_length = 256 - self.device = device - self.precision = precision - self.debug = debug - self.tokenizer = self.get_tokenizer() - self.config = config_json - self.weight_group_size = weight_group_size - self.compressed = compressed - self.n_devices = n_devices - self.dir_name = f"{model_name}-{precision}-{device}-models" - self.dir_path = Path(self.dir_name) - if not self.dir_path.is_dir(): - self.dir_path.mkdir(parents=True, exist_ok=True) - self.shark_model = self.compile(device=device) - - def check_all_artifacts_present(self): - file_list = [f"{i}_full" for i in range(self.n_layers_dict[self.model_name])] + ["norm", "embedding", "lmhead"] - file_exists_list = [Path(f"{self.dir_name}/{x}.vmfb").exists() or Path(f"{self.dir_name}/{x}.mlir").exists() for x in file_list] - return all(file_exists_list) - - def get_tokenizer(self): - kwargs = {} - if "llama2" in self.model_name: - kwargs = {"use_auth_token": self.hf_auth_token} - tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, - use_fast=False, - **kwargs, - ) - return tokenizer - - def get_src_model(self): - # Retrieve the torch model from Huggingface - kwargs = {"torch_dtype": torch.float} - if "llama2" in self.model_name: - kwargs["use_auth_token"] = self.hf_auth_token - vicuna_model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, - **kwargs, - ) - return vicuna_model - - def write_in_dynamic_inputs0(self, module, dynamic_input_size): - # Current solution for ensuring mlir files support dynamic inputs - # TODO find a more elegant way to implement this - new_lines = [] - for line in module.splitlines(): - line = re.sub(f"{dynamic_input_size}x", "?x", line) - if "?x" in line: - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - line = re.sub(f" {dynamic_input_size},", " %dim,", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line - ) - if "arith.cmpi" in line: - line = re.sub(f"c{dynamic_input_size}", "dim", line) - new_lines.append(line) - new_module = "\n".join(new_lines) - return new_module - - def write_in_dynamic_inputs1(self, module, dynamic_input_size): - if self.precision == "fp32": - fprecision = "32" - else: - fprecision = "16" - new_lines = [] - for line in module.splitlines(): - if "dim_42 =" in line: - continue - if f"%c{dynamic_input_size}_i64 =" in line: - new_lines.append( - f"%dim_42 = tensor.dim %arg1, %c3 : tensor<1x1x1x?xf{fprecision}>" - ) - new_lines.append( - f"%dim_42_i64 = arith.index_cast %dim_42 : index to i64" - ) - continue - line = re.sub(f"{dynamic_input_size}x", "?x", line) - line = re.sub(f"%c{dynamic_input_size}_i64", "%dim_42_i64", line) - if "?x" in line: - line = re.sub( - "tensor.empty\(\)", "tensor.empty(%dim_42)", line - ) - line = re.sub(f" {dynamic_input_size},", " %dim_42,", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim_42\)", - "tensor.empty(%dim_42, %dim_42)", - line, - ) - if "arith.cmpi" in line: - line = re.sub(f"c{dynamic_input_size}", "dim_42", line) - new_lines.append(line) - new_module = "\n".join(new_lines) - return new_module - - def compile_vicuna_layer( - self, - vicuna_layer, - hidden_states, - attention_mask, - position_ids, - past_key_value0=None, - past_key_value1=None, - ): - # Compile a hidden decoder layer of vicuna - if past_key_value0 is None and past_key_value1 is None: - model_inputs = (hidden_states, attention_mask, position_ids) - else: - model_inputs = ( - hidden_states, - attention_mask, - position_ids, - past_key_value0, - past_key_value1, - ) - is_f16 = self.precision in ["fp16", "int4"] - mlir_bytecode = import_with_fx( - vicuna_layer, - model_inputs, - is_f16=is_f16, - precision=self.precision, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - return mlir_bytecode - - def compile_vicuna_layer4( - self, - vicuna_layer, - hidden_states, - attention_mask, - position_ids, - past_key_values=None, - ): - # Compile a hidden decoder layer of vicuna - if past_key_values is None: - model_inputs = (hidden_states, attention_mask, position_ids) - else: - ( - (pkv00, pkv01), - (pkv10, pkv11), - (pkv20, pkv21), - (pkv30, pkv31), - (pkv40, pkv41), - (pkv50, pkv51), - (pkv60, pkv61), - (pkv70, pkv71), - ) = past_key_values - - model_inputs = ( - hidden_states, - attention_mask, - position_ids, - pkv00, - pkv01, - pkv10, - pkv11, - pkv20, - pkv21, - pkv30, - pkv31, - pkv40, - pkv41, - pkv50, - pkv51, - pkv60, - pkv61, - pkv70, - pkv71, - ) - is_f16 = self.precision in ["fp16", "int4"] - mlir_bytecode = import_with_fx( - vicuna_layer, - model_inputs, - is_f16=is_f16, - precision=self.precision, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - return mlir_bytecode - - def get_device_index(self, layer_string): - # Get the device index from the config file - # In the event that different device indices are assigned to - # different parts of a layer, a majority vote will be taken and - # everything will be run on the most commonly used device - if self.config is None: - return None - idx_votes = {} - for key in self.config.keys(): - if re.search(layer_string, key): - if int(self.config[key]["gpu"]) in idx_votes.keys(): - idx_votes[int(self.config[key]["gpu"])] += 1 - else: - idx_votes[int(self.config[key]["gpu"])] = 1 - device_idx = max(idx_votes, key=idx_votes.get) - return device_idx - - - def write_dynamic_inputs_lmhead(self, ir, sample_input_length): - if self.precision in ["fp16", "int4"]: - precision_str = "f16" - else: - precision_str = "f32" - lines = ir.splitlines() - new_lines = [] - for line in lines: - if f"%cst_0 =" in line: - new_lines.append(line) - new_lines.append("%c1 = arith.constant 1 : index") - new_lines.append(f"%dim = tensor.dim %arg0, %c1 : tensor<1x?x{self.hidden_state_size_dict[self.model_name]}x{precision_str}>") - else: - line = re.sub(f"{sample_input_length}x", "?x", line) - if "?x" in line: - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - new_lines.append(line) - - return "\n".join(new_lines) - - def compile_lmhead( - self, - lmh, - hidden_states, - device="cpu", - device_idx=None, - ): - # compile the lm head of the vicuna model - # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"{self.dir_name}/lmhead.mlir") - vmfb_path = Path(f"{self.dir_name}/lmhead.vmfb") - if mlir_path.exists(): - print(f"Found bytecode module at {mlir_path}.") - else: - # hidden_states = torch_mlir.TensorPlaceholder.like( - # hidden_states, dynamic_axes=[1] - # ) - - is_f16 = self.precision in ["fp16", "int4"] - if is_f16: - ts_graph = import_with_fx( - lmh, - (hidden_states,), - is_f16=is_f16, - precision=self.precision, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - - if is_f16: - hidden_states = hidden_states.to(torch.float16) - - hidden_states = torch_mlir.TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) - - module = torch_mlir.compile( - ts_graph, - (hidden_states,), - output_type="torch", - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - module, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - - else: - hidden_states = torch_mlir.TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) - module = torch_mlir.compile( - lmh, - (hidden_states,), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - """ - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - f_ = open(mlir_path, "wb") - f_.write(bytecode) - f_.close() - """ - module = str(module) - if self.precision in ["int4", "fp16"]: - module = self.write_dynamic_inputs_lmhead(module, 137) - filepath = Path(f"{self.dir_name}/lmhead.mlir") - f_ = open(mlir_path, "w+") - f_.write(module) - f_.close() - # download_public_file( - # "gs://shark_tank/elias/compressed_sv/lmhead.mlir", - # filepath.absolute(), - # single_file=True, - # ) - mlir_path = filepath - - shark_module = SharkInference( - mlir_path, - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - mmap=True, - ) - if vmfb_path.exists(): - shark_module.load_module(vmfb_path) - else: - shark_module.save_module( - module_name=f"{self.dir_name}/lmhead", debug=self.debug - ) - shark_module.load_module(vmfb_path) - compiled_module = LMHeadCompiled(shark_module) - return compiled_module - - def compile_norm(self, fvn, hidden_states, device="cpu", device_idx=None): - # compile the normalization layer of the vicuna model - # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"{self.dir_name}/norm.mlir") - vmfb_path = Path(f"{self.dir_name}/norm.vmfb") - if mlir_path.exists(): - print(f"Found bytecode module at {mlir_path}.") - else: - # hidden_states = torch_mlir.TensorPlaceholder.like( - # hidden_states, dynamic_axes=[1] - # ) - - is_f16 = self.precision in ["fp16", "int4"] - if is_f16: - ts_graph = import_with_fx( - fvn, - (hidden_states,), - is_f16=is_f16, - precision=self.precision, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - - if is_f16: - hidden_states = hidden_states.to(torch.float16) - - hidden_states = torch_mlir.TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) - - module = torch_mlir.compile( - ts_graph, - (hidden_states,), - output_type="torch", - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - else: - hidden_states = torch_mlir.TensorPlaceholder.like( - hidden_states, dynamic_axes=[1] - ) - module = torch_mlir.compile( - fvn, - (hidden_states,), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - module, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - f_ = open(mlir_path, "wb") - f_.write(bytecode) - f_.close() - filepath = Path(f"{self.dir_name}/norm.mlir") - # download_public_file( - # "gs://shark_tank/elias/compressed_sv/norm.mlir", - # filepath.absolute(), - # single_file=True, - # ) - mlir_path = filepath - - shark_module = SharkInference( - mlir_path, - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - mmap=True, - ) - if vmfb_path.exists(): - shark_module.load_module(vmfb_path) - else: - shark_module.save_module( - module_name=f"{self.dir_name}/norm", debug=self.debug - ) - shark_module.load_module(vmfb_path) - compiled_module = VicunaNormCompiled(shark_module) - return compiled_module - - def compile_embedding(self, fve, input_ids, device="cpu", device_idx=None): - # compile the embedding layer of the vicuna model - # This can be used for both first and second vicuna, so only needs to be run once - mlir_path = Path(f"{self.dir_name}/embedding.mlir") - vmfb_path = Path(f"{self.dir_name}/embedding.vmfb") - if mlir_path.exists(): - print(f"Found bytecode module at {mlir_path}.") - else: - is_f16 = self.precision in ["fp16", "int4"] - if is_f16: - # input_ids = torch_mlir.TensorPlaceholder.like( - # input_ids, dynamic_axes=[1] - # ) - ts_graph = import_with_fx( - fve, - (input_ids,), - is_f16=is_f16, - precision=self.precision, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - input_ids = torch_mlir.TensorPlaceholder.like( - input_ids, dynamic_axes=[1] - ) - module = torch_mlir.compile( - ts_graph, - (input_ids,), - output_type="torch", - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - else: - input_ids = torch_mlir.TensorPlaceholder.like( - input_ids, dynamic_axes=[1] - ) - module = torch_mlir.compile( - fve, - (input_ids,), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - module, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - f_ = open(mlir_path, "wb") - f_.write(bytecode) - f_.close() - filepath = Path(f"{self.dir_name}/embedding.mlir") - # download_public_file( - # "gs://shark_tank/elias/compressed_sv/embedding.mlir", - # filepath.absolute(), - # single_file=True, - # ) - mlir_path = filepath - - shark_module = SharkInference( - mlir_path, - device=device, - mlir_dialect="tm_tensor", - device_idx=device_idx, - mmap=True, - ) - if vmfb_path.exists(): - shark_module.load_module(vmfb_path) - else: - shark_module.save_module( - module_name=f"{self.dir_name}/embedding", debug=self.debug - ) - shark_module.load_module(vmfb_path) - compiled_module = VicunaEmbeddingCompiled(shark_module) - - return compiled_module - - def compile_to_vmfb_one_model( - self, - inputs0, - layers0, - inputs1, - layers1, - device="cpu", - ): - if self.precision != "fp32": - inputs0 = tuple( - inpt.to(torch.float16) if inpt.dtype == torch.float32 else inpt - for inpt in inputs0 - ) - inputs1 = tuple( - inpt.to(torch.float16) if inpt.dtype == torch.float32 else inpt - for inpt in inputs1 - ) - mlirs, modules = [], [] - assert len(layers0) == len(layers1) - for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))): - mlir_path = Path(f"{self.dir_name}/{idx}_full.mlir") - vmfb_path = Path(f"{self.dir_name}/{idx}_full.vmfb") - # if vmfb_path.exists(): - # continue - if mlir_path.exists(): - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() - mlirs.append(bytecode) - else: - hidden_states_placeholder0 = TensorPlaceholder.like( - inputs0[0], dynamic_axes=[1] - ) - attention_mask_placeholder0 = TensorPlaceholder.like( - inputs0[1], dynamic_axes=[3] - ) - position_ids_placeholder0 = TensorPlaceholder.like( - inputs0[2], dynamic_axes=[1] - ) - hidden_states_placeholder1 = TensorPlaceholder.like( - inputs1[0], dynamic_axes=[1] - ) - attention_mask_placeholder1 = TensorPlaceholder.like( - inputs1[1], dynamic_axes=[3] - ) - position_ids_placeholder1 = TensorPlaceholder.like( - inputs1[2], dynamic_axes=[1] - ) - pkv0_placeholder = TensorPlaceholder.like( - inputs1[3], dynamic_axes=[2] - ) - pkv1_placeholder = TensorPlaceholder.like( - inputs1[4], dynamic_axes=[2] - ) - - print(f"Compiling layer {idx} mlir") - ts_g = self.compile_vicuna_layer( - layer0, inputs0[0], inputs0[1], inputs0[2] - ) - if self.precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - hidden_states_placeholder0 = TensorPlaceholder.like( - inputs0[0], dynamic_axes=[1] - ) - attention_mask_placeholder0 = TensorPlaceholder.like( - inputs0[1], dynamic_axes=[3] - ) - position_ids_placeholder0 = TensorPlaceholder.like( - inputs0[2], dynamic_axes=[1] - ) - hidden_states_placeholder1 = TensorPlaceholder.like( - inputs1[0], dynamic_axes=[1] - ) - attention_mask_placeholder1 = TensorPlaceholder.like( - inputs1[1], dynamic_axes=[3] - ) - position_ids_placeholder1 = TensorPlaceholder.like( - inputs1[2], dynamic_axes=[1] - ) - pkv0_placeholder = TensorPlaceholder.like( - inputs1[3], dynamic_axes=[2] - ) - pkv1_placeholder = TensorPlaceholder.like( - inputs1[4], dynamic_axes=[2] - ) - - module0 = torch_mlir.compile( - ts_g, - ( - hidden_states_placeholder0, - inputs0[1], - inputs0[2], - ), - output_type="torch", - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - module0, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - else: - module0 = torch_mlir.compile( - ts_g, - ( - hidden_states_placeholder0, - inputs0[1], - inputs0[2], - ), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - module0 = self.write_in_dynamic_inputs0(str(module0), 137) - - ts_g = self.compile_vicuna_layer( - layer1, - inputs1[0], - inputs1[1], - inputs1[2], - inputs1[3], - inputs1[4], - ) - if self.precision in ["int4", "int8"]: - module1 = torch_mlir.compile( - ts_g, - ( - inputs1[0], - attention_mask_placeholder1, - inputs1[2], - pkv0_placeholder, - pkv1_placeholder, - ), - output_type="torch", - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - module1, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - else: - module1 = torch_mlir.compile( - ts_g, - ( - inputs1[0], - attention_mask_placeholder1, - inputs1[2], - pkv0_placeholder, - pkv1_placeholder, - ), - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - module1 = self.write_in_dynamic_inputs1(str(module1), 138) - - module_combined = self.combine_mlir_scripts( - module0, module1, f"{self.dir_name}/{idx}_full.mlir" - ) - mlirs.append(module_combined) - - if vmfb_path.exists(): - device_idx = self.get_device_index( - f"first_vicuna.model.model.layers.{idx}[\s.$]" - ) - if device_idx is None: - if self.n_devices is not None: - device_idx = (idx * self.n_devices) // self.n_layers_dict[self.model_name] - else: - device_idx = None - module = SharkInference( - None, - device=device, - device_idx=device_idx, - mlir_dialect="tm_tensor", - mmap=True, - ) - module.load_module(vmfb_path) - else: - print(f"Compiling layer {idx} vmfb") - device_idx = self.get_device_index( - f"first_vicuna.model.model.layers.{idx}[\s.$]" - ) - if device_idx is None: - if self.n_devices is not None: - device_idx = (idx * self.n_devices) // self.n_layers_dict[self.model_name] - else: - device_idx = None - module = SharkInference( - mlirs[idx], - device=device, - device_idx=device_idx, - mlir_dialect="tm_tensor", - mmap=True, - ) - module.save_module( - module_name=f"{self.dir_name}/{idx}_full", - extra_args=[ - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - + self.extra_args, - debug=self.debug, - ) - module.load_module(vmfb_path) - modules.append(module) - return mlirs, modules - - def compile_to_vmfb_one_model4( - self, inputs0, layers0, inputs1, layers1, device="cpu" - ): - mlirs, modules = [], [] - assert len(layers0) == len(layers1) - for layer0, layer1, idx in zip(layers0, layers1, range(len(layers0))): - mlir_path = Path(f"{idx}_full.mlir") - vmfb_path = Path(f"{idx}_full.vmfb") - # if vmfb_path.exists(): - # continue - if mlir_path.exists(): - f_ = open(mlir_path, "rb") - bytecode = f_.read() - f_.close() - mlirs.append(bytecode) - else: - filepath = Path(f"{idx}_full.mlir") - download_public_file( - f"gs://shark_tank/elias/compressed_sv/{idx}_full.mlir", - filepath.absolute(), - single_file=True, - ) - - f_ = open(f"{idx}_full.mlir", "rb") - bytecode = f_.read() - f_.close() - mlirs.append(bytecode) - - if vmfb_path.exists(): - device_idx = self.get_device_index( - f"first_vicuna.model.model.layers.{idx}[\s.$]" - ) - if device_idx is None: - if self.n_devices is not None: - device_idx = idx % self.n_devices - else: - device_idx = None - module = SharkInference( - None, - device=device, - device_idx=device_idx, - mlir_dialect="tm_tensor", - mmap=True, - ) - module.load_module(vmfb_path) - else: - print(f"Compiling layer {idx} vmfb") - device_idx = self.get_device_index( - f"first_vicuna.model.model.layers.{idx}[\s.$]" - ) - if device_idx is None: - if self.n_devices is not None: - device_idx = idx % self.n_devices - else: - device_idx = None - module = SharkInference( - mlirs[idx], - device=device, - device_idx=device_idx, - mlir_dialect="tm_tensor", - mmap=True, - ) - module.save_module( - module_name=f"{idx}_full", - extra_args=[ - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - + self.extra_args, - debug=self.debug, - ) - module.load_module(vmfb_path) - modules.append(module) - return mlirs, modules - - def get_sharded_model(self, device="cpu", compressed=False): - # SAMPLE_INPUT_LEN is used for creating mlir with dynamic inputs, which is currently an increadibly hacky proccess - # please don't change it - SAMPLE_INPUT_LEN = 137 - vicuna_model = self.get_src_model() - if compressed: - vicuna_model.model = LlamaModel.from_pretrained( - "TheBloke/vicuna-7B-1.1-HF" - ) - - if self.precision in ["int4", "int8"]: - - if not self.check_all_artifacts_present(): - print("Applying weight quantization..") - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - weight_bit_width = 4 if self.precision == "int4" else 8 - - quantize_model( - get_model_impl(vicuna_model).layers, - dtype=torch.float32, - weight_quant_type="asym", - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_granularity="per_group", - weight_group_size=self.weight_group_size, - quantize_weight_zero_point=False, - input_bit_width=None, - input_scale_type="float", - input_param_method="stats", - input_quant_type="asym", - input_quant_granularity="per_tensor", - quantize_input_zero_point=False, - seqlen=2048, - ) - - print("Weight quantization applied.") - - else: - print("Skipping quantization, as all required artifacts are present") - - placeholder_pkv_segment = tuple( - ( - torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), - torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), - ) - for _ in range(8) - ) - placeholder_pkv_full = tuple( - ( - torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), - torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), - ) - for _ in range(self.n_layers_dict[self.model_name]) - ) - - placeholder_input0 = ( - torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]), - torch.zeros([1, 1, SAMPLE_INPUT_LEN, SAMPLE_INPUT_LEN]), - torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64), - ) - - placeholder_input1 = ( - torch.zeros([1, 1, self.hidden_state_size_dict[self.model_name]]), - torch.zeros([1, 1, 1, SAMPLE_INPUT_LEN + 1]), - torch.zeros([1, 1], dtype=torch.int64), - torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), - torch.zeros([1, self.n_layers_dict[self.model_name], SAMPLE_INPUT_LEN, 128]), - ) - - norm = VicunaNorm(vicuna_model.model.norm) - device_idx = self.get_device_index( - r"vicuna\.model\.model\.norm(?:\.|\s|$)" - ) - # HC device_idx for non-layer vmfbs - device_idx = 0 - norm = self.compile_norm( - norm, - torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]), - device=self.device, - device_idx=device_idx, - ) - - embeddings = VicunaEmbedding(vicuna_model.model.embed_tokens) - device_idx = self.get_device_index( - r"vicuna\.model\.model\.embed_tokens(?:\.|\s|$)" - ) - # HC device_idx for non-layer vmfbs - device_idx = 0 - embeddings = self.compile_embedding( - embeddings, - (torch.zeros([1, SAMPLE_INPUT_LEN], dtype=torch.int64)), - device=self.device, - device_idx=device_idx, - ) - - lmhead = LMHead(vicuna_model.lm_head) - device_idx = self.get_device_index( - r"vicuna\.model\.lm_head(?:\.|\s|$)" - ) - # HC device_idx for non-layer vmfbs - device_idx = 0 - lmhead = self.compile_lmhead( - lmhead, - torch.zeros([1, SAMPLE_INPUT_LEN, self.hidden_state_size_dict[self.model_name]]), - device=self.device, - device_idx=device_idx, - ) - - if not compressed: - layers0 = [ - FirstVicunaLayer(layer) for layer in vicuna_model.model.layers - ] - layers1 = [ - SecondVicunaLayer(layer) for layer in vicuna_model.model.layers - ] - - else: - layers00 = EightLayerLayerFV(vicuna_model.model.layers[0:8]) - layers01 = EightLayerLayerFV(vicuna_model.model.layers[8:16]) - layers02 = EightLayerLayerFV(vicuna_model.model.layers[16:24]) - layers03 = EightLayerLayerFV(vicuna_model.model.layers[24:32]) - layers10 = EightLayerLayerSV(vicuna_model.model.layers[0:8]) - layers11 = EightLayerLayerSV(vicuna_model.model.layers[8:16]) - layers12 = EightLayerLayerSV(vicuna_model.model.layers[16:24]) - layers13 = EightLayerLayerSV(vicuna_model.model.layers[24:32]) - layers0 = [layers00, layers01, layers02, layers03] - layers1 = [layers10, layers11, layers12, layers13] - - _, modules = self.compile_to_vmfb_one_model( - placeholder_input0, - layers0, - placeholder_input1, - layers1, - device=device, - ) - - if not compressed: - if self.n_devices is None: - breakpoints = None - else: - breakpoints = [x for x in range(0,len(modules),(self.n_devices % 2) + (len(modules)//(self.n_devices)))][1:] + [len(modules)] - shark_layers = [CompiledVicunaLayer(m, i, breakpoints) for (i, m) in enumerate(modules)] - else: - shark_layers = [CompiledEightLayerLayer(m) for m in modules] - vicuna_model.model.compressedlayers = shark_layers - - sharded_model = ShardedVicunaModel( - vicuna_model, - shark_layers, - lmhead, - embeddings, - norm, - ) - return sharded_model - - def compile(self, device="cpu"): - return self.get_sharded_model( - device=device, compressed=self.compressed - ) - - def generate(self, prompt, cli=False): - # TODO: refactor for cleaner integration - - history = [] - - tokens_generated = [] - _past_key_values = None - _token = None - detoks_generated = [] - for iteration in range(self.max_num_tokens): - params = { - "prompt": prompt, - "is_first": iteration == 0, - "token": _token, - "past_key_values": _past_key_values, - } - - decode_st_time = time.time() - - generated_token_op = self.generate_new_token(params=params) - - decode_time = (time.time() - decode_st_time) * 1000 - - _token = generated_token_op["token"] - _past_key_values = generated_token_op["past_key_values"] - _detok = generated_token_op["detok"] - history.append(_token) - yield self.tokenizer.decode(history), None, decode_time - - if _token == 2: - break - detoks_generated.append(_detok) - tokens_generated.append(_token) - - for i in range(len(tokens_generated)): - if type(tokens_generated[i]) != int: - tokens_generated[i] = int(tokens_generated[i][0]) - result_output = self.tokenizer.decode(tokens_generated) - yield result_output, "formatted", None - - def autocomplete(self, prompt): - # use First vic alone to complete a story / prompt / sentence. - pass - - -class UnshardedVicuna(VicunaBase): - def __init__( - self, - model_name, - hf_model_path="TheBloke/vicuna-7B-1.1-HF", - hf_auth_token: str = None, - max_num_tokens=512, - min_num_tokens=0, - device="cpu", - device_id=None, - vulkan_target_triple="", - precision="int8", - vicuna_mlir_path=None, - vicuna_vmfb_path=None, - load_mlir_from_shark_tank=False, - low_device_memory=False, - weight_group_size=128, - download_vmfb=False, - cache_vicunas=False, - extra_args_cmd=[], - debug=False, - ) -> None: - super().__init__( - model_name, - hf_model_path, - max_num_tokens, - extra_args_cmd=extra_args_cmd, - ) - self.hf_auth_token = hf_auth_token - if self.model_name == "llama2_7b": - self.hf_model_path = "meta-llama/Llama-2-7b-chat-hf" - elif self.model_name == "llama2_13b": - self.hf_model_path = "meta-llama/Llama-2-13b-chat-hf" - elif self.model_name == "llama2_70b": - self.hf_model_path = "meta-llama/Llama-2-70b-chat-hf" - print(f"[DEBUG] hf model name: {self.hf_model_path}") - self.max_sequence_length = 256 - self.min_num_tokens = min_num_tokens - self.vulkan_target_triple = vulkan_target_triple - self.precision = precision - self.download_vmfb = download_vmfb - self.vicuna_vmfb_path = vicuna_vmfb_path - self.vicuna_mlir_path = vicuna_mlir_path - self.load_mlir_from_shark_tank = load_mlir_from_shark_tank - self.low_device_memory = low_device_memory - self.weight_group_size = weight_group_size - self.debug = debug - # Sanity check for device, device_id pair - if "://" in device: - if device_id is not None: - print( - "[ERR] can't have both full device path and a device id.\n" - f"Device : {device} | device_id : {device_id}\n" - "proceeding with given Device ignoring device_id" - ) - self.device, self.device_id = device.split("://") - if len(self.device_id) < 2: - self.device_id = int(self.device_id) - else: - self.device, self.device_id = device, device_id - if self.vicuna_mlir_path == None: - self.vicuna_mlir_path = self.get_model_path() - if self.vicuna_vmfb_path == None: - self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb") - self.tokenizer = self.get_tokenizer() - self.cache_vicunas = cache_vicunas - - self.compile() - - def get_model_path(self, suffix="mlir"): - safe_device = self.device.split("-")[0] - safe_device = safe_device.split("://")[0] - if suffix in ["mlirbc", "mlir"]: - return Path(f"{self.model_name}_{self.precision}.{suffix}") - - # Need to distinguish between multiple vmfbs of the same model - # compiled for different devices of the same driver - # Driver - Differentiator - # Vulkan - target_triple - # ROCm - device_arch - - differentiator = "" - if "vulkan" == self.device: - target_triple = "" - if self.vulkan_target_triple != "": - target_triple = "_" - target_triple += "_".join( - self.vulkan_target_triple.split("-")[:-1] - ) - differentiator = target_triple - else: - from shark.iree_utils.vulkan_utils import get_vulkan_triple_flag - tt = get_vulkan_triple_flag(device_num=self.device_id) - differentiator = "_" + "_".join(tt.split("=")[1].split('-')[:-1]) - - elif "rocm" == self.device: - from shark.iree_utils.gpu_utils import get_rocm_device_arch - - device_arch = get_rocm_device_arch( - self.device_id if self.device_id is not None else 0, - self.extra_args, - ) - differentiator = "_" + device_arch - - return Path( - f"{self.model_name}_{self.precision}_{safe_device}{differentiator}.{suffix}" - ) - - def get_tokenizer(self): - local_tokenizer_path = Path(Path.cwd(), "llama2_tokenizer_configs") - local_tokenizer_path.mkdir(parents=True, exist_ok=True) - tokenizer_files_to_download = [ - "config.json", - "special_tokens_map.json", - "tokenizer.model", - "tokenizer_config.json", - ] - for tokenizer_file in tokenizer_files_to_download: - download_public_file( - f"gs://shark_tank/llama2_tokenizer/{tokenizer_file}", - Path(local_tokenizer_path, tokenizer_file), - single_file=True, - ) - tokenizer = AutoTokenizer.from_pretrained(str(local_tokenizer_path)) - return tokenizer - - def get_src_model(self): - kwargs = { - "torch_dtype": torch.float, - "use_auth_token": self.hf_auth_token, - } - vicuna_model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, - **kwargs, - ) - return vicuna_model - - def write_in_dynamic_inputs0(self, module, dynamic_input_size): - print("[DEBUG] writing dynamic inputs to first vicuna") - # Current solution for ensuring mlir files support dynamic inputs - # TODO: find a more elegant way to implement this - new_lines = [] - module = module.splitlines() - while module: - line = module.pop(0) - line = re.sub(f"{dynamic_input_size}x", "?x", line) - if "?x" in line: - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - line = re.sub(f" {dynamic_input_size},", " %dim,", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", "tensor.empty(%dim, %dim)", line - ) - if "arith.cmpi" in line: - line = re.sub(f"c{dynamic_input_size}", "dim", line) - if "%0 = tensor.empty(%dim) : tensor" in line: - new_lines.append( - "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" - ) - if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line: - continue - - new_lines.append(line) - return "\n".join(new_lines) - - def write_in_dynamic_inputs1(self, module): - print("[DEBUG] writing dynamic inputs to second vicuna") - - def remove_constant_dim(line): - if "c19_i64" in line: - line = re.sub("c19_i64", "dim_i64", line) - if "19x" in line: - line = re.sub("19x", "?x", line) - line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line) - if "tensor.empty" in line and "?x?" in line: - line = re.sub( - "tensor.empty\(%dim\)", - "tensor.empty(%dim, %dim)", - line, - ) - if "arith.cmpi" in line: - line = re.sub("c19", "dim", line) - if " 19," in line: - line = re.sub(" 19,", " %dim,", line) - if "x20x" in line or "<20x" in line: - line = re.sub("20x", "?x", line) - line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line) - if " 20," in line: - line = re.sub(" 20,", " %dimp1,", line) - return line - - module = module.splitlines() - new_lines = [] - - # Using a while loop and the pop method to avoid creating a copy of module - pkv_tensor_shape = f"tensor<1x{self.n_layers_dict[self.model_name]}x?x128x" - if self.precision in ["fp16", "int4", "int8"]: - pkv_tensor_shape += "f16>" - else: - pkv_tensor_shape += "f32>" - - while module: - line = module.pop(0) - if "%c19_i64 = arith.constant 19 : i64" in line: - new_lines.append("%c2 = arith.constant 2 : index") - new_lines.append( - f"%dim_4_int = tensor.dim %arg1, %c2 : {pkv_tensor_shape}" - ) - new_lines.append( - "%dim_i64 = arith.index_cast %dim_4_int : index to i64" - ) - continue - if "%c2 = arith.constant 2 : index" in line: - continue - if "%c20_i64 = arith.constant 20 : i64" in line: - new_lines.append("%c1_i64 = arith.constant 1 : i64") - new_lines.append( - "%c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" - ) - new_lines.append( - "%dimp1 = arith.index_cast %c20_i64 : i64 to index" - ) - continue - line = remove_constant_dim(line) - new_lines.append(line) - - return "\n".join(new_lines) - - def compile(self): - # Testing : DO NOT Download Vmfbs if not found. Modify later - # download vmfbs for A100 - if not self.vicuna_vmfb_path.exists() and self.download_vmfb: - print( - f"Looking into gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}" - ) - download_public_file( - f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}", - self.vicuna_vmfb_path.absolute(), - single_file=True, - ) - self.shark_model = get_vmfb_from_path( - self.vicuna_vmfb_path, self.device, "tm_tensor", self.device_id - ) - if self.shark_model is not None: - print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}") - return - - print(f"[DEBUG] vmfb not found (search path: {self.vicuna_vmfb_path})") - mlir_generated = False - for suffix in ["mlirbc", "mlir"]: - self.vicuna_mlir_path = self.get_model_path(suffix) - if ( - "cpu" in self.device - and "llama2_7b" in self.vicuna_mlir_path.name - ): - self.vicuna_mlir_path = Path("llama2_7b_int4_f32.mlir") - if ( - not self.vicuna_mlir_path.exists() - and self.load_mlir_from_shark_tank - ): - print( - f"Looking into gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}" - ) - download_public_file( - f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}", - self.vicuna_mlir_path.absolute(), - single_file=True, - ) - if self.vicuna_mlir_path.exists(): - print( - f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}" - ) - combined_module = self.vicuna_mlir_path.absolute() - mlir_generated = True - break - - if not mlir_generated: - print(f"[DEBUG] mlir not found") - - print("[DEBUG] generating mlir on device") - # Select a compilation prompt such that the resulting input_ids - # from the model's tokenizer has shape [1, 19] - compilation_prompt = "".join(["0" for _ in range(17)]) - - first_model_path = f"first_{self.model_name}_{self.precision}.mlir" - if Path(first_model_path).exists(): - print(f"loading {first_model_path}") - with open(Path(first_model_path), "r") as f: - first_module = f.read() - else: - # generate first vicuna - compilation_input_ids = self.tokenizer( - compilation_prompt, - return_tensors="pt", - ).input_ids - compilation_input_ids = torch.tensor( - compilation_input_ids - ).reshape([1, 19]) - firstVicunaCompileInput = (compilation_input_ids,) - if "cpu" in self.device: - model = FirstVicuna( - self.hf_model_path, - self.precision, - "fp32" if self.device == "cpu" else "fp16", - self.weight_group_size, - self.model_name, - self.hf_auth_token, - ) - else: - model = FirstVicunaGPU( - self.hf_model_path, - self.precision, - "fp32" if self.device == "cpu" else "fp16", - self.weight_group_size, - self.model_name, - self.hf_auth_token, - ) - print(f"[DEBUG] generating torchscript graph") - is_f16 = self.precision in ["fp16", "int4"] - ts_graph = import_with_fx( - model, - firstVicunaCompileInput, - is_f16=is_f16, - precision=self.precision, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - del model - firstVicunaCompileInput = list(firstVicunaCompileInput) - firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like( - firstVicunaCompileInput[0], dynamic_axes=[1] - ) - - firstVicunaCompileInput = tuple(firstVicunaCompileInput) - first_module = None - print(f"[DEBUG] generating torch mlir") - if self.precision in ["int4", "int8"]: - first_module = torch_mlir.compile( - ts_graph, - [*firstVicunaCompileInput], - output_type=torch_mlir.OutputType.TORCH, - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - if self.cache_vicunas: - with open( - first_model_path[:-5] + "_torch.mlir", "w+" - ) as f: - f.write(str(first_module)) - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - first_module, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - else: - first_module = torch_mlir.compile( - ts_graph, - [*firstVicunaCompileInput], - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - del ts_graph - del firstVicunaCompileInput - gc.collect() - - print( - "[DEBUG] successfully generated first vicuna linalg mlir" - ) - first_module = self.write_in_dynamic_inputs0( - str(first_module), dynamic_input_size=19 - ) - if self.cache_vicunas: - with open(first_model_path, "w+") as f: - f.write(first_module) - print("Finished writing IR after dynamic") - - print(f"[DEBUG] Starting generation of second llama") - second_model_path = ( - f"second_{self.model_name}_{self.precision}.mlir" - ) - if Path(second_model_path).exists(): - print(f"loading {second_model_path}") - with open(Path(second_model_path), "r") as f: - second_module = f.read() - else: - # generate second vicuna - compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) - if self.model_name == "llama2_13b": - dim1 = 40 - total_tuple = 80 - elif self.model_name == "llama2_70b": - dim1 = 8 - total_tuple = 160 - else: - dim1 = 32 - total_tuple = 64 - pkv = tuple( - (torch.zeros([1, dim1, 19, 128], dtype=torch.float32)) - for _ in range(total_tuple) - ) - secondVicunaCompileInput = (compilation_input_ids,) + pkv - if "cpu" in self.device: - if self.model_name == "llama2_13b": - model = SecondVicuna13B( - self.hf_model_path, - self.precision, - "fp32", - self.weight_group_size, - self.model_name, - self.hf_auth_token, - ) - elif self.model_name == "llama2_70b": - model = SecondVicuna70B( - self.hf_model_path, - self.precision, - "fp32", - self.weight_group_size, - self.model_name, - self.hf_auth_token, - ) - else: - model = SecondVicuna7B( - self.hf_model_path, - self.precision, - "fp32", - self.weight_group_size, - self.model_name, - self.hf_auth_token, - ) - else: - if self.model_name == "llama2_13b": - model = SecondVicuna13BGPU( - self.hf_model_path, - self.precision, - "fp16", - self.weight_group_size, - self.model_name, - self.hf_auth_token, - ) - elif self.model_name == "llama2_70b": - model = SecondVicuna70BGPU( - self.hf_model_path, - self.precision, - "fp16", - self.weight_group_size, - self.model_name, - self.hf_auth_token, - ) - else: - model = SecondVicuna7BGPU( - self.hf_model_path, - self.precision, - "fp16", - self.weight_group_size, - self.model_name, - self.hf_auth_token, - ) - print(f"[DEBUG] generating torchscript graph") - is_f16 = self.precision in ["fp16", "int4"] - ts_graph = import_with_fx( - model, - secondVicunaCompileInput, - is_f16=is_f16, - precision=self.precision, - f16_input_mask=[False] + [True] * total_tuple, - mlir_type="torchscript", - ) - del model - if self.precision in ["fp16", "int4"]: - secondVicunaCompileInput = get_f16_inputs( - secondVicunaCompileInput, - True, - f16_input_mask=[False] + [True] * total_tuple, - ) - secondVicunaCompileInput = list(secondVicunaCompileInput) - for i in range(len(secondVicunaCompileInput)): - if i != 0: - secondVicunaCompileInput[ - i - ] = torch_mlir.TensorPlaceholder.like( - secondVicunaCompileInput[i], dynamic_axes=[2] - ) - secondVicunaCompileInput = tuple(secondVicunaCompileInput) - print(f"[DEBUG] generating torch mlir") - if self.precision in ["int4", "int8"]: - second_module = torch_mlir.compile( - ts_graph, - [*secondVicunaCompileInput], - output_type=torch_mlir.OutputType.TORCH, - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - print(f"[DEBUG] converting torch to linalg") - if self.cache_vicunas: - with open( - second_model_path[:-5] + "_torch.mlir", "w+" - ) as f: - f.write(str(second_module)) - run_pipeline_with_repro_report( - second_module, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - else: - second_module = torch_mlir.compile( - ts_graph, - [*secondVicunaCompileInput], - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - del ts_graph - del secondVicunaCompileInput - gc.collect() - - print( - "[DEBUG] successfully generated second vicuna linalg mlir" - ) - second_module = self.write_in_dynamic_inputs1( - str(second_module) - ) - if self.cache_vicunas: - with open(second_model_path, "w+") as f: - f.write(second_module) - print("Finished writing IR after dynamic") - - combined_module = self.combine_mlir_scripts( - first_module, - second_module, - self.vicuna_mlir_path, - ) - combined_module = save_mlir( - combined_module, - model_name="combined_llama", - mlir_dialect="tm_tensor", - dir=self.vicuna_mlir_path, - ) - del first_module, second_module - - print( - f"Compiling for device : {self.device}" - f"{'://' + str(self.device_id) if self.device_id is not None else ''}" - ) - shark_module = SharkInference( - mlir_module=combined_module, - device=self.device, - mlir_dialect="tm_tensor", - device_idx=self.device_id, - ) - path = shark_module.save_module( - self.vicuna_vmfb_path.parent.absolute(), - self.vicuna_vmfb_path.stem, - extra_args=[ - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - + self.extra_args, - debug=self.debug, - ) - print("Saved vic vmfb at ", str(path)) - shark_module.load_module(path) - self.shark_model = shark_module - - def decode_tokens(self, res_tokens): - for i in range(len(res_tokens)): - if type(res_tokens[i]) != int: - res_tokens[i] = int(res_tokens[i][0]) - - res_str = self.tokenizer.decode(res_tokens, skip_special_tokens=False) - return res_str - - def generate(self, prompt, cli): - # TODO: refactor for cleaner integration - if self.shark_model is None: - self.compile() - res_tokens = [] - params = {"prompt": prompt, "is_first": True, "fv": self.shark_model} - - prefill_st_time = time.time() - generated_token_op = self.generate_new_token( - params=params, sharded=False, cli=cli - ) - prefill_time_ms = (time.time() - prefill_st_time) * 1000 - - token = generated_token_op["token"] - if "cpu" not in self.device: - logits = generated_token_op["logits"] - pkv = generated_token_op["past_key_values"] - detok = generated_token_op["detok"] - yield detok, None, prefill_time_ms - - res_tokens.append(token) - if cli: - print(f"Assistant: {detok}", end=" ", flush=True) - - for idx in range(self.max_num_tokens): - params = { - "token": token, - "is_first": False, - "past_key_values": pkv, - "sv": self.shark_model, - } - if "cpu" not in self.device: - params["logits"] = logits - - decode_st_time = time.time() - generated_token_op = self.generate_new_token( - params=params, sharded=False, cli=cli - ) - decode_time_ms = (time.time() - decode_st_time) * 1000 - - token = generated_token_op["token"] - if "cpu" not in self.device: - logits = generated_token_op["logits"] - pkv = generated_token_op["past_key_values"] - detok = generated_token_op["detok"] - - if token == 2 and idx >= self.min_num_tokens: - break - res_tokens.append(token) - if detok == "<0x0A>": - if cli: - print("\n", end="", flush=True) - else: - if cli: - print(f"{detok}", end=" ", flush=True) - yield detok, None, decode_time_ms - - res_str = self.decode_tokens(res_tokens) - yield res_str, "formatted", None - - def autocomplete(self, prompt): - # use First vic alone to complete a story / prompt / sentence. - pass - - -# NOTE: Each `model_name` should have its own start message -start_message = { - "llama2_7b": ( - "System: You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal " - "content. Please ensure that your responses are socially unbiased and positive " - "in nature. If a question does not make any sense, or is not factually coherent, " - "explain why instead of answering something not correct. If you don't know the " - "answer to a question, please don't share false information." - ), - "llama2_13b": ( - "System: You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal " - "content. Please ensure that your responses are socially unbiased and positive " - "in nature. If a question does not make any sense, or is not factually coherent, " - "explain why instead of answering something not correct. If you don't know the " - "answer to a question, please don't share false information." - ), - "llama2_70b": ( - "System: You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal " - "content. Please ensure that your responses are socially unbiased and positive " - "in nature. If a question does not make any sense, or is not factually coherent, " - "explain why instead of answering something not correct. If you don't know the " - "answer to a question, please don't share false information." - ), - "vicuna": ( - "A chat between a curious user and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the user's " - "questions.\n" - ), -} - - -def create_prompt(model_name, history): - global start_message - system_message = start_message[model_name] - if "llama2" in model_name: - B_INST, E_INST = "[INST]", "[/INST]" - B_SYS, E_SYS = "<>\n", "\n<>\n\n" - conversation = "".join( - [ - f"{B_INST} {item[0].strip()} {E_INST} {item[1].strip()} " - for item in history[1:] - ] - ) - msg = f"{B_INST} {B_SYS} {system_message} {E_SYS} {history[0][0]} {E_INST} {history[0][1]} {conversation}" - - else: - conversation = "".join( - [ - "".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) - for item in history - ] - ) - msg = system_message + conversation - msg = msg.strip() - return msg - -def miliseconds_to_seconds(ms: float) -> float: - return ms / 1000.0 - - -@dataclass -class BenchmarkRunInfo: - num_prompt_tokens : int - prefill_time_ms : float - token_times_ms : list[float] - - def get_prefill_speed(self) -> float: - seconds = miliseconds_to_seconds(self.prefill_time_ms) - if seconds == 0.0: - return float('inf') - return self.num_prompt_tokens / seconds - - def num_generated_tokens(self) -> int: - return len(self.token_times_ms) - - def get_decode_time_ms(self) -> float: - return sum(self.token_times_ms) - - def get_decode_speed(self) -> float: - seconds = miliseconds_to_seconds(self.get_decode_time_ms()) - if seconds == 0.0: - return float('inf') - return self.num_generated_tokens() / seconds - - def get_e2e_time_ms(self) -> float: - return self.prefill_time_ms + self.get_decode_time_ms() - - def get_e2e_decode_speed(self) -> float: - seconds = miliseconds_to_seconds(self.get_e2e_time_ms()) - if seconds == 0.0: - return float('inf') - return self.num_generated_tokens() / seconds - - def get_e2e_token_processing_speed(self) -> float: - seconds = miliseconds_to_seconds(self.get_e2e_time_ms()) - if seconds == 0.0: - return float('inf') - return (self.num_prompt_tokens + self.num_generated_tokens()) / seconds - - def print(self) -> None: - total_tokens = self.num_prompt_tokens + self.num_generated_tokens() - print(f"Num tokens: {self.num_prompt_tokens:} (prompt), {self.num_generated_tokens()} (generated), {total_tokens} (total)") - print(f"Prefill: {self.prefill_time_ms:.2f} ms, {self.get_prefill_speed():.2f} tokens/s") - print(f"Decode: {self.get_decode_time_ms():.2f} ms, {self.get_decode_speed():.2f} tokens/s") - print(f"Decode end-2-end: {self.get_e2e_decode_speed():.2f} tokens/s (w/o prompt), {self.get_e2e_token_processing_speed():.2f} tokens/s (w/ prompt)") - -def enable_tracy_tracing(): - # Make tracy wait for a caputre to be collected before exiting. - environ["TRACY_NO_EXIT"] = "1" - - if "IREE_PY_RUNTIME" not in environ or environ["IREE_PY_RUNTIME"] != "tracy": - print("ERROR: Tracing enabled but tracy iree runtime not used.", file=sys.stderr) - print("Set the IREE_PY_RUNTIME=tracy environment variable.", file=sys.stderr) - sys.exit(1) - - -def print_aggregate_stats(run_infos: list[BenchmarkRunInfo]) -> None: - num_iterations = len(run_infos) - print(f'Number of iterations: {num_iterations}') - if num_iterations == 0: - return - - if len(run_infos) == 1: - run_infos[0].print() - return - - total_tokens = run_infos[0].num_prompt_tokens + run_infos[0].num_generated_tokens() - print(f"Num tokens: {run_infos[0].num_prompt_tokens} (prompt), {run_infos[0].num_generated_tokens()} (generated), {total_tokens} (total)") - - def avg_and_stdev(data): - x = list(data) - return mean(x), stdev(x) - - avg_prefill_ms, stdev_prefill = avg_and_stdev(x.prefill_time_ms for x in run_infos) - avg_prefill_speed = mean(x.get_prefill_speed() for x in run_infos) - print(f"Prefill: avg. {avg_prefill_ms:.2f} ms (stdev {stdev_prefill:.2f}), avg. {avg_prefill_speed:.2f} tokens/s") - - avg_decode_ms, stdev_decode = avg_and_stdev(x.get_decode_time_ms() for x in run_infos) - avg_decode_speed = mean(x.get_decode_speed() for x in run_infos) - print(f"Decode: avg. {avg_decode_ms:.2f} ms (stdev {stdev_decode:.2f}), avg. {avg_decode_speed:.2f} tokens/s") - - avg_e2e_decode_speed = mean(x.get_e2e_decode_speed() for x in run_infos) - avg_e2e_processing_speed = mean(x.get_e2e_token_processing_speed() for x in run_infos) - print(f"Decode end-2-end: avg. {avg_e2e_decode_speed:.2f} tokens/s (w/o prompt), avg. {avg_e2e_processing_speed:.2f} (w/ prompt)") - - -if __name__ == "__main__": - args, unknown = parser.parse_known_args() - - _extra_args = list(args.Xiree_compile) - - model_list = { - "vicuna": "vicuna=>TheBloke/vicuna-7B-1.1-HF", - "llama2_7b": "llama2_7b=>meta-llama/Llama-2-7b-chat-hf", - "llama2_13b": "llama2_13b=>meta-llama/Llama-2-13b-chat-hf", - "llama2_70b": "llama2_70b=>meta-llama/Llama-2-70b-chat-hf", - } - - device_id = None - if args.enable_tracing: - enable_tracy_tracing() - # Process vulkan target triple. - # TODO: This feature should just be in a common utils for other LLMs and in general - # any model run via SHARK for Vulkan backend. - vulkan_target_triple = args.iree_vulkan_target_triple - if vulkan_target_triple != "": - _extra_args.append( - f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" - ) - # Step 1. Fetch the device ID. - from shark.iree_utils.vulkan_utils import ( - get_all_vulkan_devices, - get_vulkan_target_triple, - ) - - vulkaninfo_list = get_all_vulkan_devices() - id = 0 - for device in vulkaninfo_list: - target_triple = get_vulkan_target_triple(vulkaninfo_list[id]) - if target_triple == vulkan_target_triple: - device_id = id - break - id += 1 - - if "://" in device : - from shark.iree_utils.compile_utils import clean_device_info - _, device_id = clean_device_info(args.device) - - assert ( - device_id - ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" - # Step 2. Add a few flags targetting specific hardwares. - if "rdna" in vulkan_target_triple: - flags_to_add = [ - "--iree-spirv-index-bits=64", - ] - _extra_args = _extra_args + flags_to_add - - vic = None - if not args.sharded: - vic_mlir_path = ( - None - if args.vicuna_mlir_path is None - else Path(args.vicuna_mlir_path) - ) - vic_vmfb_path = ( - None - if args.vicuna_vmfb_path is None - else Path(args.vicuna_vmfb_path) - ) - min_tokens = 0 - max_tokens = 512 - if args.enable_microbenchmark: - min_tokens = max_tokens = args.microbenchmark_num_tokens - - vic = UnshardedVicuna( - model_name=args.model_name, - hf_auth_token=args.hf_auth_token, - max_num_tokens=max_tokens, - min_num_tokens=min_tokens, - device=args.device, - vulkan_target_triple=vulkan_target_triple, - precision=args.precision, - vicuna_mlir_path=vic_mlir_path, - vicuna_vmfb_path=vic_vmfb_path, - load_mlir_from_shark_tank=args.load_mlir_from_shark_tank, - weight_group_size=args.weight_group_size, - download_vmfb=args.download_vmfb, - cache_vicunas=args.cache_vicunas, - extra_args_cmd=_extra_args, - device_id=device_id, - ) - else: - if args.config is not None: - config_file = open(args.config) - config_json = json.load(config_file) - config_file.close() - else: - config_json = None - - print( - f"[DEBUG]: model_name_input = {model_list[args.model_name].split('=>')[1]}" - ) - vic = ShardedVicuna( - model_name=args.model_name, - hf_model_path=model_list[args.model_name].split("=>")[1], - hf_auth_token=args.hf_auth_token, - device=args.device, - precision=args.precision, - config_json=config_json, - weight_group_size=args.weight_group_size, - extra_args_cmd=_extra_args, - n_devices=args.n_devices, - ) - - history = [] - - iteration = 0 - - benchmark_run_infos = [] - - - while True: - # TODO: Add break condition from user input - iteration += 1 - if not args.enable_microbenchmark: - user_prompt = input("User prompt: ") - history.append([user_prompt, ""]) - prompt = create_prompt(args.model_name, history) - else: - if iteration > args.microbenchmark_iterations: - break - user_prompt = args.user_prompt - prompt = args.system_prompt + user_prompt - history = [[user_prompt, ""]] - - prompt_token_count = len(vic.tokenizer(prompt).input_ids) - total_time_ms = 0.0 # In order to avoid divide by zero error - prefill_time_ms = 0 - is_first = True - token_times_ms = [] - for text, msg, exec_time in vic.generate(prompt, cli=True): - if args.enable_tracing: - vic.shark_model.shark_runner.iree_config.device.flush_profiling() - if msg is None: - if is_first: - prefill_time_ms = exec_time - is_first = False - else: - token_times_ms.append(exec_time) - elif "formatted" in msg: - print(f"\nResponse:\n{text.strip()}\n") - run_info = BenchmarkRunInfo(prompt_token_count, prefill_time_ms, token_times_ms) - run_info.print() - benchmark_run_infos.append(run_info) - else: - sys.exit( - "unexpected message from the vicuna generate call, exiting." - ) - - if args.enable_microbenchmark: - print("\n### Final Statistics ###") - print_aggregate_stats(benchmark_run_infos) diff --git a/apps/language_models/shark_llama_cli.spec b/apps/language_models/shark_llama_cli.spec deleted file mode 100644 index 04930bb2..00000000 --- a/apps/language_models/shark_llama_cli.spec +++ /dev/null @@ -1,94 +0,0 @@ -# -*- mode: python ; coding: utf-8 -*- -from PyInstaller.utils.hooks import collect_data_files -from PyInstaller.utils.hooks import collect_submodules -from PyInstaller.utils.hooks import copy_metadata - -import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5) - -datas = [] -datas += collect_data_files('torch') -datas += copy_metadata('torch') -datas += copy_metadata('tqdm') -datas += copy_metadata('regex') -datas += copy_metadata('requests') -datas += copy_metadata('packaging') -datas += copy_metadata('filelock') -datas += copy_metadata('numpy') -datas += copy_metadata('tokenizers') -datas += copy_metadata('importlib_metadata') -datas += copy_metadata('torch-mlir') -datas += copy_metadata('omegaconf') -datas += copy_metadata('safetensors') -datas += copy_metadata('huggingface-hub') -datas += copy_metadata('sentencepiece') -datas += copy_metadata("pyyaml") -datas += collect_data_files("tokenizers") -datas += collect_data_files("tiktoken") -datas += collect_data_files("accelerate") -datas += collect_data_files('diffusers') -datas += collect_data_files('transformers') -datas += collect_data_files('opencv-python') -datas += collect_data_files('pytorch_lightning') -datas += collect_data_files('skimage') -datas += collect_data_files('gradio') -datas += collect_data_files('gradio_client') -datas += collect_data_files('iree') -datas += collect_data_files('google-cloud-storage') -datas += collect_data_files('py-cpuinfo') -datas += collect_data_files("shark", include_py_files=True) -datas += collect_data_files("timm", include_py_files=True) -datas += collect_data_files("tqdm") -datas += collect_data_files("tkinter") -datas += collect_data_files("webview") -datas += collect_data_files("sentencepiece") -datas += collect_data_files("jsonschema") -datas += collect_data_files("jsonschema_specifications") -datas += collect_data_files("cpuinfo") -datas += collect_data_files("langchain") - -binaries = [] - -block_cipher = None - -hiddenimports = ['shark', 'shark.shark_inference', 'apps'] -hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x] -hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x] - -a = Analysis( - ['scripts/vicuna.py'], - pathex=['.'], - binaries=binaries, - datas=datas, - hiddenimports=hiddenimports, - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - win_no_prefer_redirects=False, - win_private_assemblies=False, - cipher=block_cipher, - noarchive=False, -) -pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) - -exe = EXE( - pyz, - a.scripts, - a.binaries, - a.zipfiles, - a.datas, - [], - name='shark_llama_cli', - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=True, - upx_exclude=[], - runtime_tmpdir=None, - console=True, - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, -) diff --git a/apps/language_models/src/model_wrappers/falcon_model.py b/apps/language_models/src/model_wrappers/falcon_model.py deleted file mode 100644 index 8354c72e..00000000 --- a/apps/language_models/src/model_wrappers/falcon_model.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch - - -class FalconModel(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, input_ids, attention_mask): - input_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - "past_key_values": None, - "use_cache": True, - } - output = self.model( - **input_dict, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - )[0] - return output[:, -1, :] diff --git a/apps/language_models/src/model_wrappers/falcon_sharded_model.py b/apps/language_models/src/model_wrappers/falcon_sharded_model.py deleted file mode 100644 index c56c7877..00000000 --- a/apps/language_models/src/model_wrappers/falcon_sharded_model.py +++ /dev/null @@ -1,675 +0,0 @@ -import torch -from typing import Optional, Tuple - - -class WordEmbeddingsLayer(torch.nn.Module): - def __init__(self, word_embedding_layer): - super().__init__() - self.model = word_embedding_layer - - def forward(self, input_ids): - output = self.model.forward(input=input_ids) - return output - - -class CompiledWordEmbeddingsLayer(torch.nn.Module): - def __init__(self, compiled_word_embedding_layer): - super().__init__() - self.model = compiled_word_embedding_layer - - def forward(self, input_ids): - input_ids = input_ids.detach().numpy() - new_input_ids = self.model("forward", input_ids) - new_input_ids = new_input_ids.reshape( - [1, new_input_ids.shape[0], new_input_ids.shape[1]] - ) - return torch.tensor(new_input_ids) - - -class LNFEmbeddingLayer(torch.nn.Module): - def __init__(self, ln_f): - super().__init__() - self.model = ln_f - - def forward(self, hidden_states): - output = self.model.forward(input=hidden_states) - return output - - -class CompiledLNFEmbeddingLayer(torch.nn.Module): - def __init__(self, ln_f): - super().__init__() - self.model = ln_f - - def forward(self, hidden_states): - hidden_states = hidden_states.detach().numpy() - new_hidden_states = self.model("forward", (hidden_states,)) - - return torch.tensor(new_hidden_states) - - -class LMHeadEmbeddingLayer(torch.nn.Module): - def __init__(self, embedding_layer): - super().__init__() - self.model = embedding_layer - - def forward(self, hidden_states): - output = self.model.forward(input=hidden_states) - return output - - -class CompiledLMHeadEmbeddingLayer(torch.nn.Module): - def __init__(self, lm_head): - super().__init__() - self.model = lm_head - - def forward(self, hidden_states): - hidden_states = hidden_states.detach().numpy() - new_hidden_states = self.model("forward", (hidden_states,)) - return torch.tensor(new_hidden_states) - - -class FourWayShardingDecoderLayer(torch.nn.Module): - def __init__(self, decoder_layer_model, falcon_variant): - super().__init__() - self.model = decoder_layer_model - self.falcon_variant = falcon_variant - - def forward(self, hidden_states, attention_mask): - new_pkvs = [] - for layer in self.model: - outputs = layer( - hidden_states=hidden_states, - alibi=None, - attention_mask=attention_mask, - use_cache=True, - ) - hidden_states = outputs[0] - new_pkvs.append( - ( - outputs[-1][0], - outputs[-1][1], - ) - ) - - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - (new_pkv80, new_pkv81), - (new_pkv90, new_pkv91), - (new_pkv100, new_pkv101), - (new_pkv110, new_pkv111), - (new_pkv120, new_pkv121), - (new_pkv130, new_pkv131), - (new_pkv140, new_pkv141), - (new_pkv150, new_pkv151), - (new_pkv160, new_pkv161), - (new_pkv170, new_pkv171), - (new_pkv180, new_pkv181), - (new_pkv190, new_pkv191), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - new_pkv80, - new_pkv81, - new_pkv90, - new_pkv91, - new_pkv100, - new_pkv101, - new_pkv110, - new_pkv111, - new_pkv120, - new_pkv121, - new_pkv130, - new_pkv131, - new_pkv140, - new_pkv141, - new_pkv150, - new_pkv151, - new_pkv160, - new_pkv161, - new_pkv170, - new_pkv171, - new_pkv180, - new_pkv181, - new_pkv190, - new_pkv191, - ) - return result - - -class CompiledFourWayShardingDecoderLayer(torch.nn.Module): - def __init__( - self, layer_id, device_idx, falcon_variant, device, precision, model - ): - super().__init__() - self.layer_id = layer_id - self.device_index = device_idx - self.falcon_variant = falcon_variant - self.device = device - self.precision = precision - self.model = model - - def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - import gc - - torch.cuda.empty_cache() - gc.collect() - - if self.model is None: - raise ValueError("Layer vmfb not found") - - hidden_states = hidden_states.to(torch.float32).detach().numpy() - attention_mask = attention_mask.to(torch.float32).detach().numpy() - - if alibi is not None or layer_past is not None: - raise ValueError("Past Key Values and alibi should be None") - else: - output = self.model( - "forward", - ( - hidden_states, - attention_mask, - ), - ) - - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ( - torch.tensor(output[17]), - torch.tensor(output[18]), - ), - ( - torch.tensor(output[19]), - torch.tensor(output[20]), - ), - ( - torch.tensor(output[21]), - torch.tensor(output[22]), - ), - ( - torch.tensor(output[23]), - torch.tensor(output[24]), - ), - ( - torch.tensor(output[25]), - torch.tensor(output[26]), - ), - ( - torch.tensor(output[27]), - torch.tensor(output[28]), - ), - ( - torch.tensor(output[29]), - torch.tensor(output[30]), - ), - ( - torch.tensor(output[31]), - torch.tensor(output[32]), - ), - ( - torch.tensor(output[33]), - torch.tensor(output[34]), - ), - ( - torch.tensor(output[35]), - torch.tensor(output[36]), - ), - ( - torch.tensor(output[37]), - torch.tensor(output[38]), - ), - ( - torch.tensor(output[39]), - torch.tensor(output[40]), - ), - ) - return result - - -class TwoWayShardingDecoderLayer(torch.nn.Module): - def __init__(self, decoder_layer_model, falcon_variant): - super().__init__() - self.model = decoder_layer_model - self.falcon_variant = falcon_variant - - def forward(self, hidden_states, attention_mask): - new_pkvs = [] - for layer in self.model: - outputs = layer( - hidden_states=hidden_states, - alibi=None, - attention_mask=attention_mask, - use_cache=True, - ) - hidden_states = outputs[0] - new_pkvs.append( - ( - outputs[-1][0], - outputs[-1][1], - ) - ) - - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - (new_pkv80, new_pkv81), - (new_pkv90, new_pkv91), - (new_pkv100, new_pkv101), - (new_pkv110, new_pkv111), - (new_pkv120, new_pkv121), - (new_pkv130, new_pkv131), - (new_pkv140, new_pkv141), - (new_pkv150, new_pkv151), - (new_pkv160, new_pkv161), - (new_pkv170, new_pkv171), - (new_pkv180, new_pkv181), - (new_pkv190, new_pkv191), - (new_pkv200, new_pkv201), - (new_pkv210, new_pkv211), - (new_pkv220, new_pkv221), - (new_pkv230, new_pkv231), - (new_pkv240, new_pkv241), - (new_pkv250, new_pkv251), - (new_pkv260, new_pkv261), - (new_pkv270, new_pkv271), - (new_pkv280, new_pkv281), - (new_pkv290, new_pkv291), - (new_pkv300, new_pkv301), - (new_pkv310, new_pkv311), - (new_pkv320, new_pkv321), - (new_pkv330, new_pkv331), - (new_pkv340, new_pkv341), - (new_pkv350, new_pkv351), - (new_pkv360, new_pkv361), - (new_pkv370, new_pkv371), - (new_pkv380, new_pkv381), - (new_pkv390, new_pkv391), - ) = new_pkvs - result = ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - new_pkv80, - new_pkv81, - new_pkv90, - new_pkv91, - new_pkv100, - new_pkv101, - new_pkv110, - new_pkv111, - new_pkv120, - new_pkv121, - new_pkv130, - new_pkv131, - new_pkv140, - new_pkv141, - new_pkv150, - new_pkv151, - new_pkv160, - new_pkv161, - new_pkv170, - new_pkv171, - new_pkv180, - new_pkv181, - new_pkv190, - new_pkv191, - new_pkv200, - new_pkv201, - new_pkv210, - new_pkv211, - new_pkv220, - new_pkv221, - new_pkv230, - new_pkv231, - new_pkv240, - new_pkv241, - new_pkv250, - new_pkv251, - new_pkv260, - new_pkv261, - new_pkv270, - new_pkv271, - new_pkv280, - new_pkv281, - new_pkv290, - new_pkv291, - new_pkv300, - new_pkv301, - new_pkv310, - new_pkv311, - new_pkv320, - new_pkv321, - new_pkv330, - new_pkv331, - new_pkv340, - new_pkv341, - new_pkv350, - new_pkv351, - new_pkv360, - new_pkv361, - new_pkv370, - new_pkv371, - new_pkv380, - new_pkv381, - new_pkv390, - new_pkv391, - ) - return result - - -class CompiledTwoWayShardingDecoderLayer(torch.nn.Module): - def __init__( - self, layer_id, device_idx, falcon_variant, device, precision, model - ): - super().__init__() - self.layer_id = layer_id - self.device_index = device_idx - self.falcon_variant = falcon_variant - self.device = device - self.precision = precision - self.model = model - - def forward( - self, - hidden_states: torch.Tensor, - alibi: Optional[torch.Tensor], - attention_mask: torch.Tensor, - position_ids: Optional[torch.LongTensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - head_mask: Optional[torch.Tensor] = None, - use_cache: bool = False, - output_attentions: bool = False, - ): - import gc - - torch.cuda.empty_cache() - gc.collect() - - if self.model is None: - raise ValueError("Layer vmfb not found") - - hidden_states = hidden_states.to(torch.float32).detach().numpy() - attention_mask = attention_mask.to(torch.float32).detach().numpy() - - if alibi is not None or layer_past is not None: - raise ValueError("Past Key Values and alibi should be None") - else: - output = self.model( - "forward", - ( - hidden_states, - attention_mask, - ), - ) - - result = ( - torch.tensor(output[0]), - ( - torch.tensor(output[1]), - torch.tensor(output[2]), - ), - ( - torch.tensor(output[3]), - torch.tensor(output[4]), - ), - ( - torch.tensor(output[5]), - torch.tensor(output[6]), - ), - ( - torch.tensor(output[7]), - torch.tensor(output[8]), - ), - ( - torch.tensor(output[9]), - torch.tensor(output[10]), - ), - ( - torch.tensor(output[11]), - torch.tensor(output[12]), - ), - ( - torch.tensor(output[13]), - torch.tensor(output[14]), - ), - ( - torch.tensor(output[15]), - torch.tensor(output[16]), - ), - ( - torch.tensor(output[17]), - torch.tensor(output[18]), - ), - ( - torch.tensor(output[19]), - torch.tensor(output[20]), - ), - ( - torch.tensor(output[21]), - torch.tensor(output[22]), - ), - ( - torch.tensor(output[23]), - torch.tensor(output[24]), - ), - ( - torch.tensor(output[25]), - torch.tensor(output[26]), - ), - ( - torch.tensor(output[27]), - torch.tensor(output[28]), - ), - ( - torch.tensor(output[29]), - torch.tensor(output[30]), - ), - ( - torch.tensor(output[31]), - torch.tensor(output[32]), - ), - ( - torch.tensor(output[33]), - torch.tensor(output[34]), - ), - ( - torch.tensor(output[35]), - torch.tensor(output[36]), - ), - ( - torch.tensor(output[37]), - torch.tensor(output[38]), - ), - ( - torch.tensor(output[39]), - torch.tensor(output[40]), - ), - ( - torch.tensor(output[41]), - torch.tensor(output[42]), - ), - ( - torch.tensor(output[43]), - torch.tensor(output[44]), - ), - ( - torch.tensor(output[45]), - torch.tensor(output[46]), - ), - ( - torch.tensor(output[47]), - torch.tensor(output[48]), - ), - ( - torch.tensor(output[49]), - torch.tensor(output[50]), - ), - ( - torch.tensor(output[51]), - torch.tensor(output[52]), - ), - ( - torch.tensor(output[53]), - torch.tensor(output[54]), - ), - ( - torch.tensor(output[55]), - torch.tensor(output[56]), - ), - ( - torch.tensor(output[57]), - torch.tensor(output[58]), - ), - ( - torch.tensor(output[59]), - torch.tensor(output[60]), - ), - ( - torch.tensor(output[61]), - torch.tensor(output[62]), - ), - ( - torch.tensor(output[63]), - torch.tensor(output[64]), - ), - ( - torch.tensor(output[65]), - torch.tensor(output[66]), - ), - ( - torch.tensor(output[67]), - torch.tensor(output[68]), - ), - ( - torch.tensor(output[69]), - torch.tensor(output[70]), - ), - ( - torch.tensor(output[71]), - torch.tensor(output[72]), - ), - ( - torch.tensor(output[73]), - torch.tensor(output[74]), - ), - ( - torch.tensor(output[75]), - torch.tensor(output[76]), - ), - ( - torch.tensor(output[77]), - torch.tensor(output[78]), - ), - ( - torch.tensor(output[79]), - torch.tensor(output[80]), - ), - ) - return result - - -class ShardedFalconModel: - def __init__(self, model, layers, word_embeddings, ln_f, lm_head): - super().__init__() - self.model = model - self.model.transformer.h = torch.nn.modules.container.ModuleList( - layers - ) - self.model.transformer.word_embeddings = word_embeddings - self.model.transformer.ln_f = ln_f - self.model.lm_head = lm_head - - def forward( - self, - input_ids, - attention_mask=None, - ): - return self.model.forward( - input_ids=input_ids, - attention_mask=attention_mask, - ).logits[:, -1, :] diff --git a/apps/language_models/src/model_wrappers/minigpt4.py b/apps/language_models/src/model_wrappers/minigpt4.py deleted file mode 100644 index d1dc93c4..00000000 --- a/apps/language_models/src/model_wrappers/minigpt4.py +++ /dev/null @@ -1,503 +0,0 @@ -import torch -import dataclasses -from enum import auto, Enum -from typing import List, Any -from transformers import StoppingCriteria - - -from brevitas_examples.common.generative.quantize import quantize_model -from brevitas_examples.llm.llm_quant.run_utils import get_model_impl - - -class LayerNorm(torch.nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -class VisionModel(torch.nn.Module): - def __init__( - self, - ln_vision, - visual_encoder, - precision="fp32", - weight_group_size=128, - ): - super().__init__() - self.ln_vision = ln_vision - self.visual_encoder = visual_encoder - if precision in ["int4", "int8"]: - print("Vision Model applying weight quantization to ln_vision") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - self.ln_vision, - dtype=torch.float32, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - print( - "Vision Model applying weight quantization to visual_encoder" - ) - quantize_model( - self.visual_encoder, - dtype=torch.float32, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward(self, image): - image_embeds = self.ln_vision(self.visual_encoder(image)) - return image_embeds - - -class QformerBertModel(torch.nn.Module): - def __init__(self, qformer_bert): - super().__init__() - self.qformer_bert = qformer_bert - - def forward(self, query_tokens, image_embeds, image_atts): - query_output = self.qformer_bert( - query_embeds=query_tokens, - encoder_hidden_states=image_embeds, - encoder_attention_mask=image_atts, - return_dict=True, - ) - return query_output.last_hidden_state - - -class FirstLlamaModel(torch.nn.Module): - def __init__(self, model, precision="fp32", weight_group_size=128): - super().__init__() - self.model = model - print("SHARK: Loading LLAMA Done") - if precision in ["int4", "int8"]: - print("First Llama applying weight quantization") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - self.model, - dtype=torch.float32, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward(self, inputs_embeds, position_ids, attention_mask): - print("************************************") - print( - "inputs_embeds: ", - inputs_embeds.shape, - " dtype: ", - inputs_embeds.dtype, - ) - print( - "position_ids: ", - position_ids.shape, - " dtype: ", - position_ids.dtype, - ) - print( - "attention_mask: ", - attention_mask.shape, - " dtype: ", - attention_mask.dtype, - ) - print("************************************") - config = { - "inputs_embeds": inputs_embeds, - "position_ids": position_ids, - "past_key_values": None, - "use_cache": True, - "attention_mask": attention_mask, - } - output = self.model( - **config, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - return_vals = [] - return_vals.append(output.logits) - temp_past_key_values = output.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class SecondLlamaModel(torch.nn.Module): - def __init__(self, model, precision="fp32", weight_group_size=128): - super().__init__() - self.model = model - print("SHARK: Loading LLAMA Done") - if precision in ["int4", "int8"]: - print("Second Llama applying weight quantization") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - self.model, - dtype=torch.float32, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward( - self, - input_ids, - position_ids, - attention_mask, - i1, - i2, - i3, - i4, - i5, - i6, - i7, - i8, - i9, - i10, - i11, - i12, - i13, - i14, - i15, - i16, - i17, - i18, - i19, - i20, - i21, - i22, - i23, - i24, - i25, - i26, - i27, - i28, - i29, - i30, - i31, - i32, - i33, - i34, - i35, - i36, - i37, - i38, - i39, - i40, - i41, - i42, - i43, - i44, - i45, - i46, - i47, - i48, - i49, - i50, - i51, - i52, - i53, - i54, - i55, - i56, - i57, - i58, - i59, - i60, - i61, - i62, - i63, - i64, - ): - print("************************************") - print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype) - print( - "position_ids: ", - position_ids.shape, - " dtype: ", - position_ids.dtype, - ) - print( - "attention_mask: ", - attention_mask.shape, - " dtype: ", - attention_mask.dtype, - ) - print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape) - print("past_key_values dtype: ", i1.dtype) - print("************************************") - config = { - "input_ids": input_ids, - "position_ids": position_ids, - "past_key_values": ( - (i1, i2), - ( - i3, - i4, - ), - ( - i5, - i6, - ), - ( - i7, - i8, - ), - ( - i9, - i10, - ), - ( - i11, - i12, - ), - ( - i13, - i14, - ), - ( - i15, - i16, - ), - ( - i17, - i18, - ), - ( - i19, - i20, - ), - ( - i21, - i22, - ), - ( - i23, - i24, - ), - ( - i25, - i26, - ), - ( - i27, - i28, - ), - ( - i29, - i30, - ), - ( - i31, - i32, - ), - ( - i33, - i34, - ), - ( - i35, - i36, - ), - ( - i37, - i38, - ), - ( - i39, - i40, - ), - ( - i41, - i42, - ), - ( - i43, - i44, - ), - ( - i45, - i46, - ), - ( - i47, - i48, - ), - ( - i49, - i50, - ), - ( - i51, - i52, - ), - ( - i53, - i54, - ), - ( - i55, - i56, - ), - ( - i57, - i58, - ), - ( - i59, - i60, - ), - ( - i61, - i62, - ), - ( - i63, - i64, - ), - ), - "use_cache": True, - "attention_mask": attention_mask, - } - output = self.model( - **config, - return_dict=True, - output_attentions=False, - output_hidden_states=False, - ) - return_vals = [] - return_vals.append(output.logits) - temp_past_key_values = output.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class SeparatorStyle(Enum): - """Different separator style.""" - - SINGLE = auto() - TWO = auto() - - -@dataclasses.dataclass -class Conversation: - """A class that keeps all conversation history.""" - - system: str - roles: List[str] - messages: List[List[str]] - offset: int - sep_style: SeparatorStyle = SeparatorStyle.SINGLE - sep: str = "###" - sep2: str = None - - skip_next: bool = False - conv_id: Any = None - - def get_prompt(self): - if self.sep_style == SeparatorStyle.SINGLE: - ret = self.system + self.sep - for role, message in self.messages: - if message: - ret += role + ": " + message + self.sep - else: - ret += role + ":" - return ret - elif self.sep_style == SeparatorStyle.TWO: - seps = [self.sep, self.sep2] - ret = self.system + seps[0] - for i, (role, message) in enumerate(self.messages): - if message: - ret += role + ": " + message + seps[i % 2] - else: - ret += role + ":" - return ret - else: - raise ValueError(f"Invalid style: {self.sep_style}") - - def append_message(self, role, message): - self.messages.append([role, message]) - - def to_gradio_chatbot(self): - ret = [] - for i, (role, msg) in enumerate(self.messages[self.offset :]): - if i % 2 == 0: - ret.append([msg, None]) - else: - ret[-1][-1] = msg - return ret - - def copy(self): - return Conversation( - system=self.system, - roles=self.roles, - messages=[[x, y] for x, y in self.messages], - offset=self.offset, - sep_style=self.sep_style, - sep=self.sep, - sep2=self.sep2, - conv_id=self.conv_id, - ) - - def dict(self): - return { - "system": self.system, - "roles": self.roles, - "messages": self.messages, - "offset": self.offset, - "sep": self.sep, - "sep2": self.sep2, - "conv_id": self.conv_id, - } - - -class StoppingCriteriaSub(StoppingCriteria): - def __init__(self, stops=[], encounters=1): - super().__init__() - self.stops = stops - - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): - for stop in self.stops: - if torch.all((stop == input_ids[0][-len(stop) :])).item(): - return True - - return False - - -CONV_VISION = Conversation( - system="Give the following image: ImageContent. " - "You will be able to see the image once I provide it to you. Please answer my questions.", - roles=("Human", "Assistant"), - messages=[], - offset=2, - sep_style=SeparatorStyle.SINGLE, - sep="###", -) diff --git a/apps/language_models/src/model_wrappers/stablelm_model.py b/apps/language_models/src/model_wrappers/stablelm_model.py deleted file mode 100644 index 86cc2081..00000000 --- a/apps/language_models/src/model_wrappers/stablelm_model.py +++ /dev/null @@ -1,15 +0,0 @@ -import torch - - -class StableLMModel(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, input_ids, attention_mask): - combine_input_dict = { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - output = self.model(**combine_input_dict) - return output.logits diff --git a/apps/language_models/src/model_wrappers/vicuna4.py b/apps/language_models/src/model_wrappers/vicuna4.py deleted file mode 100644 index 10bef66f..00000000 --- a/apps/language_models/src/model_wrappers/vicuna4.py +++ /dev/null @@ -1,876 +0,0 @@ -import argparse -import json -import re -from io import BytesIO -from pathlib import Path -from tqdm import tqdm -from typing import List, Optional, Tuple, Union -import numpy as np -import iree.runtime -import itertools -import subprocess - -import torch -import torch_mlir -from torch_mlir import TensorPlaceholder -from torch_mlir.compiler_utils import run_pipeline_with_repro_report -from transformers import ( - AutoTokenizer, - AutoModelForCausalLM, - LlamaPreTrainedModel, -) -from transformers.modeling_outputs import ( - BaseModelOutputWithPast, - CausalLMOutputWithPast, - SequenceClassifierOutputWithPast, -) -from transformers.modeling_utils import PreTrainedModel -from transformers.utils import ( - add_start_docstrings, - add_start_docstrings_to_model_forward, - logging, - replace_return_docstrings, -) - -from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase -from apps.language_models.src.model_wrappers.vicuna_sharded_model import ( - FirstVicunaLayer, - SecondVicunaLayer, - CompiledVicunaLayer, - ShardedVicunaModel, - LMHead, - LMHeadCompiled, - VicunaEmbedding, - VicunaEmbeddingCompiled, - VicunaNorm, - VicunaNormCompiled, -) -from apps.language_models.src.model_wrappers.vicuna_model import ( - FirstVicuna, - SecondVicuna7B, -) -from apps.language_models.utils import ( - get_vmfb_from_path, -) -from shark.shark_downloader import download_public_file -from shark.shark_importer import get_f16_inputs -from shark.shark_inference import SharkInference - -from transformers.models.llama.configuration_llama import LlamaConfig -from transformers.models.llama.modeling_llama import ( - LlamaDecoderLayer, - LlamaRMSNorm, - _make_causal_mask, - _expand_mask, -) -from torch import nn -from time import time - - -class LlamaModel(LlamaPreTrainedModel): - """ - Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`] - - Args: - config: LlamaConfig - """ - - def __init__(self, config: LlamaConfig): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding( - config.vocab_size, config.hidden_size, self.padding_idx - ) - self.layers = nn.ModuleList( - [ - LlamaDecoderLayer(config) - for _ in range(config.num_hidden_layers) - ] - ) - self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - self.gradient_checkpointing = False - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self): - return self.embed_tokens - - def set_input_embeddings(self, value): - self.embed_tokens = value - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask( - self, - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ).to(inputs_embeds.device) - combined_attention_mask = ( - expanded_attn_mask - if combined_attention_mask is None - else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - - def forward( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ): - t1 = time() - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = ( - use_cache if use_cache is not None else self.config.use_cache - ) - - return_dict = ( - return_dict - if return_dict is not None - else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = ( - seq_length_with_past + past_key_values_length - ) - - if position_ids is None: - device = ( - input_ids.device - if input_ids is not None - else inputs_embeds.device - ) - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.compressedlayers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[8 * idx : 8 * (idx + 1)] - if past_key_values is not None - else None - ) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer.forward( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += (layer_outputs[1:],) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - try: - hidden_states = np.asarray(hidden_states, hidden_states.dtype) - except: - _ = 10 - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - next_cache = tuple(itertools.chain.from_iterable(next_cache)) - print(f"Token generated in {time() - t1} seconds") - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - ] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class EightLayerLayerSV(torch.nn.Module): - def __init__(self, layers): - super().__init__() - assert len(layers) == 8 - self.layers = layers - - def forward( - self, - hidden_states, - attention_mask, - position_ids, - pkv00, - pkv01, - pkv10, - pkv11, - pkv20, - pkv21, - pkv30, - pkv31, - pkv40, - pkv41, - pkv50, - pkv51, - pkv60, - pkv61, - pkv70, - pkv71, - ): - pkvs = [ - (pkv00, pkv01), - (pkv10, pkv11), - (pkv20, pkv21), - (pkv30, pkv31), - (pkv40, pkv41), - (pkv50, pkv51), - (pkv60, pkv61), - (pkv70, pkv71), - ] - new_pkvs = [] - for layer, pkv in zip(self.layers, pkvs): - outputs = layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=( - pkv[0], - pkv[1], - ), - use_cache=True, - ) - - hidden_states = outputs[0] - new_pkvs.append( - ( - outputs[-1][0], - outputs[-1][1], - ) - ) - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - ) = new_pkvs - return ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - ) - - -class EightLayerLayerFV(torch.nn.Module): - def __init__(self, layers): - super().__init__() - assert len(layers) == 8 - self.layers = layers - - def forward(self, hidden_states, attention_mask, position_ids): - new_pkvs = [] - for layer in self.layers: - outputs = layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=None, - use_cache=True, - ) - - hidden_states = outputs[0] - new_pkvs.append( - ( - outputs[-1][0], - outputs[-1][1], - ) - ) - ( - (new_pkv00, new_pkv01), - (new_pkv10, new_pkv11), - (new_pkv20, new_pkv21), - (new_pkv30, new_pkv31), - (new_pkv40, new_pkv41), - (new_pkv50, new_pkv51), - (new_pkv60, new_pkv61), - (new_pkv70, new_pkv71), - ) = new_pkvs - return ( - hidden_states, - new_pkv00, - new_pkv01, - new_pkv10, - new_pkv11, - new_pkv20, - new_pkv21, - new_pkv30, - new_pkv31, - new_pkv40, - new_pkv41, - new_pkv50, - new_pkv51, - new_pkv60, - new_pkv61, - new_pkv70, - new_pkv71, - ) - - -class CompiledEightLayerLayerSV(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward( - self, - hidden_states, - attention_mask, - position_ids, - past_key_value, - output_attentions=False, - use_cache=True, - ): - hidden_states = hidden_states.detach() - attention_mask = attention_mask.detach() - position_ids = position_ids.detach() - ( - (pkv00, pkv01), - (pkv10, pkv11), - (pkv20, pkv21), - (pkv30, pkv31), - (pkv40, pkv41), - (pkv50, pkv51), - (pkv60, pkv61), - (pkv70, pkv71), - ) = past_key_value - pkv00 = pkv00.detatch() - pkv01 = pkv01.detatch() - pkv10 = pkv10.detatch() - pkv11 = pkv11.detatch() - pkv20 = pkv20.detatch() - pkv21 = pkv21.detatch() - pkv30 = pkv30.detatch() - pkv31 = pkv31.detatch() - pkv40 = pkv40.detatch() - pkv41 = pkv41.detatch() - pkv50 = pkv50.detatch() - pkv51 = pkv51.detatch() - pkv60 = pkv60.detatch() - pkv61 = pkv61.detatch() - pkv70 = pkv70.detatch() - pkv71 = pkv71.detatch() - - output = self.model( - "forward", - ( - hidden_states, - attention_mask, - position_ids, - pkv00, - pkv01, - pkv10, - pkv11, - pkv20, - pkv21, - pkv30, - pkv31, - pkv40, - pkv41, - pkv50, - pkv51, - pkv60, - pkv61, - pkv70, - pkv71, - ), - send_to_host=False, - ) - return ( - output[0], - (output[1][0], output[1][1]), - (output[2][0], output[2][1]), - (output[3][0], output[3][1]), - (output[4][0], output[4][1]), - (output[5][0], output[5][1]), - (output[6][0], output[6][1]), - (output[7][0], output[7][1]), - (output[8][0], output[8][1]), - ) - - -def forward_compressed( - self, - input_ids: torch.LongTensor = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, -): - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) - - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError( - "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" - ) - elif input_ids is not None: - batch_size, seq_length = input_ids.shape - elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape - else: - raise ValueError( - "You have to specify either decoder_input_ids or decoder_inputs_embeds" - ) - - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length - - if position_ids is None: - device = ( - input_ids.device if input_ids is not None else inputs_embeds.device - ) - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), - dtype=torch.bool, - device=inputs_embeds.device, - ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - - hidden_states = inputs_embeds - - # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - next_decoder_cache = () if use_cache else None - - for idx, decoder_layer in enumerate(self.compressedlayers): - if output_hidden_states: - all_hidden_states += (hidden_states,) - - past_key_value = ( - past_key_values[8 * idx : 8 * (idx + 1)] - if past_key_values is not None - else None - ) - - if self.gradient_checkpointing and self.training: - - def create_custom_forward(module): - def custom_forward(*inputs): - # None for past_key_value - return module(*inputs, output_attentions, None) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(decoder_layer), - hidden_states, - attention_mask, - position_ids, - None, - ) - else: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - ) - - hidden_states = layer_outputs[0] - - if use_cache: - next_decoder_cache += ( - layer_outputs[2 if output_attentions else 1], - ) - - if output_attentions: - all_self_attns += (layer_outputs[1],) - - hidden_states = self.norm(hidden_states) - - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - - next_cache = next_decoder_cache if use_cache else None - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_cache, - all_hidden_states, - all_self_attns, - ] - if v is not None - ) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=next_cache, - hidden_states=all_hidden_states, - attentions=all_self_attns, - ) - - -class CompiledEightLayerLayer(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward( - self, - hidden_states, - attention_mask, - position_ids, - past_key_value=None, - output_attentions=False, - use_cache=True, - ): - t2 = time() - if past_key_value is None: - try: - hidden_states = np.asarray(hidden_states, hidden_states.dtype) - except: - pass - attention_mask = attention_mask.detach() - position_ids = position_ids.detach() - t1 = time() - - output = self.model( - "first_vicuna_forward", - (hidden_states, attention_mask, position_ids), - send_to_host=False, - ) - output2 = ( - output[0], - ( - output[1], - output[2], - ), - ( - output[3], - output[4], - ), - ( - output[5], - output[6], - ), - ( - output[7], - output[8], - ), - ( - output[9], - output[10], - ), - ( - output[11], - output[12], - ), - ( - output[13], - output[14], - ), - ( - output[15], - output[16], - ), - ) - return output2 - else: - ( - (pkv00, pkv01), - (pkv10, pkv11), - (pkv20, pkv21), - (pkv30, pkv31), - (pkv40, pkv41), - (pkv50, pkv51), - (pkv60, pkv61), - (pkv70, pkv71), - ) = past_key_value - - try: - hidden_states = hidden_states.detach() - attention_mask = attention_mask.detach() - position_ids = position_ids.detach() - pkv00 = pkv00.detach() - pkv01 = pkv01.detach() - pkv10 = pkv10.detach() - pkv11 = pkv11.detach() - pkv20 = pkv20.detach() - pkv21 = pkv21.detach() - pkv30 = pkv30.detach() - pkv31 = pkv31.detach() - pkv40 = pkv40.detach() - pkv41 = pkv41.detach() - pkv50 = pkv50.detach() - pkv51 = pkv51.detach() - pkv60 = pkv60.detach() - pkv61 = pkv61.detach() - pkv70 = pkv70.detach() - pkv71 = pkv71.detach() - except: - x = 10 - - t1 = time() - if type(hidden_states) == iree.runtime.array_interop.DeviceArray: - hidden_states = np.array(hidden_states, hidden_states.dtype) - hidden_states = torch.tensor(hidden_states) - hidden_states = hidden_states.detach() - - output = self.model( - "second_vicuna_forward", - ( - hidden_states, - attention_mask, - position_ids, - pkv00, - pkv01, - pkv10, - pkv11, - pkv20, - pkv21, - pkv30, - pkv31, - pkv40, - pkv41, - pkv50, - pkv51, - pkv60, - pkv61, - pkv70, - pkv71, - ), - send_to_host=False, - ) - print(f"{time() - t1}") - del pkv00 - del pkv01 - del pkv10 - del pkv11 - del pkv20 - del pkv21 - del pkv30 - del pkv31 - del pkv40 - del pkv41 - del pkv50 - del pkv51 - del pkv60 - del pkv61 - del pkv70 - del pkv71 - output2 = ( - output[0], - ( - output[1], - output[2], - ), - ( - output[3], - output[4], - ), - ( - output[5], - output[6], - ), - ( - output[7], - output[8], - ), - ( - output[9], - output[10], - ), - ( - output[11], - output[12], - ), - ( - output[13], - output[14], - ), - ( - output[15], - output[16], - ), - ) - return output2 diff --git a/apps/language_models/src/model_wrappers/vicuna_model.py b/apps/language_models/src/model_wrappers/vicuna_model.py deleted file mode 100644 index efdf44eb..00000000 --- a/apps/language_models/src/model_wrappers/vicuna_model.py +++ /dev/null @@ -1,1175 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM - - -class FirstVicuna(torch.nn.Module): - def __init__( - self, - model_path, - precision="fp32", - accumulates="fp32", - weight_group_size=128, - model_name="vicuna", - hf_auth_token: str = None, - ): - super().__init__() - kwargs = {"torch_dtype": torch.float32} - if "llama2" in model_name: - kwargs["use_auth_token"] = hf_auth_token - self.accumulates = ( - torch.float32 if accumulates == "fp32" else torch.float16 - ) - self.model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **kwargs - ) - print(f"[DEBUG] model_path : {model_path}") - if precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - print("First Vicuna applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - get_model_impl(self.model).layers, - dtype=self.accumulates, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward(self, input_ids): - op = self.model(input_ids=input_ids, use_cache=True) - return_vals = [] - token = torch.argmax(op.logits[:, -1, :], dim=1) - return_vals.append(token) - - temp_past_key_values = op.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class SecondVicuna7B(torch.nn.Module): - def __init__( - self, - model_path, - precision="fp32", - accumulates="fp32", - weight_group_size=128, - model_name="vicuna", - hf_auth_token: str = None, - ): - super().__init__() - kwargs = {"torch_dtype": torch.float32} - if "llama2" in model_name: - kwargs["use_auth_token"] = hf_auth_token - self.model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **kwargs - ) - self.accumulates = ( - torch.float32 if accumulates == "fp32" else torch.float16 - ) - print(f"[DEBUG] model_path : {model_path}") - if precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - print("Second Vicuna applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - get_model_impl(self.model).layers, - dtype=self.accumulates, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward( - self, - i0, - i1, - i2, - i3, - i4, - i5, - i6, - i7, - i8, - i9, - i10, - i11, - i12, - i13, - i14, - i15, - i16, - i17, - i18, - i19, - i20, - i21, - i22, - i23, - i24, - i25, - i26, - i27, - i28, - i29, - i30, - i31, - i32, - i33, - i34, - i35, - i36, - i37, - i38, - i39, - i40, - i41, - i42, - i43, - i44, - i45, - i46, - i47, - i48, - i49, - i50, - i51, - i52, - i53, - i54, - i55, - i56, - i57, - i58, - i59, - i60, - i61, - i62, - i63, - i64, - ): - token = i0 - past_key_values = ( - (i1, i2), - ( - i3, - i4, - ), - ( - i5, - i6, - ), - ( - i7, - i8, - ), - ( - i9, - i10, - ), - ( - i11, - i12, - ), - ( - i13, - i14, - ), - ( - i15, - i16, - ), - ( - i17, - i18, - ), - ( - i19, - i20, - ), - ( - i21, - i22, - ), - ( - i23, - i24, - ), - ( - i25, - i26, - ), - ( - i27, - i28, - ), - ( - i29, - i30, - ), - ( - i31, - i32, - ), - ( - i33, - i34, - ), - ( - i35, - i36, - ), - ( - i37, - i38, - ), - ( - i39, - i40, - ), - ( - i41, - i42, - ), - ( - i43, - i44, - ), - ( - i45, - i46, - ), - ( - i47, - i48, - ), - ( - i49, - i50, - ), - ( - i51, - i52, - ), - ( - i53, - i54, - ), - ( - i55, - i56, - ), - ( - i57, - i58, - ), - ( - i59, - i60, - ), - ( - i61, - i62, - ), - ( - i63, - i64, - ), - ) - op = self.model( - input_ids=token, use_cache=True, past_key_values=past_key_values - ) - return_vals = [] - token = torch.argmax(op.logits[:, -1, :], dim=1) - return_vals.append(token) - temp_past_key_values = op.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class SecondVicuna13B(torch.nn.Module): - def __init__( - self, - model_path, - precision="int8", - accumulates="fp32", - weight_group_size=128, - model_name="vicuna", - hf_auth_token: str = None, - ): - super().__init__() - kwargs = {"torch_dtype": torch.float32} - if "llama2" in model_name: - kwargs["use_auth_token"] = hf_auth_token - self.model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **kwargs - ) - self.accumulates = ( - torch.float32 if accumulates == "fp32" else torch.float16 - ) - if precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - print("Second Vicuna applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - get_model_impl(self.model).layers, - dtype=self.accumulates, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward( - self, - i0, - i1, - i2, - i3, - i4, - i5, - i6, - i7, - i8, - i9, - i10, - i11, - i12, - i13, - i14, - i15, - i16, - i17, - i18, - i19, - i20, - i21, - i22, - i23, - i24, - i25, - i26, - i27, - i28, - i29, - i30, - i31, - i32, - i33, - i34, - i35, - i36, - i37, - i38, - i39, - i40, - i41, - i42, - i43, - i44, - i45, - i46, - i47, - i48, - i49, - i50, - i51, - i52, - i53, - i54, - i55, - i56, - i57, - i58, - i59, - i60, - i61, - i62, - i63, - i64, - i65, - i66, - i67, - i68, - i69, - i70, - i71, - i72, - i73, - i74, - i75, - i76, - i77, - i78, - i79, - i80, - ): - token = i0 - past_key_values = ( - (i1, i2), - ( - i3, - i4, - ), - ( - i5, - i6, - ), - ( - i7, - i8, - ), - ( - i9, - i10, - ), - ( - i11, - i12, - ), - ( - i13, - i14, - ), - ( - i15, - i16, - ), - ( - i17, - i18, - ), - ( - i19, - i20, - ), - ( - i21, - i22, - ), - ( - i23, - i24, - ), - ( - i25, - i26, - ), - ( - i27, - i28, - ), - ( - i29, - i30, - ), - ( - i31, - i32, - ), - ( - i33, - i34, - ), - ( - i35, - i36, - ), - ( - i37, - i38, - ), - ( - i39, - i40, - ), - ( - i41, - i42, - ), - ( - i43, - i44, - ), - ( - i45, - i46, - ), - ( - i47, - i48, - ), - ( - i49, - i50, - ), - ( - i51, - i52, - ), - ( - i53, - i54, - ), - ( - i55, - i56, - ), - ( - i57, - i58, - ), - ( - i59, - i60, - ), - ( - i61, - i62, - ), - ( - i63, - i64, - ), - ( - i65, - i66, - ), - ( - i67, - i68, - ), - ( - i69, - i70, - ), - ( - i71, - i72, - ), - ( - i73, - i74, - ), - ( - i75, - i76, - ), - ( - i77, - i78, - ), - ( - i79, - i80, - ), - ) - op = self.model( - input_ids=token, use_cache=True, past_key_values=past_key_values - ) - return_vals = [] - return_vals.append(op.logits) - temp_past_key_values = op.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class SecondVicuna70B(torch.nn.Module): - def __init__( - self, - model_path, - precision="fp32", - accumulates="fp32", - weight_group_size=128, - model_name="vicuna", - hf_auth_token: str = None, - ): - super().__init__() - kwargs = {"torch_dtype": torch.float32} - if "llama2" in model_name: - kwargs["use_auth_token"] = hf_auth_token - self.model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **kwargs - ) - self.accumulates = ( - torch.float32 if accumulates == "fp32" else torch.float16 - ) - print(f"[DEBUG] model_path : {model_path}") - if precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - print("Second Vicuna applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - get_model_impl(self.model).layers, - dtype=self.accumulates, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward( - self, - i0, - i1, - i2, - i3, - i4, - i5, - i6, - i7, - i8, - i9, - i10, - i11, - i12, - i13, - i14, - i15, - i16, - i17, - i18, - i19, - i20, - i21, - i22, - i23, - i24, - i25, - i26, - i27, - i28, - i29, - i30, - i31, - i32, - i33, - i34, - i35, - i36, - i37, - i38, - i39, - i40, - i41, - i42, - i43, - i44, - i45, - i46, - i47, - i48, - i49, - i50, - i51, - i52, - i53, - i54, - i55, - i56, - i57, - i58, - i59, - i60, - i61, - i62, - i63, - i64, - i65, - i66, - i67, - i68, - i69, - i70, - i71, - i72, - i73, - i74, - i75, - i76, - i77, - i78, - i79, - i80, - i81, - i82, - i83, - i84, - i85, - i86, - i87, - i88, - i89, - i90, - i91, - i92, - i93, - i94, - i95, - i96, - i97, - i98, - i99, - i100, - i101, - i102, - i103, - i104, - i105, - i106, - i107, - i108, - i109, - i110, - i111, - i112, - i113, - i114, - i115, - i116, - i117, - i118, - i119, - i120, - i121, - i122, - i123, - i124, - i125, - i126, - i127, - i128, - i129, - i130, - i131, - i132, - i133, - i134, - i135, - i136, - i137, - i138, - i139, - i140, - i141, - i142, - i143, - i144, - i145, - i146, - i147, - i148, - i149, - i150, - i151, - i152, - i153, - i154, - i155, - i156, - i157, - i158, - i159, - i160, - ): - token = i0 - past_key_values = ( - (i1, i2), - ( - i3, - i4, - ), - ( - i5, - i6, - ), - ( - i7, - i8, - ), - ( - i9, - i10, - ), - ( - i11, - i12, - ), - ( - i13, - i14, - ), - ( - i15, - i16, - ), - ( - i17, - i18, - ), - ( - i19, - i20, - ), - ( - i21, - i22, - ), - ( - i23, - i24, - ), - ( - i25, - i26, - ), - ( - i27, - i28, - ), - ( - i29, - i30, - ), - ( - i31, - i32, - ), - ( - i33, - i34, - ), - ( - i35, - i36, - ), - ( - i37, - i38, - ), - ( - i39, - i40, - ), - ( - i41, - i42, - ), - ( - i43, - i44, - ), - ( - i45, - i46, - ), - ( - i47, - i48, - ), - ( - i49, - i50, - ), - ( - i51, - i52, - ), - ( - i53, - i54, - ), - ( - i55, - i56, - ), - ( - i57, - i58, - ), - ( - i59, - i60, - ), - ( - i61, - i62, - ), - ( - i63, - i64, - ), - ( - i65, - i66, - ), - ( - i67, - i68, - ), - ( - i69, - i70, - ), - ( - i71, - i72, - ), - ( - i73, - i74, - ), - ( - i75, - i76, - ), - ( - i77, - i78, - ), - ( - i79, - i80, - ), - ( - i81, - i82, - ), - ( - i83, - i84, - ), - ( - i85, - i86, - ), - ( - i87, - i88, - ), - ( - i89, - i90, - ), - ( - i91, - i92, - ), - ( - i93, - i94, - ), - ( - i95, - i96, - ), - ( - i97, - i98, - ), - ( - i99, - i100, - ), - ( - i101, - i102, - ), - ( - i103, - i104, - ), - ( - i105, - i106, - ), - ( - i107, - i108, - ), - ( - i109, - i110, - ), - ( - i111, - i112, - ), - ( - i113, - i114, - ), - ( - i115, - i116, - ), - ( - i117, - i118, - ), - ( - i119, - i120, - ), - ( - i121, - i122, - ), - ( - i123, - i124, - ), - ( - i125, - i126, - ), - ( - i127, - i128, - ), - ( - i129, - i130, - ), - ( - i131, - i132, - ), - ( - i133, - i134, - ), - ( - i135, - i136, - ), - ( - i137, - i138, - ), - ( - i139, - i140, - ), - ( - i141, - i142, - ), - ( - i143, - i144, - ), - ( - i145, - i146, - ), - ( - i147, - i148, - ), - ( - i149, - i150, - ), - ( - i151, - i152, - ), - ( - i153, - i154, - ), - ( - i155, - i156, - ), - ( - i157, - i158, - ), - ( - i159, - i160, - ), - ) - op = self.model( - input_ids=token, use_cache=True, past_key_values=past_key_values - ) - return_vals = [] - return_vals.append(op.logits) - temp_past_key_values = op.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class CombinedModel(torch.nn.Module): - def __init__( - self, - first_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF", - second_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF", - ): - super().__init__() - self.first_vicuna = FirstVicuna(first_vicuna_model_path) - # NOT using this path for 13B currently, hence using `SecondVicuna7B`. - self.second_vicuna = SecondVicuna7B(second_vicuna_model_path) - - def forward(self, input_ids): - first_output = self.first_vicuna(input_ids=input_ids) - # generate second vicuna - compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) - pkv = tuple( - (torch.zeros([1, 32, 19, 128], dtype=torch.float32)) - for _ in range(64) - ) - secondVicunaCompileInput = (compilation_input_ids,) + pkv - second_output = self.second_vicuna(*secondVicunaCompileInput) - return second_output diff --git a/apps/language_models/src/model_wrappers/vicuna_model_gpu.py b/apps/language_models/src/model_wrappers/vicuna_model_gpu.py deleted file mode 100644 index 2e2618a9..00000000 --- a/apps/language_models/src/model_wrappers/vicuna_model_gpu.py +++ /dev/null @@ -1,1173 +0,0 @@ -import torch -from transformers import AutoModelForCausalLM - - -class FirstVicunaGPU(torch.nn.Module): - def __init__( - self, - model_path, - precision="fp32", - accumulates="fp32", - weight_group_size=128, - model_name="vicuna", - hf_auth_token: str = None, - ): - super().__init__() - kwargs = {"torch_dtype": torch.float32} - if "llama2" in model_name: - kwargs["use_auth_token"] = hf_auth_token - self.accumulates = ( - torch.float32 if accumulates == "fp32" else torch.float16 - ) - self.model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **kwargs - ) - print(f"[DEBUG] model_path : {model_path}") - if precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - print("First Vicuna applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - get_model_impl(self.model).layers, - dtype=self.accumulates, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward(self, input_ids): - op = self.model(input_ids=input_ids, use_cache=True) - return_vals = [] - return_vals.append(op.logits) - - temp_past_key_values = op.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class SecondVicuna7BGPU(torch.nn.Module): - def __init__( - self, - model_path, - precision="fp32", - accumulates="fp32", - weight_group_size=128, - model_name="vicuna", - hf_auth_token: str = None, - ): - super().__init__() - kwargs = {"torch_dtype": torch.float32} - if "llama2" in model_name: - kwargs["use_auth_token"] = hf_auth_token - self.model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **kwargs - ) - self.accumulates = ( - torch.float32 if accumulates == "fp32" else torch.float16 - ) - print(f"[DEBUG] model_path : {model_path}") - if precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - print("Second Vicuna applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - get_model_impl(self.model).layers, - dtype=self.accumulates, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward( - self, - i0, - i1, - i2, - i3, - i4, - i5, - i6, - i7, - i8, - i9, - i10, - i11, - i12, - i13, - i14, - i15, - i16, - i17, - i18, - i19, - i20, - i21, - i22, - i23, - i24, - i25, - i26, - i27, - i28, - i29, - i30, - i31, - i32, - i33, - i34, - i35, - i36, - i37, - i38, - i39, - i40, - i41, - i42, - i43, - i44, - i45, - i46, - i47, - i48, - i49, - i50, - i51, - i52, - i53, - i54, - i55, - i56, - i57, - i58, - i59, - i60, - i61, - i62, - i63, - i64, - ): - token = i0 - past_key_values = ( - (i1, i2), - ( - i3, - i4, - ), - ( - i5, - i6, - ), - ( - i7, - i8, - ), - ( - i9, - i10, - ), - ( - i11, - i12, - ), - ( - i13, - i14, - ), - ( - i15, - i16, - ), - ( - i17, - i18, - ), - ( - i19, - i20, - ), - ( - i21, - i22, - ), - ( - i23, - i24, - ), - ( - i25, - i26, - ), - ( - i27, - i28, - ), - ( - i29, - i30, - ), - ( - i31, - i32, - ), - ( - i33, - i34, - ), - ( - i35, - i36, - ), - ( - i37, - i38, - ), - ( - i39, - i40, - ), - ( - i41, - i42, - ), - ( - i43, - i44, - ), - ( - i45, - i46, - ), - ( - i47, - i48, - ), - ( - i49, - i50, - ), - ( - i51, - i52, - ), - ( - i53, - i54, - ), - ( - i55, - i56, - ), - ( - i57, - i58, - ), - ( - i59, - i60, - ), - ( - i61, - i62, - ), - ( - i63, - i64, - ), - ) - op = self.model( - input_ids=token, use_cache=True, past_key_values=past_key_values - ) - return_vals = [] - return_vals.append(op.logits) - temp_past_key_values = op.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class SecondVicuna13BGPU(torch.nn.Module): - def __init__( - self, - model_path, - precision="int8", - accumulates="fp32", - weight_group_size=128, - model_name="vicuna", - hf_auth_token: str = None, - ): - super().__init__() - kwargs = {"torch_dtype": torch.float32} - if "llama2" in model_name: - kwargs["use_auth_token"] = hf_auth_token - self.model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **kwargs - ) - self.accumulates = ( - torch.float32 if accumulates == "fp32" else torch.float16 - ) - if precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - print("Second Vicuna applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - get_model_impl(self.model).layers, - dtype=self.accumulates, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward( - self, - i0, - i1, - i2, - i3, - i4, - i5, - i6, - i7, - i8, - i9, - i10, - i11, - i12, - i13, - i14, - i15, - i16, - i17, - i18, - i19, - i20, - i21, - i22, - i23, - i24, - i25, - i26, - i27, - i28, - i29, - i30, - i31, - i32, - i33, - i34, - i35, - i36, - i37, - i38, - i39, - i40, - i41, - i42, - i43, - i44, - i45, - i46, - i47, - i48, - i49, - i50, - i51, - i52, - i53, - i54, - i55, - i56, - i57, - i58, - i59, - i60, - i61, - i62, - i63, - i64, - i65, - i66, - i67, - i68, - i69, - i70, - i71, - i72, - i73, - i74, - i75, - i76, - i77, - i78, - i79, - i80, - ): - token = i0 - past_key_values = ( - (i1, i2), - ( - i3, - i4, - ), - ( - i5, - i6, - ), - ( - i7, - i8, - ), - ( - i9, - i10, - ), - ( - i11, - i12, - ), - ( - i13, - i14, - ), - ( - i15, - i16, - ), - ( - i17, - i18, - ), - ( - i19, - i20, - ), - ( - i21, - i22, - ), - ( - i23, - i24, - ), - ( - i25, - i26, - ), - ( - i27, - i28, - ), - ( - i29, - i30, - ), - ( - i31, - i32, - ), - ( - i33, - i34, - ), - ( - i35, - i36, - ), - ( - i37, - i38, - ), - ( - i39, - i40, - ), - ( - i41, - i42, - ), - ( - i43, - i44, - ), - ( - i45, - i46, - ), - ( - i47, - i48, - ), - ( - i49, - i50, - ), - ( - i51, - i52, - ), - ( - i53, - i54, - ), - ( - i55, - i56, - ), - ( - i57, - i58, - ), - ( - i59, - i60, - ), - ( - i61, - i62, - ), - ( - i63, - i64, - ), - ( - i65, - i66, - ), - ( - i67, - i68, - ), - ( - i69, - i70, - ), - ( - i71, - i72, - ), - ( - i73, - i74, - ), - ( - i75, - i76, - ), - ( - i77, - i78, - ), - ( - i79, - i80, - ), - ) - op = self.model( - input_ids=token, use_cache=True, past_key_values=past_key_values - ) - return_vals = [] - return_vals.append(op.logits) - temp_past_key_values = op.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class SecondVicuna70BGPU(torch.nn.Module): - def __init__( - self, - model_path, - precision="fp32", - accumulates="fp32", - weight_group_size=128, - model_name="vicuna", - hf_auth_token: str = None, - ): - super().__init__() - kwargs = {"torch_dtype": torch.float32} - if "llama2" in model_name: - kwargs["use_auth_token"] = hf_auth_token - self.model = AutoModelForCausalLM.from_pretrained( - model_path, low_cpu_mem_usage=True, **kwargs - ) - self.accumulates = ( - torch.float32 if accumulates == "fp32" else torch.float16 - ) - print(f"[DEBUG] model_path : {model_path}") - if precision in ["int4", "int8"]: - from brevitas_examples.common.generative.quantize import ( - quantize_model, - ) - from brevitas_examples.llm.llm_quant.run_utils import ( - get_model_impl, - ) - - print("Second Vicuna applying weight quantization..") - weight_bit_width = 4 if precision == "int4" else 8 - quantize_model( - get_model_impl(self.model).layers, - dtype=self.accumulates, - weight_bit_width=weight_bit_width, - weight_param_method="stats", - weight_scale_precision="float_scale", - weight_quant_type="asym", - weight_quant_granularity="per_group", - weight_group_size=weight_group_size, - quantize_weight_zero_point=False, - ) - print("Weight quantization applied.") - - def forward( - self, - i0, - i1, - i2, - i3, - i4, - i5, - i6, - i7, - i8, - i9, - i10, - i11, - i12, - i13, - i14, - i15, - i16, - i17, - i18, - i19, - i20, - i21, - i22, - i23, - i24, - i25, - i26, - i27, - i28, - i29, - i30, - i31, - i32, - i33, - i34, - i35, - i36, - i37, - i38, - i39, - i40, - i41, - i42, - i43, - i44, - i45, - i46, - i47, - i48, - i49, - i50, - i51, - i52, - i53, - i54, - i55, - i56, - i57, - i58, - i59, - i60, - i61, - i62, - i63, - i64, - i65, - i66, - i67, - i68, - i69, - i70, - i71, - i72, - i73, - i74, - i75, - i76, - i77, - i78, - i79, - i80, - i81, - i82, - i83, - i84, - i85, - i86, - i87, - i88, - i89, - i90, - i91, - i92, - i93, - i94, - i95, - i96, - i97, - i98, - i99, - i100, - i101, - i102, - i103, - i104, - i105, - i106, - i107, - i108, - i109, - i110, - i111, - i112, - i113, - i114, - i115, - i116, - i117, - i118, - i119, - i120, - i121, - i122, - i123, - i124, - i125, - i126, - i127, - i128, - i129, - i130, - i131, - i132, - i133, - i134, - i135, - i136, - i137, - i138, - i139, - i140, - i141, - i142, - i143, - i144, - i145, - i146, - i147, - i148, - i149, - i150, - i151, - i152, - i153, - i154, - i155, - i156, - i157, - i158, - i159, - i160, - ): - token = i0 - past_key_values = ( - (i1, i2), - ( - i3, - i4, - ), - ( - i5, - i6, - ), - ( - i7, - i8, - ), - ( - i9, - i10, - ), - ( - i11, - i12, - ), - ( - i13, - i14, - ), - ( - i15, - i16, - ), - ( - i17, - i18, - ), - ( - i19, - i20, - ), - ( - i21, - i22, - ), - ( - i23, - i24, - ), - ( - i25, - i26, - ), - ( - i27, - i28, - ), - ( - i29, - i30, - ), - ( - i31, - i32, - ), - ( - i33, - i34, - ), - ( - i35, - i36, - ), - ( - i37, - i38, - ), - ( - i39, - i40, - ), - ( - i41, - i42, - ), - ( - i43, - i44, - ), - ( - i45, - i46, - ), - ( - i47, - i48, - ), - ( - i49, - i50, - ), - ( - i51, - i52, - ), - ( - i53, - i54, - ), - ( - i55, - i56, - ), - ( - i57, - i58, - ), - ( - i59, - i60, - ), - ( - i61, - i62, - ), - ( - i63, - i64, - ), - ( - i65, - i66, - ), - ( - i67, - i68, - ), - ( - i69, - i70, - ), - ( - i71, - i72, - ), - ( - i73, - i74, - ), - ( - i75, - i76, - ), - ( - i77, - i78, - ), - ( - i79, - i80, - ), - ( - i81, - i82, - ), - ( - i83, - i84, - ), - ( - i85, - i86, - ), - ( - i87, - i88, - ), - ( - i89, - i90, - ), - ( - i91, - i92, - ), - ( - i93, - i94, - ), - ( - i95, - i96, - ), - ( - i97, - i98, - ), - ( - i99, - i100, - ), - ( - i101, - i102, - ), - ( - i103, - i104, - ), - ( - i105, - i106, - ), - ( - i107, - i108, - ), - ( - i109, - i110, - ), - ( - i111, - i112, - ), - ( - i113, - i114, - ), - ( - i115, - i116, - ), - ( - i117, - i118, - ), - ( - i119, - i120, - ), - ( - i121, - i122, - ), - ( - i123, - i124, - ), - ( - i125, - i126, - ), - ( - i127, - i128, - ), - ( - i129, - i130, - ), - ( - i131, - i132, - ), - ( - i133, - i134, - ), - ( - i135, - i136, - ), - ( - i137, - i138, - ), - ( - i139, - i140, - ), - ( - i141, - i142, - ), - ( - i143, - i144, - ), - ( - i145, - i146, - ), - ( - i147, - i148, - ), - ( - i149, - i150, - ), - ( - i151, - i152, - ), - ( - i153, - i154, - ), - ( - i155, - i156, - ), - ( - i157, - i158, - ), - ( - i159, - i160, - ), - ) - op = self.model( - input_ids=token, use_cache=True, past_key_values=past_key_values - ) - return_vals = [] - return_vals.append(op.logits) - temp_past_key_values = op.past_key_values - for item in temp_past_key_values: - return_vals.append(item[0]) - return_vals.append(item[1]) - return tuple(return_vals) - - -class CombinedModel(torch.nn.Module): - def __init__( - self, - first_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF", - second_vicuna_model_path="TheBloke/vicuna-7B-1.1-HF", - ): - super().__init__() - self.first_vicuna = FirstVicunaGPU(first_vicuna_model_path) - # NOT using this path for 13B currently, hence using `SecondVicuna7B`. - self.second_vicuna = SecondVicuna7BGPU(second_vicuna_model_path) - - def forward(self, input_ids): - first_output = self.first_vicuna(input_ids=input_ids) - # generate second vicuna - compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64) - pkv = tuple( - (torch.zeros([1, 32, 19, 128], dtype=torch.float32)) - for _ in range(64) - ) - secondVicunaCompileInput = (compilation_input_ids,) + pkv - second_output = self.second_vicuna(*secondVicunaCompileInput) - return second_output diff --git a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py b/apps/language_models/src/model_wrappers/vicuna_sharded_model.py deleted file mode 100644 index 6120b454..00000000 --- a/apps/language_models/src/model_wrappers/vicuna_sharded_model.py +++ /dev/null @@ -1,247 +0,0 @@ -import torch -import time - - -class FirstVicunaLayer(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, hidden_states, attention_mask, position_ids): - outputs = self.model( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - use_cache=True, - ) - next_hidden_states = outputs[0] - past_key_value_out0, past_key_value_out1 = ( - outputs[-1][0], - outputs[-1][1], - ) - - return ( - next_hidden_states, - past_key_value_out0, - past_key_value_out1, - ) - - -class SecondVicunaLayer(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward( - self, - hidden_states, - attention_mask, - position_ids, - past_key_value0, - past_key_value1, - ): - outputs = self.model( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=( - past_key_value0, - past_key_value1, - ), - use_cache=True, - ) - next_hidden_states = outputs[0] - past_key_value_out0, past_key_value_out1 = ( - outputs[-1][0], - outputs[-1][1], - ) - - return ( - next_hidden_states, - past_key_value_out0, - past_key_value_out1, - ) - - -class ShardedVicunaModel(torch.nn.Module): - def __init__(self, model, layers, lmhead, embedding, norm): - super().__init__() - self.model = model - self.model.model.config.use_cache = True - self.model.model.config.output_attentions = False - self.layers = layers - self.norm = norm - self.embedding = embedding - self.lmhead = lmhead - self.model.model.norm = self.norm - self.model.model.embed_tokens = self.embedding - self.model.lm_head = self.lmhead - self.model.model.layers = torch.nn.modules.container.ModuleList( - self.layers - ) - - def forward( - self, - input_ids, - is_first=True, - past_key_values=None, - attention_mask=None, - ): - return self.model.forward( - input_ids, - attention_mask=attention_mask, - past_key_values=past_key_values, - ) - - -class LMHead(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, hidden_states): - output = self.model(hidden_states) - return output - - -class LMHeadCompiled(torch.nn.Module): - def __init__(self, shark_module): - super().__init__() - self.model = shark_module - - def forward(self, hidden_states): - hidden_states_sample = hidden_states.detach() - - output = self.model("forward", (hidden_states,)) - output = torch.tensor(output) - - return output - - -class VicunaNorm(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, hidden_states): - output = self.model(hidden_states) - return output - - -class VicunaNormCompiled(torch.nn.Module): - def __init__(self, shark_module): - super().__init__() - self.model = shark_module - - def forward(self, hidden_states): - try: - hidden_states.detach() - except: - pass - output = self.model("forward", (hidden_states,), send_to_host=True) - output = torch.tensor(output) - - return output - - -class VicunaEmbedding(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, input_ids): - output = self.model(input_ids) - return output - - -class VicunaEmbeddingCompiled(torch.nn.Module): - def __init__(self, shark_module): - super().__init__() - self.model = shark_module - - def forward(self, input_ids): - input_ids.detach() - output = self.model("forward", (input_ids,), send_to_host=True) - output = torch.tensor(output) - - return output - - -class CompiledVicunaLayer(torch.nn.Module): - def __init__(self, shark_module, idx, breakpoints): - super().__init__() - self.model = shark_module - self.idx = idx - self.breakpoints = breakpoints - - def forward( - self, - hidden_states, - attention_mask, - position_ids, - past_key_value=None, - output_attentions=False, - use_cache=True, - ): - if self.breakpoints is None: - is_breakpoint = False - else: - is_breakpoint = self.idx + 1 in self.breakpoints - if past_key_value is None: - output = self.model( - "first_vicuna_forward", - ( - hidden_states, - attention_mask, - position_ids, - ), - send_to_host=is_breakpoint, - ) - - if is_breakpoint: - output0 = torch.tensor(output[0]) - output1 = torch.tensor(output[1]) - output2 = torch.tensor(output[2]) - else: - output0 = output[0] - output1 = output[1] - output2 = output[2] - - return ( - output0, - ( - output1, - output2, - ), - ) - else: - pkv0 = past_key_value[0] - pkv1 = past_key_value[1] - output = self.model( - "second_vicuna_forward", - ( - hidden_states, - attention_mask, - position_ids, - pkv0, - pkv1, - ), - send_to_host=is_breakpoint, - ) - - if is_breakpoint: - output0 = torch.tensor(output[0]) - output1 = torch.tensor(output[1]) - output2 = torch.tensor(output[2]) - else: - output0 = output[0] - output1 = output[1] - output2 = output[2] - - return ( - output0, - ( - output1, - output2, - ), - ) diff --git a/apps/language_models/src/pipelines/SharkLLMBase.py b/apps/language_models/src/pipelines/SharkLLMBase.py deleted file mode 100644 index f33d7703..00000000 --- a/apps/language_models/src/pipelines/SharkLLMBase.py +++ /dev/null @@ -1,44 +0,0 @@ -from abc import ABC, abstractmethod - - -class SharkLLMBase(ABC): - def __init__( - self, - model_name, - hf_model_path=None, - max_num_tokens=512, - ) -> None: - self.model_name = model_name - self.hf_model_path = hf_model_path - self.max_num_tokens = max_num_tokens - self.shark_model = None - self.device = "cpu" - self.precision = "fp32" - - @classmethod - @abstractmethod - def compile(self): - pass - - @classmethod - @abstractmethod - def generate(self, prompt): - pass - - @classmethod - @abstractmethod - def generate_new_token(self, params): - pass - - @classmethod - @abstractmethod - def get_tokenizer(self): - pass - - @classmethod - @abstractmethod - def get_src_model(self): - pass - - def load_init_from_config(self): - pass diff --git a/apps/language_models/src/pipelines/falcon_pipeline.py b/apps/language_models/src/pipelines/falcon_pipeline.py deleted file mode 100644 index ba05efec..00000000 --- a/apps/language_models/src/pipelines/falcon_pipeline.py +++ /dev/null @@ -1,1137 +0,0 @@ -from apps.language_models.src.model_wrappers.falcon_model import FalconModel -from apps.language_models.src.model_wrappers.falcon_sharded_model import ( - WordEmbeddingsLayer, - CompiledWordEmbeddingsLayer, - LNFEmbeddingLayer, - CompiledLNFEmbeddingLayer, - LMHeadEmbeddingLayer, - CompiledLMHeadEmbeddingLayer, - FourWayShardingDecoderLayer, - TwoWayShardingDecoderLayer, - CompiledFourWayShardingDecoderLayer, - CompiledTwoWayShardingDecoderLayer, - ShardedFalconModel, -) -from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase -from apps.language_models.utils import ( - get_vmfb_from_path, -) -from io import BytesIO -from pathlib import Path -from contextlib import redirect_stdout -from shark.shark_downloader import download_public_file -from shark.shark_importer import import_with_fx, save_mlir -from shark.shark_inference import SharkInference -from transformers import AutoTokenizer, AutoModelForCausalLM, GPTQConfig -from transformers.generation import ( - GenerationConfig, - LogitsProcessorList, - StoppingCriteriaList, -) -import copy -import time -import re -import torch -import torch_mlir -import os -import argparse -import gc - -parser = argparse.ArgumentParser( - prog="falcon runner", - description="runs a falcon model", -) - -parser.add_argument( - "--falcon_variant_to_use", default="7b", help="7b, 40b, 180b" -) -parser.add_argument( - "--compressed", - default=False, - action=argparse.BooleanOptionalAction, - help="Do the compression of sharded layers", -) -parser.add_argument( - "--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"] -) -parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda") -parser.add_argument( - "--falcon_vmfb_path", default=None, help="path to falcon's vmfb" -) -parser.add_argument( - "--falcon_mlir_path", - default=None, - help="path to falcon's mlir file", -) -parser.add_argument( - "--use_precompiled_model", - default=True, - action=argparse.BooleanOptionalAction, - help="use the precompiled vmfb", -) -parser.add_argument( - "--load_mlir_from_shark_tank", - default=True, - action=argparse.BooleanOptionalAction, - help="download precompile mlir from shark tank", -) -parser.add_argument( - "--cli", - default=True, - action=argparse.BooleanOptionalAction, - help="Run model in cli mode", -) -parser.add_argument( - "--hf_auth_token", - type=str, - default=None, - help="Specify your own huggingface authentication token for falcon-180B model.", -) -parser.add_argument( - "-s", - "--sharded", - default=False, - action=argparse.BooleanOptionalAction, - help="Run model as sharded", -) -parser.add_argument( - "--num_shards", - type=int, - default=4, - choices=[2, 4], - help="Number of shards.", -) - - -class ShardedFalcon(SharkLLMBase): - def __init__( - self, - model_name, - hf_model_path="tiiuae/falcon-7b-instruct", - hf_auth_token: str = None, - max_num_tokens=150, - device="cuda", - precision="fp32", - falcon_mlir_path=None, - falcon_vmfb_path=None, - debug=False, - ) -> None: - super().__init__(model_name, hf_model_path, max_num_tokens) - print("hf_model_path: ", self.hf_model_path) - - if ( - "180b" in self.model_name - and precision != "int4" - and hf_auth_token == None - ): - raise ValueError( - """ HF auth token required for falcon-180b. Pass it using - --hf_auth_token flag. You can ask for the access to the model - here: https://huggingface.co/tiiuae/falcon-180B-chat.""" - ) - - if args.sharded and "180b" not in self.model_name: - raise ValueError("Sharding supported only for Falcon-180B") - - self.hf_auth_token = hf_auth_token - self.max_padding_length = 100 - self.device = device - self.precision = precision - self.falcon_vmfb_path = falcon_vmfb_path - self.falcon_mlir_path = falcon_mlir_path - self.debug = debug - self.tokenizer = self.get_tokenizer() - self.src_model = self.get_src_model() - self.shark_model = self.compile() - - def get_tokenizer(self): - tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, - trust_remote_code=True, - token=self.hf_auth_token, - ) - tokenizer.padding_side = "left" - tokenizer.pad_token_id = 11 - return tokenizer - - def get_src_model(self): - print("Loading src model: ", self.model_name) - kwargs = { - "torch_dtype": torch.float32, - "trust_remote_code": True, - "token": self.hf_auth_token, - } - if self.precision == "int4": - quantization_config = GPTQConfig(bits=4, disable_exllama=True) - kwargs["quantization_config"] = quantization_config - kwargs["device_map"] = "cpu" - falcon_model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, **kwargs - ) - return falcon_model - - def compile_layer( - self, layer, falconCompileInput, layer_id, device_idx=None - ): - self.falcon_mlir_path = Path( - f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}.mlir" - ) - self.falcon_vmfb_path = Path( - f"falcon_{args.falcon_variant_to_use}_layer_{layer_id}_{self.precision}_{self.device}.vmfb" - ) - - if args.use_precompiled_model: - if not self.falcon_vmfb_path.exists(): - # Downloading VMFB from shark_tank - print(f"[DEBUG] Trying to download vmfb from shark_tank") - download_public_file( - f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/vmfb/" - + str(self.falcon_vmfb_path), - self.falcon_vmfb_path.absolute(), - single_file=True, - ) - vmfb = get_vmfb_from_path( - self.falcon_vmfb_path, - self.device, - "linalg", - device_id=device_idx, - ) - if vmfb is not None: - return vmfb, device_idx - - print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}") - if self.falcon_mlir_path.exists(): - print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}") - with open(self.falcon_mlir_path, "rb") as f: - bytecode = f.read() - else: - mlir_generated = False - print( - f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}" - ) - if args.load_mlir_from_shark_tank: - # Downloading MLIR from shark_tank - print(f"[DEBUG] Trying to download mlir from shark_tank") - download_public_file( - f"gs://shark_tank/falcon/sharded/falcon_{args.falcon_variant_to_use}/mlir/" - + str(self.falcon_mlir_path), - self.falcon_mlir_path.absolute(), - single_file=True, - ) - if self.falcon_mlir_path.exists(): - print( - f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}" - ) - with open(self.falcon_mlir_path, "rb") as f: - bytecode = f.read() - mlir_generated = True - - if not mlir_generated: - print(f"[DEBUG] generating MLIR locally") - if layer_id == "word_embeddings": - f16_input_mask = [False] - elif layer_id in ["ln_f", "lm_head"]: - f16_input_mask = [True] - elif "_" in layer_id or type(layer_id) == int: - f16_input_mask = [True, True] - else: - raise ValueError("Unsupported layer: ", layer_id) - - print(f"[DEBUG] generating torchscript graph") - ts_graph = import_with_fx( - layer, - falconCompileInput, - is_f16=True, - f16_input_mask=f16_input_mask, - mlir_type="torchscript", - is_gptq=True, - ) - del layer - - print(f"[DEBUG] generating torch mlir") - module = torch_mlir.compile( - ts_graph, - falconCompileInput, - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - del ts_graph - - print(f"[DEBUG] converting to bytecode") - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - del module - - f_ = open(self.falcon_mlir_path, "wb") - f_.write(bytecode) - print("Saved falcon mlir at ", str(self.falcon_mlir_path)) - f_.close() - del bytecode - - shark_module = SharkInference( - mlir_module=self.falcon_mlir_path, - device=self.device, - mlir_dialect="linalg", - device_idx=device_idx, - ) - path = shark_module.save_module( - self.falcon_vmfb_path.parent.absolute(), - self.falcon_vmfb_path.stem, - extra_args=[ - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - + [ - "--iree-llvmcpu-use-fast-min-max-ops", - ] - if self.precision == "int4" - else [], - debug=self.debug, - ) - print("Saved falcon vmfb at ", str(path)) - shark_module.load_module(path) - - return shark_module, device_idx - - def compile(self): - sample_input_ids = torch.zeros([100], dtype=torch.int64) - sample_attention_mask = torch.zeros( - [1, 1, 100, 100], dtype=torch.float32 - ) - num_group_layers = int( - 20 * (4 / args.num_shards) - ) # 4 is the number of default shards - sample_hidden_states = torch.zeros( - [1, 100, 14848], dtype=torch.float32 - ) - - # Determine number of available devices - num_devices = 1 - if self.device == "rocm": - import iree.runtime as ireert - - haldriver = ireert.get_driver(self.device) - num_devices = len(haldriver.query_available_devices()) - if num_devices < 2: - raise ValueError( - "Cannot run Falcon-180B on a single ROCM device." - ) - - lm_head = LMHeadEmbeddingLayer(self.src_model.lm_head) - print("Compiling Layer lm_head") - shark_lm_head, _ = self.compile_layer( - lm_head, - [sample_hidden_states], - "lm_head", - device_idx=(0 % num_devices) % args.num_shards - if self.device == "rocm" - else None, - ) - shark_lm_head = CompiledLMHeadEmbeddingLayer(shark_lm_head) - - word_embedding = WordEmbeddingsLayer( - self.src_model.transformer.word_embeddings - ) - print("Compiling Layer word_embeddings") - shark_word_embedding, _ = self.compile_layer( - word_embedding, - [sample_input_ids], - "word_embeddings", - device_idx=(1 % num_devices) % args.num_shards - if self.device == "rocm" - else None, - ) - shark_word_embedding = CompiledWordEmbeddingsLayer( - shark_word_embedding - ) - - ln_f = LNFEmbeddingLayer(self.src_model.transformer.ln_f) - print("Compiling Layer ln_f") - shark_ln_f, _ = self.compile_layer( - ln_f, - [sample_hidden_states], - "ln_f", - device_idx=(2 % num_devices) % args.num_shards - if self.device == "rocm" - else None, - ) - shark_ln_f = CompiledLNFEmbeddingLayer(shark_ln_f) - - shark_layers = [] - for i in range( - int(len(self.src_model.transformer.h) / num_group_layers) - ): - device_idx = i % num_devices if self.device == "rocm" else None - layer_id = i - layer_id = ( - str(i * num_group_layers) - + "_" - + str((i + 1) * num_group_layers) - ) - pytorch_class = FourWayShardingDecoderLayer - compiled_class = CompiledFourWayShardingDecoderLayer - if args.num_shards == 2: - pytorch_class = TwoWayShardingDecoderLayer - compiled_class = CompiledTwoWayShardingDecoderLayer - - print("Compiling Layer {}".format(layer_id)) - layer_i = self.src_model.transformer.h[ - i * num_group_layers : (i + 1) * num_group_layers - ] - - pytorch_layer_i = pytorch_class( - layer_i, args.falcon_variant_to_use - ) - shark_module, device_idx = self.compile_layer( - pytorch_layer_i, - [sample_hidden_states, sample_attention_mask], - layer_id, - device_idx=device_idx, - ) - shark_layer_i = compiled_class( - layer_id, - device_idx, - args.falcon_variant_to_use, - self.device, - self.precision, - shark_module, - ) - shark_layers.append(shark_layer_i) - - sharded_model = ShardedFalconModel( - self.src_model, - shark_layers, - shark_word_embedding, - shark_ln_f, - shark_lm_head, - ) - return sharded_model - - def generate(self, prompt): - model_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.max_padding_length, - add_special_tokens=False, - return_tensors="pt", - ) - model_inputs["prompt_text"] = prompt - - input_ids = model_inputs["input_ids"] - attention_mask = model_inputs.get("attention_mask", None) - - # Allow empty prompts - if input_ids.shape[1] == 0: - input_ids = None - attention_mask = None - - generate_kwargs = { - "max_length": self.max_num_tokens, - "do_sample": True, - "top_k": 10, - "num_return_sequences": 1, - "eos_token_id": 11, - } - generate_kwargs["input_ids"] = input_ids - generate_kwargs["attention_mask"] = attention_mask - generation_config_ = GenerationConfig.from_model_config( - self.src_model.config - ) - generation_config = copy.deepcopy(generation_config_) - model_kwargs = generation_config.update(**generate_kwargs) - - logits_processor = LogitsProcessorList() - stopping_criteria = StoppingCriteriaList() - - eos_token_id = generation_config.eos_token_id - generation_config.pad_token_id = eos_token_id - - ( - inputs_tensor, - model_input_name, - model_kwargs, - ) = self.src_model._prepare_model_inputs( - None, generation_config.bos_token_id, model_kwargs - ) - - model_kwargs["output_attentions"] = generation_config.output_attentions - model_kwargs[ - "output_hidden_states" - ] = generation_config.output_hidden_states - model_kwargs["use_cache"] = generation_config.use_cache - - input_ids = ( - inputs_tensor - if model_input_name == "input_ids" - else model_kwargs.pop("input_ids") - ) - - self.logits_processor = self.src_model._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids.shape[-1], - encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=None, - logits_processor=logits_processor, - ) - - self.stopping_criteria = self.src_model._get_stopping_criteria( - generation_config=generation_config, - stopping_criteria=stopping_criteria, - ) - - self.logits_warper = self.src_model._get_logits_warper( - generation_config - ) - - ( - self.input_ids, - self.model_kwargs, - ) = self.src_model._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, # 1 - is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False - **model_kwargs, - ) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - self.eos_token_id_tensor = ( - torch.tensor(eos_token_id) if eos_token_id is not None else None - ) - - self.pad_token_id = generation_config.pad_token_id - self.eos_token_id = eos_token_id - - output_scores = generation_config.output_scores # False - return_dict_in_generate = ( - generation_config.return_dict_in_generate # False - ) - - # init attention / hidden states / scores tuples - self.scores = ( - () if (return_dict_in_generate and output_scores) else None - ) - - # keep track of which sequences are already finished - self.unfinished_sequences = torch.ones( - input_ids.shape[0], dtype=torch.long, device=input_ids.device - ) - - all_text = prompt - - start = time.time() - count = 0 - for i in range(self.max_num_tokens - 1): - count = count + 1 - - next_token = self.generate_new_token() - new_word = self.tokenizer.decode( - next_token.cpu().numpy(), - add_special_tokens=False, - skip_special_tokens=True, - clean_up_tokenization_spaces=True, - ) - - all_text = all_text + new_word - - print(f"{new_word}", end="", flush=True) - print(f"{all_text}", end="", flush=True) - - # if eos_token was found in one sentence, set sentence to finished - if self.eos_token_id_tensor is not None: - self.unfinished_sequences = self.unfinished_sequences.mul( - next_token.tile(self.eos_token_id_tensor.shape[0], 1) - .ne(self.eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) - # stop when each sentence is finished - if ( - self.unfinished_sequences.max() == 0 - or self.stopping_criteria(input_ids, self.scores) - ): - break - - end = time.time() - print( - "\n\nTime taken is {:.2f} seconds/token\n".format( - (end - start) / count - ) - ) - - torch.cuda.empty_cache() - gc.collect() - - return all_text - - def generate_new_token(self): - model_inputs = self.src_model.prepare_inputs_for_generation( - self.input_ids, **self.model_kwargs - ) - outputs = self.shark_model.forward( - input_ids=model_inputs["input_ids"], - attention_mask=model_inputs["attention_mask"], - ) - if self.precision in ["fp16", "int4"]: - outputs = outputs.to(dtype=torch.float32) - next_token_logits = outputs - - # pre-process distribution - next_token_scores = self.logits_processor( - self.input_ids, next_token_logits - ) - next_token_scores = self.logits_warper( - self.input_ids, next_token_scores - ) - - # sample - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - - next_token = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if self.eos_token_id is not None: - if self.pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_token = ( - next_token * self.unfinished_sequences - + self.pad_token_id * (1 - self.unfinished_sequences) - ) - - self.input_ids = torch.cat( - [self.input_ids, next_token[:, None]], dim=-1 - ) - - self.model_kwargs["past_key_values"] = None - if "attention_mask" in self.model_kwargs: - attention_mask = self.model_kwargs["attention_mask"] - self.model_kwargs["attention_mask"] = torch.cat( - [ - attention_mask, - attention_mask.new_ones((attention_mask.shape[0], 1)), - ], - dim=-1, - ) - - self.input_ids = self.input_ids[:, 1:] - self.model_kwargs["attention_mask"] = self.model_kwargs[ - "attention_mask" - ][:, 1:] - - return next_token - - -class UnshardedFalcon(SharkLLMBase): - def __init__( - self, - model_name, - hf_model_path="tiiuae/falcon-7b-instruct", - hf_auth_token: str = "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", - max_num_tokens=150, - device="cuda", - precision="fp32", - falcon_mlir_path=None, - falcon_vmfb_path=None, - debug=False, - ) -> None: - super().__init__(model_name, hf_model_path, max_num_tokens) - print("hf_model_path: ", self.hf_model_path) - - if "180b" in self.model_name and hf_auth_token == None: - raise ValueError( - """ HF auth token required for falcon-180b. Pass it using - --hf_auth_token flag. You can ask for the access to the model - here: https://huggingface.co/tiiuae/falcon-180B-chat.""" - ) - self.hf_auth_token = hf_auth_token - self.max_padding_length = 100 - self.device = device - self.precision = precision - self.falcon_vmfb_path = falcon_vmfb_path - self.falcon_mlir_path = falcon_mlir_path - self.debug = debug - self.tokenizer = self.get_tokenizer() - self.src_model = self.get_src_model() - self.shark_model = self.compile() - - def get_tokenizer(self): - tokenizer = AutoTokenizer.from_pretrained( - self.hf_model_path, - trust_remote_code=True, - token=self.hf_auth_token, - ) - tokenizer.padding_side = "left" - tokenizer.pad_token_id = 11 - return tokenizer - - def get_src_model(self): - print("Loading src model: ", self.model_name) - kwargs = { - "torch_dtype": torch.float32, - "trust_remote_code": True, - "token": self.hf_auth_token, - } - if self.precision == "int4": - quantization_config = GPTQConfig(bits=4, disable_exllama=True) - kwargs["quantization_config"] = quantization_config - kwargs["device_map"] = "cpu" - falcon_model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, **kwargs - ) - return falcon_model - - def compile(self): - if args.use_precompiled_model: - if not self.falcon_vmfb_path.exists(): - # Downloading VMFB from shark_tank - download_public_file( - "gs://shark_tank/falcon/" - + "falcon_" - + args.falcon_variant_to_use - + "_" - + self.precision - + "_" - + self.device - + ".vmfb", - self.falcon_vmfb_path.absolute(), - single_file=True, - ) - vmfb = get_vmfb_from_path( - self.falcon_vmfb_path, self.device, "linalg" - ) - if vmfb is not None: - return vmfb - - print(f"[DEBUG] vmfb not found at {self.falcon_vmfb_path.absolute()}") - if self.falcon_mlir_path.exists(): - print(f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}") - with open(self.falcon_mlir_path, "rb") as f: - bytecode = f.read() - else: - mlir_generated = False - print( - f"[DEBUG] mlir not found at {self.falcon_mlir_path.absolute()}" - ) - if args.load_mlir_from_shark_tank: - # Downloading MLIR from shark_tank - print(f"[DEBUG] Trying to download mlir from shark_tank") - download_public_file( - "gs://shark_tank/falcon/" - + "falcon_" - + args.falcon_variant_to_use - + "_" - + self.precision - + ".mlir", - self.falcon_mlir_path.absolute(), - single_file=True, - ) - if self.falcon_mlir_path.exists(): - print( - f"[DEBUG] mlir found at {self.falcon_mlir_path.absolute()}" - ) - mlir_generated = True - - if not mlir_generated: - print(f"[DEBUG] generating MLIR locally") - compilation_input_ids = torch.randint( - low=1, high=10000, size=(1, 100) - ) - compilation_attention_mask = torch.ones( - 1, 100, dtype=torch.int64 - ) - falconCompileInput = ( - compilation_input_ids, - compilation_attention_mask, - ) - model = FalconModel(self.src_model) - - print(f"[DEBUG] generating torchscript graph") - ts_graph = import_with_fx( - model, - falconCompileInput, - is_f16=self.precision in ["fp16", "int4"], - f16_input_mask=[False, False], - mlir_type="torchscript", - is_gptq=self.precision == "int4", - ) - del model - print(f"[DEBUG] generating torch mlir") - - module = torch_mlir.compile( - ts_graph, - [*falconCompileInput], - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - del ts_graph - - print(f"[DEBUG] converting to bytecode") - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - del module - - f_ = open(self.falcon_mlir_path, "wb") - f_.write(bytecode) - print("Saved falcon mlir at ", str(self.falcon_mlir_path)) - f_.close() - del bytecode - - shark_module = SharkInference( - mlir_module=self.falcon_mlir_path, - device=self.device, - mlir_dialect="linalg", - ) - path = shark_module.save_module( - self.falcon_vmfb_path.parent.absolute(), - self.falcon_vmfb_path.stem, - extra_args=[ - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - + [ - "--iree-llvmcpu-use-fast-min-max-ops", - ] - if self.precision == "int4" - else [], - debug=self.debug, - ) - print("Saved falcon vmfb at ", str(path)) - shark_module.load_module(path) - - return shark_module - - def generate(self, prompt): - model_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.max_padding_length, - add_special_tokens=False, - return_tensors="pt", - ) - model_inputs["prompt_text"] = prompt - - input_ids = model_inputs["input_ids"] - attention_mask = model_inputs.get("attention_mask", None) - - # Allow empty prompts - if input_ids.shape[1] == 0: - input_ids = None - attention_mask = None - in_b = 1 - else: - in_b = input_ids.shape[0] - - generate_kwargs = { - "max_length": self.max_num_tokens, - "do_sample": True, - "top_k": 10, - "num_return_sequences": 1, - "eos_token_id": 11, - } - generate_kwargs["input_ids"] = input_ids - generate_kwargs["attention_mask"] = attention_mask - generation_config_ = GenerationConfig.from_model_config( - self.src_model.config - ) - generation_config = copy.deepcopy(generation_config_) - model_kwargs = generation_config.update(**generate_kwargs) - - logits_processor = LogitsProcessorList() - stopping_criteria = StoppingCriteriaList() - - eos_token_id = generation_config.eos_token_id - generation_config.pad_token_id = eos_token_id - - ( - inputs_tensor, - model_input_name, - model_kwargs, - ) = self.src_model._prepare_model_inputs( - None, generation_config.bos_token_id, model_kwargs - ) - batch_size = inputs_tensor.shape[0] - - model_kwargs["output_attentions"] = generation_config.output_attentions - model_kwargs[ - "output_hidden_states" - ] = generation_config.output_hidden_states - model_kwargs["use_cache"] = generation_config.use_cache - - input_ids = ( - inputs_tensor - if model_input_name == "input_ids" - else model_kwargs.pop("input_ids") - ) - - self.logits_processor = self.src_model._get_logits_processor( - generation_config=generation_config, - input_ids_seq_length=input_ids.shape[-1], - encoder_input_ids=inputs_tensor, - prefix_allowed_tokens_fn=None, - logits_processor=logits_processor, - ) - - self.stopping_criteria = self.src_model._get_stopping_criteria( - generation_config=generation_config, - stopping_criteria=stopping_criteria, - ) - - self.logits_warper = self.src_model._get_logits_warper( - generation_config - ) - - ( - self.input_ids, - self.model_kwargs, - ) = self.src_model._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, # 1 - is_encoder_decoder=self.src_model.config.is_encoder_decoder, # False - **model_kwargs, - ) - - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - self.eos_token_id_tensor = ( - torch.tensor(eos_token_id) if eos_token_id is not None else None - ) - - self.pad_token_id = generation_config.pad_token_id - self.eos_token_id = eos_token_id - - output_scores = generation_config.output_scores # False - output_attentions = generation_config.output_attentions # False - output_hidden_states = generation_config.output_hidden_states # False - return_dict_in_generate = ( - generation_config.return_dict_in_generate # False - ) - - # init attention / hidden states / scores tuples - self.scores = ( - () if (return_dict_in_generate and output_scores) else None - ) - decoder_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - cross_attentions = ( - () if (return_dict_in_generate and output_attentions) else None - ) - decoder_hidden_states = ( - () if (return_dict_in_generate and output_hidden_states) else None - ) - - # keep track of which sequences are already finished - self.unfinished_sequences = torch.ones( - input_ids.shape[0], dtype=torch.long, device=input_ids.device - ) - - all_text = prompt - - start = time.time() - count = 0 - for i in range(self.max_num_tokens - 1): - count = count + 1 - - next_token = self.generate_new_token() - new_word = self.tokenizer.decode( - next_token.cpu().numpy(), - add_special_tokens=False, - skip_special_tokens=True, - clean_up_tokenization_spaces=True, - ) - - all_text = all_text + new_word - - print(f"{new_word}", end="", flush=True) - - # if eos_token was found in one sentence, set sentence to finished - if self.eos_token_id_tensor is not None: - self.unfinished_sequences = self.unfinished_sequences.mul( - next_token.tile(self.eos_token_id_tensor.shape[0], 1) - .ne(self.eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) - # stop when each sentence is finished - if ( - self.unfinished_sequences.max() == 0 - or self.stopping_criteria(input_ids, self.scores) - ): - break - - end = time.time() - print( - "\n\nTime taken is {:.2f} seconds/token\n".format( - (end - start) / count - ) - ) - - torch.cuda.empty_cache() - gc.collect() - - return all_text - - def generate_new_token(self): - model_inputs = self.src_model.prepare_inputs_for_generation( - self.input_ids, **self.model_kwargs - ) - outputs = torch.from_numpy( - self.shark_model( - "forward", - (model_inputs["input_ids"], model_inputs["attention_mask"]), - ) - ) - if self.precision in ["fp16", "int4"]: - outputs = outputs.to(dtype=torch.float32) - next_token_logits = outputs - - # pre-process distribution - next_token_scores = self.logits_processor( - self.input_ids, next_token_logits - ) - next_token_scores = self.logits_warper( - self.input_ids, next_token_scores - ) - - # sample - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - - next_token = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if self.eos_token_id is not None: - if self.pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_token = ( - next_token * self.unfinished_sequences - + self.pad_token_id * (1 - self.unfinished_sequences) - ) - - self.input_ids = torch.cat( - [self.input_ids, next_token[:, None]], dim=-1 - ) - - self.model_kwargs["past_key_values"] = None - if "attention_mask" in self.model_kwargs: - attention_mask = self.model_kwargs["attention_mask"] - self.model_kwargs["attention_mask"] = torch.cat( - [ - attention_mask, - attention_mask.new_ones((attention_mask.shape[0], 1)), - ], - dim=-1, - ) - - self.input_ids = self.input_ids[:, 1:] - self.model_kwargs["attention_mask"] = self.model_kwargs[ - "attention_mask" - ][:, 1:] - - return next_token - - -if __name__ == "__main__": - args = parser.parse_args() - - falcon_mlir_path = ( - Path( - "falcon_" - + args.falcon_variant_to_use - + "_" - + args.precision - + ".mlir" - ) - if args.falcon_mlir_path is None - else Path(args.falcon_mlir_path) - ) - falcon_vmfb_path = ( - Path( - "falcon_" - + args.falcon_variant_to_use - + "_" - + args.precision - + "_" - + args.device - + ".vmfb" - ) - if args.falcon_vmfb_path is None - else Path(args.falcon_vmfb_path) - ) - - if args.precision == "int4": - if args.falcon_variant_to_use == "180b": - hf_model_path_value = "TheBloke/Falcon-180B-Chat-GPTQ" - else: - hf_model_path_value = ( - "TheBloke/falcon-" - + args.falcon_variant_to_use - + "-instruct-GPTQ" - ) - else: - if args.falcon_variant_to_use == "180b": - hf_model_path_value = "tiiuae/falcon-180B-chat" - else: - hf_model_path_value = ( - "tiiuae/falcon-" + args.falcon_variant_to_use + "-instruct" - ) - - if not args.sharded: - falcon = UnshardedFalcon( - model_name="falcon_" + args.falcon_variant_to_use, - hf_model_path=hf_model_path_value, - device=args.device, - precision=args.precision, - falcon_mlir_path=falcon_mlir_path, - falcon_vmfb_path=falcon_vmfb_path, - ) - else: - falcon = ShardedFalcon( - model_name="falcon_" + args.falcon_variant_to_use, - hf_model_path=hf_model_path_value, - device=args.device, - precision=args.precision, - ) - - default_prompt_text = "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:" - continue_execution = True - - print("\n-----\nScript executing for the following config: \n") - print("Falcon Model: ", falcon.model_name) - print("Precision: ", args.precision) - print("Device: ", args.device) - - while continue_execution: - use_default_prompt = input( - "\nDo you wish to use the default prompt text? Y/N ?: " - ) - if use_default_prompt in ["Y", "y"]: - prompt = default_prompt_text - else: - prompt = input("Please enter the prompt text: ") - print("\nPrompt Text: ", prompt) - - prompt_template = f"""A helpful assistant who helps the user with any questions asked. - User: {prompt} - Assistant:""" - - res_str = falcon.generate(prompt_template) - torch.cuda.empty_cache() - gc.collect() - print( - "\n\n-----\nHere's the complete formatted result: \n\n", - res_str, - ) - continue_execution = input( - "\nDo you wish to run script one more time? Y/N ?: " - ) - continue_execution = ( - True if continue_execution in ["Y", "y"] else False - ) diff --git a/apps/language_models/src/pipelines/minigpt4_pipeline.py b/apps/language_models/src/pipelines/minigpt4_pipeline.py deleted file mode 100644 index 98ab9ed6..00000000 --- a/apps/language_models/src/pipelines/minigpt4_pipeline.py +++ /dev/null @@ -1,1449 +0,0 @@ -from apps.language_models.src.model_wrappers.minigpt4 import ( - LayerNorm, - VisionModel, - QformerBertModel, - FirstLlamaModel, - SecondLlamaModel, - StoppingCriteriaSub, - CONV_VISION, -) -from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase -from apps.language_models.utils import ( - get_vmfb_from_path, - get_vmfb_from_config, -) -from omegaconf import OmegaConf -from pathlib import Path -from shark.shark_downloader import download_public_file -from transformers import LlamaTokenizer, LlamaForCausalLM -from transformers import AutoTokenizer, AutoModelForCausalLM -from transformers import StoppingCriteriaList -from transformers.generation import GenerationConfig, LogitsProcessorList - -import re -import torch -import os -from PIL import Image -import sys -import requests - -# SHARK dependencies -from shark.shark_compile import ( - shark_compile_through_fx, -) -import random -import contextlib -from transformers import BertTokenizer -from transformers.generation import GenerationConfig, LogitsProcessorList -import copy -import tempfile - -# QFormer, eva_vit, blip_processor -from apps.language_models.src.pipelines.minigpt4_utils.Qformer import ( - BertConfig, - BertLMHeadModel, -) -from apps.language_models.src.pipelines.minigpt4_utils.eva_vit import ( - create_eva_vit_g, -) -from apps.language_models.src.pipelines.minigpt4_utils.blip_processors import ( - Blip2ImageEvalProcessor, -) - -import argparse - -parser = argparse.ArgumentParser( - prog="MiniGPT4 runner", - description="runs MiniGPT4", -) - -parser.add_argument( - "--precision", "-p", default="fp16", help="fp32, fp16, int8, int4" -) -parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda") -parser.add_argument( - "--vision_model_vmfb_path", - default=None, - help="path to vision model's vmfb", -) -parser.add_argument( - "--qformer_vmfb_path", - default=None, - help="path to qformer model's vmfb", -) -parser.add_argument( - "--image_path", - type=str, - default="", - help="path to the input image", -) -parser.add_argument( - "--load_mlir_from_shark_tank", - default=False, - action=argparse.BooleanOptionalAction, - help="download precompile mlir from shark tank", -) -parser.add_argument( - "--cli", - default=True, - action=argparse.BooleanOptionalAction, - help="Run model in cli mode", -) -parser.add_argument( - "--compile", - default=False, - action=argparse.BooleanOptionalAction, - help="Compile all models", -) -parser.add_argument( - "--max_length", - type=int, - default=2000, - help="Max length of the entire conversation", -) -parser.add_argument( - "--max_new_tokens", - type=int, - default=300, - help="Maximum no. of new tokens that can be generated for a query", -) - - -def disabled_train(self, mode=True): - """Overwrite model.train with this function to make sure train/eval mode - does not change anymore.""" - return self - - -def is_url(input_url): - """ - Check if an input string is a url. look for http(s):// and ignoring the case - """ - is_url = re.match(r"^(?:http)s?://", input_url, re.IGNORECASE) is not None - return is_url - - -import os -import tempfile -from shark.shark_inference import SharkInference -from shark.shark_importer import import_with_fx, save_mlir -import torch -import torch_mlir -from torch_mlir.compiler_utils import run_pipeline_with_repro_report -from typing import List, Tuple -from io import BytesIO -from brevitas_examples.common.generative.quantize import quantize_model -from brevitas_examples.llm.llm_quant.run_utils import get_model_impl - - -# fmt: off -def quant〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]: - if len(lhs) == 3 and len(rhs) == 2: - return [lhs[0], lhs[1], rhs[0]] - elif len(lhs) == 2 and len(rhs) == 2: - return [lhs[0], rhs[0]] - else: - raise ValueError("Input shapes not supported.") - - -def quant〇matmul_rhs_group_quant〡dtype(lhs_rank_dtype: Tuple[int, int], rhs_rank_dtype: Tuple[int, int], rhs_scale_rank_dtype: Tuple[int, int], rhs_zero_point_rank_dtype: Tuple[int, int], rhs_bit_width: int, rhs_group_size: int) -> int: - # output dtype is the dtype of the lhs float input - lhs_rank, lhs_dtype = lhs_rank_dtype - return lhs_dtype - - -def quant〇matmul_rhs_group_quant〡has_value_semantics(lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size) -> None: - return - - -brevitas_matmul_rhs_group_quant_library = [ - quant〇matmul_rhs_group_quant〡shape, - quant〇matmul_rhs_group_quant〡dtype, - quant〇matmul_rhs_group_quant〡has_value_semantics] -# fmt: on - - -def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]): - vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") - shark_module = None - if os.path.isfile(vmfb_path): - shark_module = SharkInference( - None, - device=device, - mlir_dialect=mlir_dialect, - ) - print(f"loading existing vmfb from: {vmfb_path}") - shark_module.load_module(vmfb_path, extra_args=extra_args) - return shark_module - - -def compile_module( - shark_module, extended_model_name, generate_vmfb, extra_args=[], debug=False, -): - if generate_vmfb: - vmfb_path = os.path.join(os.getcwd(), extended_model_name + ".vmfb") - if os.path.isfile(vmfb_path): - print(f"loading existing vmfb from: {vmfb_path}") - shark_module.load_module(vmfb_path, extra_args=extra_args) - else: - print( - "No vmfb found. Compiling and saving to {}".format(vmfb_path) - ) - path = shark_module.save_module( - os.getcwd(), extended_model_name, extra_args, debug=debug - ) - shark_module.load_module(path, extra_args=extra_args) - else: - shark_module.compile(extra_args) - return shark_module - - -def compile_int_precision( - model, inputs, precision, device, generate_vmfb, extended_model_name, debug=False -): - torchscript_module = import_with_fx( - model, - inputs, - precision=precision, - mlir_type="torchscript", - ) - mlir_module = torch_mlir.compile( - torchscript_module, - inputs, - output_type="torch", - backend_legal_ops=["quant.matmul_rhs_group_quant"], - extra_library=brevitas_matmul_rhs_group_quant_library, - use_tracing=False, - verbose=False, - ) - print(f"[DEBUG] converting torch to linalg") - run_pipeline_with_repro_report( - mlir_module, - "builtin.module(func.func(torch-unpack-quant-tensor),func.func(torch-convert-custom-quant-op),torch-backend-to-linalg-on-tensors-backend-pipeline)", - description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR", - ) - from contextlib import redirect_stdout - - mlir_file_path = os.path.join( - os.getcwd(), f"{extended_model_name}_linalg.mlir" - ) - with open(mlir_file_path, "w") as f: - with redirect_stdout(f): - print(mlir_module.operation.get_asm()) - mlir_module = str(mlir_module) - mlir_module = mlir_module.encode("UTF-8") - mlir_module = BytesIO(mlir_module) - bytecode = mlir_module.read() - print(f"Elided IR written for {extended_model_name}") - bytecode = save_mlir( - bytecode, - model_name=extended_model_name, - frontend="torch", - dir=os.getcwd(), - ) - return bytecode - shark_module = SharkInference( - mlir_module=bytecode, device=device, mlir_dialect="tm_tensor" - ) - extra_args = [ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - return ( - compile_module( - shark_module, - extended_model_name=extended_model_name, - generate_vmfb=generate_vmfb, - extra_args=extra_args, - debug=debug, - ), - bytecode, - ) - - -def shark_compile_through_fx_int( - model, - inputs, - extended_model_name, - precision, - f16_input_mask=None, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=None, - mlir_dialect="tm_tensor", -): - if generate_or_load_vmfb: - shark_module = load_vmfb( - extended_model_name=extended_model_name, - device=device, - mlir_dialect=mlir_dialect, - extra_args=extra_args, - ) - if shark_module: - return ( - shark_module, - None, - ) - - from shark.parser import shark_args - - if "cuda" in device: - shark_args.enable_tf32 = True - - mlir_module = compile_int_precision( - model, - inputs, - precision, - device, - generate_or_load_vmfb, - extended_model_name, - debug, - ) - extra_args = [ - "--iree-hal-dump-executable-sources-to=ies", - "--iree-vm-target-truncate-unsupported-floats", - "--iree-codegen-check-ir-before-llvm-conversion=false", - "--iree-vm-bytecode-module-output-format=flatbuffer-binary", - ] - - shark_module = SharkInference( - mlir_module, - device=device, - mlir_dialect=mlir_dialect, - ) - return ( - compile_module( - shark_module, - extended_model_name, - generate_vmfb=generate_or_load_vmfb, - extra_args=extra_args, - ), - mlir_module, - ) - - -class MiniGPT4BaseModel(torch.nn.Module): - @classmethod - def from_config(cls, cfg): - vit_model = cfg.get("vit_model", "eva_clip_g") - q_former_model = cfg.get( - "q_former_model", - "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", - ) - img_size = cfg.get("image_size") - num_query_token = cfg.get("num_query_token") - llama_model = cfg.get("llama_model") - - drop_path_rate = cfg.get("drop_path_rate", 0) - use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) - vit_precision = cfg.get("vit_precision", "fp16") - freeze_vit = cfg.get("freeze_vit", True) - freeze_qformer = cfg.get("freeze_qformer", True) - low_resource = cfg.get("low_resource", False) - device_8bit = cfg.get("device_8bit", 0) - - prompt_path = cfg.get("prompt_path", "") - prompt_template = cfg.get("prompt_template", "") - max_txt_len = cfg.get("max_txt_len", 32) - end_sym = cfg.get("end_sym", "\n") - - model = cls( - vit_model=vit_model, - q_former_model=q_former_model, - img_size=img_size, - drop_path_rate=drop_path_rate, - use_grad_checkpoint=use_grad_checkpoint, - vit_precision=vit_precision, - freeze_vit=freeze_vit, - freeze_qformer=freeze_qformer, - num_query_token=num_query_token, - llama_model=llama_model, - prompt_path=prompt_path, - prompt_template=prompt_template, - max_txt_len=max_txt_len, - end_sym=end_sym, - low_resource=low_resource, - device_8bit=device_8bit, - ) - - ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 - if ckpt_path: - print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) - ckpt = torch.load(ckpt_path, map_location="cpu") - model.load_state_dict(ckpt["model"], strict=False) - - return model - - PRETRAINED_MODEL_CONFIG_DICT = { - "pretrain_vicuna": "minigpt4_utils/configs/minigpt4.yaml", - } - - def maybe_autocast(self, dtype=torch.float32): - # if on cpu, don't use autocast - # if on gpu, use autocast with dtype if provided, otherwise use torch.float16 - # enable_autocast = self.device != torch.device("cpu") - enable_autocast = True - - if enable_autocast: - return torch.cuda.amp.autocast(dtype=dtype) - else: - return contextlib.nullcontext() - - def init_tokenizer(cls): - tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - tokenizer.add_special_tokens({"bos_token": "[DEC]"}) - return tokenizer - - def init_vision_encoder( - self, - model_name, - img_size, - drop_path_rate, - use_grad_checkpoint, - precision, - ): - assert ( - model_name == "eva_clip_g" - ), "vit model must be eva_clip_g for current version of MiniGPT-4" - visual_encoder = create_eva_vit_g( - img_size, drop_path_rate, use_grad_checkpoint, precision - ) - - ln_vision = LayerNorm(visual_encoder.num_features) - return visual_encoder, ln_vision - - def init_Qformer( - cls, num_query_token, vision_width, cross_attention_freq=2 - ): - encoder_config = BertConfig.from_pretrained("bert-base-uncased") - encoder_config.encoder_width = vision_width - # insert cross-attention layer every other block - encoder_config.add_cross_attention = True - encoder_config.cross_attention_freq = cross_attention_freq - encoder_config.query_length = num_query_token - Qformer = BertLMHeadModel(config=encoder_config) - query_tokens = torch.nn.Parameter( - torch.zeros(1, num_query_token, encoder_config.hidden_size) - ) - query_tokens.data.normal_( - mean=0.0, std=encoder_config.initializer_range - ) - return Qformer, query_tokens - - def load_from_pretrained(self, url_or_filename): - if is_url(url_or_filename): - local_filename = "blip2_pretrained_flant5xxl.pth" - response = requests.get(url_or_filename) - if response.status_code == 200: - with open(local_filename, "wb") as f: - f.write(response.content) - print("File downloaded successfully.") - checkpoint = torch.load(local_filename, map_location="cpu") - elif os.path.isfile(url_or_filename): - checkpoint = torch.load(url_or_filename, map_location="cpu") - else: - raise RuntimeError("checkpoint url or path is invalid") - - state_dict = checkpoint["model"] - - self.load_state_dict(state_dict, strict=False) - - def __init__( - self, - vit_model="eva_clip_g", - q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", - img_size=224, - drop_path_rate=0, - use_grad_checkpoint=False, - vit_precision="fp16", - freeze_vit=True, - freeze_qformer=True, - num_query_token=32, - llama_model="", - prompt_path="", - prompt_template="", - max_txt_len=32, - end_sym="\n", - low_resource=False, # use 8 bit and put vit in cpu - device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. - ): - super().__init__() - self.tokenizer = self.init_tokenizer() - self.low_resource = low_resource - - print("Loading VIT") - self.visual_encoder, self.ln_vision = self.init_vision_encoder( - vit_model, - img_size, - drop_path_rate, - use_grad_checkpoint, - vit_precision, - ) - if freeze_vit: - for _, param in self.visual_encoder.named_parameters(): - param.requires_grad = False - self.visual_encoder = self.visual_encoder.eval() - self.visual_encoder.train = disabled_train - for _, param in self.ln_vision.named_parameters(): - param.requires_grad = False - self.ln_vision = self.ln_vision.eval() - self.ln_vision.train = disabled_train - # logging.info("freeze vision encoder") - print("Loading VIT Done") - - print("Loading Q-Former") - self.Qformer, self.query_tokens = self.init_Qformer( - num_query_token, self.visual_encoder.num_features - ) - self.Qformer.cls = None - self.Qformer.bert.embeddings.word_embeddings = None - self.Qformer.bert.embeddings.position_embeddings = None - for layer in self.Qformer.bert.encoder.layer: - layer.output = None - layer.intermediate = None - self.load_from_pretrained(url_or_filename=q_former_model) - - if freeze_qformer: - for _, param in self.Qformer.named_parameters(): - param.requires_grad = False - self.Qformer = self.Qformer.eval() - self.Qformer.train = disabled_train - self.query_tokens.requires_grad = False - # logging.info("freeze Qformer") - print("Loading Q-Former Done") - - print(f"Loading Llama model from {llama_model}") - self.llama_tokenizer = AutoTokenizer.from_pretrained( - llama_model, use_fast=False, legacy=False - ) - # self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token - - if self.low_resource: - self.llama_model = AutoModelForCausalLM.from_pretrained( - llama_model, - torch_dtype=torch.float16, - load_in_8bit=True, - device_map={"": device_8bit}, - ) - else: - self.llama_model = AutoModelForCausalLM.from_pretrained( - llama_model, - torch_dtype=torch.float32, - ) - - print( - "During init :-\nLlama model pad token : ", - self.llama_model.config.pad_token_id, - ) - print( - "Llama tokenizer pad token : ", self.llama_tokenizer.pad_token_id - ) - - for _, param in self.llama_model.named_parameters(): - param.requires_grad = False - print("Loading Llama Done") - - self.llama_proj = torch.nn.Linear( - self.Qformer.config.hidden_size, - self.llama_model.config.hidden_size, - ) - self.max_txt_len = max_txt_len - self.end_sym = end_sym - - if prompt_path: - with open(prompt_path, "r") as f: - raw_prompts = f.read().splitlines() - filted_prompts = [ - raw_prompt - for raw_prompt in raw_prompts - if "" in raw_prompt - ] - self.prompt_list = [ - prompt_template.format(p) for p in filted_prompts - ] - print("Load {} training prompts".format(len(self.prompt_list))) - print( - "Prompt Example \n{}".format(random.choice(self.prompt_list)) - ) - else: - self.prompt_list = [] - - -def resource_path(relative_path): - """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) - return os.path.join(base_path, relative_path) - - -class MiniGPT4(SharkLLMBase): - def __init__( - self, - model_name, - hf_model_path=None, - max_new_tokens=300, - device="cuda", - precision="fp16", - _compile=False, - vision_model_vmfb_path=Path("vision_model_fp16_cuda.vmfb"), - qformer_vmfb_path=Path("qformer_fp32_cuda.vmfb"), - ) -> None: - self.model_name = model_name - self.shark_model = None - super().__init__(model_name, hf_model_path, max_new_tokens) - self.download_dependencies() - self.device = device - self.precision = precision - self._compile = _compile - - self.vision_model_vmfb_path = vision_model_vmfb_path - self.qformer_vmfb_path = qformer_vmfb_path - self.first_llama_vmfb_path = None - self.second_llama_vmfb_path = None - - print("Initializing Chat") - config = OmegaConf.load( - resource_path("minigpt4_utils/configs/minigpt4_eval.yaml") - ) - model_config = OmegaConf.create() - model_config = OmegaConf.merge( - model_config, - OmegaConf.load( - resource_path("minigpt4_utils/configs/minigpt4.yaml") - ), - {"model": config["model"]}, - ) - model_config = model_config["model"] - model_config.device_8bit = 0 - model = MiniGPT4BaseModel.from_config(model_config).to("cpu") - datasets = config.get("datasets", None) - dataset_config = OmegaConf.create() - for dataset_name in datasets: - dataset_config_path = resource_path( - "minigpt4_utils/configs/cc_sbu_align.yaml" - ) - dataset_config = OmegaConf.merge( - dataset_config, - OmegaConf.load(dataset_config_path), - {"datasets": {dataset_name: config["datasets"][dataset_name]}}, - ) - dataset_config = dataset_config["datasets"] - vis_processor_cfg = dataset_config.cc_sbu_align.vis_processor.train - vis_processor = Blip2ImageEvalProcessor.from_config(vis_processor_cfg) - print("Initialization complete") - - self.model = model - self.vis_processor = vis_processor - stop_words_ids = [ - torch.tensor([835]).to("cpu"), - torch.tensor([2277, 29937]).to("cpu"), - ] # '###' can be encoded in two different ways. - self.stopping_criteria = StoppingCriteriaList( - [StoppingCriteriaSub(stops=stop_words_ids)] - ) - - self.first_llama = None - self.second_llama = None - - def download_dependencies(self): - pretrained_file = "prerained_minigpt4_7b.pth" - pretrained_file_url = f"gs://shark_tank/MiniGPT4/{pretrained_file}" - if not os.path.isfile(pretrained_file): - download_public_file( - pretrained_file_url, - Path("prerained_minigpt4_7b.pth").absolute(), - single_file=True, - ) - - if os.path.isfile(pretrained_file): - print(f"File downloaded successfully: {pretrained_file}") - else: - print(f"Error downloading {pretrained_file}") - sys.exit() - - # Currently we're compiling VisionModel for fp32/cuda. - def compile_vision_model(self): - # TODO: Hardcoding precision based on input choices. Take this down - # later. - vision_model_precision = "fp32" - if self.precision in ["int4", "int8", "fp16"]: - vision_model_precision = "fp16" - - if not self._compile: - vmfb = get_vmfb_from_path( - self.vision_model_vmfb_path, self.device, "tm_tensor" - ) - if vmfb is not None: - return vmfb - else: - vmfb = get_vmfb_from_config( - self.model_name, - "vision_model", - vision_model_precision, - self.device, - self.vision_model_vmfb_path, - ) - if vmfb is not None: - return vmfb - - visionModel = VisionModel( - copy.deepcopy(self.model.ln_vision), - copy.deepcopy(self.model.visual_encoder), - vision_model_precision, - ) - extended_model_name = ( - f"vision_model_{vision_model_precision}_{self.device}" - ) - print(f"Going to compile {extended_model_name}") - # Inputs for VisionModel. - inputs = [torch.randint(3, (1, 3, 224, 224), dtype=torch.float32)] - is_f16 = False - if vision_model_precision == "fp16": - is_f16 = True - if self.precision in ["int4", "int8"]: - shark_visionModel, _ = shark_compile_through_fx_int( - visionModel, - inputs, - extended_model_name=extended_model_name, - precision=vision_model_precision, - f16_input_mask=None, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) - else: - shark_visionModel, _ = shark_compile_through_fx( - visionModel, - inputs, - extended_model_name=extended_model_name, - precision=vision_model_precision, - f16_input_mask=None, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) - print(f"Generated {extended_model_name}.vmfb") - return shark_visionModel - - def compile_qformer_model(self): - if not self._compile: - vmfb = get_vmfb_from_path( - self.qformer_vmfb_path, self.device, "tm_tensor" - ) - if vmfb is not None: - return vmfb - else: - vmfb = get_vmfb_from_config( - self.model_name, - "qformer", - "fp32", - self.device, - self.qformer_vmfb_path, - ) - if vmfb is not None: - return vmfb - - qformerBertModel = QformerBertModel(self.model.Qformer.bert) - extended_model_name = f"qformer_fp32_{self.device}" - print(f"Going to compile {extended_model_name}") - # Inputs for QFormer. - inputs = [ - torch.randint(3, (1, 32, 768), dtype=torch.float32), - torch.randint(3, (1, 257, 1408), dtype=torch.float32), - torch.randint(3, (1, 257), dtype=torch.int64), - ] - is_f16 = False - f16_input_mask = [] - shark_QformerBertModel, _ = shark_compile_through_fx( - qformerBertModel, - inputs, - extended_model_name=extended_model_name, - precision="fp32", - f16_input_mask=f16_input_mask, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) - print(f"Generated {extended_model_name}.vmfb") - return shark_QformerBertModel - - def compile_first_llama(self, padding): - self.first_llama_vmfb_path = Path( - f"first_llama_{self.precision}_{self.device}_{padding}.vmfb" - ) - if not self._compile: - vmfb = get_vmfb_from_path( - self.first_llama_vmfb_path, self.device, "tm_tensor" - ) - if vmfb is not None: - self.first_llama = vmfb - return vmfb - else: - vmfb = get_vmfb_from_config( - self.model_name, - "first_llama", - self.precision, - self.device, - self.first_llama_vmfb_path, - padding, - ) - if vmfb is not None: - self.first_llama = vmfb - return vmfb - - firstLlamaModel = FirstLlamaModel( - copy.deepcopy(self.model.llama_model), self.precision - ) - extended_model_name = ( - f"first_llama_{self.precision}_{self.device}_{padding}" - ) - print(f"Going to compile {extended_model_name}") - # Inputs for FirstLlama. - inputs_embeds = torch.ones((1, padding, 4096), dtype=torch.float32) - position_ids = torch.ones((1, padding), dtype=torch.int64) - attention_mask = torch.ones((1, padding), dtype=torch.int32) - inputs = [inputs_embeds, position_ids, attention_mask] - is_f16 = False - f16_input_mask = [] - if self.precision == "fp16": - is_f16 = True - f16_input_mask = [True, False, False] - if self.precision in ["int4", "int8"]: - shark_firstLlamaModel, _ = shark_compile_through_fx_int( - firstLlamaModel, - inputs, - extended_model_name=extended_model_name, - precision=self.precision, - f16_input_mask=f16_input_mask, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) - else: - shark_firstLlamaModel, _ = shark_compile_through_fx( - firstLlamaModel, - inputs, - extended_model_name=extended_model_name, - precision=self.precision, - f16_input_mask=f16_input_mask, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) - print(f"Generated {extended_model_name}.vmfb") - self.first_llama = shark_firstLlamaModel - return shark_firstLlamaModel - - def compile_second_llama(self, padding): - self.second_llama_vmfb_path = Path( - f"second_llama_{self.precision}_{self.device}_{padding}.vmfb" - ) - if not self._compile: - vmfb = get_vmfb_from_path( - self.second_llama_vmfb_path, self.device, "tm_tensor" - ) - if vmfb is not None: - self.second_llama = vmfb - return vmfb - else: - vmfb = get_vmfb_from_config( - self.model_name, - "second_llama", - self.precision, - self.device, - self.second_llama_vmfb_path, - padding, - ) - if vmfb is not None: - self.second_llama = vmfb - return vmfb - - secondLlamaModel = SecondLlamaModel( - copy.deepcopy(self.model.llama_model), self.precision - ) - extended_model_name = ( - f"second_llama_{self.precision}_{self.device}_{padding}" - ) - print(f"Going to compile {extended_model_name}") - # Inputs for SecondLlama. - input_ids = torch.zeros((1, 1), dtype=torch.int64) - position_ids = torch.zeros((1, 1), dtype=torch.int64) - attention_mask = torch.zeros((1, padding + 1), dtype=torch.int32) - past_key_value = [] - for i in range(64): - past_key_value.append( - torch.zeros(1, 32, padding, 128, dtype=torch.float32) - ) - inputs = [input_ids, position_ids, attention_mask, *past_key_value] - is_f16 = False - f16_input_mask = [] - if self.precision == "fp16": - is_f16 = True - f16_input_mask = [False, False, False] - for i in past_key_value: - f16_input_mask.append(True) - - if self.precision in ["int4", "int8"]: - shark_secondLlamaModel, _ = shark_compile_through_fx_int( - secondLlamaModel, - inputs, - extended_model_name=extended_model_name, - precision=self.precision, - f16_input_mask=f16_input_mask, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) - else: - shark_secondLlamaModel, _ = shark_compile_through_fx( - secondLlamaModel, - inputs, - extended_model_name=extended_model_name, - precision=self.precision, - f16_input_mask=f16_input_mask, - save_dir=tempfile.gettempdir(), - debug=False, - generate_or_load_vmfb=True, - extra_args=[], - device=self.device, - mlir_dialect="tm_tensor", - ) - print(f"Generated {extended_model_name}.vmfb") - self.second_llama = shark_secondLlamaModel - return shark_secondLlamaModel - - # Not yet sure why to use this. - def compile(self): - pass - - # Going to use `answer` instead. - def generate(self, prompt): - pass - - # Might use within `answer`, if needed. - def generate_new_token(self, params): - pass - - # Not needed yet because MiniGPT4BaseModel already loads this - will revisit later, - # if required. - def get_tokenizer(self): - pass - - # DumDum func - doing the intended stuff already at MiniGPT4BaseModel, - # i.e load llama, etc. - def get_src_model(self): - pass - - def ask(self, text, conv): - if ( - len(conv.messages) > 0 - and conv.messages[-1][0] == conv.roles[0] - and conv.messages[-1][1][-6:] == "" - ): # last message is image. - conv.messages[-1][1] = " ".join([conv.messages[-1][1], text]) - else: - conv.append_message(conv.roles[0], text) - - def answer( - self, - conv, - img_list, - max_new_tokens=300, - num_beams=1, - min_length=1, - top_p=0.9, - repetition_penalty=1.0, - length_penalty=1, - temperature=1.0, - max_length=2000, - ): - conv.append_message(conv.roles[1], None) - embs = self.get_context_emb( - conv, img_list, max_length - max_new_tokens - ) - padding = max_length - max_new_tokens - - current_max_len = embs.shape[1] + max_new_tokens - - if current_max_len - max_length > 0: - print( - "Warning: The number of tokens in current conversation exceeds the max length. " - "The model will not see the contexts outside the range." - ) - begin_idx = max(0, current_max_len - max_length) - - embs = embs[:, begin_idx:] - - ######################################################################################################### - - generation_config = GenerationConfig.from_model_config( - self.model.llama_model.config - ) - kwargs = { - "inputs_embeds": embs, - "max_new_tokens": max_new_tokens, - "num_beams": num_beams, - "do_sample": True, - "min_length": min_length, - "top_p": top_p, - "repetition_penalty": repetition_penalty, - "length_penalty": length_penalty, - "temperature": temperature, - } - generation_config = copy.deepcopy(generation_config) - model_kwargs = generation_config.update(**kwargs) - logits_processor = LogitsProcessorList() - stopping_criteria = self.stopping_criteria - inputs = None - ( - inputs_tensor, - model_input_name, - model_kwargs, - ) = self.model.llama_model._prepare_model_inputs( - inputs, generation_config.bos_token_id, model_kwargs - ) - model_kwargs["output_attentions"] = generation_config.output_attentions - model_kwargs[ - "output_hidden_states" - ] = generation_config.output_hidden_states - model_kwargs["use_cache"] = generation_config.use_cache - generation_config.pad_token_id = ( - self.model.llama_tokenizer.pad_token_id - ) - pad_token_id = generation_config.pad_token_id - embs_for_pad_token_id = self.model.llama_model.model.embed_tokens( - torch.tensor([pad_token_id]) - ) - model_kwargs["attention_mask"] = torch.logical_not( - torch.tensor( - [ - torch.all( - torch.eq(inputs_tensor[:, d, :], embs_for_pad_token_id) - ).int() - for d in range(inputs_tensor.shape[1]) - ] - ).unsqueeze(0) - ).int() - attention_meta_data = (model_kwargs["attention_mask"][0] == 0).nonzero( - as_tuple=True - )[0] - first_zero = attention_meta_data[0].item() - last_zero = attention_meta_data[-1].item() - input_ids = ( - inputs_tensor - if model_input_name == "input_ids" - else model_kwargs.pop("input_ids") - ) - input_ids_seq_length = input_ids.shape[-1] - generation_config.max_length = ( - generation_config.max_new_tokens + input_ids_seq_length - ) - logits_warper = self.model.llama_model._get_logits_warper( - generation_config - ) - ( - input_ids, - model_kwargs, - ) = self.model.llama_model._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, - is_encoder_decoder=False, - **model_kwargs, - ) - # DOUBT: stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) - logits_warper = ( - logits_warper - if logits_warper is not None - else LogitsProcessorList() - ) - pad_token_id = generation_config.pad_token_id - eos_token_id = generation_config.eos_token_id - if isinstance(eos_token_id, int): - eos_token_id = [eos_token_id] - eos_token_id_tensor = ( - torch.tensor(eos_token_id).to(input_ids.device) - if eos_token_id is not None - else None - ) - scores = None - - # keep track of which sequences are already finished - unfinished_sequences = torch.ones( - input_ids.shape[0], dtype=torch.long, device=input_ids.device - ) - i = 0 - timesRan = 0 - is_fp16 = self.precision == "fp16" - llama_list = [] - isPyTorchVariant = False - while True: - print("****** Iteration %d ******" % (i)) - # prepare model inputs - model_inputs = ( - self.model.llama_model.prepare_inputs_for_generation( - input_ids, **model_kwargs - ) - ) - - # forward pass to get next token - if i == 0: - shark_inputs = [] - if is_fp16: - model_inputs["inputs_embeds"] = model_inputs[ - "inputs_embeds" - ].to(torch.float16) - shark_inputs.append(model_inputs["inputs_embeds"].detach()) - shark_inputs.append(model_inputs["position_ids"].detach()) - shark_inputs.append(model_inputs["attention_mask"].detach()) - - if self.first_llama is None: - self.compile_first_llama(padding) - outputs_shark = self.first_llama("forward", shark_inputs) - outputs = [] - for out_shark in outputs_shark: - outputs.append(torch.from_numpy(out_shark)) - del outputs_shark - else: - shark_inputs = [] - shark_inputs.append(model_inputs["input_ids"].detach()) - shark_inputs.append(model_inputs["position_ids"].detach()) - shark_inputs.append(model_inputs["attention_mask"].detach()) - for pkv in list(model_inputs["past_key_values"]): - shark_inputs.append(pkv.detach()) - if self.second_llama is None: - self.compile_second_llama(padding) - outputs_shark = self.second_llama("forward", shark_inputs) - outputs = [] - for out_shark in outputs_shark: - outputs.append(torch.from_numpy(out_shark)) - del outputs_shark - - outputs_logits = outputs[0] - next_token_logits = outputs_logits[:, -1, :] - if is_fp16: - next_token_logits = next_token_logits.to(torch.float32) - - # pre-process distribution - next_token_scores = logits_processor(input_ids, next_token_logits) - next_token_scores = logits_warper(input_ids, next_token_scores) - probs = torch.nn.functional.softmax(next_token_scores, dim=-1) - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) - - # finished sentences should have their next token be a padding token - if eos_token_id is not None: - if pad_token_id is None: - raise ValueError( - "If `eos_token_id` is defined, make sure that `pad_token_id` is defined." - ) - next_tokens = ( - next_tokens * unfinished_sequences - + pad_token_id * (1 - unfinished_sequences) - ) - - # update generated ids, model inputs, and length for next step - outputs_for_update_func = {} - input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) - model_kwargs = ( - self.model.llama_model._update_model_kwargs_for_generation( - outputs_for_update_func, - model_kwargs, - is_encoder_decoder=False, - ) - ) - model_kwargs["past_key_values"] = outputs[1:] - if timesRan >= 1: - tmp_attention_mask = torch.cat( - ( - model_kwargs["attention_mask"][:, :first_zero], - model_kwargs["attention_mask"][:, first_zero + 1 :], - ), - dim=1, - ) - model_kwargs["attention_mask"] = tmp_attention_mask - pkv_list = [] - for pkv_pair_tuple in model_kwargs["past_key_values"]: - x = torch.cat( - ( - pkv_pair_tuple[:, :, :first_zero, :], - pkv_pair_tuple[:, :, first_zero + 1 :, :], - ), - dim=2, - ) - if is_fp16: - x = x.to(torch.float16) - pkv_list.append(x) - model_kwargs["past_key_values"] = tuple(pkv_list) - - # if eos_token was found in one sentence, set sentence to finished - if eos_token_id_tensor is not None: - unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile(eos_token_id_tensor.shape[0], 1) - .ne(eos_token_id_tensor.unsqueeze(1)) - .prod(dim=0) - ) - - # stop when each sentence is finished, or if we exceed the maximum length - if unfinished_sequences.max() == 0 or stopping_criteria( - input_ids, scores - ): - break - - i = i + 1 - timesRan += 1 - llama_list.clear() - output_token = input_ids[0] - - if ( - output_token[0] == 0 - ): # the model might output a unknow token at the beginning. remove it - output_token = output_token[1:] - if ( - output_token[0] == 1 - ): # some users find that there is a start token at the beginning. remove it - output_token = output_token[1:] - output_text = self.model.llama_tokenizer.decode( - output_token, add_special_tokens=False - ) - output_text = output_text.split("###")[0] # remove the stop sign '###' - output_text = output_text.split("Assistant:")[-1].strip() - conv.messages[-1][1] = output_text - return output_text, output_token.cpu().numpy() - - def upload_img(self, image, conv, img_list): - if isinstance(image, str): # is a image path - raw_image = Image.open(image).convert("RGB") - image = self.vis_processor(raw_image).unsqueeze(0).to("cpu") - elif isinstance(image, Image.Image): - raw_image = image - image = self.vis_processor(raw_image).unsqueeze(0).to("cpu") - elif isinstance(image, torch.Tensor): - if len(image.shape) == 3: - image = image.unsqueeze(0) - image = image.to("cpu") - - device = image.device - if self.model.low_resource: - self.model.vit_to_cpu() - image = image.to("cpu") - - with self.model.maybe_autocast(): - shark_visionModel = self.compile_vision_model() - if self.precision in ["int4", "int8", "fp16"]: - image = image.to(torch.float16) - image_embeds = shark_visionModel("forward", (image,)) - # image_embeds = shark_visionModel.forward(image) - image_embeds = torch.from_numpy(image_embeds) - image_embeds = image_embeds.to(device).to(torch.float32) - image_atts = torch.ones( - image_embeds.size()[:-1], dtype=torch.long - ).to(device) - - query_tokens = self.model.query_tokens.expand( - image_embeds.shape[0], -1, -1 - ).to(device) - shark_QformerBertModel = self.compile_qformer_model() - query_output = shark_QformerBertModel( - "forward", - ( - query_tokens, - image_embeds, - image_atts, - ), - ) - query_output = torch.from_numpy(query_output) - - inputs_llama = self.model.llama_proj(query_output) - image_emb = inputs_llama - img_list.append(image_emb) - conv.append_message(conv.roles[0], "") - msg = "Received." - return msg - - # """ - def get_context_emb(self, conv, img_list, max_allowed_tokens=200): - self.model.llama_tokenizer.padding_side = "left" - prompt = conv.get_prompt() - prompt_segs = prompt.split("") - assert ( - len(prompt_segs) == len(img_list) + 1 - ), "Unmatched numbers of image placeholders and images." - prompt_segs_pre = prompt_segs[:-1] - seg_tokens_pre = [] - for i, seg in enumerate(prompt_segs_pre): - # only add bos to the first seg - if i == 0: - add_special_tokens = True - else: - add_special_tokens = False - stp = ( - self.model.llama_tokenizer( - seg, - return_tensors="pt", - add_special_tokens=add_special_tokens, - ) - .to("cpu") - .input_ids - ) - seg_tokens_pre.append(stp) - # seg_tokens_pre = [ - # self.model.llama_tokenizer( - # seg, return_tensors="pt", add_special_tokens=i == 0 - # ) - # .to("cpu") - # .input_ids - # for i, seg in enumerate(prompt_segs_pre) - # ] - print( - "Before :-\nLlama model pad token : ", - self.model.llama_model.config.pad_token_id, - ) - print( - "Llama tokenizer pad token : ", - self.model.llama_tokenizer.pad_token_id, - ) - self.model.llama_model.config.pad_token_id = ( - self.model.llama_tokenizer.pad_token_id - ) - print( - "After :-\nLlama model pad token : ", - self.model.llama_model.config.pad_token_id, - ) - print( - "Llama tokenizer pad token : ", - self.model.llama_tokenizer.pad_token_id, - ) - print("seg_t :", seg_tokens_pre[0]) - - seg_embs_pre = [ - self.model.llama_model.model.embed_tokens(seg_t) - for seg_t in seg_tokens_pre - ] - mixed_embs_pre = [ - emb.to("cpu") - for pair in zip(seg_embs_pre, img_list) - for emb in pair - ] - mixed_embs_pre = torch.cat(mixed_embs_pre, dim=1) - max_allowed_tokens = max_allowed_tokens - mixed_embs_pre.shape[1] - final_prompt = prompt_segs[-1] - seg_tokens_post = [ - self.model.llama_tokenizer( - seg, - return_tensors="pt", - padding="max_length", - max_length=max_allowed_tokens, - add_special_tokens=False, - ) - .to("cpu") - .input_ids - # only add bos to the first seg - for i, seg in enumerate([final_prompt]) - ] - seg_tokens_post = seg_tokens_post[0] - seg_embs_post = [ - self.model.llama_model.model.embed_tokens(seg_t) - for seg_t in seg_tokens_post - ] - mixed_embs_post = [seg_embs_post[0].to("cpu")] - mixed_embs_post = torch.unsqueeze(mixed_embs_post[0], 0) - mixed_embs = [mixed_embs_pre] + [mixed_embs_post] - mixed_embs = torch.cat(mixed_embs, dim=1) - return mixed_embs - - -if __name__ == "__main__": - args = parser.parse_args() - - device = args.device - precision = args.precision - _compile = args.compile - max_length = args.max_length - max_new_tokens = args.max_new_tokens - print("Will run SHARK MultiModal for the following paramters :-\n") - print( - f"Device={device} precision={precision} compile={_compile} max_length={max_length} max_new_tokens={max_new_tokens}" - ) - - padding = max_length - max_new_tokens - assert ( - padding > 0 - ), "max_length should be strictly greater than max_new_tokens" - - if args.image_path == "": - print( - "To run MiniGPT4 in CLI mode please provide an image's path using --image_path" - ) - sys.exit() - - vision_model_precision = precision - if precision in ["int4", "int8"]: - vision_model_precision = "fp16" - vision_model_vmfb_path = ( - Path(f"vision_model_{vision_model_precision}_{device}.vmfb") - if args.vision_model_vmfb_path is None - else Path(args.vision_model_vmfb_path) - ) - qformer_vmfb_path = ( - Path(f"qformer_fp32_{device}.vmfb") - if args.qformer_vmfb_path is None - else Path(args.qformer_vmfb_path) - ) - chat = MiniGPT4( - model_name="MiniGPT4", - hf_model_path=None, - max_new_tokens=30, - device=device, - precision=precision, - _compile=_compile, - vision_model_vmfb_path=vision_model_vmfb_path, - qformer_vmfb_path=qformer_vmfb_path, - ) - - chat_state = CONV_VISION.copy() - img_list = [] - chat.upload_img(args.image_path, chat_state, img_list) - print( - "Uploaded image successfully to the bot. You may now start chatting with the bot. Enter 'END' without quotes to end the interaction" - ) - continue_execution = True - - while continue_execution: - user_message = input("User: ") - if user_message == "END": - print("Bot: Good bye.\n") - break - chat.ask(user_message, chat_state) - bot_message = chat.answer( - conv=chat_state, - img_list=img_list, - num_beams=1, - temperature=1.0, - max_new_tokens=max_new_tokens, - max_length=max_length, - )[0] - print("Bot: ", bot_message) - - del chat_state, img_list, chat diff --git a/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py b/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py deleted file mode 100644 index 6944a9dd..00000000 --- a/apps/language_models/src/pipelines/minigpt4_utils/Qformer.py +++ /dev/null @@ -1,1297 +0,0 @@ -""" - * Copyright (c) 2023, salesforce.com, inc. - * All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause - * By Junnan Li - * Based on huggingface code base - * https://github.com/huggingface/transformers/blob/v4.15.0/src/transformers/models/bert -""" - -import math -from dataclasses import dataclass -from typing import Tuple, Dict, Any - -import torch -from torch import Tensor, device, nn -import torch.utils.checkpoint -from torch import nn -from torch.nn import CrossEntropyLoss - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, - BaseModelOutputWithPoolingAndCrossAttentions, - CausalLMOutputWithCrossAttentions, - MaskedLMOutput, -) -from transformers.modeling_utils import ( - PreTrainedModel, - apply_chunking_to_forward, - find_pruneable_heads_and_indices, - prune_linear_layer, -) -from transformers.utils import logging -from transformers.models.bert.configuration_bert import BertConfig - -logger = logging.get_logger(__name__) - - -class BertEmbeddings(nn.Module): - """Construct the embeddings from word and position embeddings.""" - - def __init__(self, config): - super().__init__() - self.word_embeddings = nn.Embedding( - config.vocab_size, - config.hidden_size, - padding_idx=config.pad_token_id, - ) - self.position_embeddings = nn.Embedding( - config.max_position_embeddings, config.hidden_size - ) - - # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load - # any TensorFlow checkpoint file - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps, device="cpu" - ) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - # position_ids (1, len position emb) is contiguous in memory and exported when serialized - self.register_buffer( - "position_ids", - torch.arange(config.max_position_embeddings).expand((1, -1)), - ) - self.position_embedding_type = getattr( - config, "position_embedding_type", "absolute" - ) - - self.config = config - - def forward( - self, - input_ids=None, - position_ids=None, - query_embeds=None, - past_key_values_length=0, - ): - if input_ids is not None: - seq_length = input_ids.size()[1] - else: - seq_length = 0 - - if position_ids is None: - position_ids = self.position_ids[ - :, past_key_values_length : seq_length + past_key_values_length - ].clone() - - if input_ids is not None: - embeddings = self.word_embeddings(input_ids) - if self.position_embedding_type == "absolute": - position_embeddings = self.position_embeddings(position_ids) - embeddings = embeddings + position_embeddings - - if query_embeds is not None: - embeddings = torch.cat((query_embeds, embeddings), dim=1) - else: - embeddings = query_embeds - - embeddings = self.LayerNorm(embeddings) - embeddings = self.dropout(embeddings) - return embeddings - - -class BertSelfAttention(nn.Module): - def __init__(self, config, is_cross_attention): - super().__init__() - self.config = config - if ( - config.hidden_size % config.num_attention_heads != 0 - and not hasattr(config, "embedding_size") - ): - raise ValueError( - "The hidden size (%d) is not a multiple of the number of attention " - "heads (%d)" % (config.hidden_size, config.num_attention_heads) - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int( - config.hidden_size / config.num_attention_heads - ) - self.all_head_size = ( - self.num_attention_heads * self.attention_head_size - ) - - self.query = nn.Linear(config.hidden_size, self.all_head_size) - if is_cross_attention: - self.key = nn.Linear(config.encoder_width, self.all_head_size) - self.value = nn.Linear(config.encoder_width, self.all_head_size) - else: - self.key = nn.Linear(config.hidden_size, self.all_head_size) - self.value = nn.Linear(config.hidden_size, self.all_head_size) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - self.position_embedding_type = getattr( - config, "position_embedding_type", "absolute" - ) - if ( - self.position_embedding_type == "relative_key" - or self.position_embedding_type == "relative_key_query" - ): - self.max_position_embeddings = config.max_position_embeddings - self.distance_embedding = nn.Embedding( - 2 * config.max_position_embeddings - 1, - self.attention_head_size, - ) - self.save_attention = False - - def save_attn_gradients(self, attn_gradients): - self.attn_gradients = attn_gradients - - def get_attn_gradients(self): - return self.attn_gradients - - def save_attention_map(self, attention_map): - self.attention_map = attention_map - - def get_attention_map(self): - return self.attention_map - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + ( - self.num_attention_heads, - self.attention_head_size, - ) - x = x.view(*new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. - is_cross_attention = encoder_hidden_states is not None - - if is_cross_attention: - key_layer = self.transpose_for_scores( - self.key(encoder_hidden_states) - ) - value_layer = self.transpose_for_scores( - self.value(encoder_hidden_states) - ) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) - else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - - mixed_query_layer = self.query(hidden_states) - - query_layer = self.transpose_for_scores(mixed_query_layer) - - past_key_value = (key_layer, value_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul( - query_layer, key_layer.transpose(-1, -2) - ) - - if ( - self.position_embedding_type == "relative_key" - or self.position_embedding_type == "relative_key_query" - ): - seq_length = hidden_states.size()[1] - position_ids_l = torch.arange( - seq_length, dtype=torch.long, device=hidden_states.device - ).view(-1, 1) - position_ids_r = torch.arange( - seq_length, dtype=torch.long, device=hidden_states.device - ).view(1, -1) - distance = position_ids_l - position_ids_r - positional_embedding = self.distance_embedding( - distance + self.max_position_embeddings - 1 - ) - positional_embedding = positional_embedding.to( - dtype=query_layer.dtype - ) # fp16 compatibility - - if self.position_embedding_type == "relative_key": - relative_position_scores = torch.einsum( - "bhld,lrd->bhlr", query_layer, positional_embedding - ) - attention_scores = attention_scores + relative_position_scores - elif self.position_embedding_type == "relative_key_query": - relative_position_scores_query = torch.einsum( - "bhld,lrd->bhlr", query_layer, positional_embedding - ) - relative_position_scores_key = torch.einsum( - "bhrd,lrd->bhlr", key_layer, positional_embedding - ) - attention_scores = ( - attention_scores - + relative_position_scores_query - + relative_position_scores_key - ) - - attention_scores = attention_scores / math.sqrt( - self.attention_head_size - ) - if attention_mask is not None: - # Apply the attention mask is (precomputed for all layers in BertModel forward() function) - attention_scores = attention_scores + attention_mask - - # Normalize the attention scores to probabilities. - attention_probs = nn.Softmax(dim=-1)(attention_scores) - - if is_cross_attention and self.save_attention: - self.save_attention_map(attention_probs) - attention_probs.register_hook(self.save_attn_gradients) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs_dropped = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs_dropped = attention_probs_dropped * head_mask - - context_layer = torch.matmul(attention_probs_dropped, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + ( - self.all_head_size, - ) - context_layer = context_layer.view(*new_context_layer_shape) - - outputs = ( - (context_layer, attention_probs) - if output_attentions - else (context_layer,) - ) - - outputs = outputs + (past_key_value,) - return outputs - - -class BertSelfOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertAttention(nn.Module): - def __init__(self, config, is_cross_attention=False): - super().__init__() - self.self = BertSelfAttention(config, is_cross_attention) - self.output = BertSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads): - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, - self.self.num_attention_heads, - self.self.attention_head_size, - self.pruned_heads, - ) - - # Prune linear layers - self.self.query = prune_linear_layer(self.self.query, index) - self.self.key = prune_linear_layer(self.self.key, index) - self.self.value = prune_linear_layer(self.self.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.self.num_attention_heads = self.self.num_attention_heads - len( - heads - ) - self.self.all_head_size = ( - self.self.attention_head_size * self.self.num_attention_heads - ) - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - ): - self_outputs = self.self( - hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - ) - attention_output = self.output(self_outputs[0], hidden_states) - - outputs = (attention_output,) + self_outputs[ - 1: - ] # add attentions if we output them - return outputs - - -class BertIntermediate(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - return hidden_states - - -class BertOutput(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states, input_tensor): - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - return hidden_states - - -class BertLayer(nn.Module): - def __init__(self, config, layer_num): - super().__init__() - self.config = config - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = BertAttention(config) - self.layer_num = layer_num - if ( - self.config.add_cross_attention - and layer_num % self.config.cross_attention_freq == 0 - ): - self.crossattention = BertAttention( - config, is_cross_attention=self.config.add_cross_attention - ) - self.has_cross_attention = True - else: - self.has_cross_attention = False - self.intermediate = BertIntermediate(config) - self.output = BertOutput(config) - - self.intermediate_query = BertIntermediate(config) - self.output_query = BertOutput(config) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_value=None, - output_attentions=False, - query_length=0, - ): - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = ( - past_key_value[:2] if past_key_value is not None else None - ) - self_attention_outputs = self.attention( - hidden_states, - attention_mask, - head_mask, - output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, - ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:-1] - - present_key_value = self_attention_outputs[-1] - - if query_length > 0: - query_attention_output = attention_output[:, :query_length, :] - - if self.has_cross_attention: - assert ( - encoder_hidden_states is not None - ), "encoder_hidden_states must be given for cross-attention layers" - cross_attention_outputs = self.crossattention( - query_attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - output_attentions=output_attentions, - ) - query_attention_output = cross_attention_outputs[0] - outputs = ( - outputs + cross_attention_outputs[1:-1] - ) # add cross attentions if we output attention weights - - layer_output = apply_chunking_to_forward( - self.feed_forward_chunk_query, - self.chunk_size_feed_forward, - self.seq_len_dim, - query_attention_output, - ) - if attention_output.shape[1] > query_length: - layer_output_text = apply_chunking_to_forward( - self.feed_forward_chunk, - self.chunk_size_feed_forward, - self.seq_len_dim, - attention_output[:, query_length:, :], - ) - layer_output = torch.cat( - [layer_output, layer_output_text], dim=1 - ) - else: - layer_output = apply_chunking_to_forward( - self.feed_forward_chunk, - self.chunk_size_feed_forward, - self.seq_len_dim, - attention_output, - ) - outputs = (layer_output,) + outputs - - outputs = outputs + (present_key_value,) - - return outputs - - def feed_forward_chunk(self, attention_output): - intermediate_output = self.intermediate(attention_output) - layer_output = self.output(intermediate_output, attention_output) - return layer_output - - def feed_forward_chunk_query(self, attention_output): - intermediate_output = self.intermediate_query(attention_output) - layer_output = self.output_query(intermediate_output, attention_output) - return layer_output - - -class BertEncoder(nn.Module): - def __init__(self, config): - super().__init__() - self.config = config - self.layer = nn.ModuleList( - [BertLayer(config, i) for i in range(config.num_hidden_layers)] - ) - - def forward( - self, - hidden_states, - attention_mask=None, - head_mask=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=False, - output_hidden_states=False, - return_dict=True, - query_length=0, - ): - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - all_cross_attentions = ( - () - if output_attentions and self.config.add_cross_attention - else None - ) - - next_decoder_cache = () if use_cache else None - - for i in range(self.config.num_hidden_layers): - layer_module = self.layer[i] - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = ( - past_key_values[i] if past_key_values is not None else None - ) - - if ( - getattr(self.config, "gradient_checkpointing", False) - and self.training - ): - if use_cache: - logger.warn( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - def create_custom_forward(module): - def custom_forward(*inputs): - return module( - *inputs, - past_key_value, - output_attentions, - query_length - ) - - return custom_forward - - layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - ) - else: - layer_outputs = layer_module( - hidden_states, - attention_mask, - layer_head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, - query_length, - ) - - hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - all_cross_attentions = all_cross_attentions + ( - layer_outputs[2], - ) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple( - v - for v in [ - hidden_states, - next_decoder_cache, - all_hidden_states, - all_self_attentions, - all_cross_attentions, - ] - if v is not None - ) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - cross_attentions=all_cross_attentions, - ) - - -class BertPooler(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output - - -class BertPredictionHeadTransform(nn.Module): - def __init__(self, config): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - if isinstance(config.hidden_act, str): - self.transform_act_fn = ACT2FN[config.hidden_act] - else: - self.transform_act_fn = config.hidden_act - self.LayerNorm = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - - def forward(self, hidden_states): - hidden_states = self.dense(hidden_states) - hidden_states = self.transform_act_fn(hidden_states) - hidden_states = self.LayerNorm(hidden_states) - return hidden_states - - -class BertLMPredictionHead(nn.Module): - def __init__(self, config): - super().__init__() - self.transform = BertPredictionHeadTransform(config) - - # The output weights are the same as the input embeddings, but there is - # an output-only bias for each token. - self.decoder = nn.Linear( - config.hidden_size, config.vocab_size, bias=False - ) - - self.bias = nn.Parameter(torch.zeros(config.vocab_size)) - - # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` - self.decoder.bias = self.bias - - def forward(self, hidden_states): - hidden_states = self.transform(hidden_states) - hidden_states = self.decoder(hidden_states) - return hidden_states - - -class BertOnlyMLMHead(nn.Module): - def __init__(self, config): - super().__init__() - self.predictions = BertLMPredictionHead(config) - - def forward(self, sequence_output): - prediction_scores = self.predictions(sequence_output) - return prediction_scores - - -class BertPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = BertConfig - base_model_prefix = "bert" - _keys_to_ignore_on_load_missing = [r"position_ids"] - - def _init_weights(self, module): - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Embedding)): - # Slightly different from the TF version which uses truncated_normal for initialization - # cf https://github.com/pytorch/pytorch/pull/5617 - module.weight.data.normal_( - mean=0.0, std=self.config.initializer_range - ) - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - if isinstance(module, nn.Linear) and module.bias is not None: - module.bias.data.zero_() - - -class BertModel(BertPreTrainedModel): - """ - The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of - cross-attention is added between the self-attention layers, following the architecture described in `Attention is - all you need `__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, - Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. - argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an - input to the forward pass. - """ - - def __init__(self, config, add_pooling_layer=False): - super().__init__(config) - self.config = config - - self.embeddings = BertEmbeddings(config) - - self.encoder = BertEncoder(config) - - self.pooler = BertPooler(config) if add_pooling_layer else None - - self.init_weights() - - def get_input_embeddings(self): - return self.embeddings.word_embeddings - - def set_input_embeddings(self, value): - self.embeddings.word_embeddings = value - - def _prune_heads(self, heads_to_prune): - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def get_extended_attention_mask( - self, - attention_mask: Tensor, - input_shape: Tuple[int], - device: device, - is_decoder: bool, - has_query: bool = False, - ) -> Tensor: - """ - Makes broadcastable attention and causal masks so that future and masked tokens are ignored. - - Arguments: - attention_mask (:obj:`torch.Tensor`): - Mask with ones indicating tokens to attend to, zeros for tokens to ignore. - input_shape (:obj:`Tuple[int]`): - The shape of the input to the model. - device: (:obj:`torch.device`): - The device of the input to the model. - - Returns: - :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`. - """ - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if attention_mask.dim() == 3: - extended_attention_mask = attention_mask[:, None, :, :] - elif attention_mask.dim() == 2: - # Provided a padding mask of dimensions [batch_size, seq_length] - # - if the model is a decoder, apply a causal mask in addition to the padding mask - # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length] - if is_decoder: - batch_size, seq_length = input_shape - - seq_ids = torch.arange(seq_length, device=device) - causal_mask = ( - seq_ids[None, None, :].repeat(batch_size, seq_length, 1) - <= seq_ids[None, :, None] - ) - - # add a prefix ones mask to the causal mask - # causal and attention masks must have same type with pytorch version < 1.3 - causal_mask = causal_mask.to(attention_mask.dtype) - - if causal_mask.shape[1] < attention_mask.shape[1]: - prefix_seq_len = ( - attention_mask.shape[1] - causal_mask.shape[1] - ) - if has_query: # UniLM style attention mask - causal_mask = torch.cat( - [ - torch.zeros( - (batch_size, prefix_seq_len, seq_length), - device=device, - dtype=causal_mask.dtype, - ), - causal_mask, - ], - axis=1, - ) - causal_mask = torch.cat( - [ - torch.ones( - ( - batch_size, - causal_mask.shape[1], - prefix_seq_len, - ), - device=device, - dtype=causal_mask.dtype, - ), - causal_mask, - ], - axis=-1, - ) - extended_attention_mask = ( - causal_mask[:, None, :, :] - * attention_mask[:, None, None, :] - ) - else: - extended_attention_mask = attention_mask[:, None, None, :] - else: - raise ValueError( - "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format( - input_shape, attention_mask.shape - ) - ) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to( - dtype=self.dtype - ) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - return extended_attention_mask - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - query_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - past_key_values=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - is_decoder=False, - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - """ - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) - output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict - if return_dict is not None - else self.config.use_return_dict - ) - - # use_cache = use_cache if use_cache is not None else self.config.use_cache - - if input_ids is None: - assert ( - query_embeds is not None - ), "You have to specify query_embeds when input_ids is None" - - # past_key_values_length - past_key_values_length = ( - past_key_values[0][0].shape[2] - self.config.query_length - if past_key_values is not None - else 0 - ) - - query_length = query_embeds.shape[1] if query_embeds is not None else 0 - - embedding_output = self.embeddings( - input_ids=input_ids, - position_ids=position_ids, - query_embeds=query_embeds, - past_key_values_length=past_key_values_length, - ) - - input_shape = embedding_output.size()[:-1] - batch_size, seq_length = input_shape - device = embedding_output.device - - if attention_mask is None: - attention_mask = torch.ones( - ((batch_size, seq_length + past_key_values_length)), - device=device, - ) - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - if is_decoder: - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, - input_ids.shape, - device, - is_decoder, - has_query=(query_embeds is not None), - ) - else: - extended_attention_mask = self.get_extended_attention_mask( - attention_mask, input_shape, device, is_decoder - ) - - # If a 2D or 3D attention mask is provided for the cross-attention - # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] - if encoder_hidden_states is not None: - if type(encoder_hidden_states) == list: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states[0].size() - else: - ( - encoder_batch_size, - encoder_sequence_length, - _, - ) = encoder_hidden_states.size() - encoder_hidden_shape = ( - encoder_batch_size, - encoder_sequence_length, - ) - - if type(encoder_attention_mask) == list: - encoder_extended_attention_mask = [ - self.invert_attention_mask(mask) - for mask in encoder_attention_mask - ] - elif encoder_attention_mask is None: - encoder_attention_mask = torch.ones( - encoder_hidden_shape, device=device - ) - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) - else: - encoder_extended_attention_mask = self.invert_attention_mask( - encoder_attention_mask - ) - else: - encoder_extended_attention_mask = None - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask( - head_mask, self.config.num_hidden_layers - ) - - encoder_outputs = self.encoder( - embedding_output, - attention_mask=extended_attention_mask, - head_mask=head_mask, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_extended_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - query_length=query_length, - ) - sequence_output = encoder_outputs[0] - pooled_output = ( - self.pooler(sequence_output) if self.pooler is not None else None - ) - - if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] - - return BaseModelOutputWithPoolingAndCrossAttentions( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - past_key_values=encoder_outputs.past_key_values, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - cross_attentions=encoder_outputs.cross_attentions, - ) - - -class BertLMHeadModel(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [ - r"position_ids", - r"predictions.decoder.bias", - ] - - def __init__(self, config): - super().__init__(config) - - self.bert = BertModel(config, add_pooling_layer=False) - self.cls = BertOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - query_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - past_key_values=None, - use_cache=True, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - return_logits=False, - is_decoder=True, - reduction="mean", - ): - r""" - encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`): - Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if - the model is configured as a decoder. - encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in - the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``: - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in - ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are - ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]`` - past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`): - Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding. - If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids` - (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)` - instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`. - use_cache (:obj:`bool`, `optional`): - If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up - decoding (see :obj:`past_key_values`). - Returns: - Example:: - >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig - >>> import torch - >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased') - >>> config = BertConfig.from_pretrained("bert-base-cased") - >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config) - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") - >>> outputs = model(**inputs) - >>> prediction_logits = outputs.logits - """ - return_dict = ( - return_dict - if return_dict is not None - else self.config.use_return_dict - ) - if labels is not None: - use_cache = False - if past_key_values is not None: - query_embeds = None - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - query_embeds=query_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - past_key_values=past_key_values, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - is_decoder=is_decoder, - ) - - sequence_output = outputs[0] - if query_embeds is not None: - sequence_output = outputs[0][:, query_embeds.shape[1] :, :] - - prediction_scores = self.cls(sequence_output) - - if return_logits: - return prediction_scores[:, :-1, :].contiguous() - - lm_loss = None - if labels is not None: - # we are doing next-token prediction; shift prediction scores and input ids by one - shifted_prediction_scores = prediction_scores[ - :, :-1, : - ].contiguous() - labels = labels[:, 1:].contiguous() - loss_fct = CrossEntropyLoss( - reduction=reduction, label_smoothing=0.1 - ) - lm_loss = loss_fct( - shifted_prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1), - ) - if reduction == "none": - lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ((lm_loss,) + output) if lm_loss is not None else output - - return CausalLMOutputWithCrossAttentions( - loss=lm_loss, - logits=prediction_scores, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - cross_attentions=outputs.cross_attentions, - ) - - def prepare_inputs_for_generation( - self, - input_ids, - query_embeds, - past=None, - attention_mask=None, - **model_kwargs - ): - # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly - if attention_mask is None: - attention_mask = input_ids.new_ones(input_ids.shape) - query_mask = input_ids.new_ones(query_embeds.shape[:-1]) - attention_mask = torch.cat([query_mask, attention_mask], dim=-1) - - # cut decoder_input_ids if past is used - if past is not None: - input_ids = input_ids[:, -1:] - - return { - "input_ids": input_ids, - "query_embeds": query_embeds, - "attention_mask": attention_mask, - "past_key_values": past, - "encoder_hidden_states": model_kwargs.get( - "encoder_hidden_states", None - ), - "encoder_attention_mask": model_kwargs.get( - "encoder_attention_mask", None - ), - "is_decoder": True, - } - - def _reorder_cache(self, past, beam_idx): - reordered_past = () - for layer_past in past: - reordered_past += ( - tuple( - past_state.index_select(0, beam_idx) - for past_state in layer_past - ), - ) - return reordered_past - - -class BertForMaskedLM(BertPreTrainedModel): - _keys_to_ignore_on_load_unexpected = [r"pooler"] - _keys_to_ignore_on_load_missing = [ - r"position_ids", - r"predictions.decoder.bias", - ] - - def __init__(self, config): - super().__init__(config) - - self.bert = BertModel(config, add_pooling_layer=False) - self.cls = BertOnlyMLMHead(config) - - self.init_weights() - - def get_output_embeddings(self): - return self.cls.predictions.decoder - - def set_output_embeddings(self, new_embeddings): - self.cls.predictions.decoder = new_embeddings - - def forward( - self, - input_ids=None, - attention_mask=None, - position_ids=None, - head_mask=None, - query_embeds=None, - encoder_hidden_states=None, - encoder_attention_mask=None, - labels=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - return_logits=False, - is_decoder=False, - ): - r""" - labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): - Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ..., - config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored - (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]`` - """ - - return_dict = ( - return_dict - if return_dict is not None - else self.config.use_return_dict - ) - - outputs = self.bert( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - head_mask=head_mask, - query_embeds=query_embeds, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - is_decoder=is_decoder, - ) - - if query_embeds is not None: - sequence_output = outputs[0][:, query_embeds.shape[1] :, :] - prediction_scores = self.cls(sequence_output) - - if return_logits: - return prediction_scores - - masked_lm_loss = None - if labels is not None: - loss_fct = CrossEntropyLoss() # -100 index = padding token - masked_lm_loss = loss_fct( - prediction_scores.view(-1, self.config.vocab_size), - labels.view(-1), - ) - - if not return_dict: - output = (prediction_scores,) + outputs[2:] - return ( - ((masked_lm_loss,) + output) - if masked_lm_loss is not None - else output - ) - - return MaskedLMOutput( - loss=masked_lm_loss, - logits=prediction_scores, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) diff --git a/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py b/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py deleted file mode 100644 index 8c10c659..00000000 --- a/apps/language_models/src/pipelines/minigpt4_utils/blip_processors.py +++ /dev/null @@ -1,68 +0,0 @@ -""" - Copyright (c) 2022, salesforce.com, inc. - All rights reserved. - SPDX-License-Identifier: BSD-3-Clause - For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause -""" -from omegaconf import OmegaConf -from torchvision import transforms -from torchvision.transforms.functional import InterpolationMode - - -class BaseProcessor: - def __init__(self): - self.transform = lambda x: x - return - - def __call__(self, item): - return self.transform(item) - - @classmethod - def from_config(cls, cfg=None): - return cls() - - def build(self, **kwargs): - cfg = OmegaConf.create(kwargs) - - return self.from_config(cfg) - - -class BlipImageBaseProcessor(BaseProcessor): - def __init__(self, mean=None, std=None): - if mean is None: - mean = (0.48145466, 0.4578275, 0.40821073) - if std is None: - std = (0.26862954, 0.26130258, 0.27577711) - - self.normalize = transforms.Normalize(mean, std) - - -class Blip2ImageEvalProcessor(BlipImageBaseProcessor): - def __init__(self, image_size=224, mean=None, std=None): - super().__init__(mean=mean, std=std) - - self.transform = transforms.Compose( - [ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - self.normalize, - ] - ) - - def __call__(self, item): - return self.transform(item) - - @classmethod - def from_config(cls, cfg=None): - if cfg is None: - cfg = OmegaConf.create() - - image_size = cfg.get("image_size", 224) - - mean = cfg.get("mean", None) - std = cfg.get("std", None) - - return cls(image_size=image_size, mean=mean, std=std) diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml deleted file mode 100644 index 57108342..00000000 --- a/apps/language_models/src/pipelines/minigpt4_utils/configs/cc_sbu_align.yaml +++ /dev/null @@ -1,5 +0,0 @@ -datasets: - cc_sbu_align: - data_type: images - build_info: - storage: /path/to/cc_sbu_align/ diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml deleted file mode 100644 index 803d0a7e..00000000 --- a/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4.yaml +++ /dev/null @@ -1,33 +0,0 @@ -model: - arch: mini_gpt4 - - # vit encoder - image_size: 224 - drop_path_rate: 0 - use_grad_checkpoint: False - vit_precision: "fp16" - freeze_vit: True - freeze_qformer: True - - # Q-Former - num_query_token: 32 - - # Vicuna - llama_model: "lmsys/vicuna-7b-v1.3" - - # generation configs - prompt: "" - -preprocess: - vis_processor: - train: - name: "blip2_image_train" - image_size: 224 - eval: - name: "blip2_image_eval" - image_size: 224 - text_processor: - train: - name: "blip_caption" - eval: - name: "blip_caption" diff --git a/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml b/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml deleted file mode 100644 index e54f44cc..00000000 --- a/apps/language_models/src/pipelines/minigpt4_utils/configs/minigpt4_eval.yaml +++ /dev/null @@ -1,25 +0,0 @@ -model: - arch: mini_gpt4 - model_type: pretrain_vicuna - freeze_vit: True - freeze_qformer: True - max_txt_len: 160 - end_sym: "###" - low_resource: False - prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt" - prompt_template: '###Human: {} ###Assistant: ' - ckpt: 'prerained_minigpt4_7b.pth' - - -datasets: - cc_sbu_align: - vis_processor: - train: - name: "blip2_image_eval" - image_size: 224 - text_processor: - train: - name: "blip_caption" - -run: - task: image_text_pretrain diff --git a/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py b/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py deleted file mode 100644 index 7cf50989..00000000 --- a/apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py +++ /dev/null @@ -1,629 +0,0 @@ -# Based on EVA, BEIT, timm and DeiT code bases -# https://github.com/baaivision/EVA -# https://github.com/rwightman/pytorch-image-models/tree/master/timm -# https://github.com/microsoft/unilm/tree/master/beit -# https://github.com/facebookresearch/deit/ -# https://github.com/facebookresearch/dino -# --------------------------------------------------------' -import math -import requests -from functools import partial - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -from timm.models.layers import drop_path, to_2tuple, trunc_normal_ - - -def _cfg(url="", **kwargs): - return { - "url": url, - "num_classes": 1000, - "input_size": (3, 224, 224), - "pool_size": None, - "crop_pct": 0.9, - "interpolation": "bicubic", - "mean": (0.5, 0.5, 0.5), - "std": (0.5, 0.5, 0.5), - **kwargs, - } - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" - - def __init__(self, drop_prob=None): - super(DropPath, self).__init__() - self.drop_prob = drop_prob - - def forward(self, x): - return drop_path(x, self.drop_prob, self.training) - - def extra_repr(self) -> str: - return "p={}".format(self.drop_prob) - - -class Mlp(nn.Module): - def __init__( - self, - in_features, - hidden_features=None, - out_features=None, - act_layer=nn.GELU, - drop=0.0, - ): - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.fc1 = nn.Linear(in_features, hidden_features) - self.act = act_layer() - self.fc2 = nn.Linear(hidden_features, out_features) - self.drop = nn.Dropout(drop) - - def forward(self, x): - x = self.fc1(x) - x = self.act(x) - # x = self.drop(x) - # commit this for the orignal BERT implement - x = self.fc2(x) - x = self.drop(x) - return x - - -class Attention(nn.Module): - def __init__( - self, - dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - window_size=None, - attn_head_dim=None, - ): - super().__init__() - self.num_heads = num_heads - head_dim = dim // num_heads - if attn_head_dim is not None: - head_dim = attn_head_dim - all_head_dim = head_dim * self.num_heads - self.scale = qk_scale or head_dim**-0.5 - - self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) - if qkv_bias: - self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) - self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) - else: - self.q_bias = None - self.v_bias = None - - if window_size: - self.window_size = window_size - self.num_relative_distance = (2 * window_size[0] - 1) * ( - 2 * window_size[1] - 1 - ) + 3 - self.relative_position_bias_table = nn.Parameter( - torch.zeros(self.num_relative_distance, num_heads) - ) # 2*Wh-1 * 2*Ww-1, nH - # cls to token & token 2 cls & cls to cls - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(window_size[0]) - coords_w = torch.arange(window_size[1]) - coords = torch.stack( - torch.meshgrid([coords_h, coords_w]) - ) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = ( - coords_flatten[:, :, None] - coords_flatten[:, None, :] - ) # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute( - 1, 2, 0 - ).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += ( - window_size[0] - 1 - ) # shift to start from 0 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[1] - 1 - relative_position_index = torch.zeros( - size=(window_size[0] * window_size[1] + 1,) * 2, - dtype=relative_coords.dtype, - ) - relative_position_index[1:, 1:] = relative_coords.sum( - -1 - ) # Wh*Ww, Wh*Ww - relative_position_index[0, 0:] = self.num_relative_distance - 3 - relative_position_index[0:, 0] = self.num_relative_distance - 2 - relative_position_index[0, 0] = self.num_relative_distance - 1 - - self.register_buffer( - "relative_position_index", relative_position_index - ) - else: - self.window_size = None - self.relative_position_bias_table = None - self.relative_position_index = None - - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(all_head_dim, dim) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x, rel_pos_bias=None): - B, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: - qkv_bias = torch.cat( - ( - self.q_bias, - torch.zeros_like(self.v_bias, requires_grad=False), - self.v_bias, - ) - ) - # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) - qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = ( - qkv[0], - qkv[1], - qkv[2], - ) # make torchscript happy (cannot use tensor as tuple) - - q = q * self.scale - attn = q @ k.transpose(-2, -1) - - if self.relative_position_bias_table is not None: - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) - ].view( - self.window_size[0] * self.window_size[1] + 1, - self.window_size[0] * self.window_size[1] + 1, - -1, - ) # Wh*Ww,Wh*Ww,nH - relative_position_bias = relative_position_bias.permute( - 2, 0, 1 - ).contiguous() # nH, Wh*Ww, Wh*Ww - attn = attn + relative_position_bias.unsqueeze(0) - - if rel_pos_bias is not None: - attn = attn + rel_pos_bias - - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, -1) - x = self.proj(x) - x = self.proj_drop(x) - return x - - -class Block(nn.Module): - def __init__( - self, - dim, - num_heads, - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, - init_values=None, - act_layer=nn.GELU, - norm_layer=nn.LayerNorm, - window_size=None, - attn_head_dim=None, - ): - super().__init__() - self.norm1 = norm_layer(dim) - self.attn = Attention( - dim, - num_heads=num_heads, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - attn_drop=attn_drop, - proj_drop=drop, - window_size=window_size, - attn_head_dim=attn_head_dim, - ) - # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here - self.drop_path = ( - DropPath(drop_path) if drop_path > 0.0 else nn.Identity() - ) - self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp( - in_features=dim, - hidden_features=mlp_hidden_dim, - act_layer=act_layer, - drop=drop, - ) - - if init_values is not None and init_values > 0: - self.gamma_1 = nn.Parameter( - init_values * torch.ones((dim)), requires_grad=True - ) - self.gamma_2 = nn.Parameter( - init_values * torch.ones((dim)), requires_grad=True - ) - else: - self.gamma_1, self.gamma_2 = None, None - - def forward(self, x, rel_pos_bias=None): - if self.gamma_1 is None: - x = x + self.drop_path( - self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) - ) - x = x + self.drop_path(self.mlp(self.norm2(x))) - else: - x = x + self.drop_path( - self.gamma_1 - * self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias) - ) - x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x))) - return x - - -class PatchEmbed(nn.Module): - """Image to Patch Embedding""" - - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): - super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - num_patches = (img_size[1] // patch_size[1]) * ( - img_size[0] // patch_size[0] - ) - self.patch_shape = ( - img_size[0] // patch_size[0], - img_size[1] // patch_size[1], - ) - self.img_size = img_size - self.patch_size = patch_size - self.num_patches = num_patches - - self.proj = nn.Conv2d( - in_chans, embed_dim, kernel_size=patch_size, stride=patch_size - ) - - def forward(self, x, **kwargs): - B, C, H, W = x.shape - # FIXME look at relaxing size constraints - assert ( - H == self.img_size[0] and W == self.img_size[1] - ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) - return x - - -class RelativePositionBias(nn.Module): - def __init__(self, window_size, num_heads): - super().__init__() - self.window_size = window_size - self.num_relative_distance = (2 * window_size[0] - 1) * ( - 2 * window_size[1] - 1 - ) + 3 - self.relative_position_bias_table = nn.Parameter( - torch.zeros(self.num_relative_distance, num_heads) - ) # 2*Wh-1 * 2*Ww-1, nH - # cls to token & token 2 cls & cls to cls - - # get pair-wise relative position index for each token inside the window - coords_h = torch.arange(window_size[0]) - coords_w = torch.arange(window_size[1]) - coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww - coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww - relative_coords = ( - coords_flatten[:, :, None] - coords_flatten[:, None, :] - ) # 2, Wh*Ww, Wh*Ww - relative_coords = relative_coords.permute( - 1, 2, 0 - ).contiguous() # Wh*Ww, Wh*Ww, 2 - relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0 - relative_coords[:, :, 1] += window_size[1] - 1 - relative_coords[:, :, 0] *= 2 * window_size[1] - 1 - relative_position_index = torch.zeros( - size=(window_size[0] * window_size[1] + 1,) * 2, - dtype=relative_coords.dtype, - ) - relative_position_index[1:, 1:] = relative_coords.sum( - -1 - ) # Wh*Ww, Wh*Ww - relative_position_index[0, 0:] = self.num_relative_distance - 3 - relative_position_index[0:, 0] = self.num_relative_distance - 2 - relative_position_index[0, 0] = self.num_relative_distance - 1 - - self.register_buffer( - "relative_position_index", relative_position_index - ) - - # trunc_normal_(self.relative_position_bias_table, std=.02) - - def forward(self): - relative_position_bias = self.relative_position_bias_table[ - self.relative_position_index.view(-1) - ].view( - self.window_size[0] * self.window_size[1] + 1, - self.window_size[0] * self.window_size[1] + 1, - -1, - ) # Wh*Ww,Wh*Ww,nH - return relative_position_bias.permute( - 2, 0, 1 - ).contiguous() # nH, Wh*Ww, Wh*Ww - - -class VisionTransformer(nn.Module): - """Vision Transformer with support for patch or hybrid CNN input stage""" - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - embed_dim=768, - depth=12, - num_heads=12, - mlp_ratio=4.0, - qkv_bias=False, - qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.0, - norm_layer=nn.LayerNorm, - init_values=None, - use_abs_pos_emb=True, - use_rel_pos_bias=False, - use_shared_rel_pos_bias=False, - use_mean_pooling=True, - init_scale=0.001, - use_checkpoint=False, - ): - super().__init__() - self.image_size = img_size - self.num_classes = num_classes - self.num_features = ( - self.embed_dim - ) = embed_dim # num_features for consistency with other models - - self.patch_embed = PatchEmbed( - img_size=img_size, - patch_size=patch_size, - in_chans=in_chans, - embed_dim=embed_dim, - ) - num_patches = self.patch_embed.num_patches - - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - if use_abs_pos_emb: - self.pos_embed = nn.Parameter( - torch.zeros(1, num_patches + 1, embed_dim) - ) - else: - self.pos_embed = None - self.pos_drop = nn.Dropout(p=drop_rate) - - if use_shared_rel_pos_bias: - self.rel_pos_bias = RelativePositionBias( - window_size=self.patch_embed.patch_shape, num_heads=num_heads - ) - else: - self.rel_pos_bias = None - self.use_checkpoint = use_checkpoint - - dpr = [ - x.item() for x in torch.linspace(0, drop_path_rate, depth) - ] # stochastic depth decay rule - self.use_rel_pos_bias = use_rel_pos_bias - self.blocks = nn.ModuleList( - [ - Block( - dim=embed_dim, - num_heads=num_heads, - mlp_ratio=mlp_ratio, - qkv_bias=qkv_bias, - qk_scale=qk_scale, - drop=drop_rate, - attn_drop=attn_drop_rate, - drop_path=dpr[i], - norm_layer=norm_layer, - init_values=init_values, - window_size=self.patch_embed.patch_shape - if use_rel_pos_bias - else None, - ) - for i in range(depth) - ] - ) - # self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) - # self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None - # self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() - - if self.pos_embed is not None: - trunc_normal_(self.pos_embed, std=0.02) - trunc_normal_(self.cls_token, std=0.02) - # trunc_normal_(self.mask_token, std=.02) - # if isinstance(self.head, nn.Linear): - # trunc_normal_(self.head.weight, std=.02) - self.apply(self._init_weights) - self.fix_init_weight() - - # if isinstance(self.head, nn.Linear): - # self.head.weight.data.mul_(init_scale) - # self.head.bias.data.mul_(init_scale) - - def fix_init_weight(self): - def rescale(param, layer_id): - param.div_(math.sqrt(2.0 * layer_id)) - - for layer_id, layer in enumerate(self.blocks): - rescale(layer.attn.proj.weight.data, layer_id + 1) - rescale(layer.mlp.fc2.weight.data, layer_id + 1) - - def _init_weights(self, m): - if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=0.02) - if isinstance(m, nn.Linear) and m.bias is not None: - nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) - - def get_classifier(self): - return self.head - - def reset_classifier(self, num_classes, global_pool=""): - self.num_classes = num_classes - self.head = ( - nn.Linear(self.embed_dim, num_classes) - if num_classes > 0 - else nn.Identity() - ) - - def forward_features(self, x): - x = self.patch_embed(x) - batch_size, seq_len, _ = x.size() - - cls_tokens = self.cls_token.expand( - batch_size, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) - if self.pos_embed is not None: - x = x + self.pos_embed - x = self.pos_drop(x) - - rel_pos_bias = ( - self.rel_pos_bias() if self.rel_pos_bias is not None else None - ) - for blk in self.blocks: - if self.use_checkpoint: - x = checkpoint.checkpoint(blk, x, rel_pos_bias) - else: - x = blk(x, rel_pos_bias) - return x - - # x = self.norm(x) - - # if self.fc_norm is not None: - # t = x[:, 1:, :] - # return self.fc_norm(t.mean(1)) - # else: - # return x[:, 0] - - def forward(self, x): - x = self.forward_features(x) - # x = self.head(x) - return x - - def get_intermediate_layers(self, x): - x = self.patch_embed(x) - batch_size, seq_len, _ = x.size() - - cls_tokens = self.cls_token.expand( - batch_size, -1, -1 - ) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) - if self.pos_embed is not None: - x = x + self.pos_embed - x = self.pos_drop(x) - - features = [] - rel_pos_bias = ( - self.rel_pos_bias() if self.rel_pos_bias is not None else None - ) - for blk in self.blocks: - x = blk(x, rel_pos_bias) - features.append(x) - - return features - - -def interpolate_pos_embed(model, checkpoint_model): - if "pos_embed" in checkpoint_model: - pos_embed_checkpoint = checkpoint_model["pos_embed"].float() - embedding_size = pos_embed_checkpoint.shape[-1] - num_patches = model.patch_embed.num_patches - num_extra_tokens = model.pos_embed.shape[-2] - num_patches - # height (== width) for the checkpoint position embedding - orig_size = int( - (pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5 - ) - # height (== width) for the new position embedding - new_size = int(num_patches**0.5) - # class_token and dist_token are kept unchanged - if orig_size != new_size: - print( - "Position interpolate from %dx%d to %dx%d" - % (orig_size, orig_size, new_size, new_size) - ) - extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] - # only the position tokens are interpolated - pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape( - -1, orig_size, orig_size, embedding_size - ).permute(0, 3, 1, 2) - pos_tokens = torch.nn.functional.interpolate( - pos_tokens, - size=(new_size, new_size), - mode="bicubic", - align_corners=False, - ) - pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) - new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) - checkpoint_model["pos_embed"] = new_pos_embed - - -def convert_weights_to_fp16(model: nn.Module): - """Convert applicable model parameters to fp16""" - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - # l.weight.data = l.weight.data.half() - l.weight.data = l.weight.data - if l.bias is not None: - # l.bias.data = l.bias.data.half() - l.bias.data = l.bias.data - - # if isinstance(l, (nn.MultiheadAttention, Attention)): - # for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: - # tensor = getattr(l, attr) - # if tensor is not None: - # tensor.data = tensor.data.half() - - model.apply(_convert_weights_to_fp16) - - -def create_eva_vit_g( - img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16" -): - model = VisionTransformer( - img_size=img_size, - patch_size=14, - use_mean_pooling=False, - embed_dim=1408, - depth=39, - num_heads=1408 // 88, - mlp_ratio=4.3637, - qkv_bias=True, - drop_path_rate=drop_path_rate, - norm_layer=partial(nn.LayerNorm, eps=1e-6), - use_checkpoint=use_checkpoint, - ) - url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth" - - local_filename = "eva_vit_g.pth" - response = requests.get(url) - if response.status_code == 200: - with open(local_filename, "wb") as f: - f.write(response.content) - print("File downloaded successfully.") - state_dict = torch.load(local_filename, map_location="cpu") - interpolate_pos_embed(model, state_dict) - - incompatible_keys = model.load_state_dict(state_dict, strict=False) - - if precision == "fp16": - # model.to("cuda") - convert_weights_to_fp16(model) - return model diff --git a/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt b/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt deleted file mode 100644 index 38ae75a9..00000000 --- a/apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt +++ /dev/null @@ -1,4 +0,0 @@ - Describe this image in detail. - Take a look at this image and describe what you notice. - Please provide a detailed description of the picture. - Could you describe the contents of this image for me? \ No newline at end of file diff --git a/apps/language_models/src/pipelines/stablelm_pipeline.py b/apps/language_models/src/pipelines/stablelm_pipeline.py deleted file mode 100644 index 9264d8b9..00000000 --- a/apps/language_models/src/pipelines/stablelm_pipeline.py +++ /dev/null @@ -1,300 +0,0 @@ -import torch -import torch_mlir -from transformers import AutoTokenizer, StoppingCriteria, AutoModelForCausalLM -from io import BytesIO -from pathlib import Path -from apps.language_models.utils import ( - get_vmfb_from_path, -) -from apps.language_models.src.pipelines.SharkLLMBase import SharkLLMBase -from apps.language_models.src.model_wrappers.stablelm_model import ( - StableLMModel, -) -import argparse - -parser = argparse.ArgumentParser( - prog="stablelm runner", - description="runs a StableLM model", -) - -parser.add_argument( - "--precision", "-p", default="fp16", choices=["fp32", "fp16", "int4"] -) -parser.add_argument("--device", "-d", default="cuda", help="vulkan, cpu, cuda") -parser.add_argument( - "--stablelm_vmfb_path", default=None, help="path to StableLM's vmfb" -) -parser.add_argument( - "--stablelm_mlir_path", - default=None, - help="path to StableLM's mlir file", -) -parser.add_argument( - "--use_precompiled_model", - default=True, - action=argparse.BooleanOptionalAction, - help="use the precompiled vmfb", -) -parser.add_argument( - "--load_mlir_from_shark_tank", - default=True, - action=argparse.BooleanOptionalAction, - help="download precompile mlir from shark tank", -) -parser.add_argument( - "--hf_auth_token", - type=str, - default=None, - help="Specify your own huggingface authentication token for stablelm-3B model.", -) - - -class StopOnTokens(StoppingCriteria): - def __call__( - self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs - ) -> bool: - stop_ids = [50278, 50279, 50277, 1, 0] - for stop_id in stop_ids: - if input_ids[0][-1] == stop_id: - return True - return False - - -class SharkStableLM(SharkLLMBase): - def __init__( - self, - model_name, - hf_model_path="stabilityai/stablelm-tuned-alpha-3b", - max_num_tokens=256, - device="cuda", - precision="fp32", - debug="False", - ) -> None: - super().__init__(model_name, hf_model_path, max_num_tokens) - self.max_sequence_len = 256 - self.device = device - if precision != "int4" and args.hf_auth_token == None: - raise ValueError( - """ HF auth token required for StableLM-3B. Pass it using - --hf_auth_token flag. You can ask for the access to the model - here: https://huggingface.co/tiiuae/falcon-180B-chat.""" - ) - self.hf_auth_token = args.hf_auth_token - - self.precision = precision - self.debug = debug - self.tokenizer = self.get_tokenizer() - self.shark_model = self.compile() - - def shouldStop(self, tokens): - stop_ids = [50278, 50279, 50277, 1, 0] - for stop_id in stop_ids: - if tokens[0][-1] == stop_id: - return True - return False - - def get_src_model(self): - kwargs = {} - if self.precision == "int4": - self.hf_model_path = "TheBloke/stablelm-zephyr-3b-GPTQ" - from transformers import GPTQConfig - - quantization_config = GPTQConfig(bits=4, disable_exllama=True) - kwargs["quantization_config"] = quantization_config - kwargs["device_map"] = "cpu" - print("[DEBUG] Loading Model") - model = AutoModelForCausalLM.from_pretrained( - self.hf_model_path, - trust_remote_code=True, - torch_dtype=torch.float32, - use_auth_token=self.hf_auth_token, - **kwargs, - ) - print("[DEBUG] Model loaded successfully") - return model - - def get_model_inputs(self): - input_ids = torch.randint(3, (1, self.max_sequence_len)) - attention_mask = torch.randint(3, (1, self.max_sequence_len)) - return input_ids, attention_mask - - def compile(self): - tmp_model_name = f"{self.model_name}_linalg_{self.precision}_seqLen{self.max_sequence_len}" - - # device = "cuda" # "cpu" - # TODO: vmfb and mlir name should include precision and device - model_vmfb_name = None - vmfb_path = ( - Path(tmp_model_name + f"_{self.device}.vmfb") - if model_vmfb_name is None - else Path(model_vmfb_name) - ) - shark_module = get_vmfb_from_path( - vmfb_path, self.device, mlir_dialect="tm_tensor" - ) - if shark_module is not None: - return shark_module - - mlir_path = Path(tmp_model_name + ".mlir") - print( - f"[DEBUG] mlir path {mlir_path} {'exists' if mlir_path.exists() else 'does not exist'}" - ) - if not mlir_path.exists(): - model = StableLMModel(self.get_src_model()) - model_inputs = self.get_model_inputs() - from shark.shark_importer import import_with_fx - - ts_graph = import_with_fx( - model, - model_inputs, - is_f16=True if self.precision in ["fp16"] else False, - precision=self.precision, - f16_input_mask=[False, False], - mlir_type="torchscript", - ) - module = torch_mlir.compile( - ts_graph, - [*model_inputs], - torch_mlir.OutputType.LINALG_ON_TENSORS, - use_tracing=False, - verbose=False, - ) - bytecode_stream = BytesIO() - module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - f_ = open(mlir_path, "wb") - f_.write(bytecode) - print("Saved mlir at: ", mlir_path) - f_.close() - del bytecode - - from shark.shark_inference import SharkInference - - shark_module = SharkInference( - mlir_module=mlir_path, device=self.device, mlir_dialect="tm_tensor" - ) - shark_module.compile() - - path = shark_module.save_module( - vmfb_path.parent.absolute(), vmfb_path.stem, debug=self.debug - ) - print("Saved vmfb at ", str(path)) - - return shark_module - - def get_tokenizer(self): - tok = AutoTokenizer.from_pretrained( - self.hf_model_path, - use_auth_token=self.hf_auth_token, - ) - tok.add_special_tokens({"pad_token": ""}) - # print("[DEBUG] Sucessfully loaded the tokenizer to the memory") - return tok - - def generate(self, prompt): - words_list = [] - import time - - start = time.time() - count = 0 - for i in range(self.max_num_tokens): - count = count + 1 - params = { - "new_text": prompt, - } - - generated_token_op = self.generate_new_token(params) - - detok = generated_token_op["detok"] - stop_generation = generated_token_op["stop_generation"] - - if stop_generation: - break - - print(detok, end="", flush=True) # this is for CLI and DEBUG - words_list.append(detok) - if detok == "": - break - prompt = prompt + detok - end = time.time() - print( - "\n\nTime taken is {:.2f} tokens/second\n".format( - count / (end - start) - ) - ) - return words_list - - def generate_new_token(self, params): - new_text = params["new_text"] - model_inputs = self.tokenizer( - [new_text], - padding="max_length", - max_length=self.max_sequence_len, - truncation=True, - return_tensors="pt", - ) - sum_attentionmask = torch.sum(model_inputs.attention_mask) - output = self.shark_model( - "forward", [model_inputs.input_ids, model_inputs.attention_mask] - ) - output = torch.from_numpy(output) - next_toks = torch.topk(output, 1) - stop_generation = False - if self.shouldStop(next_toks.indices): - stop_generation = True - new_token = next_toks.indices[0][int(sum_attentionmask) - 1] - detok = self.tokenizer.decode( - new_token, - skip_special_tokens=True, - ) - ret_dict = { - "new_token": new_token, - "detok": detok, - "stop_generation": stop_generation, - } - return ret_dict - - -if __name__ == "__main__": - args = parser.parse_args() - - stable_lm = SharkStableLM( - model_name="stablelm_zephyr_3b", - hf_model_path="stabilityai/stablelm-zephyr-3b", - device=args.device, - precision=args.precision, - ) - - default_prompt_text = "The weather is always wonderful" - continue_execution = True - - print("\n-----\nScript executing for the following config: \n") - print("StableLM Model: ", stable_lm.hf_model_path) - print("Precision: ", args.precision) - print("Device: ", args.device) - - while continue_execution: - use_default_prompt = input( - "\nDo you wish to use the default prompt text? Y/N ?: " - ) - if use_default_prompt in ["Y", "y"]: - prompt = default_prompt_text - else: - prompt = input("Please enter the prompt text: ") - print("\nPrompt Text: ", prompt) - - res_str = stable_lm.generate(prompt) - torch.cuda.empty_cache() - import gc - - gc.collect() - print( - "\n\n-----\nHere's the complete formatted result: \n\n", - prompt + "".join(res_str), - ) - continue_execution = input( - "\nDo you wish to run script one more time? Y/N ?: " - ) - continue_execution = ( - True if continue_execution in ["Y", "y"] else False - ) diff --git a/apps/language_models/utils.py b/apps/language_models/utils.py deleted file mode 100644 index 20bebdf9..00000000 --- a/apps/language_models/utils.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch -from torch.fx.experimental.proxy_tensor import make_fx -from torch._decomp import get_decompositions -from typing import List -from pathlib import Path -from shark.shark_downloader import download_public_file - - -# expects a Path / str as arg -# returns None if path not found or SharkInference module -def get_vmfb_from_path(vmfb_path, device, mlir_dialect, device_id=None): - if not isinstance(vmfb_path, Path): - vmfb_path = Path(vmfb_path) - - from shark.shark_inference import SharkInference - - if not vmfb_path.exists(): - return None - - print("Loading vmfb from: ", vmfb_path) - print("Device from get_vmfb_from_path - ", device) - shark_module = SharkInference( - None, device=device, mlir_dialect=mlir_dialect, device_idx=device_id - ) - shark_module.load_module(vmfb_path) - print("Successfully loaded vmfb") - return shark_module - - -def get_vmfb_from_config( - shark_container, - model, - precision, - device, - vmfb_path, - padding=None, - device_id=None, -): - vmfb_url = ( - f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}" - ) - if padding: - vmfb_url = vmfb_url + f"_{padding}" - vmfb_url = vmfb_url + ".vmfb" - download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True) - return get_vmfb_from_path( - vmfb_path, device, "tm_tensor", device_id=device_id - ) diff --git a/apps/stable_diffusion/__init__.py b/apps/stable_diffusion/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/apps/stable_diffusion/profiling_with_iree.md b/apps/stable_diffusion/profiling_with_iree.md deleted file mode 100644 index 0beba6e3..00000000 --- a/apps/stable_diffusion/profiling_with_iree.md +++ /dev/null @@ -1,87 +0,0 @@ -Compile / Run Instructions: - -To compile .vmfb for SD (vae, unet, CLIP), run the following commands with the .mlir in your local shark_tank cache (default location for Linux users is `~/.local/shark_tank`). These will be available once the script from [this README](https://github.com/nod-ai/SHARK/blob/main/shark/examples/shark_inference/stable_diffusion/README.md) is run once. -Running the script mentioned above with the `--save_vmfb` flag will also save the .vmfb in your SHARK base directory if you want to skip straight to benchmarks. - -Compile Commands FP32/FP16: - -```shell -Vulkan AMD: -iree-compile --iree-input-type=none --iree-hal-target-backends=vulkan --iree-vulkan-target-triple=rdna2-unknown-linux /path/to/input/mlir -o /path/to/output/vmfb - -# add --mlir-print-debuginfo --mlir-print-op-on-diagnostic=true for debug -# use –iree-input-type=auto or "mhlo_legacy" or "stablehlo" for TF models - -CUDA NVIDIA: -iree-compile --iree-input-type=none --iree-hal-target-backends=cuda /path/to/input/mlir -o /path/to/output/vmfb - -CPU: -iree-compile --iree-input-type=none --iree-hal-target-backends=llvm-cpu /path/to/input/mlir -o /path/to/output/vmfb -``` - - - -Run / Benchmark Command (FP32 - NCHW): -(NEED to use BS=2 since we do two forward passes to unet as a result of classifier free guidance.) - -```shell -## Vulkan AMD: -iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0 - -## CUDA: -iree-benchmark-module --module=/path/to/vmfb --function=forward --device=cuda --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0 - -## CPU: -iree-benchmark-module --module=/path/to/vmfb --function=forward --device=local-task --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0 - -``` - -Run via vulkan_gui for RGP Profiling: - -To build the vulkan app for profiling UNet follow the instructions [here](https://github.com/nod-ai/SHARK/tree/main/cpp) and then run the following command from the cpp directory with your compiled stable_diff.vmfb -```shell -./build/vulkan_gui/iree-vulkan-gui --module=/path/to/unet.vmfb --input=1x4x64x64xf32 --input=1xf32 --input=2x77x768xf32 --input=f32=1.0 --input=f32=1.0 -``` - - -
    - Debug Commands - -## Debug commands and other advanced usage follows. - -```shell -python txt2img.py --precision="fp32"|"fp16" --device="cpu"|"cuda"|"vulkan" --import_mlir|--no-import_mlir --prompt "enter the text" -``` - -## dump all dispatch .spv and isa using amdllpc - -```shell -python txt2img.py --precision="fp16" --device="vulkan" --iree-vulkan-target-triple=rdna3-unknown-linux --no-load_vmfb --dispatch_benchmarks="all" --dispatch_benchmarks_dir="SD_dispatches" --dump_isa -``` - -## Compile and save the .vmfb (using vulkan fp16 as an example): - -```shell -python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb -``` - -## Capture an RGP trace - -```shell -python txt2img.py --precision=fp16 --device=vulkan --steps=50 --save_vmfb --enable_rgp -``` - -## Run the vae module with iree-benchmark-module (NCHW, fp16, vulkan, for example): - -```shell -iree-benchmark-module --module=/path/to/output/vmfb --function=forward --device=vulkan --input=1x4x64x64xf16 -``` - -## Run the unet module with iree-benchmark-module (same config as above): -```shell -##if you want to use .npz inputs: -unzip ~/.local/shark_tank//inputs.npz -iree-benchmark-module --module=/path/to/output/vmfb --function=forward --input=@arr_0.npy --input=1xf16 --input=@arr_2.npy --input=@arr_3.npy --input=@arr_4.npy -``` - -
    diff --git a/apps/stable_diffusion/scripts/__init__.py b/apps/stable_diffusion/scripts/__init__.py deleted file mode 100644 index ee4bfbb8..00000000 --- a/apps/stable_diffusion/scripts/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from apps.stable_diffusion.scripts.train_lora_word import lora_train diff --git a/apps/stable_diffusion/scripts/img2img.py b/apps/stable_diffusion/scripts/img2img.py deleted file mode 100644 index 7095c758..00000000 --- a/apps/stable_diffusion/scripts/img2img.py +++ /dev/null @@ -1,128 +0,0 @@ -import sys -import torch -import time -from PIL import Image -import transformers -from apps.stable_diffusion.src import ( - args, - Image2ImagePipeline, - StencilPipeline, - resize_stencil, - get_schedulers, - set_init_device_flags, - utils, - clear_all, - save_output_img, -) -from apps.stable_diffusion.src.utils import get_generation_text_info - - -def main(): - if args.clear_all: - clear_all() - - if args.img_path is None: - print("Flag --img_path is required.") - exit() - - image = Image.open(args.img_path).convert("RGB") - # When the models get uploaded, it should be default to False. - args.import_mlir = True - - use_stencil = args.use_stencil - if use_stencil: - args.scheduler = "DDIM" - args.hf_model_id = "runwayml/stable-diffusion-v1-5" - image, args.width, args.height = resize_stencil(image) - elif "Shark" in args.scheduler: - print( - f"Shark schedulers are not supported. Switching to EulerDiscrete scheduler" - ) - args.scheduler = "EulerDiscrete" - cpu_scheduling = not args.scheduler.startswith("Shark") - dtype = torch.float32 if args.precision == "fp32" else torch.half - set_init_device_flags() - schedulers = get_schedulers(args.hf_model_id) - scheduler_obj = schedulers[args.scheduler] - seed = utils.sanitize_seed(args.seed) - # Adjust for height and width based on model - - if use_stencil: - img2img_obj = StencilPipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - use_stencil=use_stencil, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - else: - img2img_obj = Image2ImagePipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - - start_time = time.time() - generated_imgs = img2img_obj.generate_images( - args.prompts, - args.negative_prompts, - image, - args.batch_size, - args.height, - args.width, - args.steps, - args.strength, - args.guidance_scale, - seed, - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - use_stencil=use_stencil, - control_mode=args.control_mode, - ) - total_time = time.time() - start_time - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" - text_output += f"\nscheduler={args.scheduler}, device={args.device}" - text_output += f"\nsteps={args.steps}, strength={args.strength}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}" - text_output += ( - f", batch size={args.batch_size}, max_length={args.max_length}" - ) - text_output += img2img_obj.log - text_output += f"\nTotal image generation time: {total_time:.4f}sec" - - extra_info = {"STRENGTH": args.strength} - save_output_img(generated_imgs[0], seed, extra_info) - print(text_output) - - -if __name__ == "__main__": - main() diff --git a/apps/stable_diffusion/scripts/inpaint.py b/apps/stable_diffusion/scripts/inpaint.py deleted file mode 100644 index 2f2a8af6..00000000 --- a/apps/stable_diffusion/scripts/inpaint.py +++ /dev/null @@ -1,105 +0,0 @@ -import torch -import time -from PIL import Image -import transformers -from apps.stable_diffusion.src import ( - args, - InpaintPipeline, - get_schedulers, - set_init_device_flags, - utils, - clear_all, - save_output_img, -) -from apps.stable_diffusion.src.utils import get_generation_text_info - - -def main(): - if args.clear_all: - clear_all() - - if args.img_path is None: - print("Flag --img_path is required.") - exit() - if args.mask_path is None: - print("Flag --mask_path is required.") - exit() - - dtype = torch.float32 if args.precision == "fp32" else torch.half - cpu_scheduling = not args.scheduler.startswith("Shark") - set_init_device_flags() - model_id = ( - args.hf_model_id - if "inpaint" in args.hf_model_id - else "stabilityai/stable-diffusion-2-inpainting" - ) - schedulers = get_schedulers(model_id) - scheduler_obj = schedulers[args.scheduler] - seed = args.seed - image = Image.open(args.img_path) - mask_image = Image.open(args.mask_path) - - inpaint_obj = InpaintPipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - custom_vae=args.custom_vae, - precision=args.precision, - max_length=args.max_length, - batch_size=args.batch_size, - height=args.height, - width=args.width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - - seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds) - for current_batch in range(args.batch_count): - start_time = time.time() - generated_imgs = inpaint_obj.generate_images( - args.prompts, - args.negative_prompts, - image, - mask_image, - args.batch_size, - args.height, - args.width, - args.inpaint_full_res, - args.inpaint_full_res_padding, - args.steps, - args.guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - total_time = time.time() - start_time - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += ( - f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" - ) - text_output += f"\nscheduler={args.scheduler}, device={args.device}" - text_output += ( - f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}," - ) - text_output += f"seed={seed}, size={args.height}x{args.width}" - text_output += ( - f", batch size={args.batch_size}, max_length={args.max_length}" - ) - text_output += inpaint_obj.log - text_output += f"\nTotal image generation time: {total_time:.4f}sec" - - save_output_img(generated_imgs[0], seed) - print(text_output) - - -if __name__ == "__main__": - main() diff --git a/apps/stable_diffusion/scripts/main.py b/apps/stable_diffusion/scripts/main.py deleted file mode 100644 index a73ead42..00000000 --- a/apps/stable_diffusion/scripts/main.py +++ /dev/null @@ -1,19 +0,0 @@ -from apps.stable_diffusion.src import args -from apps.stable_diffusion.scripts import ( - img2img, - txt2img, - # inpaint, - # outpaint, -) - -if __name__ == "__main__": - if args.app == "txt2img": - txt2img.main() - elif args.app == "img2img": - img2img.main() - # elif args.app == "inpaint": - # inpaint.main() - # elif args.app == "outpaint": - # outpaint.main() - else: - print(f"args.app value is {args.app} but this isn't supported") diff --git a/apps/stable_diffusion/scripts/outpaint.py b/apps/stable_diffusion/scripts/outpaint.py deleted file mode 100644 index 2c1a1ff8..00000000 --- a/apps/stable_diffusion/scripts/outpaint.py +++ /dev/null @@ -1,120 +0,0 @@ -import torch -import time -from PIL import Image -import transformers -from apps.stable_diffusion.src import ( - args, - OutpaintPipeline, - get_schedulers, - set_init_device_flags, - utils, - clear_all, - save_output_img, -) - - -def main(): - if args.clear_all: - clear_all() - - if args.img_path is None: - print("Flag --img_path is required.") - exit() - - dtype = torch.float32 if args.precision == "fp32" else torch.half - cpu_scheduling = not args.scheduler.startswith("Shark") - set_init_device_flags() - model_id = ( - args.hf_model_id - if "inpaint" in args.hf_model_id - else "stabilityai/stable-diffusion-2-inpainting" - ) - schedulers = get_schedulers(model_id) - scheduler_obj = schedulers[args.scheduler] - seed = args.seed - image = Image.open(args.img_path) - - outpaint_obj = OutpaintPipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - - seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds) - for current_batch in range(args.batch_count): - start_time = time.time() - generated_imgs = outpaint_obj.generate_images( - args.prompts, - args.negative_prompts, - image, - args.pixels, - args.mask_blur, - args.left, - args.right, - args.top, - args.bottom, - args.noise_q, - args.color_variation, - args.batch_size, - args.height, - args.width, - args.steps, - args.guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - total_time = time.time() - start_time - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += ( - f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" - ) - text_output += f"\nscheduler={args.scheduler}, device={args.device}" - text_output += ( - f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}," - ) - text_output += f"seed={seed}, size={args.height}x{args.width}" - text_output += ( - f", batch size={args.batch_size}, max_length={args.max_length}" - ) - text_output += outpaint_obj.log - text_output += f"\nTotal image generation time: {total_time:.4f}sec" - - # save this information as metadata of output generated image. - directions = [] - if args.left: - directions.append("left") - if args.right: - directions.append("right") - if args.top: - directions.append("up") - if args.bottom: - directions.append("down") - extra_info = { - "PIXELS": args.pixels, - "MASK_BLUR": args.mask_blur, - "DIRECTIONS": directions, - "NOISE_Q": args.noise_q, - "COLOR_VARIATION": args.color_variation, - } - save_output_img(generated_imgs[0], seed, extra_info) - print(text_output) - - -if __name__ == "__main__": - main() diff --git a/apps/stable_diffusion/scripts/telegram_bot.py b/apps/stable_diffusion/scripts/telegram_bot.py deleted file mode 100644 index b178efe6..00000000 --- a/apps/stable_diffusion/scripts/telegram_bot.py +++ /dev/null @@ -1,240 +0,0 @@ -import logging -import os -from models.stable_diffusion.main import stable_diff_inf -from models.stable_diffusion.utils import get_available_devices -from dotenv import load_dotenv -from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup -from telegram import BotCommand -from telegram.ext import Application, ApplicationBuilder, CallbackQueryHandler -from telegram.ext import ContextTypes, MessageHandler, CommandHandler, filters -from io import BytesIO -import random - -log = logging.getLogger("TG.Bot") -logging.basicConfig() -log.warning("Start") -load_dotenv() -os.environ["AMD_ENABLE_LLPC"] = "0" -TG_TOKEN = os.getenv("TG_TOKEN") -SELECTED_MODEL = "stablediffusion" -SELECTED_SCHEDULER = "EulerAncestralDiscrete" -STEPS = 30 -NEGATIVE_PROMPT = ( - "Ugly,Morbid,Extra fingers,Poorly drawn hands,Mutation,Blurry,Extra" - " limbs,Gross proportions,Missing arms,Mutated hands,Long" - " neck,Duplicate,Mutilated,Mutilated hands,Poorly drawn face,Deformed,Bad" - " anatomy,Cloned face,Malformed limbs,Missing legs,Too many" - " fingers,blurry, lowres, text, error, cropped, worst quality, low" - " quality, jpeg artifacts, out of frame, extra fingers, mutated hands," - " poorly drawn hands, poorly drawn face, bad anatomy, extra limbs, cloned" - " face, malformed limbs, missing arms, missing legs, extra arms, extra" - " legs, fused fingers, too many fingers" -) -GUIDANCE_SCALE = 6 -available_devices = get_available_devices() -models_list = [ - "stablediffusion", - "anythingv3", - "analogdiffusion", - "openjourney", - "dreamlike", -] -sheds_list = [ - "DDIM", - "PNDM", - "LMSDiscrete", - "DPMSolverMultistep", - "EulerDiscrete", - "EulerAncestralDiscrete", - "SharkEulerDiscrete", -] - - -def image_to_bytes(image): - bio = BytesIO() - bio.name = "image.jpeg" - image.save(bio, "JPEG") - bio.seek(0) - return bio - - -def get_try_again_markup(): - keyboard = [[InlineKeyboardButton("Try again", callback_data="TRYAGAIN")]] - reply_markup = InlineKeyboardMarkup(keyboard) - return reply_markup - - -def generate_image(prompt): - seed = random.randint(1, 10000) - log.warning(SELECTED_MODEL) - log.warning(STEPS) - image, text = stable_diff_inf( - prompt=prompt, - negative_prompt=NEGATIVE_PROMPT, - steps=STEPS, - guidance_scale=GUIDANCE_SCALE, - seed=seed, - scheduler_key=SELECTED_SCHEDULER, - variant=SELECTED_MODEL, - device_key=available_devices[0], - ) - - return image, seed - - -async def generate_and_send_photo( - update: Update, context: ContextTypes.DEFAULT_TYPE -) -> None: - progress_msg = await update.message.reply_text( - "Generating image...", reply_to_message_id=update.message.message_id - ) - im, seed = generate_image(prompt=update.message.text) - await context.bot.delete_message( - chat_id=progress_msg.chat_id, message_id=progress_msg.message_id - ) - await context.bot.send_photo( - update.effective_user.id, - image_to_bytes(im), - caption=f'"{update.message.text}" (Seed: {seed})', - reply_markup=get_try_again_markup(), - reply_to_message_id=update.message.message_id, - ) - - -async def button(update: Update, context: ContextTypes.DEFAULT_TYPE) -> None: - query = update.callback_query - if query.data in models_list: - global SELECTED_MODEL - SELECTED_MODEL = query.data - await query.answer() - await query.edit_message_text(text=f"Selected model: {query.data}") - return - if query.data in sheds_list: - global SELECTED_SCHEDULER - SELECTED_SCHEDULER = query.data - await query.answer() - await query.edit_message_text(text=f"Selected scheduler: {query.data}") - return - replied_message = query.message.reply_to_message - await query.answer() - progress_msg = await query.message.reply_text( - "Generating image...", reply_to_message_id=replied_message.message_id - ) - - if query.data == "TRYAGAIN": - prompt = replied_message.text - im, seed = generate_image(prompt) - - await context.bot.delete_message( - chat_id=progress_msg.chat_id, message_id=progress_msg.message_id - ) - await context.bot.send_photo( - update.effective_user.id, - image_to_bytes(im), - caption=f'"{prompt}" (Seed: {seed})', - reply_markup=get_try_again_markup(), - reply_to_message_id=replied_message.message_id, - ) - - -async def select_model_handler(update, context): - text = "Select model" - keyboard = [] - for model in models_list: - keyboard.append( - [ - InlineKeyboardButton(text=model, callback_data=model), - ] - ) - markup = InlineKeyboardMarkup(keyboard) - await update.message.reply_text(text=text, reply_markup=markup) - - -async def select_scheduler_handler(update, context): - text = "Select schedule" - keyboard = [] - for shed in sheds_list: - keyboard.append( - [ - InlineKeyboardButton(text=shed, callback_data=shed), - ] - ) - markup = InlineKeyboardMarkup(keyboard) - await update.message.reply_text(text=text, reply_markup=markup) - - -async def set_steps_handler(update, context): - input_mex = update.message.text - log.warning(input_mex) - try: - input_args = input_mex.split("/set_steps ")[1] - global STEPS - STEPS = int(input_args) - except Exception: - input_args = ( - "Invalid parameter for command. Correct command looks like\n" - " /set_steps 30" - ) - await update.message.reply_text(input_args) - - -async def set_negative_prompt_handler(update, context): - input_mex = update.message.text - log.warning(input_mex) - try: - input_args = input_mex.split("/set_negative_prompt ")[1] - global NEGATIVE_PROMPT - NEGATIVE_PROMPT = input_args - except Exception: - input_args = ( - "Invalid parameter for command. Correct command looks like\n" - " /set_negative_prompt ugly, bad art, mutated" - ) - await update.message.reply_text(input_args) - - -async def set_guidance_scale_handler(update, context): - input_mex = update.message.text - log.warning(input_mex) - try: - input_args = input_mex.split("/set_guidance_scale ")[1] - global GUIDANCE_SCALE - GUIDANCE_SCALE = int(input_args) - except Exception: - input_args = ( - "Invalid parameter for command. Correct command looks like\n" - " /set_guidance_scale 7" - ) - await update.message.reply_text(input_args) - - -async def setup_bot_commands(application: Application) -> None: - await application.bot.set_my_commands( - [ - BotCommand("select_model", "to select model"), - BotCommand("select_scheduler", "to select scheduler"), - BotCommand("set_steps", "to set steps"), - BotCommand("set_guidance_scale", "to set guidance scale"), - BotCommand("set_negative_prompt", "to set negative prompt"), - ] - ) - - -app = ( - ApplicationBuilder().token(TG_TOKEN).post_init(setup_bot_commands).build() -) -app.add_handler(CommandHandler("select_model", select_model_handler)) -app.add_handler(CommandHandler("select_scheduler", select_scheduler_handler)) -app.add_handler(CommandHandler("set_steps", set_steps_handler)) -app.add_handler( - CommandHandler("set_guidance_scale", set_guidance_scale_handler) -) -app.add_handler( - CommandHandler("set_negative_prompt", set_negative_prompt_handler) -) -app.add_handler( - MessageHandler(filters.TEXT & ~filters.COMMAND, generate_and_send_photo) -) -app.add_handler(CallbackQueryHandler(button)) -log.warning("Start bot") -app.run_polling() diff --git a/apps/stable_diffusion/scripts/train_lora_word.py b/apps/stable_diffusion/scripts/train_lora_word.py deleted file mode 100644 index bbfe22fd..00000000 --- a/apps/stable_diffusion/scripts/train_lora_word.py +++ /dev/null @@ -1,693 +0,0 @@ -# Install the required libs -# pip install -U git+https://github.com/huggingface/diffusers.git -# pip install accelerate transformers ftfy - -# HuggingFace Token -# YOUR_TOKEN = "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" - - -# Import required libraries -import itertools -import math -import os -from typing import List -import random -import torch_mlir - -import numpy as np -import torch -import torch.nn.functional as F -import torch.utils.checkpoint -from torch.utils.data import Dataset - -import PIL -import logging - -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - PNDMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) -from PIL import Image -from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRAXFormersAttnProcessor - -import torch_mlir -from torch_mlir.dynamo import make_simple_dynamo_backend -import torch._dynamo as dynamo -from torch.fx.experimental.proxy_tensor import make_fx -from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend -from shark.shark_inference import SharkInference - -torch._dynamo.config.verbose = True - -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - PNDMScheduler, - StableDiffusionPipeline, - UNet2DConditionModel, -) -from diffusers.optimization import get_scheduler -from diffusers.pipelines.stable_diffusion import ( - StableDiffusionSafetyChecker, -) -from PIL import Image -from tqdm.auto import tqdm -from transformers import ( - CLIPFeatureExtractor, - CLIPTextModel, - CLIPTokenizer, -) - -from io import BytesIO - -from dataclasses import dataclass -from apps.stable_diffusion.src import ( - args, - get_schedulers, - set_init_device_flags, - clear_all, -) -from apps.stable_diffusion.src.utils import update_lora_weight - - -# Setup the dataset -class LoraDataset(Dataset): - def __init__( - self, - data_root, - tokenizer, - size=512, - repeats=100, - interpolation="bicubic", - set="train", - prompt="myloraprompt", - center_crop=False, - ): - self.data_root = data_root - self.tokenizer = tokenizer - self.size = size - self.center_crop = center_crop - self.prompt = prompt - - self.image_paths = [ - os.path.join(self.data_root, file_path) - for file_path in os.listdir(self.data_root) - ] - - self.num_images = len(self.image_paths) - self._length = self.num_images - - if set == "train": - self._length = self.num_images * repeats - - self.interpolation = { - "linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, - }[interpolation] - - def __len__(self): - return self._length - - def __getitem__(self, i): - example = {} - image = Image.open(self.image_paths[i % self.num_images]) - - if not image.mode == "RGB": - image = image.convert("RGB") - - example["input_ids"] = self.tokenizer( - self.prompt, - padding="max_length", - truncation=True, - max_length=self.tokenizer.model_max_length, - return_tensors="pt", - ).input_ids[0] - - # default to score-sde preprocessing - img = np.array(image).astype(np.uint8) - - if self.center_crop: - crop = min(img.shape[0], img.shape[1]) - ( - h, - w, - ) = ( - img.shape[0], - img.shape[1], - ) - img = img[ - (h - crop) // 2 : (h + crop) // 2, - (w - crop) // 2 : (w + crop) // 2, - ] - - image = Image.fromarray(img) - image = image.resize( - (self.size, self.size), resample=self.interpolation - ) - - image = np.array(image).astype(np.uint8) - image = (image / 127.5 - 1.0).astype(np.float32) - - example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) - return example - - -def torch_device(device): - device_tokens = device.split("=>") - if len(device_tokens) == 1: - device_str = device_tokens[0].strip() - else: - device_str = device_tokens[1].strip() - device_type_tokens = device_str.split("://") - if device_type_tokens[0] == "metal": - device_type_tokens[0] = "vulkan" - if len(device_type_tokens) > 1: - return device_type_tokens[0] + ":" + device_type_tokens[1] - else: - return device_type_tokens[0] - - -########## Setting up the model ########## -def lora_train( - prompt: str, - height: int, - width: int, - steps: int, - guidance_scale: float, - seed: int, - batch_count: int, - batch_size: int, - scheduler: str, - custom_model: str, - hf_model_id: str, - precision: str, - device: str, - max_length: int, - training_images_dir: str, - lora_save_dir: str, - use_lora: str, -): - from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_pathfile, - Config, - ) - import apps.stable_diffusion.web.utils.global_obj as global_obj - - print( - "Note LoRA training is not compatible with the latest torch-mlir branch" - ) - print( - "To run LoRA training you'll need this to follow this guide for the torch-mlir branch: https://github.com/nod-ai/SHARK/tree/main/shark/examples/shark_training/stable_diffusion" - ) - torch.manual_seed(seed) - - args.prompts = [prompt] - args.steps = steps - - # set ckpt_loc and hf_model_id. - types = ( - ".ckpt", - ".safetensors", - ) # the tuple of file types - args.ckpt_loc = "" - args.hf_model_id = "" - if custom_model == "None": - if not hf_model_id: - return ( - None, - "Please provide either custom model or huggingface model ID, both must not be " - "empty.", - ) - args.hf_model_id = hf_model_id - elif ".ckpt" in custom_model or ".safetensors" in custom_model: - args.ckpt_loc = custom_model - else: - args.hf_model_id = custom_model - - args.training_images_dir = training_images_dir - args.lora_save_dir = lora_save_dir - - args.precision = precision - args.batch_size = batch_size - args.max_length = max_length - args.height = height - args.width = width - args.device = torch_device(device) - args.use_lora = use_lora - - # Load the Stable Diffusion model - text_encoder = CLIPTextModel.from_pretrained( - args.hf_model_id, subfolder="text_encoder" - ) - vae = AutoencoderKL.from_pretrained(args.hf_model_id, subfolder="vae") - unet = UNet2DConditionModel.from_pretrained( - args.hf_model_id, subfolder="unet" - ) - - def freeze_params(params): - for param in params: - param.requires_grad = False - - # Freeze everything but LoRA - freeze_params(vae.parameters()) - freeze_params(unet.parameters()) - freeze_params(text_encoder.parameters()) - - # Move vae and unet to device - vae.to(args.device) - unet.to(args.device) - text_encoder.to(args.device) - - if use_lora != "": - update_lora_weight(unet, args.use_lora, "unet") - else: - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = ( - None - if name.endswith("attn1.processor") - else unet.config.cross_attention_dim - ) - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[ - block_id - ] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - - lora_attn_procs[name] = LoRAXFormersAttnProcessor( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - ) - - unet.set_attn_processor(lora_attn_procs) - lora_layers = AttnProcsLayers(unet.attn_processors) - - class VaeModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.vae = vae - - def forward(self, input): - x = self.vae.encode(input, return_dict=False)[0] - return x - - class UnetModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.unet = unet - - def forward(self, x, y, z): - return self.unet.forward(x, y, z, return_dict=False)[0] - - shark_vae = VaeModel() - shark_unet = UnetModel() - - ####### Creating our training data ######## - - tokenizer = CLIPTokenizer.from_pretrained( - args.hf_model_id, - subfolder="tokenizer", - ) - - # Let's create the Dataset and Dataloader - train_dataset = LoraDataset( - data_root=args.training_images_dir, - tokenizer=tokenizer, - size=vae.sample_size, - prompt=args.prompts[0], - repeats=100, - center_crop=False, - set="train", - ) - - def create_dataloader(train_batch_size=1): - return torch.utils.data.DataLoader( - train_dataset, batch_size=train_batch_size, shuffle=True - ) - - # Create noise_scheduler for training - noise_scheduler = DDPMScheduler.from_config( - args.hf_model_id, subfolder="scheduler" - ) - - ######## Training ########### - - # Define hyperparameters for our training. If you are not happy with your results, - # you can tune the `learning_rate` and the `max_train_steps` - - # Setting up all training args - hyperparameters = { - "learning_rate": 5e-04, - "scale_lr": True, - "max_train_steps": steps, - "train_batch_size": batch_size, - "gradient_accumulation_steps": 1, - "gradient_checkpointing": True, - "mixed_precision": "fp16", - "seed": 42, - "output_dir": "sd-concept-output", - } - # creating output directory - cwd = os.getcwd() - out_dir = os.path.join(cwd, hyperparameters["output_dir"]) - while not os.path.exists(str(out_dir)): - try: - os.mkdir(out_dir) - except OSError as error: - print("Output directory not created") - - ###### Torch-MLIR Compilation ###### - - def _remove_nones(fx_g: torch.fx.GraphModule) -> List[int]: - removed_indexes = [] - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, (list, tuple)): - node_arg = list(node_arg) - node_args_len = len(node_arg) - for i in range(node_args_len): - curr_index = node_args_len - (i + 1) - if node_arg[curr_index] is None: - removed_indexes.append(curr_index) - node_arg.pop(curr_index) - node.args = (tuple(node_arg),) - break - - if len(removed_indexes) > 0: - fx_g.graph.lint() - fx_g.graph.eliminate_dead_code() - fx_g.recompile() - removed_indexes.sort() - return removed_indexes - - def _unwrap_single_tuple_return(fx_g: torch.fx.GraphModule) -> bool: - """ - Replace tuple with tuple element in functions that return one-element tuples. - Returns true if an unwrapping took place, and false otherwise. - """ - unwrapped_tuple = False - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - if len(node_arg) == 1: - node.args = (node_arg[0],) - unwrapped_tuple = True - break - - if unwrapped_tuple: - fx_g.graph.lint() - fx_g.recompile() - return unwrapped_tuple - - def _returns_nothing(fx_g: torch.fx.GraphModule) -> bool: - for node in fx_g.graph.nodes: - if node.op == "output": - assert ( - len(node.args) == 1 - ), "Output node must have a single argument" - node_arg = node.args[0] - if isinstance(node_arg, tuple): - return len(node_arg) == 0 - return False - - def transform_fx(fx_g): - for node in fx_g.graph.nodes: - if node.op == "call_function": - if node.target in [ - torch.ops.aten.empty, - ]: - # aten.empty should be filled with zeros. - if node.target in [torch.ops.aten.empty]: - with fx_g.graph.inserting_after(node): - new_node = fx_g.graph.call_function( - torch.ops.aten.zero_, - args=(node,), - ) - node.append(new_node) - node.replace_all_uses_with(new_node) - new_node.args = (node,) - - fx_g.graph.lint() - - @make_simple_dynamo_backend - def refbackend_torchdynamo_backend( - fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor] - ): - # handling usage of empty tensor without initializing - transform_fx(fx_graph) - fx_graph.recompile() - if _returns_nothing(fx_graph): - return fx_graph - removed_none_indexes = _remove_nones(fx_graph) - was_unwrapped = _unwrap_single_tuple_return(fx_graph) - - mlir_module = torch_mlir.compile( - fx_graph, example_inputs, output_type="linalg-on-tensors" - ) - - bytecode_stream = BytesIO() - mlir_module.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - shark_module = SharkInference( - mlir_module=bytecode, device=args.device, mlir_dialect="tm_tensor" - ) - shark_module.compile() - - def compiled_callable(*inputs): - inputs = [x.numpy() for x in inputs] - result = shark_module("forward", inputs) - if was_unwrapped: - result = [ - result, - ] - if not isinstance(result, list): - result = torch.from_numpy(result) - else: - result = tuple(torch.from_numpy(x) for x in result) - result = list(result) - for removed_index in removed_none_indexes: - result.insert(removed_index, None) - result = tuple(result) - return result - - return compiled_callable - - def predictions(torch_func, jit_func, batchA, batchB): - res = jit_func(batchA.numpy(), batchB.numpy()) - if res is not None: - # prediction = torch.from_numpy(res) - prediction = res - else: - prediction = None - return prediction - - logger = logging.getLogger(__name__) - - train_batch_size = hyperparameters["train_batch_size"] - gradient_accumulation_steps = hyperparameters[ - "gradient_accumulation_steps" - ] - learning_rate = hyperparameters["learning_rate"] - if hyperparameters["scale_lr"]: - learning_rate = ( - learning_rate - * gradient_accumulation_steps - * train_batch_size - # * accelerator.num_processes - ) - - # Initialize the optimizer - optimizer = torch.optim.AdamW( - lora_layers.parameters(), # only optimize the embeddings - lr=learning_rate, - ) - - # Training function - def train_func(batch_pixel_values, batch_input_ids): - # Convert images to latent space - latents = shark_vae(batch_pixel_values).sample().detach() - latents = latents * 0.18215 - - # Sample noise that we'll add to the latents - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, - noise_scheduler.num_train_timesteps, - (bsz,), - device=latents.device, - ).long() - - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch_input_ids)[0] - - # Predict the noise residual - noise_pred = shark_unet( - noisy_latents, - timesteps, - encoder_hidden_states, - ) - - # Get the target for loss depending on the prediction type - if noise_scheduler.config.prediction_type == "epsilon": - target = noise - elif noise_scheduler.config.prediction_type == "v_prediction": - target = noise_scheduler.get_velocity(latents, noise, timesteps) - else: - raise ValueError( - f"Unknown prediction type {noise_scheduler.config.prediction_type}" - ) - - loss = ( - F.mse_loss(noise_pred, target, reduction="none") - .mean([1, 2, 3]) - .mean() - ) - loss.backward() - - optimizer.step() - optimizer.zero_grad() - - return loss - - def training_function(): - max_train_steps = hyperparameters["max_train_steps"] - output_dir = hyperparameters["output_dir"] - gradient_checkpointing = hyperparameters["gradient_checkpointing"] - - train_dataloader = create_dataloader(train_batch_size) - - # We need to recalculate our total training steps as the size of the training dataloader may have changed. - num_update_steps_per_epoch = math.ceil( - len(train_dataloader) / gradient_accumulation_steps - ) - num_train_epochs = math.ceil( - max_train_steps / num_update_steps_per_epoch - ) - - # Train! - total_batch_size = ( - train_batch_size - * gradient_accumulation_steps - # train_batch_size * accelerator.num_processes * gradient_accumulation_steps - ) - - logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") - logger.info( - f" Instantaneous batch size per device = {train_batch_size}" - ) - logger.info( - f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}" - ) - logger.info( - f" Gradient Accumulation steps = {gradient_accumulation_steps}" - ) - logger.info(f" Total optimization steps = {max_train_steps}") - # Only show the progress bar once on each machine. - progress_bar = tqdm( - # range(max_train_steps), disable=not accelerator.is_local_main_process - range(max_train_steps) - ) - progress_bar.set_description("Steps") - global_step = 0 - - params__ = [ - i for i in text_encoder.get_input_embeddings().parameters() - ] - - for epoch in range(num_train_epochs): - unet.train() - for step, batch in enumerate(train_dataloader): - dynamo_callable = dynamo.optimize( - refbackend_torchdynamo_backend - )(train_func) - lam_func = lambda x, y: dynamo_callable( - torch.from_numpy(x), torch.from_numpy(y) - ) - loss = predictions( - train_func, - lam_func, - batch["pixel_values"], - batch["input_ids"], - ) - - # Checks if the accelerator has performed an optimization step behind the scenes - progress_bar.update(1) - global_step += 1 - - logs = {"loss": loss.detach().item()} - progress_bar.set_postfix(**logs) - - if global_step >= max_train_steps: - break - - training_function() - - # Save the lora weights - unet.save_attn_procs(args.lora_save_dir) - - for param in itertools.chain(unet.parameters(), text_encoder.parameters()): - if param.grad is not None: - del param.grad # free some memory - torch.cuda.empty_cache() - - -if __name__ == "__main__": - if args.clear_all: - clear_all() - - dtype = torch.float32 if args.precision == "fp32" else torch.half - cpu_scheduling = not args.scheduler.startswith("Shark") - set_init_device_flags() - schedulers = get_schedulers(args.hf_model_id) - scheduler_obj = schedulers[args.scheduler] - seed = args.seed - if len(args.prompts) != 1: - print("Need exactly one prompt for the LoRA word") - lora_train( - args.prompts[0], - args.height, - args.width, - args.training_steps, - args.guidance_scale, - args.seed, - args.batch_count, - args.batch_size, - args.scheduler, - "None", - args.hf_model_id, - args.precision, - args.device, - args.max_length, - args.training_images_dir, - args.lora_save_dir, - args.use_lora, - ) diff --git a/apps/stable_diffusion/scripts/tuner.py b/apps/stable_diffusion/scripts/tuner.py deleted file mode 100644 index 18b4c047..00000000 --- a/apps/stable_diffusion/scripts/tuner.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -from pathlib import Path -from shark_tuner.codegen_tuner import SharkCodegenTuner -from shark_tuner.iree_utils import ( - dump_dispatches, - create_context, - export_module_to_mlir_file, -) -from shark_tuner.model_annotation import model_annotation -from apps.stable_diffusion.src.utils.stable_args import args -from apps.stable_diffusion.src.utils.utils import set_init_device_flags -from apps.stable_diffusion.src.utils.sd_annotation import ( - get_device_args, - load_winograd_configs, -) -from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel - - -def load_mlir_module(): - if "upscaler" in args.hf_model_id: - is_upscaler = True - else: - is_upscaler = False - sd_model = SharkifyStableDiffusionModel( - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - max_len=args.max_length, - batch_size=args.batch_size, - height=args.height, - width=args.width, - use_base_vae=args.use_base_vae, - is_upscaler=is_upscaler, - use_tuned=False, - low_cpu_mem_usage=args.low_cpu_mem_usage, - return_mlir=True, - ) - - if args.annotation_model == "unet": - mlir_module = sd_model.unet() - model_name = sd_model.model_name["unet"] - elif args.annotation_model == "vae": - mlir_module = sd_model.vae() - model_name = sd_model.model_name["vae"] - else: - raise ValueError( - f"{args.annotation_model} is not supported for tuning." - ) - - return mlir_module, model_name - - -def main(): - args.use_tuned = False - set_init_device_flags() - mlir_module, model_name = load_mlir_module() - - # Get device and device specific arguments - device, device_spec_args = get_device_args() - device_spec = "" - vulkan_target_triple = "" - if device_spec_args: - device_spec = device_spec_args[-1].split("=")[-1].strip() - if device == "vulkan": - vulkan_target_triple = device_spec - device_spec = device_spec.split("-")[0] - - # Add winograd annotation for vulkan device - use_winograd = ( - True - if device == "vulkan" and args.annotation_model in ["unet", "vae"] - else False - ) - winograd_config = ( - load_winograd_configs() - if device == "vulkan" and args.annotation_model in ["unet", "vae"] - else "" - ) - with create_context() as ctx: - input_module = model_annotation( - ctx, - input_contents=mlir_module, - config_path=winograd_config, - search_op="conv", - winograd=use_winograd, - ) - - # Dump model dispatches - generates_dir = Path.home() / "tmp" - if not os.path.exists(generates_dir): - os.makedirs(generates_dir) - dump_mlir = generates_dir / "temp.mlir" - dispatch_dir = generates_dir / f"{model_name}_{device_spec}_dispatches" - export_module_to_mlir_file(input_module, dump_mlir) - dump_dispatches( - dump_mlir, - device, - dispatch_dir, - vulkan_target_triple, - use_winograd=use_winograd, - ) - - # Tune each dispatch - dtype = "f16" if args.precision == "fp16" else "f32" - config_filename = f"{model_name}_{device_spec}_configs.json" - - for f_path in os.listdir(dispatch_dir): - if not f_path.endswith(".mlir"): - continue - - model_dir = os.path.join(dispatch_dir, f_path) - - tuner = SharkCodegenTuner( - model_dir, - device, - "random", - args.num_iters, - args.tuned_config_dir, - dtype, - args.search_op, - batch_size=1, - config_filename=config_filename, - use_dispatch=True, - vulkan_target_triple=vulkan_target_triple, - ) - tuner.tune() - - -if __name__ == "__main__": - main() diff --git a/apps/stable_diffusion/scripts/txt2img.py b/apps/stable_diffusion/scripts/txt2img.py deleted file mode 100644 index f425f48c..00000000 --- a/apps/stable_diffusion/scripts/txt2img.py +++ /dev/null @@ -1,88 +0,0 @@ -import torch -import transformers -import time -from apps.stable_diffusion.src import ( - args, - Text2ImagePipeline, - get_schedulers, - set_init_device_flags, - utils, - clear_all, - save_output_img, -) - - -def main(): - if args.clear_all: - clear_all() - - dtype = torch.float32 if args.precision == "fp32" else torch.half - cpu_scheduling = not args.scheduler.startswith("Shark") - set_init_device_flags() - schedulers = get_schedulers(args.hf_model_id) - scheduler_obj = schedulers[args.scheduler] - seed = args.seed - txt2img_obj = Text2ImagePipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=args.precision, - max_length=args.max_length, - batch_size=args.batch_size, - height=args.height, - width=args.width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - use_quantize=args.use_quantize, - ondemand=args.ondemand, - ) - - seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds) - for current_batch in range(args.batch_count): - start_time = time.time() - generated_imgs = txt2img_obj.generate_images( - args.prompts, - args.negative_prompts, - args.batch_size, - args.height, - args.width, - args.steps, - args.guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - total_time = time.time() - start_time - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += ( - f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" - ) - text_output += f"\nscheduler={args.scheduler}, device={args.device}" - text_output += ( - f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}," - ) - text_output += ( - f"seed={seeds[current_batch]}, size={args.height}x{args.width}" - ) - text_output += ( - f", batch size={args.batch_size}, max_length={args.max_length}" - ) - # TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start - text_output += txt2img_obj.log - text_output += f"\nTotal image generation time: {total_time:.4f}sec" - - save_output_img(generated_imgs[0], seed) - print(text_output) - - -if __name__ == "__main__": - main() diff --git a/apps/stable_diffusion/scripts/txt2img_sdxl.py b/apps/stable_diffusion/scripts/txt2img_sdxl.py deleted file mode 100644 index e930c605..00000000 --- a/apps/stable_diffusion/scripts/txt2img_sdxl.py +++ /dev/null @@ -1,96 +0,0 @@ -import torch -import time -from apps.stable_diffusion.src import ( - args, - Text2ImageSDXLPipeline, - get_schedulers, - set_init_device_flags, - utils, - clear_all, - save_output_img, -) - - -def main(): - if args.clear_all: - clear_all() - - # TODO: prompt_embeds and text_embeds form base_model.json requires fixing - args.precision = "fp16" - args.height = 1024 - args.width = 1024 - args.max_length = 77 - args.scheduler = "DDIM" - print( - "Using default supported configuration for SDXL :-\nprecision=fp16, width*height= 1024*1024, max_length=77 and scheduler=DDIM" - ) - dtype = torch.float32 if args.precision == "fp32" else torch.half - cpu_scheduling = not args.scheduler.startswith("Shark") - set_init_device_flags() - schedulers = get_schedulers(args.hf_model_id) - scheduler_obj = schedulers[args.scheduler] - seed = args.seed - txt2img_obj = Text2ImageSDXLPipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=args.precision, - max_length=args.max_length, - batch_size=args.batch_size, - height=args.height, - width=args.width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - use_quantize=args.use_quantize, - ondemand=args.ondemand, - ) - - seeds = utils.batch_seeds(seed, args.batch_count, args.repeatable_seeds) - for current_batch in range(args.batch_count): - start_time = time.time() - generated_imgs = txt2img_obj.generate_images( - args.prompts, - args.negative_prompts, - args.batch_size, - args.height, - args.width, - args.steps, - args.guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - total_time = time.time() - start_time - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += ( - f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" - ) - text_output += f"\nscheduler={args.scheduler}, device={args.device}" - text_output += ( - f"\nsteps={args.steps}, guidance_scale={args.guidance_scale}," - ) - text_output += ( - f"seed={seeds[current_batch]}, size={args.height}x{args.width}" - ) - text_output += ( - f", batch size={args.batch_size}, max_length={args.max_length}" - ) - # TODO: if using --batch_count=x txt2img_obj.log will output on each display every iteration infos from the start - text_output += txt2img_obj.log - text_output += f"\nTotal image generation time: {total_time:.4f}sec" - - save_output_img(generated_imgs[0], seed) - print(text_output) - - -if __name__ == "__main__": - main() diff --git a/apps/stable_diffusion/scripts/upscaler.py b/apps/stable_diffusion/scripts/upscaler.py deleted file mode 100644 index 3ea57033..00000000 --- a/apps/stable_diffusion/scripts/upscaler.py +++ /dev/null @@ -1,92 +0,0 @@ -import torch -import time -from PIL import Image -import transformers -from apps.stable_diffusion.src import ( - args, - UpscalerPipeline, - get_schedulers, - set_init_device_flags, - utils, - clear_all, - save_output_img, -) - - -if __name__ == "__main__": - if args.clear_all: - clear_all() - - if args.img_path is None: - print("Flag --img_path is required.") - exit() - - # When the models get uploaded, it should be defaulted to False. - args.import_mlir = True - - cpu_scheduling = not args.scheduler.startswith("Shark") - dtype = torch.float32 if args.precision == "fp32" else torch.half - set_init_device_flags() - schedulers = get_schedulers(args.hf_model_id) - - scheduler_obj = schedulers[args.scheduler] - image = ( - Image.open(args.img_path) - .convert("RGB") - .resize((args.height, args.width)) - ) - seed = utils.sanitize_seed(args.seed) - # Adjust for height and width based on model - - upscaler_obj = UpscalerPipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - use_lora=args.use_lora, - ddpm_scheduler=schedulers["DDPM"], - ondemand=args.ondemand, - ) - - start_time = time.time() - generated_imgs = upscaler_obj.generate_images( - args.prompts, - args.negative_prompts, - image, - args.batch_size, - args.height, - args.width, - args.steps, - args.noise_level, - args.guidance_scale, - seed, - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - total_time = time.time() - start_time - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += f"\nmodel_id={args.hf_model_id}, ckpt_loc={args.ckpt_loc}" - text_output += f"\nscheduler={args.scheduler}, device={args.device}" - text_output += f"\nsteps={args.steps}, noise_level={args.noise_level}, guidance_scale={args.guidance_scale}, seed={seed}, size={args.height}x{args.width}" - text_output += ( - f", batch size={args.batch_size}, max_length={args.max_length}" - ) - text_output += upscaler_obj.log - text_output += f"\nTotal image generation time: {total_time:.4f}sec" - - extra_info = {"NOISE LEVEL": args.noise_level} - save_output_img(generated_imgs[0], seed, extra_info) - print(text_output) diff --git a/apps/stable_diffusion/shark_sd.spec b/apps/stable_diffusion/shark_sd.spec deleted file mode 100644 index 07b4a81a..00000000 --- a/apps/stable_diffusion/shark_sd.spec +++ /dev/null @@ -1,48 +0,0 @@ -# -*- mode: python ; coding: utf-8 -*- -from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports - -binaries = [] - -block_cipher = None - -a = Analysis( - ['web/index.py'], - pathex=pathex, - binaries=binaries, - datas=datas, - hiddenimports=hiddenimports, - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - win_no_prefer_redirects=False, - win_private_assemblies=False, - cipher=block_cipher, - noarchive=False, - module_collection_mode={ - 'gradio': 'py', # Collect gradio package as source .py files - }, -) -pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) - -exe = EXE( - pyz, - a.scripts, - a.binaries, - a.zipfiles, - a.datas, - [], - name='nodai_shark_studio', - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=False, - upx_exclude=[], - runtime_tmpdir=None, - console=True, - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, -) diff --git a/apps/stable_diffusion/shark_sd_cli.spec b/apps/stable_diffusion/shark_sd_cli.spec deleted file mode 100644 index 21749c8f..00000000 --- a/apps/stable_diffusion/shark_sd_cli.spec +++ /dev/null @@ -1,85 +0,0 @@ -# -*- mode: python ; coding: utf-8 -*- -from PyInstaller.utils.hooks import collect_data_files -from PyInstaller.utils.hooks import collect_submodules -from PyInstaller.utils.hooks import copy_metadata - -import sys ; sys.setrecursionlimit(sys.getrecursionlimit() * 5) - -datas = [] -datas += collect_data_files('torch') -datas += copy_metadata('torch') -datas += copy_metadata('tqdm') -datas += copy_metadata('regex') -datas += copy_metadata('requests') -datas += copy_metadata('packaging') -datas += copy_metadata('filelock') -datas += copy_metadata('numpy') -datas += copy_metadata('tokenizers') -datas += copy_metadata('importlib_metadata') -datas += copy_metadata('torch-mlir') -datas += copy_metadata('omegaconf') -datas += copy_metadata('safetensors') -datas += collect_data_files('diffusers') -datas += collect_data_files('transformers') -datas += collect_data_files('opencv-python') -datas += collect_data_files('pytorch_lightning') -datas += collect_data_files('skimage') -datas += collect_data_files('gradio') -datas += collect_data_files('gradio_client') -datas += collect_data_files('iree') -datas += collect_data_files('google-cloud-storage') -datas += collect_data_files('shark') -datas += collect_data_files('py-cpuinfo') -datas += [ - ( 'src/utils/resources/prompts.json', 'resources' ), - ( 'src/utils/resources/model_db.json', 'resources' ), - ( 'src/utils/resources/opt_flags.json', 'resources' ), - ( 'src/utils/resources/base_model.json', 'resources' ), - ] - -binaries = [] - -block_cipher = None - -hiddenimports = ['shark', 'shark.shark_inference', 'apps'] -hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x] -hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x] - -a = Analysis( - ['scripts/main.py'], - pathex=['.'], - binaries=binaries, - datas=datas, - hiddenimports=hiddenimports, - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - win_no_prefer_redirects=False, - win_private_assemblies=False, - cipher=block_cipher, - noarchive=False, -) -pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) - -exe = EXE( - pyz, - a.scripts, - a.binaries, - a.zipfiles, - a.datas, - [], - name='shark_sd_cli', - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=True, - upx_exclude=[], - runtime_tmpdir=None, - console=True, - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, -) diff --git a/apps/stable_diffusion/shark_studio_imports.py b/apps/stable_diffusion/shark_studio_imports.py deleted file mode 100644 index dfb5bf80..00000000 --- a/apps/stable_diffusion/shark_studio_imports.py +++ /dev/null @@ -1,90 +0,0 @@ -from PyInstaller.utils.hooks import collect_data_files -from PyInstaller.utils.hooks import copy_metadata -from PyInstaller.utils.hooks import collect_submodules - -import sys - -sys.setrecursionlimit(sys.getrecursionlimit() * 5) - -# python path for pyinstaller -pathex = [ - ".", - "./apps/language_models/langchain", - "./apps/language_models/src/pipelines/minigpt4_utils", -] - -# datafiles for pyinstaller -datas = [] -datas += copy_metadata("torch") -datas += copy_metadata("tokenizers") -datas += copy_metadata("tqdm") -datas += copy_metadata("regex") -datas += copy_metadata("requests") -datas += copy_metadata("packaging") -datas += copy_metadata("filelock") -datas += copy_metadata("numpy") -datas += copy_metadata("importlib_metadata") -datas += copy_metadata("torch-mlir") -datas += copy_metadata("omegaconf") -datas += copy_metadata("safetensors") -datas += copy_metadata("Pillow") -datas += copy_metadata("sentencepiece") -datas += copy_metadata("pyyaml") -datas += copy_metadata("huggingface-hub") -datas += copy_metadata("gradio") -datas += collect_data_files("torch") -datas += collect_data_files("tokenizers") -datas += collect_data_files("tiktoken") -datas += collect_data_files("accelerate") -datas += collect_data_files("diffusers") -datas += collect_data_files("transformers") -datas += collect_data_files("pytorch_lightning") -datas += collect_data_files("skimage") -datas += collect_data_files("gradio") -datas += collect_data_files("gradio_client") -datas += collect_data_files("iree") -datas += collect_data_files("shark", include_py_files=True) -datas += collect_data_files("timm", include_py_files=True) -datas += collect_data_files("tqdm") -datas += collect_data_files("tkinter") -datas += collect_data_files("webview") -datas += collect_data_files("sentencepiece") -datas += collect_data_files("jsonschema") -datas += collect_data_files("jsonschema_specifications") -datas += collect_data_files("cpuinfo") -datas += collect_data_files("langchain") -datas += collect_data_files("cv2") -datas += collect_data_files("einops") -datas += [ - ("src/utils/resources/prompts.json", "resources"), - ("src/utils/resources/model_db.json", "resources"), - ("src/utils/resources/opt_flags.json", "resources"), - ("src/utils/resources/base_model.json", "resources"), - ("web/ui/css/*", "ui/css"), - ("web/ui/logos/*", "logos"), - ( - "../language_models/src/pipelines/minigpt4_utils/configs/*", - "minigpt4_utils/configs", - ), - ( - "../language_models/src/pipelines/minigpt4_utils/prompts/*", - "minigpt4_utils/prompts", - ), -] - - -# hidden imports for pyinstaller -hiddenimports = ["shark", "shark.shark_inference", "apps"] -hiddenimports += [x for x in collect_submodules("skimage") if "tests" not in x] -hiddenimports += [x for x in collect_submodules("gradio") if "tests" not in x] -hiddenimports += [ - x for x in collect_submodules("diffusers") if "tests" not in x -] -blacklist = ["tests", "convert"] -hiddenimports += [ - x - for x in collect_submodules("transformers") - if not any(kw in x for kw in blacklist) -] -hiddenimports += [x for x in collect_submodules("iree") if "tests" not in x] -hiddenimports += ["iree._runtime"] diff --git a/apps/stable_diffusion/src/__init__.py b/apps/stable_diffusion/src/__init__.py deleted file mode 100644 index a40bafb7..00000000 --- a/apps/stable_diffusion/src/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from apps.stable_diffusion.src.utils import ( - args, - set_init_device_flags, - prompt_examples, - get_available_devices, - clear_all, - save_output_img, - resize_stencil, -) -from apps.stable_diffusion.src.pipelines import ( - Text2ImagePipeline, - Text2ImageSDXLPipeline, - Image2ImagePipeline, - InpaintPipeline, - OutpaintPipeline, - StencilPipeline, - UpscalerPipeline, -) -from apps.stable_diffusion.src.schedulers import get_schedulers diff --git a/apps/stable_diffusion/src/models/__init__.py b/apps/stable_diffusion/src/models/__init__.py deleted file mode 100644 index 8d8ca717..00000000 --- a/apps/stable_diffusion/src/models/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -from apps.stable_diffusion.src.models.model_wrappers import ( - SharkifyStableDiffusionModel, -) -from apps.stable_diffusion.src.models.opt_params import ( - get_vae_encode, - get_vae, - get_unet, - get_clip, - get_tokenizer, - get_params, - get_variant_version, -) diff --git a/apps/stable_diffusion/src/models/model_wrappers.py b/apps/stable_diffusion/src/models/model_wrappers.py deleted file mode 100644 index 57be6da9..00000000 --- a/apps/stable_diffusion/src/models/model_wrappers.py +++ /dev/null @@ -1,1356 +0,0 @@ -from diffusers import AutoencoderKL, UNet2DConditionModel, ControlNetModel -from transformers import CLIPTextModel, CLIPTextModelWithProjection -from collections import defaultdict -from pathlib import Path -import torch -import safetensors.torch -import traceback -import subprocess -import sys -import os -import requests -from apps.stable_diffusion.src.utils import ( - compile_through_fx, - get_opt_flags, - base_models, - args, - preprocessCKPT, - convert_original_vae, - get_path_to_diffusers_checkpoint, - get_civitai_checkpoint, - fetch_and_update_base_model_id, - get_path_stem, - get_extended_name, - get_stencil_model_id, - update_lora_weight, -) -from shark.shark_downloader import download_public_file -from shark.shark_inference import SharkInference - - -# These shapes are parameter dependent. -def replace_shape_str(shape, max_len, width, height, batch_size): - new_shape = [] - for i in range(len(shape)): - if shape[i] == "max_len": - new_shape.append(max_len) - elif shape[i] == "height": - new_shape.append(height) - elif shape[i] == "width": - new_shape.append(width) - elif isinstance(shape[i], str): - if "*" in shape[i]: - mul_val = int(shape[i].split("*")[0]) - if "batch_size" in shape[i]: - new_shape.append(batch_size * mul_val) - elif "height" in shape[i]: - new_shape.append(height * mul_val) - elif "width" in shape[i]: - new_shape.append(width * mul_val) - elif "/" in shape[i]: - import math - - div_val = int(shape[i].split("/")[1]) - if "batch_size" in shape[i]: - new_shape.append(math.ceil(batch_size / div_val)) - elif "height" in shape[i]: - new_shape.append(math.ceil(height / div_val)) - elif "width" in shape[i]: - new_shape.append(math.ceil(width / div_val)) - elif "+" in shape[i]: - # Currently this case only hits for SDXL. So, in case any other - # case requires this operator, change this. - new_shape.append(height + width) - else: - new_shape.append(shape[i]) - return new_shape - - -def check_compilation(model, model_name): - if not model: - raise Exception( - f"Could not compile {model_name}. Please create an issue with the detailed log at https://github.com/nod-ai/SHARK/issues" - ) - - -def shark_compile_after_ir( - module_name, - device, - vmfb_path, - precision, - ir_path=None, -): - if ir_path: - print(f"[DEBUG] mlir found at {ir_path.absolute()}") - - module = SharkInference( - mlir_module=ir_path, - device=device, - mlir_dialect="tm_tensor", - ) - print(f"Will get extra flag for {module_name} and precision = {precision}") - path = module.save_module( - vmfb_path.parent.absolute(), - vmfb_path.stem, - extra_args=get_opt_flags(module_name, precision=precision), - ) - print(f"Saved {module_name} vmfb at {path}") - module.load_module(path) - return module - - -def process_vmfb_ir_sdxl(extended_model_name, model_name, device, precision): - name_split = extended_model_name.split("_") - if "vae" in model_name: - name_split[5] = "fp32" - extended_model_name_for_vmfb = "_".join(name_split) - extended_model_name_for_mlir = "_".join(name_split[:-1]) - vmfb_path = Path(extended_model_name_for_vmfb + ".vmfb") - if "vulkan" in device: - _device = args.iree_vulkan_target_triple - _device = _device.replace("-", "_") - vmfb_path = Path(extended_model_name_for_vmfb + f"_vulkan.vmfb") - if vmfb_path.exists(): - shark_module = SharkInference( - None, - device=device, - mlir_dialect="tm_tensor", - ) - print(f"loading existing vmfb from: {vmfb_path}") - shark_module.load_module(vmfb_path, extra_args=[]) - return shark_module, None - mlir_path = Path(extended_model_name_for_mlir + ".mlir") - if not mlir_path.exists(): - print(f"Looking into gs://shark_tank/SDXL/mlir/{mlir_path.name}") - download_public_file( - f"gs://shark_tank/SDXL/mlir/{mlir_path.name}", - mlir_path.absolute(), - single_file=True, - ) - if mlir_path.exists(): - return ( - shark_compile_after_ir( - model_name, device, vmfb_path, precision, mlir_path - ), - None, - ) - return None, None - - -class SharkifyStableDiffusionModel: - def __init__( - self, - model_id: str, - custom_weights: str, - custom_vae: str, - precision: str, - max_len: int = 64, - width: int = 512, - height: int = 512, - batch_size: int = 1, - use_base_vae: bool = False, - use_tuned: bool = False, - low_cpu_mem_usage: bool = False, - debug: bool = False, - sharktank_dir: str = "", - generate_vmfb: bool = True, - is_inpaint: bool = False, - is_upscaler: bool = False, - is_sdxl: bool = False, - stencils: list[str] = [], - use_lora: str = "", - use_quantize: str = None, - return_mlir: bool = False, - ): - self.check_params(max_len, width, height) - self.max_len = max_len - self.is_sdxl = is_sdxl - self.height = height // 8 - self.width = width // 8 - self.batch_size = batch_size - self.custom_weights = custom_weights.strip() - self.use_quantize = use_quantize - if custom_weights != "": - if custom_weights.startswith("https://civitai.com/api/"): - # download the checkpoint from civitai if we don't already have it - weights_path = get_civitai_checkpoint(custom_weights) - - # act as if we were given the local file as custom_weights originally - custom_weights = get_path_to_diffusers_checkpoint(weights_path) - self.custom_weights = weights_path - - # needed to ensure webui sets the correct model name metadata - args.ckpt_loc = weights_path - else: - assert custom_weights.lower().endswith( - (".ckpt", ".safetensors") - ), "checkpoint files supported can be any of [.ckpt, .safetensors] type" - custom_weights = get_path_to_diffusers_checkpoint( - custom_weights - ) - - self.model_id = model_id if custom_weights == "" else custom_weights - self.custom_vae = custom_vae - self.precision = precision - self.base_vae = use_base_vae - self.model_name = ( - "_" - + str(batch_size) - + "_" - + str(max_len) - + "_" - + str(height) - + "_" - + str(width) - + "_" - + precision - ) - self.model_namedata = self.model_name - print(f"use_tuned? sharkify: {use_tuned}") - self.use_tuned = use_tuned - if use_tuned: - self.model_name = self.model_name + "_tuned" - self.model_name = self.model_name + "_" + get_path_stem(self.model_id) - self.low_cpu_mem_usage = low_cpu_mem_usage - self.is_inpaint = is_inpaint - self.is_upscaler = is_upscaler - self.stencils = [get_stencil_model_id(x) for x in stencils] - if use_lora != "": - self.model_name = self.model_name + "_" + get_path_stem(use_lora) - self.use_lora = use_lora - - self.model_name = self.get_extended_name_for_all_model() - self.debug = debug - self.sharktank_dir = sharktank_dir - self.generate_vmfb = generate_vmfb - - self.inputs = dict() - self.model_to_run = "" - if self.custom_weights != "": - self.model_to_run = self.custom_weights - assert self.custom_weights.lower().endswith( - (".ckpt", ".safetensors") - ), "checkpoint files supported can be any of [.ckpt, .safetensors] type" - preprocessCKPT(self.custom_weights, self.is_inpaint) - else: - self.model_to_run = args.hf_model_id - self.custom_vae = self.process_custom_vae() - self.base_model_id = fetch_and_update_base_model_id(self.model_to_run) - if self.base_model_id != "" and args.ckpt_loc != "": - args.hf_model_id = self.base_model_id - self.return_mlir = return_mlir - - def get_extended_name_for_all_model(self, model_list=None): - model_name = {} - sub_model_list = [ - "clip", - "clip2", - "unet", - "unet512", - "stencil_unet", - "stencil_unet_512", - "vae", - "vae_encode", - "stencil_adapter", - "stencil_adapter_512", - ] - if model_list is not None: - sub_model_list = model_list - index = 0 - for model in sub_model_list: - sub_model = model - model_config = self.model_name - if "vae" == model: - if self.custom_vae != "": - model_config = model_config + get_path_stem( - self.custom_vae - ) - if self.base_vae: - sub_model = "base_vae" - if "stencil_adapter" in model: - stencil_names = [] - for i, stencil in enumerate(self.stencils): - if stencil is not None: - cnet_config = ( - self.model_namedata - + "_sd15_" - + stencil.split("_")[-1] - ) - stencil_names.append( - get_extended_name(sub_model + cnet_config) - ) - - model_name[model] = stencil_names - else: - model_name[model] = get_extended_name(sub_model + model_config) - index += 1 - - return model_name - - def check_params(self, max_len, width, height): - if not (max_len >= 32 and max_len <= 77): - sys.exit("please specify max_len in the range [32, 77].") - if not (width % 8 == 0 and width >= 128): - sys.exit("width should be greater than 128 and multiple of 8") - if not (height % 8 == 0 and height >= 128): - sys.exit("height should be greater than 128 and multiple of 8") - - # Get the input info for a model i.e. "unet", "clip", "vae", etc. - def get_input_info_for(self, model_info): - dtype_config = {"f32": torch.float32, "i64": torch.int64} - input_map = [] - for inp in model_info: - shape = model_info[inp]["shape"] - dtype = dtype_config[model_info[inp]["dtype"]] - tensor = None - if isinstance(shape, list): - clean_shape = replace_shape_str( - shape, - self.max_len, - self.width, - self.height, - self.batch_size, - ) - if dtype == torch.int64: - tensor = torch.randint(1, 3, tuple(clean_shape)) - else: - tensor = torch.randn(*clean_shape).to(dtype) - elif isinstance(shape, int): - tensor = torch.tensor(shape).to(dtype) - else: - sys.exit("shape isn't specified correctly.") - input_map.append(tensor) - return input_map - - def get_vae_encode(self): - class VaeEncodeModel(torch.nn.Module): - def __init__( - self, model_id=self.model_id, low_cpu_mem_usage=False - ): - super().__init__() - self.vae = AutoencoderKL.from_pretrained( - model_id, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - - def forward(self, input): - latents = self.vae.encode(input).latent_dist.sample() - return 0.18215 * latents - - vae_encode = VaeEncodeModel() - inputs = tuple(self.inputs["vae_encode"]) - is_f16 = ( - True - if not self.is_upscaler and self.precision == "fp16" - else False - ) - shark_vae_encode, vae_encode_mlir = compile_through_fx( - vae_encode, - inputs, - is_f16=is_f16, - use_tuned=self.use_tuned, - extended_model_name=self.model_name["vae_encode"], - extra_args=get_opt_flags("vae", precision=self.precision), - base_model_id=self.base_model_id, - model_name="vae_encode", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_vae_encode, vae_encode_mlir - - def get_vae(self): - class VaeModel(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - base_vae=self.base_vae, - custom_vae=self.custom_vae, - low_cpu_mem_usage=False, - ): - super().__init__() - self.vae = None - if custom_vae == "": - self.vae = AutoencoderKL.from_pretrained( - model_id, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - elif not isinstance(custom_vae, dict): - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - else: - self.vae = AutoencoderKL.from_pretrained( - model_id, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - self.vae.load_state_dict(custom_vae) - self.base_vae = base_vae - - def forward(self, input): - if not self.base_vae: - input = 1 / 0.18215 * input - x = self.vae.decode(input, return_dict=False)[0] - x = (x / 2 + 0.5).clamp(0, 1) - if self.base_vae: - return x - x = x * 255.0 - return x.round() - - vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage) - inputs = tuple(self.inputs["vae"]) - is_f16 = ( - True - if not self.is_upscaler and self.precision == "fp16" - else False - ) - save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"]) - if self.debug: - os.makedirs(save_dir, exist_ok=True) - shark_vae, vae_mlir = compile_through_fx( - vae, - inputs, - is_f16=is_f16, - use_tuned=self.use_tuned, - extended_model_name=self.model_name["vae"], - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("vae", precision=self.precision), - base_model_id=self.base_model_id, - model_name="vae", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_vae, vae_mlir - - def get_vae_sdxl(self): - # TODO: Remove this after convergence with shark_tank. This should just be part of - # opt_params.py. - shark_module_or_none = process_vmfb_ir_sdxl( - self.model_name["vae"], "vae", args.device, self.precision - ) - if shark_module_or_none[0]: - return shark_module_or_none - - class VaeModel(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - base_vae=self.base_vae, - custom_vae=self.custom_vae, - low_cpu_mem_usage=False, - ): - super().__init__() - self.vae = None - if custom_vae == "": - print(f"Loading default vae, with target {model_id}") - self.vae = AutoencoderKL.from_pretrained( - model_id, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - elif not isinstance(custom_vae, dict): - precision = "fp16" if "fp16" in custom_vae else None - print(f"Loading custom vae, with target {custom_vae}") - if os.path.exists(custom_vae): - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - else: - custom_vae = "/".join( - [ - custom_vae.split("/")[-2].split("\\")[-1], - custom_vae.split("/")[-1], - ] - ) - print("Using hub to get custom vae") - try: - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - low_cpu_mem_usage=low_cpu_mem_usage, - variant=precision, - ) - except: - self.vae = AutoencoderKL.from_pretrained( - custom_vae, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - else: - print(f"Loading custom vae, with state {custom_vae}") - self.vae = AutoencoderKL.from_pretrained( - model_id, - subfolder="vae", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - self.vae.load_state_dict(custom_vae) - self.base_vae = base_vae - - def forward(self, latents): - image = self.vae.decode(latents / 0.13025, return_dict=False)[ - 0 - ] - return image - - vae = VaeModel(low_cpu_mem_usage=self.low_cpu_mem_usage) - inputs = tuple(self.inputs["vae"]) - # Make sure the VAE is in float32 mode, as it overflows in float16 as per SDXL - # pipeline. - if not self.custom_vae: - is_f16 = False - elif "16" in self.custom_vae: - is_f16 = True - else: - is_f16 = False - save_dir = os.path.join(self.sharktank_dir, self.model_name["vae"]) - if self.debug: - os.makedirs(save_dir, exist_ok=True) - shark_vae, vae_mlir = compile_through_fx( - vae, - inputs, - is_f16=is_f16, - use_tuned=self.use_tuned, - extended_model_name=self.model_name["vae"], - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("vae", precision=self.precision), - base_model_id=self.base_model_id, - model_name="vae", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_vae, vae_mlir - - def get_controlled_unet(self, use_large=False): - class ControlledUnetModel(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - low_cpu_mem_usage=False, - use_lora=self.use_lora, - ): - super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - if use_lora != "": - update_lora_weight(self.unet, use_lora, "unet") - self.in_channels = self.unet.config.in_channels - self.train(False) - - def forward( - self, - latent, - timestep, - text_embedding, - guidance_scale, - control1, - control2, - control3, - control4, - control5, - control6, - control7, - control8, - control9, - control10, - control11, - control12, - control13, - scale1, - scale2, - scale3, - scale4, - scale5, - scale6, - scale7, - scale8, - scale9, - scale10, - scale11, - scale12, - scale13, - ): - # TODO: Average pooling - db_res_samples = [ - control1, - control2, - control3, - control4, - control5, - control6, - control7, - control8, - control9, - control10, - control11, - control12, - ] - - # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. - db_res_samples = tuple( - [ - control1 * scale1, - control2 * scale2, - control3 * scale3, - control4 * scale4, - control5 * scale5, - control6 * scale6, - control7 * scale7, - control8 * scale8, - control9 * scale9, - control10 * scale10, - control11 * scale11, - control12 * scale12, - ] - ) - mb_res_samples = control13 * scale13 - latents = torch.cat([latent] * 2) - unet_out = self.unet.forward( - latents, - timestep, - encoder_hidden_states=text_embedding, - down_block_additional_residuals=db_res_samples, - mid_block_additional_residual=mb_res_samples, - return_dict=False, - )[0] - noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - return noise_pred - - unet = ControlledUnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage) - is_f16 = True if self.precision == "fp16" else False - - inputs = tuple(self.inputs["unet"]) - model_name = "stencil_unet" - if use_large: - pad = (0, 0) * (len(inputs[2].shape) - 2) - pad = pad + (0, 512 - inputs[2].shape[1]) - inputs = ( - inputs[:2] - + (torch.nn.functional.pad(inputs[2], pad),) - + inputs[3:] - ) - model_name = "stencil_unet_512" - input_mask = [ - True, - True, - True, - False, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - True, - ] - shark_controlled_unet, controlled_unet_mlir = compile_through_fx( - unet, - inputs, - extended_model_name=self.model_name[model_name], - is_f16=is_f16, - f16_input_mask=input_mask, - use_tuned=self.use_tuned, - extra_args=get_opt_flags("unet", precision=self.precision), - base_model_id=self.base_model_id, - model_name=model_name, - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_controlled_unet, controlled_unet_mlir - - def get_control_net(self, stencil_id, use_large=False): - stencil_id = get_stencil_model_id(stencil_id) - adapter_id, base_model_safe_id, ext_model_name = (None, None, None) - print(f"Importing ControlNet adapter from {stencil_id}") - - class StencilControlNetModel(torch.nn.Module): - def __init__(self, model_id=stencil_id, low_cpu_mem_usage=False): - super().__init__() - self.cnet = ControlNetModel.from_pretrained( - model_id, - low_cpu_mem_usage=low_cpu_mem_usage, - ) - self.in_channels = self.cnet.config.in_channels - self.train(False) - - def forward( - self, - latent, - timestep, - text_embedding, - stencil_image_input, - acc1, - acc2, - acc3, - acc4, - acc5, - acc6, - acc7, - acc8, - acc9, - acc10, - acc11, - acc12, - acc13, - ): - # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. - # TODO: guidance NOT NEEDED change in `get_input_info` later - latents = torch.cat( - [latent] * 2 - ) # needs to be same as controlledUNET latents - stencil_image = torch.cat( - [stencil_image_input] * 2 - ) # needs to be same as controlledUNET latents - ( - down_block_res_samples, - mid_block_res_sample, - ) = self.cnet.forward( - latents, - timestep, - encoder_hidden_states=text_embedding, - controlnet_cond=stencil_image, - return_dict=False, - ) - return tuple( - list(down_block_res_samples) + [mid_block_res_sample] - ) + ( - acc1 + down_block_res_samples[0], - acc2 + down_block_res_samples[1], - acc3 + down_block_res_samples[2], - acc4 + down_block_res_samples[3], - acc5 + down_block_res_samples[4], - acc6 + down_block_res_samples[5], - acc7 + down_block_res_samples[6], - acc8 + down_block_res_samples[7], - acc9 + down_block_res_samples[8], - acc10 + down_block_res_samples[9], - acc11 + down_block_res_samples[10], - acc12 + down_block_res_samples[11], - acc13 + mid_block_res_sample, - ) - - scnet = StencilControlNetModel( - low_cpu_mem_usage=self.low_cpu_mem_usage - ) - is_f16 = True if self.precision == "fp16" else False - - inputs = tuple(self.inputs["stencil_adapter"]) - model_name = "stencil_adapter_512" if use_large else "stencil_adapter" - stencil_names = self.get_extended_name_for_all_model([model_name]) - ext_model_name = stencil_names[model_name] - if isinstance(ext_model_name, list): - desired_name = None - print(ext_model_name) - for i in ext_model_name: - if stencil_id.split("_")[-1] in i: - desired_name = i - else: - continue - if desired_name: - ext_model_name = desired_name - else: - raise Exception( - f"Could not find extended configuration for {stencil_id}" - ) - - if use_large: - pad = (0, 0) * (len(inputs[2].shape) - 2) - pad = pad + (0, 512 - inputs[2].shape[1]) - inputs = ( - inputs[0], - inputs[1], - torch.nn.functional.pad(inputs[2], pad), - *inputs[3:], - ) - save_dir = os.path.join(self.sharktank_dir, ext_model_name) - input_mask = [True, True, True, True] + ([True] * 13) - - shark_cnet, cnet_mlir = compile_through_fx( - scnet, - inputs, - extended_model_name=ext_model_name, - is_f16=is_f16, - f16_input_mask=input_mask, - use_tuned=self.use_tuned, - extra_args=get_opt_flags("unet", precision=self.precision), - base_model_id=self.base_model_id, - model_name=model_name, - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_cnet, cnet_mlir - - def get_unet(self, use_large=False): - class UnetModel(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - low_cpu_mem_usage=False, - use_lora=self.use_lora, - ): - super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - if use_lora != "": - update_lora_weight(self.unet, use_lora, "unet") - self.in_channels = self.unet.config.in_channels - self.train(False) - if ( - args.attention_slicing is not None - and args.attention_slicing != "none" - ): - if args.attention_slicing.isdigit(): - self.unet.set_attention_slice( - int(args.attention_slicing) - ) - else: - self.unet.set_attention_slice(args.attention_slicing) - - # TODO: Instead of flattening the `control` try to use the list. - def forward( - self, - latent, - timestep, - text_embedding, - guidance_scale, - ): - # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. - latents = torch.cat([latent] * 2) - unet_out = self.unet.forward( - latents, timestep, text_embedding, return_dict=False - )[0] - noise_pred_uncond, noise_pred_text = unet_out.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - return noise_pred - - unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage) - is_f16 = True if self.precision == "fp16" else False - inputs = tuple(self.inputs["unet"]) - if use_large: - pad = (0, 0) * (len(inputs[2].shape) - 2) - pad = pad + (0, 512 - inputs[2].shape[1]) - inputs = ( - inputs[0], - inputs[1], - torch.nn.functional.pad(inputs[2], pad), - inputs[3], - ) - save_dir = os.path.join( - self.sharktank_dir, self.model_name["unet512"] - ) - else: - save_dir = os.path.join( - self.sharktank_dir, self.model_name["unet"] - ) - input_mask = [True, True, True, False] - if self.debug: - os.makedirs( - save_dir, - exist_ok=True, - ) - model_name = "unet512" if use_large else "unet" - shark_unet, unet_mlir = compile_through_fx( - unet, - inputs, - extended_model_name=self.model_name[model_name], - is_f16=is_f16, - f16_input_mask=input_mask, - use_tuned=self.use_tuned, - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("unet", precision=self.precision), - base_model_id=self.base_model_id, - model_name=model_name, - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_unet, unet_mlir - - def get_unet_upscaler(self, use_large=False): - class UnetModel(torch.nn.Module): - def __init__( - self, model_id=self.model_id, low_cpu_mem_usage=False - ): - super().__init__() - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - self.in_channels = self.unet.in_channels - self.train(False) - - def forward(self, latent, timestep, text_embedding, noise_level): - unet_out = self.unet.forward( - latent, - timestep, - text_embedding, - noise_level, - return_dict=False, - )[0] - return unet_out - - unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage) - is_f16 = True if self.precision == "fp16" else False - inputs = tuple(self.inputs["unet"]) - if use_large: - pad = (0, 0) * (len(inputs[2].shape) - 2) - pad = pad + (0, 512 - inputs[2].shape[1]) - inputs = ( - inputs[0], - inputs[1], - torch.nn.functional.pad(inputs[2], pad), - inputs[3], - ) - input_mask = [True, True, True, False] - model_name = "unet512" if use_large else "unet" - shark_unet, unet_mlir = compile_through_fx( - unet, - inputs, - extended_model_name=self.model_name[model_name], - is_f16=is_f16, - f16_input_mask=input_mask, - use_tuned=self.use_tuned, - extra_args=get_opt_flags("unet", precision=self.precision), - base_model_id=self.base_model_id, - model_name=model_name, - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_unet, unet_mlir - - def get_unet_sdxl(self): - # TODO: Remove this after convergence with shark_tank. This should just be part of - # opt_params.py. - shark_module_or_none = process_vmfb_ir_sdxl( - self.model_name["unet"], "unet", args.device, self.precision - ) - if shark_module_or_none[0]: - return shark_module_or_none - - class UnetModel(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - low_cpu_mem_usage=False, - ): - super().__init__() - try: - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - variant="fp16", - ) - except: - self.unet = UNet2DConditionModel.from_pretrained( - model_id, - subfolder="unet", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - if ( - args.attention_slicing is not None - and args.attention_slicing != "none" - ): - if args.attention_slicing.isdigit(): - self.unet.set_attention_slice( - int(args.attention_slicing) - ) - else: - self.unet.set_attention_slice(args.attention_slicing) - - def forward( - self, - latent, - timestep, - prompt_embeds, - text_embeds, - time_ids, - guidance_scale, - ): - added_cond_kwargs = { - "text_embeds": text_embeds, - "time_ids": time_ids, - } - noise_pred = self.unet.forward( - latent, - timestep, - encoder_hidden_states=prompt_embeds, - cross_attention_kwargs=None, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - return noise_pred - - unet = UnetModel(low_cpu_mem_usage=self.low_cpu_mem_usage) - is_f16 = True if self.precision == "fp16" else False - inputs = tuple(self.inputs["unet"]) - save_dir = os.path.join(self.sharktank_dir, self.model_name["unet"]) - input_mask = [True, True, True, True, True, True] - if self.debug: - os.makedirs( - save_dir, - exist_ok=True, - ) - shark_unet, unet_mlir = compile_through_fx( - unet, - inputs, - extended_model_name=self.model_name["unet"], - is_f16=is_f16, - f16_input_mask=input_mask, - use_tuned=self.use_tuned, - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("unet", precision=self.precision), - base_model_id=self.base_model_id, - model_name="unet", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_unet, unet_mlir - - def get_clip(self): - class CLIPText(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - low_cpu_mem_usage=False, - use_lora=self.use_lora, - ): - super().__init__() - self.text_encoder = CLIPTextModel.from_pretrained( - model_id, - subfolder="text_encoder", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - if use_lora != "": - update_lora_weight( - self.text_encoder, use_lora, "text_encoder" - ) - - def forward(self, input): - return self.text_encoder(input)[0] - - clip_model = CLIPText(low_cpu_mem_usage=self.low_cpu_mem_usage) - save_dir = "" - if self.debug: - save_dir = os.path.join( - self.sharktank_dir, self.model_name["clip"] - ) - os.makedirs( - save_dir, - exist_ok=True, - ) - shark_clip, clip_mlir = compile_through_fx( - clip_model, - tuple(self.inputs["clip"]), - extended_model_name=self.model_name["clip"], - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("clip", precision="fp32"), - base_model_id=self.base_model_id, - model_name="clip", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_clip, clip_mlir - - def get_clip_sdxl(self, clip_index=1): - if clip_index == 1: - extended_model_name = self.model_name["clip"] - model_name = "clip" - else: - extended_model_name = self.model_name["clip2"] - model_name = "clip2" - # TODO: Remove this after convergence with shark_tank. This should just be part of - # opt_params.py. - shark_module_or_none = process_vmfb_ir_sdxl( - extended_model_name, f"clip", args.device, self.precision - ) - if shark_module_or_none[0]: - return shark_module_or_none - - class CLIPText(torch.nn.Module): - def __init__( - self, - model_id=self.model_id, - low_cpu_mem_usage=False, - clip_index=1, - ): - super().__init__() - if clip_index == 1: - self.text_encoder = CLIPTextModel.from_pretrained( - model_id, - subfolder="text_encoder", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - else: - self.text_encoder = ( - CLIPTextModelWithProjection.from_pretrained( - model_id, - subfolder="text_encoder_2", - low_cpu_mem_usage=low_cpu_mem_usage, - ) - ) - - def forward(self, input): - prompt_embeds = self.text_encoder( - input, - output_hidden_states=True, - ) - # We are only ALWAYS interested in the pooled output of the final text encoder - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] - return prompt_embeds, pooled_prompt_embeds - - clip_model = CLIPText( - low_cpu_mem_usage=self.low_cpu_mem_usage, clip_index=clip_index - ) - save_dir = os.path.join(self.sharktank_dir, extended_model_name) - if self.debug: - os.makedirs( - save_dir, - exist_ok=True, - ) - shark_clip, clip_mlir = compile_through_fx( - clip_model, - tuple(self.inputs["clip"]), - extended_model_name=extended_model_name, - debug=self.debug, - generate_vmfb=self.generate_vmfb, - save_dir=save_dir, - extra_args=get_opt_flags("clip", precision="fp32"), - base_model_id=self.base_model_id, - model_name="clip", - precision=self.precision, - return_mlir=self.return_mlir, - ) - return shark_clip, clip_mlir - - def process_custom_vae(self): - custom_vae = self.custom_vae.lower() - if not custom_vae.endswith((".ckpt", ".safetensors")): - return self.custom_vae - try: - preprocessCKPT(self.custom_vae) - return get_path_to_diffusers_checkpoint(self.custom_vae) - except: - print("Processing standalone Vae checkpoint") - vae_checkpoint = None - vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - if custom_vae.endswith(".ckpt"): - vae_checkpoint = torch.load( - self.custom_vae, map_location="cpu" - ) - else: - vae_checkpoint = safetensors.torch.load_file( - self.custom_vae, device="cpu" - ) - if "state_dict" in vae_checkpoint: - vae_checkpoint = vae_checkpoint["state_dict"] - - try: - vae_checkpoint = convert_original_vae(vae_checkpoint) - finally: - vae_dict = { - k: v - for k, v in vae_checkpoint.items() - if k[0:4] != "loss" and k not in vae_ignore_keys - } - return vae_dict - - def compile_unet_variants(self, model, use_large=False, base_model=""): - if self.is_sdxl: - return self.get_unet_sdxl() - if model == "unet": - if self.is_upscaler: - return self.get_unet_upscaler(use_large=use_large) - # TODO: Plug the experimental "int8" support at right place. - elif self.use_quantize == "int8": - from apps.stable_diffusion.src.models.opt_params import ( - get_unet, - ) - - return get_unet() - else: - return self.get_unet(use_large=use_large) - else: - return self.get_controlled_unet(use_large=use_large) - - def vae_encode(self): - try: - self.inputs["vae_encode"] = self.get_input_info_for( - base_models["vae_encode"] - ) - compiled_vae_encode, vae_encode_mlir = self.get_vae_encode() - - check_compilation(compiled_vae_encode, "Vae Encode") - if self.return_mlir: - return vae_encode_mlir - return compiled_vae_encode - except Exception as e: - sys.exit(e) - - def clip(self): - try: - self.inputs["clip"] = self.get_input_info_for(base_models["clip"]) - compiled_clip, clip_mlir = self.get_clip() - - check_compilation(compiled_clip, "Clip") - if self.return_mlir: - return clip_mlir - return compiled_clip - except Exception as e: - sys.exit(e) - - def sdxl_clip(self): - try: - self.inputs["clip"] = self.get_input_info_for( - base_models["sdxl_clip"] - ) - compiled_clip, clip_mlir = self.get_clip_sdxl(clip_index=1) - compiled_clip2, clip_mlir2 = self.get_clip_sdxl(clip_index=2) - - check_compilation(compiled_clip, "Clip") - check_compilation(compiled_clip, "Clip2") - if self.return_mlir: - return clip_mlir, clip_mlir2 - return compiled_clip, compiled_clip2 - except Exception as e: - sys.exit(e) - - def unet(self, use_large=False): - try: - stencil_count = 0 - for stencil in self.stencils: - stencil_count += 1 - model = "stencil_unet" if stencil_count > 0 else "unet" - compiled_unet = None - unet_inputs = base_models[model] - - if self.base_model_id != "": - self.inputs["unet"] = self.get_input_info_for( - unet_inputs[self.base_model_id] - ) - compiled_unet, unet_mlir = self.compile_unet_variants( - model, use_large=use_large, base_model=self.base_model_id - ) - else: - for model_id in unet_inputs: - self.base_model_id = model_id - self.inputs["unet"] = self.get_input_info_for( - unet_inputs[model_id] - ) - - try: - compiled_unet, unet_mlir = self.compile_unet_variants( - model, use_large=use_large, base_model=model_id - ) - except Exception as e: - print(e) - print( - "Retrying with a different base model configuration" - ) - continue - - # -- Once a successful compilation has taken place we'd want to store - # the base model's configuration inferred. - fetch_and_update_base_model_id(self.model_to_run, model_id) - # This is done just because in main.py we are basing the choice of tokenizer and scheduler - # on `args.hf_model_id`. Since now, we don't maintain 1:1 mapping of variants and the base - # model and rely on retrying method to find the input configuration, we should also update - # the knowledge of base model id accordingly into `args.hf_model_id`. - if args.ckpt_loc != "": - args.hf_model_id = model_id - break - - check_compilation(compiled_unet, "Unet") - if self.return_mlir: - return unet_mlir - return compiled_unet - except Exception as e: - sys.exit(e) - - def vae(self): - try: - vae_input = ( - base_models["vae"]["vae_upscaler"] - if self.is_upscaler - else base_models["vae"]["vae"] - ) - self.inputs["vae"] = self.get_input_info_for(vae_input) - - is_base_vae = self.base_vae - if self.is_upscaler: - self.base_vae = True - if self.is_sdxl: - compiled_vae, vae_mlir = self.get_vae_sdxl() - else: - compiled_vae, vae_mlir = self.get_vae() - self.base_vae = is_base_vae - - check_compilation(compiled_vae, "Vae") - if self.return_mlir: - return vae_mlir - return compiled_vae - except Exception as e: - sys.exit(e) - - def controlnet(self, stencil_id, use_large=False): - try: - self.inputs["stencil_adapter"] = self.get_input_info_for( - base_models["stencil_adapter"] - ) - compiled_stencil_adapter, controlnet_mlir = self.get_control_net( - stencil_id, use_large=use_large - ) - - check_compilation(compiled_stencil_adapter, "Stencil") - if self.return_mlir: - return controlnet_mlir - return compiled_stencil_adapter - except Exception as e: - sys.exit(e) diff --git a/apps/stable_diffusion/src/models/opt_params.py b/apps/stable_diffusion/src/models/opt_params.py deleted file mode 100644 index f03897de..00000000 --- a/apps/stable_diffusion/src/models/opt_params.py +++ /dev/null @@ -1,133 +0,0 @@ -import sys -from transformers import CLIPTokenizer -from apps.stable_diffusion.src.utils import ( - models_db, - args, - get_shark_model, - get_opt_flags, -) - - -hf_model_variant_map = { - "Linaqruf/anything-v3.0": ["anythingv3", "v1_4"], - "dreamlike-art/dreamlike-diffusion-1.0": ["dreamlike", "v1_4"], - "prompthero/openjourney": ["openjourney", "v1_4"], - "wavymulder/Analog-Diffusion": ["analogdiffusion", "v1_4"], - "stabilityai/stable-diffusion-2-1": ["stablediffusion", "v2_1base"], - "stabilityai/stable-diffusion-2-1-base": ["stablediffusion", "v2_1base"], - "CompVis/stable-diffusion-v1-4": ["stablediffusion", "v1_4"], - "runwayml/stable-diffusion-inpainting": ["stablediffusion", "inpaint_v1"], - "stabilityai/stable-diffusion-2-inpainting": [ - "stablediffusion", - "inpaint_v2", - ], -} - - -# TODO: Add the quantized model as a part model_db.json. -# This is currently in experimental phase. -def get_quantize_model(): - bucket_key = "gs://shark_tank/prashant_nod" - model_key = "unet_int8" - iree_flags = get_opt_flags("unet", precision="fp16") - if args.height != 512 and args.width != 512 and args.max_length != 77: - sys.exit( - "The int8 quantized model currently requires the height and width to be 512, and max_length to be 77" - ) - return bucket_key, model_key, iree_flags - - -def get_variant_version(hf_model_id): - return hf_model_variant_map[hf_model_id] - - -def get_params(bucket_key, model_key, model, is_tuned, precision): - try: - bucket = models_db[0][bucket_key] - model_name = models_db[1][model_key] - except KeyError: - raise Exception( - f"{bucket_key}/{model_key} is not present in the models database" - ) - iree_flags = get_opt_flags(model, precision="fp16") - return bucket, model_name, iree_flags - - -def get_unet(): - variant, version = get_variant_version(args.hf_model_id) - # Tuned model is present only for `fp16` precision. - is_tuned = "tuned" if args.use_tuned else "untuned" - - # TODO: Get the quantize model from model_db.json - if args.use_quantize == "int8": - bk, mk, flags = get_quantize_model() - return get_shark_model(bk, mk, flags) - - if "vulkan" not in args.device and args.use_tuned: - bucket_key = f"{variant}/{is_tuned}/{args.device}" - model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}/{args.device}" - else: - bucket_key = f"{variant}/{is_tuned}" - model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/{is_tuned}" - - bucket, model_name, iree_flags = get_params( - bucket_key, model_key, "unet", is_tuned, args.precision - ) - return get_shark_model(bucket, model_name, iree_flags) - - -def get_vae_encode(): - variant, version = get_variant_version(args.hf_model_id) - # Tuned model is present only for `fp16` precision. - is_tuned = "tuned" if args.use_tuned else "untuned" - if "vulkan" not in args.device and args.use_tuned: - bucket_key = f"{variant}/{is_tuned}/{args.device}" - model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}/{args.device}" - else: - bucket_key = f"{variant}/{is_tuned}" - model_key = f"{variant}/{version}/vae_encode/{args.precision}/length_77/{is_tuned}" - - bucket, model_name, iree_flags = get_params( - bucket_key, model_key, "vae", is_tuned, args.precision - ) - return get_shark_model(bucket, model_name, iree_flags) - - -def get_vae(): - variant, version = get_variant_version(args.hf_model_id) - # Tuned model is present only for `fp16` precision. - is_tuned = "tuned" if args.use_tuned else "untuned" - is_base = "/base" if args.use_base_vae else "" - if "vulkan" not in args.device and args.use_tuned: - bucket_key = f"{variant}/{is_tuned}/{args.device}" - model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}/{args.device}" - else: - bucket_key = f"{variant}/{is_tuned}" - model_key = f"{variant}/{version}/vae/{args.precision}/length_77/{is_tuned}{is_base}" - - bucket, model_name, iree_flags = get_params( - bucket_key, model_key, "vae", is_tuned, args.precision - ) - return get_shark_model(bucket, model_name, iree_flags) - - -def get_clip(): - variant, version = get_variant_version(args.hf_model_id) - bucket_key = f"{variant}/untuned" - model_key = ( - f"{variant}/{version}/clip/fp32/length_{args.max_length}/untuned" - ) - bucket, model_name, iree_flags = get_params( - bucket_key, model_key, "clip", "untuned", "fp32" - ) - return get_shark_model(bucket, model_name, iree_flags) - - -def get_tokenizer(subfolder="tokenizer", hf_model_id=None): - if hf_model_id is not None: - args.hf_model_id = hf_model_id - - tokenizer = CLIPTokenizer.from_pretrained( - args.hf_model_id, subfolder=subfolder - ) - return tokenizer diff --git a/apps/stable_diffusion/src/pipelines/__init__.py b/apps/stable_diffusion/src/pipelines/__init__.py deleted file mode 100644 index d65921c7..00000000 --- a/apps/stable_diffusion/src/pipelines/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img import ( - Text2ImagePipeline, -) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_txt2img_sdxl import ( - Text2ImageSDXLPipeline, -) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_img2img import ( - Image2ImagePipeline, -) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_inpaint import ( - InpaintPipeline, -) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_outpaint import ( - OutpaintPipeline, -) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_stencil import ( - StencilPipeline, -) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_upscaler import ( - UpscalerPipeline, -) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py deleted file mode 100644 index 9ff4c123..00000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_img2img.py +++ /dev/null @@ -1,231 +0,0 @@ -import torch -import time -import numpy as np -from tqdm.auto import tqdm -from random import randint -from PIL import Image -from transformers import CLIPTokenizer -from typing import Union -from shark.shark_inference import SharkInference -from diffusers import ( - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - DDPMScheduler, - KDPM2DiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - StableDiffusionPipeline, -) -from apps.stable_diffusion.src.models import ( - SharkifyStableDiffusionModel, - get_vae_encode, -) -from apps.stable_diffusion.src.utils import ( - resamplers, - resampler_list, -) - - -class Image2ImagePipeline(StableDiffusionPipeline): - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - DDPMScheduler, - KDPM2DiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - ): - super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) - self.vae_encode = None - - def load_vae_encode(self): - if self.vae_encode is not None: - return - - if self.import_mlir or self.use_lora: - self.vae_encode = self.sd_model.vae_encode() - else: - try: - self.vae_encode = get_vae_encode() - except: - print("download pipeline failed, falling back to import_mlir") - self.vae_encode = self.sd_model.vae_encode() - - def unload_vae_encode(self): - del self.vae_encode - self.vae_encode = None - - def prepare_image_latents( - self, - image, - batch_size, - height, - width, - generator, - num_inference_steps, - strength, - dtype, - resample_type, - ): - # Pre process image -> get image encoded -> process latents - - # TODO: process with variable HxW combos - - # Pre-process image - resample_type = ( - resamplers[resample_type] - if resample_type in resampler_list - # Fallback to Lanczos - else Image.Resampling.LANCZOS - ) - - image = image.resize((width, height), resample=resample_type) - image_arr = np.stack([np.array(i) for i in (image,)], axis=0) - image_arr = image_arr / 255.0 - image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype) - image_arr = 2 * (image_arr - 0.5) - - # set scheduler steps - self.scheduler.set_timesteps(num_inference_steps) - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps - ) - t_start = max(num_inference_steps - init_timestep, 0) - # timesteps reduced as per strength - timesteps = self.scheduler.timesteps[t_start:] - # new number of steps to be used as per strength will be - # num_inference_steps = num_inference_steps - t_start - - # image encode - latents = self.encode_image((image_arr,)) - latents = torch.from_numpy(latents).to(dtype) - # add noise to data - noise = torch.randn(latents.shape, generator=generator, dtype=dtype) - latents = self.scheduler.add_noise( - latents, noise, timesteps[0].repeat(1) - ) - - return latents, timesteps - - def encode_image(self, input_image): - self.load_vae_encode() - vae_encode_start = time.time() - latents = self.vae_encode("forward", input_image) - vae_inf_time = (time.time() - vae_encode_start) * 1000 - if self.ondemand: - self.unload_vae_encode() - self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}" - - return latents - - def generate_images( - self, - prompts, - neg_prompts, - image, - batch_size, - height, - width, - num_inference_steps, - strength, - guidance_scale, - seed, - max_length, - dtype, - use_base_vae, - cpu_scheduling, - max_embeddings_multiples, - stencils, - images, - resample_type, - control_mode, - preprocessed_hints=[], - ): - # prompts and negative prompts must be a list. - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(neg_prompts, str): - neg_prompts = [neg_prompts] - - prompts = prompts * batch_size - neg_prompts = neg_prompts * batch_size - - # seed generator to create the inital latent noise. Also handle out of range seeds. - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - generator = torch.manual_seed(seed) - - # Get text embeddings with weight emphasis from prompts - text_embeddings = self.encode_prompts_weight( - prompts, - neg_prompts, - max_length, - max_embeddings_multiples=max_embeddings_multiples, - ) - - # guidance scale as a float32 tensor. - guidance_scale = torch.tensor(guidance_scale).to(torch.float32) - - # Prepare input image latent - image_latents, final_timesteps = self.prepare_image_latents( - image=image, - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - strength=strength, - dtype=dtype, - resample_type=resample_type, - ) - - # Get Image latents - latents = self.produce_img_latents( - latents=image_latents, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - total_timesteps=final_timesteps, - dtype=dtype, - cpu_scheduling=cpu_scheduling, - ) - - # Img latents -> PIL images - all_imgs = [] - self.load_vae() - for i in tqdm(range(0, latents.shape[0], batch_size)): - imgs = self.decode_latents( - latents=latents[i : i + batch_size], - use_base_vae=use_base_vae, - cpu_scheduling=cpu_scheduling, - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_vae() - - return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_inpaint.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_inpaint.py deleted file mode 100644 index b43c643b..00000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_inpaint.py +++ /dev/null @@ -1,487 +0,0 @@ -import torch -from tqdm.auto import tqdm -import numpy as np -from random import randint -from PIL import Image, ImageOps -from transformers import CLIPTokenizer -from typing import Union -from shark.shark_inference import SharkInference -from diffusers import ( - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - DDPMScheduler, - KDPM2DiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - StableDiffusionPipeline, -) -from apps.stable_diffusion.src.models import ( - SharkifyStableDiffusionModel, - get_vae_encode, -) - - -class InpaintPipeline(StableDiffusionPipeline): - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - DDPMScheduler, - KDPM2DiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - ): - super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) - self.vae_encode = None - - def load_vae_encode(self): - if self.vae_encode is not None: - return - - if self.import_mlir or self.use_lora: - self.vae_encode = self.sd_model.vae_encode() - else: - try: - self.vae_encode = get_vae_encode() - except: - print("download pipeline failed, falling back to import_mlir") - self.vae_encode = self.sd_model.vae_encode() - - def unload_vae_encode(self): - del self.vae_encode - self.vae_encode = None - - def prepare_latents( - self, - batch_size, - height, - width, - generator, - num_inference_steps, - dtype, - ): - latents = torch.randn( - ( - batch_size, - 4, - height // 8, - width // 8, - ), - generator=generator, - dtype=torch.float32, - ).to(dtype) - - self.scheduler.set_timesteps(num_inference_steps) - latents = latents * self.scheduler.init_noise_sigma - return latents - - def get_crop_region(self, mask, pad=0): - h, w = mask.shape - - crop_left = 0 - for i in range(w): - if not (mask[:, i] == 0).all(): - break - crop_left += 1 - - crop_right = 0 - for i in reversed(range(w)): - if not (mask[:, i] == 0).all(): - break - crop_right += 1 - - crop_top = 0 - for i in range(h): - if not (mask[i] == 0).all(): - break - crop_top += 1 - - crop_bottom = 0 - for i in reversed(range(h)): - if not (mask[i] == 0).all(): - break - crop_bottom += 1 - - return ( - int(max(crop_left - pad, 0)), - int(max(crop_top - pad, 0)), - int(min(w - crop_right + pad, w)), - int(min(h - crop_bottom + pad, h)), - ) - - def expand_crop_region( - self, - crop_region, - processing_width, - processing_height, - image_width, - image_height, - ): - x1, y1, x2, y2 = crop_region - - ratio_crop_region = (x2 - x1) / (y2 - y1) - ratio_processing = processing_width / processing_height - - if ratio_crop_region > ratio_processing: - desired_height = (x2 - x1) / ratio_processing - desired_height_diff = int(desired_height - (y2 - y1)) - y1 -= desired_height_diff // 2 - y2 += desired_height_diff - desired_height_diff // 2 - if y2 >= image_height: - diff = y2 - image_height - y2 -= diff - y1 -= diff - if y1 < 0: - y2 -= y1 - y1 -= y1 - if y2 >= image_height: - y2 = image_height - else: - desired_width = (y2 - y1) * ratio_processing - desired_width_diff = int(desired_width - (x2 - x1)) - x1 -= desired_width_diff // 2 - x2 += desired_width_diff - desired_width_diff // 2 - if x2 >= image_width: - diff = x2 - image_width - x2 -= diff - x1 -= diff - if x1 < 0: - x2 -= x1 - x1 -= x1 - if x2 >= image_width: - x2 = image_width - - return x1, y1, x2, y2 - - def resize_image(self, resize_mode, im, width, height): - """ - resize_mode: - 0: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. - 1: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. - """ - - if resize_mode == 0: - ratio = width / height - src_ratio = im.width / im.height - - src_w = ( - width if ratio > src_ratio else im.width * height // im.height - ) - src_h = ( - height if ratio <= src_ratio else im.height * width // im.width - ) - - resized = im.resize((src_w, src_h), resample=Image.LANCZOS) - res = Image.new("RGB", (width, height)) - res.paste( - resized, - box=(width // 2 - src_w // 2, height // 2 - src_h // 2), - ) - - else: - ratio = width / height - src_ratio = im.width / im.height - - src_w = ( - width if ratio < src_ratio else im.width * height // im.height - ) - src_h = ( - height if ratio >= src_ratio else im.height * width // im.width - ) - - resized = im.resize((src_w, src_h), resample=Image.LANCZOS) - res = Image.new("RGB", (width, height)) - res.paste( - resized, - box=(width // 2 - src_w // 2, height // 2 - src_h // 2), - ) - - if ratio < src_ratio: - fill_height = height // 2 - src_h // 2 - res.paste( - resized.resize((width, fill_height), box=(0, 0, width, 0)), - box=(0, 0), - ) - res.paste( - resized.resize( - (width, fill_height), - box=(0, resized.height, width, resized.height), - ), - box=(0, fill_height + src_h), - ) - elif ratio > src_ratio: - fill_width = width // 2 - src_w // 2 - res.paste( - resized.resize( - (fill_width, height), box=(0, 0, 0, height) - ), - box=(0, 0), - ) - res.paste( - resized.resize( - (fill_width, height), - box=(resized.width, 0, resized.width, height), - ), - box=(fill_width + src_w, 0), - ) - - return res - - def prepare_mask_and_masked_image( - self, - image, - mask, - height, - width, - inpaint_full_res, - inpaint_full_res_padding, - ): - # preprocess image - image = image.resize((width, height)) - mask = mask.resize((width, height)) - - paste_to = () - overlay_image = None - if inpaint_full_res: - # prepare overlay image - overlay_image = Image.new("RGB", (image.width, image.height)) - overlay_image.paste( - image.convert("RGB"), - mask=ImageOps.invert(mask.convert("L")), - ) - - # prepare mask - mask = mask.convert("L") - crop_region = self.get_crop_region( - np.array(mask), inpaint_full_res_padding - ) - crop_region = self.expand_crop_region( - crop_region, width, height, mask.width, mask.height - ) - x1, y1, x2, y2 = crop_region - mask = mask.crop(crop_region) - mask = self.resize_image(1, mask, width, height) - paste_to = (x1, y1, x2 - x1, y2 - y1) - - # prepare image - image = image.crop(crop_region) - image = self.resize_image(1, image, width, height) - - if isinstance(image, (Image.Image, np.ndarray)): - image = [image] - - if isinstance(image, list) and isinstance(image[0], Image.Image): - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, (Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], Image.Image): - mask = np.concatenate( - [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 - ) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * (mask < 0.5) - - return mask, masked_image, paste_to, overlay_image - - def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - height, - width, - dtype, - ): - mask = torch.nn.functional.interpolate( - mask, size=(height // 8, width // 8) - ) - mask = mask.to(dtype) - - self.load_vae_encode() - masked_image = masked_image.to(dtype) - masked_image_latents = self.vae_encode("forward", (masked_image,)) - masked_image_latents = torch.from_numpy(masked_image_latents) - if self.ondemand: - self.unload_vae_encode() - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - return mask, masked_image_latents - - def apply_overlay(self, image, paste_loc, overlay): - x, y, w, h = paste_loc - image = self.resize_image(0, image, w, h) - overlay.paste(image, (x, y)) - - return overlay - - def generate_images( - self, - prompts, - neg_prompts, - image, - mask_image, - batch_size, - height, - width, - inpaint_full_res, - inpaint_full_res_padding, - num_inference_steps, - guidance_scale, - seed, - max_length, - dtype, - use_base_vae, - cpu_scheduling, - max_embeddings_multiples, - ): - # prompts and negative prompts must be a list. - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(neg_prompts, str): - neg_prompts = [neg_prompts] - - prompts = prompts * batch_size - neg_prompts = neg_prompts * batch_size - - # seed generator to create the inital latent noise. Also handle out of range seeds. - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - generator = torch.manual_seed(seed) - - # Get initial latents - init_latents = self.prepare_latents( - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - dtype=dtype, - ) - - # Get text embeddings with weight emphasis from prompts - text_embeddings = self.encode_prompts_weight( - prompts, - neg_prompts, - max_length, - max_embeddings_multiples=max_embeddings_multiples, - ) - - # guidance scale as a float32 tensor. - guidance_scale = torch.tensor(guidance_scale).to(torch.float32) - - # Preprocess mask and image - ( - mask, - masked_image, - paste_to, - overlay_image, - ) = self.prepare_mask_and_masked_image( - image, - mask_image, - height, - width, - inpaint_full_res, - inpaint_full_res_padding, - ) - - # Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask=mask, - masked_image=masked_image, - batch_size=batch_size, - height=height, - width=width, - dtype=dtype, - ) - - # Get Image latents - latents = self.produce_img_latents( - latents=init_latents, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - total_timesteps=self.scheduler.timesteps, - dtype=dtype, - cpu_scheduling=cpu_scheduling, - mask=mask, - masked_image_latents=masked_image_latents, - ) - - # Img latents -> PIL images - all_imgs = [] - self.load_vae() - for i in tqdm(range(0, latents.shape[0], batch_size)): - imgs = self.decode_latents( - latents=latents[i : i + batch_size], - use_base_vae=use_base_vae, - cpu_scheduling=cpu_scheduling, - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_vae() - - if inpaint_full_res: - output_image = self.apply_overlay( - all_imgs[0], paste_to, overlay_image - ) - return [output_image] - - return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_outpaint.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_outpaint.py deleted file mode 100644 index 982bec6d..00000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_outpaint.py +++ /dev/null @@ -1,581 +0,0 @@ -import torch -from tqdm.auto import tqdm -import numpy as np -from random import randint -from PIL import Image, ImageDraw, ImageFilter -from transformers import CLIPTokenizer -from typing import Union -from shark.shark_inference import SharkInference -from diffusers import ( - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - DDPMScheduler, - KDPM2DiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - StableDiffusionPipeline, -) -import math -from apps.stable_diffusion.src.models import ( - SharkifyStableDiffusionModel, - get_vae_encode, -) - - -class OutpaintPipeline(StableDiffusionPipeline): - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - DDPMScheduler, - KDPM2DiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - ): - super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) - self.vae_encode = None - - def load_vae_encode(self): - if self.vae_encode is not None: - return - - if self.import_mlir or self.use_lora: - self.vae_encode = self.sd_model.vae_encode() - else: - try: - self.vae_encode = get_vae_encode() - except: - print("download pipeline failed, falling back to import_mlir") - self.vae_encode = self.sd_model.vae_encode() - - def unload_vae_encode(self): - del self.vae_encode - self.vae_encode = None - - def prepare_latents( - self, - batch_size, - height, - width, - generator, - num_inference_steps, - dtype, - ): - latents = torch.randn( - ( - batch_size, - 4, - height // 8, - width // 8, - ), - generator=generator, - dtype=torch.float32, - ).to(dtype) - - self.scheduler.set_timesteps(num_inference_steps) - latents = latents * self.scheduler.init_noise_sigma - return latents - - def prepare_mask_and_masked_image( - self, image, mask, mask_blur, width, height - ): - if mask_blur > 0: - mask = mask.filter(ImageFilter.GaussianBlur(mask_blur)) - image = image.resize((width, height)) - mask = mask.resize((width, height)) - - # preprocess image - if isinstance(image, (Image.Image, np.ndarray)): - image = [image] - - if isinstance(image, list) and isinstance(image[0], Image.Image): - image = [np.array(i.convert("RGB"))[None, :] for i in image] - image = np.concatenate(image, axis=0) - elif isinstance(image, list) and isinstance(image[0], np.ndarray): - image = np.concatenate([i[None, :] for i in image], axis=0) - - image = image.transpose(0, 3, 1, 2) - image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0 - - # preprocess mask - if isinstance(mask, (Image.Image, np.ndarray)): - mask = [mask] - - if isinstance(mask, list) and isinstance(mask[0], Image.Image): - mask = np.concatenate( - [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0 - ) - mask = mask.astype(np.float32) / 255.0 - elif isinstance(mask, list) and isinstance(mask[0], np.ndarray): - mask = np.concatenate([m[None, None, :] for m in mask], axis=0) - - mask[mask < 0.5] = 0 - mask[mask >= 0.5] = 1 - mask = torch.from_numpy(mask) - - masked_image = image * (mask < 0.5) - - return mask, masked_image - - def prepare_mask_latents( - self, - mask, - masked_image, - batch_size, - height, - width, - dtype, - ): - mask = torch.nn.functional.interpolate( - mask, size=(height // 8, width // 8) - ) - mask = mask.to(dtype) - - self.load_vae_encode() - masked_image = masked_image.to(dtype) - masked_image_latents = self.vae_encode("forward", (masked_image,)) - masked_image_latents = torch.from_numpy(masked_image_latents) - if self.ondemand: - self.unload_vae_encode() - - # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method - if mask.shape[0] < batch_size: - if not batch_size % mask.shape[0] == 0: - raise ValueError( - "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to" - f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number" - " of masks that you pass is divisible by the total requested batch size." - ) - mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1) - if masked_image_latents.shape[0] < batch_size: - if not batch_size % masked_image_latents.shape[0] == 0: - raise ValueError( - "The passed images and the required batch size don't match. Images are supposed to be duplicated" - f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed." - " Make sure the number of images that you pass is divisible by the total requested batch size." - ) - masked_image_latents = masked_image_latents.repeat( - batch_size // masked_image_latents.shape[0], 1, 1, 1 - ) - return mask, masked_image_latents - - def get_matched_noise( - self, _np_src_image, np_mask_rgb, noise_q=1, color_variation=0.05 - ): - # helper fft routines that keep ortho normalization and auto-shift before and after fft - def _fft2(data): - if data.ndim > 2: # has channels - out_fft = np.zeros( - (data.shape[0], data.shape[1], data.shape[2]), - dtype=np.complex128, - ) - for c in range(data.shape[2]): - c_data = data[:, :, c] - out_fft[:, :, c] = np.fft.fft2( - np.fft.fftshift(c_data), norm="ortho" - ) - out_fft[:, :, c] = np.fft.ifftshift(out_fft[:, :, c]) - else: # one channel - out_fft = np.zeros( - (data.shape[0], data.shape[1]), dtype=np.complex128 - ) - out_fft[:, :] = np.fft.fft2( - np.fft.fftshift(data), norm="ortho" - ) - out_fft[:, :] = np.fft.ifftshift(out_fft[:, :]) - - return out_fft - - def _ifft2(data): - if data.ndim > 2: # has channels - out_ifft = np.zeros( - (data.shape[0], data.shape[1], data.shape[2]), - dtype=np.complex128, - ) - for c in range(data.shape[2]): - c_data = data[:, :, c] - out_ifft[:, :, c] = np.fft.ifft2( - np.fft.fftshift(c_data), norm="ortho" - ) - out_ifft[:, :, c] = np.fft.ifftshift(out_ifft[:, :, c]) - else: # one channel - out_ifft = np.zeros( - (data.shape[0], data.shape[1]), dtype=np.complex128 - ) - out_ifft[:, :] = np.fft.ifft2( - np.fft.fftshift(data), norm="ortho" - ) - out_ifft[:, :] = np.fft.ifftshift(out_ifft[:, :]) - - return out_ifft - - def _get_gaussian_window(width, height, std=3.14, mode=0): - window_scale_x = float(width / min(width, height)) - window_scale_y = float(height / min(width, height)) - - window = np.zeros((width, height)) - x = (np.arange(width) / width * 2.0 - 1.0) * window_scale_x - for y in range(height): - fy = (y / height * 2.0 - 1.0) * window_scale_y - if mode == 0: - window[:, y] = np.exp(-(x**2 + fy**2) * std) - else: - window[:, y] = ( - 1 / ((x**2 + 1.0) * (fy**2 + 1.0)) - ) ** (std / 3.14) - - return window - - def _get_masked_window_rgb(np_mask_grey, hardness=1.0): - np_mask_rgb = np.zeros( - (np_mask_grey.shape[0], np_mask_grey.shape[1], 3) - ) - if hardness != 1.0: - hardened = np_mask_grey[:] ** hardness - else: - hardened = np_mask_grey[:] - for c in range(3): - np_mask_rgb[:, :, c] = hardened[:] - return np_mask_rgb - - def _match_cumulative_cdf(source, template): - src_values, src_unique_indices, src_counts = np.unique( - source.ravel(), return_inverse=True, return_counts=True - ) - tmpl_values, tmpl_counts = np.unique( - template.ravel(), return_counts=True - ) - - # calculate normalized quantiles for each array - src_quantiles = np.cumsum(src_counts) / source.size - tmpl_quantiles = np.cumsum(tmpl_counts) / template.size - - interp_a_values = np.interp( - src_quantiles, tmpl_quantiles, tmpl_values - ) - return interp_a_values[src_unique_indices].reshape(source.shape) - - def _match_histograms(image, reference): - if image.ndim != reference.ndim: - raise ValueError( - "Image and reference must have the same number of channels." - ) - - if image.shape[-1] != reference.shape[-1]: - raise ValueError( - "Number of channels in the input image and reference image must match!" - ) - - matched = np.empty(image.shape, dtype=image.dtype) - for channel in range(image.shape[-1]): - matched_channel = _match_cumulative_cdf( - image[..., channel], reference[..., channel] - ) - matched[..., channel] = matched_channel - - matched = matched.astype(np.float64, copy=False) - return matched - - width = _np_src_image.shape[0] - height = _np_src_image.shape[1] - num_channels = _np_src_image.shape[2] - - np_src_image = _np_src_image[:] * (1.0 - np_mask_rgb) - np_mask_grey = np.sum(np_mask_rgb, axis=2) / 3.0 - img_mask = np_mask_grey > 1e-6 - ref_mask = np_mask_grey < 1e-3 - - # rather than leave the masked area black, we get better results from fft by filling the average unmasked color - windowed_image = _np_src_image * ( - 1.0 - _get_masked_window_rgb(np_mask_grey) - ) - windowed_image /= np.max(windowed_image) - windowed_image += np.average(_np_src_image) * np_mask_rgb - - src_fft = _fft2( - windowed_image - ) # get feature statistics from masked src img - src_dist = np.absolute(src_fft) - src_phase = src_fft / src_dist - - # create a generator with a static seed to make outpainting deterministic / only follow global seed - rng = np.random.default_rng(0) - - noise_window = _get_gaussian_window( - width, height, mode=1 - ) # start with simple gaussian noise - noise_rgb = rng.random((width, height, num_channels)) - noise_grey = np.sum(noise_rgb, axis=2) / 3.0 - # the colorfulness of the starting noise is blended to greyscale with a parameter - noise_rgb *= color_variation - for c in range(num_channels): - noise_rgb[:, :, c] += (1.0 - color_variation) * noise_grey - - noise_fft = _fft2(noise_rgb) - for c in range(num_channels): - noise_fft[:, :, c] *= noise_window - noise_rgb = np.real(_ifft2(noise_fft)) - shaped_noise_fft = _fft2(noise_rgb) - shaped_noise_fft[:, :, :] = ( - np.absolute(shaped_noise_fft[:, :, :]) ** 2 - * (src_dist**noise_q) - * src_phase - ) # perform the actual shaping - - # color_variation - brightness_variation = 0.0 - contrast_adjusted_np_src = ( - _np_src_image[:] * (brightness_variation + 1.0) - - brightness_variation * 2.0 - ) - - shaped_noise = np.real(_ifft2(shaped_noise_fft)) - shaped_noise -= np.min(shaped_noise) - shaped_noise /= np.max(shaped_noise) - shaped_noise[img_mask, :] = _match_histograms( - shaped_noise[img_mask, :] ** 1.0, - contrast_adjusted_np_src[ref_mask, :], - ) - shaped_noise = ( - _np_src_image[:] * (1.0 - np_mask_rgb) + shaped_noise * np_mask_rgb - ) - - matched_noise = shaped_noise[:] - - return np.clip(matched_noise, 0.0, 1.0) - - def generate_images( - self, - prompts, - neg_prompts, - image, - pixels, - mask_blur, - is_left, - is_right, - is_top, - is_bottom, - noise_q, - color_variation, - batch_size, - height, - width, - num_inference_steps, - guidance_scale, - seed, - max_length, - dtype, - use_base_vae, - cpu_scheduling, - max_embeddings_multiples, - ): - # prompts and negative prompts must be a list. - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(neg_prompts, str): - neg_prompts = [neg_prompts] - - prompts = prompts * batch_size - neg_prompts = neg_prompts * batch_size - - # seed generator to create the inital latent noise. Also handle out of range seeds. - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - generator = torch.manual_seed(seed) - - # Get initial latents - init_latents = self.prepare_latents( - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - dtype=dtype, - ) - - # Get text embeddings with weight emphasis from prompts - text_embeddings = self.encode_prompts_weight( - prompts, - neg_prompts, - max_length, - max_embeddings_multiples=max_embeddings_multiples, - ) - - # guidance scale as a float32 tensor. - guidance_scale = torch.tensor(guidance_scale).to(torch.float32) - - process_width = width - process_height = height - left = pixels if is_left else 0 - right = pixels if is_right else 0 - up = pixels if is_top else 0 - down = pixels if is_bottom else 0 - target_w = math.ceil((image.width + left + right) / 64) * 64 - target_h = math.ceil((image.height + up + down) / 64) * 64 - - if left > 0: - left = left * (target_w - image.width) // (left + right) - if right > 0: - right = target_w - image.width - left - if up > 0: - up = up * (target_h - image.height) // (up + down) - if down > 0: - down = target_h - image.height - up - - def expand( - init_img, - expand_pixels, - is_left=False, - is_right=False, - is_top=False, - is_bottom=False, - ): - is_horiz = is_left or is_right - is_vert = is_top or is_bottom - pixels_horiz = expand_pixels if is_horiz else 0 - pixels_vert = expand_pixels if is_vert else 0 - - res_w = init_img.width + pixels_horiz - res_h = init_img.height + pixels_vert - process_res_w = math.ceil(res_w / 64) * 64 - process_res_h = math.ceil(res_h / 64) * 64 - - img = Image.new("RGB", (process_res_w, process_res_h)) - img.paste( - init_img, - (pixels_horiz if is_left else 0, pixels_vert if is_top else 0), - ) - - msk = Image.new("RGB", (process_res_w, process_res_h), "white") - draw = ImageDraw.Draw(msk) - draw.rectangle( - ( - expand_pixels + mask_blur if is_left else 0, - expand_pixels + mask_blur if is_top else 0, - msk.width - expand_pixels - mask_blur - if is_right - else res_w, - msk.height - expand_pixels - mask_blur - if is_bottom - else res_h, - ), - fill="black", - ) - - np_image = (np.asarray(img) / 255.0).astype(np.float64) - np_mask = (np.asarray(msk) / 255.0).astype(np.float64) - noised = self.get_matched_noise( - np_image, np_mask, noise_q, color_variation - ) - output_image = Image.fromarray( - np.clip(noised * 255.0, 0.0, 255.0).astype(np.uint8), - mode="RGB", - ) - - target_width = ( - min(width, init_img.width + pixels_horiz) - if is_horiz - else img.width - ) - target_height = ( - min(height, init_img.height + pixels_vert) - if is_vert - else img.height - ) - crop_region = ( - 0 if is_left else output_image.width - target_width, - 0 if is_top else output_image.height - target_height, - target_width if is_left else output_image.width, - target_height if is_top else output_image.height, - ) - mask_to_process = msk.crop(crop_region) - image_to_process = output_image.crop(crop_region) - - # Preprocess mask and image - mask, masked_image = self.prepare_mask_and_masked_image( - image_to_process, mask_to_process, mask_blur, width, height - ) - - # Prepare mask latent variables - mask, masked_image_latents = self.prepare_mask_latents( - mask=mask, - masked_image=masked_image, - batch_size=batch_size, - height=height, - width=width, - dtype=dtype, - ) - - # Get Image latents - latents = self.produce_img_latents( - latents=init_latents, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - total_timesteps=self.scheduler.timesteps, - dtype=dtype, - cpu_scheduling=cpu_scheduling, - mask=mask, - masked_image_latents=masked_image_latents, - ) - - # Img latents -> PIL images - all_imgs = [] - self.load_vae() - for i in tqdm(range(0, latents.shape[0], batch_size)): - imgs = self.decode_latents( - latents=latents[i : i + batch_size], - use_base_vae=use_base_vae, - cpu_scheduling=cpu_scheduling, - ) - all_imgs.extend(imgs) - - res_img = all_imgs[0].resize( - (image_to_process.width, image_to_process.height) - ) - output_image.paste( - res_img, - ( - 0 if is_left else output_image.width - res_img.width, - 0 if is_top else output_image.height - res_img.height, - ), - ) - output_image = output_image.crop((0, 0, res_w, res_h)) - - return output_image - - img = image.resize((width, height)) - if left > 0: - img = expand(img, left, is_left=True) - if right > 0: - img = expand(img, right, is_right=True) - if up > 0: - img = expand(img, up, is_top=True) - if down > 0: - img = expand(img, down, is_bottom=True) - - return [img] diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py deleted file mode 100644 index 51fd51f9..00000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_stencil.py +++ /dev/null @@ -1,603 +0,0 @@ -import torch -import time -import numpy as np -from tqdm.auto import tqdm -from random import randint -from PIL import Image -from transformers import CLIPTokenizer -from typing import Union -from shark.shark_inference import SharkInference -from diffusers import ( - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - DDPMScheduler, - KDPM2DiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - StableDiffusionPipeline, -) -from apps.stable_diffusion.src.utils import ( - controlnet_hint_conversion, - controlnet_hint_reshaping, -) -from apps.stable_diffusion.src.utils import ( - start_profiling, - end_profiling, -) -from apps.stable_diffusion.src.utils import ( - resamplers, - resampler_list, -) -from apps.stable_diffusion.src.models import ( - SharkifyStableDiffusionModel, - get_vae_encode, -) - - -class StencilPipeline(StableDiffusionPipeline): - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - DDPMScheduler, - KDPM2DiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - controlnet_names: list[str], - ): - super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) - self.controlnet = [None] * len(controlnet_names) - self.controlnet_512 = [None] * len(controlnet_names) - self.controlnet_id = [str] * len(controlnet_names) - self.controlnet_512_id = [str] * len(controlnet_names) - self.controlnet_names = controlnet_names - self.vae_encode = None - - def load_vae_encode(self): - if self.vae_encode is not None: - return - - if self.import_mlir or self.use_lora: - self.vae_encode = self.sd_model.vae_encode() - else: - try: - self.vae_encode = get_vae_encode() - except: - print("download pipeline failed, falling back to import_mlir") - self.vae_encode = self.sd_model.vae_encode() - - def unload_vae_encode(self): - del self.vae_encode - self.vae_encode = None - - def load_controlnet(self, index, model_name): - if model_name is None: - return - if ( - self.controlnet[index] is not None - and self.controlnet_id[index] is not None - and self.controlnet_id[index] == model_name - ): - return - self.controlnet_id[index] = model_name - self.controlnet[index] = self.sd_model.controlnet(model_name) - - def unload_controlnet(self, index): - del self.controlnet[index] - self.controlnet_id[index] = None - self.controlnet[index] = None - - def load_controlnet_512(self, index, model_name): - if ( - self.controlnet_512[index] is not None - and self.controlnet_512_id[index] == model_name - ): - return - self.controlnet_512_id[index] = model_name - self.controlnet_512[index] = self.sd_model.controlnet( - model_name, use_large=True - ) - - def unload_controlnet_512(self, index): - del self.controlnet_512[index] - self.controlnet_512_id[index] = None - self.controlnet_512[index] = None - - def prepare_latents( - self, - batch_size, - height, - width, - generator, - num_inference_steps, - dtype, - ): - latents = torch.randn( - ( - batch_size, - 4, - height // 8, - width // 8, - ), - generator=generator, - dtype=torch.float32, - ).to(dtype) - - self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.is_scale_input_called = True - latents = latents * self.scheduler.init_noise_sigma - return latents - - def prepare_image_latents( - self, - image, - batch_size, - height, - width, - generator, - num_inference_steps, - strength, - dtype, - resample_type, - ): - # Pre process image -> get image encoded -> process latents - - # TODO: process with variable HxW combos - - # Pre-process image - resample_type = ( - resamplers[resample_type] - if resample_type in resampler_list - # Fallback to Lanczos - else Image.Resampling.LANCZOS - ) - - image = image.resize((width, height), resample=resample_type) - image_arr = np.stack([np.array(i) for i in (image,)], axis=0) - image_arr = image_arr / 255.0 - image_arr = torch.from_numpy(image_arr).permute(0, 3, 1, 2).to(dtype) - image_arr = 2 * (image_arr - 0.5) - - # set scheduler steps - self.scheduler.set_timesteps(num_inference_steps) - init_timestep = min( - int(num_inference_steps * strength), num_inference_steps - ) - t_start = max(num_inference_steps - init_timestep, 0) - # timesteps reduced as per strength - timesteps = self.scheduler.timesteps[t_start:] - # new number of steps to be used as per strength will be - # num_inference_steps = num_inference_steps - t_start - - # image encode - latents = self.encode_image((image_arr,)) - latents = torch.from_numpy(latents).to(dtype) - # add noise to data - noise = torch.randn(latents.shape, generator=generator, dtype=dtype) - latents = self.scheduler.add_noise( - latents, noise, timesteps[0].repeat(1) - ) - - return latents, timesteps - - def produce_stencil_latents( - self, - latents, - text_embeddings, - guidance_scale, - total_timesteps, - dtype, - cpu_scheduling, - stencil_hints=[None], - controlnet_conditioning_scale: float = 1.0, - control_mode="Balanced", # Prompt, Balanced, or Controlnet - mask=None, - masked_image_latents=None, - return_all_latents=False, - ): - step_time_sum = 0 - latent_history = [latents] - text_embeddings = torch.from_numpy(text_embeddings).to(dtype) - text_embeddings_numpy = text_embeddings.detach().numpy() - assert control_mode in ["Prompt", "Balanced", "Controlnet"] - if text_embeddings.shape[1] <= self.model_max_length: - self.load_unet() - else: - self.load_unet_512() - - for i, name in enumerate(self.controlnet_names): - use_names = [] - if name is not None: - use_names.append(name) - else: - continue - if text_embeddings.shape[1] <= self.model_max_length: - self.load_controlnet(i, name) - else: - self.load_controlnet_512(i, name) - self.controlnet_names = use_names - - for i, t in tqdm(enumerate(total_timesteps)): - step_start_time = time.time() - timestep = torch.tensor([t]).to(dtype) - latent_model_input = self.scheduler.scale_model_input(latents, t) - if mask is not None and masked_image_latents is not None: - latent_model_input = torch.cat( - [ - torch.from_numpy(np.asarray(latent_model_input)), - mask, - masked_image_latents, - ], - dim=1, - ).to(dtype) - if cpu_scheduling: - latent_model_input = latent_model_input.detach().numpy() - - if not torch.is_tensor(latent_model_input): - latent_model_input_1 = torch.from_numpy( - np.asarray(latent_model_input) - ).to(dtype) - else: - latent_model_input_1 = latent_model_input - - # Multicontrolnet - width = latent_model_input_1.shape[2] - height = latent_model_input_1.shape[3] - dtype = latent_model_input_1.dtype - control_acc = ( - [torch.zeros((2, 320, height, width), dtype=dtype)] * 3 - + [ - torch.zeros( - (2, 320, int(height / 2), int(width / 2)), dtype=dtype - ) - ] - + [ - torch.zeros( - (2, 640, int(height / 2), int(width / 2)), dtype=dtype - ) - ] - * 2 - + [ - torch.zeros( - (2, 640, int(height / 4), int(width / 4)), dtype=dtype - ) - ] - + [ - torch.zeros( - (2, 1280, int(height / 4), int(width / 4)), dtype=dtype - ) - ] - * 2 - + [ - torch.zeros( - (2, 1280, int(height / 8), int(width / 8)), dtype=dtype - ) - ] - * 4 - ) - for i, controlnet_hint in enumerate(stencil_hints): - if controlnet_hint is None: - pass - if text_embeddings.shape[1] <= self.model_max_length: - control = self.controlnet[i]( - "forward", - ( - latent_model_input_1, - timestep, - text_embeddings, - controlnet_hint, - *control_acc, - ), - send_to_host=False, - ) - else: - control = self.controlnet_512[i]( - "forward", - ( - latent_model_input_1, - timestep, - text_embeddings, - controlnet_hint, - *control_acc, - ), - send_to_host=False, - ) - control_acc = control[13:] - control = control[:13] - - timestep = timestep.detach().numpy() - # Profiling Unet. - profile_device = start_profiling(file_path="unet.rdc") - # TODO: Pass `control` as it is to Unet. Same as TODO mentioned in model_wrappers.py. - - dtype = latents.dtype - if control_mode == "Balanced": - control_scale = [ - torch.tensor(1.0, dtype=dtype) for _ in range(len(control)) - ] - elif control_mode == "Prompt": - control_scale = [ - torch.tensor(0.825**x, dtype=dtype) - for x in range(len(control)) - ] - elif control_mode == "Controlnet": - control_scale = [ - torch.tensor(float(guidance_scale), dtype=dtype) - for _ in range(len(control)) - ] - - if text_embeddings.shape[1] <= self.model_max_length: - noise_pred = self.unet( - "forward", - ( - latent_model_input, - timestep, - text_embeddings_numpy, - guidance_scale, - control[0], - control[1], - control[2], - control[3], - control[4], - control[5], - control[6], - control[7], - control[8], - control[9], - control[10], - control[11], - control[12], - control_scale[0], - control_scale[1], - control_scale[2], - control_scale[3], - control_scale[4], - control_scale[5], - control_scale[6], - control_scale[7], - control_scale[8], - control_scale[9], - control_scale[10], - control_scale[11], - control_scale[12], - ), - send_to_host=False, - ) - else: - noise_pred = self.unet_512( - "forward", - ( - latent_model_input, - timestep, - text_embeddings_numpy, - guidance_scale, - control[0], - control[1], - control[2], - control[3], - control[4], - control[5], - control[6], - control[7], - control[8], - control[9], - control[10], - control[11], - control[12], - control_scale[0], - control_scale[1], - control_scale[2], - control_scale[3], - control_scale[4], - control_scale[5], - control_scale[6], - control_scale[7], - control_scale[8], - control_scale[9], - control_scale[10], - control_scale[11], - control_scale[12], - ), - send_to_host=False, - ) - end_profiling(profile_device) - - if cpu_scheduling: - noise_pred = torch.from_numpy(noise_pred.to_host()) - latents = self.scheduler.step( - noise_pred, t, latents - ).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents) - - latent_history.append(latents) - step_time = (time.time() - step_start_time) * 1000 - # self.log += ( - # f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms" - # ) - step_time_sum += step_time - - if self.ondemand: - self.unload_unet() - self.unload_unet_512() - for i in range(len(self.controlnet_names)): - self.unload_controlnet(i) - self.unload_controlnet_512(i) - avg_step_time = step_time_sum / len(total_timesteps) - self.log += f"\nAverage step time: {avg_step_time}ms/it" - - if not return_all_latents: - return latents - all_latents = torch.cat(latent_history, dim=0) - return all_latents - - def encode_image(self, input_image): - self.load_vae_encode() - vae_encode_start = time.time() - latents = self.vae_encode("forward", input_image) - vae_inf_time = (time.time() - vae_encode_start) * 1000 - if self.ondemand: - self.unload_vae_encode() - self.log += f"\nVAE Encode Inference time (ms): {vae_inf_time:.3f}" - - return latents - - def generate_images( - self, - prompts, - neg_prompts, - image, - batch_size, - height, - width, - num_inference_steps, - strength, - guidance_scale, - seed, - max_length, - dtype, - use_base_vae, - cpu_scheduling, - max_embeddings_multiples, - stencils, - stencil_images, - resample_type, - control_mode, - preprocessed_hints, - ): - # Control Embedding check & conversion - # controlnet_hint = controlnet_hint_conversion( - # image, use_stencil, height, width, dtype, num_images_per_prompt=1 - # ) - stencil_hints = [] - self.sd_model.stencils = stencils - for i, hint in enumerate(preprocessed_hints): - if hint is not None: - hint = controlnet_hint_reshaping( - hint, - height, - width, - dtype, - num_images_per_prompt=1, - ) - stencil_hints.append(hint) - - for i, stencil in enumerate(stencils): - if stencil == None: - continue - if len(stencil_hints) > i: - if stencil_hints[i] is not None: - print(f"Using preprocessed controlnet hint for {stencil}") - continue - image = stencil_images[i] - stencil_hints.append( - controlnet_hint_conversion( - image, - stencil, - height, - width, - dtype, - num_images_per_prompt=1, - ) - ) - - # prompts and negative prompts must be a list. - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(neg_prompts, str): - neg_prompts = [neg_prompts] - - prompts = prompts * batch_size - neg_prompts = neg_prompts * batch_size - - # seed generator to create the inital latent noise. Also handle out of range seeds. - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - generator = torch.manual_seed(seed) - - # Get text embeddings with weight emphasis from prompts - text_embeddings = self.encode_prompts_weight( - prompts, - neg_prompts, - max_length, - max_embeddings_multiples=max_embeddings_multiples, - ) - - # guidance scale as a float32 tensor. - guidance_scale = torch.tensor(guidance_scale).to(torch.float32) - if image is not None: - # Prepare input image latent - init_latents, final_timesteps = self.prepare_image_latents( - image=image, - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - strength=strength, - dtype=dtype, - resample_type=resample_type, - ) - else: - # Prepare initial latent. - init_latents = self.prepare_latents( - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - dtype=dtype, - ) - final_timesteps = self.scheduler.timesteps - - # Get Image latents - latents = self.produce_stencil_latents( - latents=init_latents, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - total_timesteps=final_timesteps, - dtype=dtype, - cpu_scheduling=cpu_scheduling, - control_mode=control_mode, - stencil_hints=stencil_hints, - ) - - # Img latents -> PIL images - all_imgs = [] - self.load_vae() - for i in tqdm(range(0, latents.shape[0], batch_size)): - imgs = self.decode_latents( - latents=latents[i : i + batch_size], - use_base_vae=use_base_vae, - cpu_scheduling=cpu_scheduling, - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_vae() - - return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py deleted file mode 100644 index 51a790bb..00000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img.py +++ /dev/null @@ -1,159 +0,0 @@ -import torch -import numpy as np -from random import randint -from transformers import CLIPTokenizer -from typing import Union -from shark.shark_inference import SharkInference -from diffusers import ( - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers import ( - SharkEulerDiscreteScheduler, - SharkEulerAncestralDiscreteScheduler, -) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - StableDiffusionPipeline, -) -from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel - - -class Text2ImagePipeline(StableDiffusionPipeline): - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - ): - super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) - - def prepare_latents( - self, - batch_size, - height, - width, - generator, - num_inference_steps, - dtype, - ): - latents = torch.randn( - ( - batch_size, - 4, - height // 8, - width // 8, - ), - generator=generator, - dtype=torch.float32, - ).to(dtype) - - self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.is_scale_input_called = True - latents = latents * self.scheduler.init_noise_sigma - return latents - - def generate_images( - self, - prompts, - neg_prompts, - batch_size, - height, - width, - num_inference_steps, - guidance_scale, - seed, - max_length, - dtype, - use_base_vae, - cpu_scheduling, - max_embeddings_multiples, - ): - # prompts and negative prompts must be a list. - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(neg_prompts, str): - neg_prompts = [neg_prompts] - - prompts = prompts * batch_size - neg_prompts = neg_prompts * batch_size - - # seed generator to create the inital latent noise. Also handle out of range seeds. - # TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly? - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - generator = torch.manual_seed(seed) - - # Get initial latents - init_latents = self.prepare_latents( - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - dtype=dtype, - ) - - # Get text embeddings with weight emphasis from prompts - text_embeddings = self.encode_prompts_weight( - prompts, - neg_prompts, - max_length, - max_embeddings_multiples=max_embeddings_multiples, - ) - - # guidance scale as a float32 tensor. - guidance_scale = torch.tensor(guidance_scale).to(torch.float32) - - # Get Image latents - latents = self.produce_img_latents( - latents=init_latents, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - total_timesteps=self.scheduler.timesteps, - dtype=dtype, - cpu_scheduling=cpu_scheduling, - ) - - # Img latents -> PIL images - all_imgs = [] - self.load_vae() - for i in range(0, latents.shape[0], batch_size): - imgs = self.decode_latents( - latents=latents[i : i + batch_size], - use_base_vae=use_base_vae, - cpu_scheduling=cpu_scheduling, - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_vae() - - return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py deleted file mode 100644 index a3b52793..00000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_txt2img_sdxl.py +++ /dev/null @@ -1,220 +0,0 @@ -import torch -import numpy as np -from random import randint -from typing import Union -from diffusers import ( - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers import ( - SharkEulerDiscreteScheduler, - SharkEulerAncestralDiscreteScheduler, -) -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - StableDiffusionPipeline, -) -from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel -from transformers.utils import logging - -logger = logging.get_logger(__name__) - - -class Text2ImageSDXLPipeline(StableDiffusionPipeline): - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - SharkEulerAncestralDiscreteScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - is_fp32_vae: bool, - ): - super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) - self.is_fp32_vae = is_fp32_vae - - def prepare_latents( - self, - batch_size, - height, - width, - generator, - num_inference_steps, - dtype, - ): - latents = torch.randn( - ( - batch_size, - 4, - height // 8, - width // 8, - ), - generator=generator, - dtype=torch.float32, - ).to(dtype) - - self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.is_scale_input_called = True - latents = latents * self.scheduler.init_noise_sigma - return latents - - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, dtype - ): - add_time_ids = list( - original_size + crops_coords_top_left + target_size - ) - - # self.unet.config.addition_time_embed_dim IS 256. - # self.text_encoder_2.config.projection_dim IS 1280. - passed_add_embed_dim = 256 * len(add_time_ids) + 1280 - expected_add_embed_dim = 2816 - # self.unet.add_embedding.linear_1.in_features IS 2816. - - if expected_add_embed_dim != passed_add_embed_dim: - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." - ) - - add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - return add_time_ids - - def generate_images( - self, - prompts, - neg_prompts, - batch_size, - height, - width, - num_inference_steps, - guidance_scale, - seed, - max_length, - dtype, - use_base_vae, - cpu_scheduling, - max_embeddings_multiples, - ): - # prompts and negative prompts must be a list. - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(neg_prompts, str): - neg_prompts = [neg_prompts] - - prompts = prompts * batch_size - neg_prompts = neg_prompts * batch_size - - # seed generator to create the inital latent noise. Also handle out of range seeds. - # TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly? - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - generator = torch.manual_seed(seed) - - # Get initial latents. - init_latents = self.prepare_latents( - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - dtype=dtype, - ) - - # Get text embeddings. - ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) = self.encode_prompt_sdxl( - prompt=prompts, - num_images_per_prompt=1, - do_classifier_free_guidance=True, - negative_prompt=neg_prompts, - ) - - # Prepare timesteps. - self.scheduler.set_timesteps(num_inference_steps) - - timesteps = self.scheduler.timesteps - - # Prepare added time ids & embeddings. - original_size = (height, width) - target_size = (height, width) - crops_coords_top_left = (0, 0) - add_text_embeds = pooled_prompt_embeds - add_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - dtype=prompt_embeds.dtype, - ) - - prompt_embeds = torch.cat( - [negative_prompt_embeds, prompt_embeds], dim=0 - ) - add_text_embeds = torch.cat( - [negative_pooled_prompt_embeds, add_text_embeds], dim=0 - ) - add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) - - prompt_embeds = prompt_embeds - add_text_embeds = add_text_embeds.to(dtype) - add_time_ids = add_time_ids.repeat(batch_size * 1, 1) - - # guidance scale as a float32 tensor. - guidance_scale = torch.tensor(guidance_scale).to(dtype) - prompt_embeds = prompt_embeds.to(dtype) - add_time_ids = add_time_ids.to(dtype) - - # Get Image latents. - latents = self.produce_img_latents_sdxl( - init_latents, - timesteps, - add_text_embeds, - add_time_ids, - prompt_embeds, - cpu_scheduling, - guidance_scale, - dtype, - ) - - # Img latents -> PIL images. - all_imgs = [] - self.load_vae() - for i in range(0, latents.shape[0], batch_size): - imgs = self.decode_latents_sdxl( - latents[i : i + batch_size], is_fp32_vae=self.is_fp32_vae - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_vae() - - return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_upscaler.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_upscaler.py deleted file mode 100644 index da78e82f..00000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_upscaler.py +++ /dev/null @@ -1,357 +0,0 @@ -import inspect -import torch -import time -from tqdm.auto import tqdm -import numpy as np -from random import randint -from transformers import CLIPTokenizer -from typing import Union -from shark.shark_inference import SharkInference -from diffusers import ( - DDIMScheduler, - DDPMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers import SharkEulerDiscreteScheduler -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - SD_STATE_IDLE, - SD_STATE_CANCEL, - StableDiffusionPipeline, -) -from apps.stable_diffusion.src.utils import ( - start_profiling, - end_profiling, -) -from PIL import Image -from apps.stable_diffusion.src.models import SharkifyStableDiffusionModel - - -def preprocess(image): - if isinstance(image, torch.Tensor): - return image - elif isinstance(image, Image.Image): - image = [image] - - if isinstance(image[0], Image.Image): - w, h = image[0].size - w, h = map( - lambda x: x - x % 64, (w, h) - ) # resize to integer multiple of 64 - - image = [np.array(i.resize((w, h)))[None, :] for i in image] - image = np.concatenate(image, axis=0) - image = np.array(image).astype(np.float32) / 255.0 - image = image.transpose(0, 3, 1, 2) - image = 2.0 * image - 1.0 - image = torch.from_numpy(image) - elif isinstance(image[0], torch.Tensor): - image = torch.cat(image, dim=0) - return image - - -class UpscalerPipeline(StableDiffusionPipeline): - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - ], - low_res_scheduler: Union[ - DDIMScheduler, - DDPMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2DiscreteScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - ): - super().__init__(scheduler, sd_model, import_mlir, use_lora, ondemand) - self.low_res_scheduler = low_res_scheduler - self.status = SD_STATE_IDLE - - def prepare_extra_step_kwargs(self, generator, eta): - accepts_eta = "eta" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set( - inspect.signature(self.scheduler.step).parameters.keys() - ) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def decode_latents(self, latents, use_base_vae, cpu_scheduling): - latents = 1 / 0.08333 * (latents.float()) - latents_numpy = latents - if cpu_scheduling: - latents_numpy = latents.detach().numpy() - - profile_device = start_profiling(file_path="vae.rdc") - vae_start = time.time() - images = self.vae("forward", (latents_numpy,)) - vae_inf_time = (time.time() - vae_start) * 1000 - end_profiling(profile_device) - self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}" - - images = torch.from_numpy(images) - images = (images.detach().cpu() * 255.0).numpy() - images = images.round() - - images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1) - pil_images = [Image.fromarray(image) for image in images.numpy()] - return pil_images - - def prepare_latents( - self, - batch_size, - height, - width, - generator, - num_inference_steps, - dtype, - ): - latents = torch.randn( - ( - batch_size, - 4, - height, - width, - ), - generator=generator, - dtype=torch.float32, - ).to(dtype) - - self.scheduler.set_timesteps(num_inference_steps) - self.scheduler.is_scale_input_called = True - latents = latents * self.scheduler.init_noise_sigma - return latents - - def produce_img_latents( - self, - latents, - image, - text_embeddings, - guidance_scale, - noise_level, - total_timesteps, - dtype, - cpu_scheduling, - extra_step_kwargs, - return_all_latents=False, - ): - step_time_sum = 0 - latent_history = [latents] - text_embeddings = torch.from_numpy(text_embeddings).to(dtype) - text_embeddings_numpy = text_embeddings.detach().numpy() - self.status = SD_STATE_IDLE - if text_embeddings.shape[1] <= self.model_max_length: - self.load_unet() - else: - self.load_unet_512() - for i, t in tqdm(enumerate(total_timesteps)): - step_start_time = time.time() - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) - latent_model_input = torch.cat([latent_model_input, image], dim=1) - timestep = torch.tensor([t]).to(dtype).detach().numpy() - if cpu_scheduling: - latent_model_input = latent_model_input.detach().numpy() - - # Profiling Unet. - profile_device = start_profiling(file_path="unet.rdc") - if text_embeddings.shape[1] <= self.model_max_length: - noise_pred = self.unet( - "forward", - ( - latent_model_input, - timestep, - text_embeddings_numpy, - noise_level, - ), - ) - else: - noise_pred = self.unet_512( - "forward", - ( - latent_model_input, - timestep, - text_embeddings_numpy, - noise_level, - ), - ) - end_profiling(profile_device) - noise_pred = torch.from_numpy(noise_pred) - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * ( - noise_pred_text - noise_pred_uncond - ) - - if cpu_scheduling: - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ).prev_sample - else: - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs - ) - - latent_history.append(latents) - step_time = (time.time() - step_start_time) * 1000 - # self.log += ( - # f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms" - # ) - step_time_sum += step_time - - if self.status == SD_STATE_CANCEL: - break - - if self.ondemand: - self.unload_unet() - self.unload_unet_512() - avg_step_time = step_time_sum / len(total_timesteps) - self.log += f"\nAverage step time: {avg_step_time}ms/it" - - if not return_all_latents: - return latents - all_latents = torch.cat(latent_history, dim=0) - return all_latents - - def generate_images( - self, - prompts, - neg_prompts, - image, - batch_size, - height, - width, - num_inference_steps, - noise_level, - guidance_scale, - seed, - max_length, - dtype, - use_base_vae, - cpu_scheduling, - max_embeddings_multiples, - ): - # prompts and negative prompts must be a list. - if isinstance(prompts, str): - prompts = [prompts] - - if isinstance(neg_prompts, str): - neg_prompts = [neg_prompts] - - prompts = prompts * batch_size - neg_prompts = neg_prompts * batch_size - - # seed generator to create the inital latent noise. Also handle out of range seeds. - # TODO: Wouldn't it be preferable to just report an error instead of modifying the seed on the fly? - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - generator = torch.manual_seed(seed) - - # Get text embeddings with weight emphasis from prompts - text_embeddings = self.encode_prompts_weight( - prompts, - neg_prompts, - max_length, - max_embeddings_multiples=max_embeddings_multiples, - ) - - # 4. Preprocess image - image = preprocess(image).to(dtype) - - # 5. Add noise to image - noise_level = torch.tensor([noise_level], dtype=torch.long) - noise = torch.randn( - image.shape, - generator=generator, - ).to(dtype) - image = self.low_res_scheduler.add_noise(image, noise, noise_level) - image = torch.cat([image] * 2) - noise_level = torch.cat([noise_level] * image.shape[0]) - - height, width = image.shape[2:] - # Get initial latents - init_latents = self.prepare_latents( - batch_size=batch_size, - height=height, - width=width, - generator=generator, - num_inference_steps=num_inference_steps, - dtype=dtype, - ) - - eta = 0.0 - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # guidance scale as a float32 tensor. - # guidance_scale = torch.tensor(guidance_scale).to(torch.float32) - - # Get Image latents - latents = self.produce_img_latents( - latents=init_latents, - image=image, - text_embeddings=text_embeddings, - guidance_scale=guidance_scale, - noise_level=noise_level, - total_timesteps=self.scheduler.timesteps, - dtype=dtype, - cpu_scheduling=cpu_scheduling, - extra_step_kwargs=extra_step_kwargs, - ) - - # Img latents -> PIL images - all_imgs = [] - self.load_vae() - for i in tqdm(range(0, latents.shape[0], batch_size)): - imgs = self.decode_latents( - latents=latents[i : i + batch_size], - use_base_vae=use_base_vae, - cpu_scheduling=cpu_scheduling, - ) - all_imgs.extend(imgs) - if self.ondemand: - self.unload_vae() - - return all_imgs diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py deleted file mode 100644 index 6cfda0cb..00000000 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ /dev/null @@ -1,1264 +0,0 @@ -import torch -import numpy as np -from transformers import CLIPTokenizer -from PIL import Image -from tqdm.auto import tqdm -import time -from typing import Union -from diffusers import ( - DDIMScheduler, - DDPMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, -) -from shark.shark_inference import SharkInference -from apps.stable_diffusion.src.schedulers import ( - SharkEulerDiscreteScheduler, - SharkEulerAncestralDiscreteScheduler, -) -from apps.stable_diffusion.src.models import ( - SharkifyStableDiffusionModel, - get_vae, - get_clip, - get_unet, - get_tokenizer, -) -from apps.stable_diffusion.src.utils import ( - start_profiling, - end_profiling, -) -import sys -import gc -from typing import List, Optional - -SD_STATE_IDLE = "idle" -SD_STATE_CANCEL = "cancel" - - -class StableDiffusionPipeline: - def __init__( - self, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - SharkEulerAncestralDiscreteScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - ], - sd_model: SharkifyStableDiffusionModel, - import_mlir: bool, - use_lora: str, - ondemand: bool, - is_f32_vae: bool = False, - ): - self.vae = None - self.text_encoder = None - self.text_encoder_2 = None - self.unet = None - self.unet_512 = None - self.model_max_length = 77 - # TODO: Implement using logging python utility. - self.log = "" - self.status = SD_STATE_IDLE - self.sd_model = sd_model - self.scheduler = scheduler - self.import_mlir = import_mlir - self.use_lora = use_lora - self.ondemand = ondemand - self.is_f32_vae = is_f32_vae - # TODO: Find a better workaround for fetching base_model_id early - # enough for CLIPTokenizer. - try: - self.tokenizer = get_tokenizer() - except: - self.load_unet() - self.unload_unet() - self.tokenizer = get_tokenizer() - - def load_clip(self): - if self.text_encoder is not None: - return - - if self.import_mlir or self.use_lora: - if not self.import_mlir: - print( - "Warning: LoRA provided but import_mlir not specified. " - "Importing MLIR anyways." - ) - self.text_encoder = self.sd_model.clip() - else: - try: - self.text_encoder = get_clip() - except Exception as e: - print(e) - print("download pipeline failed, falling back to import_mlir") - self.text_encoder = self.sd_model.clip() - - def unload_clip(self): - del self.text_encoder - self.text_encoder = None - - def load_clip_sdxl(self): - if self.text_encoder and self.text_encoder_2: - return - - if self.import_mlir or self.use_lora: - if not self.import_mlir: - print( - "Warning: LoRA provided but import_mlir not specified. " - "Importing MLIR anyways." - ) - self.text_encoder, self.text_encoder_2 = self.sd_model.sdxl_clip() - else: - try: - # TODO: Fix this for SDXL - self.text_encoder = get_clip() - except Exception as e: - print(e) - print("download pipeline failed, falling back to import_mlir") - ( - self.text_encoder, - self.text_encoder_2, - ) = self.sd_model.sdxl_clip() - - def unload_clip_sdxl(self): - del self.text_encoder, self.text_encoder_2 - self.text_encoder = None - self.text_encoder_2 = None - - def load_unet(self): - if self.unet is not None: - return - - if self.import_mlir or self.use_lora: - self.unet = self.sd_model.unet() - else: - try: - self.unet = get_unet() - except Exception as e: - print(e) - print("download pipeline failed, falling back to import_mlir") - self.unet = self.sd_model.unet() - - def unload_unet(self): - del self.unet - self.unet = None - - def load_unet_512(self): - if self.unet_512 is not None: - return - - if self.import_mlir or self.use_lora: - self.unet_512 = self.sd_model.unet(use_large=True) - else: - try: - self.unet_512 = get_unet(use_large=True) - except Exception as e: - print(e) - print("download pipeline failed, falling back to import_mlir") - self.unet_512 = self.sd_model.unet(use_large=True) - - def unload_unet_512(self): - del self.unet_512 - self.unet_512 = None - - def load_vae(self): - if self.vae is not None: - return - - if self.import_mlir or self.use_lora: - self.vae = self.sd_model.vae() - else: - try: - self.vae = get_vae() - except Exception as e: - print(e) - print("download pipeline failed, falling back to import_mlir") - self.vae = self.sd_model.vae() - - def unload_vae(self): - del self.vae - self.vae = None - gc.collect() - - def encode_prompt_sdxl( - self, - prompt: str, - num_images_per_prompt: int = 1, - do_classifier_free_guidance: bool = True, - negative_prompt: Optional[str] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, - hf_model_id: Optional[ - str - ] = "stabilityai/stable-diffusion-xl-base-1.0", - ): - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - # Define tokenizers and text encoders - self.tokenizer_2 = get_tokenizer("tokenizer_2", hf_model_id) - self.load_clip_sdxl() - tokenizers = ( - [self.tokenizer, self.tokenizer_2] - if self.tokenizer is not None - else [self.tokenizer_2] - ) - text_encoders = ( - [self.text_encoder, self.text_encoder_2] - if self.text_encoder is not None - else [self.text_encoder_2] - ) - - # textual inversion: procecss multi-vector tokens if necessary - prompt_embeds_list = [] - prompts = [prompt, prompt] - for prompt, tokenizer, text_encoder in zip( - prompts, tokenizers, text_encoders - ): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - untruncated_ids = tokenizer( - prompt, padding="longest", return_tensors="pt" - ).input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[ - -1 - ] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = tokenizer.batch_decode( - untruncated_ids[:, tokenizer.model_max_length - 1 : -1] - ) - print( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {tokenizer.model_max_length} tokens: {removed_text}" - ) - - text_encoder_output = text_encoder("forward", (text_input_ids,)) - prompt_embeds = torch.from_numpy(text_encoder_output[0]) - pooled_prompt_embeds = torch.from_numpy(text_encoder_output[1]) - - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - # get unconditional embeddings for classifier free guidance - zero_out_negative_prompt = ( - negative_prompt is None - and self.config.force_zeros_for_empty_prompt - ) - if ( - do_classifier_free_guidance - and negative_prompt_embeds is None - and zero_out_negative_prompt - ): - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like( - pooled_prompt_embeds - ) - elif do_classifier_free_guidance and negative_prompt_embeds is None: - negative_prompt = negative_prompt or "" - negative_prompt_2 = negative_prompt - - uncond_tokens: List[str] - if prompt is not None and type(prompt) is not type( - negative_prompt - ): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt, negative_prompt_2] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = [negative_prompt, negative_prompt_2] - - negative_prompt_embeds_list = [] - for negative_prompt, tokenizer, text_encoder in zip( - uncond_tokens, tokenizers, text_encoders - ): - max_length = prompt_embeds.shape[1] - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - text_encoder_output = text_encoder( - "forward", (uncond_input.input_ids,) - ) - negative_prompt_embeds = torch.from_numpy( - text_encoder_output[0] - ) - negative_pooled_prompt_embeds = torch.from_numpy( - text_encoder_output[1] - ) - - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat( - negative_prompt_embeds_list, dim=-1 - ) - - if self.ondemand: - self.unload_clip_sdxl() - gc.collect() - - # TODO: Look into dtype for text_encoder_2! - prompt_embeds = prompt_embeds.to(dtype=torch.float16) - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view( - bs_embed * num_images_per_prompt, seq_len, -1 - ) - - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=torch.float32) - negative_prompt_embeds = negative_prompt_embeds.repeat( - 1, num_images_per_prompt, 1 - ) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( - 1, num_images_per_prompt - ).view(bs_embed * num_images_per_prompt, -1) - - return ( - prompt_embeds, - negative_prompt_embeds, - pooled_prompt_embeds, - negative_pooled_prompt_embeds, - ) - - def encode_prompts(self, prompts, neg_prompts, max_length): - # Tokenize text and get embeddings - text_input = self.tokenizer( - prompts, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - # Get unconditional embeddings as well - uncond_input = self.tokenizer( - neg_prompts, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - text_input = torch.cat([uncond_input.input_ids, text_input.input_ids]) - - self.load_clip() - clip_inf_start = time.time() - text_embeddings = self.text_encoder("forward", (text_input,)) - clip_inf_time = (time.time() - clip_inf_start) * 1000 - if self.ondemand: - self.unload_clip() - gc.collect() - self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}" - - return text_embeddings - - def decode_latents(self, latents, use_base_vae, cpu_scheduling): - if use_base_vae: - latents = 1 / 0.18215 * latents - - latents_numpy = latents - if cpu_scheduling: - latents_numpy = latents.detach().numpy() - - profile_device = start_profiling(file_path="vae.rdc") - vae_start = time.time() - images = self.vae("forward", (latents_numpy,)) - vae_inf_time = (time.time() - vae_start) * 1000 - end_profiling(profile_device) - self.log += f"\nVAE Inference time (ms): {vae_inf_time:.3f}" - - if use_base_vae: - images = torch.from_numpy(images) - images = (images.detach().cpu() * 255.0).numpy() - images = images.round() - - images = torch.from_numpy(images).to(torch.uint8).permute(0, 2, 3, 1) - pil_images = [Image.fromarray(image) for image in images.numpy()] - return pil_images - - def produce_img_latents( - self, - latents, - text_embeddings, - guidance_scale, - total_timesteps, - dtype, - cpu_scheduling, - mask=None, - masked_image_latents=None, - return_all_latents=False, - ): - self.status = SD_STATE_IDLE - step_time_sum = 0 - latent_history = [latents] - text_embeddings = torch.from_numpy(text_embeddings).to(dtype) - text_embeddings_numpy = text_embeddings.detach().numpy() - if text_embeddings.shape[1] <= self.model_max_length: - self.load_unet() - else: - self.load_unet_512() - for i, t in tqdm(enumerate(total_timesteps)): - step_start_time = time.time() - timestep = torch.tensor([t]).to(dtype).detach().numpy() - latent_model_input = self.scheduler.scale_model_input(latents, t) - if mask is not None and masked_image_latents is not None: - latent_model_input = torch.cat( - [ - torch.from_numpy(np.asarray(latent_model_input)), - mask, - masked_image_latents, - ], - dim=1, - ).to(dtype) - if cpu_scheduling: - latent_model_input = latent_model_input.detach().numpy() - - # Profiling Unet. - profile_device = start_profiling(file_path="unet.rdc") - if text_embeddings.shape[1] <= self.model_max_length: - noise_pred = self.unet( - "forward", - ( - latent_model_input, - timestep, - text_embeddings_numpy, - guidance_scale, - ), - send_to_host=False, - ) - else: - noise_pred = self.unet_512( - "forward", - ( - latent_model_input, - timestep, - text_embeddings_numpy, - guidance_scale, - ), - send_to_host=False, - ) - end_profiling(profile_device) - - if cpu_scheduling: - noise_pred = torch.from_numpy(noise_pred.to_host()) - latents = self.scheduler.step( - noise_pred, t, latents - ).prev_sample - else: - latents = self.scheduler.step(noise_pred, t, latents) - - latent_history.append(latents) - step_time = (time.time() - step_start_time) * 1000 - # self.log += ( - # f"\nstep = {i} | timestep = {t} | time = {step_time:.2f}ms" - # ) - step_time_sum += step_time - - if self.status == SD_STATE_CANCEL: - break - - if self.ondemand: - self.unload_unet() - self.unload_unet_512() - gc.collect() - - avg_step_time = step_time_sum / len(total_timesteps) - self.log += f"\nAverage step time: {avg_step_time}ms/it" - - if not return_all_latents: - return latents - all_latents = torch.cat(latent_history, dim=0) - return all_latents - - def produce_img_latents_sdxl( - self, - latents, - total_timesteps, - add_text_embeds, - add_time_ids, - prompt_embeds, - cpu_scheduling, - guidance_scale, - dtype, - mask=None, - masked_image_latents=None, - return_all_latents=False, - ): - # return None - self.status = SD_STATE_IDLE - step_time_sum = 0 - extra_step_kwargs = {"generator": None} - self.load_unet() - for i, t in tqdm(enumerate(total_timesteps)): - step_start_time = time.time() - timestep = torch.tensor([t]).to(dtype).detach().numpy() - # expand the latents if we are doing classifier free guidance - if isinstance(latents, np.ndarray): - latents = torch.tensor(latents) - latent_model_input = torch.cat([latents] * 2) - - latent_model_input = self.scheduler.scale_model_input( - latent_model_input, t - ) - if mask is not None and masked_image_latents is not None: - latent_model_input = torch.cat( - [ - torch.from_numpy(np.asarray(latent_model_input)), - mask, - masked_image_latents, - ], - dim=1, - ).to(dtype) - - noise_pred = self.unet( - "forward", - ( - latent_model_input, - timestep, - prompt_embeds, - add_text_embeds, - add_time_ids, - guidance_scale, - ), - send_to_host=True, - ) - if not isinstance(latents, torch.Tensor): - latents = torch.from_numpy(latents).to("cpu") - noise_pred = torch.from_numpy(noise_pred).to("cpu") - - latents = self.scheduler.step( - noise_pred, t, latents, **extra_step_kwargs, return_dict=False - )[0] - latents = latents.detach().numpy() - noise_pred = noise_pred.detach().numpy() - - step_time = (time.time() - step_start_time) * 1000 - step_time_sum += step_time - - if self.status == SD_STATE_CANCEL: - break - if self.ondemand: - self.unload_unet() - gc.collect() - - avg_step_time = step_time_sum / len(total_timesteps) - self.log += f"\nAverage step time: {avg_step_time}ms/it" - - return latents - - def decode_latents_sdxl(self, latents, is_fp32_vae): - # latents are in unet dtype here so switch if we want to use fp32 - if is_fp32_vae: - print("Casting latents to float32 for VAE") - latents = latents.to(torch.float32) - images = self.vae("forward", (latents,)) - images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1) - images = images.cpu().permute(0, 2, 3, 1).float().numpy() - - images = (images * 255).round().astype("uint8") - pil_images = [Image.fromarray(image[:, :, :3]) for image in images] - - return pil_images - - @classmethod - def from_pretrained( - cls, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - SharkEulerDiscreteScheduler, - DEISMultistepScheduler, - DDPMScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, - ], - import_mlir: bool, - model_id: str, - ckpt_loc: str, - custom_vae: str, - precision: str, - max_length: int, - batch_size: int, - height: int, - width: int, - use_base_vae: bool, - use_tuned: bool, - ondemand: bool, - low_cpu_mem_usage: bool = False, - debug: bool = False, - stencils: list[str] = [], - # stencil_images: list[Image] = [] - use_lora: str = "", - ddpm_scheduler: DDPMScheduler = None, - use_quantize=None, - ): - if ( - not import_mlir - and not use_lora - and cls.__name__ == "StencilPipeline" - ): - sys.exit("StencilPipeline not supported with SharkTank currently.") - - is_inpaint = cls.__name__ in [ - "InpaintPipeline", - "OutpaintPipeline", - ] - is_upscaler = cls.__name__ in ["UpscalerPipeline"] - is_sdxl = cls.__name__ in ["Text2ImageSDXLPipeline"] - - sd_model = SharkifyStableDiffusionModel( - model_id, - ckpt_loc, - custom_vae, - precision, - max_len=max_length, - batch_size=batch_size, - height=height, - width=width, - use_base_vae=use_base_vae, - use_tuned=use_tuned, - low_cpu_mem_usage=low_cpu_mem_usage, - debug=debug, - is_inpaint=is_inpaint, - is_upscaler=is_upscaler, - is_sdxl=is_sdxl, - stencils=stencils, - use_lora=use_lora, - use_quantize=use_quantize, - ) - - if cls.__name__ in ["UpscalerPipeline"]: - return cls( - scheduler, - ddpm_scheduler, - sd_model, - import_mlir, - use_lora, - ondemand, - ) - - if cls.__name__ == "StencilPipeline": - return cls( - scheduler, sd_model, import_mlir, use_lora, ondemand, stencils - ) - if cls.__name__ == "Text2ImageSDXLPipeline": - is_fp32_vae = True if "16" not in custom_vae else False - return cls( - scheduler, - sd_model, - import_mlir, - use_lora, - ondemand, - is_fp32_vae, - ) - - return cls(scheduler, sd_model, import_mlir, use_lora, ondemand) - - # ##################################################### - # Implements text embeddings with weights from prompts - # https://huggingface.co/AlanB/lpw_stable_diffusion_mod - # ##################################################### - def encode_prompts_weight( - self, - prompt, - negative_prompt, - model_max_length, - do_classifier_free_guidance=True, - max_embeddings_multiples=1, - num_images_per_prompt=1, - ): - r""" - Encodes the prompt into text encoder hidden states. - Args: - prompt (`str` or `list(int)`): - prompt to be encoded - negative_prompt (`str` or `List[str]`): - The prompt or prompts not to guide the image generation. - Ignored when not using guidance - (i.e., ignored if `guidance_scale` is less than `1`). - model_max_length (int): - SHARK: pass the max length instead of relying on - pipe.tokenizer.model_max_length - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not, - SHARK: must be set to True as we always expect neg embeddings - (defaulted to True) - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the - max output length of text encoder. - SHARK: max_embeddings_multiples>1 produce a tensor shape error - (defaulted to 1) - num_images_per_prompt (`int`): - number of images that should be generated per prompt - SHARK: num_images_per_prompt is not used (defaulted to 1) - """ - - # SHARK: Save model_max_length, load the clip and init inference time - self.model_max_length = model_max_length - self.load_clip() - clip_inf_start = time.time() - - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - if negative_prompt is None: - negative_prompt = [""] * batch_size - elif isinstance(negative_prompt, str): - negative_prompt = [negative_prompt] * batch_size - if batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: " - f"{negative_prompt} has batch size {len(negative_prompt)}, " - f"but `prompt`: {prompt} has batch size {batch_size}. " - f"Please make sure that passed `negative_prompt` matches " - "the batch size of `prompt`." - ) - - text_embeddings, uncond_embeddings = get_weighted_text_embeddings( - pipe=self, - prompt=prompt, - uncond_prompt=negative_prompt - if do_classifier_free_guidance - else None, - max_embeddings_multiples=max_embeddings_multiples, - ) - # SHARK: we are not using num_images_per_prompt - # bs_embed, seq_len, _ = text_embeddings.shape - # text_embeddings = text_embeddings.repeat( - # 1, - # num_images_per_prompt, - # 1 - # ) - # text_embeddings = ( - # text_embeddings.view( - # bs_embed * num_images_per_prompt, - # seq_len, - # -1 - # ) - # ) - - if do_classifier_free_guidance: - # SHARK: we are not using num_images_per_prompt - # bs_embed, seq_len, _ = uncond_embeddings.shape - # uncond_embeddings = ( - # uncond_embeddings.repeat( - # 1, - # num_images_per_prompt, - # 1 - # ) - # ) - # uncond_embeddings = ( - # uncond_embeddings.view( - # bs_embed * num_images_per_prompt, - # seq_len, - # -1 - # ) - # ) - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - if text_embeddings.shape[1] > model_max_length: - pad = (0, 0) * (len(text_embeddings.shape) - 2) - pad = pad + (0, 512 - text_embeddings.shape[1]) - text_embeddings = torch.nn.functional.pad(text_embeddings, pad) - - # SHARK: Report clip inference time - clip_inf_time = (time.time() - clip_inf_start) * 1000 - if self.ondemand: - self.unload_clip() - gc.collect() - self.log += f"\nClip Inference time (ms) = {clip_inf_time:.3f}" - - return text_embeddings.numpy().astype(np.float16) - - -from typing import List, Optional, Union -import re - -re_attention = re.compile( - r""" -\\\(| -\\\)| -\\\[| -\\]| -\\\\| -\\| -\(| -\[| -:([+-]?[.\d]+)\)| -\)| -]| -[^\\()\[\]:]+| -: -""", - re.X, -) - - -def parse_prompt_attention(text): - """ - Parses a string with attention tokens and returns a list of pairs: - text and its associated weight. - Accepted tokens are: - (abc) - increases attention to abc by a multiplier of 1.1 - (abc:3.12) - increases attention to abc by a multiplier of 3.12 - [abc] - decreases attention to abc by a multiplier of 1.1 - \( - literal character '(' - \[ - literal character '[' - \) - literal character ')' - \] - literal character ']' - \\ - literal character '\' - anything else - just text - >>> parse_prompt_attention('normal text') - [['normal text', 1.0]] - >>> parse_prompt_attention('an (important) word') - [['an ', 1.0], ['important', 1.1], [' word', 1.0]] - >>> parse_prompt_attention('(unbalanced') - [['unbalanced', 1.1]] - >>> parse_prompt_attention('\(literal\]') - [['(literal]', 1.0]] - >>> parse_prompt_attention('(unnecessary)(parens)') - [['unnecessaryparens', 1.1]] - >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).') - [['a ', 1.0], - ['house', 1.5730000000000004], - [' ', 1.1], - ['on', 1.0], - [' a ', 1.1], - ['hill', 0.55], - [', sun, ', 1.1], - ['sky', 1.4641000000000006], - ['.', 1.1]] - """ - - res = [] - round_brackets = [] - square_brackets = [] - - round_bracket_multiplier = 1.1 - square_bracket_multiplier = 1 / 1.1 - - def multiply_range(start_position, multiplier): - for p in range(start_position, len(res)): - res[p][1] *= multiplier - - for m in re_attention.finditer(text): - text = m.group(0) - weight = m.group(1) - - if text.startswith("\\"): - res.append([text[1:], 1.0]) - elif text == "(": - round_brackets.append(len(res)) - elif text == "[": - square_brackets.append(len(res)) - elif weight is not None and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), float(weight)) - elif text == ")" and len(round_brackets) > 0: - multiply_range(round_brackets.pop(), round_bracket_multiplier) - elif text == "]" and len(square_brackets) > 0: - multiply_range(square_brackets.pop(), square_bracket_multiplier) - else: - res.append([text, 1.0]) - - for pos in round_brackets: - multiply_range(pos, round_bracket_multiplier) - - for pos in square_brackets: - multiply_range(pos, square_bracket_multiplier) - - if len(res) == 0: - res = [["", 1.0]] - - # merge runs of identical weights - i = 0 - while i + 1 < len(res): - if res[i][1] == res[i + 1][1]: - res[i][0] += res[i + 1][0] - res.pop(i + 1) - else: - i += 1 - - return res - - -def get_prompts_with_weights( - pipe: StableDiffusionPipeline, prompt: List[str], max_length: int -): - r""" - Tokenize a list of prompts and return its tokens with weights of each token. - No padding, starting or ending token is included. - """ - tokens = [] - weights = [] - truncated = False - for text in prompt: - texts_and_weights = parse_prompt_attention(text) - text_token = [] - text_weight = [] - for word, weight in texts_and_weights: - # tokenize and discard the starting and the ending token - token = pipe.tokenizer(word).input_ids[1:-1] - text_token += token - # copy the weight by length of token - text_weight += [weight] * len(token) - # stop if the text is too long (longer than truncation limit) - if len(text_token) > max_length: - truncated = True - break - # truncate - if len(text_token) > max_length: - truncated = True - text_token = text_token[:max_length] - text_weight = text_weight[:max_length] - tokens.append(text_token) - weights.append(text_weight) - if truncated: - print( - "Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples" - ) - return tokens, weights - - -def pad_tokens_and_weights( - tokens, - weights, - max_length, - bos, - eos, - no_boseos_middle=True, - chunk_length=77, -): - r""" - Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length. - """ - max_embeddings_multiples = (max_length - 2) // (chunk_length - 2) - weights_length = ( - max_length - if no_boseos_middle - else max_embeddings_multiples * chunk_length - ) - for i in range(len(tokens)): - tokens[i] = ( - [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i])) - ) - if no_boseos_middle: - weights[i] = ( - [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i])) - ) - else: - w = [] - if len(weights[i]) == 0: - w = [1.0] * weights_length - else: - for j in range(max_embeddings_multiples): - w.append(1.0) # weight for starting token in this chunk - w += weights[i][ - j - * (chunk_length - 2) : min( - len(weights[i]), (j + 1) * (chunk_length - 2) - ) - ] - w.append(1.0) # weight for ending token in this chunk - w += [1.0] * (weights_length - len(w)) - weights[i] = w[:] - - return tokens, weights - - -def get_unweighted_text_embeddings( - pipe: StableDiffusionPipeline, - text_input: torch.Tensor, - chunk_length: int, - no_boseos_middle: Optional[bool] = True, -): - """ - When the length of tokens is a multiple of the capacity of the text encoder, - it should be split into chunks and sent to the text encoder individually. - """ - max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2) - if max_embeddings_multiples > 1: - text_embeddings = [] - for i in range(max_embeddings_multiples): - # extract the i-th chunk - text_input_chunk = text_input[ - :, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2 - ].clone() - - # cover the head and the tail by the starting and the ending tokens - text_input_chunk[:, 0] = text_input[0, 0] - text_input_chunk[:, -1] = text_input[0, -1] - # text_embedding = pipe.text_encoder(text_input_chunk)[0] - # SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens - formatted_text_input_chunk = torch.cat( - [text_input_chunk, text_input_chunk] - ) - text_embedding = pipe.text_encoder( - "forward", (formatted_text_input_chunk,) - )[0] - - if no_boseos_middle: - if i == 0: - # discard the ending token - text_embedding = text_embedding[:, :-1] - elif i == max_embeddings_multiples - 1: - # discard the starting token - text_embedding = text_embedding[:, 1:] - else: - # discard both starting and ending tokens - text_embedding = text_embedding[:, 1:-1] - - text_embeddings.append(text_embedding) - # SHARK: Convert the result to tensor - # text_embeddings = torch.concat(text_embeddings, axis=1) - text_embeddings_np = np.concatenate(np.array(text_embeddings)) - text_embeddings = torch.from_numpy(text_embeddings_np)[None, :] - else: - # SHARK: deplicate the text_input as Shark runner expects tokens and neg tokens - # Convert the result to tensor - # text_embeddings = pipe.text_encoder(text_input)[0] - formatted_text_input = torch.cat([text_input, text_input]) - text_embeddings = pipe.text_encoder( - "forward", (formatted_text_input,) - )[0] - text_embeddings = torch.from_numpy(text_embeddings)[None, :] - return text_embeddings - - -# This function deals with NoneType values occuring in tokens after padding -# It switches out None with 49407 as truncating None values causes matrix dimension errors, -def filter_nonetype_tokens(tokens: List[List]): - return [[49407 if token is None else token for token in tokens[0]]] - - -def get_weighted_text_embeddings( - pipe: StableDiffusionPipeline, - prompt: Union[str, List[str]], - uncond_prompt: Optional[Union[str, List[str]]] = None, - max_embeddings_multiples: Optional[int] = 3, - no_boseos_middle: Optional[bool] = False, - skip_parsing: Optional[bool] = False, - skip_weighting: Optional[bool] = False, -): - r""" - Prompts can be assigned with local weights using brackets. For example, - prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful', - and the embedding tokens corresponding to the words get multiplied by a constant, 1.1. - Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean. - Args: - pipe (`StableDiffusionPipeline`): - Pipe to provide access to the tokenizer and the text encoder. - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. - uncond_prompt (`str` or `List[str]`): - The unconditional prompt or prompts for guide the image generation. If unconditional prompt - is provided, the embeddings of prompt and uncond_prompt are concatenated. - max_embeddings_multiples (`int`, *optional*, defaults to `3`): - The max multiple length of prompt embeddings compared to the max output length of text encoder. - no_boseos_middle (`bool`, *optional*, defaults to `False`): - If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and - ending token in each of the chunk in the middle. - skip_parsing (`bool`, *optional*, defaults to `False`): - Skip the parsing of brackets. - skip_weighting (`bool`, *optional*, defaults to `False`): - Skip the weighting. When the parsing is skipped, it is forced True. - """ - max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 - if isinstance(prompt, str): - prompt = [prompt] - - if not skip_parsing: - prompt_tokens, prompt_weights = get_prompts_with_weights( - pipe, prompt, max_length - 2 - ) - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens, uncond_weights = get_prompts_with_weights( - pipe, uncond_prompt, max_length - 2 - ) - else: - prompt_tokens = [ - token[1:-1] - for token in pipe.tokenizer( - prompt, max_length=max_length, truncation=True - ).input_ids - ] - prompt_weights = [[1.0] * len(token) for token in prompt_tokens] - if uncond_prompt is not None: - if isinstance(uncond_prompt, str): - uncond_prompt = [uncond_prompt] - uncond_tokens = [ - token[1:-1] - for token in pipe.tokenizer( - uncond_prompt, max_length=max_length, truncation=True - ).input_ids - ] - uncond_weights = [[1.0] * len(token) for token in uncond_tokens] - - # round up the longest length of tokens to a multiple of (model_max_length - 2) - max_length = max([len(token) for token in prompt_tokens]) - if uncond_prompt is not None: - max_length = max( - max_length, max([len(token) for token in uncond_tokens]) - ) - - max_embeddings_multiples = min( - max_embeddings_multiples, - (max_length - 1) // (pipe.model_max_length - 2) + 1, - ) - max_embeddings_multiples = max(1, max_embeddings_multiples) - max_length = (pipe.model_max_length - 2) * max_embeddings_multiples + 2 - - # pad the length of tokens and weights - bos = pipe.tokenizer.bos_token_id - eos = pipe.tokenizer.eos_token_id - prompt_tokens, prompt_weights = pad_tokens_and_weights( - prompt_tokens, - prompt_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.model_max_length, - ) - - # FIXME: This is a hacky fix caused by tokenizer padding with None values - prompt_tokens = filter_nonetype_tokens(prompt_tokens) - - # prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device) - prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device="cpu") - if uncond_prompt is not None: - uncond_tokens, uncond_weights = pad_tokens_and_weights( - uncond_tokens, - uncond_weights, - max_length, - bos, - eos, - no_boseos_middle=no_boseos_middle, - chunk_length=pipe.model_max_length, - ) - - # FIXME: This is a hacky fix caused by tokenizer padding with None values - uncond_tokens = filter_nonetype_tokens(uncond_tokens) - - # uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device) - uncond_tokens = torch.tensor( - uncond_tokens, dtype=torch.long, device="cpu" - ) - - # get the embeddings - text_embeddings = get_unweighted_text_embeddings( - pipe, - prompt_tokens, - pipe.model_max_length, - no_boseos_middle=no_boseos_middle, - ) - # prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device) - prompt_weights = torch.tensor( - prompt_weights, dtype=torch.float, device="cpu" - ) - if uncond_prompt is not None: - uncond_embeddings = get_unweighted_text_embeddings( - pipe, - uncond_tokens, - pipe.model_max_length, - no_boseos_middle=no_boseos_middle, - ) - # uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device) - uncond_weights = torch.tensor( - uncond_weights, dtype=torch.float, device="cpu" - ) - - # assign weights to the prompts and normalize in the sense of mean - # TODO: should we normalize by chunk or in a whole (current implementation)? - if (not skip_parsing) and (not skip_weighting): - previous_mean = ( - text_embeddings.float() - .mean(axis=[-2, -1]) - .to(text_embeddings.dtype) - ) - text_embeddings *= prompt_weights.unsqueeze(-1) - current_mean = ( - text_embeddings.float() - .mean(axis=[-2, -1]) - .to(text_embeddings.dtype) - ) - text_embeddings *= ( - (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - ) - if uncond_prompt is not None: - previous_mean = ( - uncond_embeddings.float() - .mean(axis=[-2, -1]) - .to(uncond_embeddings.dtype) - ) - uncond_embeddings *= uncond_weights.unsqueeze(-1) - current_mean = ( - uncond_embeddings.float() - .mean(axis=[-2, -1]) - .to(uncond_embeddings.dtype) - ) - uncond_embeddings *= ( - (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1) - ) - - if uncond_prompt is not None: - return text_embeddings, uncond_embeddings - return text_embeddings, None diff --git a/apps/stable_diffusion/src/schedulers/__init__.py b/apps/stable_diffusion/src/schedulers/__init__.py deleted file mode 100644 index e7864e2d..00000000 --- a/apps/stable_diffusion/src/schedulers/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import ( - SharkEulerDiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import ( - SharkEulerAncestralDiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers.sd_schedulers import get_schedulers diff --git a/apps/stable_diffusion/src/schedulers/sd_schedulers.py b/apps/stable_diffusion/src/schedulers/sd_schedulers.py deleted file mode 100644 index 913b15c9..00000000 --- a/apps/stable_diffusion/src/schedulers/sd_schedulers.py +++ /dev/null @@ -1,128 +0,0 @@ -from diffusers import ( - LCMScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - DDPMScheduler, - DDIMScheduler, - DPMSolverMultistepScheduler, - KDPM2DiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DEISMultistepScheduler, - DPMSolverSinglestepScheduler, - KDPM2AncestralDiscreteScheduler, - HeunDiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers.shark_eulerdiscrete import ( - SharkEulerDiscreteScheduler, -) -from apps.stable_diffusion.src.schedulers.shark_eulerancestraldiscrete import ( - SharkEulerAncestralDiscreteScheduler, -) - - -def get_schedulers(model_id): - # TODO: Robust scheduler setup on pipeline creation -- if we don't - # set batch_size here, the SHARK schedulers will - # compile with batch size = 1 regardless of whether the model - # outputs latents of a larger batch size, e.g. SDXL. - # However, obviously, searching for whether the base model ID - # contains "xl" is not very robust. - - batch_size = 2 if "xl" in model_id.lower() else 1 - - schedulers = dict() - schedulers["PNDM"] = PNDMScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers["DDPM"] = DDPMScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers["KDPM2Discrete"] = KDPM2DiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers["LMSDiscrete"] = LMSDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers["DDIM"] = DDIMScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers["LCMScheduler"] = LCMScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers[ - "DPMSolverMultistep" - ] = DPMSolverMultistepScheduler.from_pretrained( - model_id, subfolder="scheduler", algorithm_type="dpmsolver" - ) - schedulers[ - "DPMSolverMultistep++" - ] = DPMSolverMultistepScheduler.from_pretrained( - model_id, subfolder="scheduler", algorithm_type="dpmsolver++" - ) - schedulers[ - "DPMSolverMultistepKarras" - ] = DPMSolverMultistepScheduler.from_pretrained( - model_id, - subfolder="scheduler", - use_karras_sigmas=True, - ) - schedulers[ - "DPMSolverMultistepKarras++" - ] = DPMSolverMultistepScheduler.from_pretrained( - model_id, - subfolder="scheduler", - algorithm_type="dpmsolver++", - use_karras_sigmas=True, - ) - schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers[ - "EulerAncestralDiscrete" - ] = EulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers["DEISMultistep"] = DEISMultistepScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers[ - "SharkEulerDiscrete" - ] = SharkEulerDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers[ - "SharkEulerAncestralDiscrete" - ] = SharkEulerAncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers[ - "DPMSolverSinglestep" - ] = DPMSolverSinglestepScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers[ - "KDPM2AncestralDiscrete" - ] = KDPM2AncestralDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers["HeunDiscrete"] = HeunDiscreteScheduler.from_pretrained( - model_id, - subfolder="scheduler", - ) - schedulers["SharkEulerDiscrete"].compile(batch_size) - schedulers["SharkEulerAncestralDiscrete"].compile(batch_size) - return schedulers diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py deleted file mode 100644 index c941e562..00000000 --- a/apps/stable_diffusion/src/schedulers/shark_eulerancestraldiscrete.py +++ /dev/null @@ -1,251 +0,0 @@ -import sys -import numpy as np -from typing import List, Optional, Tuple, Union -from diffusers import ( - EulerAncestralDiscreteScheduler, -) -from diffusers.utils.torch_utils import randn_tensor -from diffusers.configuration_utils import register_to_config -from apps.stable_diffusion.src.utils import ( - compile_through_fx, - get_shark_model, - args, -) -import torch - - -class SharkEulerAncestralDiscreteScheduler(EulerAncestralDiscreteScheduler): - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - timestep_spacing: str = "linspace", - steps_offset: int = 0, - ): - super().__init__( - num_train_timesteps, - beta_start, - beta_end, - beta_schedule, - trained_betas, - prediction_type, - timestep_spacing, - steps_offset, - ) - # TODO: make it dynamic so we dont have to worry about batch size - self.batch_size = None - self.init_input_shape = None - - def compile(self, batch_size=1): - SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers" - device = args.device.split(":", 1)[0].strip() - self.batch_size = batch_size - - model_input = { - "eulera": { - "output": torch.randn( - batch_size, 4, args.height // 8, args.width // 8 - ), - "latent": torch.randn( - batch_size, 4, args.height // 8, args.width // 8 - ), - "sigma": torch.tensor(1).to(torch.float32), - "sigma_from": torch.tensor(1).to(torch.float32), - "sigma_to": torch.tensor(1).to(torch.float32), - "noise": torch.randn( - batch_size, 4, args.height // 8, args.width // 8 - ), - }, - } - - example_latent = model_input["eulera"]["latent"] - example_output = model_input["eulera"]["output"] - example_noise = model_input["eulera"]["noise"] - if args.precision == "fp16": - example_latent = example_latent.half() - example_output = example_output.half() - example_noise = example_noise.half() - example_sigma = model_input["eulera"]["sigma"] - example_sigma_from = model_input["eulera"]["sigma_from"] - example_sigma_to = model_input["eulera"]["sigma_to"] - - class ScalingModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, latent, sigma): - return latent / ((sigma**2 + 1) ** 0.5) - - class SchedulerStepEpsilonModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, noise_pred, latent, sigma, sigma_from, sigma_to, noise - ): - sigma_up = ( - sigma_to**2 - * (sigma_from**2 - sigma_to**2) - / sigma_from**2 - ) ** 0.5 - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - dt = sigma_down - sigma - pred_original_sample = latent - sigma * noise_pred - derivative = (latent - pred_original_sample) / sigma - prev_sample = latent + derivative * dt - return prev_sample + noise * sigma_up - - class SchedulerStepVPredictionModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward( - self, noise_pred, sigma, sigma_from, sigma_to, latent, noise - ): - sigma_up = ( - sigma_to**2 - * (sigma_from**2 - sigma_to**2) - / sigma_from**2 - ) ** 0.5 - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - dt = sigma_down - sigma - pred_original_sample = noise_pred * ( - -sigma / (sigma**2 + 1) ** 0.5 - ) + (latent / (sigma**2 + 1)) - derivative = (latent - pred_original_sample) / sigma - prev_sample = latent + derivative * dt - return prev_sample + noise * sigma_up - - iree_flags = [] - if len(args.iree_vulkan_target_triple) > 0: - iree_flags.append( - f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" - ) - - def _import(self): - scaling_model = ScalingModel() - self.scaling_model, _ = compile_through_fx( - model=scaling_model, - inputs=(example_latent, example_sigma), - extended_model_name=f"euler_a_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_" - + args.precision, - extra_args=iree_flags, - ) - - pred_type_model_dict = { - "epsilon": SchedulerStepEpsilonModel(), - "v_prediction": SchedulerStepVPredictionModel(), - } - step_model = pred_type_model_dict[self.config.prediction_type] - self.step_model, _ = compile_through_fx( - step_model, - ( - example_output, - example_latent, - example_sigma, - example_sigma_from, - example_sigma_to, - example_noise, - ), - extended_model_name=f"euler_a_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_" - + args.precision, - extra_args=iree_flags, - ) - - if args.import_mlir: - _import(self) - - else: - try: - self.scaling_model = get_shark_model( - SCHEDULER_BUCKET, - "euler_a_scale_model_input_" + args.precision, - iree_flags, - ) - self.step_model = get_shark_model( - SCHEDULER_BUCKET, - "euler_a_step_" - + self.config.prediction_type - + args.precision, - iree_flags, - ) - except: - print( - "failed to download model, falling back and using import_mlir" - ) - args.import_mlir = True - _import(self) - - def scale_model_input(self, sample, timestep): - if self.step_index is None: - self._init_step_index(timestep) - sigma = self.sigmas[self.step_index] - return self.scaling_model( - "forward", - ( - sample, - sigma, - ), - send_to_host=False, - ) - - def step( - self, - noise_pred, - timestep, - latent, - generator: Optional[torch.Generator] = None, - return_dict: Optional[bool] = False, - ): - step_inputs = [] - - if self.step_index is None: - self._init_step_index(timestep) - - sigma = self.sigmas[self.step_index] - - sigma_from = self.sigmas[self.step_index] - sigma_to = self.sigmas[self.step_index + 1] - noise = randn_tensor( - torch.Size(noise_pred.shape), - dtype=torch.float16, - device="cpu", - generator=generator, - ) - step_inputs = [ - noise_pred, - latent, - sigma, - sigma_from, - sigma_to, - noise, - ] - # TODO: deal with dynamic inputs in turbine flow. - # update step index since we're done with the variable and will return with compiled module output. - self._step_index += 1 - - if noise_pred.shape[0] < self.batch_size: - for i in [0, 1, 5]: - try: - step_inputs[i] = torch.tensor(step_inputs[i]) - except: - step_inputs[i] = torch.tensor(step_inputs[i].to_host()) - step_inputs[i] = torch.cat( - (step_inputs[i], step_inputs[i]), axis=0 - ) - return self.step_model( - "forward", - tuple(step_inputs), - send_to_host=True, - ) - - return self.step_model( - "forward", - tuple(step_inputs), - send_to_host=False, - ) diff --git a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py b/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py deleted file mode 100644 index 5e9040c5..00000000 --- a/apps/stable_diffusion/src/schedulers/shark_eulerdiscrete.py +++ /dev/null @@ -1,245 +0,0 @@ -import sys -import numpy as np -from typing import List, Optional, Tuple, Union -from diffusers import ( - EulerDiscreteScheduler, -) -from diffusers.utils.torch_utils import randn_tensor -from diffusers.configuration_utils import register_to_config -from apps.stable_diffusion.src.utils import ( - compile_through_fx, - get_shark_model, - args, -) -import torch - - -class SharkEulerDiscreteScheduler(EulerDiscreteScheduler): - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[Union[np.ndarray, List[float]]] = None, - prediction_type: str = "epsilon", - interpolation_type: str = "linear", - use_karras_sigmas: bool = False, - sigma_min: Optional[float] = None, - sigma_max: Optional[float] = None, - timestep_spacing: str = "linspace", - timestep_type: str = "discrete", - steps_offset: int = 0, - ): - super().__init__( - num_train_timesteps, - beta_start, - beta_end, - beta_schedule, - trained_betas, - prediction_type, - interpolation_type, - use_karras_sigmas, - sigma_min, - sigma_max, - timestep_spacing, - timestep_type, - steps_offset, - ) - # TODO: make it dynamic so we dont have to worry about batch size - self.batch_size = 1 - - def compile(self, batch_size=1): - SCHEDULER_BUCKET = "gs://shark_tank/stable_diffusion/schedulers" - device = args.device.split(":", 1)[0].strip() - self.batch_size = batch_size - - model_input = { - "euler": { - "latent": torch.randn( - batch_size, 4, args.height // 8, args.width // 8 - ), - "output": torch.randn( - batch_size, 4, args.height // 8, args.width // 8 - ), - "sigma": torch.tensor(1).to(torch.float32), - "dt": torch.tensor(1).to(torch.float32), - }, - } - - example_latent = model_input["euler"]["latent"] - example_output = model_input["euler"]["output"] - if args.precision == "fp16": - example_latent = example_latent.half() - example_output = example_output.half() - example_sigma = model_input["euler"]["sigma"] - example_dt = model_input["euler"]["dt"] - - class ScalingModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, latent, sigma): - return latent / ((sigma**2 + 1) ** 0.5) - - class SchedulerStepEpsilonModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, noise_pred, sigma_hat, latent, dt): - pred_original_sample = latent - sigma_hat * noise_pred - derivative = (latent - pred_original_sample) / sigma_hat - return latent + derivative * dt - - class SchedulerStepSampleModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, noise_pred, sigma_hat, latent, dt): - pred_original_sample = noise_pred - derivative = (latent - pred_original_sample) / sigma_hat - return latent + derivative * dt - - class SchedulerStepVPredictionModel(torch.nn.Module): - def __init__(self): - super().__init__() - - def forward(self, noise_pred, sigma, latent, dt): - pred_original_sample = noise_pred * ( - -sigma / (sigma**2 + 1) ** 0.5 - ) + (latent / (sigma**2 + 1)) - derivative = (latent - pred_original_sample) / sigma - return latent + derivative * dt - - iree_flags = [] - if len(args.iree_vulkan_target_triple) > 0: - iree_flags.append( - f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" - ) - - def _import(self): - scaling_model = ScalingModel() - self.scaling_model, _ = compile_through_fx( - model=scaling_model, - inputs=(example_latent, example_sigma), - extended_model_name=f"euler_scale_model_input_{self.batch_size}_{args.height}_{args.width}_{device}_" - + args.precision, - extra_args=iree_flags, - ) - - pred_type_model_dict = { - "epsilon": SchedulerStepEpsilonModel(), - "v_prediction": SchedulerStepVPredictionModel(), - "sample": SchedulerStepSampleModel(), - "original_sample": SchedulerStepSampleModel(), - } - step_model = pred_type_model_dict[self.config.prediction_type] - self.step_model, _ = compile_through_fx( - step_model, - (example_output, example_sigma, example_latent, example_dt), - extended_model_name=f"euler_step_{self.config.prediction_type}_{self.batch_size}_{args.height}_{args.width}_{device}_" - + args.precision, - extra_args=iree_flags, - ) - - if args.import_mlir: - _import(self) - - else: - try: - step_model_type = ( - "sample" - if "sample" in self.config.prediction_type - else self.config.prediction_type - ) - self.scaling_model = get_shark_model( - SCHEDULER_BUCKET, - "euler_scale_model_input_" + args.precision, - iree_flags, - ) - self.step_model = get_shark_model( - SCHEDULER_BUCKET, - "euler_step_" + step_model_type + args.precision, - iree_flags, - ) - except: - print( - "failed to download model, falling back and using import_mlir" - ) - args.import_mlir = True - _import(self) - - def scale_model_input(self, sample, timestep): - if self.step_index is None: - self._init_step_index(timestep) - sigma = self.sigmas[self.step_index] - return self.scaling_model( - "forward", - ( - sample, - sigma, - ), - send_to_host=False, - ) - - def step( - self, - noise_pred, - timestep, - latent, - s_churn: float = 0.0, - s_tmin: float = 0.0, - s_tmax: float = float("inf"), - s_noise: float = 1.0, - generator: Optional[torch.Generator] = None, - return_dict: Optional[bool] = False, - ): - if self.step_index is None: - self._init_step_index(timestep) - - sigma = self.sigmas[self.step_index] - - gamma = ( - min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) - if s_tmin <= sigma <= s_tmax - else 0.0 - ) - - sigma_hat = sigma * (gamma + 1) - - noise_pred = ( - torch.from_numpy(noise_pred) - if isinstance(noise_pred, np.ndarray) - else noise_pred - ) - - noise = randn_tensor( - torch.Size(noise_pred.shape), - dtype=torch.float16, - device="cpu", - generator=generator, - ) - - eps = noise * s_noise - - if gamma > 0: - latent = latent + eps * (sigma_hat**2 - sigma**2) ** 0.5 - - if self.config.prediction_type == "v_prediction": - sigma_hat = sigma - - dt = self.sigmas[self.step_index + 1] - sigma_hat - - self._step_index += 1 - - return self.step_model( - "forward", - ( - noise_pred, - sigma_hat, - latent, - dt, - ), - send_to_host=False, - ) diff --git a/apps/stable_diffusion/src/utils/__init__.py b/apps/stable_diffusion/src/utils/__init__.py deleted file mode 100644 index 8436e655..00000000 --- a/apps/stable_diffusion/src/utils/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -from apps.stable_diffusion.src.utils.profiler import ( - start_profiling, - end_profiling, -) -from apps.stable_diffusion.src.utils.resources import ( - prompt_examples, - models_db, - base_models, - opt_flags, - resource_path, -) -from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation -from apps.stable_diffusion.src.utils.stable_args import args -from apps.stable_diffusion.src.utils.stencils.stencil_utils import ( - controlnet_hint_conversion, - controlnet_hint_reshaping, - get_stencil_model_id, -) -from apps.stable_diffusion.src.utils.utils import ( - get_shark_model, - compile_through_fx, - set_iree_runtime_flags, - map_device_to_name_path, - set_init_device_flags, - get_available_devices, - get_opt_flags, - preprocessCKPT, - convert_original_vae, - fetch_and_update_base_model_id, - get_path_to_diffusers_checkpoint, - sanitize_seed, - parse_seed_input, - batch_seeds, - get_path_stem, - get_extended_name, - get_generated_imgs_path, - get_generated_imgs_todays_subdir, - clear_all, - save_output_img, - get_generation_text_info, - update_lora_weight, - resize_stencil, - _compile_module, -) -from apps.stable_diffusion.src.utils.civitai import get_civitai_checkpoint -from apps.stable_diffusion.src.utils.resamplers import ( - resamplers, - resampler_list, -) diff --git a/apps/stable_diffusion/src/utils/civitai.py b/apps/stable_diffusion/src/utils/civitai.py deleted file mode 100644 index ff2aa45b..00000000 --- a/apps/stable_diffusion/src/utils/civitai.py +++ /dev/null @@ -1,42 +0,0 @@ -import re -import requests -from apps.stable_diffusion.src.utils.stable_args import args - -from pathlib import Path -from tqdm import tqdm - - -def get_civitai_checkpoint(url: str): - with requests.get(url, allow_redirects=True, stream=True) as response: - response.raise_for_status() - - # civitai api returns the filename in the content disposition - base_filename = re.findall( - '"([^"]*)"', response.headers["Content-Disposition"] - )[0] - destination_path = ( - Path.cwd() / (args.ckpt_dir or "models") / base_filename - ) - - # we don't have this model downloaded yet - if not destination_path.is_file(): - print( - f"downloading civitai model from {url} to {destination_path}" - ) - - size = int(response.headers["content-length"], 0) - progress_bar = tqdm(total=size, unit="iB", unit_scale=True) - - with open(destination_path, "wb") as f: - for chunk in response.iter_content(chunk_size=65536): - f.write(chunk) - progress_bar.update(len(chunk)) - - progress_bar.close() - - # we already have this model downloaded - else: - print(f"civitai model already downloaded to {destination_path}") - - response.close() - return destination_path.as_posix() diff --git a/apps/stable_diffusion/src/utils/profiler.py b/apps/stable_diffusion/src/utils/profiler.py deleted file mode 100644 index 8de53ef0..00000000 --- a/apps/stable_diffusion/src/utils/profiler.py +++ /dev/null @@ -1,20 +0,0 @@ -from apps.stable_diffusion.src.utils.stable_args import args - - -# Helper function to profile the vulkan device. -def start_profiling(file_path="foo.rdc", profiling_mode="queue"): - from shark.parser import shark_args - - if shark_args.vulkan_debug_utils and "vulkan" in args.device: - import iree - - print(f"Profiling and saving to {file_path}.") - vulkan_device = iree.runtime.get_device(args.device) - vulkan_device.begin_profiling(mode=profiling_mode, file_path=file_path) - return vulkan_device - return None - - -def end_profiling(device): - if device: - return device.end_profiling() diff --git a/apps/stable_diffusion/src/utils/resamplers.py b/apps/stable_diffusion/src/utils/resamplers.py deleted file mode 100644 index 3d24ad1b..00000000 --- a/apps/stable_diffusion/src/utils/resamplers.py +++ /dev/null @@ -1,12 +0,0 @@ -import PIL.Image as Image - -resamplers = { - "Lanczos": Image.Resampling.LANCZOS, - "Nearest Neighbor": Image.Resampling.NEAREST, - "Bilinear": Image.Resampling.BILINEAR, - "Bicubic": Image.Resampling.BICUBIC, - "Hamming": Image.Resampling.HAMMING, - "Box": Image.Resampling.BOX, -} - -resampler_list = resamplers.keys() diff --git a/apps/stable_diffusion/src/utils/resources.py b/apps/stable_diffusion/src/utils/resources.py deleted file mode 100644 index 43504b82..00000000 --- a/apps/stable_diffusion/src/utils/resources.py +++ /dev/null @@ -1,37 +0,0 @@ -import os -import json -import sys - - -def resource_path(relative_path): - """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) - return os.path.join(base_path, relative_path) - - -def get_json_file(path): - json_var = [] - loc_json = resource_path(path) - if os.path.exists(loc_json): - with open(loc_json, encoding="utf-8") as fopen: - json_var = json.load(fopen) - - if not json_var: - print(f"Unable to fetch {path}") - - return json_var - - -# TODO: This shouldn't be called from here, every time the file imports -# it will run all the global vars. -prompt_examples = get_json_file("resources/prompts.json") -models_db = get_json_file("resources/model_db.json") - -# The base_model contains the input configuration for the different -# models and also helps in providing information for the variants. -base_models = get_json_file("resources/base_model.json") - -# Contains optimization flags for different models. -opt_flags = get_json_file("resources/opt_flags.json") diff --git a/apps/stable_diffusion/src/utils/resources/base_model.json b/apps/stable_diffusion/src/utils/resources/base_model.json deleted file mode 100644 index 6adce2ad..00000000 --- a/apps/stable_diffusion/src/utils/resources/base_model.json +++ /dev/null @@ -1,495 +0,0 @@ -{ - "clip": { - "token" : { - "shape" : [ - "2*batch_size", - "max_len" - ], - "dtype":"i64" - } - }, - "sdxl_clip": { - "token" : { - "shape" : [ - "1*batch_size", - "max_len" - ], - "dtype":"i64" - } - }, - "vae_encode": { - "image" : { - "shape" : [ - "1*batch_size",3,"8*height","8*width" - ], - "dtype":"f32" - } - }, - "vae": { - "vae": { - "latents" : { - "shape" : [ - "1*batch_size",4,"height","width" - ], - "dtype":"f32" - } - }, - "vae_upscaler": { - "latents" : { - "shape" : [ - "1*batch_size",4,"8*height","8*width" - ], - "dtype":"f32" - } - } - }, - "unet": { - "stabilityai/stable-diffusion-2-1": { - "latents": { - "shape": [ - "1*batch_size", - 4, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "embedding": { - "shape": [ - "2*batch_size", - "max_len", - 1024 - ], - "dtype": "f32" - }, - "guidance_scale": { - "shape": 2, - "dtype": "f32" - } - }, - "CompVis/stable-diffusion-v1-4": { - "latents": { - "shape": [ - "1*batch_size", - 4, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "embedding": { - "shape": [ - "2*batch_size", - "max_len", - 768 - ], - "dtype": "f32" - }, - "guidance_scale": { - "shape": 2, - "dtype": "f32" - } - }, - "stabilityai/stable-diffusion-2-inpainting": { - "latents": { - "shape": [ - "1*batch_size", - 9, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "embedding": { - "shape": [ - "2*batch_size", - "max_len", - 1024 - ], - "dtype": "f32" - }, - "guidance_scale": { - "shape": 2, - "dtype": "f32" - } - }, - "runwayml/stable-diffusion-inpainting": { - "latents": { - "shape": [ - "1*batch_size", - 9, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "embedding": { - "shape": [ - "2*batch_size", - "max_len", - 768 - ], - "dtype": "f32" - }, - "guidance_scale": { - "shape": 2, - "dtype": "f32" - } - }, - "stabilityai/stable-diffusion-x4-upscaler": { - "latents": { - "shape": [ - "2*batch_size", - 7, - "8*height", - "8*width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "embedding": { - "shape": [ - "2*batch_size", - "max_len", - 1024 - ], - "dtype": "f32" - }, - "noise_level": { - "shape": [2], - "dtype": "i64" - } - }, - "stabilityai/sdxl-turbo": { - "latents": { - "shape": [ - "2*batch_size", - 4, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "prompt_embeds": { - "shape": [ - "2*batch_size", - "max_len", - 2048 - ], - "dtype": "f32" - }, - "text_embeds": { - "shape": [ - "2*batch_size", - 1280 - ], - "dtype": "f32" - }, - "time_ids": { - "shape": [ - "2*batch_size", - 6 - ], - "dtype": "f32" - }, - "guidance_scale": { - "shape": 1, - "dtype": "f32" - } - }, - "stabilityai/stable-diffusion-xl-base-1.0": { - "latents": { - "shape": [ - "2*batch_size", - 4, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "prompt_embeds": { - "shape": [ - "2*batch_size", - "max_len", - 2048 - ], - "dtype": "f32" - }, - "text_embeds": { - "shape": [ - "2*batch_size", - 1280 - ], - "dtype": "f32" - }, - "time_ids": { - "shape": [ - "2*batch_size", - 6 - ], - "dtype": "f32" - }, - "guidance_scale": { - "shape": 1, - "dtype": "f32" - } - } - }, - "stencil_adapter": { - "latents": { - "shape": [ - "1*batch_size", - 4, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "embedding": { - "shape": [ - "2*batch_size", - "max_len", - 768 - ], - "dtype": "f32" - }, - "controlnet_hint": { - "shape": [1, 3, "8*height", "8*width"], - "dtype": "f32" - }, - "acc1": { - "shape": [2, 320, "height", "width"], - "dtype": "f32" - }, - "acc2": { - "shape": [2, 320, "height", "width"], - "dtype": "f32" - }, - "acc3": { - "shape": [2, 320, "height", "width"], - "dtype": "f32" - }, - "acc4": { - "shape": [2, 320, "height/2", "width/2"], - "dtype": "f32" - }, - "acc5": { - "shape": [2, 640, "height/2", "width/2"], - "dtype": "f32" - }, - "acc6": { - "shape": [2, 640, "height/2", "width/2"], - "dtype": "f32" - }, - "acc7": { - "shape": [2, 640, "height/4", "width/4"], - "dtype": "f32" - }, - "acc8": { - "shape": [2, 1280, "height/4", "width/4"], - "dtype": "f32" - }, - "acc9": { - "shape": [2, 1280, "height/4", "width/4"], - "dtype": "f32" - }, - "acc10": { - "shape": [2, 1280, "height/8", "width/8"], - "dtype": "f32" - }, - "acc11": { - "shape": [2, 1280, "height/8", "width/8"], - "dtype": "f32" - }, - "acc12": { - "shape": [2, 1280, "height/8", "width/8"], - "dtype": "f32" - }, - "acc13": { - "shape": [2, 1280, "height/8", "width/8"], - "dtype": "f32" - } - }, - "stencil_unet": { - "CompVis/stable-diffusion-v1-4": { - "latents": { - "shape": [ - "1*batch_size", - 4, - "height", - "width" - ], - "dtype": "f32" - }, - "timesteps": { - "shape": [ - 1 - ], - "dtype": "f32" - }, - "embedding": { - "shape": [ - "2*batch_size", - "max_len", - 768 - ], - "dtype": "f32" - }, - "guidance_scale": { - "shape": 2, - "dtype": "f32" - }, - "control1": { - "shape": [2, 320, "height", "width"], - "dtype": "f32" - }, - "control2": { - "shape": [2, 320, "height", "width"], - "dtype": "f32" - }, - "control3": { - "shape": [2, 320, "height", "width"], - "dtype": "f32" - }, - "control4": { - "shape": [2, 320, "height/2", "width/2"], - "dtype": "f32" - }, - "control5": { - "shape": [2, 640, "height/2", "width/2"], - "dtype": "f32" - }, - "control6": { - "shape": [2, 640, "height/2", "width/2"], - "dtype": "f32" - }, - "control7": { - "shape": [2, 640, "height/4", "width/4"], - "dtype": "f32" - }, - "control8": { - "shape": [2, 1280, "height/4", "width/4"], - "dtype": "f32" - }, - "control9": { - "shape": [2, 1280, "height/4", "width/4"], - "dtype": "f32" - }, - "control10": { - "shape": [2, 1280, "height/8", "width/8"], - "dtype": "f32" - }, - "control11": { - "shape": [2, 1280, "height/8", "width/8"], - "dtype": "f32" - }, - "control12": { - "shape": [2, 1280, "height/8", "width/8"], - "dtype": "f32" - }, - "control13": { - "shape": [2, 1280, "height/8", "width/8"], - "dtype": "f32" - }, - "scale1": { - "shape": 1, - "dtype": "f32" - }, - "scale2": { - "shape": 1, - "dtype": "f32" - }, - "scale3": { - "shape": 1, - "dtype": "f32" - }, - "scale4": { - "shape": 1, - "dtype": "f32" - }, - "scale5": { - "shape": 1, - "dtype": "f32" - }, - "scale6": { - "shape": 1, - "dtype": "f32" - }, - "scale7": { - "shape": 1, - "dtype": "f32" - }, - "scale8": { - "shape": 1, - "dtype": "f32" - }, - "scale9": { - "shape": 1, - "dtype": "f32" - }, - "scale10": { - "shape": 1, - "dtype": "f32" - }, - "scale11": { - "shape": 1, - "dtype": "f32" - }, - "scale12": { - "shape": 1, - "dtype": "f32" - }, - "scale13": { - "shape": 1, - "dtype": "f32" - } - } - } -} diff --git a/apps/stable_diffusion/src/utils/resources/model_config.json b/apps/stable_diffusion/src/utils/resources/model_config.json deleted file mode 100644 index ad9191e2..00000000 --- a/apps/stable_diffusion/src/utils/resources/model_config.json +++ /dev/null @@ -1,23 +0,0 @@ -[ - { - "stablediffusion/v1_4":"CompVis/stable-diffusion-v1-4", - "stablediffusion/v2_1base":"stabilityai/stable-diffusion-2-1-base", - "stablediffusion/v2_1":"stabilityai/stable-diffusion-2-1", - "stablediffusion/inpaint_v1":"runwayml/stable-diffusion-inpainting", - "stablediffusion/inpaint_v2":"stabilityai/stable-diffusion-2-inpainting", - "anythingv3/v1_4":"Linaqruf/anything-v3.0", - "analogdiffusion/v1_4":"wavymulder/Analog-Diffusion", - "openjourney/v1_4":"prompthero/openjourney", - "dreamlike/v1_4":"dreamlike-art/dreamlike-diffusion-1.0" - }, - { - "stablediffusion/fp16":"fp16", - "stablediffusion/fp32":"main", - "anythingv3/fp16":"diffusers", - "anythingv3/fp32":"diffusers", - "analogdiffusion/fp16":"main", - "analogdiffusion/fp32":"main", - "openjourney/fp16":"main", - "openjourney/fp32":"main" - } -] diff --git a/apps/stable_diffusion/src/utils/resources/model_db.json b/apps/stable_diffusion/src/utils/resources/model_db.json deleted file mode 100644 index 14cc24e1..00000000 --- a/apps/stable_diffusion/src/utils/resources/model_db.json +++ /dev/null @@ -1,19 +0,0 @@ -[ - { - "stablediffusion/untuned":"gs://shark_tank/nightly" - }, - { - "stablediffusion/v1_4/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan", - "stablediffusion/v1_4/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan", - "stablediffusion/v1_4/vae/fp16/length_64/untuned":"vae_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan", - "stablediffusion/v1_4/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-v1-4_vulkan", - "stablediffusion/v2_1base/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan", - "stablediffusion/v2_1base/unet/fp16/length_64/untuned":"unet_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan", - "stablediffusion/v2_1base/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan", - "stablediffusion/v2_1base/clip/fp32/length_77/untuned":"clip_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan", - "stablediffusion/v2_1base/clip/fp32/length_64/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan", - "stablediffusion/v2_1/unet/fp16/length_77/untuned":"unet_1_77_512_512_fp16_stable-diffusion-2-1-base_vulkan", - "stablediffusion/v2_1/vae/fp16/length_77/untuned":"vae_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan", - "stablediffusion/v2_1/clip/fp32/length_77/untuned":"clip_1_64_512_512_fp16_stable-diffusion-2-1-base_vulkan" - } -] diff --git a/apps/stable_diffusion/src/utils/resources/opt_flags.json b/apps/stable_diffusion/src/utils/resources/opt_flags.json deleted file mode 100644 index 544bda46..00000000 --- a/apps/stable_diffusion/src/utils/resources/opt_flags.json +++ /dev/null @@ -1,88 +0,0 @@ -{ - "unet": { - "tuned": { - "fp16": { - "default_compilation_flags": [] - }, - "fp32": { - "default_compilation_flags": [] - } - }, - "untuned": { - "fp16": { - "default_compilation_flags": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32}))" - ] - }, - "fp32": { - "default_compilation_flags": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16}))" - ] - } - } - }, - "vae": { - "tuned": { - "fp16": { - "default_compilation_flags": [], - "specified_compilation_flags": { - "cuda": [], - "default_device": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))" - ] - } - }, - "fp32": { - "default_compilation_flags": [], - "specified_compilation_flags": { - "cuda": [], - "default_device": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))" - ] - } - } - }, - "untuned": { - "fp16": { - "default_compilation_flags": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=32},iree-linalg-ext-convert-conv2d-to-winograd))" - ] - }, - "fp32": { - "default_compilation_flags": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-global-opt-detach-elementwise-from-named-ops,iree-global-opt-convert-1x1-filter-conv2d-to-matmul,iree-preprocessing-convert-conv2d-to-img2col,iree-preprocessing-pad-linalg-ops{pad-size=16},iree-linalg-ext-convert-conv2d-to-winograd))" - ] - } - } - }, - "clip": { - "tuned": { - "fp16": { - "default_compilation_flags": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))", - "--iree-opt-data-tiling=False" - ] - }, - "fp32": { - "default_compilation_flags": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))", - "--iree-opt-data-tiling=False" - ] - } - }, - "untuned": { - "fp16": { - "default_compilation_flags": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))", - "--iree-opt-data-tiling=False" - ] - }, - "fp32": { - "default_compilation_flags": [ - "--iree-preprocessing-pass-pipeline=builtin.module(func.func(iree-preprocessing-pad-linalg-ops{pad-size=16}))", - "--iree-opt-data-tiling=False" - ] - } - } - } -} diff --git a/apps/stable_diffusion/src/utils/resources/prompts.json b/apps/stable_diffusion/src/utils/resources/prompts.json deleted file mode 100644 index 7ecce99e..00000000 --- a/apps/stable_diffusion/src/utils/resources/prompts.json +++ /dev/null @@ -1,12 +0,0 @@ -[["A high tech solarpunk utopia in the Amazon rainforest"], -["Astrophotography, the shark nebula, nebula with a tiny shark-like cloud in the middle in the middle, hubble telescope, vivid colors"], -["A pikachu fine dining with a view to the Eiffel Tower"], -["A mecha robot in a favela in expressionist style"], -["an insect robot preparing a delicious meal"], -["A digital Illustration of the Babel tower, 4k, detailed, trending in artstation, fantasy vivid colors"], -["Cluttered house in the woods, anime, oil painting, high resolution, cottagecore, ghibli inspired, 4k"], -["A beautiful mansion beside a waterfall in the woods, by josef thoma, matte painting, trending on artstation HQ"], -["portrait photo of a asia old warrior chief, tribal panther make up, blue on red, side profile, looking away, serious eyes"], -["A photo of a beach, sunset, calm, beautiful landscape, waves, water"], -["(a large body of water with snowy mountains in the background), (fog, foggy, rolling fog), (clouds, cloudy, rolling clouds), dramatic sky and landscape, extraordinary landscape, (beautiful snow capped mountain background), (forest, dirt path)"], -["a photo taken of the front of a super-car drifting on a road near mountains at high speeds with smokes coming off the tires, front angle, front point of view, trees in the mountains of the background, ((sharp focus))"]] diff --git a/apps/stable_diffusion/src/utils/sd_annotation.py b/apps/stable_diffusion/src/utils/sd_annotation.py deleted file mode 100644 index 525f0c0c..00000000 --- a/apps/stable_diffusion/src/utils/sd_annotation.py +++ /dev/null @@ -1,300 +0,0 @@ -import os -import io -from shark.model_annotation import model_annotation, create_context -from shark.iree_utils._common import iree_target_map, run_cmd -from shark.shark_downloader import ( - download_model, - download_public_file, - WORKDIR, -) -from shark.parser import shark_args -from apps.stable_diffusion.src.utils.stable_args import args - - -def get_device(): - device = ( - args.device - if "://" not in args.device - else args.device.split("://")[0] - ) - return device - - -def get_device_args(): - device = get_device() - device_spec_args = [] - if device == "cuda": - from shark.iree_utils.gpu_utils import get_iree_gpu_args - - gpu_flags = get_iree_gpu_args() - for flag in gpu_flags: - device_spec_args.append(flag) - elif device == "vulkan": - device_spec_args.append( - f"--iree-vulkan-target-triple={args.iree_vulkan_target_triple} " - ) - return device, device_spec_args - - -# Download the model (Unet or VAE fp16) from shark_tank -def load_model_from_tank(): - from apps.stable_diffusion.src.models import ( - get_params, - get_variant_version, - ) - - variant, version = get_variant_version(args.hf_model_id) - - shark_args.local_tank_cache = args.local_tank_cache - bucket_key = f"{variant}/untuned" - if args.annotation_model == "unet": - model_key = f"{variant}/{version}/unet/{args.precision}/length_{args.max_length}/untuned" - elif args.annotation_model == "vae": - is_base = "/base" if args.use_base_vae else "" - model_key = f"{variant}/{version}/vae/{args.precision}/length_77/untuned{is_base}" - - bucket, model_name, iree_flags = get_params( - bucket_key, model_key, args.annotation_model, "untuned", args.precision - ) - mlir_model, func_name, inputs, golden_out = download_model( - model_name, - tank_url=bucket, - frontend="torch", - ) - return mlir_model, model_name - - -# Download the tuned config files from shark_tank -def load_winograd_configs(): - device = get_device() - config_bucket = "gs://shark_tank/sd_tuned/configs/" - config_name = f"{args.annotation_model}_winograd_{device}.json" - full_gs_url = config_bucket + config_name - if not os.path.exists(WORKDIR): - os.mkdir(WORKDIR) - winograd_config_dir = os.path.join(WORKDIR, "configs", config_name) - print("Loading Winograd config file from ", winograd_config_dir) - download_public_file(full_gs_url, winograd_config_dir, True) - return winograd_config_dir - - -def load_lower_configs(base_model_id=None): - from apps.stable_diffusion.src.models import get_variant_version - from apps.stable_diffusion.src.utils.utils import ( - fetch_and_update_base_model_id, - ) - - if not base_model_id: - if args.ckpt_loc != "": - base_model_id = fetch_and_update_base_model_id(args.ckpt_loc) - else: - base_model_id = fetch_and_update_base_model_id(args.hf_model_id) - if base_model_id == "": - base_model_id = args.hf_model_id - - variant, version = get_variant_version(base_model_id) - - if version == "inpaint_v1": - version = "v1_4" - elif version == "inpaint_v2": - version = "v2_1base" - - config_bucket = "gs://shark_tank/sd_tuned_configs/" - - device, device_spec_args = get_device_args() - spec = "" - if device_spec_args: - spec = device_spec_args[-1].split("=")[-1].strip() - if device == "vulkan": - spec = spec.split("-")[0] - - if args.annotation_model == "vae": - if not spec or spec in ["sm_80"]: - config_name = ( - f"{args.annotation_model}_{args.precision}_{device}.json" - ) - else: - config_name = f"{args.annotation_model}_{args.precision}_{device}_{spec}.json" - else: - if not spec or spec in ["sm_80"]: - if ( - version in ["v2_1", "v2_1base"] - and args.height == 768 - and args.width == 768 - ): - config_name = f"{args.annotation_model}_v2_1_768_{args.precision}_{device}.json" - else: - config_name = f"{args.annotation_model}_{version}_{args.precision}_{device}.json" - elif spec in ["rdna3"] and version in [ - "v2_1", - "v2_1base", - "v1_4", - "v1_5", - ]: - config_name = ( - f"{args.annotation_model}_" - f"{version}_" - f"{args.max_length}_" - f"{args.precision}_" - f"{device}_" - f"{spec}_" - f"{args.width}x{args.height}.json" - ) - elif spec in ["rdna2"] and version in ["v2_1", "v2_1base", "v1_4"]: - config_name = ( - f"{args.annotation_model}_" - f"{version}_" - f"{args.precision}_" - f"{device}_" - f"{spec}_" - f"{args.width}x{args.height}.json" - ) - else: - config_name = ( - f"{args.annotation_model}_" - f"{version}_" - f"{args.precision}_" - f"{device}_" - f"{spec}.json" - ) - - lowering_config_dir = os.path.join(WORKDIR, "configs", config_name) - print("Loading lowering config file from ", lowering_config_dir) - full_gs_url = config_bucket + config_name - download_public_file(full_gs_url, lowering_config_dir, True) - return lowering_config_dir - - -# Annotate the model with Winograd attribute on selected conv ops -def annotate_with_winograd(input_mlir, winograd_config_dir, model_name): - with create_context() as ctx: - winograd_model = model_annotation( - ctx, - input_contents=input_mlir, - config_path=winograd_config_dir, - search_op="conv", - winograd=True, - ) - - bytecode_stream = io.BytesIO() - winograd_model.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - if args.save_annotation: - if model_name.split("_")[-1] != "tuned": - out_file_path = os.path.join( - args.annotation_output, model_name + "_tuned_torch.mlir" - ) - else: - out_file_path = os.path.join( - args.annotation_output, model_name + "_torch.mlir" - ) - with open(out_file_path, "w") as f: - f.write(str(winograd_model)) - f.close() - - return bytecode - - -def dump_after_mlir(input_mlir, use_winograd): - import iree.compiler as ireec - - device, device_spec_args = get_device_args() - if use_winograd: - preprocess_flag = ( - "--iree-preprocessing-pass-pipeline=builtin.module" - "(func.func(iree-global-opt-detach-elementwise-from-named-ops," - "iree-global-opt-convert-1x1-filter-conv2d-to-matmul," - "iree-preprocessing-convert-conv2d-to-img2col," - "iree-preprocessing-pad-linalg-ops{pad-size=32}," - "iree-linalg-ext-convert-conv2d-to-winograd))" - ) - else: - preprocess_flag = ( - "--iree-preprocessing-pass-pipeline=builtin.module" - "(func.func(iree-global-opt-detach-elementwise-from-named-ops," - "iree-global-opt-convert-1x1-filter-conv2d-to-matmul," - "iree-preprocessing-convert-conv2d-to-img2col," - "iree-preprocessing-pad-linalg-ops{pad-size=32}))" - ) - - dump_module = ireec.compile_str( - input_mlir, - target_backends=[iree_target_map(device)], - extra_args=device_spec_args - + [ - preprocess_flag, - "--compile-to=preprocessing", - ], - ) - return dump_module - - -# For Unet annotate the model with tuned lowering configs -def annotate_with_lower_configs( - input_mlir, lowering_config_dir, model_name, use_winograd -): - # Dump IR after padding/img2col/winograd passes - dump_module = dump_after_mlir(input_mlir, use_winograd) - print("Applying tuned configs on", model_name) - - # Annotate the model with lowering configs in the config file - with create_context() as ctx: - tuned_model = model_annotation( - ctx, - input_contents=dump_module, - config_path=lowering_config_dir, - search_op="all", - ) - - bytecode_stream = io.BytesIO() - tuned_model.operation.write_bytecode(bytecode_stream) - bytecode = bytecode_stream.getvalue() - - if args.save_annotation: - if model_name.split("_")[-1] != "tuned": - out_file_path = ( - f"{args.annotation_output}/{model_name}_tuned_torch.mlir" - ) - else: - out_file_path = f"{args.annotation_output}/{model_name}_torch.mlir" - with open(out_file_path, "w") as f: - f.write(str(tuned_model)) - f.close() - - return bytecode - - -def sd_model_annotation(mlir_model, model_name, base_model_id=None): - device = get_device() - if args.annotation_model == "unet" and device == "vulkan": - use_winograd = True - winograd_config_dir = load_winograd_configs() - winograd_model = annotate_with_winograd( - mlir_model, winograd_config_dir, model_name - ) - lowering_config_dir = load_lower_configs(base_model_id) - tuned_model = annotate_with_lower_configs( - winograd_model, lowering_config_dir, model_name, use_winograd - ) - elif args.annotation_model == "vae" and device == "vulkan": - if "rdna2" not in args.iree_vulkan_target_triple.split("-")[0]: - use_winograd = True - winograd_config_dir = load_winograd_configs() - tuned_model = annotate_with_winograd( - mlir_model, winograd_config_dir, model_name - ) - else: - tuned_model = mlir_model - else: - use_winograd = False - lowering_config_dir = load_lower_configs(base_model_id) - tuned_model = annotate_with_lower_configs( - mlir_model, lowering_config_dir, model_name, use_winograd - ) - return tuned_model - - -if __name__ == "__main__": - mlir_model, model_name = load_model_from_tank() - sd_model_annotation(mlir_model, model_name) diff --git a/apps/stable_diffusion/src/utils/stable_args.py b/apps/stable_diffusion/src/utils/stable_args.py deleted file mode 100644 index 88434ff5..00000000 --- a/apps/stable_diffusion/src/utils/stable_args.py +++ /dev/null @@ -1,771 +0,0 @@ -import argparse -import os -from pathlib import Path - -from apps.stable_diffusion.src.utils.resamplers import resampler_list - - -def path_expand(s): - return Path(s).expanduser().resolve() - - -def is_valid_file(arg): - if not os.path.exists(arg): - return None - else: - return arg - - -p = argparse.ArgumentParser( - description=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter -) - -############################################################################## -# Stable Diffusion Params -############################################################################## - -p.add_argument( - "-a", - "--app", - default="txt2img", - help="Which app to use, one of: txt2img, img2img, outpaint, inpaint.", -) -p.add_argument( - "-p", - "--prompts", - nargs="+", - default=[ - "a photo taken of the front of a super-car drifting on a road near " - "mountains at high speeds with smokes coming off the tires, front " - "angle, front point of view, trees in the mountains of the " - "background, ((sharp focus))" - ], - help="Text of which images to be generated.", -) - -p.add_argument( - "--negative_prompts", - nargs="+", - default=[ - "watermark, signature, logo, text, lowres, ((monochrome, grayscale)), " - "blurry, ugly, blur, oversaturated, cropped" - ], - help="Text you don't want to see in the generated image.", -) - -p.add_argument( - "--img_path", - type=str, - help="Path to the image input for img2img/inpainting.", -) - -p.add_argument( - "--steps", - type=int, - default=50, - help="The number of steps to do the sampling.", -) - -p.add_argument( - "--seed", - type=str, - default=-1, - help="The seed or list of seeds to use. -1 for a random one.", -) - -p.add_argument( - "--batch_size", - type=int, - default=1, - choices=range(1, 4), - help="The number of inferences to be made in a single `batch_count`.", -) - -p.add_argument( - "--height", - type=int, - default=512, - choices=range(128, 1025, 8), - help="The height of the output image.", -) - -p.add_argument( - "--width", - type=int, - default=512, - choices=range(128, 1025, 8), - help="The width of the output image.", -) - -p.add_argument( - "--guidance_scale", - type=float, - default=7.5, - help="The value to be used for guidance scaling.", -) - -p.add_argument( - "--noise_level", - type=int, - default=20, - help="The value to be used for noise level of upscaler.", -) - -p.add_argument( - "--max_length", - type=int, - default=64, - help="Max length of the tokenizer output, options are 64 and 77.", -) - -p.add_argument( - "--max_embeddings_multiples", - type=int, - default=5, - help="The max multiple length of prompt embeddings compared to the max " - "output length of text encoder.", -) - -p.add_argument( - "--strength", - type=float, - default=0.8, - help="The strength of change applied on the given input image for " - "img2img.", -) - -p.add_argument( - "--use_hiresfix", - type=bool, - default=False, - help="Use Hires Fix to do higher resolution images, while trying to " - "avoid the issues that come with it. This is accomplished by first " - "generating an image using txt2img, then running it through img2img.", -) - -p.add_argument( - "--hiresfix_height", - type=int, - default=768, - choices=range(128, 769, 8), - help="The height of the Hires Fix image.", -) - -p.add_argument( - "--hiresfix_width", - type=int, - default=768, - choices=range(128, 769, 8), - help="The width of the Hires Fix image.", -) - -p.add_argument( - "--hiresfix_strength", - type=float, - default=0.6, - help="The denoising strength to apply for the Hires Fix.", -) - -p.add_argument( - "--resample_type", - type=str, - default="Nearest Neighbor", - choices=resampler_list, - help="The resample type to use when resizing an image before being run " - "through stable diffusion.", -) - -############################################################################## -# Stable Diffusion Training Params -############################################################################## - -p.add_argument( - "--lora_save_dir", - type=str, - default="models/lora/", - help="Directory to save the lora fine tuned model.", -) - -p.add_argument( - "--training_images_dir", - type=str, - default="models/lora/training_images/", - help="Directory containing images that are an example of the prompt.", -) - -p.add_argument( - "--training_steps", - type=int, - default=2000, - help="The number of steps to train.", -) - -############################################################################## -# Inpainting and Outpainting Params -############################################################################## - -p.add_argument( - "--mask_path", - type=str, - help="Path to the mask image input for inpainting.", -) - -p.add_argument( - "--inpaint_full_res", - default=False, - action=argparse.BooleanOptionalAction, - help="If inpaint only masked area or whole picture.", -) - -p.add_argument( - "--inpaint_full_res_padding", - type=int, - default=32, - choices=range(0, 257, 4), - help="Number of pixels for only masked padding.", -) - -p.add_argument( - "--pixels", - type=int, - default=128, - choices=range(8, 257, 8), - help="Number of expended pixels for one direction for outpainting.", -) - -p.add_argument( - "--mask_blur", - type=int, - default=8, - choices=range(0, 65), - help="Number of blur pixels for outpainting.", -) - -p.add_argument( - "--left", - default=False, - action=argparse.BooleanOptionalAction, - help="If extend left for outpainting.", -) - -p.add_argument( - "--right", - default=False, - action=argparse.BooleanOptionalAction, - help="If extend right for outpainting.", -) - -p.add_argument( - "--up", - "--top", - default=False, - action=argparse.BooleanOptionalAction, - help="If extend top for outpainting.", -) - -p.add_argument( - "--down", - "--bottom", - default=False, - action=argparse.BooleanOptionalAction, - help="If extend bottom for outpainting.", -) - -p.add_argument( - "--noise_q", - type=float, - default=1.0, - help="Fall-off exponent for outpainting (lower=higher detail) " - "(min=0.0, max=4.0).", -) - -p.add_argument( - "--color_variation", - type=float, - default=0.05, - help="Color variation for outpainting (min=0.0, max=1.0).", -) - -############################################################################## -# Model Config and Usage Params -############################################################################## - -p.add_argument( - "--device", type=str, default="vulkan", help="Device to run the model." -) - -p.add_argument( - "--precision", type=str, default="fp16", help="Precision to run the model." -) - -p.add_argument( - "--import_mlir", - default=True, - action=argparse.BooleanOptionalAction, - help="Imports the model from torch module to shark_module otherwise " - "downloads the model from shark_tank.", -) - -p.add_argument( - "--load_vmfb", - default=True, - action=argparse.BooleanOptionalAction, - help="Attempts to load the model from a precompiled flat-buffer " - "and compiles + saves it if not found.", -) - -p.add_argument( - "--save_vmfb", - default=False, - action=argparse.BooleanOptionalAction, - help="Saves the compiled flat-buffer to the local directory.", -) - -p.add_argument( - "--use_tuned", - default=False, - action=argparse.BooleanOptionalAction, - help="Download and use the tuned version of the model if available.", -) - -p.add_argument( - "--use_base_vae", - default=False, - action=argparse.BooleanOptionalAction, - help="Do conversion from the VAE output to pixel space on cpu.", -) - -p.add_argument( - "--scheduler", - type=str, - default="SharkEulerDiscrete", - help="Other supported schedulers are [DDIM, PNDM, LMSDiscrete, " - "DPMSolverMultistep, DPMSolverMultistep++, DPMSolverMultistepKarras, " - "DPMSolverMultistepKarras++, EulerDiscrete, EulerAncestralDiscrete, " - "DEISMultistep, KDPM2AncestralDiscrete, DPMSolverSinglestep, DDPM, " - "HeunDiscrete].", -) - -p.add_argument( - "--output_img_format", - type=str, - default="png", - help="Specify the format in which output image is save. " - "Supported options: jpg / png.", -) - -p.add_argument( - "--output_dir", - type=str, - default=None, - help="Directory path to save the output images and json.", -) - -p.add_argument( - "--batch_count", - type=int, - default=1, - help="Number of batches to be generated with random seeds in " - "single execution.", -) - -p.add_argument( - "--repeatable_seeds", - default=False, - action=argparse.BooleanOptionalAction, - help="The seed of the first batch will be used as the rng seed to " - "generate the subsequent seeds for subsequent batches in that run.", -) - -p.add_argument( - "--ckpt_loc", - type=str, - default="", - help="Path to SD's .ckpt file.", -) - -p.add_argument( - "--custom_vae", - type=str, - default="", - help="HuggingFace repo-id or path to SD model's checkpoint whose VAE " - "needs to be plugged in.", -) - -p.add_argument( - "--hf_model_id", - type=str, - default="stabilityai/stable-diffusion-2-1-base", - help="The repo-id of hugging face.", -) - -p.add_argument( - "--low_cpu_mem_usage", - default=False, - action=argparse.BooleanOptionalAction, - help="Use the accelerate package to reduce cpu memory consumption.", -) - -p.add_argument( - "--attention_slicing", - type=str, - default="none", - help="Amount of attention slicing to use (one of 'max', 'auto', 'none', " - "or an integer).", -) - -p.add_argument( - "--use_stencil", - choices=["canny", "openpose", "scribble", "zoedepth"], - help="Enable the stencil feature.", -) - -p.add_argument( - "--control_mode", - choices=["Prompt", "Balanced", "Controlnet"], - default="Balanced", - help="How Controlnet injection should be prioritized.", -) - -p.add_argument( - "--use_lora", - type=str, - default="", - help="Use standalone LoRA weight using a HF ID or a checkpoint " - "file (~3 MB).", -) - -p.add_argument( - "--use_quantize", - type=str, - default="none", - help="Runs the quantized version of stable diffusion model. " - "This is currently in experimental phase. " - "Currently, only runs the stable-diffusion-2-1-base model in " - "int8 quantization.", -) - -p.add_argument( - "--ondemand", - default=False, - action=argparse.BooleanOptionalAction, - help="Load and unload models for low VRAM.", -) - -p.add_argument( - "--hf_auth_token", - type=str, - default=None, - help="Specify your own huggingface authentication tokens for models like Llama2.", -) - -p.add_argument( - "--device_allocator_heap_key", - type=str, - default="", - help="Specify heap key for device caching allocator." - "Expected form: max_allocation_size;max_allocation_capacity;max_free_allocation_count" - "Example: --device_allocator_heap_key='*;1gib' (will limit caching on device to 1 gigabyte)", -) - -p.add_argument( - "--autogen", - type=bool, - default="False", - help="Only used for a gradio workaround.", -) -############################################################################## -# IREE - Vulkan supported flags -############################################################################## - -p.add_argument( - "--iree_vulkan_target_triple", - type=str, - default="", - help="Specify target triple for vulkan.", -) - -p.add_argument( - "--iree_metal_target_platform", - type=str, - default="", - help="Specify target triple for metal.", -) - -############################################################################## -# Misc. Debug and Optimization flags -############################################################################## - -p.add_argument( - "--use_compiled_scheduler", - default=True, - action=argparse.BooleanOptionalAction, - help="Use the default scheduler precompiled into the model if available.", -) - -p.add_argument( - "--local_tank_cache", - default="", - help="Specify where to save downloaded shark_tank artifacts. " - "If this is not set, the default is ~/.local/shark_tank/.", -) - -p.add_argument( - "--dump_isa", - default=False, - action="store_true", - help="When enabled call amdllpc to get ISA dumps. " - "Use with dispatch benchmarks.", -) - -p.add_argument( - "--dispatch_benchmarks", - default=None, - help="Dispatches to return benchmark data on. " - 'Use "All" for all, and None for none.', -) - -p.add_argument( - "--dispatch_benchmarks_dir", - default="temp_dispatch_benchmarks", - help="Directory where you want to store dispatch data " - 'generated with "--dispatch_benchmarks".', -) - -p.add_argument( - "--enable_rgp", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag for inserting debug frames between iterations " - "for use with rgp.", -) - -p.add_argument( - "--hide_steps", - default=True, - action=argparse.BooleanOptionalAction, - help="Flag for hiding the details of iteration/sec for each step.", -) - -p.add_argument( - "--warmup_count", - type=int, - default=0, - help="Flag setting warmup count for CLIP and VAE [>= 0].", -) - -p.add_argument( - "--clear_all", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag to clear all mlir and vmfb from common locations. " - "Recompiling will take several minutes.", -) - -p.add_argument( - "--save_metadata_to_json", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag for whether or not to save a generation information " - "json file with the image.", -) - -p.add_argument( - "--write_metadata_to_png", - default=True, - action=argparse.BooleanOptionalAction, - help="Flag for whether or not to save generation information in " - "PNG chunk text to generated images.", -) - -p.add_argument( - "--import_debug", - default=False, - action=argparse.BooleanOptionalAction, - help="If import_mlir is True, saves mlir via the debug option " - "in shark importer. Does nothing if import_mlir is false (the default).", -) - -p.add_argument( - "--compile_debug", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag to toggle debug assert/verify flags for imported IR in the" - "iree-compiler. Default to false.", -) - -p.add_argument( - "--iree_constant_folding", - default=True, - action=argparse.BooleanOptionalAction, - help="Controls constant folding in iree-compile for all SD models.", -) - -p.add_argument( - "--data_tiling", - default=False, - action=argparse.BooleanOptionalAction, - help="Controls data tiling in iree-compile for all SD models.", -) - -############################################################################## -# Web UI flags -############################################################################## - -p.add_argument( - "--progress_bar", - default=True, - action=argparse.BooleanOptionalAction, - help="Flag for removing the progress bar animation during " - "image generation.", -) - -p.add_argument( - "--ckpt_dir", - type=str, - default="", - help="Path to directory where all .ckpts are stored in order to populate " - "them in the web UI.", -) -# TODO: replace API flag when these can be run together -p.add_argument( - "--ui", - type=str, - default="app" if os.name == "nt" else "web", - help="One of: [api, app, web].", -) - -p.add_argument( - "--share", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag for generating a public URL.", -) - -p.add_argument( - "--server_port", - type=int, - default=8080, - help="Flag for setting server port.", -) - -p.add_argument( - "--api", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag for enabling rest API.", -) - -p.add_argument( - "--api_accept_origin", - action="append", - type=str, - help="An origin to be accepted by the REST api for Cross Origin" - "Resource Sharing (CORS). Use multiple times for multiple origins, " - 'or use --api_accept_origin="*" to accept all origins. If no origins ' - "are set no CORS headers will be returned by the api. Use, for " - "instance, if you need to access the REST api from Javascript running " - "in a web browser.", -) - -p.add_argument( - "--debug", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag for enabling debugging log in WebUI.", -) - -p.add_argument( - "--output_gallery", - default=True, - action=argparse.BooleanOptionalAction, - help="Flag for removing the output gallery tab, and avoid exposing " - "images under --output_dir in the UI.", -) - -p.add_argument( - "--output_gallery_followlinks", - default=False, - action=argparse.BooleanOptionalAction, - help="Flag for whether the output gallery tab in the UI should " - "follow symlinks when listing subdirectories under --output_dir.", -) - - -############################################################################## -# SD model auto-annotation flags -############################################################################## - -p.add_argument( - "--annotation_output", - type=path_expand, - default="./", - help="Directory to save the annotated mlir file.", -) - -p.add_argument( - "--annotation_model", - type=str, - default="unet", - help="Options are unet and vae.", -) - -p.add_argument( - "--save_annotation", - default=False, - action=argparse.BooleanOptionalAction, - help="Save annotated mlir file.", -) -############################################################################## -# SD model auto-tuner flags -############################################################################## - -p.add_argument( - "--tuned_config_dir", - type=path_expand, - default="./", - help="Directory to save the tuned config file.", -) - -p.add_argument( - "--num_iters", - type=int, - default=400, - help="Number of iterations for tuning.", -) - -p.add_argument( - "--search_op", - type=str, - default="all", - help="Op to be optimized, options are matmul, bmm, conv and all.", -) - -############################################################################## -# DocuChat Flags -############################################################################## - -p.add_argument( - "--run_docuchat_web", - default=False, - action=argparse.BooleanOptionalAction, - help="Specifies whether the docuchat's web version is running or not.", -) - -############################################################################## -# rocm Flags -############################################################################## - -p.add_argument( - "--iree_rocm_target_chip", - type=str, - default="", - help="Add the rocm device architecture ex gfx1100, gfx90a, etc. Use `hipinfo` " - "or `iree-run-module --dump_devices=rocm` or `hipinfo` to get desired arch name", -) - -args, unknown = p.parse_known_args() -if args.import_debug: - os.environ["IREE_SAVE_TEMPS"] = os.path.join( - os.getcwd(), args.hf_model_id.replace("/", "_") - ) diff --git a/apps/stable_diffusion/src/utils/stencils/__init__.py b/apps/stable_diffusion/src/utils/stencils/__init__.py deleted file mode 100644 index aa7383b4..00000000 --- a/apps/stable_diffusion/src/utils/stencils/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from apps.stable_diffusion.src.utils.stencils.canny import CannyDetector -from apps.stable_diffusion.src.utils.stencils.openpose import OpenposeDetector -from apps.stable_diffusion.src.utils.stencils.zoe import ZoeDetector diff --git a/apps/stable_diffusion/src/utils/stencils/canny/__init__.py b/apps/stable_diffusion/src/utils/stencils/canny/__init__.py deleted file mode 100644 index cb0da951..00000000 --- a/apps/stable_diffusion/src/utils/stencils/canny/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -import cv2 - - -class CannyDetector: - def __call__(self, img, low_threshold, high_threshold): - return cv2.Canny(img, low_threshold, high_threshold) diff --git a/apps/stable_diffusion/src/utils/stencils/openpose/__init__.py b/apps/stable_diffusion/src/utils/stencils/openpose/__init__.py deleted file mode 100644 index f516ec22..00000000 --- a/apps/stable_diffusion/src/utils/stencils/openpose/__init__.py +++ /dev/null @@ -1,62 +0,0 @@ -import requests -from pathlib import Path - -import torch -import numpy as np - -# from annotator.util import annotator_ckpts_path -from apps.stable_diffusion.src.utils.stencils.openpose.body import Body -from apps.stable_diffusion.src.utils.stencils.openpose.hand import Hand -from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import ( - draw_bodypose, - draw_handpose, - handDetect, -) - - -body_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/body_pose_model.pth" -hand_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/hand_pose_model.pth" - - -class OpenposeDetector: - def __init__(self): - cwd = Path.cwd() - ckpt_path = Path(cwd, "stencil_annotator") - ckpt_path.mkdir(parents=True, exist_ok=True) - body_modelpath = ckpt_path / "body_pose_model.pth" - hand_modelpath = ckpt_path / "hand_pose_model.pth" - - if not body_modelpath.is_file(): - r = requests.get(body_model_path, allow_redirects=True) - open(body_modelpath, "wb").write(r.content) - if not hand_modelpath.is_file(): - r = requests.get(hand_model_path, allow_redirects=True) - open(hand_modelpath, "wb").write(r.content) - - self.body_estimation = Body(body_modelpath) - self.hand_estimation = Hand(hand_modelpath) - - def __call__(self, oriImg, hand=False): - oriImg = oriImg[:, :, ::-1].copy() - with torch.no_grad(): - candidate, subset = self.body_estimation(oriImg) - canvas = np.zeros_like(oriImg) - canvas = draw_bodypose(canvas, candidate, subset) - if hand: - hands_list = handDetect(candidate, subset, oriImg) - all_hand_peaks = [] - for x, y, w, is_left in hands_list: - peaks = self.hand_estimation( - oriImg[y : y + w, x : x + w, :] - ) - peaks[:, 0] = np.where( - peaks[:, 0] == 0, peaks[:, 0], peaks[:, 0] + x - ) - peaks[:, 1] = np.where( - peaks[:, 1] == 0, peaks[:, 1], peaks[:, 1] + y - ) - all_hand_peaks.append(peaks) - canvas = draw_handpose(canvas, all_hand_peaks) - return canvas, dict( - candidate=candidate.tolist(), subset=subset.tolist() - ) diff --git a/apps/stable_diffusion/src/utils/stencils/openpose/body.py b/apps/stable_diffusion/src/utils/stencils/openpose/body.py deleted file mode 100644 index 40839750..00000000 --- a/apps/stable_diffusion/src/utils/stencils/openpose/body.py +++ /dev/null @@ -1,499 +0,0 @@ -import cv2 -import numpy as np -import math -from scipy.ndimage.filters import gaussian_filter -import torch -import torch.nn as nn -from collections import OrderedDict -from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import ( - make_layers, - transfer, - padRightDownCorner, -) - - -class BodyPoseModel(nn.Module): - def __init__(self): - super(BodyPoseModel, self).__init__() - - # these layers have no relu layer - no_relu_layers = [ - "conv5_5_CPM_L1", - "conv5_5_CPM_L2", - "Mconv7_stage2_L1", - "Mconv7_stage2_L2", - "Mconv7_stage3_L1", - "Mconv7_stage3_L2", - "Mconv7_stage4_L1", - "Mconv7_stage4_L2", - "Mconv7_stage5_L1", - "Mconv7_stage5_L2", - "Mconv7_stage6_L1", - "Mconv7_stage6_L1", - ] - blocks = {} - block0 = OrderedDict( - [ - ("conv1_1", [3, 64, 3, 1, 1]), - ("conv1_2", [64, 64, 3, 1, 1]), - ("pool1_stage1", [2, 2, 0]), - ("conv2_1", [64, 128, 3, 1, 1]), - ("conv2_2", [128, 128, 3, 1, 1]), - ("pool2_stage1", [2, 2, 0]), - ("conv3_1", [128, 256, 3, 1, 1]), - ("conv3_2", [256, 256, 3, 1, 1]), - ("conv3_3", [256, 256, 3, 1, 1]), - ("conv3_4", [256, 256, 3, 1, 1]), - ("pool3_stage1", [2, 2, 0]), - ("conv4_1", [256, 512, 3, 1, 1]), - ("conv4_2", [512, 512, 3, 1, 1]), - ("conv4_3_CPM", [512, 256, 3, 1, 1]), - ("conv4_4_CPM", [256, 128, 3, 1, 1]), - ] - ) - - # Stage 1 - block1_1 = OrderedDict( - [ - ("conv5_1_CPM_L1", [128, 128, 3, 1, 1]), - ("conv5_2_CPM_L1", [128, 128, 3, 1, 1]), - ("conv5_3_CPM_L1", [128, 128, 3, 1, 1]), - ("conv5_4_CPM_L1", [128, 512, 1, 1, 0]), - ("conv5_5_CPM_L1", [512, 38, 1, 1, 0]), - ] - ) - - block1_2 = OrderedDict( - [ - ("conv5_1_CPM_L2", [128, 128, 3, 1, 1]), - ("conv5_2_CPM_L2", [128, 128, 3, 1, 1]), - ("conv5_3_CPM_L2", [128, 128, 3, 1, 1]), - ("conv5_4_CPM_L2", [128, 512, 1, 1, 0]), - ("conv5_5_CPM_L2", [512, 19, 1, 1, 0]), - ] - ) - blocks["block1_1"] = block1_1 - blocks["block1_2"] = block1_2 - - self.model0 = make_layers(block0, no_relu_layers) - - # Stages 2 - 6 - for i in range(2, 7): - blocks["block%d_1" % i] = OrderedDict( - [ - ("Mconv1_stage%d_L1" % i, [185, 128, 7, 1, 3]), - ("Mconv2_stage%d_L1" % i, [128, 128, 7, 1, 3]), - ("Mconv3_stage%d_L1" % i, [128, 128, 7, 1, 3]), - ("Mconv4_stage%d_L1" % i, [128, 128, 7, 1, 3]), - ("Mconv5_stage%d_L1" % i, [128, 128, 7, 1, 3]), - ("Mconv6_stage%d_L1" % i, [128, 128, 1, 1, 0]), - ("Mconv7_stage%d_L1" % i, [128, 38, 1, 1, 0]), - ] - ) - - blocks["block%d_2" % i] = OrderedDict( - [ - ("Mconv1_stage%d_L2" % i, [185, 128, 7, 1, 3]), - ("Mconv2_stage%d_L2" % i, [128, 128, 7, 1, 3]), - ("Mconv3_stage%d_L2" % i, [128, 128, 7, 1, 3]), - ("Mconv4_stage%d_L2" % i, [128, 128, 7, 1, 3]), - ("Mconv5_stage%d_L2" % i, [128, 128, 7, 1, 3]), - ("Mconv6_stage%d_L2" % i, [128, 128, 1, 1, 0]), - ("Mconv7_stage%d_L2" % i, [128, 19, 1, 1, 0]), - ] - ) - - for k in blocks.keys(): - blocks[k] = make_layers(blocks[k], no_relu_layers) - - self.model1_1 = blocks["block1_1"] - self.model2_1 = blocks["block2_1"] - self.model3_1 = blocks["block3_1"] - self.model4_1 = blocks["block4_1"] - self.model5_1 = blocks["block5_1"] - self.model6_1 = blocks["block6_1"] - - self.model1_2 = blocks["block1_2"] - self.model2_2 = blocks["block2_2"] - self.model3_2 = blocks["block3_2"] - self.model4_2 = blocks["block4_2"] - self.model5_2 = blocks["block5_2"] - self.model6_2 = blocks["block6_2"] - - def forward(self, x): - out1 = self.model0(x) - - out1_1 = self.model1_1(out1) - out1_2 = self.model1_2(out1) - out2 = torch.cat([out1_1, out1_2, out1], 1) - - out2_1 = self.model2_1(out2) - out2_2 = self.model2_2(out2) - out3 = torch.cat([out2_1, out2_2, out1], 1) - - out3_1 = self.model3_1(out3) - out3_2 = self.model3_2(out3) - out4 = torch.cat([out3_1, out3_2, out1], 1) - - out4_1 = self.model4_1(out4) - out4_2 = self.model4_2(out4) - out5 = torch.cat([out4_1, out4_2, out1], 1) - - out5_1 = self.model5_1(out5) - out5_2 = self.model5_2(out5) - out6 = torch.cat([out5_1, out5_2, out1], 1) - - out6_1 = self.model6_1(out6) - out6_2 = self.model6_2(out6) - - return out6_1, out6_2 - - -class Body(object): - def __init__(self, model_path): - self.model = BodyPoseModel() - if torch.cuda.is_available(): - self.model = self.model.cuda() - model_dict = transfer(self.model, torch.load(model_path)) - self.model.load_state_dict(model_dict) - self.model.eval() - - def __call__(self, oriImg): - scale_search = [0.5] - boxsize = 368 - stride = 8 - padValue = 128 - thre1 = 0.1 - thre2 = 0.05 - multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] - heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) - paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) - - for m in range(len(multiplier)): - scale = multiplier[m] - imageToTest = cv2.resize( - oriImg, - (0, 0), - fx=scale, - fy=scale, - interpolation=cv2.INTER_CUBIC, - ) - imageToTest_padded, pad = padRightDownCorner( - imageToTest, stride, padValue - ) - im = ( - np.transpose( - np.float32(imageToTest_padded[:, :, :, np.newaxis]), - (3, 2, 0, 1), - ) - / 256 - - 0.5 - ) - im = np.ascontiguousarray(im) - - data = torch.from_numpy(im).float() - if torch.cuda.is_available(): - data = data.cuda() - with torch.no_grad(): - Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) - Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() - Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() - - # extract outputs, resize, and remove padding - heatmap = np.transpose( - np.squeeze(Mconv7_stage6_L2), (1, 2, 0) - ) # output 1 is heatmaps - heatmap = cv2.resize( - heatmap, - (0, 0), - fx=stride, - fy=stride, - interpolation=cv2.INTER_CUBIC, - ) - heatmap = heatmap[ - : imageToTest_padded.shape[0] - pad[2], - : imageToTest_padded.shape[1] - pad[3], - :, - ] - heatmap = cv2.resize( - heatmap, - (oriImg.shape[1], oriImg.shape[0]), - interpolation=cv2.INTER_CUBIC, - ) - - # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs - paf = np.transpose( - np.squeeze(Mconv7_stage6_L1), (1, 2, 0) - ) # output 0 is PAFs - paf = cv2.resize( - paf, - (0, 0), - fx=stride, - fy=stride, - interpolation=cv2.INTER_CUBIC, - ) - paf = paf[ - : imageToTest_padded.shape[0] - pad[2], - : imageToTest_padded.shape[1] - pad[3], - :, - ] - paf = cv2.resize( - paf, - (oriImg.shape[1], oriImg.shape[0]), - interpolation=cv2.INTER_CUBIC, - ) - - heatmap_avg += heatmap_avg + heatmap / len(multiplier) - paf_avg += +paf / len(multiplier) - - all_peaks = [] - peak_counter = 0 - - for part in range(18): - map_ori = heatmap_avg[:, :, part] - one_heatmap = gaussian_filter(map_ori, sigma=3) - - map_left = np.zeros(one_heatmap.shape) - map_left[1:, :] = one_heatmap[:-1, :] - map_right = np.zeros(one_heatmap.shape) - map_right[:-1, :] = one_heatmap[1:, :] - map_up = np.zeros(one_heatmap.shape) - map_up[:, 1:] = one_heatmap[:, :-1] - map_down = np.zeros(one_heatmap.shape) - map_down[:, :-1] = one_heatmap[:, 1:] - - peaks_binary = np.logical_and.reduce( - ( - one_heatmap >= map_left, - one_heatmap >= map_right, - one_heatmap >= map_up, - one_heatmap >= map_down, - one_heatmap > thre1, - ) - ) - peaks = list( - zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0]) - ) # note reverse - peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] - peak_id = range(peak_counter, peak_counter + len(peaks)) - peaks_with_score_and_id = [ - peaks_with_score[i] + (peak_id[i],) - for i in range(len(peak_id)) - ] - - all_peaks.append(peaks_with_score_and_id) - peak_counter += len(peaks) - - # find connection in the specified sequence, center 29 is in the position 15 - limbSeq = [ - [2, 3], - [2, 6], - [3, 4], - [4, 5], - [6, 7], - [7, 8], - [2, 9], - [9, 10], - [10, 11], - [2, 12], - [12, 13], - [13, 14], - [2, 1], - [1, 15], - [15, 17], - [1, 16], - [16, 18], - [3, 17], - [6, 18], - ] - # the middle joints heatmap correpondence - mapIdx = [ - [31, 32], - [39, 40], - [33, 34], - [35, 36], - [41, 42], - [43, 44], - [19, 20], - [21, 22], - [23, 24], - [25, 26], - [27, 28], - [29, 30], - [47, 48], - [49, 50], - [53, 54], - [51, 52], - [55, 56], - [37, 38], - [45, 46], - ] - - connection_all = [] - special_k = [] - mid_num = 10 - - for k in range(len(mapIdx)): - score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] - candA = all_peaks[limbSeq[k][0] - 1] - candB = all_peaks[limbSeq[k][1] - 1] - nA = len(candA) - nB = len(candB) - indexA, indexB = limbSeq[k] - if nA != 0 and nB != 0: - connection_candidate = [] - for i in range(nA): - for j in range(nB): - vec = np.subtract(candB[j][:2], candA[i][:2]) - norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) - norm = max(0.001, norm) - vec = np.divide(vec, norm) - - startend = list( - zip( - np.linspace( - candA[i][0], candB[j][0], num=mid_num - ), - np.linspace( - candA[i][1], candB[j][1], num=mid_num - ), - ) - ) - - vec_x = np.array( - [ - score_mid[ - int(round(startend[I][1])), - int(round(startend[I][0])), - 0, - ] - for I in range(len(startend)) - ] - ) - vec_y = np.array( - [ - score_mid[ - int(round(startend[I][1])), - int(round(startend[I][0])), - 1, - ] - for I in range(len(startend)) - ] - ) - - score_midpts = np.multiply( - vec_x, vec[0] - ) + np.multiply(vec_y, vec[1]) - score_with_dist_prior = sum(score_midpts) / len( - score_midpts - ) + min(0.5 * oriImg.shape[0] / norm - 1, 0) - criterion1 = len( - np.nonzero(score_midpts > thre2)[0] - ) > 0.8 * len(score_midpts) - criterion2 = score_with_dist_prior > 0 - if criterion1 and criterion2: - connection_candidate.append( - [ - i, - j, - score_with_dist_prior, - score_with_dist_prior - + candA[i][2] - + candB[j][2], - ] - ) - - connection_candidate = sorted( - connection_candidate, key=lambda x: x[2], reverse=True - ) - connection = np.zeros((0, 5)) - for c in range(len(connection_candidate)): - i, j, s = connection_candidate[c][0:3] - if i not in connection[:, 3] and j not in connection[:, 4]: - connection = np.vstack( - [connection, [candA[i][3], candB[j][3], s, i, j]] - ) - if len(connection) >= min(nA, nB): - break - - connection_all.append(connection) - else: - special_k.append(k) - connection_all.append([]) - - # last number in each row is the total parts number of that person - # the second last number in each row is the score of the overall configuration - subset = -1 * np.ones((0, 20)) - candidate = np.array( - [item for sublist in all_peaks for item in sublist] - ) - - for k in range(len(mapIdx)): - if k not in special_k: - partAs = connection_all[k][:, 0] - partBs = connection_all[k][:, 1] - indexA, indexB = np.array(limbSeq[k]) - 1 - - for i in range(len(connection_all[k])): # = 1:size(temp,1) - found = 0 - subset_idx = [-1, -1] - for j in range(len(subset)): # 1:size(subset,1): - if ( - subset[j][indexA] == partAs[i] - or subset[j][indexB] == partBs[i] - ): - subset_idx[found] = j - found += 1 - - if found == 1: - j = subset_idx[0] - if subset[j][indexB] != partBs[i]: - subset[j][indexB] = partBs[i] - subset[j][-1] += 1 - subset[j][-2] += ( - candidate[partBs[i].astype(int), 2] - + connection_all[k][i][2] - ) - elif found == 2: # if found 2 and disjoint, merge them - j1, j2 = subset_idx - membership = ( - (subset[j1] >= 0).astype(int) - + (subset[j2] >= 0).astype(int) - )[:-2] - if len(np.nonzero(membership == 2)[0]) == 0: # merge - subset[j1][:-2] += subset[j2][:-2] + 1 - subset[j1][-2:] += subset[j2][-2:] - subset[j1][-2] += connection_all[k][i][2] - subset = np.delete(subset, j2, 0) - else: # as like found == 1 - subset[j1][indexB] = partBs[i] - subset[j1][-1] += 1 - subset[j1][-2] += ( - candidate[partBs[i].astype(int), 2] - + connection_all[k][i][2] - ) - - # if find no partA in the subset, create a new subset - elif not found and k < 17: - row = -1 * np.ones(20) - row[indexA] = partAs[i] - row[indexB] = partBs[i] - row[-1] = 2 - row[-2] = ( - sum( - candidate[ - connection_all[k][i, :2].astype(int), 2 - ] - ) - + connection_all[k][i][2] - ) - subset = np.vstack([subset, row]) - # delete some rows of subset which has few parts occur - deleteIdx = [] - for i in range(len(subset)): - if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: - deleteIdx.append(i) - subset = np.delete(subset, deleteIdx, axis=0) - - # candidate: x, y, score, id - return candidate, subset diff --git a/apps/stable_diffusion/src/utils/stencils/openpose/hand.py b/apps/stable_diffusion/src/utils/stencils/openpose/hand.py deleted file mode 100644 index e48a3fa0..00000000 --- a/apps/stable_diffusion/src/utils/stencils/openpose/hand.py +++ /dev/null @@ -1,205 +0,0 @@ -import cv2 -import numpy as np -from scipy.ndimage.filters import gaussian_filter -import torch -import torch.nn as nn -from skimage.measure import label -from collections import OrderedDict -from apps.stable_diffusion.src.utils.stencils.openpose.openpose_util import ( - make_layers, - transfer, - padRightDownCorner, - npmax, -) - - -class HandPoseModel(nn.Module): - def __init__(self): - super(HandPoseModel, self).__init__() - - # these layers have no relu layer - no_relu_layers = [ - "conv6_2_CPM", - "Mconv7_stage2", - "Mconv7_stage3", - "Mconv7_stage4", - "Mconv7_stage5", - "Mconv7_stage6", - ] - # stage 1 - block1_0 = OrderedDict( - [ - ("conv1_1", [3, 64, 3, 1, 1]), - ("conv1_2", [64, 64, 3, 1, 1]), - ("pool1_stage1", [2, 2, 0]), - ("conv2_1", [64, 128, 3, 1, 1]), - ("conv2_2", [128, 128, 3, 1, 1]), - ("pool2_stage1", [2, 2, 0]), - ("conv3_1", [128, 256, 3, 1, 1]), - ("conv3_2", [256, 256, 3, 1, 1]), - ("conv3_3", [256, 256, 3, 1, 1]), - ("conv3_4", [256, 256, 3, 1, 1]), - ("pool3_stage1", [2, 2, 0]), - ("conv4_1", [256, 512, 3, 1, 1]), - ("conv4_2", [512, 512, 3, 1, 1]), - ("conv4_3", [512, 512, 3, 1, 1]), - ("conv4_4", [512, 512, 3, 1, 1]), - ("conv5_1", [512, 512, 3, 1, 1]), - ("conv5_2", [512, 512, 3, 1, 1]), - ("conv5_3_CPM", [512, 128, 3, 1, 1]), - ] - ) - - block1_1 = OrderedDict( - [ - ("conv6_1_CPM", [128, 512, 1, 1, 0]), - ("conv6_2_CPM", [512, 22, 1, 1, 0]), - ] - ) - - blocks = {} - blocks["block1_0"] = block1_0 - blocks["block1_1"] = block1_1 - - # stage 2-6 - for i in range(2, 7): - blocks["block%d" % i] = OrderedDict( - [ - ("Mconv1_stage%d" % i, [150, 128, 7, 1, 3]), - ("Mconv2_stage%d" % i, [128, 128, 7, 1, 3]), - ("Mconv3_stage%d" % i, [128, 128, 7, 1, 3]), - ("Mconv4_stage%d" % i, [128, 128, 7, 1, 3]), - ("Mconv5_stage%d" % i, [128, 128, 7, 1, 3]), - ("Mconv6_stage%d" % i, [128, 128, 1, 1, 0]), - ("Mconv7_stage%d" % i, [128, 22, 1, 1, 0]), - ] - ) - - for k in blocks.keys(): - blocks[k] = make_layers(blocks[k], no_relu_layers) - - self.model1_0 = blocks["block1_0"] - self.model1_1 = blocks["block1_1"] - self.model2 = blocks["block2"] - self.model3 = blocks["block3"] - self.model4 = blocks["block4"] - self.model5 = blocks["block5"] - self.model6 = blocks["block6"] - - def forward(self, x): - out1_0 = self.model1_0(x) - out1_1 = self.model1_1(out1_0) - concat_stage2 = torch.cat([out1_1, out1_0], 1) - out_stage2 = self.model2(concat_stage2) - concat_stage3 = torch.cat([out_stage2, out1_0], 1) - out_stage3 = self.model3(concat_stage3) - concat_stage4 = torch.cat([out_stage3, out1_0], 1) - out_stage4 = self.model4(concat_stage4) - concat_stage5 = torch.cat([out_stage4, out1_0], 1) - out_stage5 = self.model5(concat_stage5) - concat_stage6 = torch.cat([out_stage5, out1_0], 1) - out_stage6 = self.model6(concat_stage6) - return out_stage6 - - -class Hand(object): - def __init__(self, model_path): - self.model = HandPoseModel() - if torch.cuda.is_available(): - self.model = self.model.cuda() - model_dict = transfer(self.model, torch.load(model_path)) - self.model.load_state_dict(model_dict) - self.model.eval() - - def __call__(self, oriImg): - scale_search = [0.5, 1.0, 1.5, 2.0] - # scale_search = [0.5] - boxsize = 368 - stride = 8 - padValue = 128 - thre = 0.05 - multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] - heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 22)) - # paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) - - for m in range(len(multiplier)): - scale = multiplier[m] - imageToTest = cv2.resize( - oriImg, - (0, 0), - fx=scale, - fy=scale, - interpolation=cv2.INTER_CUBIC, - ) - imageToTest_padded, pad = padRightDownCorner( - imageToTest, stride, padValue - ) - im = ( - np.transpose( - np.float32(imageToTest_padded[:, :, :, np.newaxis]), - (3, 2, 0, 1), - ) - / 256 - - 0.5 - ) - im = np.ascontiguousarray(im) - - data = torch.from_numpy(im).float() - if torch.cuda.is_available(): - data = data.cuda() - # data = data.permute([2, 0, 1]).unsqueeze(0).float() - with torch.no_grad(): - output = self.model(data).cpu().numpy() - # output = self.model(data).numpy()q - - # extract outputs, resize, and remove padding - heatmap = np.transpose( - np.squeeze(output), (1, 2, 0) - ) # output 1 is heatmaps - heatmap = cv2.resize( - heatmap, - (0, 0), - fx=stride, - fy=stride, - interpolation=cv2.INTER_CUBIC, - ) - heatmap = heatmap[ - : imageToTest_padded.shape[0] - pad[2], - : imageToTest_padded.shape[1] - pad[3], - :, - ] - heatmap = cv2.resize( - heatmap, - (oriImg.shape[1], oriImg.shape[0]), - interpolation=cv2.INTER_CUBIC, - ) - - heatmap_avg += heatmap / len(multiplier) - - all_peaks = [] - for part in range(21): - map_ori = heatmap_avg[:, :, part] - one_heatmap = gaussian_filter(map_ori, sigma=3) - binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) - # 全部小于阈值 - if np.sum(binary) == 0: - all_peaks.append([0, 0]) - continue - label_img, label_numbers = label( - binary, return_num=True, connectivity=binary.ndim - ) - max_index = ( - np.argmax( - [ - np.sum(map_ori[label_img == i]) - for i in range(1, label_numbers + 1) - ] - ) - + 1 - ) - label_img[label_img != max_index] = 0 - map_ori[label_img == 0] = 0 - - y, x = npmax(map_ori) - all_peaks.append([x, y]) - return np.array(all_peaks) diff --git a/apps/stable_diffusion/src/utils/stencils/openpose/openpose_util.py b/apps/stable_diffusion/src/utils/stencils/openpose/openpose_util.py deleted file mode 100644 index 46dba60e..00000000 --- a/apps/stable_diffusion/src/utils/stencils/openpose/openpose_util.py +++ /dev/null @@ -1,272 +0,0 @@ -import math -import numpy as np -import matplotlib -import cv2 -from collections import OrderedDict -import torch.nn as nn - - -def make_layers(block, no_relu_layers): - layers = [] - for layer_name, v in block.items(): - if "pool" in layer_name: - layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], padding=v[2]) - layers.append((layer_name, layer)) - else: - conv2d = nn.Conv2d( - in_channels=v[0], - out_channels=v[1], - kernel_size=v[2], - stride=v[3], - padding=v[4], - ) - layers.append((layer_name, conv2d)) - if layer_name not in no_relu_layers: - layers.append(("relu_" + layer_name, nn.ReLU(inplace=True))) - - return nn.Sequential(OrderedDict(layers)) - - -def padRightDownCorner(img, stride, padValue): - h = img.shape[0] - w = img.shape[1] - - pad = 4 * [None] - pad[0] = 0 # up - pad[1] = 0 # left - pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down - pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right - - img_padded = img - pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1)) - img_padded = np.concatenate((pad_up, img_padded), axis=0) - pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1)) - img_padded = np.concatenate((pad_left, img_padded), axis=1) - pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1)) - img_padded = np.concatenate((img_padded, pad_down), axis=0) - pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1)) - img_padded = np.concatenate((img_padded, pad_right), axis=1) - - return img_padded, pad - - -# transfer caffe model to pytorch which will match the layer name -def transfer(model, model_weights): - transfered_model_weights = {} - for weights_name in model.state_dict().keys(): - transfered_model_weights[weights_name] = model_weights[ - ".".join(weights_name.split(".")[1:]) - ] - return transfered_model_weights - - -# draw the body keypoint and lims -def draw_bodypose(canvas, candidate, subset): - stickwidth = 4 - limbSeq = [ - [2, 3], - [2, 6], - [3, 4], - [4, 5], - [6, 7], - [7, 8], - [2, 9], - [9, 10], - [10, 11], - [2, 12], - [12, 13], - [13, 14], - [2, 1], - [1, 15], - [15, 17], - [1, 16], - [16, 18], - [3, 17], - [6, 18], - ] - - colors = [ - [255, 0, 0], - [255, 85, 0], - [255, 170, 0], - [255, 255, 0], - [170, 255, 0], - [85, 255, 0], - [0, 255, 0], - [0, 255, 85], - [0, 255, 170], - [0, 255, 255], - [0, 170, 255], - [0, 85, 255], - [0, 0, 255], - [85, 0, 255], - [170, 0, 255], - [255, 0, 255], - [255, 0, 170], - [255, 0, 85], - ] - for i in range(18): - for n in range(len(subset)): - index = int(subset[n][i]) - if index == -1: - continue - x, y = candidate[index][0:2] - cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) - for i in range(17): - for n in range(len(subset)): - index = subset[n][np.array(limbSeq[i]) - 1] - if -1 in index: - continue - cur_canvas = canvas.copy() - Y = candidate[index.astype(int), 0] - X = candidate[index.astype(int), 1] - mX = np.mean(X) - mY = np.mean(Y) - length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 - angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) - polygon = cv2.ellipse2Poly( - (int(mY), int(mX)), - (int(length / 2), stickwidth), - int(angle), - 0, - 360, - 1, - ) - cv2.fillConvexPoly(cur_canvas, polygon, colors[i]) - canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0) - return canvas - - -# image drawed by opencv is not good. -def draw_handpose(canvas, all_hand_peaks, show_number=False): - edges = [ - [0, 1], - [1, 2], - [2, 3], - [3, 4], - [0, 5], - [5, 6], - [6, 7], - [7, 8], - [0, 9], - [9, 10], - [10, 11], - [11, 12], - [0, 13], - [13, 14], - [14, 15], - [15, 16], - [0, 17], - [17, 18], - [18, 19], - [19, 20], - ] - - for peaks in all_hand_peaks: - for ie, e in enumerate(edges): - if np.sum(np.all(peaks[e], axis=1) == 0) == 0: - x1, y1 = peaks[e[0]] - x2, y2 = peaks[e[1]] - cv2.line( - canvas, - (x1, y1), - (x2, y2), - matplotlib.colors.hsv_to_rgb( - [ie / float(len(edges)), 1.0, 1.0] - ) - * 255, - thickness=2, - ) - - for i, keyponit in enumerate(peaks): - x, y = keyponit - cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) - if show_number: - cv2.putText( - canvas, - str(i), - (x, y), - cv2.FONT_HERSHEY_SIMPLEX, - 0.3, - (0, 0, 0), - lineType=cv2.LINE_AA, - ) - return canvas - - -# detect hand according to body pose keypoints -# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp -def handDetect(candidate, subset, oriImg): - # right hand: wrist 4, elbow 3, shoulder 2 - # left hand: wrist 7, elbow 6, shoulder 5 - ratioWristElbow = 0.33 - detect_result = [] - image_height, image_width = oriImg.shape[0:2] - for person in subset.astype(int): - # if any of three not detected - has_left = np.sum(person[[5, 6, 7]] == -1) == 0 - has_right = np.sum(person[[2, 3, 4]] == -1) == 0 - if not (has_left or has_right): - continue - hands = [] - # left hand - if has_left: - left_shoulder_index, left_elbow_index, left_wrist_index = person[ - [5, 6, 7] - ] - x1, y1 = candidate[left_shoulder_index][:2] - x2, y2 = candidate[left_elbow_index][:2] - x3, y3 = candidate[left_wrist_index][:2] - hands.append([x1, y1, x2, y2, x3, y3, True]) - # right hand - if has_right: - ( - right_shoulder_index, - right_elbow_index, - right_wrist_index, - ) = person[[2, 3, 4]] - x1, y1 = candidate[right_shoulder_index][:2] - x2, y2 = candidate[right_elbow_index][:2] - x3, y3 = candidate[right_wrist_index][:2] - hands.append([x1, y1, x2, y2, x3, y3, False]) - - for x1, y1, x2, y2, x3, y3, is_left in hands: - x = x3 + ratioWristElbow * (x3 - x2) - y = y3 + ratioWristElbow * (y3 - y2) - distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) - distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) - width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) - # x-y refers to the center --> offset to topLeft point - x -= width / 2 - y -= width / 2 # width = height - # overflow the image - if x < 0: - x = 0 - if y < 0: - y = 0 - width1 = width - width2 = width - if x + width > image_width: - width1 = image_width - x - if y + width > image_height: - width2 = image_height - y - width = min(width1, width2) - # the max hand box value is 20 pixels - if width >= 20: - detect_result.append([int(x), int(y), int(width), is_left]) - - """ - return value: [[x, y, w, True if left hand else False]]. - width=height since the network require squared input. - x, y is the coordinate of top left - """ - return detect_result - - -# get max index of 2d array -def npmax(array): - arrayindex = array.argmax(1) - arrayvalue = array.max(1) - i = arrayvalue.argmax() - j = arrayindex[i] - return (i,) diff --git a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py b/apps/stable_diffusion/src/utils/stencils/stencil_utils.py deleted file mode 100644 index 8a7466db..00000000 --- a/apps/stable_diffusion/src/utils/stencils/stencil_utils.py +++ /dev/null @@ -1,253 +0,0 @@ -import numpy as np -from PIL import Image -import torch -import os -from pathlib import Path -import torchvision -import time -from apps.stable_diffusion.src.utils.stencils import ( - CannyDetector, - OpenposeDetector, - ZoeDetector, -) - -stencil = {} - - -def save_img(img): - from apps.stable_diffusion.src.utils import ( - get_generated_imgs_path, - get_generated_imgs_todays_subdir, - ) - - subdir = Path(get_generated_imgs_path(), "preprocessed_control_hints") - os.makedirs(subdir, exist_ok=True) - if isinstance(img, Image.Image): - img.save( - os.path.join( - subdir, "controlnet_" + str(int(time.time())) + ".png" - ) - ) - elif isinstance(img, np.ndarray): - img = Image.fromarray(img) - img.save(os.path.join(subdir, str(int(time.time())) + ".png")) - else: - converter = torchvision.transforms.ToPILImage() - for i in img: - converter(i).save( - os.path.join(subdir, str(int(time.time())) + ".png") - ) - - -def HWC3(x): - assert x.dtype == np.uint8 - if x.ndim == 2: - x = x[:, :, None] - assert x.ndim == 3 - H, W, C = x.shape - assert C == 1 or C == 3 or C == 4 - if C == 3: - return x - if C == 1: - return np.concatenate([x, x, x], axis=2) - if C == 4: - color = x[:, :, 0:3].astype(np.float32) - alpha = x[:, :, 3:4].astype(np.float32) / 255.0 - y = color * alpha + 255.0 * (1.0 - alpha) - y = y.clip(0, 255).astype(np.uint8) - return y - - -def controlnet_hint_reshaping( - controlnet_hint, height, width, dtype, num_images_per_prompt=1 -): - channels = 3 - if isinstance(controlnet_hint, torch.Tensor): - # torch.Tensor: acceptble shape are any of chw, bchw(b==1) or bchw(b==num_images_per_prompt) - shape_chw = (channels, height, width) - shape_bchw = (1, channels, height, width) - shape_nchw = (num_images_per_prompt, channels, height, width) - if controlnet_hint.shape in [shape_chw, shape_bchw, shape_nchw]: - controlnet_hint = controlnet_hint.to( - dtype=dtype, device=torch.device("cpu") - ) - if controlnet_hint.shape != shape_nchw: - controlnet_hint = controlnet_hint.repeat( - num_images_per_prompt, 1, 1, 1 - ) - return controlnet_hint - else: - return controlnet_hint_reshaping( - Image.fromarray(controlnet_hint.detach().numpy()), - height, - width, - dtype, - num_images_per_prompt, - ) - elif isinstance(controlnet_hint, np.ndarray): - # np.ndarray: acceptable shape is any of hw, hwc, bhwc(b==1) or bhwc(b==num_images_per_promot) - # hwc is opencv compatible image format. Color channel must be BGR Format. - if controlnet_hint.shape == (height, width): - controlnet_hint = np.repeat( - controlnet_hint[:, :, np.newaxis], channels, axis=2 - ) # hw -> hwc(c==3) - shape_hwc = (height, width, channels) - shape_bhwc = (1, height, width, channels) - shape_nhwc = (num_images_per_prompt, height, width, channels) - if controlnet_hint.shape in [shape_hwc, shape_bhwc, shape_nhwc]: - controlnet_hint = torch.from_numpy(controlnet_hint.copy()) - controlnet_hint = controlnet_hint.to( - dtype=dtype, device=torch.device("cpu") - ) - controlnet_hint /= 255.0 - if controlnet_hint.shape != shape_nhwc: - controlnet_hint = controlnet_hint.repeat( - num_images_per_prompt, 1, 1, 1 - ) - controlnet_hint = controlnet_hint.permute( - 0, 3, 1, 2 - ) # b h w c -> b c h w - return controlnet_hint - else: - return controlnet_hint_reshaping( - Image.fromarray(controlnet_hint), - height, - width, - dtype, - num_images_per_prompt, - ) - - elif isinstance(controlnet_hint, Image.Image): - controlnet_hint = controlnet_hint.convert( - "RGB" - ) # make sure 3 channel RGB format - if controlnet_hint.size == (width, height): - controlnet_hint = np.array(controlnet_hint).astype( - np.float16 - ) # to numpy - controlnet_hint = controlnet_hint[:, :, ::-1] # RGB -> BGR - return controlnet_hint_reshaping( - controlnet_hint, height, width, dtype, num_images_per_prompt - ) - else: - (hint_w, hint_h) = controlnet_hint.size - left = int((hint_w - width) / 2) - right = left + height - controlnet_hint = controlnet_hint.crop((left, 0, right, hint_h)) - controlnet_hint = controlnet_hint.resize((width, height)) - return controlnet_hint_reshaping( - controlnet_hint, height, width, dtype, num_images_per_prompt - ) - else: - raise ValueError( - f"Acceptible controlnet input types are any of torch.Tensor, np.ndarray, PIL.Image.Image but is {type(controlnet_hint)}" - ) - - -def controlnet_hint_conversion( - image, use_stencil, height, width, dtype, num_images_per_prompt=1 -): - controlnet_hint = None - match use_stencil: - case "canny": - print( - "Converting controlnet hint to edge detection mask with canny preprocessor." - ) - controlnet_hint = hint_canny(image) - case "openpose": - print( - "Detecting human pose in controlnet hint with openpose preprocessor." - ) - controlnet_hint = hint_openpose(image) - case "scribble": - print("Using your scribble as a controlnet hint.") - controlnet_hint = hint_scribble(image) - case "zoedepth": - print( - "Converting controlnet hint to a depth mapping with ZoeDepth." - ) - controlnet_hint = hint_zoedepth(image) - case _: - return None - controlnet_hint = controlnet_hint_reshaping( - controlnet_hint, height, width, dtype, num_images_per_prompt - ) - return controlnet_hint - - -stencil_to_model_id_map = { - "canny": "lllyasviel/control_v11p_sd15_canny", - "zoedepth": "lllyasviel/control_v11f1p_sd15_depth", - "hed": "lllyasviel/sd-controlnet-hed", - "mlsd": "lllyasviel/control_v11p_sd15_mlsd", - "normal": "lllyasviel/control_v11p_sd15_normalbae", - "openpose": "lllyasviel/control_v11p_sd15_openpose", - "scribble": "lllyasviel/control_v11p_sd15_scribble", - "seg": "lllyasviel/control_v11p_sd15_seg", -} - - -def get_stencil_model_id(use_stencil): - if use_stencil in stencil_to_model_id_map: - return stencil_to_model_id_map[use_stencil] - return None - - -# Stencil 1. Canny -def hint_canny( - image: Image.Image, - low_threshold=100, - high_threshold=200, -): - with torch.no_grad(): - input_image = np.array(image) - - if not "canny" in stencil: - stencil["canny"] = CannyDetector() - detected_map = stencil["canny"]( - input_image, low_threshold, high_threshold - ) - save_img(detected_map) - detected_map = HWC3(detected_map) - return detected_map - - -# Stencil 2. OpenPose. -def hint_openpose( - image: Image.Image, -): - with torch.no_grad(): - input_image = np.array(image) - - if not "openpose" in stencil: - stencil["openpose"] = OpenposeDetector() - - detected_map, _ = stencil["openpose"](input_image) - save_img(detected_map) - detected_map = HWC3(detected_map) - return detected_map - - -# Stencil 3. Scribble. -def hint_scribble(image: Image.Image): - with torch.no_grad(): - input_image = np.array(image) - - detected_map = np.zeros_like(input_image, dtype=np.uint8) - detected_map[np.min(input_image, axis=2) < 127] = 255 - save_img(detected_map) - return detected_map - - -# Stencil 4. Depth (Only Zoe Preprocessing) -def hint_zoedepth(image: Image.Image): - with torch.no_grad(): - input_image = np.array(image) - - if not "depth" in stencil: - stencil["depth"] = ZoeDetector() - - detected_map = stencil["depth"](input_image) - save_img(detected_map) - detected_map = HWC3(detected_map) - return detected_map diff --git a/apps/stable_diffusion/src/utils/stencils/zoe/__init__.py b/apps/stable_diffusion/src/utils/stencils/zoe/__init__.py deleted file mode 100644 index fbf299e4..00000000 --- a/apps/stable_diffusion/src/utils/stencils/zoe/__init__.py +++ /dev/null @@ -1,64 +0,0 @@ -import numpy as np -import torch -from pathlib import Path -import requests - - -from einops import rearrange - -remote_model_path = ( - "https://huggingface.co/lllyasviel/Annotators/resolve/main/ZoeD_M12_N.pt" -) - - -class ZoeDetector: - def __init__(self): - cwd = Path.cwd() - ckpt_path = Path(cwd, "stencil_annotator") - ckpt_path.mkdir(parents=True, exist_ok=True) - modelpath = ckpt_path / "ZoeD_M12_N.pt" - - with requests.get(remote_model_path, stream=True) as r: - r.raise_for_status() - with open(modelpath, "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - model = torch.hub.load( - "monorimet/ZoeDepth:torch_update", - "ZoeD_N", - pretrained=False, - force_reload=False, - ) - - # Hack to fix the ZoeDepth import issue - model_keys = model.state_dict().keys() - loaded_dict = torch.load(modelpath, map_location=model.device)["model"] - loaded_keys = loaded_dict.keys() - for key in loaded_keys - model_keys: - loaded_dict.pop(key) - - model.load_state_dict(loaded_dict) - model.eval() - self.model = model - - def __call__(self, input_image): - assert input_image.ndim == 3 - image_depth = input_image - with torch.no_grad(): - image_depth = torch.from_numpy(image_depth).float() - image_depth = image_depth / 255.0 - image_depth = rearrange(image_depth, "h w c -> 1 c h w") - depth = self.model.infer(image_depth) - - depth = depth[0, 0].cpu().numpy() - - vmin = np.percentile(depth, 2) - vmax = np.percentile(depth, 85) - - depth -= vmin - depth /= vmax - vmin - depth = 1.0 - depth - depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) - - return depth_image diff --git a/apps/stable_diffusion/src/utils/utils.py b/apps/stable_diffusion/src/utils/utils.py deleted file mode 100644 index b3b6aae8..00000000 --- a/apps/stable_diffusion/src/utils/utils.py +++ /dev/null @@ -1,1045 +0,0 @@ -import os -import gc -import json -import re -from PIL import PngImagePlugin -from PIL import Image -from datetime import datetime as dt -from csv import DictWriter -from pathlib import Path -import numpy as np -from random import ( - randint, - seed as seed_random, - getstate as random_getstate, - setstate as random_setstate, -) -import tempfile -import torch -from safetensors.torch import load_file -from shark.shark_inference import SharkInference -from shark.shark_importer import import_with_fx, save_mlir -from shark.iree_utils.vulkan_utils import ( - set_iree_vulkan_runtime_flags, - get_vulkan_target_triple, - get_iree_vulkan_runtime_flags, -) -from shark.iree_utils.metal_utils import get_metal_target_triple -from shark.iree_utils.gpu_utils import get_cuda_sm_cc, get_iree_rocm_args -from apps.stable_diffusion.src.utils.stable_args import args -from apps.stable_diffusion.src.utils.resources import opt_flags -from apps.stable_diffusion.src.utils.sd_annotation import sd_model_annotation -import sys -from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( - download_from_original_stable_diffusion_ckpt, - create_vae_diffusers_config, - convert_ldm_vae_checkpoint, -) -import requests -from io import BytesIO -from omegaconf import OmegaConf -from cpuinfo import get_cpu_info - - -def get_extended_name(model_name): - device = args.device.split("://", 1)[0] - extended_name = "{}_{}".format(model_name, device) - return extended_name - - -def get_vmfb_path_name(model_name): - vmfb_path = os.path.join(os.getcwd(), model_name + ".vmfb") - return vmfb_path - - -def _load_vmfb(shark_module, vmfb_path, model, precision): - model = "vae" if "base_vae" in model or "vae_encode" in model else model - model = "unet" if "stencil" in model else model - model = "unet" if "unet512" in model else model - precision = "fp32" if "clip" in model else precision - extra_args = get_opt_flags(model, precision) - shark_module.load_module(vmfb_path, extra_args=extra_args) - return shark_module - - -def _compile_module(shark_module, model_name, extra_args=[]): - if args.load_vmfb or args.save_vmfb: - vmfb_path = get_vmfb_path_name(model_name) - if args.load_vmfb and os.path.isfile(vmfb_path) and not args.save_vmfb: - print(f"loading existing vmfb from: {vmfb_path}") - shark_module.load_module(vmfb_path, extra_args=extra_args) - else: - if args.save_vmfb: - print("Saving to {}".format(vmfb_path)) - else: - print( - "No vmfb found. Compiling and saving to {}".format( - vmfb_path - ) - ) - path = shark_module.save_module( - os.getcwd(), model_name, extra_args, debug=args.compile_debug - ) - shark_module.load_module(path, extra_args=extra_args) - else: - shark_module.compile(extra_args) - return shark_module - - -# Downloads the model from shark_tank and returns the shark_module. -def get_shark_model(tank_url, model_name, extra_args=None): - if extra_args is None: - extra_args = [] - from shark.parser import shark_args - - # Set local shark_tank cache directory. - shark_args.local_tank_cache = args.local_tank_cache - from shark.shark_downloader import download_model - - if "cuda" in args.device: - shark_args.enable_tf32 = True - - mlir_model, func_name, inputs, golden_out = download_model( - model_name, - tank_url=tank_url, - frontend="torch", - ) - shark_module = SharkInference( - mlir_model, device=args.device, mlir_dialect="tm_tensor" - ) - return _compile_module(shark_module, model_name, extra_args) - - -# Converts the torch-module into a shark_module. -def compile_through_fx( - model, - inputs, - extended_model_name, - is_f16=False, - f16_input_mask=None, - use_tuned=False, - save_dir="", - debug=False, - generate_vmfb=True, - extra_args=None, - base_model_id=None, - model_name=None, - precision=None, - return_mlir=False, - device=None, -): - if extra_args is None: - extra_args = [] - if not return_mlir and model_name is not None: - vmfb_path = get_vmfb_path_name(extended_model_name) - if os.path.isfile(vmfb_path): - shark_module = SharkInference(mlir_module=None, device=args.device) - return ( - _load_vmfb(shark_module, vmfb_path, model_name, precision), - None, - ) - - from shark.parser import shark_args - - if "cuda" in args.device: - shark_args.enable_tf32 = True - - ( - mlir_module, - func_name, - ) = import_with_fx( - model=model, - inputs=inputs, - is_f16=is_f16, - f16_input_mask=f16_input_mask, - debug=debug, - model_name=extended_model_name, - ) - - if use_tuned: - if "vae" in extended_model_name.split("_")[0]: - args.annotation_model = "vae" - if ( - "unet" in model_name.split("_")[0] - or "unet_512" in model_name.split("_")[0] - ): - args.annotation_model = "unet" - mlir_module = sd_model_annotation( - mlir_module, extended_model_name, base_model_id - ) - - if not os.path.isdir(save_dir): - save_dir = "" - - mlir_module = save_mlir( - mlir_module, - model_name=extended_model_name, - dir=save_dir, - ) - shark_module = SharkInference( - mlir_module, - device=args.device if device is None else device, - mlir_dialect="tm_tensor", - ) - if generate_vmfb: - return ( - _compile_module(shark_module, extended_model_name, extra_args), - mlir_module, - ) - - gc.collect() - - -def set_iree_runtime_flags(): - # TODO: This function should be device-agnostic and piped properly - # to general runtime driver init. - vulkan_runtime_flags = get_iree_vulkan_runtime_flags() - if args.enable_rgp: - vulkan_runtime_flags += [ - f"--enable_rgp=true", - f"--vulkan_debug_utils=true", - ] - if args.device_allocator_heap_key: - vulkan_runtime_flags += [ - f"--device_allocator=caching:device_local={args.device_allocator_heap_key}", - ] - set_iree_vulkan_runtime_flags(flags=vulkan_runtime_flags) - - -def get_all_devices(driver_name): - """ - Inputs: driver_name - Returns a list of all the available devices for a given driver sorted by - the iree path names of the device as in --list_devices option in iree. - """ - from iree.runtime import get_driver - - driver = get_driver(driver_name) - device_list_src = driver.query_available_devices() - device_list_src.sort(key=lambda d: d["path"]) - return device_list_src - - -def get_device_mapping(driver, key_combination=3): - """This method ensures consistent device ordering when choosing - specific devices for execution - Args: - driver (str): execution driver (vulkan, cuda, rocm, etc) - key_combination (int, optional): choice for mapping value for - device name. - 1 : path - 2 : name - 3 : (name, path) - Defaults to 3. - Returns: - dict: map to possible device names user can input mapped to desired - combination of name/path. - """ - from shark.iree_utils._common import iree_device_map - - driver = iree_device_map(driver) - device_list = get_all_devices(driver) - device_map = dict() - - def get_output_value(dev_dict): - if key_combination == 1: - return f"{driver}://{dev_dict['path']}" - if key_combination == 2: - return dev_dict["name"] - if key_combination == 3: - return dev_dict["name"], f"{driver}://{dev_dict['path']}" - - # mapping driver name to default device (driver://0) - device_map[f"{driver}"] = get_output_value(device_list[0]) - for i, device in enumerate(device_list): - # mapping with index - device_map[f"{driver}://{i}"] = get_output_value(device) - # mapping with full path - device_map[f"{driver}://{device['path']}"] = get_output_value(device) - return device_map - - -def map_device_to_name_path(device, key_combination=3): - """Gives the appropriate device data (supported name/path) for user - selected execution device - Args: - device (str): user - key_combination (int, optional): choice for mapping value for - device name. - 1 : path - 2 : name - 3 : (name, path) - Defaults to 3. - Raises: - ValueError: - Returns: - str / tuple: returns the mapping str or tuple of mapping str for - the device depending on key_combination value - """ - driver = device.split("://")[0] - device_map = get_device_mapping(driver, key_combination) - try: - device_mapping = device_map[device] - except KeyError: - raise ValueError(f"Device '{device}' is not a valid device.") - return device_mapping - - -def set_init_device_flags(): - if "vulkan" in args.device: - # set runtime flags for vulkan. - set_iree_runtime_flags() - - # set triple flag to avoid multiple calls to get_vulkan_triple_flag - device_name, args.device = map_device_to_name_path(args.device) - if not args.iree_vulkan_target_triple: - triple = get_vulkan_target_triple(device_name) - if triple is not None: - args.iree_vulkan_target_triple = triple - print( - f"Found device {device_name}. Using target triple " - f"{args.iree_vulkan_target_triple}." - ) - elif "cuda" in args.device: - args.device = "cuda" - elif "metal" in args.device: - device_name, args.device = map_device_to_name_path(args.device) - if not args.iree_metal_target_platform: - triple = get_metal_target_triple(device_name) - if triple is not None: - args.iree_metal_target_platform = triple.split("-")[-1] - print( - f"Found device {device_name}. Using target triple " - f"{args.iree_metal_target_platform}." - ) - elif "cpu" in args.device: - args.device = "cpu" - - # set max_length based on availability. - if args.hf_model_id in [ - "Linaqruf/anything-v3.0", - "wavymulder/Analog-Diffusion", - "dreamlike-art/dreamlike-diffusion-1.0", - ]: - args.max_length = 77 - elif args.hf_model_id == "prompthero/openjourney": - args.max_length = 64 - - # Use tuned models in the case of fp16, vulkan rdna3 or cuda sm devices. - if args.ckpt_loc != "": - base_model_id = fetch_and_update_base_model_id(args.ckpt_loc) - else: - base_model_id = fetch_and_update_base_model_id(args.hf_model_id) - if base_model_id == "": - base_model_id = args.hf_model_id - - if ( - args.precision != "fp16" - or args.height not in [512, 768] - or (args.height == 512 and args.width not in [512, 768]) - or (args.height == 768 and args.width not in [512, 768]) - or args.batch_size != 1 - or ("vulkan" not in args.device and "cuda" not in args.device) - ): - args.use_tuned = False - - elif ( - args.height != args.width - and "rdna2" in args.iree_vulkan_target_triple - and base_model_id - not in [ - "CompVis/stable-diffusion-v1-4", - "runwayml/stable-diffusion-v1-5", - ] - ): - args.use_tuned = False - - elif base_model_id not in [ - "Linaqruf/anything-v3.0", - "dreamlike-art/dreamlike-diffusion-1.0", - "prompthero/openjourney", - "wavymulder/Analog-Diffusion", - "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-2-1-base", - "CompVis/stable-diffusion-v1-4", - "runwayml/stable-diffusion-v1-5", - "runwayml/stable-diffusion-inpainting", - "stabilityai/stable-diffusion-2-inpainting", - ]: - args.use_tuned = False - - elif "vulkan" in args.device and not any( - x in args.iree_vulkan_target_triple for x in ["rdna2", "rdna3"] - ): - args.use_tuned = False - - elif "cuda" in args.device and get_cuda_sm_cc() not in ["sm_80", "sm_89"]: - args.use_tuned = False - - elif args.use_base_vae and args.hf_model_id not in [ - "stabilityai/stable-diffusion-2-1-base", - "CompVis/stable-diffusion-v1-4", - ]: - args.use_tuned = False - - elif ( - args.height == 768 - and args.width == 768 - and ( - base_model_id - not in [ - "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-2-1-base", - ] - or "rdna" not in args.iree_vulkan_target_triple - ) - ): - args.use_tuned = False - - elif "rdna2" in args.iree_vulkan_target_triple and ( - base_model_id - not in [ - "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-2-1-base", - "CompVis/stable-diffusion-v1-4", - ] - ): - args.use_tuned = False - - if args.use_tuned: - print( - f"Using tuned models for {base_model_id}(fp16) on " - f"device {args.device}." - ) - else: - print("Tuned models are currently not supported for this setting.") - - # set import_mlir to True for unuploaded models. - if args.ckpt_loc != "": - args.import_mlir = True - - elif args.hf_model_id not in [ - "Linaqruf/anything-v3.0", - "dreamlike-art/dreamlike-diffusion-1.0", - "prompthero/openjourney", - "wavymulder/Analog-Diffusion", - "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-2-1-base", - "CompVis/stable-diffusion-v1-4", - ]: - args.import_mlir = True - - elif args.height != 512 or args.width != 512 or args.batch_size != 1: - args.import_mlir = True - - elif args.use_tuned and args.hf_model_id in [ - "dreamlike-art/dreamlike-diffusion-1.0", - "prompthero/openjourney", - "stabilityai/stable-diffusion-2-1", - ]: - args.import_mlir = True - - elif ( - args.use_tuned - and "vulkan" in args.device - and "rdna2" in args.iree_vulkan_target_triple - ): - args.import_mlir = True - - elif ( - args.use_tuned - and "cuda" in args.device - and get_cuda_sm_cc() == "sm_89" - ): - args.import_mlir = True - - -# Utility to get list of devices available. -def get_available_devices(): - def get_devices_by_name(driver_name): - from shark.iree_utils._common import iree_device_map - - device_list = [] - try: - driver_name = iree_device_map(driver_name) - device_list_dict = get_all_devices(driver_name) - print(f"{driver_name} devices are available.") - except: - print(f"{driver_name} devices are not available.") - else: - cpu_name = get_cpu_info()["brand_raw"] - for i, device in enumerate(device_list_dict): - device_name = ( - cpu_name if device["name"] == "default" else device["name"] - ) - if "local" in driver_name: - device_list.append( - f"{device_name} => {driver_name.replace('local', 'cpu')}" - ) - else: - # for drivers with single devices - # let the default device be selected without any indexing - if len(device_list_dict) == 1: - device_list.append(f"{device_name} => {driver_name}") - else: - device_list.append( - f"{device_name} => {driver_name}://{i}" - ) - return device_list - - set_iree_runtime_flags() - - available_devices = [] - from shark.iree_utils.vulkan_utils import ( - get_all_vulkan_devices, - ) - - vulkaninfo_list = get_all_vulkan_devices() - vulkan_devices = [] - id = 0 - for device in vulkaninfo_list: - vulkan_devices.append(f"{device.strip()} => vulkan://{id}") - id += 1 - if id != 0: - print(f"vulkan devices are available.") - available_devices.extend(vulkan_devices) - metal_devices = get_devices_by_name("metal") - available_devices.extend(metal_devices) - cuda_devices = get_devices_by_name("cuda") - available_devices.extend(cuda_devices) - rocm_devices = get_devices_by_name("rocm") - available_devices.extend(rocm_devices) - cpu_device = get_devices_by_name("cpu-sync") - available_devices.extend(cpu_device) - cpu_device = get_devices_by_name("cpu-task") - available_devices.extend(cpu_device) - return available_devices - - -def disk_space_check(path, lim=20): - from shutil import disk_usage - - du = disk_usage(path) - free = du.free / (1024 * 1024 * 1024) - if free <= lim: - print(f"[WARNING] Only {free:.2f}GB space available in {path}.") - - -def get_opt_flags(model, precision="fp16"): - iree_flags = [] - is_tuned = "tuned" if args.use_tuned else "untuned" - if len(args.iree_vulkan_target_triple) > 0: - iree_flags.append( - f"-iree-vulkan-target-triple={args.iree_vulkan_target_triple}" - ) - if "rocm" in args.device: - rocm_args = get_iree_rocm_args() - iree_flags.extend(rocm_args) - print(iree_flags) - if args.iree_constant_folding == False: - iree_flags.append("--iree-opt-const-expr-hoisting=False") - iree_flags.append( - "--iree-codegen-linalg-max-constant-fold-elements=9223372036854775807" - ) - if args.data_tiling == False: - iree_flags.append("--iree-opt-data-tiling=False") - - if "default_compilation_flags" in opt_flags[model][is_tuned][precision]: - iree_flags += opt_flags[model][is_tuned][precision][ - "default_compilation_flags" - ] - - if "specified_compilation_flags" in opt_flags[model][is_tuned][precision]: - device = ( - args.device - if "://" not in args.device - else args.device.split("://")[0] - ) - if ( - device - not in opt_flags[model][is_tuned][precision][ - "specified_compilation_flags" - ] - ): - device = "default_device" - iree_flags += opt_flags[model][is_tuned][precision][ - "specified_compilation_flags" - ][device] - if "vae" not in model: - # Due to lack of support for multi-reduce, we always collapse reduction - # dims before dispatch formation right now. - iree_flags += ["--iree-flow-collapse-reduction-dims"] - return iree_flags - - -def get_path_stem(path): - path = Path(path) - return path.stem - - -def get_path_to_diffusers_checkpoint(custom_weights): - path = Path(custom_weights) - diffusers_path = path.parent.absolute() - diffusers_directory_name = os.path.join("diffusers", path.stem) - complete_path_to_diffusers = diffusers_path / diffusers_directory_name - complete_path_to_diffusers.mkdir(parents=True, exist_ok=True) - path_to_diffusers = complete_path_to_diffusers.as_posix() - return path_to_diffusers - - -def preprocessCKPT(custom_weights, is_inpaint=False): - path_to_diffusers = get_path_to_diffusers_checkpoint(custom_weights) - if next(Path(path_to_diffusers).iterdir(), None): - print("Checkpoint already loaded at : ", path_to_diffusers) - return - else: - print( - "Diffusers' checkpoint will be identified here : ", - path_to_diffusers, - ) - from_safetensors = ( - True if custom_weights.lower().endswith(".safetensors") else False - ) - # EMA weights usually yield higher quality images for inference but - # non-EMA weights have been yielding better results in our case. - # TODO: Add an option `--ema` (`--no-ema`) for users to specify if - # they want to go for EMA weight extraction or not. - extract_ema = False - print( - "Loading diffusers' pipeline from original stable diffusion checkpoint" - ) - num_in_channels = 9 if is_inpaint else 4 - pipe = download_from_original_stable_diffusion_ckpt( - checkpoint_path_or_dict=custom_weights, - extract_ema=extract_ema, - from_safetensors=from_safetensors, - num_in_channels=num_in_channels, - ) - pipe.save_pretrained(path_to_diffusers) - print("Loading complete") - - -def convert_original_vae(vae_checkpoint): - vae_state_dict = {} - for key in list(vae_checkpoint.keys()): - vae_state_dict["first_stage_model." + key] = vae_checkpoint.get(key) - - config_url = ( - "https://raw.githubusercontent.com/CompVis/stable-diffusion/" - "main/configs/stable-diffusion/v1-inference.yaml" - ) - original_config_file = BytesIO(requests.get(config_url).content) - original_config = OmegaConf.load(original_config_file) - vae_config = create_vae_diffusers_config(original_config, image_size=512) - - converted_vae_checkpoint = convert_ldm_vae_checkpoint( - vae_state_dict, vae_config - ) - return converted_vae_checkpoint - - -def processLoRA(model, use_lora, splitting_prefix): - state_dict = "" - if ".safetensors" in use_lora: - state_dict = load_file(use_lora) - else: - state_dict = torch.load(use_lora) - alpha = 0.75 - visited = [] - - # directly update weight in model - process_unet = "te" not in splitting_prefix - for key in state_dict: - if ".alpha" in key or key in visited: - continue - - curr_layer = model - if ("text" not in key and process_unet) or ( - "text" in key and not process_unet - ): - layer_infos = ( - key.split(".")[0].split(splitting_prefix)[-1].split("_") - ) - else: - continue - - # find the target layer - temp_name = layer_infos.pop(0) - while len(layer_infos) > -1: - try: - curr_layer = curr_layer.__getattr__(temp_name) - if len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - elif len(layer_infos) == 0: - break - except Exception: - if len(temp_name) > 0: - temp_name += "_" + layer_infos.pop(0) - else: - temp_name = layer_infos.pop(0) - - pair_keys = [] - if "lora_down" in key: - pair_keys.append(key.replace("lora_down", "lora_up")) - pair_keys.append(key) - else: - pair_keys.append(key) - pair_keys.append(key.replace("lora_up", "lora_down")) - - # update weight - if len(state_dict[pair_keys[0]].shape) == 4: - weight_up = ( - state_dict[pair_keys[0]] - .squeeze(3) - .squeeze(2) - .to(torch.float32) - ) - weight_down = ( - state_dict[pair_keys[1]] - .squeeze(3) - .squeeze(2) - .to(torch.float32) - ) - curr_layer.weight.data += alpha * torch.mm( - weight_up, weight_down - ).unsqueeze(2).unsqueeze(3) - else: - weight_up = state_dict[pair_keys[0]].to(torch.float32) - weight_down = state_dict[pair_keys[1]].to(torch.float32) - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) - # update visited list - for item in pair_keys: - visited.append(item) - return model - - -def update_lora_weight_for_unet(unet, use_lora): - extensions = [".bin", ".safetensors", ".pt"] - if not any([extension in use_lora for extension in extensions]): - # We assume if it is a HF ID with standalone LoRA weights. - unet.load_attn_procs(use_lora) - return unet - - main_file_name = get_path_stem(use_lora) - if ".bin" in use_lora: - main_file_name += ".bin" - elif ".safetensors" in use_lora: - main_file_name += ".safetensors" - elif ".pt" in use_lora: - main_file_name += ".pt" - else: - sys.exit("Only .bin and .safetensors format for LoRA is supported") - - try: - dir_name = os.path.dirname(use_lora) - unet.load_attn_procs(dir_name, weight_name=main_file_name) - return unet - except: - return processLoRA(unet, use_lora, "lora_unet_") - - -def update_lora_weight(model, use_lora, model_name): - if "unet" in model_name: - return update_lora_weight_for_unet(model, use_lora) - try: - return processLoRA(model, use_lora, "lora_te_") - except: - return None - - -# `fetch_and_update_base_model_id` is a resource utility function which -# helps to maintain mapping of the model to run with its base model. -# If `base_model` is "", then this function tries to fetch the base model -# info for the `model_to_run`. -def fetch_and_update_base_model_id(model_to_run, base_model=""): - variants_path = os.path.join(os.getcwd(), "variants.json") - data = {model_to_run: base_model} - json_data = {} - if os.path.exists(variants_path): - with open(variants_path, "r", encoding="utf-8") as jsonFile: - json_data = json.load(jsonFile) - # Return with base_model's info if base_model is "". - if base_model == "": - if model_to_run in json_data: - base_model = json_data[model_to_run] - return base_model - elif base_model == "": - return base_model - # Update JSON data to contain an entry mapping model_to_run with - # base_model. - json_data.update(data) - with open(variants_path, "w", encoding="utf-8") as jsonFile: - json.dump(json_data, jsonFile) - - -# Generate and return a new seed if the provided one is not in the -# supported range (including -1) -def sanitize_seed(seed: int | str): - seed = int(seed) - uint32_info = np.iinfo(np.uint32) - uint32_min, uint32_max = uint32_info.min, uint32_info.max - if seed < uint32_min or seed >= uint32_max: - seed = randint(uint32_min, uint32_max) - return seed - - -# take a seed expression in an input format and convert it to -# a list of integers, where possible -def parse_seed_input(seed_input: str | list | int): - if isinstance(seed_input, str): - try: - seed_input = json.loads(seed_input) - except (ValueError, TypeError): - seed_input = None - - if isinstance(seed_input, int): - return [seed_input] - - if isinstance(seed_input, list) and all( - type(seed) is int for seed in seed_input - ): - return seed_input - - raise TypeError( - "Seed input must be an integer or an array of integers in JSON format" - ) - - -# Generate a set of seeds from an input expression for batch_count batches, -# optionally using that input as the rng seed for any randomly generated seeds. -def batch_seeds( - seed_input: str | list | int, batch_count: int, repeatable=False -): - # turn the input into a list if possible - seeds = parse_seed_input(seed_input) - - # slice or pad the list to be of batch_count length - seeds = seeds[:batch_count] + [-1] * (batch_count - len(seeds)) - - if repeatable: - if all(seed < 0 for seed in seeds): - seeds[0] = sanitize_seed(seeds[0]) - - # set seed for the rng based on what we have so far - saved_random_state = random_getstate() - seed_random(str([n for n in seeds if n > -1])) - - # generate any seeds that are unspecified - seeds = [sanitize_seed(seed) for seed in seeds] - - if repeatable: - # reset the rng back to normal - random_setstate(saved_random_state) - - return seeds - - -# clear all the cached objects to recompile cleanly. -def clear_all(): - print("CLEARING ALL, EXPECT SEVERAL MINUTES TO RECOMPILE") - from glob import glob - import shutil - - vmfbs = glob(os.path.join(os.getcwd(), "*.vmfb")) - for vmfb in vmfbs: - if os.path.exists(vmfb): - os.remove(vmfb) - # Temporary workaround of deleting yaml files to incorporate - # diffusers' pipeline. - # TODO: Remove this once we have better weight updation logic. - inference_yaml = ["v2-inference-v.yaml", "v1-inference.yaml"] - for yaml in inference_yaml: - if os.path.exists(yaml): - os.remove(yaml) - home = os.path.expanduser("~") - if os.name == "nt": # Windows - appdata = os.getenv("LOCALAPPDATA") - shutil.rmtree(os.path.join(appdata, "AMD/VkCache"), ignore_errors=True) - shutil.rmtree( - os.path.join(home, ".local/shark_tank"), ignore_errors=True - ) - elif os.name == "unix": - shutil.rmtree(os.path.join(home, ".cache/AMD/VkCache")) - shutil.rmtree(os.path.join(home, ".local/shark_tank")) - if args.local_tank_cache != "": - shutil.rmtree(args.local_tank_cache) - - -def get_generated_imgs_path() -> Path: - return Path( - args.output_dir if args.output_dir else Path.cwd(), "generated_imgs" - ) - - -def get_generated_imgs_todays_subdir() -> str: - return dt.now().strftime("%Y%m%d") - - -# save output images and the inputs corresponding to it. -def save_output_img(output_img, img_seed, extra_info=None): - if extra_info is None: - extra_info = {} - generated_imgs_path = Path( - get_generated_imgs_path(), get_generated_imgs_todays_subdir() - ) - generated_imgs_path.mkdir(parents=True, exist_ok=True) - csv_path = Path(generated_imgs_path, "imgs_details.csv") - - prompt_slice = re.sub("[^a-zA-Z0-9]", "_", args.prompts[0][:15]) - out_img_name = f"{dt.now().strftime('%H%M%S')}_{prompt_slice}_{img_seed}" - - img_model = args.hf_model_id - if args.ckpt_loc: - img_model = Path(os.path.basename(args.ckpt_loc)).stem - - img_vae = None - if args.custom_vae: - img_vae = Path(os.path.basename(args.custom_vae)).stem - - img_lora = None - if args.use_lora: - img_lora = Path(os.path.basename(args.use_lora)).stem - - if args.output_img_format == "jpg": - out_img_path = Path(generated_imgs_path, f"{out_img_name}.jpg") - output_img.save(out_img_path, quality=95, subsampling=0) - else: - out_img_path = Path(generated_imgs_path, f"{out_img_name}.png") - pngInfo = PngImagePlugin.PngInfo() - - if args.write_metadata_to_png: - # Using a conditional expression caused problems, so setting a new - # variable for now. - if args.use_hiresfix: - png_size_text = f"{args.hiresfix_width}x{args.hiresfix_height}" - else: - png_size_text = f"{args.width}x{args.height}" - - pngInfo.add_text( - "parameters", - f"{args.prompts[0]}" - f"\nNegative prompt: {args.negative_prompts[0]}" - f"\nSteps: {args.steps}," - f"Sampler: {args.scheduler}, " - f"CFG scale: {args.guidance_scale}, " - f"Seed: {img_seed}," - f"Size: {png_size_text}, " - f"Model: {img_model}, " - f"VAE: {img_vae}, " - f"LoRA: {img_lora}", - ) - - output_img.save(out_img_path, "PNG", pnginfo=pngInfo) - - if args.output_img_format not in ["png", "jpg"]: - print( - f"[ERROR] Format {args.output_img_format} is not " - f"supported yet. Image saved as png instead." - f"Supported formats: png / jpg" - ) - - # To be as low-impact as possible to the existing CSV format, we append - # "VAE" and "LORA" to the end. However, it does not fit the hierarchy of - # importance for each data point. Something to consider. - new_entry = { - "VARIANT": img_model, - "SCHEDULER": args.scheduler, - "PROMPT": args.prompts[0], - "NEG_PROMPT": args.negative_prompts[0], - "SEED": img_seed, - "CFG_SCALE": args.guidance_scale, - "PRECISION": args.precision, - "STEPS": args.steps, - "HEIGHT": args.height - if not args.use_hiresfix - else args.hiresfix_height, - "WIDTH": args.width if not args.use_hiresfix else args.hiresfix_width, - "MAX_LENGTH": args.max_length, - "OUTPUT": out_img_path, - "VAE": img_vae, - "LORA": img_lora, - } - - new_entry.update(extra_info) - - csv_mode = "a" if os.path.isfile(csv_path) else "w" - with open(csv_path, csv_mode, encoding="utf-8") as csv_obj: - dictwriter_obj = DictWriter(csv_obj, fieldnames=list(new_entry.keys())) - if csv_mode == "w": - dictwriter_obj.writeheader() - dictwriter_obj.writerow(new_entry) - csv_obj.close() - - if args.save_metadata_to_json: - del new_entry["OUTPUT"] - json_path = Path(generated_imgs_path, f"{out_img_name}.json") - with open(json_path, "w") as f: - json.dump(new_entry, f, indent=4) - - -def get_generation_text_info(seeds, device): - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += ( - f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}" - ) - text_output += f"\nscheduler={args.scheduler}, " f"device={device}" - text_output += ( - f"\nsteps={args.steps}, " - f"guidance_scale={args.guidance_scale}, " - f"seed={seeds}" - ) - text_output += ( - f"\nsize={args.height}x{args.width}, " - if not args.use_hiresfix - else f"\nsize={args.hiresfix_height}x{args.hiresfix_width}, " - ) - text_output += ( - f"batch_count={args.batch_count}, " - f"batch_size={args.batch_size}, " - f"max_length={args.max_length}" - ) - - return text_output - - -# For stencil, the input image can be of any size, but we need to ensure that -# it conforms with our model constraints :- -# Both width and height should be in the range of [128, 768] and multiple of 8. -# This utility function performs the transformation on the input image while -# also maintaining the aspect ratio before sending it to the stencil pipeline. -def resize_stencil(image: Image.Image, width, height): - aspect_ratio = width / height - min_size = min(width, height) - if min_size < 128: - n_size = 128 - if width == min_size: - width = n_size - height = n_size / aspect_ratio - else: - height = n_size - width = n_size * aspect_ratio - width = int(width) - height = int(height) - n_width = width // 8 - n_height = height // 8 - n_width *= 8 - n_height *= 8 - - min_size = min(width, height) - if min_size > 768: - n_size = 768 - if width == min_size: - height = n_size - width = n_size * aspect_ratio - else: - width = n_size - height = n_size / aspect_ratio - width = int(width) - height = int(height) - n_width = width // 8 - n_height = height // 8 - n_width *= 8 - n_height *= 8 - new_image = image.resize((n_width, n_height)) - return new_image, n_width, n_height diff --git a/apps/stable_diffusion/stable_diffusion_telegram_bot.md b/apps/stable_diffusion/stable_diffusion_telegram_bot.md deleted file mode 100644 index 0784910c..00000000 --- a/apps/stable_diffusion/stable_diffusion_telegram_bot.md +++ /dev/null @@ -1,15 +0,0 @@ -You need to pre-create your bot (https://core.telegram.org/bots#how-do-i-create-a-bot) -Then create in the directory web file .env -In it the record: -TG_TOKEN="your_token" -specifying your bot's token from previous step. -Then run telegram_bot.py with the same parameters that you use when running index.py, for example: -python telegram_bot.py --max_length=77 --vulkan_large_heap_block_size=0 --use_base_vae --local_tank_cache h:\shark\TEMP - -Bot commands: -/select_model -/select_scheduler -/set_steps "integer number of steps" -/set_guidance_scale "integer number" -/set_negative_prompt "negative text" -Any other text triggers the creation of an image based on it. diff --git a/apps/stable_diffusion/studio_bundle.spec b/apps/stable_diffusion/studio_bundle.spec deleted file mode 100644 index d73abd1e..00000000 --- a/apps/stable_diffusion/studio_bundle.spec +++ /dev/null @@ -1,54 +0,0 @@ -# -*- mode: python ; coding: utf-8 -*- -from apps.stable_diffusion.shark_studio_imports import pathex, datas, hiddenimports - -binaries = [] - -block_cipher = None - -a = Analysis( - ['web\\index.py'], - pathex=pathex, - binaries=binaries, - datas=datas, - hiddenimports=hiddenimports, - hookspath=[], - hooksconfig={}, - runtime_hooks=[], - excludes=[], - win_no_prefer_redirects=False, - win_private_assemblies=False, - cipher=block_cipher, - noarchive=False, - module_collection_mode={ - 'gradio': 'py', # Collect gradio package as source .py files - }, -) -pyz = PYZ(a.pure, a.zipped_data, cipher=block_cipher) - -exe = EXE( - pyz, - a.scripts, - [], - exclude_binaries=True, - name='studio_bundle', - debug=False, - bootloader_ignore_signals=False, - strip=False, - upx=True, - console=True, - disable_windowed_traceback=False, - argv_emulation=False, - target_arch=None, - codesign_identity=None, - entitlements_file=None, -) -coll = COLLECT( - exe, - a.binaries, - a.zipfiles, - a.datas, - strip=False, - upx=True, - upx_exclude=[], - name='studio_bundle', -) diff --git a/apps/stable_diffusion/web/api/__init__.py b/apps/stable_diffusion/web/api/__init__.py deleted file mode 100644 index 892d1976..00000000 --- a/apps/stable_diffusion/web/api/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from apps.stable_diffusion.web.api.sdapi_v1 import sdapi diff --git a/apps/stable_diffusion/web/api/sdapi_v1.py b/apps/stable_diffusion/web/api/sdapi_v1.py deleted file mode 100644 index f376f0fe..00000000 --- a/apps/stable_diffusion/web/api/sdapi_v1.py +++ /dev/null @@ -1,580 +0,0 @@ -import os - -from collections import defaultdict -from enum import Enum -from fastapi import FastAPI -from pydantic import BaseModel, Field, conlist, model_validator - -from apps.stable_diffusion.web.api.utils import ( - frozen_args, - sampler_aliases, - encode_pil_to_base64, - decode_base64_to_image, - get_model_from_request, - get_scheduler_from_request, - get_lora_params, - get_device, - GenerationInputData, - GenerationResponseData, -) - -from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_files, - get_custom_model_pathfile, - predefined_models, - predefined_paint_models, - predefined_upscaler_models, - scheduler_list, -) -from apps.stable_diffusion.web.ui.txt2img_ui import txt2img_inf -from apps.stable_diffusion.web.ui.img2img_ui import img2img_inf -from apps.stable_diffusion.web.ui.inpaint_ui import inpaint_inf -from apps.stable_diffusion.web.ui.outpaint_ui import outpaint_inf -from apps.stable_diffusion.web.ui.upscaler_ui import upscaler_inf - -sdapi = FastAPI() - - -# Rest API: /sdapi/v1/sd-models (lists available models) -class AppParam(str, Enum): - txt2img = "txt2img" - img2img = "img2img" - inpaint = "inpaint" - outpaint = "outpaint" - upscaler = "upscaler" - - -@sdapi.get( - "/v1/sd-models", - summary="lists available models", - description=( - "This is all the models that this server currently knows about.\n " - "Models listed may still have a compilation and build pending that " - "will be triggered the first time they are used." - ), -) -def sd_models_api(app: AppParam = frozen_args.app): - match app: - case "inpaint" | "outpaint": - checkpoint_type = "inpainting" - predefined = predefined_paint_models - case "upscaler": - checkpoint_type = "upscaler" - predefined = predefined_upscaler_models - case _: - checkpoint_type = "" - predefined = predefined_models - - return [ - { - "title": model_file, - "model_name": model_file, - "hash": None, - "sha256": None, - "filename": get_custom_model_pathfile(model_file), - "config": None, - } - for model_file in get_custom_model_files( - custom_checkpoint_type=checkpoint_type - ) - ] + [ - { - "title": model, - "model_name": model, - "hash": None, - "sha256": None, - "filename": None, - "config": None, - } - for model in predefined - ] - - -# Rest API: /sdapi/v1/samplers (lists schedulers) -@sdapi.get( - "/v1/samplers", - summary="lists available schedulers/samplers", - description=( - "These are all the Schedulers defined and available. Not " - "every scheduler is compatible with all apis. Aliases are " - "equivalent samplers in A1111 if they are known." - ), -) -def sd_samplers_api(): - reverse_sampler_aliases = defaultdict(list) - for key, value in sampler_aliases.items(): - reverse_sampler_aliases[value].append(key) - - return ( - { - "name": scheduler, - "aliases": reverse_sampler_aliases.get(scheduler, []), - "options": {}, - } - for scheduler in scheduler_list - ) - - -# Rest API: /sdapi/v1/options (lists application level options) -@sdapi.get( - "/v1/options", - summary="lists current settings of application level options", - description=( - "A subset of the command line arguments set at startup renamed " - "to correspond to the A1111 naming. Only a small subset of A1111 " - "options are returned." - ), -) -def options_api(): - # This is mostly just enough to support what Koboldcpp wants, with a - # few other things that seemed obvious - return { - "samples_save": True, - "samples_format": frozen_args.output_img_format, - "sd_model_checkpoint": os.path.basename(frozen_args.ckpt_loc) - if frozen_args.ckpt_loc - else frozen_args.hf_model_id, - "sd_lora": frozen_args.use_lora, - "sd_vae": frozen_args.custom_vae or "Automatic", - "enable_pnginfo": frozen_args.write_metadata_to_png, - } - - -# Rest API: /sdapi/v1/cmd-flags (lists command line argument settings) -@sdapi.get( - "/v1/cmd-flags", - summary="lists the command line arguments value that were set on startup.", -) -def cmd_flags_api(): - return vars(frozen_args) - - -# Rest API: /sdapi/v1/txt2img (Text to image) -class ModelOverrideSettings(BaseModel): - sd_model_checkpoint: str = get_model_from_request( - fallback_model="stabilityai/stable-diffusion-2-1-base" - ) - - -class Txt2ImgInputData(GenerationInputData): - enable_hr: bool = frozen_args.use_hiresfix - hr_resize_y: int = Field( - default=frozen_args.hiresfix_height, ge=128, le=768, multiple_of=8 - ) - hr_resize_x: int = Field( - default=frozen_args.hiresfix_width, ge=128, le=768, multiple_of=8 - ) - override_settings: ModelOverrideSettings = None - - -@sdapi.post( - "/v1/txt2img", - summary="Does text to image generation", - response_model=GenerationResponseData, -) -def txt2img_api(InputData: Txt2ImgInputData): - model_id = get_model_from_request( - InputData, - fallback_model="stabilityai/stable-diffusion-2-1-base", - ) - scheduler = get_scheduler_from_request( - InputData, "txt2img_hires" if InputData.enable_hr else "txt2img" - ) - (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) - - print( - f"Prompt: {InputData.prompt}, " - f"Negative Prompt: {InputData.negative_prompt}, " - f"Seed: {InputData.seed}," - f"Model: {model_id}, " - f"Scheduler: {scheduler}. " - ) - - res = txt2img_inf( - InputData.prompt, - InputData.negative_prompt, - InputData.height, - InputData.width, - InputData.steps, - InputData.cfg_scale, - InputData.seed, - batch_count=InputData.n_iter, - batch_size=1, - scheduler=scheduler, - model_id=model_id, - custom_vae=frozen_args.custom_vae or "None", - precision="fp16", - device=get_device(frozen_args.device), - max_length=frozen_args.max_length, - save_metadata_to_json=frozen_args.save_metadata_to_json, - save_metadata_to_png=frozen_args.write_metadata_to_png, - lora_weights=lora_weights, - lora_hf_id=lora_hf_id, - ondemand=frozen_args.ondemand, - repeatable_seeds=False, - use_hiresfix=InputData.enable_hr, - hiresfix_height=InputData.hr_resize_y, - hiresfix_width=InputData.hr_resize_x, - hiresfix_strength=frozen_args.hiresfix_strength, - resample_type=frozen_args.resample_type, - ) - - # Since we're not streaming we just want the last generator result - for items_so_far in res: - items = items_so_far - - return { - "images": encode_pil_to_base64(items[0]), - "parameters": {}, - "info": items[1], - } - - -# Rest API: /sdapi/v1/img2img (Image to image) -class StencilParam(str, Enum): - canny = "canny" - openpose = "openpose" - scribble = "scribble" - zoedepth = "zoedepth" - - -class Img2ImgInputData(GenerationInputData): - init_images: conlist(str, min_length=1, max_length=2) - denoising_strength: float = frozen_args.strength - use_stencil: StencilParam = frozen_args.use_stencil - override_settings: ModelOverrideSettings = None - - @model_validator(mode="after") - def check_image_supplied_for_scribble_stencil(self) -> "Img2ImgInputData": - if ( - self.use_stencil == StencilParam.scribble - and len(self.init_images) < 2 - ): - raise ValueError( - "a second image must be supplied for the controlnet:scribble stencil" - ) - - return self - - -@sdapi.post( - "/v1/img2img", - summary="Does image to image generation", - response_model=GenerationResponseData, -) -def img2img_api( - InputData: Img2ImgInputData, -): - model_id = get_model_from_request( - InputData, - fallback_model="stabilityai/stable-diffusion-2-1-base", - ) - scheduler = get_scheduler_from_request(InputData, "img2img") - (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) - - init_image = decode_base64_to_image(InputData.init_images[0]) - mask_image = ( - decode_base64_to_image(InputData.init_images[1]) - if len(InputData.init_images) > 1 - else None - ) - - print( - f"Prompt: {InputData.prompt}, " - f"Negative Prompt: {InputData.negative_prompt}, " - f"Seed: {InputData.seed}, " - f"Model: {model_id}, " - f"Scheduler: {scheduler}." - ) - - res = img2img_inf( - InputData.prompt, - InputData.negative_prompt, - {"image": init_image, "mask": mask_image}, - InputData.height, - InputData.width, - InputData.steps, - InputData.denoising_strength, - InputData.cfg_scale, - InputData.seed, - batch_count=InputData.n_iter, - batch_size=1, - scheduler=scheduler, - model_id=model_id, - custom_vae=frozen_args.custom_vae or "None", - precision="fp16", - device=get_device(frozen_args.device), - max_length=frozen_args.max_length, - use_stencil=InputData.use_stencil, - save_metadata_to_json=frozen_args.save_metadata_to_json, - save_metadata_to_png=frozen_args.write_metadata_to_png, - lora_weights=lora_weights, - lora_hf_id=lora_hf_id, - ondemand=frozen_args.ondemand, - repeatable_seeds=False, - resample_type=frozen_args.resample_type, - ) - - # Since we're not streaming we just want the last generator result - for items_so_far in res: - items = items_so_far - - return { - "images": encode_pil_to_base64(items[0]), - "parameters": {}, - "info": items[1], - } - - -# Rest API: /sdapi/v1/inpaint (Inpainting) -class PaintModelOverideSettings(BaseModel): - sd_model_checkpoint: str = get_model_from_request( - checkpoint_type="inpainting", - fallback_model="stabilityai/stable-diffusion-2-inpainting", - ) - - -class InpaintInputData(GenerationInputData): - image: str = Field(description="Base64 encoded input image") - mask: str = Field(description="Base64 encoded mask image") - is_full_res: bool = False # Is this setting backwards in the UI? - full_res_padding: int = Field(default=32, ge=0, le=256, multiple_of=4) - denoising_strength: float = frozen_args.strength - use_stencil: StencilParam = frozen_args.use_stencil - override_settings: PaintModelOverideSettings = None - - -@sdapi.post( - "/v1/inpaint", - summary="Does inpainting generation on an image", - response_model=GenerationResponseData, -) -def inpaint_api( - InputData: InpaintInputData, -): - model_id = get_model_from_request( - InputData, - checkpoint_type="inpainting", - fallback_model="stabilityai/stable-diffusion-2-inpainting", - ) - scheduler = get_scheduler_from_request(InputData, "inpaint") - (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) - - init_image = decode_base64_to_image(InputData.image) - mask = decode_base64_to_image(InputData.mask) - - print( - f"Prompt: {InputData.prompt}, " - f'Negative Prompt: {InputData.negative_prompt}", ' - f'Seed: {InputData.seed}", ' - f"Model: {model_id}, " - f"Scheduler: {scheduler}." - ) - - res = inpaint_inf( - InputData.prompt, - InputData.negative_prompt, - init_image, - mask, - InputData.height, - InputData.width, - InputData.is_full_res, - InputData.full_res_padding, - InputData.steps, - InputData.cfg_scale, - InputData.seed, - batch_count=InputData.n_iter, - batch_size=1, - scheduler=scheduler, - model_id=model_id, - custom_vae=frozen_args.custom_vae or "None", - precision="fp16", - device=get_device(frozen_args.device), - max_length=frozen_args.max_length, - save_metadata_to_json=frozen_args.save_metadata_to_json, - save_metadata_to_png=frozen_args.write_metadata_to_png, - lora_weights=lora_weights, - lora_hf_id=lora_hf_id, - ondemand=frozen_args.ondemand, - repeatable_seeds=False, - ) - - # Since we're not streaming we just want the last generator result - for items_so_far in res: - items = items_so_far - - return { - "images": encode_pil_to_base64(items[0]), - "parameters": {}, - "info": items[1], - } - - -# Rest API: /sdapi/v1/outpaint (Outpainting) -class DirectionParam(str, Enum): - left = "left" - right = "right" - up = "up" - down = "down" - - -class OutpaintInputData(GenerationInputData): - init_images: list[str] - pixels: int = Field( - default=frozen_args.pixels, ge=8, le=256, multiple_of=8 - ) - mask_blur: int = Field(default=frozen_args.mask_blur, ge=0, le=64) - directions: set[DirectionParam] = [ - direction - for direction in ["left", "right", "up", "down"] - if vars(frozen_args)[direction] - ] - noise_q: float = frozen_args.noise_q - color_variation: float = frozen_args.color_variation - override_settings: PaintModelOverideSettings = None - - -@sdapi.post( - "/v1/outpaint", - summary="Does outpainting generation on an image", - response_model=GenerationResponseData, -) -def outpaint_api( - InputData: OutpaintInputData, -): - model_id = get_model_from_request( - InputData, - checkpoint_type="inpainting", - fallback_model="stabilityai/stable-diffusion-2-inpainting", - ) - scheduler = get_scheduler_from_request(InputData, "outpaint") - (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) - - init_image = decode_base64_to_image(InputData.init_images[0]) - - print( - f"Prompt: {InputData.prompt}, " - f"Negative Prompt: {InputData.negative_prompt}, " - f"Seed: {InputData.seed}, " - f"Model: {model_id}, " - f"Scheduler: {scheduler}." - ) - - res = outpaint_inf( - InputData.prompt, - InputData.negative_prompt, - init_image, - InputData.pixels, - InputData.mask_blur, - InputData.directions, - InputData.noise_q, - InputData.color_variation, - InputData.height, - InputData.width, - InputData.steps, - InputData.cfg_scale, - InputData.seed, - batch_count=InputData.n_iter, - batch_size=1, - scheduler=scheduler, - model_id=model_id, - custom_vae=frozen_args.custom_vae or "None", - precision="fp16", - device=get_device(frozen_args.device), - max_length=frozen_args.max_length, - save_metadata_to_json=frozen_args.save_metadata_to_json, - save_metadata_to_png=frozen_args.write_metadata_to_png, - lora_weights=lora_weights, - lora_hf_id=lora_hf_id, - ondemand=frozen_args.ondemand, - repeatable_seeds=False, - ) - - # Since we're not streaming we just want the last generator result - for items_so_far in res: - items = items_so_far - - return { - "images": encode_pil_to_base64(items[0]), - "parameters": {}, - "info": items[1], - } - - -# Rest API: /sdapi/v1/upscaler (Upscaling) -class UpscalerModelOverideSettings(BaseModel): - sd_model_checkpoint: str = get_model_from_request( - checkpoint_type="upscaler", - fallback_model="stabilityai/stable-diffusion-x4-upscaler", - ) - - -class UpscalerInputData(GenerationInputData): - init_images: list[str] = Field( - description="Base64 encoded image to upscale" - ) - noise_level: int = frozen_args.noise_level - override_settings: UpscalerModelOverideSettings = None - - -@sdapi.post( - "/v1/upscaler", - summary="Does image upscaling", - response_model=GenerationResponseData, -) -def upscaler_api( - InputData: UpscalerInputData, -): - model_id = get_model_from_request( - InputData, - checkpoint_type="upscaler", - fallback_model="stabilityai/stable-diffusion-x4-upscaler", - ) - scheduler = get_scheduler_from_request(InputData, "upscaler") - (lora_weights, lora_hf_id) = get_lora_params(frozen_args.use_lora) - - init_image = decode_base64_to_image(InputData.init_images[0]) - - print( - f"Prompt: {InputData.prompt}, " - f"Negative Prompt: {InputData.negative_prompt}, " - f"Seed: {InputData.seed}, " - f"Model: {model_id}, " - f"Scheduler: {scheduler}." - ) - - res = upscaler_inf( - InputData.prompt, - InputData.negative_prompt, - init_image, - InputData.height, - InputData.width, - InputData.steps, - InputData.noise_level, - InputData.cfg_scale, - InputData.seed, - batch_count=InputData.n_iter, - batch_size=1, - scheduler=scheduler, - model_id=model_id, - custom_vae=frozen_args.custom_vae or "None", - precision="fp16", - device=get_device(frozen_args.device), - max_length=frozen_args.max_length, - save_metadata_to_json=frozen_args.save_metadata_to_json, - save_metadata_to_png=frozen_args.write_metadata_to_png, - lora_weights=lora_weights, - lora_hf_id=lora_hf_id, - ondemand=frozen_args.ondemand, - repeatable_seeds=False, - ) - - # Since we're not streaming we just want the last generator result - for items_so_far in res: - items = items_so_far - - return { - "images": encode_pil_to_base64(items[0]), - "parameters": {}, - "info": items[1], - } diff --git a/apps/stable_diffusion/web/api/utils.py b/apps/stable_diffusion/web/api/utils.py deleted file mode 100644 index eca422f9..00000000 --- a/apps/stable_diffusion/web/api/utils.py +++ /dev/null @@ -1,211 +0,0 @@ -import base64 -import pickle - -from argparse import Namespace -from fastapi.exceptions import HTTPException -from io import BytesIO -from PIL import Image -from pydantic import BaseModel, Field - -from apps.stable_diffusion.src import args -from apps.stable_diffusion.web.ui.utils import ( - available_devices, - get_custom_model_files, - predefined_models, - predefined_paint_models, - predefined_upscaler_models, - scheduler_list, - scheduler_list_cpu_only, -) - - -# Probably overly cautious, but try to ensure we only use the starting -# args in each api call, as the code does `args. = ` -# in lots of places and in testing, it seemed to me, these changes leaked -# into subsequent api calls. - -# Roundtripping through pickle for deepcopy, there is probably a better way -frozen_args = Namespace(**(pickle.loads(pickle.dumps(vars(args))))) - -# an attempt to map some of the A1111 sampler names to scheduler names -# https://github.com/huggingface/diffusers/issues/4167 is where the -# (not so obvious) ones come from -sampler_aliases = { - # a1111/onnx (these point to diffusers classes in A1111) - "pndm": "PNDM", - "heun": "HeunDiscrete", - "ddim": "DDIM", - "ddpm": "DDPM", - "euler": "EulerDiscrete", - "euler-ancestral": "EulerAncestralDiscrete", - "dpm": "DPMSolverMultistep", - # a1111/k_diffusion (the obvious ones) - "Euler a": "EulerAncestralDiscrete", - "Euler": "EulerDiscrete", - "LMS": "LMSDiscrete", - "Heun": "HeunDiscrete", - # a1111/k_diffusion (not so obvious) - "DPM++ 2M": "DPMSolverMultistep", - "DPM++ 2M Karras": "DPMSolverMultistepKarras", - "DPM++ 2M SDE": "DPMSolverMultistep++", - "DPM++ 2M SDE Karras": "DPMSolverMultistepKarras++", - "DPM2": "KDPM2Discrete", - "DPM2 a": "KDPM2AncestralDiscrete", -} - -allowed_schedulers = { - "txt2img": { - "schedulers": scheduler_list, - "fallback": "SharkEulerDiscrete", - }, - "txt2img_hires": { - "schedulers": scheduler_list_cpu_only, - "fallback": "DEISMultistep", - }, - "img2img": { - "schedulers": scheduler_list_cpu_only, - "fallback": "EulerDiscrete", - }, - "inpaint": { - "schedulers": scheduler_list_cpu_only, - "fallback": "DDIM", - }, - "outpaint": { - "schedulers": scheduler_list_cpu_only, - "fallback": "DDIM", - }, - "upscaler": { - "schedulers": scheduler_list_cpu_only, - "fallback": "DDIM", - }, -} - -# base pydantic model for sd generation apis - - -class GenerationInputData(BaseModel): - prompt: str = "" - negative_prompt: str = "" - hf_model_id: str | None = None - height: int = Field( - default=frozen_args.height, ge=128, le=768, multiple_of=8 - ) - width: int = Field( - default=frozen_args.width, ge=128, le=768, multiple_of=8 - ) - sampler_name: str = frozen_args.scheduler - cfg_scale: float = Field(default=frozen_args.guidance_scale, ge=1) - steps: int = Field(default=frozen_args.steps, ge=1, le=100) - seed: int = frozen_args.seed - n_iter: int = Field(default=frozen_args.batch_count) - - -class GenerationResponseData(BaseModel): - images: list[str] = Field(description="Generated images, Base64 encoded") - properties: dict = {} - info: str - - -# image encoding/decoding - - -def encode_pil_to_base64(images: list[Image.Image]): - encoded_imgs = [] - for image in images: - with BytesIO() as output_bytes: - if frozen_args.output_img_format.lower() == "png": - image.save(output_bytes, format="PNG") - - elif frozen_args.output_img_format.lower() in ("jpg", "jpeg"): - image.save(output_bytes, format="JPEG") - else: - raise HTTPException( - status_code=500, detail="Invalid image format" - ) - bytes_data = output_bytes.getvalue() - encoded_imgs.append(base64.b64encode(bytes_data)) - return encoded_imgs - - -def decode_base64_to_image(encoding: str): - if encoding.startswith("data:image/"): - encoding = encoding.split(";", 1)[1].split(",", 1)[1] - try: - image = Image.open(BytesIO(base64.b64decode(encoding))) - return image - except Exception as err: - print(err) - raise HTTPException(status_code=400, detail="Invalid encoded image") - - -# get valid sd models/vaes/schedulers etc. - - -def get_predefined_models(custom_checkpoint_type: str): - match custom_checkpoint_type: - case "inpainting": - return predefined_paint_models - case "upscaler": - return predefined_upscaler_models - case _: - return predefined_models - - -def get_model_from_request( - request_data=None, - checkpoint_type: str = "", - fallback_model: str = "", -): - model = None - if request_data: - if request_data.hf_model_id: - model = request_data.hf_model_id - elif request_data.override_settings: - model = request_data.override_settings.sd_model_checkpoint - - # if the request didn't specify a model try the command line args - result = model or frozen_args.ckpt_loc or frozen_args.hf_model_id - - # make sure whatever we have is a valid model for the checkpoint type - if result in get_custom_model_files( - custom_checkpoint_type=checkpoint_type - ) + get_predefined_models(checkpoint_type): - return result - # if not return what was specified as the fallback - else: - return fallback_model - - -def get_scheduler_from_request( - request_data: GenerationInputData, operation: str -): - allowed = allowed_schedulers[operation] - - requested = request_data.sampler_name - requested = sampler_aliases.get(requested, requested) - - return ( - requested - if requested in allowed["schedulers"] - else allowed["fallback"] - ) - - -def get_lora_params(use_lora: str): - # TODO: since the inference functions in the webui, which we are - # still calling into for the api, jam these back together again before - # handing them off to the pipeline, we should remove this nonsense - # and unify their selection in the UI and command line args proper - if use_lora in get_custom_model_files("lora"): - return (use_lora, "") - - return ("None", use_lora) - - -def get_device(device_str: str): - # first substring match in the list available devices, with first - # device when none are matched - return next( - (device for device in available_devices if device_str in device), - available_devices[0], - ) diff --git a/apps/stable_diffusion/web/index.py b/apps/stable_diffusion/web/index.py deleted file mode 100644 index b6795863..00000000 --- a/apps/stable_diffusion/web/index.py +++ /dev/null @@ -1,480 +0,0 @@ -from multiprocessing import freeze_support -import os -import sys -import logging -import apps.stable_diffusion.web.utils.app as app - -if sys.platform == "darwin": - # import before IREE to avoid torch-MLIR library issues - import torch_mlir - -import shutil -import PIL, transformers, sentencepiece # ensures inclusion in pysintaller exe generation -from apps.stable_diffusion.src import args, clear_all -import apps.stable_diffusion.web.utils.global_obj as global_obj - -if sys.platform == "darwin": - os.environ["DYLD_LIBRARY_PATH"] = "/usr/local/lib" - # import before IREE to avoid MLIR library issues - import torch_mlir - -if args.clear_all: - clear_all() - - -if __name__ == "__main__": - if args.debug: - logging.basicConfig(level=logging.DEBUG) - # required to do multiprocessing in a pyinstaller freeze - freeze_support() - if args.api or "api" in args.ui.split(","): - from apps.stable_diffusion.web.ui import ( - llm_chat_api, - ) - from apps.stable_diffusion.web.api import sdapi - - from fastapi import FastAPI, APIRouter - from fastapi.middleware.cors import CORSMiddleware - import uvicorn - - # init global sd pipeline and config - global_obj._init() - - api = FastAPI() - api.mount("/sdapi/", sdapi) - - # chat APIs needed for compatibility with multiple extensions using OpenAI API - api.add_api_route( - "/v1/chat/completions", llm_chat_api, methods=["post"] - ) - api.add_api_route("/v1/completions", llm_chat_api, methods=["post"]) - api.add_api_route("/chat/completions", llm_chat_api, methods=["post"]) - api.add_api_route("/completions", llm_chat_api, methods=["post"]) - api.add_api_route( - "/v1/engines/codegen/completions", llm_chat_api, methods=["post"] - ) - api.include_router(APIRouter()) - - # deal with CORS requests if CORS accept origins are set - if args.api_accept_origin: - print( - f"API Configured for CORS. Accepting origins: { args.api_accept_origin }" - ) - api.add_middleware( - CORSMiddleware, - allow_origins=args.api_accept_origin, - allow_methods=["GET", "POST"], - allow_headers=["*"], - ) - else: - print("API not configured for CORS") - - uvicorn.run(api, host="0.0.0.0", port=args.server_port) - sys.exit(0) - - # Setup to use shark_tmp for gradio's temporary image files and clear any - # existing temporary images there if they exist. Then we can import gradio. - # It has to be in this order or gradio ignores what we've set up. - from apps.stable_diffusion.web.utils.tmp_configs import ( - config_tmp, - shark_tmp, - ) - - config_tmp() - import gradio as gr - - # Create custom models folders if they don't exist - from apps.stable_diffusion.web.ui.utils import ( - create_custom_models_folders, - nodicon_loc, - mask_editor_value_for_gallery_data, - mask_editor_value_for_image_file, - ) - - create_custom_models_folders() - - def resource_path(relative_path): - """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) - return os.path.join(base_path, relative_path) - - from apps.stable_diffusion.web.ui import ( - txt2img_web, - txt2img_custom_model, - txt2img_gallery, - txt2img_png_info_img, - txt2img_status, - txt2img_sendto_img2img, - txt2img_sendto_inpaint, - txt2img_sendto_outpaint, - txt2img_sendto_upscaler, - # SDXL - txt2img_sdxl_web, - txt2img_sdxl_custom_model, - txt2img_sdxl_gallery, - txt2img_sdxl_png_info_img, - txt2img_sdxl_status, - txt2img_sdxl_sendto_img2img, - txt2img_sdxl_sendto_inpaint, - txt2img_sdxl_sendto_outpaint, - txt2img_sdxl_sendto_upscaler, - # h2ogpt_upload, - # h2ogpt_web, - img2img_web, - img2img_custom_model, - img2img_gallery, - img2img_init_image, - img2img_status, - img2img_sendto_inpaint, - img2img_sendto_outpaint, - img2img_sendto_upscaler, - inpaint_web, - inpaint_custom_model, - inpaint_gallery, - inpaint_init_image, - inpaint_status, - inpaint_sendto_img2img, - inpaint_sendto_outpaint, - inpaint_sendto_upscaler, - outpaint_web, - outpaint_custom_model, - outpaint_gallery, - outpaint_init_image, - outpaint_status, - outpaint_sendto_img2img, - outpaint_sendto_inpaint, - outpaint_sendto_upscaler, - upscaler_web, - upscaler_custom_model, - upscaler_gallery, - upscaler_init_image, - upscaler_status, - upscaler_sendto_img2img, - upscaler_sendto_inpaint, - upscaler_sendto_outpaint, - # lora_train_web, - # model_web, - model_config_web, - hf_models, - modelmanager_sendto_txt2img, - modelmanager_sendto_img2img, - modelmanager_sendto_inpaint, - modelmanager_sendto_outpaint, - modelmanager_sendto_upscaler, - stablelm_chat, - minigpt4_web, - outputgallery_web, - outputgallery_tab_select, - outputgallery_watch, - outputgallery_filename, - outputgallery_sendto_txt2img, - outputgallery_sendto_txt2img_sdxl, - outputgallery_sendto_img2img, - outputgallery_sendto_inpaint, - outputgallery_sendto_outpaint, - outputgallery_sendto_upscaler, - ) - - # init global sd pipeline and config - global_obj._init() - - def register_sendto_click(button, selectedid, inputs, outputs): - button.click( - lambda x: ( - x.root[0].image.path if len(x.root) != 0 else None, - gr.Tabs(selected=selectedid), - ), - inputs, - outputs, - ) - - def register_sendto_editor_click(button, selectedid, inputs, outputs): - button.click( - lambda x: ( - mask_editor_value_for_gallery_data(x), - gr.Tabs(selected=selectedid), - ), - inputs, - outputs, - ) - - def register_modelmanager_button(button, selectedid, inputs, outputs): - button.click( - lambda x: ( - "None", - x, - gr.Tabs(selected=selectedid), - ), - inputs, - outputs, - queue=False, - ) - - def register_outputgallery_sendto_button( - button, selectedid, inputs, outputs - ): - button.click( - lambda x: ( - x, - gr.Tabs(selected=selectedid), - ), - inputs, - outputs, - ) - - def register_outputgallery_sendto_editor_button( - button, selectedid, inputs, outputs - ): - button.click( - lambda x: ( - mask_editor_value_for_image_file(x), - gr.Tabs(selected=selectedid), - ), - inputs, - outputs, - ) - - dark_theme = resource_path("ui/css/sd_dark_theme.css") - - with gr.Blocks( - css=dark_theme, analytics_enabled=False, title="SHARK AI Studio" - ) as sd_web: - with gr.Tabs() as tabs: - # NOTE: If adding, removing, or re-ordering tabs, make sure that they - # have a unique id that doesn't clash with any of the other tabs, - # and that the order in the code here is the order they should - # appear in the ui, as the id value doesn't determine the order. - - # Where possible, avoid changing the id of any tab that is the - # destination of one of the 'send to' buttons. If you do have to change - # that id, make sure you update the relevant register_button_click calls - # further down with the new id. - with gr.TabItem(label="Text-to-Image", id=0): - txt2img_web.render() - with gr.TabItem(label="Image-to-Image", id=1): - img2img_web.render() - with gr.TabItem(label="Inpainting", id=2): - inpaint_web.render() - with gr.TabItem(label="Outpainting", id=3): - outpaint_web.render() - with gr.TabItem(label="Upscaler", id=4): - upscaler_web.render() - if args.output_gallery: - with gr.TabItem(label="Output Gallery", id=5) as og_tab: - outputgallery_web.render() - # with gr.TabItem(label="Model Manager", id=6): - # model_web.render() - # with gr.TabItem(label="LoRA Training (Experimental)", id=7): - # lora_train_web.render() - with gr.TabItem(label="Chat Bot", id=8): - stablelm_chat.render() - # with gr.TabItem( - # label="Generate Sharding Config (Experimental)", id=9 - # ): - # model_config_web.render() - # with gr.TabItem(label="MultiModal (Experimental)", id=10): - # minigpt4_web.render() - # with gr.TabItem(label="DocuChat Upload", id=11): - # h2ogpt_upload.render() - # with gr.TabItem(label="DocuChat(Experimental)", id=12): - # h2ogpt_web.render() - with gr.TabItem(label="Text-to-Image (SDXL)", id=13): - txt2img_sdxl_web.render() - - # extra output gallery configuration - outputgallery_tab_select(og_tab.select) - outputgallery_watch( - [ - txt2img_status, - img2img_status, - inpaint_status, - outpaint_status, - upscaler_status, - txt2img_sdxl_status, - ], - ) - - actual_port = app.usable_port() - if actual_port != args.server_port: - sd_web.load( - fn=lambda: gr.Info( - f"Port {args.server_port} is in use by another application. " - f"Shark is running on port {actual_port} instead." - ) - ) - - # send to buttons - register_sendto_click( - txt2img_sendto_img2img, - 1, - [txt2img_gallery], - [img2img_init_image, tabs], - ) - register_sendto_editor_click( - txt2img_sendto_inpaint, - 2, - [txt2img_gallery], - [inpaint_init_image, tabs], - ) - register_sendto_click( - txt2img_sendto_outpaint, - 3, - [txt2img_gallery], - [outpaint_init_image, tabs], - ) - register_sendto_click( - txt2img_sendto_upscaler, - 4, - [txt2img_gallery], - [upscaler_init_image, tabs], - ) - register_sendto_editor_click( - img2img_sendto_inpaint, - 2, - [img2img_gallery], - [inpaint_init_image, tabs], - ) - register_sendto_click( - img2img_sendto_outpaint, - 3, - [img2img_gallery], - [outpaint_init_image, tabs], - ) - register_sendto_click( - img2img_sendto_upscaler, - 4, - [img2img_gallery], - [upscaler_init_image, tabs], - ) - register_sendto_click( - inpaint_sendto_img2img, - 1, - [inpaint_gallery], - [img2img_init_image, tabs], - ) - register_sendto_click( - inpaint_sendto_outpaint, - 3, - [inpaint_gallery], - [outpaint_init_image, tabs], - ) - register_sendto_click( - inpaint_sendto_upscaler, - 4, - [inpaint_gallery], - [upscaler_init_image, tabs], - ) - register_sendto_click( - outpaint_sendto_img2img, - 1, - [outpaint_gallery], - [img2img_init_image, tabs], - ) - register_sendto_editor_click( - outpaint_sendto_inpaint, - 2, - [outpaint_gallery], - [inpaint_init_image, tabs], - ) - register_sendto_click( - outpaint_sendto_upscaler, - 4, - [outpaint_gallery], - [upscaler_init_image, tabs], - ) - register_sendto_click( - upscaler_sendto_img2img, - 1, - [upscaler_gallery], - [img2img_init_image, tabs], - ) - register_sendto_editor_click( - upscaler_sendto_inpaint, - 2, - [upscaler_gallery], - [inpaint_init_image, tabs], - ) - register_sendto_click( - upscaler_sendto_outpaint, - 3, - [upscaler_gallery], - [outpaint_init_image, tabs], - ) - if args.output_gallery: - register_outputgallery_sendto_button( - outputgallery_sendto_txt2img, - 0, - [outputgallery_filename], - [txt2img_png_info_img, tabs], - ) - register_outputgallery_sendto_button( - outputgallery_sendto_img2img, - 1, - [outputgallery_filename], - [img2img_init_image, tabs], - ) - register_outputgallery_sendto_editor_button( - outputgallery_sendto_inpaint, - 2, - [outputgallery_filename], - [inpaint_init_image, tabs], - ) - register_outputgallery_sendto_button( - outputgallery_sendto_outpaint, - 3, - [outputgallery_filename], - [outpaint_init_image, tabs], - ) - register_outputgallery_sendto_button( - outputgallery_sendto_upscaler, - 4, - [outputgallery_filename], - [upscaler_init_image, tabs], - ) - register_outputgallery_sendto_button( - outputgallery_sendto_txt2img_sdxl, - 0, - [outputgallery_filename], - [txt2img_sdxl_png_info_img, tabs], - ) - register_modelmanager_button( - modelmanager_sendto_txt2img, - 0, - [hf_models], - [txt2img_custom_model, tabs], - ) - register_modelmanager_button( - modelmanager_sendto_img2img, - 1, - [hf_models], - [img2img_custom_model, tabs], - ) - register_modelmanager_button( - modelmanager_sendto_inpaint, - 2, - [hf_models], - [inpaint_custom_model, tabs], - ) - register_modelmanager_button( - modelmanager_sendto_outpaint, - 3, - [hf_models], - [outpaint_custom_model, tabs], - ) - register_modelmanager_button( - modelmanager_sendto_upscaler, - 4, - [hf_models], - [upscaler_custom_model, tabs], - ) - - sd_web.queue() - sd_web.launch( - share=args.share, - inbrowser=not app.launch(actual_port), - server_name="0.0.0.0", - server_port=actual_port, - favicon_path=nodicon_loc, - ) diff --git a/apps/stable_diffusion/web/ui/__init__.py b/apps/stable_diffusion/web/ui/__init__.py deleted file mode 100644 index 979c1298..00000000 --- a/apps/stable_diffusion/web/ui/__init__.py +++ /dev/null @@ -1,96 +0,0 @@ -from apps.stable_diffusion.web.ui.txt2img_ui import ( - txt2img_inf, - txt2img_web, - txt2img_custom_model, - txt2img_gallery, - txt2img_png_info_img, - txt2img_status, - txt2img_sendto_img2img, - txt2img_sendto_inpaint, - txt2img_sendto_outpaint, - txt2img_sendto_upscaler, -) -from apps.stable_diffusion.web.ui.txt2img_sdxl_ui import ( - txt2img_sdxl_inf, - txt2img_sdxl_web, - txt2img_sdxl_custom_model, - txt2img_sdxl_gallery, - txt2img_sdxl_status, - txt2img_sdxl_png_info_img, - txt2img_sdxl_sendto_img2img, - txt2img_sdxl_sendto_inpaint, - txt2img_sdxl_sendto_outpaint, - txt2img_sdxl_sendto_upscaler, -) -from apps.stable_diffusion.web.ui.img2img_ui import ( - img2img_inf, - img2img_web, - img2img_custom_model, - img2img_gallery, - img2img_init_image, - img2img_status, - img2img_sendto_inpaint, - img2img_sendto_outpaint, - img2img_sendto_upscaler, -) -from apps.stable_diffusion.web.ui.inpaint_ui import ( - inpaint_inf, - inpaint_web, - inpaint_custom_model, - inpaint_gallery, - inpaint_init_image, - inpaint_status, - inpaint_sendto_img2img, - inpaint_sendto_outpaint, - inpaint_sendto_upscaler, -) -from apps.stable_diffusion.web.ui.outpaint_ui import ( - outpaint_inf, - outpaint_web, - outpaint_custom_model, - outpaint_gallery, - outpaint_init_image, - outpaint_status, - outpaint_sendto_img2img, - outpaint_sendto_inpaint, - outpaint_sendto_upscaler, -) -from apps.stable_diffusion.web.ui.upscaler_ui import ( - upscaler_inf, - upscaler_web, - upscaler_custom_model, - upscaler_gallery, - upscaler_init_image, - upscaler_status, - upscaler_sendto_img2img, - upscaler_sendto_inpaint, - upscaler_sendto_outpaint, -) -from apps.stable_diffusion.web.ui.model_manager import ( - model_web, - hf_models, - modelmanager_sendto_txt2img, - modelmanager_sendto_img2img, - modelmanager_sendto_inpaint, - modelmanager_sendto_outpaint, - modelmanager_sendto_upscaler, -) -from apps.stable_diffusion.web.ui.lora_train_ui import lora_train_web -from apps.stable_diffusion.web.ui.stablelm_ui import ( - stablelm_chat, - llm_chat_api, -) -from apps.stable_diffusion.web.ui.generate_config import model_config_web -from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web -from apps.stable_diffusion.web.ui.outputgallery_ui import ( - outputgallery_web, - outputgallery_tab_select, - outputgallery_watch, - outputgallery_filename, - outputgallery_sendto_txt2img, - outputgallery_sendto_txt2img_sdxl, - outputgallery_sendto_img2img, - outputgallery_sendto_inpaint, - outputgallery_sendto_outpaint, - outputgallery_sendto_upscaler, -) diff --git a/apps/stable_diffusion/web/ui/common_ui_events.py b/apps/stable_diffusion/web/ui/common_ui_events.py deleted file mode 100644 index f467f6b0..00000000 --- a/apps/stable_diffusion/web/ui/common_ui_events.py +++ /dev/null @@ -1,57 +0,0 @@ -import gradio as gr - -from apps.stable_diffusion.web.ui.utils import ( - HSLHue, - hsl_color, - get_lora_metadata, -) - - -# Answers HTML to show the most frequent tags used when a LoRA was trained, -# taken from the metadata of its .safetensors file. -def lora_changed(lora_file): - # tag frequency percentage, that gets maximum amount of the staring hue - TAG_COLOR_THRESHOLD = 0.55 - # tag frequency percentage, above which a tag is displayed - TAG_DISPLAY_THRESHOLD = 0.65 - # template for the html used to display a tag - TAG_HTML_TEMPLATE = '{tag}' - - if lora_file == "None": - return ["
    No LoRA selected
    "] - elif not lora_file.lower().endswith(".safetensors"): - return [ - "
    Only metadata queries for .safetensors files are currently supported
    " - ] - else: - metadata = get_lora_metadata(lora_file) - if metadata: - frequencies = metadata["frequencies"] - return [ - "".join( - [ - f'
    Trained against weights in: {metadata["model"]}
    ' - ] - + [ - TAG_HTML_TEMPLATE.format( - color=hsl_color( - (tag[1] - TAG_COLOR_THRESHOLD) - / (1 - TAG_COLOR_THRESHOLD), - start=HSLHue.RED, - end=HSLHue.GREEN, - ), - tag=tag[0], - ) - for tag in frequencies - if tag[1] > TAG_DISPLAY_THRESHOLD - ], - ) - ] - elif metadata is None: - return [ - "
    This LoRA does not publish tag frequency metadata
    " - ] - else: - return [ - "
    This LoRA has empty tag frequency metadata, or we could not parse it
    " - ] diff --git a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css b/apps/stable_diffusion/web/ui/css/sd_dark_theme.css deleted file mode 100644 index fa8d50ad..00000000 --- a/apps/stable_diffusion/web/ui/css/sd_dark_theme.css +++ /dev/null @@ -1,339 +0,0 @@ -/* -Apply Gradio dark theme to the default Gradio theme. -Procedure to upgrade the dark theme: -- Using your browser, visit http://localhost:8080/?__theme=dark -- Open your browser inspector, search for the .dark css class -- Copy .dark class declarations, apply them here into :root -*/ - -:root { - --body-background-fill: var(--background-fill-primary); - --body-text-color: var(--neutral-100); - --color-accent-soft: var(--neutral-700); - --background-fill-primary: var(--neutral-950); - --background-fill-secondary: var(--neutral-900); - --border-color-accent: var(--neutral-600); - --border-color-primary: var(--neutral-700); - --link-text-color-active: var(--secondary-500); - --link-text-color: var(--secondary-500); - --link-text-color-hover: var(--secondary-400); - --link-text-color-visited: var(--secondary-600); - --body-text-color-subdued: var(--neutral-400); - --shadow-spread: 1px; - --block-background-fill: var(--neutral-800); - --block-border-color: var(--border-color-primary); - --block_border_width: None; - --block-info-text-color: var(--body-text-color-subdued); - --block-label-background-fill: var(--background-fill-secondary); - --block-label-border-color: var(--border-color-primary); - --block_label_border_width: None; - --block-label-text-color: var(--neutral-200); - --block_shadow: None; - --block_title_background_fill: None; - --block_title_border_color: None; - --block_title_border_width: None; - --block-title-text-color: var(--neutral-200); - --panel-background-fill: var(--background-fill-secondary); - --panel-border-color: var(--border-color-primary); - --panel_border_width: None; - --checkbox-background-color: var(--neutral-800); - --checkbox-background-color-focus: var(--checkbox-background-color); - --checkbox-background-color-hover: var(--checkbox-background-color); - --checkbox-background-color-selected: var(--secondary-600); - --checkbox-border-color: var(--neutral-700); - --checkbox-border-color-focus: var(--secondary-500); - --checkbox-border-color-hover: var(--neutral-600); - --checkbox-border-color-selected: var(--secondary-600); - --checkbox-border-width: var(--input-border-width); - --checkbox-label-background-fill: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); - --checkbox-label-background-fill-hover: linear-gradient(to top, var(--neutral-900), var(--neutral-800)); - --checkbox-label-background-fill-selected: var(--checkbox-label-background-fill); - --checkbox-label-border-color: var(--border-color-primary); - --checkbox-label-border-color-hover: var(--checkbox-label-border-color); - --checkbox-label-border-width: var(--input-border-width); - --checkbox-label-text-color: var(--body-text-color); - --checkbox-label-text-color-selected: var(--checkbox-label-text-color); - --error-background-fill: var(--background-fill-primary); - --error-border-color: var(--border-color-primary); - --error_border_width: None; - --error-text-color: #ef4444; - --input-background-fill: var(--neutral-800); - --input-background-fill-focus: var(--secondary-600); - --input-background-fill-hover: var(--input-background-fill); - --input-border-color: var(--border-color-primary); - --input-border-color-focus: var(--neutral-700); - --input-border-color-hover: var(--input-border-color); - --input_border_width: None; - --input-placeholder-color: var(--neutral-500); - --input_shadow: None; - --input-shadow-focus: 0 0 0 var(--shadow-spread) var(--neutral-700), var(--shadow-inset); - --loader_color: None; - --slider_color: None; - --stat-background-fill: linear-gradient(to right, var(--primary-400), var(--primary-600)); - --table-border-color: var(--neutral-700); - --table-even-background-fill: var(--neutral-950); - --table-odd-background-fill: var(--neutral-900); - --table-row-focus: var(--color-accent-soft); - --button-border-width: var(--input-border-width); - --button-cancel-background-fill: linear-gradient(to bottom right, #dc2626, #b91c1c); - --button-cancel-background-fill-hover: linear-gradient(to bottom right, #dc2626, #dc2626); - --button-cancel-border-color: #dc2626; - --button-cancel-border-color-hover: var(--button-cancel-border-color); - --button-cancel-text-color: white; - --button-cancel-text-color-hover: var(--button-cancel-text-color); - --button-primary-background-fill: linear-gradient(to bottom right, var(--primary-500), var(--primary-600)); - --button-primary-background-fill-hover: linear-gradient(to bottom right, var(--primary-500), var(--primary-500)); - --button-primary-border-color: var(--primary-500); - --button-primary-border-color-hover: var(--button-primary-border-color); - --button-primary-text-color: white; - --button-primary-text-color-hover: var(--button-primary-text-color); - --button-secondary-background-fill: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-700)); - --button-secondary-background-fill-hover: linear-gradient(to bottom right, var(--neutral-600), var(--neutral-600)); - --button-secondary-border-color: var(--neutral-600); - --button-secondary-border-color-hover: var(--button-secondary-border-color); - --button-secondary-text-color: white; - --button-secondary-text-color-hover: var(--button-secondary-text-color); - --block-border-width: 1px; - --block-label-border-width: 1px; - --form-gap-width: 1px; - --error-border-width: 1px; - --input-border-width: 1px; -} - -/* SHARK theme */ -body { - background-color: var(--background-fill-primary); -} - -.generating.svelte-zlszon.svelte-zlszon { - border: none; -} - -.generating { - border: none !important; -} - -#chatbot { - height: 100% !important; -} - -/* display in full width for desktop devices */ -@media (min-width: 1536px) -{ - .gradio-container { - max-width: var(--size-full) !important; - } -} - -.gradio-container .contain { - padding: 0 var(--size-4) !important; -} - -#top_logo { - color: transparent; - background-color: transparent; - border-radius: 0 !important; - border: 0; -} - -#ui_title { - padding: var(--size-2) 0 0 var(--size-1); -} - -#demo_title_outer { - border-radius: 0; -} - -#prompt_box_outer div:first-child { - border-radius: 0 !important -} - -#prompt_box textarea, #negative_prompt_box textarea { - background-color: var(--background-fill-primary) !important; -} - -#prompt_examples { - margin: 0 !important; -} - -#prompt_examples svg { - display: none !important; -} - -#ui_body { - padding: var(--size-2) !important; - border-radius: 0.5em !important; -} - -#img_result+div { - display: none !important; -} - -footer { - display: none !important; -} - -#gallery + div { - border-radius: 0 !important; -} - -/* Gallery: Remove the default square ratio thumbnail and limit images height to the container */ -#gallery .thumbnail-item.thumbnail-lg { - aspect-ratio: unset; - max-height: calc(55vh - (2 * var(--spacing-lg))); -} -@media (min-width: 1921px) { - /* Force a 768px_height + 4px_margin_height + navbar_height for the gallery */ - #gallery .grid-wrap, #gallery .preview{ - min-height: calc(768px + 4px + var(--size-14)); - max-height: calc(768px + 4px + var(--size-14)); - } - /* Limit height to 768px_height + 2px_margin_height for the thumbnails */ - #gallery .thumbnail-item.thumbnail-lg { - max-height: 770px !important; - } -} -/* Don't upscale when viewing in solo image mode */ -#gallery .preview img { - object-fit: scale-down; -} -/* Navbar images in cover mode*/ -#gallery .preview .thumbnail-item img { - object-fit: cover; -} - -/* Limit the stable diffusion text output height */ -#std_output textarea { - max-height: 215px; -} - -/* Prevent progress bar to block gallery navigation while building images (Gradio V3.19.0) */ -#gallery .wrap.default { - pointer-events: none; -} - -/* Import Png info box */ -#txt2img_prompt_image { - height: var(--size-32) !important; -} - -/* Hide "remove buttons" from ui dropdowns */ -#custom_model .token-remove.remove-all, -#lora_weights .token-remove.remove-all, -#scheduler .token-remove.remove-all, -#device .token-remove.remove-all, -#stencil_model .token-remove.remove-all { - display: none; -} - -/* Hide selected items from ui dropdowns */ -#custom_model .options .item .inner-item, -#scheduler .options .item .inner-item, -#device .options .item .inner-item, -#stencil_model .options .item .inner-item { - display:none; -} - -/* workarounds for container=false not currently working for dropdowns */ -.dropdown_no_container { - padding: 0 !important; -} - -#output_subdir_container { - background-color: var(--block-background-fill); - padding-right: 8px; -} - -/* reduced animation load when generating */ -.generating { - animation-play-state: paused !important; -} - -/* better clarity when progress bars are minimal */ -.meta-text { - background-color: var(--block-label-background-fill); -} - -/* lora tag pills */ -.lora-tags { - border: 1px solid var(--border-color-primary); - color: var(--block-info-text-color) !important; - padding: var(--block-padding); -} - -.lora-tag { - display: inline-block; - height: 2em; - color: rgb(212 212 212) !important; - margin-right: 5pt; - margin-bottom: 5pt; - padding: 2pt 5pt; - border-radius: 5pt; - white-space: nowrap; -} - -.lora-model { - margin-bottom: var(--spacing-lg); - color: var(--block-info-text-color) !important; - line-height: var(--line-sm); -} - -/* output gallery tab */ -.output_parameters_dataframe table.table { -/* works around a gradio bug that always shows scrollbars */ - overflow: clip auto; -} - -.output_parameters_dataframe .cell-wrap span { - /* inadequate workaround for gradio issue #6086 */ - user-select:text !important; - -moz-user-select:text !important; - -webkit-user-select:text !important; - -o-user-select:text !important; - -ms-user-select:text !important; -} - -.output_parameters_dataframe tbody td { - font-size: small; - line-height: var(--line-xs); -} - -.output_icon_button { - max-width: 30px; - align-self: end; - padding-bottom: 16px !important; -} - -.outputgallery_sendto { - min-width: 7em !important; -} - -/* output gallery should take up most of the viewport height regardless of image size/number */ -#outputgallery_gallery .fixed-height { - min-height: 89vh !important; -} - -/* don't stretch non-square images to be square, breaking their aspect ratio */ -#outputgallery_gallery .thumbnail-item.thumbnail-lg > img { - object-fit: contain !important; -} - -/* use the whole gallery area for previeews */ -#outputgallery_gallery .preview { - width: inherit; -} - -/* centered logo for when there are no images */ -#top_logo.logo_centered { - height: 100%; - width: 100%; -} - -#top_logo.logo_centered img{ - object-fit: scale-down; - position: absolute; - width: 80%; - top: 50%; - left: 50%; - transform: translate(-50%, -50%); -} diff --git a/apps/stable_diffusion/web/ui/generate_config.py b/apps/stable_diffusion/web/ui/generate_config.py deleted file mode 100644 index f63c4e50..00000000 --- a/apps/stable_diffusion/web/ui/generate_config.py +++ /dev/null @@ -1,41 +0,0 @@ -import gradio as gr -import torch -from transformers import AutoTokenizer -from apps.language_models.src.model_wrappers.vicuna_model import CombinedModel -from shark.shark_generate_model_config import GenerateConfigFile - - -def get_model_config(): - hf_model_path = "TheBloke/vicuna-7B-1.1-HF" - tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False) - compilation_prompt = "".join(["0" for _ in range(17)]) - compilation_input_ids = tokenizer( - compilation_prompt, - return_tensors="pt", - ).input_ids - compilation_input_ids = torch.tensor(compilation_input_ids).reshape( - [1, 19] - ) - firstVicunaCompileInput = (compilation_input_ids,) - - model = CombinedModel() - c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) - return c.split_into_layers() - - -with gr.Blocks() as model_config_web: - with gr.Row(): - hf_models = gr.Dropdown( - label="Model List", - choices=["Vicuna"], - value="Vicuna", - visible=True, - ) - get_model_config_btn = gr.Button(value="Get Model Config") - json_view = gr.JSON() - - get_model_config_btn.click( - fn=get_model_config, - inputs=[], - outputs=[json_view], - ) diff --git a/apps/stable_diffusion/web/ui/h2ogpt.py b/apps/stable_diffusion/web/ui/h2ogpt.py deleted file mode 100644 index 9db556a2..00000000 --- a/apps/stable_diffusion/web/ui/h2ogpt.py +++ /dev/null @@ -1,367 +0,0 @@ -import gradio as gr -import torch -import os -from pathlib import Path -from transformers import ( - AutoModelForCausalLM, -) -from apps.stable_diffusion.web.ui.utils import available_devices - -from apps.language_models.langchain.enums import ( - DocumentChoices, - LangChainAction, -) -import apps.language_models.langchain.gen as gen -from gpt_langchain import ( - path_to_docs, - create_or_update_db, -) -from apps.stable_diffusion.src import args - - -def user(message, history): - # Append the user's message to the conversation history - return "", history + [[message, ""]] - - -sharkModel = 0 -h2ogpt_model = 0 - - -# NOTE: Each `model_name` should have its own start message -start_message = """ - SHARK DocuChat - Chat with an AI, contextualized with provided files. -""" - - -def create_prompt(history): - system_message = start_message - for item in history: - print("His item: ", item) - - conversation = "<|endoftext|>".join( - [ - "<|endoftext|><|answer|>".join([item[0], item[1]]) - for item in history - ] - ) - - msg = system_message + conversation - msg = msg.strip() - return msg - - -def chat(curr_system_message, history, device, precision): - args.run_docuchat_web = True - global h2ogpt_model - global sharkModel - global h2ogpt_tokenizer - global model_state - global langchain - global userpath_selector - from apps.language_models.langchain.h2oai_pipeline import generate_token - - if h2ogpt_model == 0: - if "cuda" in device: - shark_device = "cuda" - elif "sync" in device: - shark_device = "cpu" - elif "task" in device: - shark_device = "cpu" - elif "vulkan" in device: - shark_device = "vulkan" - else: - print("unrecognized device") - - device = "cpu" if shark_device == "cpu" else "cuda" - - args.device = shark_device - args.precision = precision - - from apps.language_models.langchain.gen import Langchain - - langchain = Langchain(device, precision) - h2ogpt_model, h2ogpt_tokenizer, _ = langchain.get_model( - load_4bit=True - if device == "cuda" - else False, # load model in 4bit if device is cuda to save memory - load_gptq="", - use_safetensors=False, - infer_devices=True, - device=device, - base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", - inference_server="", - tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", - lora_weights="", - gpu_id=0, - reward_type=None, - local_files_only=False, - resume_download=True, - use_auth_token=False, - trust_remote_code=True, - offload_folder=None, - compile_model=False, - verbose=False, - ) - model_state = dict( - model=h2ogpt_model, - tokenizer=h2ogpt_tokenizer, - device=device, - base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", - tokenizer_base_model="h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3", - lora_weights="", - inference_server="", - prompt_type=None, - prompt_dict=None, - ) - from apps.language_models.langchain.h2oai_pipeline import ( - H2OGPTSHARKModel, - ) - - sharkModel = H2OGPTSHARKModel() - - prompt = create_prompt(history) - output_dict = langchain.evaluate( - model_state=model_state, - my_db_state=None, - instruction=prompt, - iinput="", - context="", - stream_output=True, - prompt_type="prompt_answer", - prompt_dict={ - "promptA": "", - "promptB": "", - "PreInstruct": "<|prompt|>", - "PreInput": None, - "PreResponse": "<|answer|>", - "terminate_response": [ - "<|prompt|>", - "<|answer|>", - "<|endoftext|>", - ], - "chat_sep": "<|endoftext|>", - "chat_turn_sep": "<|endoftext|>", - "humanstr": "<|prompt|>", - "botstr": "<|answer|>", - "generates_leading_space": False, - }, - temperature=0.1, - top_p=0.75, - top_k=40, - num_beams=1, - max_new_tokens=256, - min_new_tokens=0, - early_stopping=False, - max_time=180, - repetition_penalty=1.07, - num_return_sequences=1, - do_sample=False, - chat=True, - instruction_nochat=prompt, - iinput_nochat="", - langchain_mode="UserData", - langchain_action=LangChainAction.QUERY.value, - top_k_docs=3, - chunk=True, - chunk_size=512, - document_choice=[DocumentChoices.All_Relevant.name], - concurrency_count=1, - memory_restriction_level=2, - raise_generate_gpu_exceptions=False, - chat_context="", - use_openai_embedding=False, - use_openai_model=False, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", - db_type="chroma", - n_jobs=-1, - first_para=False, - max_max_time=60 * 2, - model_state0=model_state, - model_lock=True, - user_path=userpath_selector.value, - ) - - output = generate_token(sharkModel, **output_dict) - for partial_text in output: - history[-1][1] = partial_text - yield history - return history - - -userpath_selector = gr.Textbox( - label="Document Directory", - value=str(os.path.abspath("apps/language_models/langchain/user_path/")), - interactive=True, - container=True, -) - -with gr.Blocks(title="DocuChat") as h2ogpt_web: - with gr.Row(): - supported_devices = available_devices - enabled = len(supported_devices) > 0 - # show cpu-task device first in list for chatbot - supported_devices = supported_devices[-1:] + supported_devices[:-1] - supported_devices = [x for x in supported_devices if "sync" not in x] - print(supported_devices) - device = gr.Dropdown( - label="Device", - value=supported_devices[0] - if enabled - else "Only CUDA Supported for now", - choices=supported_devices, - interactive=enabled, - allow_custom_value=True, - ) - precision = gr.Radio( - label="Precision", - value="fp16", - choices=[ - "int4", - "int8", - "fp16", - "fp32", - ], - visible=True, - ) - chatbot = gr.Chatbot(height=500) - with gr.Row(): - with gr.Column(): - msg = gr.Textbox( - label="Chat Message Box", - placeholder="Chat Message Box", - show_label=False, - interactive=enabled, - container=False, - ) - with gr.Column(): - with gr.Row(): - submit = gr.Button("Submit", interactive=enabled) - stop = gr.Button("Stop", interactive=enabled) - clear = gr.Button("Clear", interactive=enabled) - system_msg = gr.Textbox( - start_message, label="System Message", interactive=False, visible=False - ) - - submit_event = msg.submit( - fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False - ).then( - fn=chat, - inputs=[system_msg, chatbot, device, precision], - outputs=[chatbot], - queue=True, - ) - submit_click_event = submit.click( - fn=user, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False - ).then( - fn=chat, - inputs=[system_msg, chatbot, device, precision], - outputs=[chatbot], - queue=True, - ) - stop.click( - fn=None, - inputs=None, - outputs=None, - cancels=[submit_event, submit_click_event], - queue=False, - ) - clear.click(lambda: None, None, [chatbot], queue=False) - - -with gr.Blocks(title="DocuChat Upload") as h2ogpt_upload: - import pathlib - - upload_path = None - database = None - database_directory = os.path.abspath( - "apps/language_models/langchain/db_path/" - ) - - def read_path(): - global upload_path - filenames = [ - [f] - for f in os.listdir(upload_path) - if os.path.isfile(os.path.join(upload_path, f)) - ] - filenames.sort() - return filenames - - def upload_file(f): - names = [] - for tmpfile in f: - name = tmpfile.name.split("/")[-1] - basename = os.path.join(upload_path, name) - with open(basename, "wb") as w: - with open(tmpfile.name, "rb") as r: - w.write(r.read()) - update_or_create_db() - return read_path() - - def update_userpath(newpath): - global upload_path - upload_path = newpath - pathlib.Path(upload_path).mkdir(parents=True, exist_ok=True) - return read_path() - - def update_or_create_db(): - global database - global upload_path - - sources = path_to_docs( - upload_path, - verbose=True, - fail_any_exception=False, - n_jobs=-1, - chunk=True, - chunk_size=512, - url=None, - enable_captions=False, - captions_model=None, - caption_loader=None, - enable_ocr=False, - ) - - pathlib.Path(database_directory).mkdir(parents=True, exist_ok=True) - - database = create_or_update_db( - "chroma", - database_directory, - "UserData", - sources, - False, - True, - True, - "sentence-transformers/all-MiniLM-L6-v2", - ) - - def first_run(): - global database - if database is None: - update_or_create_db() - - update_userpath( - os.path.abspath("apps/language_models/langchain/user_path/") - ) - h2ogpt_upload.load(fn=first_run) - h2ogpt_web.load(fn=first_run) - - with gr.Column(): - text = gr.DataFrame( - col_count=(1, "fixed"), - type="array", - label="Documents", - value=read_path(), - ) - with gr.Row(): - upload = gr.UploadButton( - label="Upload documents", - file_count="multiple", - ) - upload.upload(fn=upload_file, inputs=upload, outputs=text) - userpath_selector.render() - userpath_selector.input( - fn=update_userpath, inputs=userpath_selector, outputs=text - ).then(fn=update_or_create_db) diff --git a/apps/stable_diffusion/web/ui/img2img_ui.py b/apps/stable_diffusion/web/ui/img2img_ui.py deleted file mode 100644 index a6df2463..00000000 --- a/apps/stable_diffusion/web/ui/img2img_ui.py +++ /dev/null @@ -1,1056 +0,0 @@ -import os -import torch -import time -import gradio as gr -import PIL -from math import ceil -from PIL import Image - -from gradio.components.image_editor import ( - Brush, - Eraser, - EditorData, - EditorValue, -) -from apps.stable_diffusion.web.ui.utils import ( - available_devices, - nodlogo_loc, - get_custom_model_path, - get_custom_model_files, - scheduler_list_cpu_only, - predefined_models, - cancel_sd, -) -from apps.stable_diffusion.web.ui.common_ui_events import lora_changed -from apps.stable_diffusion.src import ( - args, - Image2ImagePipeline, - StencilPipeline, - resize_stencil, - get_schedulers, - set_init_device_flags, - utils, - save_output_img, -) -from apps.stable_diffusion.src.utils import ( - get_generated_imgs_path, - get_generation_text_info, - resampler_list, -) -from apps.stable_diffusion.src.utils.stencils import ( - CannyDetector, - OpenposeDetector, - ZoeDetector, -) -from apps.stable_diffusion.web.utils.common_label_calc import status_label -import numpy as np - - -# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. -init_iree_vulkan_target_triple = args.iree_vulkan_target_triple -init_use_tuned = args.use_tuned -init_import_mlir = args.import_mlir - - -# Exposed to UI. -def img2img_inf( - prompt: str, - negative_prompt: str, - image_dict, - height: int, - width: int, - steps: int, - strength: float, - guidance_scale: float, - seed: str | int, - batch_count: int, - batch_size: int, - scheduler: str, - model_id: str, - custom_vae: str, - precision: str, - device: str, - max_length: int, - save_metadata_to_json: bool, - save_metadata_to_png: bool, - lora_weights: str, - lora_hf_id: str, - ondemand: bool, - repeatable_seeds: bool, - resample_type: str, - control_mode: str, - stencils: list, - images: list, - preprocessed_hints: list, -): - from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_pathfile, - get_custom_vae_or_lora_weights, - Config, - ) - import apps.stable_diffusion.web.utils.global_obj as global_obj - from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - SD_STATE_CANCEL, - ) - - args.prompts = [prompt] - args.negative_prompts = [negative_prompt] - args.guidance_scale = guidance_scale - args.seed = seed - args.steps = steps - args.strength = strength - args.scheduler = scheduler - args.img_path = "not none" - args.ondemand = ondemand - - for i, stencil in enumerate(stencils): - if images[i] is None and stencil is not None: - continue - if images[i] is not None: - if isinstance(images[i], dict): - images[i] = images[i]["composite"] - images[i] = images[i].convert("RGB") - - if image_dict is None and images[0] is None: - return None, "An Initial Image is required" - if isinstance(image_dict, PIL.Image.Image): - image = image_dict.convert("RGB") - elif image_dict: - image = image_dict["image"].convert("RGB") - else: - # TODO: enable t2i + controlnets - image = None - if image: - image, _, _ = resize_stencil(image, width, height) - - # set ckpt_loc and hf_model_id. - args.ckpt_loc = "" - args.hf_model_id = "" - args.custom_vae = "" - - # .safetensor or .chkpt on the custom model path - if model_id in get_custom_model_files(): - args.ckpt_loc = get_custom_model_pathfile(model_id) - # civitai download - elif "civitai" in model_id: - args.ckpt_loc = model_id - # either predefined or huggingface - else: - args.hf_model_id = model_id - - if custom_vae != "None": - args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") - - args.use_lora = get_custom_vae_or_lora_weights( - lora_weights, lora_hf_id, "lora" - ) - - args.save_metadata_to_json = save_metadata_to_json - args.write_metadata_to_png = save_metadata_to_png - - stencil_count = 0 - for stencil in stencils: - if stencil is not None: - stencil_count += 1 - if stencil_count > 0: - args.hf_model_id = "runwayml/stable-diffusion-v1-5" - elif "Shark" in args.scheduler: - print( - f"Shark schedulers are not supported. Switching to EulerDiscrete " - f"scheduler" - ) - args.scheduler = "EulerDiscrete" - cpu_scheduling = not args.scheduler.startswith("Shark") - args.precision = precision - dtype = torch.float32 if precision == "fp32" else torch.half - print(stencils) - new_config_obj = Config( - "img2img", - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - precision, - batch_size, - max_length, - height, - width, - device, - use_lora=args.use_lora, - stencils=stencils, - ondemand=ondemand, - ) - if ( - not global_obj.get_sd_obj() - or global_obj.get_cfg_obj() != new_config_obj - or any( - global_obj.get_cfg_obj().stencils[idx] != stencil - for idx, stencil in enumerate(stencils) - ) - ): - print("clearing config because you changed something important") - global_obj.clear_cache() - global_obj.set_cfg_obj(new_config_obj) - args.batch_count = batch_count - args.batch_size = batch_size - args.max_length = max_length - args.height = height - args.width = width - args.device = device.split("=>", 1)[1].strip() - args.iree_vulkan_target_triple = init_iree_vulkan_target_triple - args.use_tuned = init_use_tuned - args.import_mlir = init_import_mlir - set_init_device_flags() - model_id = ( - args.hf_model_id - if args.hf_model_id - else "runwayml/stable-diffusion-v1-5" - ) - global_obj.set_schedulers(get_schedulers(model_id)) - scheduler_obj = global_obj.get_scheduler(args.scheduler) - - if stencil_count > 0: - args.use_tuned = False - global_obj.set_sd_obj( - StencilPipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - stencils=stencils, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - ) - else: - global_obj.set_sd_obj( - Image2ImagePipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - ) - - global_obj.set_sd_scheduler(args.scheduler) - - start_time = time.time() - global_obj.get_sd_obj().log = "" - generated_imgs = [] - extra_info = {"STRENGTH": strength} - text_output = "" - try: - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) - except TypeError as error: - raise gr.Error(str(error)) from None - - for current_batch in range(batch_count): - out_imgs = global_obj.get_sd_obj().generate_images( - prompt, - negative_prompt, - image, - batch_size, - height, - width, - ceil(steps / strength), - strength, - guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - stencils, - images, - resample_type=resample_type, - control_mode=control_mode, - preprocessed_hints=preprocessed_hints, - ) - total_time = time.time() - start_time - text_output = get_generation_text_info( - seeds[: current_batch + 1], device - ) - text_output += "\n" + global_obj.get_sd_obj().log - text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec" - - if global_obj.get_sd_status() == SD_STATE_CANCEL: - break - else: - save_output_img( - out_imgs[0], - seeds[current_batch], - extra_info, - ) - generated_imgs.extend(out_imgs) - yield generated_imgs, text_output, status_label( - "Image-to-Image", current_batch + 1, batch_count, batch_size - ), stencils, images - - return generated_imgs, text_output, "", stencils, images - - -with gr.Blocks(title="Image-to-Image") as img2img_web: - # Stencils - # TODO: Add more stencils here - STENCIL_COUNT = 2 - stencils = gr.State([None] * STENCIL_COUNT) - images = gr.State([None] * STENCIL_COUNT) - preprocessed_hints = gr.State([None] * STENCIL_COUNT) - with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - show_download_button=False, - elem_id="top_logo", - width=150, - height=50, - ) - with gr.Row(elem_id="ui_body"): - with gr.Row(): - with gr.Column(scale=1, min_width=600): - # TODO: make this import image prompt info if it exists - img2img_init_image = gr.Image( - label="Input Image", - type="pil", - interactive=True, - sources=["upload"], - ) - with gr.Row(): - # janky fix for overflowing text - i2i_model_info = ( - f"Custom Model Path: {str(get_custom_model_path())}" - ) - img2img_custom_model = gr.Dropdown( - label=f"Models", - info="Select, or enter HuggingFace Model ID or Civitai model download URL", - elem_id="custom_model", - value=os.path.basename(args.ckpt_loc) - if args.ckpt_loc - else "stabilityai/stable-diffusion-2-1-base", - choices=get_custom_model_files() + predefined_models, - allow_custom_value=True, - scale=2, - ) - # janky fix for overflowing text - i2i_vae_info = (str(get_custom_model_path("vae"))).replace( - "\\", "\n\\" - ) - i2i_vae_info = f"VAE Path: {i2i_vae_info}" - custom_vae = gr.Dropdown( - label=f"Custom VAE Models", - info=i2i_vae_info, - elem_id="custom_model", - value=os.path.basename(args.custom_vae) - if args.custom_vae - else "None", - choices=["None"] + get_custom_model_files("vae"), - allow_custom_value=True, - scale=1, - ) - - with gr.Group(elem_id="prompt_box_outer"): - prompt = gr.Textbox( - label="Prompt", - value=args.prompts[0], - lines=2, - elem_id="prompt_box", - ) - negative_prompt = gr.Textbox( - label="Negative Prompt", - value=args.negative_prompts[0], - lines=2, - elem_id="negative_prompt_box", - ) - with gr.Accordion(label="Multistencil Options", open=False): - choices = [ - "None", - "canny", - "openpose", - "scribble", - "zoedepth", - ] - - def cnet_preview( - model, - input_image, - index, - stencils, - images, - preprocessed_hints, - ): - if isinstance(input_image, PIL.Image.Image): - img_dict = { - "background": None, - "layers": [None], - "composite": input_image, - } - input_image = EditorValue(img_dict) - images[index] = input_image - if model: - stencils[index] = model - match model: - case "canny": - canny = CannyDetector() - result = canny( - np.array(input_image["composite"]), - 100, - 200, - ) - preprocessed_hints[index] = Image.fromarray( - result - ) - return ( - Image.fromarray(result), - stencils, - images, - preprocessed_hints, - ) - case "openpose": - openpose = OpenposeDetector() - result = openpose( - np.array(input_image["composite"]) - ) - preprocessed_hints[index] = Image.fromarray( - result[0] - ) - return ( - Image.fromarray(result[0]), - stencils, - images, - preprocessed_hints, - ) - case "zoedepth": - zoedepth = ZoeDetector() - result = zoedepth( - np.array(input_image["composite"]) - ) - preprocessed_hints[index] = Image.fromarray( - result - ) - return ( - Image.fromarray(result), - stencils, - images, - preprocessed_hints, - ) - case "scribble": - preprocessed_hints[index] = input_image[ - "composite" - ] - return ( - input_image["composite"], - stencils, - images, - preprocessed_hints, - ) - case _: - preprocessed_hints[index] = None - return ( - None, - stencils, - images, - preprocessed_hints, - ) - - def import_original(original_img, width, height): - resized_img, _, _ = resize_stencil( - original_img, width, height - ) - img_dict = { - "background": resized_img, - "layers": [resized_img], - "composite": None, - } - return gr.ImageEditor( - value=EditorValue(img_dict), - crop_size=(width, height), - ) - - def create_canvas(width, height): - data = Image.fromarray( - np.zeros( - shape=(height, width, 3), - dtype=np.uint8, - ) - + 255 - ) - img_dict = { - "background": data, - "layers": [data], - "composite": None, - } - return EditorValue(img_dict) - - def update_cn_input( - model, - width, - height, - stencils, - images, - preprocessed_hints, - index, - ): - if model == None: - stencils[index] = None - images[index] = None - preprocessed_hints[index] = None - return [ - gr.ImageEditor(value=None, visible=False), - gr.Image(value=None), - gr.Slider(visible=False), - gr.Slider(visible=False), - gr.Button(visible=False), - gr.Button(visible=False), - stencils, - images, - preprocessed_hints, - ] - elif model == "scribble": - return [ - gr.ImageEditor( - visible=True, - interactive=True, - show_label=False, - image_mode="RGB", - type="pil", - brush=Brush( - colors=["#000000"], - color_mode="fixed", - default_size=2, - ), - ), - gr.Image( - visible=True, - show_label=False, - interactive=True, - show_download_button=False, - ), - gr.Slider(visible=True, label="Canvas Width"), - gr.Slider(visible=True, label="Canvas Height"), - gr.Button(visible=True), - gr.Button(visible=False), - stencils, - images, - preprocessed_hints, - ] - else: - return [ - gr.ImageEditor( - visible=True, - image_mode="RGB", - type="pil", - interactive=True, - ), - gr.Image( - visible=True, - show_label=False, - interactive=True, - show_download_button=False, - ), - gr.Slider(visible=True, label="Input Width"), - gr.Slider(visible=True, label="Input Height"), - gr.Button(visible=False), - gr.Button(visible=True), - stencils, - images, - preprocessed_hints, - ] - - with gr.Row(): - with gr.Column(): - cnet_1 = gr.Button( - value="Generate controlnet input" - ) - cnet_1_model = gr.Dropdown( - label="Controlnet 1", - value="None", - choices=choices, - ) - canvas_width = gr.Slider( - label="Canvas Width", - minimum=256, - maximum=1024, - value=512, - step=1, - visible=False, - ) - canvas_height = gr.Slider( - label="Canvas Height", - minimum=256, - maximum=1024, - value=512, - step=1, - visible=False, - ) - make_canvas = gr.Button( - value="Make Canvas!", - visible=False, - ) - use_input_img_1 = gr.Button( - value="Use Original Image", - visible=False, - ) - - cnet_1_image = gr.ImageEditor( - visible=False, - image_mode="RGB", - interactive=True, - show_label=True, - label="Input Image", - type="pil", - ) - cnet_1_output = gr.Image( - value=None, - visible=True, - label="Preprocessed Hint", - interactive=True, - ) - - use_input_img_1.click( - import_original, - [img2img_init_image, canvas_width, canvas_height], - [cnet_1_image], - ) - - cnet_1_model.change( - fn=( - lambda m, w, h, s, i, p: update_cn_input( - m, w, h, s, i, p, 0 - ) - ), - inputs=[ - cnet_1_model, - canvas_width, - canvas_height, - stencils, - images, - preprocessed_hints, - ], - outputs=[ - cnet_1_image, - cnet_1_output, - canvas_width, - canvas_height, - make_canvas, - use_input_img_1, - stencils, - images, - preprocessed_hints, - ], - ) - make_canvas.click( - create_canvas, - [canvas_width, canvas_height], - [ - cnet_1_image, - ], - ) - gr.on( - triggers=[cnet_1.click], - fn=( - lambda a, b, s, i, p: cnet_preview( - a, b, 0, s, i, p - ) - ), - inputs=[ - cnet_1_model, - cnet_1_image, - stencils, - images, - preprocessed_hints, - ], - outputs=[ - cnet_1_output, - stencils, - images, - preprocessed_hints, - ], - ) - with gr.Row(): - with gr.Column(): - cnet_2 = gr.Button( - value="Generate controlnet input" - ) - cnet_2_model = gr.Dropdown( - label="Controlnet 2", - value="None", - choices=choices, - ) - canvas_width = gr.Slider( - label="Canvas Width", - minimum=256, - maximum=1024, - value=512, - step=1, - visible=False, - ) - canvas_height = gr.Slider( - label="Canvas Height", - minimum=256, - maximum=1024, - value=512, - step=1, - visible=False, - ) - make_canvas = gr.Button( - value="Make Canvas!", - visible=False, - ) - use_input_img_2 = gr.Button( - value="Use Original Image", - visible=False, - ) - cnet_2_image = gr.ImageEditor( - visible=False, - image_mode="RGB", - interactive=True, - type="pil", - show_label=True, - label="Input Image", - ) - use_input_img_2.click( - import_original, - [img2img_init_image, canvas_width, canvas_height], - [cnet_2_image], - ) - cnet_2_output = gr.Image( - value=None, - visible=True, - label="Preprocessed Hint", - interactive=True, - ) - cnet_2_model.change( - fn=( - lambda m, w, h, s, i, p: update_cn_input( - m, w, h, s, i, p, 0 - ) - ), - inputs=[ - cnet_2_model, - canvas_width, - canvas_height, - stencils, - images, - preprocessed_hints, - ], - outputs=[ - cnet_2_image, - cnet_2_output, - canvas_width, - canvas_height, - make_canvas, - use_input_img_2, - stencils, - images, - preprocessed_hints, - ], - ) - make_canvas.click( - create_canvas, - [canvas_width, canvas_height], - [ - cnet_2_image, - ], - ) - cnet_2.click( - fn=( - lambda a, b, s, i, p: cnet_preview( - a, b, 1, s, i, p - ) - ), - inputs=[ - cnet_2_model, - cnet_2_image, - stencils, - images, - preprocessed_hints, - ], - outputs=[ - cnet_2_output, - stencils, - images, - preprocessed_hints, - ], - ) - control_mode = gr.Radio( - choices=["Prompt", "Balanced", "Controlnet"], - value="Balanced", - label="Control Mode", - ) - - with gr.Accordion(label="LoRA Options", open=False): - with gr.Row(): - # janky fix for overflowing text - i2i_lora_info = ( - str(get_custom_model_path("lora")) - ).replace("\\", "\n\\") - i2i_lora_info = f"LoRA Path: {i2i_lora_info}" - lora_weights = gr.Dropdown( - allow_custom_value=True, - label=f"Standalone LoRA Weights", - info=i2i_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_custom_model_files("lora"), - ) - lora_hf_id = gr.Textbox( - elem_id="lora_hf_id", - placeholder="Select 'None' in the Standalone LoRA " - "weights dropdown on the left if you want to use " - "a standalone HuggingFace model ID for LoRA here " - "e.g: sayakpaul/sd-model-finetuned-lora-t4", - value="", - label="HuggingFace Model ID", - lines=3, - ) - with gr.Row(): - lora_tags = gr.HTML( - value="
    No LoRA selected
    ", - elem_classes="lora-tags", - ) - with gr.Accordion(label="Advanced Options", open=False): - with gr.Row(): - scheduler = gr.Dropdown( - elem_id="scheduler", - label="Scheduler", - value="EulerDiscrete", - choices=scheduler_list_cpu_only, - allow_custom_value=True, - ) - with gr.Group(): - save_metadata_to_png = gr.Checkbox( - label="Save prompt information to PNG", - value=args.write_metadata_to_png, - interactive=True, - ) - save_metadata_to_json = gr.Checkbox( - label="Save prompt information to JSON file", - value=args.save_metadata_to_json, - interactive=True, - ) - with gr.Row(): - height = gr.Slider( - 384, 768, value=args.height, step=8, label="Height" - ) - width = gr.Slider( - 384, 768, value=args.width, step=8, label="Width" - ) - max_length = gr.Radio( - label="Max Length", - value=args.max_length, - choices=[ - 64, - 77, - ], - visible=False, - ) - with gr.Row(): - with gr.Column(scale=3): - steps = gr.Slider( - 1, 100, value=args.steps, step=1, label="Steps" - ) - with gr.Column(scale=3): - strength = gr.Slider( - 0, - 1, - value=args.strength, - step=0.01, - label="Denoising Strength", - ) - resample_type = gr.Dropdown( - value=args.resample_type, - choices=resampler_list, - label="Resample Type", - allow_custom_value=True, - ) - ondemand = gr.Checkbox( - value=args.ondemand, - label="Low VRAM", - interactive=True, - ) - precision = gr.Radio( - label="Precision", - value=args.precision, - choices=[ - "fp16", - "fp32", - ], - visible=True, - ) - with gr.Row(): - with gr.Column(scale=3): - guidance_scale = gr.Slider( - 0, - 50, - value=args.guidance_scale, - step=0.1, - label="CFG Scale", - ) - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=args.batch_count, - step=1, - label="Batch Count", - interactive=True, - ) - repeatable_seeds = gr.Checkbox( - args.repeatable_seeds, - label="Repeatable Seeds", - ) - with gr.Row(): - batch_size = gr.Slider( - 1, - 4, - value=args.batch_size, - step=1, - label="Batch Size", - interactive=False, - visible=False, - ) - with gr.Row(): - seed = gr.Textbox( - value=args.seed, - label="Seed", - info="An integer or a JSON list of integers, -1 for random", - ) - device = gr.Dropdown( - elem_id="device", - label="Device", - value=available_devices[0], - choices=available_devices, - allow_custom_value=True, - ) - - with gr.Column(scale=1, min_width=600): - with gr.Group(): - img2img_gallery = gr.Gallery( - label="Generated images", - show_label=False, - elem_id="gallery", - columns=2, - object_fit="contain", - # TODO: Re-enable download when fixed in Gradio - show_download_button=False, - ) - std_output = gr.Textbox( - value=f"{i2i_model_info}\n" - f"Images will be saved at " - f"{get_generated_imgs_path()}", - lines=2, - elem_id="std_output", - show_label=False, - ) - img2img_status = gr.Textbox(visible=False) - with gr.Row(): - stable_diffusion = gr.Button("Generate Image(s)") - random_seed = gr.Button("Randomize Seed") - random_seed.click( - lambda: -1, - inputs=[], - outputs=[seed], - queue=False, - ) - stop_batch = gr.Button("Stop Batch") - with gr.Row(): - blank_thing_for_row = None - with gr.Row(): - img2img_sendto_inpaint = gr.Button(value="SendTo Inpaint") - img2img_sendto_outpaint = gr.Button( - value="SendTo Outpaint" - ) - img2img_sendto_upscaler = gr.Button( - value="SendTo Upscaler" - ) - - kwargs = dict( - fn=img2img_inf, - inputs=[ - prompt, - negative_prompt, - img2img_init_image, - height, - width, - steps, - strength, - guidance_scale, - seed, - batch_count, - batch_size, - scheduler, - img2img_custom_model, - custom_vae, - precision, - device, - max_length, - save_metadata_to_json, - save_metadata_to_png, - lora_weights, - lora_hf_id, - ondemand, - repeatable_seeds, - resample_type, - control_mode, - stencils, - images, - preprocessed_hints, - ], - outputs=[ - img2img_gallery, - std_output, - img2img_status, - stencils, - images, - ], - show_progress="minimal" if args.progress_bar else "none", - ) - - status_kwargs = dict( - fn=lambda bc, bs: status_label("Image-to-Image", 0, bc, bs), - inputs=[batch_count, batch_size], - outputs=img2img_status, - ) - - prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **kwargs - ) - generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) - stop_batch.click( - fn=cancel_sd, - cancels=[prompt_submit, neg_prompt_submit, generate_click], - ) - - lora_weights.change( - fn=lora_changed, - inputs=[lora_weights], - outputs=[lora_tags], - queue=True, - ) diff --git a/apps/stable_diffusion/web/ui/inpaint_ui.py b/apps/stable_diffusion/web/ui/inpaint_ui.py deleted file mode 100644 index 4ce4795a..00000000 --- a/apps/stable_diffusion/web/ui/inpaint_ui.py +++ /dev/null @@ -1,624 +0,0 @@ -import os -import torch -import time -import sys -import gradio as gr -import PIL.ImageOps -from PIL import Image - -from gradio.components.image_editor import ( - Brush, - Eraser, - EditorData, - EditorValue, -) -from apps.stable_diffusion.web.ui.utils import ( - available_devices, - nodlogo_loc, - get_custom_model_path, - get_custom_model_files, - scheduler_list_cpu_only, - predefined_paint_models, - cancel_sd, -) -from apps.stable_diffusion.web.ui.common_ui_events import lora_changed -from apps.stable_diffusion.src import ( - args, - InpaintPipeline, - get_schedulers, - set_init_device_flags, - utils, - clear_all, - save_output_img, -) -from apps.stable_diffusion.src.utils import ( - get_generated_imgs_path, - get_generation_text_info, -) -from apps.stable_diffusion.web.utils.common_label_calc import status_label - - -# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. -init_iree_vulkan_target_triple = args.iree_vulkan_target_triple -init_use_tuned = args.use_tuned -init_import_mlir = args.import_mlir - - -def set_image_states(editor_data): - input_mask = editor_data["layers"][0] - - # inpaint_inf wants white mask on black background (?), whilst ImageEditor - # delivers black mask on transparent (0 opacity) background - inference_mask = Image.new( - mode="RGB", size=input_mask.size, color=(255, 255, 255) - ) - inference_mask.paste(input_mask, input_mask) - inference_mask = PIL.ImageOps.invert(inference_mask) - - return ( - # we set the ImageEditor data again, because it likes to clear - # the image layers (which include the mask) if the user hasn't - # used the upload button, and we sent it and image - # TODO: work out what is going wrong in that case so we don't have - # to do this - { - "background": editor_data["background"], - "layers": [input_mask], - "composite": None, - }, - editor_data["background"], - input_mask, - inference_mask, - ) - - -def reload_image_editor(editor_image, editor_mask): - # we set the ImageEditor data again, because it likes to clear - # the image layers (which include the mask) if the user hasn't - # used the upload button, and we sent it the image - # TODO: work out what is going wrong in that case so we don't have - # to do this - return { - "background": editor_image, - "layers": [editor_mask], - "composite": None, - } - - -# Exposed to UI. -def inpaint_inf( - prompt: str, - negative_prompt: str, - image, - mask_image, - height: int, - width: int, - inpaint_full_res: bool, - inpaint_full_res_padding: int, - steps: int, - guidance_scale: float, - seed: str | int, - batch_count: int, - batch_size: int, - scheduler: str, - model_id: str, - custom_vae: str, - precision: str, - device: str, - max_length: int, - save_metadata_to_json: bool, - save_metadata_to_png: bool, - lora_weights: str, - lora_hf_id: str, - ondemand: bool, - repeatable_seeds: int, -): - from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_pathfile, - get_custom_vae_or_lora_weights, - Config, - ) - import apps.stable_diffusion.web.utils.global_obj as global_obj - from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - SD_STATE_CANCEL, - ) - - args.prompts = [prompt] - args.negative_prompts = [negative_prompt] - args.guidance_scale = guidance_scale - args.steps = steps - args.scheduler = scheduler - args.img_path = "not none" - args.mask_path = "not none" - args.ondemand = ondemand - - # set ckpt_loc and hf_model_id. - args.ckpt_loc = "" - args.hf_model_id = "" - args.custom_vae = "" - - # .safetensor or .chkpt on the custom model path - if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"): - args.ckpt_loc = get_custom_model_pathfile(model_id) - # civitai download - elif "civitai" in model_id: - args.ckpt_loc = model_id - # either predefined or huggingface - else: - args.hf_model_id = model_id - - if custom_vae != "None": - args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") - - args.use_lora = get_custom_vae_or_lora_weights( - lora_weights, lora_hf_id, "lora" - ) - - args.save_metadata_to_json = save_metadata_to_json - args.write_metadata_to_png = save_metadata_to_png - - dtype = torch.float32 if precision == "fp32" else torch.half - cpu_scheduling = not scheduler.startswith("Shark") - new_config_obj = Config( - "inpaint", - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - precision, - batch_size, - max_length, - height, - width, - device, - use_lora=args.use_lora, - stencils=[], - ondemand=ondemand, - ) - if ( - not global_obj.get_sd_obj() - or global_obj.get_cfg_obj() != new_config_obj - ): - global_obj.clear_cache() - global_obj.set_cfg_obj(new_config_obj) - args.precision = precision - args.batch_count = batch_count - args.batch_size = batch_size - args.max_length = max_length - args.height = height - args.width = width - args.device = device.split("=>", 1)[1].strip() - args.iree_vulkan_target_triple = init_iree_vulkan_target_triple - args.use_tuned = init_use_tuned - args.import_mlir = init_import_mlir - set_init_device_flags() - model_id = ( - args.hf_model_id - if args.hf_model_id - else "stabilityai/stable-diffusion-2-inpainting" - ) - global_obj.set_schedulers(get_schedulers(model_id)) - scheduler_obj = global_obj.get_scheduler(scheduler) - global_obj.set_sd_obj( - InpaintPipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - custom_vae=args.custom_vae, - precision=args.precision, - max_length=args.max_length, - batch_size=args.batch_size, - height=args.height, - width=args.width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - ) - - global_obj.set_sd_scheduler(scheduler) - - start_time = time.time() - global_obj.get_sd_obj().log = "" - generated_imgs = [] - text_output = "" - try: - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) - except TypeError as error: - raise gr.Error(str(error)) from None - - for current_batch in range(batch_count): - out_imgs = global_obj.get_sd_obj().generate_images( - prompt, - negative_prompt, - image, - mask_image, - batch_size, - height, - width, - inpaint_full_res, - inpaint_full_res_padding, - steps, - guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - total_time = time.time() - start_time - text_output = get_generation_text_info( - seeds[: current_batch + 1], device - ) - text_output += "\n" + global_obj.get_sd_obj().log - text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec" - - if global_obj.get_sd_status() == SD_STATE_CANCEL: - break - else: - save_output_img(out_imgs[0], seeds[current_batch]) - generated_imgs.extend(out_imgs) - yield generated_imgs, text_output, status_label( - "Inpaint", current_batch + 1, batch_count, batch_size - ) - - return generated_imgs, text_output - - -with gr.Blocks(title="Inpainting") as inpaint_web: - editor_image = gr.State() - editor_mask = gr.State() - inference_mask = gr.State() - with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - show_download_button=False, - elem_id="top_logo", - width=150, - height=50, - ) - with gr.Row(elem_id="ui_body"): - with gr.Row(): - with gr.Column(scale=1, min_width=600): - inpaint_init_image = gr.Sketchpad( - label="Masked Image", - type="pil", - sources=("clipboard", "upload"), - interactive=True, - brush=Brush( - colors=["#000000"], - color_mode="fixed", - ), - ) - with gr.Row(): - # janky fix for overflowing text - inpaint_model_info = ( - f"Custom Model Path: {str(get_custom_model_path())}" - ) - inpaint_custom_model = gr.Dropdown( - label=f"Models", - info="Select, or enter HuggingFace Model ID or Civitai model download URL", - elem_id="custom_model", - value=os.path.basename(args.ckpt_loc) - if args.ckpt_loc - else "stabilityai/stable-diffusion-2-inpainting", - choices=get_custom_model_files( - custom_checkpoint_type="inpainting" - ) - + predefined_paint_models, - allow_custom_value=True, - scale=2, - ) - # janky fix for overflowing text - inpaint_vae_info = ( - str(get_custom_model_path("vae")) - ).replace("\\", "\n\\") - inpaint_vae_info = f"VAE Path: {inpaint_vae_info}" - custom_vae = gr.Dropdown( - label=f"Custom VAE Models", - info=inpaint_vae_info, - elem_id="custom_model", - value=os.path.basename(args.custom_vae) - if args.custom_vae - else "None", - choices=["None"] + get_custom_model_files("vae"), - allow_custom_value=True, - scale=1, - ) - - with gr.Group(elem_id="prompt_box_outer"): - prompt = gr.Textbox( - label="Prompt", - value=args.prompts[0], - lines=2, - elem_id="prompt_box", - ) - negative_prompt = gr.Textbox( - label="Negative Prompt", - value=args.negative_prompts[0], - lines=2, - elem_id="negative_prompt_box", - ) - with gr.Accordion(label="LoRA Options", open=False): - with gr.Row(): - # janky fix for overflowing text - inpaint_lora_info = ( - str(get_custom_model_path("lora")) - ).replace("\\", "\n\\") - inpaint_lora_info = f"LoRA Path: {inpaint_lora_info}" - lora_weights = gr.Dropdown( - label=f"Standalone LoRA Weights", - info=inpaint_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_custom_model_files("lora"), - allow_custom_value=True, - ) - lora_hf_id = gr.Textbox( - elem_id="lora_hf_id", - placeholder="Select 'None' in the Standalone LoRA " - "weights dropdown on the left if you want to use " - "a standalone HuggingFace model ID for LoRA here " - "e.g: sayakpaul/sd-model-finetuned-lora-t4", - value="", - label="HuggingFace Model ID", - lines=3, - ) - with gr.Row(): - lora_tags = gr.HTML( - value="
    No LoRA selected
    ", - elem_classes="lora-tags", - ) - with gr.Accordion(label="Advanced Options", open=False): - with gr.Row(): - scheduler = gr.Dropdown( - elem_id="scheduler", - label="Scheduler", - value="EulerDiscrete", - choices=scheduler_list_cpu_only, - allow_custom_value=True, - ) - with gr.Group(): - save_metadata_to_png = gr.Checkbox( - label="Save prompt information to PNG", - value=args.write_metadata_to_png, - interactive=True, - ) - save_metadata_to_json = gr.Checkbox( - label="Save prompt information to JSON file", - value=args.save_metadata_to_json, - interactive=True, - ) - with gr.Row(): - height = gr.Slider( - 384, 768, value=args.height, step=8, label="Height" - ) - width = gr.Slider( - 384, 768, value=args.width, step=8, label="Width" - ) - precision = gr.Radio( - label="Precision", - value=args.precision, - choices=[ - "fp16", - "fp32", - ], - visible=False, - ) - max_length = gr.Radio( - label="Max Length", - value=args.max_length, - choices=[ - 64, - 77, - ], - visible=False, - ) - with gr.Row(): - inpaint_full_res = gr.Radio( - choices=["Whole picture", "Only masked"], - type="index", - value="Whole picture", - label="Inpaint area", - ) - inpaint_full_res_padding = gr.Slider( - minimum=0, - maximum=256, - step=4, - value=32, - label="Only masked padding, pixels", - ) - with gr.Row(): - steps = gr.Slider( - 1, 100, value=args.steps, step=1, label="Steps" - ) - ondemand = gr.Checkbox( - value=args.ondemand, - label="Low VRAM", - interactive=True, - ) - with gr.Row(): - with gr.Column(scale=3): - guidance_scale = gr.Slider( - 0, - 50, - value=args.guidance_scale, - step=0.1, - label="CFG Scale", - ) - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=args.batch_count, - step=1, - label="Batch Count", - interactive=True, - ) - repeatable_seeds = gr.Checkbox( - args.repeatable_seeds, - label="Repeatable Seeds", - ) - with gr.Row(): - batch_size = gr.Slider( - 1, - 4, - value=args.batch_size, - step=1, - label="Batch Size", - interactive=False, - visible=False, - ) - with gr.Row(): - seed = gr.Textbox( - value=args.seed, - label="Seed", - info="An integer or a JSON list of integers, -1 for random", - ) - device = gr.Dropdown( - elem_id="device", - label="Device", - value=available_devices[0], - choices=available_devices, - allow_custom_value=True, - ) - - with gr.Column(scale=1, min_width=600): - with gr.Group(): - inpaint_gallery = gr.Gallery( - label="Generated images", - show_label=False, - elem_id="gallery", - columns=[2], - object_fit="contain", - # TODO: Re-enable download when fixed in Gradio - show_download_button=False, - ) - std_output = gr.Textbox( - value=f"{inpaint_model_info}\n" - "Images will be saved at " - f"{get_generated_imgs_path()}", - lines=2, - elem_id="std_output", - show_label=False, - ) - inpaint_status = gr.Textbox(visible=False) - with gr.Row(): - stable_diffusion = gr.Button("Generate Image(s)") - random_seed = gr.Button("Randomize Seed") - random_seed.click( - lambda: -1, - inputs=[], - outputs=[seed], - queue=False, - ) - stop_batch = gr.Button("Stop Batch") - with gr.Row(): - blank_thing_for_row = None - with gr.Row(): - inpaint_sendto_img2img = gr.Button(value="SendTo Img2Img") - inpaint_sendto_outpaint = gr.Button( - value="SendTo Outpaint" - ) - inpaint_sendto_upscaler = gr.Button( - value="SendTo Upscaler" - ) - - kwargs = dict( - fn=inpaint_inf, - inputs=[ - prompt, - negative_prompt, - editor_image, - inference_mask, - height, - width, - inpaint_full_res, - inpaint_full_res_padding, - steps, - guidance_scale, - seed, - batch_count, - batch_size, - scheduler, - inpaint_custom_model, - custom_vae, - precision, - device, - max_length, - save_metadata_to_json, - save_metadata_to_png, - lora_weights, - lora_hf_id, - ondemand, - repeatable_seeds, - ], - outputs=[inpaint_gallery, std_output, inpaint_status], - show_progress="minimal" if args.progress_bar else "none", - ) - status_kwargs = dict( - fn=lambda bc, bs: status_label("Inpaint", 0, bc, bs), - inputs=[batch_count, batch_size], - outputs=inpaint_status, - show_progress="none", - ) - set_image_states_args = dict( - fn=set_image_states, - inputs=[inpaint_init_image], - outputs=[ - inpaint_init_image, - editor_image, - editor_mask, - inference_mask, - ], - show_progress="none", - ) - reload_image_editor_args = dict( - fn=reload_image_editor, - inputs=[editor_image, editor_mask], - outputs=[inpaint_init_image], - show_progress="none", - ) - - # all these trigger generation - prompt_submit = ( - prompt.submit(**set_image_states_args) - .then(**status_kwargs) - .then(**kwargs) - .then(**reload_image_editor_args) - ) - neg_prompt_submit = ( - negative_prompt.submit(**set_image_states_args) - .then(**status_kwargs) - .then(**kwargs) - .then(**reload_image_editor_args) - ) - generate_click = ( - stable_diffusion.click(**set_image_states_args) - .then(**status_kwargs) - .then(**kwargs) - .then(**reload_image_editor_args) - ) - - # Attempts to cancel generation - stop_batch.click( - fn=cancel_sd, - cancels=[prompt_submit, neg_prompt_submit, generate_click], - ) - - # Updates LoRA information when one is selected - lora_weights.change( - fn=lora_changed, - inputs=[lora_weights], - outputs=[lora_tags], - queue=True, - ) diff --git a/apps/stable_diffusion/web/ui/logos/nod-icon.png b/apps/stable_diffusion/web/ui/logos/nod-icon.png deleted file mode 100644 index 29f7e322..00000000 Binary files a/apps/stable_diffusion/web/ui/logos/nod-icon.png and /dev/null differ diff --git a/apps/stable_diffusion/web/ui/logos/nod-logo.png b/apps/stable_diffusion/web/ui/logos/nod-logo.png deleted file mode 100644 index 4727e15a..00000000 Binary files a/apps/stable_diffusion/web/ui/logos/nod-logo.png and /dev/null differ diff --git a/apps/stable_diffusion/web/ui/lora_train_ui.py b/apps/stable_diffusion/web/ui/lora_train_ui.py deleted file mode 100644 index 45c1c3e2..00000000 --- a/apps/stable_diffusion/web/ui/lora_train_ui.py +++ /dev/null @@ -1,251 +0,0 @@ -from pathlib import Path -import os -import gradio as gr -from PIL import Image -from apps.stable_diffusion.scripts import lora_train -from apps.stable_diffusion.src import prompt_examples, args, utils -from apps.stable_diffusion.web.ui.utils import ( - available_devices, - nodlogo_loc, - get_custom_model_path, - get_custom_model_files, - get_custom_vae_or_lora_weights, - scheduler_list, - predefined_models, -) - -with gr.Blocks(title="Lora Training") as lora_train_web: - with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - show_download_button=False, - elem_id="top_logo", - width=150, - height=50, - ) - with gr.Row(elem_id="ui_body"): - with gr.Row(): - with gr.Column(scale=1, min_width=600): - with gr.Row(): - with gr.Column(scale=10): - with gr.Row(): - # janky fix for overflowing text - train_lora_model_info = ( - str(get_custom_model_path()) - ).replace("\\", "\n\\") - train_lora_model_info = ( - f"Custom Model Path: {train_lora_model_info}" - ) - custom_model = gr.Dropdown( - label=f"Models", - info=train_lora_model_info, - elem_id="custom_model", - value=os.path.basename(args.ckpt_loc) - if args.ckpt_loc - else "None", - choices=["None"] - + get_custom_model_files() - + predefined_models, - allow_custom_value=True, - ) - hf_model_id = gr.Textbox( - elem_id="hf_model_id", - placeholder="Select 'None' in the Models " - "dropdown on the left and enter model ID here " - "e.g: SG161222/Realistic_Vision_V1.3", - value="", - label="HuggingFace Model ID", - lines=3, - ) - - with gr.Row(): - # janky fix for overflowing text - train_lora_info = ( - str(get_custom_model_path("lora")) - ).replace("\\", "\n\\") - train_lora_info = f"LoRA Path: {train_lora_info}" - lora_weights = gr.Dropdown( - label=f"Standalone LoRA weights to initialize weights", - info=train_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_custom_model_files("lora"), - allow_custom_value=True, - ) - lora_hf_id = gr.Textbox( - elem_id="lora_hf_id", - placeholder="Select 'None' in the Standalone LoRA " - "weights dropdown on the left if you want to use a " - "standalone HuggingFace model ID for LoRA here " - "e.g: sayakpaul/sd-model-finetuned-lora-t4", - value="", - label="HuggingFace Model ID to initialize weights", - lines=3, - ) - with gr.Group(elem_id="image_dir_box_outer"): - training_images_dir = gr.Textbox( - label="ImageDirectory", - value=args.training_images_dir, - lines=1, - elem_id="prompt_box", - ) - with gr.Group(elem_id="prompt_box_outer"): - prompt = gr.Textbox( - label="Prompt", - value=args.prompts[0], - lines=2, - elem_id="prompt_box", - ) - with gr.Accordion(label="Advanced Options", open=False): - with gr.Row(): - scheduler = gr.Dropdown( - elem_id="scheduler", - label="Scheduler", - value=args.scheduler, - choices=scheduler_list, - allow_custom_value=True, - ) - with gr.Row(): - height = gr.Slider( - 384, 768, value=args.height, step=8, label="Height" - ) - width = gr.Slider( - 384, 768, value=args.width, step=8, label="Width" - ) - precision = gr.Radio( - label="Precision", - value=args.precision, - choices=[ - "fp16", - "fp32", - ], - visible=False, - ) - max_length = gr.Radio( - label="Max Length", - value=args.max_length, - choices=[ - 64, - 77, - ], - visible=False, - ) - with gr.Row(): - steps = gr.Slider( - 1, - 2000, - value=args.training_steps, - step=1, - label="Training Steps", - ) - guidance_scale = gr.Slider( - 0, - 50, - value=args.guidance_scale, - step=0.1, - label="CFG Scale", - ) - with gr.Row(): - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=args.batch_count, - step=1, - label="Batch Count", - interactive=True, - ) - with gr.Column(scale=3): - batch_size = gr.Slider( - 1, - 4, - value=args.batch_size, - step=1, - label="Batch Size", - interactive=True, - ) - stop_batch = gr.Button("Stop Batch") - with gr.Row(): - seed = gr.Number( - value=utils.parse_seed_input(args.seed)[0], - precision=0, - label="Seed", - ) - device = gr.Dropdown( - elem_id="device", - label="Device", - value=available_devices[0], - choices=available_devices, - allow_custom_value=True, - ) - with gr.Row(): - with gr.Column(scale=2): - random_seed = gr.Button("Randomize Seed") - random_seed.click( - lambda: -1, - inputs=[], - outputs=[seed], - queue=False, - ) - with gr.Column(scale=6): - train_lora = gr.Button("Train LoRA") - - with gr.Accordion(label="Prompt Examples!", open=False): - ex = gr.Examples( - examples=prompt_examples, - inputs=prompt, - cache_examples=False, - elem_id="prompt_examples", - ) - - with gr.Column(scale=1, min_width=600): - with gr.Group(): - std_output = gr.Textbox( - value="Nothing to show.", - lines=1, - show_label=False, - ) - lora_save_dir = ( - args.lora_save_dir if args.lora_save_dir else Path.cwd() - ) - lora_save_dir = Path(lora_save_dir, "lora") - output_loc = gr.Textbox( - label="Saving Lora at", - value=lora_save_dir, - ) - - kwargs = dict( - fn=lora_train, - inputs=[ - prompt, - height, - width, - steps, - guidance_scale, - seed, - batch_count, - batch_size, - scheduler, - custom_model, - hf_model_id, - precision, - device, - max_length, - training_images_dir, - output_loc, - get_custom_vae_or_lora_weights( - lora_weights, lora_hf_id, "lora" - ), - ], - outputs=[std_output], - show_progress="minimal" if args.progress_bar else "none", - ) - - prompt_submit = prompt.submit(**kwargs) - train_click = train_lora.click(**kwargs) - stop_batch.click(fn=None, cancels=[prompt_submit, train_click]) diff --git a/apps/stable_diffusion/web/ui/minigpt4_ui.py b/apps/stable_diffusion/web/ui/minigpt4_ui.py deleted file mode 100644 index cf13f8e2..00000000 --- a/apps/stable_diffusion/web/ui/minigpt4_ui.py +++ /dev/null @@ -1,194 +0,0 @@ -# ======================================== -# Gradio Setting -# ======================================== -import gradio as gr - -# from apps.language_models.src.pipelines.minigpt4_pipeline import ( -# # MiniGPT4, -# CONV_VISION, -# ) -from pathlib import Path - -chat = None - - -def gradio_reset(chat_state, img_list): - if chat_state is not None: - chat_state.messages = [] - if img_list is not None: - img_list = [] - return ( - None, - gr.update(value=None, interactive=True), - gr.update( - placeholder="Please upload your image first", interactive=False - ), - gr.update(value="Upload & Start Chat", interactive=True), - chat_state, - img_list, - ) - - -def upload_img(gr_img, text_input, chat_state, device, precision, _compile): - global chat - if chat is None: - from apps.language_models.src.pipelines.minigpt4_pipeline import ( - MiniGPT4, - CONV_VISION, - ) - - vision_model_precision = precision - if precision in ["int4", "int8"]: - vision_model_precision = "fp16" - vision_model_vmfb_path = Path( - f"vision_model_{vision_model_precision}_{device}.vmfb" - ) - qformer_vmfb_path = Path(f"qformer_fp32_{device}.vmfb") - chat = MiniGPT4( - model_name="MiniGPT4", - hf_model_path=None, - max_new_tokens=30, - device=device, - precision=precision, - _compile=_compile, - vision_model_vmfb_path=vision_model_vmfb_path, - qformer_vmfb_path=qformer_vmfb_path, - ) - if gr_img is None: - return None, None, gr.update(interactive=True), chat_state, None - chat_state = CONV_VISION.copy() - img_list = [] - llm_message = chat.upload_img(gr_img, chat_state, img_list) - return ( - gr.update(interactive=False), - gr.update(interactive=True, placeholder="Type and press Enter"), - gr.update(value="Start Chatting", interactive=False), - chat_state, - img_list, - ) - - -def gradio_ask(user_message, chatbot, chat_state): - if len(user_message) == 0: - return ( - gr.update( - interactive=True, placeholder="Input should not be empty!" - ), - chatbot, - chat_state, - ) - chat.ask(user_message, chat_state) - chatbot = chatbot + [[user_message, None]] - return "", chatbot, chat_state - - -def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature): - llm_message = chat.answer( - conv=chat_state, - img_list=img_list, - num_beams=num_beams, - temperature=temperature, - max_new_tokens=300, - max_length=2000, - )[0] - print(llm_message) - print("************") - chatbot[-1][1] = llm_message - return chatbot, chat_state, img_list - - -title = """

    MultiModal SHARK (experimental)

    """ -description = """

    Upload your images and start chatting!

    """ -article = """

    -""" - -# TODO show examples below - -with gr.Blocks() as minigpt4_web: - gr.Markdown(title) - gr.Markdown(description) - - with gr.Row(): - with gr.Column(): - image = gr.Image(type="pil") - upload_button = gr.Button( - value="Upload & Start Chat", - interactive=True, - variant="primary", - ) - clear = gr.Button("Restart") - - num_beams = gr.Slider( - minimum=1, - maximum=10, - value=1, - step=1, - interactive=True, - label="beam search numbers)", - ) - - temperature = gr.Slider( - minimum=0.1, - maximum=2.0, - value=1.0, - step=0.1, - interactive=True, - label="Temperature", - ) - - device = gr.Dropdown( - label="Device", - value="cuda", - # if enabled - # else "Only CUDA Supported for now", - choices=["cuda"], - interactive=False, - allow_custom_value=True, - ) - - with gr.Column(): - chat_state = gr.State() - img_list = gr.State() - chatbot = gr.Chatbot(label="MiniGPT-4") - text_input = gr.Textbox( - label="User", - placeholder="Please upload your image first", - interactive=False, - ) - precision = gr.Radio( - label="Precision", - value="int8", - choices=[ - "int8", - "fp16", - "fp32", - ], - visible=True, - ) - _compile = gr.Checkbox( - value=False, - label="Compile", - interactive=True, - ) - - upload_button.click( - upload_img, - [image, text_input, chat_state, device, precision, _compile], - [image, text_input, upload_button, chat_state, img_list], - ) - - text_input.submit( - gradio_ask, - [text_input, chatbot, chat_state], - [text_input, chatbot, chat_state], - ).then( - gradio_answer, - [chatbot, chat_state, img_list, num_beams, temperature], - [chatbot, chat_state, img_list], - ) - clear.click( - gradio_reset, - [chat_state, img_list], - [chatbot, image, text_input, upload_button, chat_state, img_list], - queue=False, - ) diff --git a/apps/stable_diffusion/web/ui/model_manager.py b/apps/stable_diffusion/web/ui/model_manager.py deleted file mode 100644 index 21c0939f..00000000 --- a/apps/stable_diffusion/web/ui/model_manager.py +++ /dev/null @@ -1,161 +0,0 @@ -import os -import gradio as gr -import requests -from io import BytesIO -from PIL import Image - - -def get_hf_list(num_of_models=20): - path = "https://huggingface.co/api/models" - params = { - "search": "stable-diffusion", - "sort": "downloads", - "direction": "-1", - "limit": {num_of_models}, - "full": "true", - } - response = requests.get(path, params=params) - return response.json() - - -def get_civit_list(num_of_models=50): - path = ( - f"https://civitai.com/api/v1/models?limit=" - f"{num_of_models}&types=Checkpoint" - ) - headers = {"Content-Type": "application/json"} - raw_json = requests.get(path, headers=headers).json() - models = list(raw_json.items())[0][1] - safe_models = [ - safe_model for safe_model in models if not safe_model["nsfw"] - ] - version_id = 0 # Currently just using the first version. - safe_models = [ - safe_model - for safe_model in safe_models - if safe_model["modelVersions"][version_id]["files"][0]["metadata"][ - "format" - ] - == "SafeTensor" - ] - first_version_models = [] - for model_iter in safe_models: - # The modelVersion would only keep the version name. - if ( - model_iter["modelVersions"][version_id]["images"][0]["nsfw"] - != "None" - ): - continue - model_iter["modelVersions"][version_id]["modelName"] = model_iter[ - "name" - ] - model_iter["modelVersions"][version_id]["rating"] = model_iter[ - "stats" - ]["rating"] - model_iter["modelVersions"][version_id]["favoriteCount"] = model_iter[ - "stats" - ]["favoriteCount"] - model_iter["modelVersions"][version_id]["downloadCount"] = model_iter[ - "stats" - ]["downloadCount"] - first_version_models.append(model_iter["modelVersions"][version_id]) - return first_version_models - - -def get_image_from_model(model_json): - model_id = model_json["modelId"] - image = None - for img_info in model_json["images"]: - if img_info["nsfw"] == "None": - image_url = model_json["images"][0]["url"] - response = requests.get(image_url) - image = BytesIO(response.content) - break - return image - - -with gr.Blocks() as model_web: - with gr.Row(): - model_source = gr.Radio( - value=None, - choices=["Hugging Face", "Civitai"], - type="value", - label="Model Source", - ) - model_number = gr.Slider( - 1, - 100, - value=10, - step=1, - label="Number of models", - interactive=True, - ) - # TODO: add more filters - get_model_btn = gr.Button(value="Get Models") - - hf_models = gr.Dropdown( - label="Hugging Face Model List", - choices=None, - value=None, - visible=False, - allow_custom_value=True, - ) - # TODO: select and SendTo - civit_models = gr.Gallery( - label="Civitai Model Gallery", - value=None, - visible=False, - show_download_button=False, - ) - - with gr.Row(visible=False) as sendto_btns: - modelmanager_sendto_txt2img = gr.Button(value="SendTo Txt2Img") - modelmanager_sendto_img2img = gr.Button(value="SendTo Img2Img") - modelmanager_sendto_inpaint = gr.Button(value="SendTo Inpaint") - modelmanager_sendto_outpaint = gr.Button(value="SendTo Outpaint") - modelmanager_sendto_upscaler = gr.Button(value="SendTo Upscaler") - - def get_model_list(model_source, model_number): - if model_source == "Hugging Face": - hf_model_list = get_hf_list(model_number) - models = [] - for model in hf_model_list: - # TODO: add model info - models.append(f'{model["modelId"]}') - return ( - gr.Dropdown.update(choices=models, visible=True), - gr.Gallery.update(value=None, visible=False), - gr.Row.update(visible=True), - ) - elif model_source == "Civitai": - civit_model_list = get_civit_list(model_number) - models = [] - for model in civit_model_list: - image = get_image_from_model(model) - if image is None: - continue - # TODO: add model info - models.append( - (Image.open(image), f'{model["files"][0]["downloadUrl"]}') - ) - return ( - gr.Dropdown.update(value=None, choices=None, visible=False), - gr.Gallery.update(value=models, visible=True), - gr.Row.update(visible=False), - ) - else: - return ( - gr.Dropdown.update(value=None, choices=None, visible=False), - gr.Gallery.update(value=None, visible=False), - gr.Row.update(visible=False), - ) - - get_model_btn.click( - fn=get_model_list, - inputs=[model_source, model_number], - outputs=[ - hf_models, - civit_models, - sendto_btns, - ], - ) diff --git a/apps/stable_diffusion/web/ui/outpaint_ui.py b/apps/stable_diffusion/web/ui/outpaint_ui.py deleted file mode 100644 index a515f6c9..00000000 --- a/apps/stable_diffusion/web/ui/outpaint_ui.py +++ /dev/null @@ -1,558 +0,0 @@ -import os -import torch -import time -import gradio as gr -from PIL import Image - -from apps.stable_diffusion.web.ui.common_ui_events import lora_changed -from apps.stable_diffusion.web.ui.utils import ( - available_devices, - nodlogo_loc, - get_custom_model_path, - get_custom_model_files, - scheduler_list_cpu_only, - predefined_paint_models, - cancel_sd, -) -from apps.stable_diffusion.src import ( - args, - OutpaintPipeline, - get_schedulers, - set_init_device_flags, - utils, - save_output_img, -) -from apps.stable_diffusion.src.utils import ( - get_generated_imgs_path, - get_generation_text_info, -) -from apps.stable_diffusion.web.utils.common_label_calc import status_label - -# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. -init_iree_vulkan_target_triple = args.iree_vulkan_target_triple -init_use_tuned = args.use_tuned -init_import_mlir = args.import_mlir - - -# Exposed to UI. -def outpaint_inf( - prompt: str, - negative_prompt: str, - init_image, - pixels: int, - mask_blur: int, - directions: list, - noise_q: float, - color_variation: float, - height: int, - width: int, - steps: int, - guidance_scale: float, - seed: str, - batch_count: int, - batch_size: int, - scheduler: str, - model_id: str, - custom_vae: str, - precision: str, - device: str, - max_length: int, - save_metadata_to_json: bool, - save_metadata_to_png: bool, - lora_weights: str, - lora_hf_id: str, - ondemand: bool, - repeatable_seeds: bool, -): - from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_pathfile, - get_custom_vae_or_lora_weights, - Config, - ) - import apps.stable_diffusion.web.utils.global_obj as global_obj - from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - SD_STATE_CANCEL, - ) - - args.prompts = [prompt] - args.negative_prompts = [negative_prompt] - args.guidance_scale = guidance_scale - args.steps = steps - args.scheduler = scheduler - args.img_path = "not none" - args.ondemand = ondemand - - # set ckpt_loc and hf_model_id. - args.ckpt_loc = "" - args.hf_model_id = "" - args.custom_vae = "" - - # .safetensor or .chkpt on the custom model path - if model_id in get_custom_model_files(custom_checkpoint_type="inpainting"): - args.ckpt_loc = get_custom_model_pathfile(model_id) - # civitai download - elif "civitai" in model_id: - args.ckpt_loc = model_id - # either predefined or huggingface - else: - args.hf_model_id = model_id - - if custom_vae != "None": - args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") - - args.use_lora = get_custom_vae_or_lora_weights( - lora_weights, lora_hf_id, "lora" - ) - - args.save_metadata_to_json = save_metadata_to_json - args.write_metadata_to_png = save_metadata_to_png - - dtype = torch.float32 if precision == "fp32" else torch.half - cpu_scheduling = not scheduler.startswith("Shark") - new_config_obj = Config( - "outpaint", - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - precision, - batch_size, - max_length, - height, - width, - device, - use_lora=args.use_lora, - stencils=[], - ondemand=ondemand, - ) - if ( - not global_obj.get_sd_obj() - or global_obj.get_cfg_obj() != new_config_obj - ): - global_obj.clear_cache() - global_obj.set_cfg_obj(new_config_obj) - args.precision = precision - args.batch_count = batch_count - args.batch_size = batch_size - args.max_length = max_length - args.height = height - args.width = width - args.device = device.split("=>", 1)[1].strip() - args.iree_vulkan_target_triple = init_iree_vulkan_target_triple - args.use_tuned = init_use_tuned - args.import_mlir = init_import_mlir - set_init_device_flags() - model_id = ( - args.hf_model_id - if args.hf_model_id - else "stabilityai/stable-diffusion-2-inpainting" - ) - global_obj.set_schedulers(get_schedulers(model_id)) - scheduler_obj = global_obj.get_scheduler(scheduler) - global_obj.set_sd_obj( - OutpaintPipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - ) - - global_obj.set_sd_scheduler(scheduler) - - start_time = time.time() - global_obj.get_sd_obj().log = "" - generated_imgs = [] - try: - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) - except TypeError as error: - raise gr.Error(str(error)) from None - - left = True if "left" in directions else False - right = True if "right" in directions else False - top = True if "up" in directions else False - bottom = True if "down" in directions else False - - text_output = "" - for current_batch in range(batch_count): - out_imgs = global_obj.get_sd_obj().generate_images( - prompt, - negative_prompt, - init_image, - pixels, - mask_blur, - left, - right, - top, - bottom, - noise_q, - color_variation, - batch_size, - height, - width, - steps, - guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - total_time = time.time() - start_time - text_output = get_generation_text_info( - seeds[: current_batch + 1], device - ) - text_output += "\n" + global_obj.get_sd_obj().log - text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec" - - if global_obj.get_sd_status() == SD_STATE_CANCEL: - break - else: - save_output_img(out_imgs[0], seeds[current_batch]) - generated_imgs.extend(out_imgs) - yield generated_imgs, text_output, status_label( - "Outpaint", current_batch + 1, batch_count, batch_size - ) - - return generated_imgs, text_output, "" - - -with gr.Blocks(title="Outpainting") as outpaint_web: - with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - show_download_button=False, - elem_id="top_logo", - width=150, - height=50, - ) - with gr.Row(elem_id="ui_body"): - with gr.Row(): - with gr.Column(scale=1, min_width=600): - outpaint_init_image = gr.Image( - label="Input Image", type="pil", sources=["upload"] - ) - with gr.Row(): - outpaint_model_info = ( - f"Custom Model Path: {str(get_custom_model_path())}" - ) - outpaint_custom_model = gr.Dropdown( - label=f"Models", - info="Select, or enter HuggingFace Model ID or Civitai model download URL", - elem_id="custom_model", - value=os.path.basename(args.ckpt_loc) - if args.ckpt_loc - else "stabilityai/stable-diffusion-2-inpainting", - choices=get_custom_model_files( - custom_checkpoint_type="inpainting" - ) - + predefined_paint_models, - allow_custom_value=True, - scale=2, - ) - # janky fix for overflowing text - outpaint_vae_info = ( - str(get_custom_model_path("vae")) - ).replace("\\", "\n\\") - outpaint_vae_info = f"VAE Path: {outpaint_vae_info}" - custom_vae = gr.Dropdown( - label=f"Custom VAE Models", - info=outpaint_vae_info, - elem_id="custom_model", - value=os.path.basename(args.custom_vae) - if args.custom_vae - else "None", - choices=["None"] + get_custom_model_files("vae"), - allow_custom_value=True, - scale=1, - ) - with gr.Group(elem_id="prompt_box_outer"): - prompt = gr.Textbox( - label="Prompt", - value=args.prompts[0], - lines=2, - elem_id="prompt_box", - ) - negative_prompt = gr.Textbox( - label="Negative Prompt", - value=args.negative_prompts[0], - lines=2, - elem_id="negative_prompt_box", - ) - with gr.Accordion(label="LoRA Options", open=False): - with gr.Row(): - # janky fix for overflowing text - outpaint_lora_info = ( - str(get_custom_model_path("lora")) - ).replace("\\", "\n\\") - outpaint_lora_info = f"LoRA Path: {outpaint_lora_info}" - lora_weights = gr.Dropdown( - label=f"Standalone LoRA Weights", - info=outpaint_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_custom_model_files("lora"), - allow_custom_value=True, - ) - lora_hf_id = gr.Textbox( - elem_id="lora_hf_id", - placeholder="Select 'None' in the Standalone LoRA " - "weights dropdown on the left if you want to use " - "a standalone HuggingFace model ID for LoRA here " - "e.g: sayakpaul/sd-model-finetuned-lora-t4", - value="", - label="HuggingFace Model ID", - lines=3, - ) - with gr.Row(): - lora_tags = gr.HTML( - value="
    No LoRA selected
    ", - elem_classes="lora-tags", - ) - with gr.Accordion(label="Advanced Options", open=False): - with gr.Row(): - scheduler = gr.Dropdown( - elem_id="scheduler", - label="Scheduler", - value="EulerDiscrete", - choices=scheduler_list_cpu_only, - allow_custom_value=True, - ) - with gr.Group(): - save_metadata_to_png = gr.Checkbox( - label="Save prompt information to PNG", - value=args.write_metadata_to_png, - interactive=True, - ) - save_metadata_to_json = gr.Checkbox( - label="Save prompt information to JSON file", - value=args.save_metadata_to_json, - interactive=True, - ) - with gr.Row(): - pixels = gr.Slider( - 8, - 256, - value=args.pixels, - step=8, - label="Pixels to expand", - ) - mask_blur = gr.Slider( - 0, - 64, - value=args.mask_blur, - step=1, - label="Mask blur", - ) - with gr.Row(): - directions = gr.CheckboxGroup( - label="Outpainting direction", - choices=["left", "right", "up", "down"], - value=["left", "right", "up", "down"], - ) - with gr.Row(): - noise_q = gr.Slider( - 0.0, - 4.0, - value=1.0, - step=0.01, - label="Fall-off exponent (lower=higher detail)", - ) - color_variation = gr.Slider( - 0.0, - 1.0, - value=0.05, - step=0.01, - label="Color variation", - ) - with gr.Row(): - height = gr.Slider( - 384, 768, value=args.height, step=8, label="Height" - ) - width = gr.Slider( - 384, 768, value=args.width, step=8, label="Width" - ) - precision = gr.Radio( - label="Precision", - value=args.precision, - choices=[ - "fp16", - "fp32", - ], - visible=False, - ) - max_length = gr.Radio( - label="Max Length", - value=args.max_length, - choices=[ - 64, - 77, - ], - visible=False, - ) - with gr.Row(): - steps = gr.Slider( - 1, 100, value=20, step=1, label="Steps" - ) - ondemand = gr.Checkbox( - value=args.ondemand, - label="Low VRAM", - interactive=True, - ) - with gr.Row(): - with gr.Column(scale=3): - guidance_scale = gr.Slider( - 0, - 50, - value=args.guidance_scale, - step=0.1, - label="CFG Scale", - ) - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=args.batch_count, - step=1, - label="Batch Count", - interactive=True, - ) - repeatable_seeds = gr.Checkbox( - args.repeatable_seeds, - label="Repeatable Seeds", - ) - - with gr.Row(): - batch_size = gr.Slider( - 1, - 4, - value=args.batch_size, - step=1, - label="Batch Size", - interactive=False, - visible=False, - ) - with gr.Row(): - seed = gr.Textbox( - value=args.seed, - label="Seed", - info="An integer or a JSON list of integers, -1 for random", - ) - device = gr.Dropdown( - elem_id="device", - label="Device", - value=available_devices[0], - choices=available_devices, - allow_custom_value=True, - ) - - with gr.Column(scale=1, min_width=600): - with gr.Group(): - outpaint_gallery = gr.Gallery( - label="Generated images", - show_label=False, - elem_id="gallery", - columns=[2], - object_fit="contain", - # TODO: Re-enable download when fixed in Gradio - show_download_button=False, - ) - std_output = gr.Textbox( - value=f"{outpaint_model_info}\n" - f"Images will be saved at " - f"{get_generated_imgs_path()}", - lines=2, - elem_id="std_output", - show_label=False, - ) - outpaint_status = gr.Textbox(visible=False) - with gr.Row(): - stable_diffusion = gr.Button("Generate Image(s)") - random_seed = gr.Button("Randomize Seed") - random_seed.click( - lambda: -1, - inputs=[], - outputs=[seed], - queue=False, - ) - stop_batch = gr.Button("Stop Batch") - with gr.Row(): - blank_thing_for_row = None - with gr.Row(): - outpaint_sendto_img2img = gr.Button(value="SendTo Img2Img") - outpaint_sendto_inpaint = gr.Button(value="SendTo Inpaint") - outpaint_sendto_upscaler = gr.Button( - value="SendTo Upscaler" - ) - - kwargs = dict( - fn=outpaint_inf, - inputs=[ - prompt, - negative_prompt, - outpaint_init_image, - pixels, - mask_blur, - directions, - noise_q, - color_variation, - height, - width, - steps, - guidance_scale, - seed, - batch_count, - batch_size, - scheduler, - outpaint_custom_model, - custom_vae, - precision, - device, - max_length, - save_metadata_to_json, - save_metadata_to_png, - lora_weights, - lora_hf_id, - ondemand, - repeatable_seeds, - ], - outputs=[outpaint_gallery, std_output, outpaint_status], - show_progress="minimal" if args.progress_bar else "none", - ) - status_kwargs = dict( - fn=lambda bc, bs: status_label("Outpaint", 0, bc, bs), - inputs=[batch_count, batch_size], - outputs=outpaint_status, - ) - - prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **kwargs - ) - generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) - stop_batch.click( - fn=cancel_sd, - cancels=[prompt_submit, neg_prompt_submit, generate_click], - ) - - lora_weights.change( - fn=lora_changed, - inputs=[lora_weights], - outputs=[lora_tags], - queue=True, - ) diff --git a/apps/stable_diffusion/web/ui/outputgallery_ui.py b/apps/stable_diffusion/web/ui/outputgallery_ui.py deleted file mode 100644 index d33e5f53..00000000 --- a/apps/stable_diffusion/web/ui/outputgallery_ui.py +++ /dev/null @@ -1,468 +0,0 @@ -import glob -import gradio as gr -import os -import subprocess -import sys -from PIL import Image - -from apps.stable_diffusion.src import args -from apps.stable_diffusion.src.utils import ( - get_generated_imgs_path, - get_generated_imgs_todays_subdir, -) -from apps.stable_diffusion.web.ui.utils import nodlogo_loc -from apps.stable_diffusion.web.utils.metadata import displayable_metadata - -# -- Functions for file, directory and image info querying - -output_dir = get_generated_imgs_path() - - -def outputgallery_filenames(subdir) -> list[str]: - new_dir_path = os.path.join(output_dir, subdir) - if os.path.exists(new_dir_path): - filenames = [ - glob.glob(new_dir_path + "/" + ext) - for ext in ("*.png", "*.jpg", "*.jpeg") - ] - - return sorted(sum(filenames, []), key=os.path.getmtime, reverse=True) - else: - return [] - - -def output_subdirs() -> list[str]: - # Gets a list of subdirectories of output_dir and below, as relative paths. - relative_paths = [ - os.path.relpath(entry[0], output_dir) - for entry in os.walk( - output_dir, followlinks=args.output_gallery_followlinks - ) - ] - - # It is less confusing to always including the subdir that will take any - # images generated today even if it doesn't exist yet - if get_generated_imgs_todays_subdir() not in relative_paths: - relative_paths.append(get_generated_imgs_todays_subdir()) - - # sort subdirectories so that the date named ones we probably - # created in this or previous sessions come first, sorted with the most - # recent first. Other subdirs are listed after. - generated_paths = sorted( - [path for path in relative_paths if path.isnumeric()], reverse=True - ) - result_paths = generated_paths + sorted( - [ - path - for path in relative_paths - if (not path.isnumeric()) and path != "." - ] - ) - - return result_paths - - -# --- Define UI layout for Gradio - -with gr.Blocks() as outputgallery_web: - nod_logo = Image.open(nodlogo_loc) - - with gr.Row(elem_id="outputgallery_gallery"): - # needed to workaround gradio issue: - # https://github.com/gradio-app/gradio/issues/2907 - dev_null = gr.Textbox("", visible=False) - - gallery_files = gr.State(value=[]) - subdirectory_paths = gr.State(value=[]) - - with gr.Column(scale=6): - logo = gr.Image( - label="Getting subdirectories...", - value=nod_logo, - interactive=False, - show_download_button=False, - visible=True, - show_label=True, - elem_id="top_logo", - elem_classes="logo_centered", - ) - gallery = gr.Gallery( - label="", - value=gallery_files.value, - visible=False, - show_label=True, - columns=4, - # TODO: Re-enable download when fixed in Gradio - show_download_button=False, - ) - - with gr.Column(scale=4): - with gr.Group(): - with gr.Row(elem_id="output_subdir_container"): - with gr.Column( - scale=15, - min_width=160, - ): - subdirectories = gr.Dropdown( - label=f"Subdirectories of {output_dir}", - type="value", - choices=subdirectory_paths.value, - value="", - interactive=True, - # elem_classes="dropdown_no_container", - allow_custom_value=True, - ) - with gr.Column( - scale=1, - min_width=32, - elem_classes="output_icon_button", - ): - open_subdir = gr.Button( - variant="secondary", - value="\U0001F5C1", # unicode open folder - interactive=False, - size="sm", - ) - with gr.Column( - scale=1, - min_width=32, - elem_classes="output_icon_button", - ): - refresh = gr.Button( - variant="secondary", - value="\u21BB", # unicode clockwise arrow circle - size="sm", - ) - - image_columns = gr.Slider( - label="Columns shown", value=4, minimum=1, maximum=16, step=1 - ) - outputgallery_filename = gr.Textbox( - label="Filename", - value="None", - interactive=False, - show_copy_button=True, - ) - - with gr.Accordion( - label="Parameter Information", open=False - ) as parameters_accordian: - image_parameters = gr.DataFrame( - headers=["Parameter", "Value"], - col_count=(2, "fixed"), - row_count=(1, "fixed"), - wrap=True, - elem_classes="output_parameters_dataframe", - value=[["Status", "No image selected"]], - interactive=False, - ) - - with gr.Accordion(label="Send To", open=True): - with gr.Row(): - outputgallery_sendto_txt2img = gr.Button( - value="Txt2Img", - interactive=False, - elem_classes="outputgallery_sendto", - size="sm", - ) - outputgallery_sendto_txt2img_sdxl = gr.Button( - value="Txt2Img XL", - interactive=False, - elem_classes="outputgallery_sendto", - size="sm", - ) - - outputgallery_sendto_img2img = gr.Button( - value="Img2Img", - interactive=False, - elem_classes="outputgallery_sendto", - size="sm", - ) - - outputgallery_sendto_inpaint = gr.Button( - value="Inpaint", - interactive=False, - elem_classes="outputgallery_sendto", - size="sm", - ) - - outputgallery_sendto_outpaint = gr.Button( - value="Outpaint", - interactive=False, - elem_classes="outputgallery_sendto", - size="sm", - ) - - outputgallery_sendto_upscaler = gr.Button( - value="Upscaler", - interactive=False, - elem_classes="outputgallery_sendto", - size="sm", - ) - - # --- Event handlers - - def on_clear_gallery(): - return [ - gr.Gallery( - value=[], - visible=False, - ), - gr.Image( - visible=True, - ), - ] - - def on_image_columns_change(columns): - return gr.Gallery(columns=columns) - - def on_select_subdir(subdir) -> list: - # evt.value is the subdirectory name - new_images = outputgallery_filenames(subdir) - new_label = ( - f"{len(new_images)} images in {os.path.join(output_dir, subdir)}" - ) - return [ - new_images, - gr.Gallery( - value=new_images, - label=new_label, - visible=len(new_images) > 0, - ), - gr.Image( - label=new_label, - visible=len(new_images) == 0, - ), - ] - - def on_open_subdir(subdir): - subdir_path = os.path.normpath(os.path.join(output_dir, subdir)) - - if os.path.isdir(subdir_path): - if sys.platform == "linux": - subprocess.run(["xdg-open", subdir_path]) - elif sys.platform == "darwin": - subprocess.run(["open", subdir_path]) - elif sys.platform == "win32": - os.startfile(subdir_path) - - def on_refresh(current_subdir: str) -> list: - # get an up-to-date subdirectory list - refreshed_subdirs = output_subdirs() - # get the images using either the current subdirectory or the most - # recent valid one - new_subdir = ( - current_subdir - if current_subdir in refreshed_subdirs - else refreshed_subdirs[0] - ) - new_images = outputgallery_filenames(new_subdir) - new_label = ( - f"{len(new_images)} images in " - f"{os.path.join(output_dir, new_subdir)}" - ) - - return [ - gr.Dropdown( - choices=refreshed_subdirs, - value=new_subdir, - ), - refreshed_subdirs, - new_images, - gr.Gallery( - value=new_images, label=new_label, visible=len(new_images) > 0 - ), - gr.Image( - label=new_label, - visible=len(new_images) == 0, - ), - ] - - def on_new_image(subdir, subdir_paths, status) -> list: - # prevent error triggered when an image generates before the tab - # has even been selected - subdir_paths = ( - subdir_paths - if len(subdir_paths) > 0 - else [get_generated_imgs_todays_subdir()] - ) - - # only update if the current subdir is the most recent one as - # new images only go there - if subdir_paths[0] == subdir: - new_images = outputgallery_filenames(subdir) - new_label = ( - f"{len(new_images)} images in " - f"{os.path.join(output_dir, subdir)} - {status}" - ) - - return [ - new_images, - gr.Gallery( - value=new_images, - label=new_label, - visible=len(new_images) > 0, - ), - gr.Image( - label=new_label, - visible=len(new_images) == 0, - ), - ] - else: - # otherwise change nothing, - # (only untyped gradio gr.update() does this) - return [gr.update(), gr.update(), gr.update()] - - def on_select_image(images: list[str], evt: gr.SelectData) -> list: - # evt.index is an index into the full list of filenames for - # the current subdirectory - filename = images[evt.index] - params = displayable_metadata(filename) - - if params: - if params["source"] == "missing": - return [ - "Could not find this image file, refresh the gallery and update the images", - [["Status", "File missing"]], - ] - else: - return [ - filename, - gr.DataFrame( - value=list(map(list, params["parameters"].items())), - row_count=(len(params["parameters"]), "fixed"), - ), - ] - - return [ - filename, - gr.DataFrame( - value=[["Status", "No parameters found"]], - row_count=(1, "fixed"), - ), - ] - - def on_outputgallery_filename_change(filename: str) -> list: - exists = filename != "None" and os.path.exists(filename) - return [ - # disable or enable each of the sendto button based on whether - # an image is selected - gr.Button(interactive=exists), - gr.Button(interactive=exists), - gr.Button(interactive=exists), - gr.Button(interactive=exists), - gr.Button(interactive=exists), - gr.Button(interactive=exists), - ] - - # The time first our tab is selected we need to do an initial refresh - # to populate the subdirectory select box and the images from the most - # recent subdirectory. - # - # We do it at this point rather than setting this up in the controls' - # definitions as when you refresh the browser you always get what was - # *initially* set, which won't include any new subdirectories or images - # that might have created since the application was started. Doing it - # this way means a browser refresh/reload always gets the most - # up-to-date data. - def on_select_tab(subdir_paths, request: gr.Request): - local_client = request.headers["host"].startswith( - "127.0.0.1:" - ) or request.headers["host"].startswith("localhost:") - - if len(subdir_paths) == 0: - return on_refresh("") + [gr.update(interactive=local_client)] - else: - return ( - # Change nothing, (only untyped gr.update() does this) - gr.update(), - gr.update(), - gr.update(), - gr.update(), - gr.update(), - gr.update(), - ) - - # clearing images when we need to completely change what's in the - # gallery avoids current images being shown replacing piecemeal and - # prevents weirdness and errors if the user selects an image during the - # replacement phase. - clear_gallery = dict( - fn=on_clear_gallery, - inputs=None, - outputs=[gallery, logo], - queue=False, - ) - - subdirectories.select(**clear_gallery).then( - on_select_subdir, - [subdirectories], - [gallery_files, gallery, logo], - queue=False, - ) - - open_subdir.click(on_open_subdir, inputs=[subdirectories], queue=False) - - refresh.click(**clear_gallery).then( - on_refresh, - [subdirectories], - [subdirectories, subdirectory_paths, gallery_files, gallery, logo], - queue=False, - ) - - image_columns.change( - fn=on_image_columns_change, - inputs=[image_columns], - outputs=[gallery], - queue=False, - ) - - gallery.select( - on_select_image, - [gallery_files], - [outputgallery_filename, image_parameters], - queue=False, - ) - - outputgallery_filename.change( - on_outputgallery_filename_change, - [outputgallery_filename], - [ - outputgallery_sendto_txt2img, - outputgallery_sendto_txt2img_sdxl, - outputgallery_sendto_img2img, - outputgallery_sendto_inpaint, - outputgallery_sendto_outpaint, - outputgallery_sendto_upscaler, - ], - queue=False, - ) - - # We should have been given the .select function for our tab, so set it up - def outputgallery_tab_select(select): - select( - fn=on_select_tab, - inputs=[subdirectory_paths], - outputs=[ - subdirectories, - subdirectory_paths, - gallery_files, - gallery, - logo, - open_subdir, - ], - queue=False, - ) - - # We should have been passed a list of components on other tabs that update - # when a new image has generated on that tab, so set things up so the user - # will see that new image if they are looking at today's subdirectory - def outputgallery_watch(components: gr.Textbox, queued_components=[]): - for component in components: - component.change( - on_new_image, - inputs=[subdirectories, subdirectory_paths, component], - outputs=[gallery_files, gallery, logo], - queue=component in queued_components, - show_progress="none", - ) diff --git a/apps/stable_diffusion/web/ui/stablelm_ui.py b/apps/stable_diffusion/web/ui/stablelm_ui.py deleted file mode 100644 index 5b13145a..00000000 --- a/apps/stable_diffusion/web/ui/stablelm_ui.py +++ /dev/null @@ -1,549 +0,0 @@ -import gradio as gr -import torch -import os -from pathlib import Path -from transformers import ( - AutoModelForCausalLM, -) -from apps.stable_diffusion.web.ui.utils import available_devices -from shark.iree_utils.compile_utils import clean_device_info -from datetime import datetime as dt -import json -import sys - - -def user(message, history): - # Append the user's message to the conversation history - return "", history + [[message, ""]] - - -sharkModel = 0 -sharded_model = 0 -vicuna_model = 0 - -past_key_values = None - -model_map = { - "llama2_7b": "meta-llama/Llama-2-7b-chat-hf", - "llama2_13b": "meta-llama/Llama-2-13b-chat-hf", - "llama2_70b": "meta-llama/Llama-2-70b-chat-hf", - "vicuna": "TheBloke/vicuna-7B-1.1-HF", -} - -# NOTE: Each `model_name` should have its own start message -start_message = { - "llama2_7b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "llama2_13b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "llama2_70b": ( - "You are a helpful, respectful and honest assistant. Always answer " - "as helpfully as possible, while being safe. Your answers should not " - "include any harmful, unethical, racist, sexist, toxic, dangerous, or " - "illegal content. Please ensure that your responses are socially " - "unbiased and positive in nature. If a question does not make any " - "sense, or is not factually coherent, explain why instead of " - "answering something not correct. If you don't know the answer " - "to a question, please don't share false information." - ), - "vicuna": ( - "A chat between a curious user and an artificial intelligence " - "assistant. The assistant gives helpful, detailed, and " - "polite answers to the user's questions.\n" - ), -} - - -def create_prompt(model_name, history, prompt_prefix): - system_message = "" - if prompt_prefix: - system_message = start_message[model_name] - - if "llama2" in model_name: - B_INST, E_INST = "[INST]", "[/INST]" - B_SYS, E_SYS = "<>\n", "\n<>\n\n" - conversation = "".join( - [f"{B_INST} {item[0]} {E_INST} {item[1]} " for item in history[1:]] - ) - if prompt_prefix: - msg = f"{B_INST} {B_SYS}{system_message}{E_SYS}{history[0][0]} {E_INST} {history[0][1]} {conversation}" - else: - msg = f"{B_INST} {history[0][0]} {E_INST} {history[0][1]} {conversation}" - elif model_name in ["vicuna"]: - conversation = "".join( - [ - "".join(["<|USER|>" + item[0], "<|ASSISTANT|>" + item[1]]) - for item in history - ] - ) - msg = system_message + conversation - msg = msg.strip() - else: - conversation = "".join( - ["".join([item[0], item[1]]) for item in history] - ) - msg = system_message + conversation - msg = msg.strip() - return msg - - -def set_vicuna_model(model): - global vicuna_model - vicuna_model = model - - -def get_default_config(): - import torch - from transformers import AutoTokenizer - - hf_model_path = "TheBloke/vicuna-7B-1.1-HF" - tokenizer = AutoTokenizer.from_pretrained(hf_model_path, use_fast=False) - compilation_prompt = "".join(["0" for _ in range(17)]) - compilation_input_ids = tokenizer( - compilation_prompt, - return_tensors="pt", - ).input_ids - compilation_input_ids = torch.tensor(compilation_input_ids).reshape( - [1, 19] - ) - firstVicunaCompileInput = (compilation_input_ids,) - from apps.language_models.src.model_wrappers.vicuna_model import ( - CombinedModel, - ) - from shark.shark_generate_model_config import GenerateConfigFile - - model = CombinedModel() - c = GenerateConfigFile(model, 1, ["gpu_id"], firstVicunaCompileInput) - c.split_into_layers() - - -model_vmfb_key = "" - - -# TODO: Make chat reusable for UI and API -def chat( - prompt_prefix, - history, - model, - backend, - devices, - sharded, - precision, - download_vmfb, - config_file, - cli=False, - progress=gr.Progress(), -): - global past_key_values - global model_vmfb_key - global vicuna_model - - model_name, model_path = list(map(str.strip, model.split("=>"))) - device, device_id = clean_device_info(devices[0]) - no_of_devices = len(devices) - - from apps.language_models.scripts.vicuna import ShardedVicuna - from apps.language_models.scripts.vicuna import UnshardedVicuna - from apps.stable_diffusion.src import args - - new_model_vmfb_key = f"{model_name}#{model_path}#{device}#{device_id}#{precision}#{download_vmfb}" - if vicuna_model is None or new_model_vmfb_key != model_vmfb_key: - model_vmfb_key = new_model_vmfb_key - max_toks = 128 if model_name == "codegen" else 512 - - # get iree flags that need to be overridden, from commandline args - _extra_args = [] - # vulkan target triple - vulkan_target_triple = args.iree_vulkan_target_triple - from shark.iree_utils.vulkan_utils import ( - get_all_vulkan_devices, - get_vulkan_target_triple, - ) - - _extra_args = _extra_args + [ - "--iree-global-opt-enable-quantized-matmul-reassociation", - "--iree-llvmcpu-enable-quantized-matmul-reassociation", - "--iree-opt-const-eval=false", - "--iree-opt-data-tiling=false", - ] - - if device == "vulkan": - vulkaninfo_list = get_all_vulkan_devices() - if vulkan_target_triple == "": - # We already have the device_id extracted via WebUI, so we directly use - # that to find the target triple. - vulkan_target_triple = get_vulkan_target_triple( - vulkaninfo_list[device_id] - ) - _extra_args.append( - f"-iree-vulkan-target-triple={vulkan_target_triple}" - ) - if "rdna" in vulkan_target_triple: - flags_to_add = [ - "--iree-spirv-index-bits=64", - ] - _extra_args = _extra_args + flags_to_add - - if device_id is None: - id = 0 - for device in vulkaninfo_list: - target_triple = get_vulkan_target_triple( - vulkaninfo_list[id] - ) - if target_triple == vulkan_target_triple: - device_id = id - break - id += 1 - - assert ( - device_id - ), f"no vulkan hardware for target-triple '{vulkan_target_triple}' exists" - print(f"Will use vulkan target triple : {vulkan_target_triple}") - - elif "rocm" in device: - # add iree rocm flags - if args.iree_rocm_target_chip != "": - _extra_args.append( - f"--iree-rocm-target-chip={args.iree_rocm_target_chip}" - ) - print(f"extra args = {_extra_args}") - - if sharded: - vicuna_model = ShardedVicuna( - model_name, - hf_model_path=model_path, - device=device, - precision=precision, - max_num_tokens=max_toks, - compressed=True, - extra_args_cmd=_extra_args, - n_devices=no_of_devices, - ) - else: - # if config_file is None: - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - hf_auth_token=args.hf_auth_token, - device=device, - vulkan_target_triple=vulkan_target_triple, - precision=precision, - max_num_tokens=max_toks, - download_vmfb=download_vmfb, - load_mlir_from_shark_tank=True, - extra_args_cmd=_extra_args, - device_id=device_id, - ) - - if vicuna_model is None: - sys.exit("Unable to instantiate the model object, exiting.") - - prompt = create_prompt(model_name, history, prompt_prefix) - - partial_text = "" - token_count = 0 - total_time_ms = 0.001 # In order to avoid divide by zero error - prefill_time = 0 - is_first = True - # for text, msg, exec_time in progress.tqdm( - # vicuna_model.generate(prompt, cli=cli), - # desc="generating response", - # ): - for text, msg, exec_time in vicuna_model.generate(prompt, cli=cli): - if msg is None: - if is_first: - prefill_time = exec_time / 1000 - is_first = False - else: - total_time_ms += exec_time - token_count += 1 - partial_text += text + " " - history[-1][1] = partial_text - yield history, f"Prefill: {prefill_time:.2f}" - elif "formatted" in msg: - history[-1][1] = text - tokens_per_sec = (token_count / total_time_ms) * 1000 - yield history, f"Prefill: {prefill_time:.2f} seconds\n Decode: {tokens_per_sec:.2f} tokens/sec" - else: - sys.exit( - "unexpected message from the vicuna generate call, exiting." - ) - - return history, "" - - -def llm_chat_api(InputData: dict): - print(f"Input keys : {InputData.keys()}") - # print(f"model : {InputData['model']}") - is_chat_completion_api = ( - "messages" in InputData.keys() - ) # else it is the legacy `completion` api - # For Debugging input data from API - # if is_chat_completion_api: - # print(f"message -> role : {InputData['messages'][0]['role']}") - # print(f"message -> content : {InputData['messages'][0]['content']}") - # else: - # print(f"prompt : {InputData['prompt']}") - # print(f"max_tokens : {InputData['max_tokens']}") # Default to 128 for now - global vicuna_model - model_name = ( - InputData["model"] if "model" in InputData.keys() else "codegen" - ) - model_path = model_map[model_name] - device = "cpu-task" - precision = "fp16" - max_toks = ( - None - if "max_tokens" not in InputData.keys() - else InputData["max_tokens"] - ) - if max_toks is None: - max_toks = 128 if model_name == "codegen" else 512 - - # make it working for codegen first - from apps.language_models.scripts.vicuna import ( - UnshardedVicuna, - ) - - device_id = None - if vicuna_model == 0: - device, device_id = clean_device_info(device) - - vicuna_model = UnshardedVicuna( - model_name, - hf_model_path=model_path, - device=device, - precision=precision, - max_num_tokens=max_toks, - download_vmfb=True, - load_mlir_from_shark_tank=True, - device_id=device_id, - ) - - # TODO: add role dict for different models - if is_chat_completion_api: - # TODO: add funtionality for multiple messages - prompt = create_prompt( - model_name, [(InputData["messages"][0]["content"], "")] - ) - else: - prompt = InputData["prompt"] - print("prompt = ", prompt) - - res = vicuna_model.generate(prompt) - res_op = None - for op in res: - res_op = op - - if is_chat_completion_api: - choices = [ - { - "index": 0, - "message": { - "role": "assistant", - "content": res_op, # since we are yeilding the result - }, - "finish_reason": "stop", # or length - } - ] - else: - choices = [ - { - "text": res_op, - "index": 0, - "logprobs": None, - "finish_reason": "stop", # or length - } - ] - end_time = dt.now().strftime("%Y%m%d%H%M%S%f") - return { - "id": end_time, - "object": "chat.completion" - if is_chat_completion_api - else "text_completion", - "created": int(end_time), - "choices": choices, - } - - -def view_json_file(file_obj): - content = "" - with open(file_obj.name, "r") as fopen: - content = fopen.read() - return content - - -filtered_devices = dict() - - -def change_backend(backend): - new_choices = gr.Dropdown( - choices=filtered_devices[backend], label=f"{backend} devices" - ) - return new_choices - - -with gr.Blocks(title="Chatbot") as stablelm_chat: - with gr.Row(): - model_choices = list( - map(lambda x: f"{x[0]: <10} => {x[1]}", model_map.items()) - ) - model = gr.Dropdown( - label="Select Model", - value=model_choices[0], - choices=model_choices, - allow_custom_value=True, - ) - supported_devices = available_devices - enabled = len(supported_devices) > 0 - # show cpu-task device first in list for chatbot - supported_devices = supported_devices[-1:] + supported_devices[:-1] - supported_devices = [x for x in supported_devices if "sync" not in x] - backend_list = ["cpu", "cuda", "vulkan", "rocm"] - for x in backend_list: - filtered_devices[x] = [y for y in supported_devices if x in y] - print(filtered_devices) - - backend = gr.Radio( - label="backend", - value="cpu", - choices=backend_list, - ) - device = gr.Dropdown( - label="cpu devices", - choices=filtered_devices["cpu"], - interactive=True, - allow_custom_value=True, - multiselect=True, - ) - precision = gr.Radio( - label="Precision", - value="int4", - choices=[ - "int4", - "int8", - "fp16", - ], - visible=False, - ) - tokens_time = gr.Textbox(label="Tokens generated per second") - with gr.Column(): - download_vmfb = gr.Checkbox( - label="Download vmfb from Shark tank if available", - value=True, - interactive=True, - ) - prompt_prefix = gr.Checkbox( - label="Add System Prompt", - value=False, - interactive=True, - ) - sharded = gr.Checkbox( - label="Shard Model", - value=False, - interactive=True, - ) - - with gr.Row(visible=False): - with gr.Group(): - config_file = gr.File( - label="Upload sharding configuration", visible=False - ) - json_view_button = gr.Button(value="View as JSON", visible=False) - json_view = gr.JSON(visible=False) - json_view_button.click( - fn=view_json_file, inputs=[config_file], outputs=[json_view] - ) - chatbot = gr.Chatbot(elem_id="chatbot") - with gr.Row(): - with gr.Column(): - msg = gr.Textbox( - label="Chat Message Box", - placeholder="Chat Message Box", - show_label=False, - interactive=enabled, - container=False, - ) - with gr.Column(): - with gr.Row(): - submit = gr.Button("Submit", interactive=enabled) - stop = gr.Button("Stop", interactive=enabled) - clear = gr.Button("Clear", interactive=enabled) - - backend.change( - fn=change_backend, - inputs=[backend], - outputs=[device], - show_progress=False, - ) - - submit_event = msg.submit( - fn=user, - inputs=[msg, chatbot], - outputs=[msg, chatbot], - show_progress=False, - queue=False, - ).then( - fn=chat, - inputs=[ - prompt_prefix, - chatbot, - model, - backend, - device, - sharded, - precision, - download_vmfb, - config_file, - ], - outputs=[chatbot, tokens_time], - show_progress=False, - queue=True, - ) - submit_click_event = submit.click( - fn=user, - inputs=[msg, chatbot], - outputs=[msg, chatbot], - show_progress=False, - queue=False, - ).then( - fn=chat, - inputs=[ - prompt_prefix, - chatbot, - model, - backend, - device, - sharded, - precision, - download_vmfb, - config_file, - ], - outputs=[chatbot, tokens_time], - show_progress=False, - queue=True, - ) - stop.click( - fn=None, - inputs=None, - outputs=None, - cancels=[submit_event, submit_click_event], - queue=False, - ) - clear.click(lambda: None, None, [chatbot], queue=False) diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py deleted file mode 100644 index 807c30ad..00000000 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ /dev/null @@ -1,653 +0,0 @@ -import os -import torch -import time -import sys -import gradio as gr -from PIL import Image -from math import ceil -from apps.stable_diffusion.web.ui.utils import ( - available_devices, - nodlogo_loc, - get_custom_model_path, - get_custom_model_files, - scheduler_list, - predefined_sdxl_models, - cancel_sd, - set_model_default_configs, -) -from apps.stable_diffusion.web.ui.common_ui_events import lora_changed -from apps.stable_diffusion.web.utils.metadata import import_png_metadata -from apps.stable_diffusion.web.utils.common_label_calc import status_label -from apps.stable_diffusion.src import ( - args, - Text2ImageSDXLPipeline, - get_schedulers, - set_init_device_flags, - utils, - save_output_img, - prompt_examples, - Image2ImagePipeline, -) -from apps.stable_diffusion.src.utils import ( - get_generated_imgs_path, - get_generation_text_info, -) - -# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. -init_iree_vulkan_target_triple = args.iree_vulkan_target_triple -init_iree_metal_target_platform = args.iree_metal_target_platform -init_use_tuned = args.use_tuned -init_import_mlir = args.import_mlir - - -def txt2img_sdxl_inf( - prompt: str, - negative_prompt: str, - height: int, - width: int, - steps: int, - guidance_scale: float, - seed: str | int, - batch_count: int, - batch_size: int, - scheduler: str, - model_id: str, - custom_vae: str, - precision: str, - device: str, - max_length: int, - save_metadata_to_json: bool, - save_metadata_to_png: bool, - lora_weights: str, - lora_hf_id: str, - ondemand: bool, - repeatable_seeds: bool, -): - from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_pathfile, - get_custom_vae_or_lora_weights, - Config, - ) - import apps.stable_diffusion.web.utils.global_obj as global_obj - from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - SD_STATE_CANCEL, - ) - - if precision != "fp16": - print("currently we support fp16 for SDXL") - precision = "fp16" - - args.prompts = [prompt] - args.negative_prompts = [negative_prompt] - args.guidance_scale = guidance_scale - args.steps = steps - args.scheduler = scheduler - args.ondemand = ondemand - - # set ckpt_loc and hf_model_id. - args.ckpt_loc = "" - args.hf_model_id = "" - args.custom_vae = "" - - # .safetensor or .chkpt on the custom model path - if model_id in get_custom_model_files(): - args.ckpt_loc = get_custom_model_pathfile(model_id) - # civitai download - elif "civitai" in model_id: - args.ckpt_loc = model_id - # either predefined or huggingface - else: - args.hf_model_id = model_id - - if custom_vae: - args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") - - args.save_metadata_to_json = save_metadata_to_json - args.write_metadata_to_png = save_metadata_to_png - - args.use_lora = get_custom_vae_or_lora_weights( - lora_weights, lora_hf_id, "lora" - ) - - dtype = torch.float32 if precision == "fp32" else torch.half - cpu_scheduling = not scheduler.startswith("Shark") - new_config_obj = Config( - "txt2img_sdxl", - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - precision, - batch_size, - max_length, - height, - width, - device, - use_lora=args.use_lora, - stencils=None, - ondemand=ondemand, - ) - if ( - not global_obj.get_sd_obj() - or global_obj.get_cfg_obj() != new_config_obj - ): - global_obj.clear_cache() - global_obj.set_cfg_obj(new_config_obj) - args.precision = precision - args.batch_count = batch_count - args.batch_size = batch_size - args.max_length = max_length - args.height = height - args.width = width - args.device = device.split("=>", 1)[1].strip() - args.iree_vulkan_target_triple = init_iree_vulkan_target_triple - args.iree_metal_target_platform = init_iree_metal_target_platform - args.use_tuned = init_use_tuned - args.import_mlir = init_import_mlir - args.img_path = None - set_init_device_flags() - model_id = ( - args.hf_model_id - if args.hf_model_id - else "stabilityai/stable-diffusion-xl-base-1.0" - ) - global_obj.set_schedulers(get_schedulers(model_id)) - scheduler_obj = global_obj.get_scheduler(scheduler) - if global_obj.get_cfg_obj().ondemand: - print("Running txt2img in memory efficient mode.") - global_obj.set_sd_obj( - Text2ImageSDXLPipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=precision, - max_length=max_length, - batch_size=batch_size, - height=height, - width=width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - use_quantize=args.use_quantize, - ondemand=global_obj.get_cfg_obj().ondemand, - ) - ) - - global_obj.set_sd_scheduler(scheduler) - - start_time = time.time() - global_obj.get_sd_obj().log = "" - generated_imgs = [] - text_output = "" - try: - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) - except TypeError as error: - raise gr.Error(str(error)) from None - - for current_batch in range(batch_count): - out_imgs = global_obj.get_sd_obj().generate_images( - prompt, - negative_prompt, - batch_size, - height, - width, - steps, - guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - - total_time = time.time() - start_time - text_output = get_generation_text_info( - seeds[: current_batch + 1], device - ) - text_output += "\n" + global_obj.get_sd_obj().log - text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec" - - if global_obj.get_sd_status() == SD_STATE_CANCEL: - break - else: - save_output_img(out_imgs[0], seeds[current_batch]) - generated_imgs.extend(out_imgs) - yield generated_imgs, text_output, status_label( - "Text-to-Image-SDXL", - current_batch + 1, - batch_count, - batch_size, - ) - - return generated_imgs, text_output, "" - - -theme = gr.themes.Glass( - primary_hue="slate", - secondary_hue="gray", -) - -with gr.Blocks(title="Text-to-Image-SDXL", theme=theme) as txt2img_sdxl_web: - with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - show_download_button=False, - elem_id="top_logo", - width=150, - height=50, - ) - with gr.Row(elem_id="ui_body"): - with gr.Row(): - with gr.Column(scale=1, min_width=600): - with gr.Row(): - with gr.Column(scale=10): - with gr.Row(): - t2i_sdxl_model_info = f"Custom Model Path: {str(get_custom_model_path())}" - txt2img_sdxl_custom_model = gr.Dropdown( - label=f"Models", - info="Select, or enter HuggingFace Model ID or Civitai model download URL", - elem_id="custom_model", - value=os.path.basename(args.ckpt_loc) - if args.ckpt_loc - else "stabilityai/stable-diffusion-xl-base-1.0", - choices=predefined_sdxl_models - + get_custom_model_files( - custom_checkpoint_type="sdxl" - ), - allow_custom_value=True, - scale=11, - ) - t2i_sdxl_vae_info = ( - str(get_custom_model_path("vae")) - ).replace("\\", "\n\\") - t2i_sdxl_vae_info = ( - f"VAE Path: {t2i_sdxl_vae_info}" - ) - custom_vae = gr.Dropdown( - label=f"VAE Models", - info=t2i_sdxl_vae_info, - elem_id="custom_model", - value="None", - choices=[ - None, - "madebyollin/sdxl-vae-fp16-fix", - ] - + get_custom_model_files("vae"), - allow_custom_value=True, - scale=4, - ) - txt2img_sdxl_png_info_img = gr.Image( - scale=1, - label="Import PNG info", - elem_id="txt2img_prompt_image", - type="pil", - visible=True, - sources=["upload"], - ) - - with gr.Group(elem_id="prompt_box_outer"): - txt2img_sdxl_autogen = gr.Checkbox( - label="Auto-Generate Images", - value=False, - visible=False, - ) - prompt = gr.Textbox( - label="Prompt", - value=args.prompts[0], - lines=2, - elem_id="prompt_box", - show_copy_button=True, - ) - negative_prompt = gr.Textbox( - label="Negative Prompt", - value=args.negative_prompts[0], - lines=2, - elem_id="negative_prompt_box", - show_copy_button=True, - ) - with gr.Accordion(label="LoRA Options", open=False): - with gr.Row(): - # janky fix for overflowing text - t2i_sdxl_lora_info = ( - str(get_custom_model_path("lora")) - ).replace("\\", "\n\\") - t2i_sdxl_lora_info = f"LoRA Path: {t2i_sdxl_lora_info}" - lora_weights = gr.Dropdown( - label=f"Standalone LoRA Weights", - info=t2i_sdxl_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_custom_model_files("lora"), - allow_custom_value=True, - ) - lora_hf_id = gr.Textbox( - elem_id="lora_hf_id", - placeholder="Select 'None' in the Standalone LoRA " - "weights dropdown on the left if you want to use " - "a standalone HuggingFace model ID for LoRA here " - "e.g: sayakpaul/sd-model-finetuned-lora-t4", - value="", - label="HuggingFace Model ID", - lines=3, - ) - with gr.Row(): - lora_tags = gr.HTML( - value="
    No LoRA selected
    ", - elem_classes="lora-tags", - ) - with gr.Accordion(label="Advanced Options", open=False): - with gr.Row(): - scheduler = gr.Dropdown( - elem_id="scheduler", - label="Scheduler", - value="EulerDiscrete", - choices=[ - "DDIM", - "EulerAncestralDiscrete", - "EulerDiscrete", - "LCMScheduler", - ], - allow_custom_value=True, - visible=True, - ) - with gr.Column(): - save_metadata_to_png = gr.Checkbox( - label="Save prompt information to PNG", - value=args.write_metadata_to_png, - interactive=True, - ) - save_metadata_to_json = gr.Checkbox( - label="Save prompt information to JSON file", - value=args.save_metadata_to_json, - interactive=True, - ) - with gr.Row(): - height = gr.Slider( - 512, - 1024, - value=1024, - step=256, - label="Height", - visible=True, - interactive=True, - ) - width = gr.Slider( - 512, - 1024, - value=1024, - step=256, - label="Width", - visible=True, - interactive=True, - ) - precision = gr.Radio( - label="Precision", - value="fp16", - choices=[ - "fp16", - ], - visible=False, - ) - max_length = gr.Radio( - label="Max Length", - value=77, - choices=[ - 64, - 77, - ], - visible=False, - ) - with gr.Row(): - with gr.Column(scale=3): - steps = gr.Slider( - 1, 100, value=args.steps, step=1, label="Steps" - ) - with gr.Column(scale=3): - guidance_scale = gr.Slider( - 0, - 50, - value=args.guidance_scale, - step=0.1, - label="Guidance Scale", - ) - ondemand = gr.Checkbox( - value=args.ondemand, - label="Low VRAM", - interactive=True, - ) - with gr.Row(): - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=args.batch_count, - step=1, - label="Batch Count", - interactive=True, - ) - with gr.Column(scale=3): - batch_size = gr.Slider( - 1, - 4, - value=args.batch_size, - step=1, - label="Batch Size", - interactive=False, - visible=False, - ) - repeatable_seeds = gr.Checkbox( - args.repeatable_seeds, - label="Repeatable Seeds", - ) - - with gr.Row(): - seed = gr.Textbox( - value=args.seed, - label="Seed", - info="An integer or a JSON list of integers, -1 for random", - ) - device = gr.Dropdown( - elem_id="device", - label="Device", - value=available_devices[0], - choices=available_devices, - allow_custom_value=True, - ) - with gr.Accordion(label="Prompt Examples!", open=False): - ex = gr.Examples( - examples=prompt_examples, - inputs=prompt, - cache_examples=False, - elem_id="prompt_examples", - ) - - with gr.Column(scale=1, min_width=600): - with gr.Group(): - txt2img_sdxl_gallery = gr.Gallery( - label="Generated images", - show_label=False, - elem_id="gallery", - columns=[2], - object_fit="scale_down", - # TODO: Re-enable download when fixed in Gradio - show_download_button=False, - ) - std_output = gr.Textbox( - value=f"{t2i_sdxl_model_info}\n" - f"Images will be saved at " - f"{get_generated_imgs_path()}", - lines=1, - elem_id="std_output", - show_label=False, - ) - txt2img_sdxl_status = gr.Textbox(visible=False) - with gr.Row(): - stable_diffusion = gr.Button("Generate Image(s)") - random_seed = gr.Button("Randomize Seed") - random_seed.click( - lambda: -1, - inputs=[], - outputs=[seed], - queue=False, - ) - stop_batch = gr.Button("Stop Batch") - with gr.Row(): - txt2img_sdxl_sendto_img2img = gr.Button( - value="Send To Img2Img", - visible=False, - ) - txt2img_sdxl_sendto_inpaint = gr.Button( - value="Send To Inpaint", - visible=False, - ) - txt2img_sdxl_sendto_outpaint = gr.Button( - value="Send To Outpaint", - visible=False, - ) - txt2img_sdxl_sendto_upscaler = gr.Button( - value="Send To Upscaler", - visible=False, - ) - - kwargs = dict( - fn=txt2img_sdxl_inf, - inputs=[ - prompt, - negative_prompt, - height, - width, - steps, - guidance_scale, - seed, - batch_count, - batch_size, - scheduler, - txt2img_sdxl_custom_model, - custom_vae, - precision, - device, - max_length, - save_metadata_to_json, - save_metadata_to_png, - lora_weights, - lora_hf_id, - ondemand, - repeatable_seeds, - ], - outputs=[txt2img_sdxl_gallery, std_output, txt2img_sdxl_status], - show_progress="minimal" if args.progress_bar else "none", - queue=True, - ) - - status_kwargs = dict( - fn=lambda bc, bs: status_label("Text-to-Image-SDXL", 0, bc, bs), - inputs=[batch_count, batch_size], - outputs=txt2img_sdxl_status, - concurrency_limit=1, - ) - - def autogen_changed(checked): - if checked: - args.autogen = True - else: - args.autogen = False - - def check_last_input(prompt): - if not prompt.endswith(" "): - return True - elif not args.autogen: - return True - else: - return False - - auto_gen_kwargs = dict( - fn=check_last_input, - inputs=[negative_prompt], - outputs=[txt2img_sdxl_status], - concurrency_limit=1, - ) - - txt2img_sdxl_autogen.change( - fn=autogen_changed, - inputs=[txt2img_sdxl_autogen], - outputs=None, - ) - prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **kwargs - ) - generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) - stop_batch.click( - fn=cancel_sd, - cancels=[ - prompt_submit, - neg_prompt_submit, - generate_click, - ], - ) - - txt2img_sdxl_png_info_img.change( - fn=import_png_metadata, - inputs=[ - txt2img_sdxl_png_info_img, - prompt, - negative_prompt, - steps, - scheduler, - guidance_scale, - seed, - width, - height, - txt2img_sdxl_custom_model, - lora_weights, - lora_hf_id, - custom_vae, - ], - outputs=[ - txt2img_sdxl_png_info_img, - prompt, - negative_prompt, - steps, - scheduler, - guidance_scale, - seed, - width, - height, - txt2img_sdxl_custom_model, - lora_weights, - lora_hf_id, - custom_vae, - ], - ) - txt2img_sdxl_custom_model.change( - fn=set_model_default_configs, - inputs=[ - txt2img_sdxl_custom_model, - ], - outputs=[ - prompt, - negative_prompt, - steps, - scheduler, - guidance_scale, - width, - height, - custom_vae, - txt2img_sdxl_autogen, - ], - ) - lora_weights.change( - fn=lora_changed, - inputs=[lora_weights], - outputs=[lora_tags], - queue=True, - ) diff --git a/apps/stable_diffusion/web/ui/txt2img_ui.py b/apps/stable_diffusion/web/ui/txt2img_ui.py deleted file mode 100644 index 3b6c936c..00000000 --- a/apps/stable_diffusion/web/ui/txt2img_ui.py +++ /dev/null @@ -1,903 +0,0 @@ -import json -import os -import warnings -import torch -import time -import sys -import gradio as gr -from PIL import Image -from math import ceil - -from apps.stable_diffusion.web.ui.utils import ( - available_devices, - nodlogo_loc, - get_custom_model_path, - get_custom_model_files, - scheduler_list, - scheduler_list_cpu_only, - predefined_models, - cancel_sd, -) -from apps.stable_diffusion.web.ui.common_ui_events import lora_changed -from apps.stable_diffusion.web.utils.metadata import import_png_metadata -from apps.stable_diffusion.web.utils.common_label_calc import status_label -from apps.stable_diffusion.src import ( - args, - Text2ImagePipeline, - get_schedulers, - set_init_device_flags, - utils, - save_output_img, - prompt_examples, - Image2ImagePipeline, -) -from apps.stable_diffusion.src.utils import ( - get_generated_imgs_path, - get_generation_text_info, - resampler_list, -) - -# Names of all interactive fields that can be edited by user -all_gradio_labels = [ - "txt2img_custom_model", - "custom_vae", - "prompt", - "negative_prompt", - "lora_weights", - "lora_hf_id", - "scheduler", - "save_metadata_to_png", - "save_metadata_to_json", - "height", - "width", - "steps", - "guidance_scale", - "Low VRAM", - "use_hiresfix", - "resample_type", - "hiresfix_height", - "hiresfix_width", - "hiresfix_strength", - "batch_count", - "batch_size", - "repeatable_seeds", - "seed", - "device", -] - -# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. -init_iree_vulkan_target_triple = args.iree_vulkan_target_triple -init_iree_metal_target_platform = args.iree_metal_target_platform -init_use_tuned = args.use_tuned -init_import_mlir = args.import_mlir - - -def txt2img_inf( - prompt: str, - negative_prompt: str, - height: int, - width: int, - steps: int, - guidance_scale: float, - seed: str | int, - batch_count: int, - batch_size: int, - scheduler: str, - model_id: str, - custom_vae: str, - precision: str, - device: str, - max_length: int, - save_metadata_to_json: bool, - save_metadata_to_png: bool, - lora_weights: str, - lora_hf_id: str, - ondemand: bool, - repeatable_seeds: bool, - use_hiresfix: bool, - hiresfix_height: int, - hiresfix_width: int, - hiresfix_strength: float, - resample_type: str, -): - from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_pathfile, - get_custom_vae_or_lora_weights, - Config, - ) - import apps.stable_diffusion.web.utils.global_obj as global_obj - from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - SD_STATE_CANCEL, - ) - - args.prompts = [prompt] - args.negative_prompts = [negative_prompt] - args.guidance_scale = guidance_scale - args.steps = steps - args.scheduler = scheduler - args.ondemand = ondemand - - # set ckpt_loc and hf_model_id. - args.ckpt_loc = "" - args.hf_model_id = "" - args.custom_vae = "" - - # .safetensor or .chkpt on the custom model path - if model_id in get_custom_model_files(): - args.ckpt_loc = get_custom_model_pathfile(model_id) - # civitai download - elif "civitai" in model_id: - args.ckpt_loc = model_id - # either predefined or huggingface - else: - args.hf_model_id = model_id - - if custom_vae != "None": - args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") - - args.save_metadata_to_json = save_metadata_to_json - args.write_metadata_to_png = save_metadata_to_png - - args.use_lora = get_custom_vae_or_lora_weights( - lora_weights, lora_hf_id, "lora" - ) - - dtype = torch.float32 if precision == "fp32" else torch.half - cpu_scheduling = not scheduler.startswith("Shark") - new_config_obj = Config( - "txt2img", - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - precision, - batch_size, - max_length, - height, - width, - device, - use_lora=args.use_lora, - stencils=[], - ondemand=ondemand, - ) - if ( - not global_obj.get_sd_obj() - or global_obj.get_cfg_obj() != new_config_obj - ): - global_obj.clear_cache() - global_obj.set_cfg_obj(new_config_obj) - args.precision = precision - args.batch_count = batch_count - args.batch_size = batch_size - args.max_length = max_length - args.height = height - args.width = width - args.use_hiresfix = use_hiresfix - args.hiresfix_height = hiresfix_height - args.hiresfix_width = hiresfix_width - args.hiresfix_strength = hiresfix_strength - args.resample_type = resample_type - args.device = device.split("=>", 1)[1].strip() - args.iree_vulkan_target_triple = init_iree_vulkan_target_triple - args.iree_metal_target_platform = init_iree_metal_target_platform - args.use_tuned = init_use_tuned - args.import_mlir = init_import_mlir - args.img_path = None - set_init_device_flags() - model_id = ( - args.hf_model_id - if args.hf_model_id - else "stabilityai/stable-diffusion-2-1-base" - ) - global_obj.set_schedulers(get_schedulers(model_id)) - scheduler_obj = global_obj.get_scheduler(scheduler) - global_obj.set_sd_obj( - Text2ImagePipeline.from_pretrained( - scheduler=scheduler_obj, - import_mlir=args.import_mlir, - model_id=args.hf_model_id, - ckpt_loc=args.ckpt_loc, - precision=args.precision, - max_length=args.max_length, - batch_size=args.batch_size, - height=args.height, - width=args.width, - use_base_vae=args.use_base_vae, - use_tuned=args.use_tuned, - custom_vae=args.custom_vae, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - ) - - global_obj.set_sd_scheduler(scheduler) - - start_time = time.time() - global_obj.get_sd_obj().log = "" - generated_imgs = [] - text_output = "" - try: - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) - except TypeError as error: - raise gr.Error(str(error)) from None - - for current_batch in range(batch_count): - out_imgs = global_obj.get_sd_obj().generate_images( - prompt, - negative_prompt, - batch_size, - height, - width, - steps, - guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - # TODO: allow user to save original image - # TODO: add option to let user keep both pipelines loaded, and unload - # either at will - # TODO: add custom step value slider - # TODO: add option to use secondary model for the img2img pass - if use_hiresfix is True: - new_config_obj = Config( - "img2img", - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - precision, - 1, - max_length, - height, - width, - device, - use_lora=args.use_lora, - stencils=[], - ondemand=ondemand, - ) - - global_obj.clear_cache() - global_obj.set_cfg_obj(new_config_obj) - set_init_device_flags() - model_id = ( - args.hf_model_id - if args.hf_model_id - else "stabilityai/stable-diffusion-2-1-base" - ) - global_obj.set_schedulers(get_schedulers(model_id)) - scheduler_obj = global_obj.get_scheduler(args.scheduler) - - global_obj.set_sd_obj( - Image2ImagePipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - 1, - hiresfix_height, - hiresfix_width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - debug=args.import_debug if args.import_mlir else False, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - ) - - global_obj.set_sd_scheduler(args.scheduler) - - out_imgs = global_obj.get_sd_obj().generate_images( - prompt, - negative_prompt, - out_imgs[0], - batch_size, - hiresfix_height, - hiresfix_width, - ceil(steps / hiresfix_strength), - hiresfix_strength, - guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - stencils=[], - control_mode=None, - resample_type=resample_type, - ) - total_time = time.time() - start_time - text_output = get_generation_text_info( - seeds[: current_batch + 1], device - ) - text_output += "\n" + global_obj.get_sd_obj().log - text_output += f"\nTotal image(s) generation time: {total_time:.4f}sec" - - if global_obj.get_sd_status() == SD_STATE_CANCEL: - break - else: - save_output_img(out_imgs[0], seeds[current_batch]) - generated_imgs.extend(out_imgs) - yield generated_imgs, text_output, status_label( - "Text-to-Image", current_batch + 1, batch_count, batch_size - ) - - return generated_imgs, text_output, "" - - -def resource_path(relative_path): - """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) - return os.path.join(base_path, relative_path) - - -dark_theme = resource_path("ui/css/sd_dark_theme.css") - - -# This function export values for all fields that can be edited by user to the settings.json file in ui folder -def export_settings(*values): - settings_list = list(zip(all_gradio_labels, values)) - settings = {} - - for label, value in settings_list: - settings[label] = value - - settings = {"txt2img": settings} - with open("./ui/settings.json", "w") as json_file: - json.dump(settings, json_file, indent=4) - - -# This function loads all values for all fields that can be edited by user from the settings.json file in ui folder -def load_settings(): - try: - with open("./ui/settings.json", "r") as json_file: - loaded_settings = json.load(json_file)["txt2img"] - except (FileNotFoundError, KeyError): - warnings.warn( - "Settings.json file not found or 'txt2img' key is missing. Using default values for fields." - ) - loaded_settings = ( - {} - ) # json file not existing or the data wasn't saved yet - - return [ - loaded_settings.get( - "txt2img_custom_model", - os.path.basename(args.ckpt_loc) - if args.ckpt_loc - else "stabilityai/stable-diffusion-2-1-base", - ), - loaded_settings.get( - "custom_vae", - os.path.basename(args.custom_vae) if args.custom_vae else "None", - ), - loaded_settings.get("prompt", args.prompts[0]), - loaded_settings.get("negative_prompt", args.negative_prompts[0]), - loaded_settings.get("lora_weights", "None"), - loaded_settings.get("lora_hf_id", ""), - loaded_settings.get("scheduler", args.scheduler), - loaded_settings.get( - "save_metadata_to_png", args.write_metadata_to_png - ), - loaded_settings.get( - "save_metadata_to_json", args.save_metadata_to_json - ), - loaded_settings.get("height", args.height), - loaded_settings.get("width", args.width), - loaded_settings.get("steps", args.steps), - loaded_settings.get("guidance_scale", args.guidance_scale), - loaded_settings.get("Low VRAM", args.ondemand), - loaded_settings.get("use_hiresfix", args.use_hiresfix), - loaded_settings.get("resample_type", args.resample_type), - loaded_settings.get("hiresfix_height", args.hiresfix_height), - loaded_settings.get("hiresfix_width", args.hiresfix_width), - loaded_settings.get("hiresfix_strength", args.hiresfix_strength), - loaded_settings.get("batch_count", args.batch_count), - loaded_settings.get("batch_size", args.batch_size), - loaded_settings.get("repeatable_seeds", args.repeatable_seeds), - loaded_settings.get("seed", args.seed), - loaded_settings.get("device", available_devices[0]), - ] - - -# This function loads the user's exported default settings on the start of program -def onload_load_settings(): - loaded_data = load_settings() - structured_data = settings_list = list(zip(all_gradio_labels, loaded_data)) - return dict(structured_data) - - -default_settings = onload_load_settings() -with gr.Blocks(title="Text-to-Image", css=dark_theme) as txt2img_web: - with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - show_download_button=False, - elem_id="top_logo", - width=150, - height=50, - ) - with gr.Row(elem_id="ui_body"): - with gr.Row(): - with gr.Column(scale=1, min_width=600): - with gr.Row(): - with gr.Column(): - with gr.Row(): - t2i_model_info = f"Custom Model Path: {str(get_custom_model_path())}" - txt2img_custom_model = gr.Dropdown( - label=f"Models", - info="Select, or enter HuggingFace Model ID or Civitai model download URL", - elem_id="custom_model", - value=default_settings.get( - "txt2img_custom_model" - ), - choices=get_custom_model_files() - + predefined_models, - allow_custom_value=True, - scale=11, - ) - # janky fix for overflowing text - t2i_vae_info = ( - str(get_custom_model_path("vae")) - ).replace("\\", "\n\\") - t2i_vae_info = f"VAE Path: {t2i_vae_info}" - custom_vae = gr.Dropdown( - label=f"VAE Models", - info=t2i_vae_info, - elem_id="custom_model", - value=default_settings.get("custom_vae"), - choices=["None"] - + get_custom_model_files("vae"), - allow_custom_value=True, - scale=4, - ) - txt2img_png_info_img = gr.Image( - label="Import PNG info", - elem_id="txt2img_prompt_image", - type="pil", - visible=True, - sources=["upload"], - scale=1, - ) - with gr.Group(elem_id="prompt_box_outer"): - prompt = gr.Textbox( - label="Prompt", - value=default_settings.get("prompt"), - lines=2, - elem_id="prompt_box", - ) - # TODO: coming soon - autogen = gr.Checkbox( - label="Continuous Generation", - visible=False, - ) - negative_prompt = gr.Textbox( - label="Negative Prompt", - value=default_settings.get("negative_prompt"), - lines=2, - elem_id="negative_prompt_box", - ) - with gr.Accordion(label="LoRA Options", open=False): - with gr.Row(): - # janky fix for overflowing text - t2i_lora_info = ( - str(get_custom_model_path("lora")) - ).replace("\\", "\n\\") - t2i_lora_info = f"LoRA Path: {t2i_lora_info}" - lora_weights = gr.Dropdown( - label=f"Standalone LoRA Weights", - info=t2i_lora_info, - elem_id="lora_weights", - value=default_settings.get("lora_weights"), - choices=["None"] + get_custom_model_files("lora"), - allow_custom_value=True, - ) - lora_hf_id = gr.Textbox( - elem_id="lora_hf_id", - placeholder="Select 'None' in the Standalone LoRA " - "weights dropdown on the left if you want to use " - "a standalone HuggingFace model ID for LoRA here " - "e.g: sayakpaul/sd-model-finetuned-lora-t4", - value=default_settings.get("lora_hf_id"), - label="HuggingFace Model ID", - lines=3, - ) - with gr.Row(): - lora_tags = gr.HTML( - value="
    No LoRA selected
    ", - elem_classes="lora-tags", - ) - with gr.Accordion(label="Advanced Options", open=False): - with gr.Row(): - scheduler = gr.Dropdown( - elem_id="scheduler", - label="Scheduler", - value=default_settings.get("scheduler"), - choices=scheduler_list, - allow_custom_value=True, - ) - with gr.Column(): - save_metadata_to_png = gr.Checkbox( - label="Save prompt information to PNG", - value=default_settings.get( - "save_metadata_to_png" - ), - interactive=True, - ) - save_metadata_to_json = gr.Checkbox( - label="Save prompt information to JSON file", - value=default_settings.get( - "save_metadata_to_json" - ), - interactive=True, - ) - with gr.Row(): - height = gr.Slider( - 384, - 768, - value=default_settings.get("height"), - step=8, - label="Height", - ) - width = gr.Slider( - 384, - 768, - value=default_settings.get("width"), - step=8, - label="Width", - ) - precision = gr.Radio( - label="Precision", - value=args.precision, - choices=[ - "fp16", - "fp32", - ], - visible=False, - ) - max_length = gr.Radio( - label="Max Length", - value=args.max_length, - choices=[ - 64, - 77, - ], - visible=False, - ) - with gr.Row(): - with gr.Column(scale=3): - steps = gr.Slider( - 1, - 100, - value=default_settings.get("steps"), - step=1, - label="Steps", - ) - with gr.Column(scale=3): - guidance_scale = gr.Slider( - 0, - 50, - value=default_settings.get("guidance_scale"), - step=0.1, - label="CFG Scale", - ) - ondemand = gr.Checkbox( - value=default_settings.get("Low VRAM"), - label="Low VRAM", - interactive=True, - ) - with gr.Row(): - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=default_settings.get("batch_count"), - step=1, - label="Batch Count", - interactive=True, - ) - with gr.Column(scale=3): - batch_size = gr.Slider( - 1, - 4, - value=args.batch_size, - step=1, - label=default_settings.get("batch_size"), - interactive=True, - ) - repeatable_seeds = gr.Checkbox( - default_settings.get("repeatable_seeds"), - label="Repeatable Seeds", - ) - with gr.Accordion(label="Hires Fix Options", open=False): - with gr.Group(): - with gr.Row(): - use_hiresfix = gr.Checkbox( - value=default_settings.get("use_hiresfix"), - label="Use Hires Fix", - interactive=True, - ) - resample_type = gr.Dropdown( - value=default_settings.get("resample_type"), - choices=resampler_list, - label="Resample Type", - allow_custom_value=False, - ) - hiresfix_height = gr.Slider( - 384, - 768, - value=default_settings.get("hiresfix_height"), - step=8, - label="Hires Fix Height", - ) - hiresfix_width = gr.Slider( - 384, - 768, - value=default_settings.get("hiresfix_width"), - step=8, - label="Hires Fix Width", - ) - hiresfix_strength = gr.Slider( - 0, - 1, - value=default_settings.get("hiresfix_strength"), - step=0.01, - label="Hires Fix Denoising Strength", - ) - with gr.Row(): - seed = gr.Textbox( - value=default_settings.get("seed"), - label="Seed", - info="An integer or a JSON list of integers, -1 for random", - ) - device = gr.Dropdown( - elem_id="device", - label="Device", - value=default_settings.get("device"), - choices=available_devices, - allow_custom_value=True, - ) - with gr.Accordion(label="Prompt Examples!", open=False): - ex = gr.Examples( - examples=prompt_examples, - inputs=prompt, - cache_examples=False, - elem_id="prompt_examples", - ) - - with gr.Column(scale=1, min_width=600): - with gr.Group(): - txt2img_gallery = gr.Gallery( - label="Generated images", - show_label=False, - elem_id="gallery", - columns=[2], - object_fit="contain", - # TODO: Re-enable download when fixed in Gradio - show_download_button=False, - ) - std_output = gr.Textbox( - value=f"{t2i_model_info}\n" - f"Images will be saved at " - f"{get_generated_imgs_path()}", - lines=1, - elem_id="std_output", - show_label=False, - ) - txt2img_status = gr.Textbox(visible=False) - with gr.Row(): - stable_diffusion = gr.Button("Generate Image(s)") - random_seed = gr.Button("Randomize Seed") - random_seed.click( - lambda: -1, - inputs=[], - outputs=[seed], - queue=False, - ) - stop_batch = gr.Button("Stop Batch") - with gr.Row(): - blank_thing_for_row = None - with gr.Row(): - txt2img_sendto_img2img = gr.Button(value="SendTo Img2Img") - txt2img_sendto_inpaint = gr.Button(value="SendTo Inpaint") - txt2img_sendto_outpaint = gr.Button( - value="SendTo Outpaint" - ) - txt2img_sendto_upscaler = gr.Button( - value="SendTo Upscaler" - ) - with gr.Row(): - with gr.Column(scale=2): - export_defaults = gr.Button( - value="Load Default Settings" - ) - export_defaults.click( - fn=load_settings, - inputs=[], - outputs=[ - txt2img_custom_model, - custom_vae, - prompt, - negative_prompt, - lora_weights, - lora_hf_id, - scheduler, - save_metadata_to_png, - save_metadata_to_json, - height, - width, - steps, - guidance_scale, - ondemand, - use_hiresfix, - resample_type, - hiresfix_height, - hiresfix_width, - hiresfix_strength, - batch_count, - batch_size, - repeatable_seeds, - seed, - device, - ], - ) - with gr.Column(scale=2): - export_defaults = gr.Button( - value="Export Default Settings" - ) - export_defaults.click( - fn=export_settings, - inputs=[ - txt2img_custom_model, - custom_vae, - prompt, - negative_prompt, - lora_weights, - lora_hf_id, - scheduler, - save_metadata_to_png, - save_metadata_to_json, - height, - width, - steps, - guidance_scale, - ondemand, - use_hiresfix, - resample_type, - hiresfix_height, - hiresfix_width, - hiresfix_strength, - batch_count, - batch_size, - repeatable_seeds, - seed, - device, - ], - outputs=[], - ) - - kwargs = dict( - fn=txt2img_inf, - inputs=[ - prompt, - negative_prompt, - height, - width, - steps, - guidance_scale, - seed, - batch_count, - batch_size, - scheduler, - txt2img_custom_model, - custom_vae, - precision, - device, - max_length, - save_metadata_to_json, - save_metadata_to_png, - lora_weights, - lora_hf_id, - ondemand, - repeatable_seeds, - use_hiresfix, - hiresfix_height, - hiresfix_width, - hiresfix_strength, - resample_type, - ], - outputs=[txt2img_gallery, std_output, txt2img_status], - show_progress="minimal" if args.progress_bar else "none", - ) - - status_kwargs = dict( - fn=lambda bc, bs: status_label("Text-to-Image", 0, bc, bs), - inputs=[batch_count, batch_size], - outputs=txt2img_status, - ) - - prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **kwargs - ) - generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) - stop_batch.click( - fn=cancel_sd, - cancels=[prompt_submit, neg_prompt_submit, generate_click], - ) - - txt2img_png_info_img.change( - fn=import_png_metadata, - inputs=[ - txt2img_png_info_img, - prompt, - negative_prompt, - steps, - scheduler, - guidance_scale, - seed, - width, - height, - txt2img_custom_model, - lora_weights, - lora_hf_id, - custom_vae, - ], - outputs=[ - txt2img_png_info_img, - prompt, - negative_prompt, - steps, - scheduler, - guidance_scale, - seed, - width, - height, - txt2img_custom_model, - lora_weights, - lora_hf_id, - custom_vae, - ], - ) - - # SharkEulerDiscrete doesn't work with img2img which hires_fix uses - def set_compatible_schedulers(hires_fix_selected): - if hires_fix_selected: - return gr.Dropdown( - choices=scheduler_list_cpu_only, - value="DEISMultistep", - ) - else: - return gr.Dropdown( - choices=scheduler_list, - value="SharkEulerDiscrete", - ) - - use_hiresfix.change( - fn=set_compatible_schedulers, - inputs=[use_hiresfix], - outputs=[scheduler], - queue=False, - ) - - lora_weights.change( - fn=lora_changed, - inputs=[lora_weights], - outputs=[lora_tags], - queue=True, - ) diff --git a/apps/stable_diffusion/web/ui/upscaler_ui.py b/apps/stable_diffusion/web/ui/upscaler_ui.py deleted file mode 100644 index 88d0507a..00000000 --- a/apps/stable_diffusion/web/ui/upscaler_ui.py +++ /dev/null @@ -1,554 +0,0 @@ -import os -import torch -import time -import gradio as gr -from PIL import Image - -from apps.stable_diffusion.web.ui.utils import ( - available_devices, - nodlogo_loc, - get_custom_model_path, - get_custom_model_files, - scheduler_list_cpu_only, - predefined_upscaler_models, - cancel_sd, -) -from apps.stable_diffusion.web.ui.common_ui_events import lora_changed -from apps.stable_diffusion.web.utils.common_label_calc import status_label -from apps.stable_diffusion.src import ( - args, - UpscalerPipeline, - get_schedulers, - set_init_device_flags, - utils, - save_output_img, -) -from apps.stable_diffusion.src.utils import get_generated_imgs_path - -# set initial values of iree_vulkan_target_triple, use_tuned and import_mlir. -init_iree_vulkan_target_triple = args.iree_vulkan_target_triple -init_use_tuned = args.use_tuned -init_import_mlir = args.import_mlir - - -# Exposed to UI. -def upscaler_inf( - prompt: str, - negative_prompt: str, - init_image, - height: int, - width: int, - steps: int, - noise_level: int, - guidance_scale: float, - seed: str, - batch_count: int, - batch_size: int, - scheduler: str, - model_id: str, - custom_vae: str, - precision: str, - device: str, - max_length: int, - save_metadata_to_json: bool, - save_metadata_to_png: bool, - lora_weights: str, - lora_hf_id: str, - ondemand: bool, - repeatable_seeds: bool, -): - from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_pathfile, - get_custom_vae_or_lora_weights, - Config, - ) - import apps.stable_diffusion.web.utils.global_obj as global_obj - from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - SD_STATE_CANCEL, - ) - - args.prompts = [prompt] - args.negative_prompts = [negative_prompt] - args.guidance_scale = guidance_scale - args.seed = seed - args.steps = steps - args.scheduler = scheduler - args.ondemand = ondemand - - if init_image is None: - return None, "An Initial Image is required" - image = init_image.convert("RGB").resize((height, width)) - - # set ckpt_loc and hf_model_id. - args.ckpt_loc = "" - args.hf_model_id = "" - args.custom_vae = "" - - # .safetensor or .chkpt on the custom model path - if model_id in get_custom_model_files(custom_checkpoint_type="upscaler"): - args.ckpt_loc = get_custom_model_pathfile(model_id) - # civitai download - elif "civitai" in model_id: - args.ckpt_loc = model_id - # either predefined or huggingface - else: - args.hf_model_id = model_id - - if custom_vae != "None": - args.custom_vae = get_custom_model_pathfile(custom_vae, model="vae") - - args.save_metadata_to_json = save_metadata_to_json - args.write_metadata_to_png = save_metadata_to_png - - args.use_lora = get_custom_vae_or_lora_weights( - lora_weights, lora_hf_id, "lora" - ) - - dtype = torch.float32 if precision == "fp32" else torch.half - cpu_scheduling = not scheduler.startswith("Shark") - args.height = 128 - args.width = 128 - new_config_obj = Config( - "upscaler", - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - precision, - batch_size, - max_length, - args.height, - args.width, - device, - use_lora=args.use_lora, - stencils=[], - ondemand=ondemand, - ) - if ( - not global_obj.get_sd_obj() - or global_obj.get_cfg_obj() != new_config_obj - ): - global_obj.clear_cache() - global_obj.set_cfg_obj(new_config_obj) - args.batch_size = batch_size - args.max_length = max_length - args.device = device.split("=>", 1)[1].strip() - args.iree_vulkan_target_triple = init_iree_vulkan_target_triple - args.use_tuned = init_use_tuned - args.import_mlir = init_import_mlir - set_init_device_flags() - model_id = ( - args.hf_model_id - if args.hf_model_id - else "stabilityai/stable-diffusion-2-1-base" - ) - global_obj.set_schedulers(get_schedulers(model_id)) - scheduler_obj = global_obj.get_scheduler(scheduler) - global_obj.set_sd_obj( - UpscalerPipeline.from_pretrained( - scheduler_obj, - args.import_mlir, - args.hf_model_id, - args.ckpt_loc, - args.custom_vae, - args.precision, - args.max_length, - args.batch_size, - args.height, - args.width, - args.use_base_vae, - args.use_tuned, - low_cpu_mem_usage=args.low_cpu_mem_usage, - use_lora=args.use_lora, - ondemand=args.ondemand, - ) - ) - - global_obj.set_sd_scheduler(scheduler) - global_obj.get_sd_obj().low_res_scheduler = global_obj.get_scheduler( - "DDPM" - ) - - start_time = time.time() - global_obj.get_sd_obj().log = "" - generated_imgs = [] - extra_info = {"NOISE LEVEL": noise_level} - try: - seeds = utils.batch_seeds(seed, batch_count, repeatable_seeds) - except TypeError as error: - raise gr.Error(str(error)) from None - - for current_batch in range(batch_count): - low_res_img = image - high_res_img = Image.new("RGB", (height * 4, width * 4)) - - for i in range(0, width, 128): - for j in range(0, height, 128): - box = (j, i, j + 128, i + 128) - upscaled_image = global_obj.get_sd_obj().generate_images( - prompt, - negative_prompt, - low_res_img.crop(box), - batch_size, - args.height, - args.width, - steps, - noise_level, - guidance_scale, - seeds[current_batch], - args.max_length, - dtype, - args.use_base_vae, - cpu_scheduling, - args.max_embeddings_multiples, - ) - if global_obj.get_sd_status() == SD_STATE_CANCEL: - break - else: - high_res_img.paste(upscaled_image[0], (j * 4, i * 4)) - - if global_obj.get_sd_status() == SD_STATE_CANCEL: - break - - total_time = time.time() - start_time - text_output = f"prompt={args.prompts}" - text_output += f"\nnegative prompt={args.negative_prompts}" - text_output += ( - f"\nmodel_id={args.hf_model_id}, " f"ckpt_loc={args.ckpt_loc}" - ) - text_output += f"\nscheduler={args.scheduler}, " f"device={device}" - text_output += ( - f"\nsteps={steps}, " - f"noise_level={noise_level}, " - f"guidance_scale={guidance_scale}, " - f"seed={seeds[:current_batch + 1]}" - ) - text_output += ( - f"\ninput size={height}x{width}, " - f"output size={height*4}x{width*4}, " - f"batch_count={batch_count}, " - f"batch_size={batch_size}, " - f"max_length={args.max_length}\n" - ) - - text_output += global_obj.get_sd_obj().log - text_output += f"\nTotal image generation time: {total_time:.4f}sec" - - if global_obj.get_sd_status() == SD_STATE_CANCEL: - break - else: - save_output_img(high_res_img, seeds[current_batch], extra_info) - generated_imgs.append(high_res_img) - global_obj.get_sd_obj().log += "\n" - yield generated_imgs, text_output, status_label( - "Upscaler", current_batch + 1, batch_count, batch_size - ) - - yield generated_imgs, text_output, "" - - -with gr.Blocks(title="Upscaler") as upscaler_web: - with gr.Row(elem_id="ui_title"): - nod_logo = Image.open(nodlogo_loc) - with gr.Row(): - with gr.Column(scale=1, elem_id="demo_title_outer"): - gr.Image( - value=nod_logo, - show_label=False, - interactive=False, - show_download_button=False, - elem_id="top_logo", - width=150, - height=50, - ) - with gr.Row(elem_id="ui_body"): - with gr.Row(): - with gr.Column(scale=1, min_width=600): - upscaler_init_image = gr.Image( - label="Input Image", - type="pil", - sources=["upload"], - ) - with gr.Row(): - upscaler_model_info = ( - f"Custom Model Path: {str(get_custom_model_path())}" - ) - upscaler_custom_model = gr.Dropdown( - label=f"Models", - info="Select, or enter HuggingFace Model ID or Civitai model download URL", - elem_id="custom_model", - value=os.path.basename(args.ckpt_loc) - if args.ckpt_loc - else "stabilityai/stable-diffusion-x4-upscaler", - choices=get_custom_model_files( - custom_checkpoint_type="upscaler" - ) - + predefined_upscaler_models, - allow_custom_value=True, - scale=2, - ) - # janky fix for overflowing text - upscaler_vae_info = ( - str(get_custom_model_path("vae")) - ).replace("\\", "\n\\") - upscaler_vae_info = f"VAE Path: {upscaler_vae_info}" - custom_vae = gr.Dropdown( - label=f"Custom VAE Models", - info=upscaler_vae_info, - elem_id="custom_model", - value=os.path.basename(args.custom_vae) - if args.custom_vae - else "None", - choices=["None"] + get_custom_model_files("vae"), - allow_custom_value=True, - scale=1, - ) - - with gr.Group(elem_id="prompt_box_outer"): - prompt = gr.Textbox( - label="Prompt", - value=args.prompts[0], - lines=2, - elem_id="prompt_box", - ) - negative_prompt = gr.Textbox( - label="Negative Prompt", - value=args.negative_prompts[0], - lines=2, - elem_id="negative_prompt_box", - ) - with gr.Accordion(label="LoRA Options", open=False): - with gr.Row(): - # janky fix for overflowing text - upscaler_lora_info = ( - str(get_custom_model_path("lora")) - ).replace("\\", "\n\\") - upscaler_lora_info = f"LoRA Path: {upscaler_lora_info}" - lora_weights = gr.Dropdown( - label=f"Standalone LoRA Weights", - info=upscaler_lora_info, - elem_id="lora_weights", - value="None", - choices=["None"] + get_custom_model_files("lora"), - allow_custom_value=True, - ) - lora_hf_id = gr.Textbox( - elem_id="lora_hf_id", - placeholder="Select 'None' in the Standalone LoRA " - "weights dropdown on the left if you want to use " - "a standalone HuggingFace model ID for LoRA here " - "e.g: sayakpaul/sd-model-finetuned-lora-t4", - value="", - label="HuggingFace Model ID", - lines=3, - ) - with gr.Row(): - lora_tags = gr.HTML( - value="
    No LoRA selected
    ", - elem_classes="lora-tags", - ) - with gr.Accordion(label="Advanced Options", open=False): - with gr.Row(): - scheduler = gr.Dropdown( - elem_id="scheduler", - label="Scheduler", - value="DDIM", - choices=scheduler_list_cpu_only, - allow_custom_value=True, - ) - with gr.Group(): - save_metadata_to_png = gr.Checkbox( - label="Save prompt information to PNG", - value=args.write_metadata_to_png, - interactive=True, - ) - save_metadata_to_json = gr.Checkbox( - label="Save prompt information to JSON file", - value=args.save_metadata_to_json, - interactive=True, - ) - with gr.Row(): - height = gr.Slider( - 128, - 512, - value=args.height, - step=128, - label="Height", - ) - width = gr.Slider( - 128, - 512, - value=args.width, - step=128, - label="Width", - ) - precision = gr.Radio( - label="Precision", - value=args.precision, - choices=[ - "fp16", - "fp32", - ], - visible=True, - ) - max_length = gr.Radio( - label="Max Length", - value=args.max_length, - choices=[ - 64, - 77, - ], - visible=False, - ) - with gr.Row(): - steps = gr.Slider( - 1, 100, value=args.steps, step=1, label="Steps" - ) - noise_level = gr.Slider( - 0, - 100, - value=args.noise_level, - step=1, - label="Noise Level", - ) - ondemand = gr.Checkbox( - value=args.ondemand, - label="Low VRAM", - interactive=True, - ) - with gr.Row(): - with gr.Column(scale=3): - guidance_scale = gr.Slider( - 0, - 50, - value=args.guidance_scale, - step=0.1, - label="CFG Scale", - ) - with gr.Column(scale=3): - batch_count = gr.Slider( - 1, - 100, - value=args.batch_count, - step=1, - label="Batch Count", - interactive=True, - ) - repeatable_seeds = gr.Checkbox( - args.repeatable_seeds, - label="Repeatable Seeds", - ) - with gr.Row(): - batch_size = gr.Slider( - 1, - 4, - value=args.batch_size, - step=1, - label="Batch Size", - interactive=False, - visible=False, - ) - with gr.Row(): - seed = gr.Textbox( - value=args.seed, - label="Seed", - info="An integer or a JSON list of integers, -1 for random", - ) - device = gr.Dropdown( - elem_id="device", - label="Device", - value=available_devices[0], - choices=available_devices, - allow_custom_value=True, - ) - - with gr.Column(scale=1, min_width=600): - with gr.Group(): - upscaler_gallery = gr.Gallery( - label="Generated images", - show_label=False, - elem_id="gallery", - columns=[2], - object_fit="contain", - # TODO: Re-enable download when fixed in Gradio - show_download_button=False, - ) - std_output = gr.Textbox( - value=f"{upscaler_model_info}\n" - f"Images will be saved at " - f"{get_generated_imgs_path()}", - lines=2, - elem_id="std_output", - show_label=False, - ) - upscaler_status = gr.Textbox(visible=False) - with gr.Row(): - stable_diffusion = gr.Button("Generate Image(s)") - random_seed = gr.Button("Randomize Seed") - random_seed.click( - lambda: -1, - inputs=[], - outputs=[seed], - queue=False, - ) - stop_batch = gr.Button("Stop Batch") - with gr.Row(): - blank_thing_for_row = None - with gr.Row(): - upscaler_sendto_img2img = gr.Button(value="SendTo Img2Img") - upscaler_sendto_inpaint = gr.Button(value="SendTo Inpaint") - upscaler_sendto_outpaint = gr.Button( - value="SendTo Outpaint" - ) - - kwargs = dict( - fn=upscaler_inf, - inputs=[ - prompt, - negative_prompt, - upscaler_init_image, - height, - width, - steps, - noise_level, - guidance_scale, - seed, - batch_count, - batch_size, - scheduler, - upscaler_custom_model, - custom_vae, - precision, - device, - max_length, - save_metadata_to_json, - save_metadata_to_png, - lora_weights, - lora_hf_id, - ondemand, - repeatable_seeds, - ], - outputs=[upscaler_gallery, std_output, upscaler_status], - show_progress="minimal" if args.progress_bar else "none", - ) - status_kwargs = dict( - fn=lambda bc, bs: status_label("Upscaler", 0, bc, bs), - inputs=[batch_count, batch_size], - outputs=upscaler_status, - ) - - prompt_submit = prompt.submit(**status_kwargs).then(**kwargs) - neg_prompt_submit = negative_prompt.submit(**status_kwargs).then( - **kwargs - ) - generate_click = stable_diffusion.click(**status_kwargs).then(**kwargs) - stop_batch.click( - fn=cancel_sd, - cancels=[prompt_submit, neg_prompt_submit, generate_click], - ) - - lora_weights.change( - fn=lora_changed, - inputs=[lora_weights], - outputs=[lora_tags], - queue=True, - ) diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py deleted file mode 100644 index 0572089e..00000000 --- a/apps/stable_diffusion/web/ui/utils.py +++ /dev/null @@ -1,377 +0,0 @@ -import os -import sys -import glob -import math -import json -import safetensors -import gradio as gr -import PIL.Image as Image - -from pathlib import Path -from apps.stable_diffusion.src import args -from dataclasses import dataclass -from enum import IntEnum -from gradio.components.image_editor import EditorValue - -from apps.stable_diffusion.src import get_available_devices -import apps.stable_diffusion.web.utils.global_obj as global_obj -from apps.stable_diffusion.src.pipelines.pipeline_shark_stable_diffusion_utils import ( - SD_STATE_CANCEL, -) - - -@dataclass -class Config: - mode: str - model_id: str - ckpt_loc: str - custom_vae: str - precision: str - batch_size: int - max_length: int - height: int - width: int - device: str - use_lora: str - stencils: list[str] - ondemand: str # should this be expecting a bool instead? - - -class HSLHue(IntEnum): - RED = 0 - YELLOW = 60 - GREEN = 120 - CYAN = 180 - BLUE = 240 - MAGENTA = 300 - - -custom_model_filetypes = ( - "*.ckpt", - "*.safetensors", -) # the tuple of file types - -scheduler_list_cpu_only = [ - "DDIM", - "PNDM", - "LMSDiscrete", - "KDPM2Discrete", - "DPMSolverMultistep", - "DPMSolverMultistep++", - "DPMSolverMultistepKarras", - "DPMSolverMultistepKarras++", - "EulerDiscrete", - "EulerAncestralDiscrete", - "DEISMultistep", - "KDPM2AncestralDiscrete", - "DPMSolverSinglestep", - "DDPM", - "HeunDiscrete", - "LCMScheduler", -] -scheduler_list = scheduler_list_cpu_only + [ - "SharkEulerDiscrete", - "SharkEulerAncestralDiscrete", -] - -predefined_models = [ - "Linaqruf/anything-v3.0", - "prompthero/openjourney", - "wavymulder/Analog-Diffusion", - "xzuyn/PhotoMerge", - "stabilityai/stable-diffusion-2-1", - "stabilityai/stable-diffusion-2-1-base", - "CompVis/stable-diffusion-v1-4", -] - -predefined_paint_models = [ - "runwayml/stable-diffusion-inpainting", - "stabilityai/stable-diffusion-2-inpainting", - "xzuyn/PhotoMerge-inpainting", -] -predefined_upscaler_models = [ - "stabilityai/stable-diffusion-x4-upscaler", -] -predefined_sdxl_models = [ - "stabilityai/sdxl-turbo", - "stabilityai/stable-diffusion-xl-base-1.0", -] - - -def resource_path(relative_path): - """Get absolute path to resource, works for dev and for PyInstaller""" - base_path = getattr( - sys, "_MEIPASS", os.path.dirname(os.path.abspath(__file__)) - ) - return os.path.join(base_path, relative_path) - - -def create_custom_models_folders(): - dir = ["vae", "lora"] - if not args.ckpt_dir: - dir.insert(0, "models") - else: - if not os.path.isdir(args.ckpt_dir): - sys.exit( - f"Invalid --ckpt_dir argument, " - f"{args.ckpt_dir} folder does not exists." - ) - for root in dir: - get_custom_model_path(root).mkdir(parents=True, exist_ok=True) - - -def get_custom_model_path(model="models"): - # structure in WebUI :- - # models or args.ckpt_dir - # |___lora - # |___vae - sub_folder = "" if model == "models" else model - if args.ckpt_dir: - return Path(Path(args.ckpt_dir), sub_folder) - else: - return Path(Path.cwd(), "models/" + sub_folder) - - -def get_custom_model_pathfile(custom_model_name, model="models"): - return os.path.join(get_custom_model_path(model), custom_model_name) - - -def get_custom_model_files(model="models", custom_checkpoint_type=""): - ckpt_files = [] - file_types = custom_model_filetypes - if model == "lora": - file_types = custom_model_filetypes + ("*.pt", "*.bin") - for extn in file_types: - files = [ - os.path.basename(x) - for x in glob.glob( - os.path.join(get_custom_model_path(model), extn) - ) - ] - match custom_checkpoint_type: - case "sdxl": - files = [ - val - for val in files - if any(x in val for x in ["XL", "xl", "Xl"]) - ] - case "inpainting": - files = [ - val - for val in files - if val.endswith("inpainting" + extn.removeprefix("*")) - ] - case "upscaler": - files = [ - val - for val in files - if val.endswith("upscaler" + extn.removeprefix("*")) - ] - case _: - files = [ - val - for val in files - if not ( - val.endswith("inpainting" + extn.removeprefix("*")) - or val.endswith("upscaler" + extn.removeprefix("*")) - ) - ] - ckpt_files.extend(files) - return sorted(ckpt_files, key=str.casefold) - - -def get_custom_vae_or_lora_weights(weights, hf_id, model): - use_weight = "" - if weights == "None" and not hf_id: - use_weight = "" - elif not hf_id: - use_weight = get_custom_model_pathfile(weights, model) - else: - use_weight = hf_id - return use_weight - - -def hsl_color(alpha: float, start, end): - b = (end - start) * (alpha if alpha > 0 else 0) - result = b + start - - # Return a CSS HSL string - return f"hsl({math.floor(result)}, 80%, 35%)" - - -def get_lora_metadata(lora_filename): - # get the metadata from the file - filename = get_custom_model_pathfile(lora_filename, "lora") - with safetensors.safe_open(filename, framework="pt", device="cpu") as f: - metadata = f.metadata() - - # guard clause for if there isn't any metadata - if not metadata: - return None - - # metadata is a dictionary of strings, the values of the keys we're - # interested in are actually json, and need to be loaded as such - tag_frequencies = json.loads(metadata.get("ss_tag_frequency", str("{}"))) - dataset_dirs = json.loads(metadata.get("ss_dataset_dirs", str("{}"))) - tag_dirs = [dir for dir in tag_frequencies.keys()] - - # gather the tag frequency information for all the datasets trained - all_frequencies = {} - for dataset in tag_dirs: - frequencies = sorted( - [entry for entry in tag_frequencies[dataset].items()], - reverse=True, - key=lambda x: x[1], - ) - - # get a figure for the total number of images processed for this dataset - # either then number actually listed or in its dataset_dir entry or - # the highest frequency's number if that doesn't exist - img_count = dataset_dirs.get(dir, {}).get( - "img_count", frequencies[0][1] - ) - - # add the dataset frequencies to the overall frequencies replacing the - # frequency counts on the tags with a percentage/ratio - all_frequencies.update( - [(entry[0], entry[1] / img_count) for entry in frequencies] - ) - - trained_model_id = " ".join( - [ - metadata.get("ss_sd_model_hash", ""), - metadata.get("ss_sd_model_name", ""), - metadata.get("ss_base_model_version", ""), - ] - ).strip() - - # return the topmost of all frequencies in all datasets - return { - "model": trained_model_id, - "frequencies": sorted( - all_frequencies.items(), reverse=True, key=lambda x: x[1] - ), - } - - -def cancel_sd(): - # Try catch it, as gc can delete global_obj.sd_obj while switching model - try: - global_obj.set_sd_status(SD_STATE_CANCEL) - except Exception: - pass - - -def set_model_default_configs(model_ckpt_or_id, jsonconfig=None): - import gradio as gr - - config_modelname = default_config_exists(model_ckpt_or_id) - if jsonconfig: - return get_config_from_json(jsonconfig) - elif config_modelname: - return default_configs[config_modelname] - # TODO: Use HF metadata to setup pipeline if available - # elif is_valid_hf_id(model_ckpt_or_id): - # return get_HF_default_configs(model_ckpt_or_id) - else: - # We don't have default metadata to setup a good config. Do not change configs. - return [ - gr.Textbox(label="Prompt", interactive=True, visible=True), - gr.Textbox(label="Negative Prompt", interactive=True), - gr.update(), - gr.update(), - gr.update(), - gr.update(), - gr.update(), - gr.update(), - gr.Checkbox( - label="Auto-Generate", - visible=False, - interactive=False, - value=False, - ), - ] - - -def get_config_from_json(model_ckpt_or_id, jsonconfig): - # TODO: make this work properly. It is currently not user-exposed. - cfgdata = json.load(jsonconfig) - return [ - cfgdata["prompt_box_behavior"], - cfgdata["neg_prompt_box_behavior"], - cfgdata["steps"], - cfgdata["scheduler"], - cfgdata["guidance_scale"], - cfgdata["width"], - cfgdata["height"], - cfgdata["custom_vae"], - ] - - -def default_config_exists(model_ckpt_or_id): - if model_ckpt_or_id in default_configs.keys(): - return model_ckpt_or_id - elif "turbo" in model_ckpt_or_id.lower(): - return "stabilityai/sdxl-turbo" - else: - return None - - -def mask_editor_value_for_image_file(filepath): - image = Image.open(filepath) - mask = Image.new(mode="RGBA", size=image.size, color=(0, 0, 0, 0)) - return {"background": image, "layers": [mask], "composite": image} - - -def mask_editor_value_for_gallery_data(gallery_data): - filepath = ( - gallery_data.root[0].image.path - if len(gallery_data.root) != 0 - else None - ) - - if os.path.isfile(filepath): - return mask_editor_value_for_image_file(filepath) - - return EditorValue() - - -default_configs = { - "stabilityai/sdxl-turbo": [ - gr.Textbox(label="", interactive=False, value=None, visible=False), - gr.Textbox( - label="Prompt", - value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette", - ), - gr.Slider(0, 10, value=2), - "EulerAncestralDiscrete", - gr.Slider(0, value=0), - 512, - 512, - "madebyollin/sdxl-vae-fp16-fix", - gr.Checkbox( - label="Auto-Generate", visible=False, interactive=True, value=False - ), - ], - "stabilityai/stable-diffusion-xl-base-1.0": [ - gr.Textbox(label="Prompt", interactive=True, visible=True), - gr.Textbox(label="Negative Prompt", interactive=True), - 40, - "EulerDiscrete", - 7.5, - gr.Slider(value=768, interactive=True), - gr.Slider(value=768, interactive=True), - "madebyollin/sdxl-vae-fp16-fix", - gr.Checkbox( - label="Auto-Generate", - visible=False, - interactive=False, - value=False, - ), - ], -} - - -nodlogo_loc = resource_path("logos/nod-logo.png") -nodicon_loc = resource_path("logos/nod-icon.png") -available_devices = get_available_devices() diff --git a/apps/stable_diffusion/web/utils/app.py b/apps/stable_diffusion/web/utils/app.py deleted file mode 100644 index 4f927578..00000000 --- a/apps/stable_diffusion/web/utils/app.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -import sys -import webview -import webview.util -import socket - -from contextlib import closing -from multiprocessing import Process - -from apps.stable_diffusion.src import args - - -def webview2_installed(): - if sys.platform != "win32": - return False - - # On windows we want to ensure we have MS webview2 available so we don't fall back - # to MSHTML (aka ye olde Internet Explorer) which is deprecated by pywebview, and - # apparently causes SHARK not to load in properly. - - # Checking these registry entries is how Microsoft says to detect a webview2 installation: - # https://learn.microsoft.com/en-us/microsoft-edge/webview2/concepts/distribution - import winreg - - path = r"SOFTWARE\WOW6432Node\Microsoft\EdgeUpdate\Clients\{F3017226-FE2A-4295-8BDF-00C3A9A7E4C5}" - - # only way can find if a registry entry even exists is to try and open it - try: - # check for an all user install - with winreg.OpenKey( - winreg.HKEY_LOCAL_MACHINE, - path, - 0, - winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY, - ) as registry_key: - value, type = winreg.QueryValueEx(registry_key, "pv") - - # if it didn't exist, we want to continue on... - except WindowsError: - try: - # ...to check for a current user install - with winreg.OpenKey( - winreg.HKEY_CURRENT_USER, - path, - 0, - winreg.KEY_QUERY_VALUE | winreg.KEY_WOW64_64KEY, - ) as registry_key: - value, type = winreg.QueryValueEx(registry_key, "pv") - except WindowsError: - value = None - finally: - return (value is not None) and value != "" and value != "0.0.0.0" - - -def window(address): - from tkinter import Tk - - window = Tk() - - # get screen width and height of display and make it more reasonably - # sized as we aren't making it full-screen or maximized - width = int(window.winfo_screenwidth() * 0.81) - height = int(window.winfo_screenheight() * 0.91) - webview.create_window( - "SHARK AI Studio", - url=address, - width=width, - height=height, - text_select=True, - ) - webview.start(private_mode=False, storage_path=os.getcwd()) - - -def usable_port(): - # Make sure we can actually use the port given in args.server_port. If - # not ask the OS for a port and return that as our port to use. - - port = args.server_port - - with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as sock: - try: - sock.bind(("0.0.0.0", port)) - except OSError: - with closing( - socket.socket(socket.AF_INET, socket.SOCK_STREAM) - ) as sock: - sock.bind(("0.0.0.0", 0)) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - return sock.getsockname()[1] - - return port - - -def launch(port): - # setup to launch as an app if app mode has been requested and we're able - # to do it, answering whether we succeeded. - if args.ui == "app" and (sys.platform != "win32" or webview2_installed()): - try: - t = Process(target=window, args=[f"http://localhost:{port}"]) - t.start() - return True - except webview.util.WebViewException: - return False - else: - return False diff --git a/apps/stable_diffusion/web/utils/common_label_calc.py b/apps/stable_diffusion/web/utils/common_label_calc.py deleted file mode 100644 index f55e2396..00000000 --- a/apps/stable_diffusion/web/utils/common_label_calc.py +++ /dev/null @@ -1,9 +0,0 @@ -# functions for generating labels used in common by tabs across the UI - - -def status_label(tab_name, batch_index=0, batch_count=1, batch_size=1): - if batch_index < batch_count: - bs = f"x{batch_size}" if batch_size > 1 else "" - return f"{tab_name} generating {batch_index+1}/{batch_count}{bs}" - else: - return f"{tab_name} complete" diff --git a/apps/stable_diffusion/web/utils/global_obj.py b/apps/stable_diffusion/web/utils/global_obj.py deleted file mode 100644 index c1a4aae9..00000000 --- a/apps/stable_diffusion/web/utils/global_obj.py +++ /dev/null @@ -1,75 +0,0 @@ -import gc - - -""" -The global objects include SD pipeline and config. -Maintaining the global objects would avoid creating extra pipeline objects when switching modes. -Also we could avoid memory leak when switching models by clearing the cache. -""" - - -def _init(): - global _sd_obj - global _config_obj - global _schedulers - _sd_obj = None - _config_obj = None - _schedulers = None - - -def set_sd_obj(value): - global _sd_obj - _sd_obj = value - - -def set_sd_scheduler(key): - global _sd_obj - _sd_obj.scheduler = _schedulers[key] - - -def set_sd_status(value): - global _sd_obj - _sd_obj.status = value - - -def set_cfg_obj(value): - global _config_obj - _config_obj = value - - -def set_schedulers(value): - global _schedulers - _schedulers = value - - -def get_sd_obj(): - global _sd_obj - return _sd_obj - - -def get_sd_status(): - global _sd_obj - return _sd_obj.status - - -def get_cfg_obj(): - global _config_obj - return _config_obj - - -def get_scheduler(key): - global _schedulers - return _schedulers[key] - - -def clear_cache(): - global _sd_obj - global _config_obj - global _schedulers - del _sd_obj - del _config_obj - del _schedulers - gc.collect() - _sd_obj = None - _config_obj = None - _schedulers = None diff --git a/apps/stable_diffusion/web/utils/metadata/__init__.py b/apps/stable_diffusion/web/utils/metadata/__init__.py deleted file mode 100644 index bcbcf746..00000000 --- a/apps/stable_diffusion/web/utils/metadata/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .png_metadata import ( - import_png_metadata, -) -from .display import ( - displayable_metadata, -) diff --git a/apps/stable_diffusion/web/utils/metadata/csv_metadata.py b/apps/stable_diffusion/web/utils/metadata/csv_metadata.py deleted file mode 100644 index d617e802..00000000 --- a/apps/stable_diffusion/web/utils/metadata/csv_metadata.py +++ /dev/null @@ -1,45 +0,0 @@ -import csv -import os -from .format import humanize, humanizable - - -def csv_path(image_filename: str): - return os.path.join(os.path.dirname(image_filename), "imgs_details.csv") - - -def has_csv(image_filename: str) -> bool: - return os.path.exists(csv_path(image_filename)) - - -def matching_filename(image_filename: str, row): - # we assume the final column of the csv has the original filename with full path and match that - # against the image_filename if we are given a list. Otherwise we assume a dict and and take - # the value of the OUTPUT key - return os.path.basename(image_filename) in ( - row[-1] if isinstance(row, list) else row["OUTPUT"] - ) - - -def parse_csv(image_filename: str): - csv_filename = csv_path(image_filename) - - with open(csv_filename, "r", newline="") as csv_file: - # We use a reader or DictReader here for images_details.csv depending on whether we think it - # has headers or not. Having headers means less guessing of the format. - has_header = csv.Sniffer().has_header(csv_file.read(2048)) - csv_file.seek(0) - - reader = ( - csv.DictReader(csv_file) if has_header else csv.reader(csv_file) - ) - - matches = [ - # we rely on humanize and humanizable to work out the parsing of the individual .csv rows - humanize(row) - for row in reader - if row - and (has_header or humanizable(row)) - and matching_filename(image_filename, row) - ] - - return matches[0] if matches else {} diff --git a/apps/stable_diffusion/web/utils/metadata/display.py b/apps/stable_diffusion/web/utils/metadata/display.py deleted file mode 100644 index 26234aab..00000000 --- a/apps/stable_diffusion/web/utils/metadata/display.py +++ /dev/null @@ -1,53 +0,0 @@ -import json -import os -from PIL import Image -from .png_metadata import parse_generation_parameters -from .exif_metadata import has_exif, parse_exif -from .csv_metadata import has_csv, parse_csv -from .format import compact, humanize - - -def displayable_metadata(image_filename: str) -> dict: - if not os.path.isfile(image_filename): - return {"source": "missing", "parameters": {}} - - pil_image = Image.open(image_filename) - - # we have PNG generation parameters (preferred, as it's what the txt2img dropzone reads, - # and we go via that for SendTo, and is directly tied to the image) - if "parameters" in pil_image.info: - return { - "source": "png", - "parameters": compact( - parse_generation_parameters(pil_image.info["parameters"]) - ), - } - - # we have a matching json file (next most likely to be accurate when it's there) - json_path = os.path.splitext(image_filename)[0] + ".json" - if os.path.isfile(json_path): - with open(json_path) as params_file: - return { - "source": "json", - "parameters": compact( - humanize(json.load(params_file), includes_filename=False) - ), - } - - # we have a CSV file so try that (can be different shapes, and it usually has no - # headers/param names so of the things we we *know* have parameters, it's the - # last resort) - if has_csv(image_filename): - params = parse_csv(image_filename) - if params: # we might not have found the filename in the csv - return { - "source": "csv", - "parameters": compact(params), # already humanized - } - - # EXIF data, probably a .jpeg, may well not include parameters, but at least it's *something* - if has_exif(image_filename): - return {"source": "exif", "parameters": parse_exif(pil_image)} - - # we've got nothing - return None diff --git a/apps/stable_diffusion/web/utils/metadata/exif_metadata.py b/apps/stable_diffusion/web/utils/metadata/exif_metadata.py deleted file mode 100644 index c72da8a9..00000000 --- a/apps/stable_diffusion/web/utils/metadata/exif_metadata.py +++ /dev/null @@ -1,52 +0,0 @@ -from PIL import Image -from PIL.ExifTags import Base as EXIFKeys, TAGS, IFD, GPSTAGS - - -def has_exif(image_filename: str) -> bool: - return True if Image.open(image_filename).getexif() else False - - -def parse_exif(pil_image: Image) -> dict: - img_exif = pil_image.getexif() - - # See this stackoverflow answer for where most this comes from: https://stackoverflow.com/a/75357594 - # I did try to use the exif library but it broke just as much as my initial attempt at this (albeit I - # I was probably using it wrong) so I reverted back to using PIL with more filtering and saved a - # dependency - exif_tags = { - TAGS.get(key, key): str(val) - for (key, val) in img_exif.items() - if key in TAGS - and key not in (EXIFKeys.ExifOffset, EXIFKeys.GPSInfo) - and val - and (not isinstance(val, bytes)) - and (not str(val).isspace()) - } - - def try_get_ifd(ifd_id): - try: - return img_exif.get_ifd(ifd_id).items() - except KeyError: - return {} - - ifd_tags = { - TAGS.get(key, key): str(val) - for ifd_id in IFD - for (key, val) in try_get_ifd(ifd_id) - if ifd_id != IFD.GPSInfo - and key in TAGS - and val - and (not isinstance(val, bytes)) - and (not str(val).isspace()) - } - - gps_tags = { - GPSTAGS.get(key, key): str(val) - for (key, val) in try_get_ifd(IFD.GPSInfo) - if key in GPSTAGS - and val - and (not isinstance(val, bytes)) - and (not str(val).isspace()) - } - - return {**exif_tags, **ifd_tags, **gps_tags} diff --git a/apps/stable_diffusion/web/utils/metadata/format.py b/apps/stable_diffusion/web/utils/metadata/format.py deleted file mode 100644 index f097dab5..00000000 --- a/apps/stable_diffusion/web/utils/metadata/format.py +++ /dev/null @@ -1,143 +0,0 @@ -# As SHARK has evolved more columns have been added to images_details.csv. However, since -# no version of the CSV has any headers (yet) we don't actually have anything within the -# file that tells us which parameter each column is for. So this is a list of known patterns -# indexed by length which is what we're going to have to use to guess which columns are the -# right ones for the file we're looking at. - -# The same ordering is used for JSON, but these do have key names, however they are not very -# human friendly, nor do they match up with the what is written to the .png headers - -# So these are functions to try and get something consistent out the raw input from all -# these sources - -PARAMS_FORMATS = { - 9: { - "VARIANT": "Model", - "SCHEDULER": "Sampler", - "PROMPT": "Prompt", - "NEG_PROMPT": "Negative prompt", - "SEED": "Seed", - "CFG_SCALE": "CFG scale", - "PRECISION": "Precision", - "STEPS": "Steps", - "OUTPUT": "Filename", - }, - 10: { - "MODEL": "Model", - "VARIANT": "Variant", - "SCHEDULER": "Sampler", - "PROMPT": "Prompt", - "NEG_PROMPT": "Negative prompt", - "SEED": "Seed", - "CFG_SCALE": "CFG scale", - "PRECISION": "Precision", - "STEPS": "Steps", - "OUTPUT": "Filename", - }, - 12: { - "VARIANT": "Model", - "SCHEDULER": "Sampler", - "PROMPT": "Prompt", - "NEG_PROMPT": "Negative prompt", - "SEED": "Seed", - "CFG_SCALE": "CFG scale", - "PRECISION": "Precision", - "STEPS": "Steps", - "HEIGHT": "Height", - "WIDTH": "Width", - "MAX_LENGTH": "Max Length", - "OUTPUT": "Filename", - }, -} - -PARAMS_FORMAT_CURRENT = { - "VARIANT": "Model", - "VAE": "VAE", - "LORA": "LoRA", - "SCHEDULER": "Sampler", - "PROMPT": "Prompt", - "NEG_PROMPT": "Negative prompt", - "SEED": "Seed", - "CFG_SCALE": "CFG scale", - "PRECISION": "Precision", - "STEPS": "Steps", - "HEIGHT": "Height", - "WIDTH": "Width", - "MAX_LENGTH": "Max Length", - "OUTPUT": "Filename", -} - - -def compact(metadata: dict) -> dict: - # we don't want to alter the original dictionary - result = dict(metadata) - - # discard the filename because we should already have it - if result.keys() & {"Filename"}: - result.pop("Filename") - - # make showing the sizes more compact by using only one line each - if result.keys() & {"Size-1", "Size-2"}: - result["Size"] = f"{result.pop('Size-1')}x{result.pop('Size-2')}" - elif result.keys() & {"Height", "Width"}: - result["Size"] = f"{result.pop('Height')}x{result.pop('Width')}" - - if result.keys() & {"Hires resize-1", "Hires resize-1"}: - hires_y = result.pop("Hires resize-1") - hires_x = result.pop("Hires resize-2") - - if hires_x == 0 and hires_y == 0: - result["Hires resize"] = "None" - else: - result["Hires resize"] = f"{hires_y}x{hires_x}" - - # remove VAE if it exists and is empty - if (result.keys() & {"VAE"}) and ( - not result["VAE"] or result["VAE"] == "None" - ): - result.pop("VAE") - - # remove LoRA if it exists and is empty - if (result.keys() & {"LoRA"}) and ( - not result["LoRA"] or result["LoRA"] == "None" - ): - result.pop("LoRA") - - return result - - -def humanizable(metadata: dict | list[str], includes_filename=True) -> dict: - lookup_key = len(metadata) + (0 if includes_filename else 1) - return lookup_key in PARAMS_FORMATS.keys() - - -def humanize(metadata: dict | list[str], includes_filename=True) -> dict: - lookup_key = len(metadata) + (0 if includes_filename else 1) - - # For lists we can only work based on the length, we have no other information - if isinstance(metadata, list): - if humanizable(metadata, includes_filename): - return dict(zip(PARAMS_FORMATS[lookup_key].values(), metadata)) - else: - raise KeyError( - f"Humanize could not find the format for a parameter list of length {len(metadata)}" - ) - - # For dictionaries we try to use the matching length parameter format if - # available, otherwise we just use the current format which is assumed to - # have everything currently known about. Then we swap keys in the metadata - # that match keys in the format for the friendlier name that we have set - # in the format value - if isinstance(metadata, dict): - if humanizable(metadata, includes_filename): - format = PARAMS_FORMATS[lookup_key] - else: - format = PARAMS_FORMAT_CURRENT - - return { - format[key]: metadata[key] - for key in format.keys() - if key in metadata.keys() and metadata[key] - } - - raise TypeError("Can only humanize parameter lists or dictionaries") diff --git a/apps/stable_diffusion/web/utils/metadata/png_metadata.py b/apps/stable_diffusion/web/utils/metadata/png_metadata.py deleted file mode 100644 index f83d83a5..00000000 --- a/apps/stable_diffusion/web/utils/metadata/png_metadata.py +++ /dev/null @@ -1,220 +0,0 @@ -import re -from pathlib import Path -from apps.stable_diffusion.web.ui.utils import ( - get_custom_model_pathfile, - scheduler_list, - predefined_models, -) - -re_param_code = r'\s*([\w ]+):\s*("(?:\\"[^,]|\\"|\\|[^\"])+"|[^,]*)(?:,|$)' -re_param = re.compile(re_param_code) -re_imagesize = re.compile(r"^(\d+)x(\d+)$") - - -def parse_generation_parameters(x: str): - res = {} - prompt = "" - negative_prompt = "" - done_with_prompt = False - - *lines, lastline = x.strip().split("\n") - if len(re_param.findall(lastline)) < 3: - lines.append(lastline) - lastline = "" - - for i, line in enumerate(lines): - line = line.strip() - if line.startswith("Negative prompt:"): - done_with_prompt = True - line = line[16:].strip() - - if done_with_prompt: - negative_prompt += ("" if negative_prompt == "" else "\n") + line - else: - prompt += ("" if prompt == "" else "\n") + line - - res["Prompt"] = prompt - res["Negative prompt"] = negative_prompt - - for k, v in re_param.findall(lastline): - v = v[1:-1] if v[0] == '"' and v[-1] == '"' else v - m = re_imagesize.match(v) - if m is not None: - res[k + "-1"] = m.group(1) - res[k + "-2"] = m.group(2) - else: - res[k] = v - - # Missing CLIP skip means it was set to 1 (the default) - if "Clip skip" not in res: - res["Clip skip"] = "1" - - hypernet = res.get("Hypernet", None) - if hypernet is not None: - res[ - "Prompt" - ] += f"""""" - - if "Hires resize-1" not in res: - res["Hires resize-1"] = 0 - res["Hires resize-2"] = 0 - - return res - - -def try_find_model_base_from_png_metadata( - file: str, folder: str = "models" -) -> str: - custom = "" - - # Remove extension from file info - if file.endswith(".safetensors") or file.endswith(".ckpt"): - file = Path(file).stem - # Check for the file name match with one of the local ckpt or safetensors files - if Path(get_custom_model_pathfile(file + ".ckpt", folder)).is_file(): - custom = file + ".ckpt" - if Path( - get_custom_model_pathfile(file + ".safetensors", folder) - ).is_file(): - custom = file + ".safetensors" - - return custom - - -def find_model_from_png_metadata( - key: str, metadata: dict[str, str | int] -) -> tuple[str, str]: - png_hf_id = "" - png_custom = "" - - if key in metadata: - model_file = metadata[key] - png_custom = try_find_model_base_from_png_metadata(model_file) - # Check for a model match with one of the default model list (ex: "Linaqruf/anything-v3.0") - if model_file in predefined_models: - png_custom = model_file - # If nothing had matched, check vendor/hf_model_id - if not png_custom and model_file.count("/"): - png_hf_id = model_file - # No matching model was found - if not png_custom and not png_hf_id: - print( - "Import PNG info: Unable to find a matching model for %s" - % model_file - ) - - return png_custom, png_hf_id - - -def find_vae_from_png_metadata( - key: str, metadata: dict[str, str | int] -) -> str: - vae_custom = "" - - if key in metadata: - vae_file = metadata[key] - vae_custom = try_find_model_base_from_png_metadata(vae_file, "vae") - - # VAE input is optional, should not print or throw an error if missing - - return vae_custom - - -def find_lora_from_png_metadata( - key: str, metadata: dict[str, str | int] -) -> tuple[str, str]: - lora_hf_id = "" - lora_custom = "" - - if key in metadata: - lora_file = metadata[key] - lora_custom = try_find_model_base_from_png_metadata(lora_file, "lora") - # If nothing had matched, check vendor/hf_model_id - if not lora_custom and lora_file.count("/"): - lora_hf_id = lora_file - - # LoRA input is optional, should not print or throw an error if missing - - return lora_custom, lora_hf_id - - -def import_png_metadata( - pil_data, - prompt, - negative_prompt, - steps, - sampler, - cfg_scale, - seed, - width, - height, - custom_model, - custom_lora, - hf_lora_id, - custom_vae, -): - try: - png_info = pil_data.info["parameters"] - metadata = parse_generation_parameters(png_info) - - (png_custom_model, png_hf_model_id) = find_model_from_png_metadata( - "Model", metadata - ) - (lora_custom_model, lora_hf_model_id) = find_lora_from_png_metadata( - "LoRA", metadata - ) - vae_custom_model = find_vae_from_png_metadata("VAE", metadata) - - negative_prompt = metadata["Negative prompt"] - steps = int(metadata["Steps"]) - cfg_scale = float(metadata["CFG scale"]) - seed = int(metadata["Seed"]) - width = float(metadata["Size-1"]) - height = float(metadata["Size-2"]) - - if "Model" in metadata and png_custom_model: - custom_model = png_custom_model - elif "Model" in metadata and png_hf_model_id: - custom_model = png_hf_model_id - - if "LoRA" in metadata and lora_custom_model: - custom_lora = lora_custom_model - hf_lora_id = "" - if "LoRA" in metadata and lora_hf_model_id: - custom_lora = "None" - hf_lora_id = lora_hf_model_id - - if "VAE" in metadata and vae_custom_model: - custom_vae = vae_custom_model - - if "Prompt" in metadata: - prompt = metadata["Prompt"] - if "Sampler" in metadata: - if metadata["Sampler"] in scheduler_list: - sampler = metadata["Sampler"] - else: - print( - "Import PNG info: Unable to find a scheduler for %s" - % metadata["Sampler"] - ) - - except Exception as ex: - if pil_data and pil_data.info.get("parameters"): - print("import_png_metadata failed with %s" % ex) - pass - - return ( - None, - prompt, - negative_prompt, - steps, - sampler, - cfg_scale, - seed, - width, - height, - custom_model, - custom_lora, - hf_lora_id, - custom_vae, - ) diff --git a/apps/stable_diffusion/web/utils/tmp_configs.py b/apps/stable_diffusion/web/utils/tmp_configs.py deleted file mode 100644 index 3e6ba46b..00000000 --- a/apps/stable_diffusion/web/utils/tmp_configs.py +++ /dev/null @@ -1,77 +0,0 @@ -import os -import shutil -from time import time - -shark_tmp = os.path.join(os.getcwd(), "shark_tmp/") - - -def clear_tmp_mlir(): - cleanup_start = time() - print( - "Clearing .mlir temporary files from a prior run. This may take some time..." - ) - mlir_files = [ - filename - for filename in os.listdir(shark_tmp) - if os.path.isfile(os.path.join(shark_tmp, filename)) - and filename.endswith(".mlir") - ] - for filename in mlir_files: - os.remove(shark_tmp + filename) - print( - f"Clearing .mlir temporary files took {time() - cleanup_start:.4f} seconds." - ) - - -def clear_tmp_imgs(): - # tell gradio to use a directory under shark_tmp for its temporary - # image files unless somewhere else has been set - if "GRADIO_TEMP_DIR" not in os.environ: - os.environ["GRADIO_TEMP_DIR"] = os.path.join(shark_tmp, "gradio") - - print( - f"gradio temporary image cache located at {os.environ['GRADIO_TEMP_DIR']}. " - + "You may change this by setting the GRADIO_TEMP_DIR environment variable." - ) - - # Clear all gradio tmp images from the last session - if os.path.exists(os.environ["GRADIO_TEMP_DIR"]): - cleanup_start = time() - print( - "Clearing gradio UI temporary image files from a prior run. This may take some time..." - ) - shutil.rmtree(os.environ["GRADIO_TEMP_DIR"], ignore_errors=True) - print( - f"Clearing gradio UI temporary image files took {time() - cleanup_start:.4f} seconds." - ) - - # older SHARK versions had to workaround gradio bugs and stored things differently - else: - image_files = [ - filename - for filename in os.listdir(shark_tmp) - if os.path.isfile(os.path.join(shark_tmp, filename)) - and filename.startswith("tmp") - and filename.endswith(".png") - ] - if len(image_files) > 0: - print( - "Clearing temporary image files of a prior run of a previous SHARK version. This may take some time..." - ) - cleanup_start = time() - for filename in image_files: - os.remove(shark_tmp + filename) - print( - f"Clearing temporary image files took {time() - cleanup_start:.4f} seconds." - ) - else: - print("No temporary images files to clear.") - - -def config_tmp(): - # create shark_tmp if it does not exist - if not os.path.exists(shark_tmp): - os.mkdir(shark_tmp) - - clear_tmp_mlir() - clear_tmp_imgs()