mirror of
https://github.com/nod-ai/AMD-SHARK-Studio.git
synced 2026-04-03 03:00:17 -04:00
Merge branch 'main' into ean-pytest-bench
This commit is contained in:
2
.flake8
2
.flake8
@@ -2,4 +2,4 @@
|
||||
count = 1
|
||||
show-source = 1
|
||||
select = E9,F63,F7,F82
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py
|
||||
exclude = lit.cfg.py, apps/language_models/scripts/vicuna.py, apps/language_models/src/pipelines/minigpt4_pipeline.py, apps/language_models/langchain/h2oai_pipeline.py
|
||||
|
||||
@@ -1,406 +0,0 @@
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from evaluate_params import eval_func_param_names, eval_extra_columns
|
||||
from gen import Langchain
|
||||
from prompter import Prompter
|
||||
from utils import clear_torch_cache, NullContext, get_kwargs
|
||||
|
||||
|
||||
def run_eval( # for local function:
|
||||
base_model=None,
|
||||
lora_weights=None,
|
||||
inference_server=None,
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
debug=None,
|
||||
chat=False,
|
||||
chat_context=None,
|
||||
stream_output=None,
|
||||
eval_filename=None,
|
||||
eval_prompts_only_num=None,
|
||||
eval_prompts_only_seed=None,
|
||||
eval_as_output=None,
|
||||
examples=None,
|
||||
memory_restriction_level=None,
|
||||
# for get_model:
|
||||
score_model=None,
|
||||
load_8bit=None,
|
||||
load_4bit=None,
|
||||
load_half=None,
|
||||
load_gptq=None,
|
||||
use_safetensors=None,
|
||||
infer_devices=None,
|
||||
tokenizer_base_model=None,
|
||||
gpu_id=None,
|
||||
local_files_only=None,
|
||||
resume_download=None,
|
||||
use_auth_token=None,
|
||||
trust_remote_code=None,
|
||||
offload_folder=None,
|
||||
compile_model=None,
|
||||
# for evaluate args beyond what's already above, or things that are always dynamic and locally created
|
||||
temperature=None,
|
||||
top_p=None,
|
||||
top_k=None,
|
||||
num_beams=None,
|
||||
max_new_tokens=None,
|
||||
min_new_tokens=None,
|
||||
early_stopping=None,
|
||||
max_time=None,
|
||||
repetition_penalty=None,
|
||||
num_return_sequences=None,
|
||||
do_sample=None,
|
||||
langchain_mode=None,
|
||||
langchain_action=None,
|
||||
top_k_docs=None,
|
||||
chunk=None,
|
||||
chunk_size=None,
|
||||
document_choice=None,
|
||||
# for evaluate kwargs:
|
||||
src_lang=None,
|
||||
tgt_lang=None,
|
||||
concurrency_count=None,
|
||||
save_dir=None,
|
||||
sanitize_bot_response=None,
|
||||
model_state0=None,
|
||||
max_max_new_tokens=None,
|
||||
is_public=None,
|
||||
max_max_time=None,
|
||||
raise_generate_gpu_exceptions=None,
|
||||
load_db_if_exists=None,
|
||||
dbs=None,
|
||||
user_path=None,
|
||||
detect_user_path_changes_every_query=None,
|
||||
use_openai_embedding=None,
|
||||
use_openai_model=None,
|
||||
hf_embedding_model=None,
|
||||
db_type=None,
|
||||
n_jobs=None,
|
||||
first_para=None,
|
||||
text_limit=None,
|
||||
verbose=None,
|
||||
cli=None,
|
||||
reverse_docs=None,
|
||||
use_cache=None,
|
||||
auto_reduce_chunks=None,
|
||||
max_chunks=None,
|
||||
model_lock=None,
|
||||
force_langchain_evaluate=None,
|
||||
model_state_none=None,
|
||||
):
|
||||
Langchain.check_locals(**locals())
|
||||
|
||||
if eval_prompts_only_num > 0:
|
||||
np.random.seed(eval_prompts_only_seed)
|
||||
example1 = examples[-1] # pick reference example
|
||||
examples = []
|
||||
responses = []
|
||||
if eval_filename is None:
|
||||
# override default examples with shareGPT ones for human-level eval purposes only
|
||||
eval_filename = (
|
||||
"ShareGPT_V3_unfiltered_cleaned_split_no_imsorry.json"
|
||||
)
|
||||
if not os.path.isfile(eval_filename):
|
||||
os.system(
|
||||
"wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s"
|
||||
% eval_filename
|
||||
)
|
||||
import json
|
||||
|
||||
data = json.load(open(eval_filename, "rt"))
|
||||
# focus on data that starts with human, else likely chopped from other data
|
||||
turn_start = 0 # odd in general
|
||||
data = [
|
||||
x
|
||||
for x in data
|
||||
if len(x["conversations"]) > turn_start + 1
|
||||
and x["conversations"][turn_start]["from"] == "human"
|
||||
and x["conversations"][turn_start + 1]["from"] == "gpt"
|
||||
]
|
||||
for i in sorted(
|
||||
np.random.randint(0, len(data), size=eval_prompts_only_num)
|
||||
):
|
||||
assert data[i]["conversations"][turn_start]["from"] == "human"
|
||||
instruction = data[i]["conversations"][turn_start]["value"]
|
||||
assert (
|
||||
data[i]["conversations"][turn_start + 1]["from"] == "gpt"
|
||||
)
|
||||
output = data[i]["conversations"][turn_start + 1]["value"]
|
||||
examplenew = example1.copy()
|
||||
assert (
|
||||
not chat
|
||||
), "No gradio must use chat=False, uses nochat instruct"
|
||||
examplenew[
|
||||
eval_func_param_names.index("instruction_nochat")
|
||||
] = instruction
|
||||
examplenew[
|
||||
eval_func_param_names.index("iinput_nochat")
|
||||
] = "" # no input
|
||||
examplenew[
|
||||
eval_func_param_names.index("context")
|
||||
] = Langchain.get_context(chat_context, prompt_type)
|
||||
examples.append(examplenew)
|
||||
responses.append(output)
|
||||
else:
|
||||
# get data, assume in correct format: json of rows of dict of instruction and output
|
||||
# only instruction is required
|
||||
import json
|
||||
|
||||
data = json.load(open(eval_filename, "rt"))
|
||||
for i in sorted(
|
||||
np.random.randint(0, len(data), size=eval_prompts_only_num)
|
||||
):
|
||||
examplenew = example1.copy()
|
||||
instruction = data[i]["instruction"]
|
||||
output = data[i].get("output", "") # not required
|
||||
assert (
|
||||
not chat
|
||||
), "No gradio must use chat=False, uses nochat instruct"
|
||||
examplenew[
|
||||
eval_func_param_names.index("instruction_nochat")
|
||||
] = instruction
|
||||
examplenew[
|
||||
eval_func_param_names.index("iinput_nochat")
|
||||
] = "" # no input
|
||||
examplenew[
|
||||
eval_func_param_names.index("context")
|
||||
] = Langchain.get_context(chat_context, prompt_type)
|
||||
examples.append(examplenew)
|
||||
responses.append(output)
|
||||
|
||||
num_examples = len(examples)
|
||||
scoring_path = "scoring"
|
||||
os.makedirs(scoring_path, exist_ok=True)
|
||||
if eval_as_output:
|
||||
used_base_model = "gpt35"
|
||||
used_lora_weights = ""
|
||||
used_inference_server = ""
|
||||
else:
|
||||
used_base_model = str(base_model.split("/")[-1])
|
||||
used_lora_weights = str(lora_weights.split("/")[-1])
|
||||
used_inference_server = str(inference_server.split("/")[-1])
|
||||
eval_out_filename = "df_scores_%s_%s_%s_%s_%s_%s_%s.parquet" % (
|
||||
num_examples,
|
||||
eval_prompts_only_num,
|
||||
eval_prompts_only_seed,
|
||||
eval_as_output,
|
||||
used_base_model,
|
||||
used_lora_weights,
|
||||
used_inference_server,
|
||||
)
|
||||
eval_out_filename = os.path.join(scoring_path, eval_out_filename)
|
||||
|
||||
# torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
device = "cpu" if n_gpus == 0 else "cuda"
|
||||
context_class = NullContext if n_gpus > 1 or n_gpus == 0 else torch.device
|
||||
|
||||
with context_class(device):
|
||||
# ensure was set right above before examples generated
|
||||
assert (
|
||||
not stream_output
|
||||
), "stream_output=True does not make sense with example loop"
|
||||
import time
|
||||
from functools import partial
|
||||
|
||||
# get score model
|
||||
smodel, stokenizer, sdevice = Langchain.get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
Langchain.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
|
||||
if not eval_as_output:
|
||||
model, tokenizer, device = Langchain.get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(
|
||||
Langchain.get_model,
|
||||
exclude_names=["reward_type"],
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
model_dict = dict(
|
||||
base_model=base_model,
|
||||
tokenizer_base_model=tokenizer_base_model,
|
||||
lora_weights=lora_weights,
|
||||
inference_server=inference_server,
|
||||
prompt_type=prompt_type,
|
||||
prompt_dict=prompt_dict,
|
||||
)
|
||||
model_state = dict(model=model, tokenizer=tokenizer, device=device)
|
||||
model_state.update(model_dict)
|
||||
my_db_state = [None]
|
||||
fun = partial(
|
||||
Langchain.evaluate,
|
||||
model_state,
|
||||
my_db_state,
|
||||
**get_kwargs(
|
||||
Langchain.evaluate,
|
||||
exclude_names=["model_state", "my_db_state"]
|
||||
+ eval_func_param_names,
|
||||
**locals()
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert eval_prompts_only_num > 0
|
||||
|
||||
def get_response(*args, exi=0):
|
||||
# assumes same ordering of examples and responses
|
||||
yield responses[exi]
|
||||
|
||||
fun = get_response
|
||||
t0 = time.time()
|
||||
score_dump = []
|
||||
score_avg = 0
|
||||
score_median = 0
|
||||
|
||||
for exi, ex in enumerate(examples):
|
||||
clear_torch_cache()
|
||||
|
||||
instruction = ex[eval_func_param_names.index("instruction_nochat")]
|
||||
iinput = ex[eval_func_param_names.index("iinput_nochat")]
|
||||
context = ex[eval_func_param_names.index("context")]
|
||||
clear_torch_cache()
|
||||
print("")
|
||||
print("START" + "=" * 100)
|
||||
print(
|
||||
"Question: %s %s"
|
||||
% (instruction, ("input=%s" % iinput if iinput else ""))
|
||||
)
|
||||
print("-" * 105)
|
||||
# fun yields as generator, so have to iterate over it
|
||||
# Also means likely do NOT want --stream_output=True, else would show all generations
|
||||
t1 = time.time()
|
||||
gener = (
|
||||
fun(*tuple(ex), exi=exi) if eval_as_output else fun(*tuple(ex))
|
||||
)
|
||||
for res_fun in gener:
|
||||
res = res_fun["response"]
|
||||
extra = res_fun["sources"]
|
||||
print(res)
|
||||
if smodel:
|
||||
score_with_prompt = False
|
||||
if score_with_prompt:
|
||||
data_point = dict(
|
||||
instruction=instruction,
|
||||
input=iinput,
|
||||
context=context,
|
||||
)
|
||||
prompter = Prompter(
|
||||
prompt_type,
|
||||
prompt_dict,
|
||||
debug=debug,
|
||||
chat=chat,
|
||||
stream_output=stream_output,
|
||||
)
|
||||
prompt = prompter.generate_prompt(data_point)
|
||||
else:
|
||||
# just raw input and output
|
||||
if eval_prompts_only_num > 0:
|
||||
# only our own examples have this filled at moment
|
||||
assert iinput in [
|
||||
None,
|
||||
"",
|
||||
], iinput # should be no iinput
|
||||
if not (chat_context and prompt_type == "human_bot"):
|
||||
assert context in [
|
||||
None,
|
||||
"",
|
||||
], context # should be no context
|
||||
prompt = instruction
|
||||
if memory_restriction_level > 0:
|
||||
cutoff_len = (
|
||||
768 if memory_restriction_level <= 2 else 512
|
||||
)
|
||||
else:
|
||||
cutoff_len = tokenizer.model_max_length
|
||||
inputs = stokenizer(
|
||||
prompt,
|
||||
res,
|
||||
return_tensors="pt",
|
||||
truncation=True,
|
||||
max_length=cutoff_len,
|
||||
)
|
||||
try:
|
||||
score = (
|
||||
torch.sigmoid(smodel(**inputs).logits[0].float())
|
||||
.cpu()
|
||||
.detach()
|
||||
.numpy()[0]
|
||||
)
|
||||
except torch.cuda.OutOfMemoryError as e:
|
||||
print(
|
||||
"GPU OOM 1: question: %s answer: %s exception: %s"
|
||||
% (prompt, res, str(e)),
|
||||
flush=True,
|
||||
)
|
||||
traceback.print_exc()
|
||||
score = 0.0
|
||||
clear_torch_cache()
|
||||
except (Exception, RuntimeError) as e:
|
||||
if (
|
||||
"Expected all tensors to be on the same device"
|
||||
in str(e)
|
||||
or "expected scalar type Half but found Float"
|
||||
in str(e)
|
||||
or "probability tensor contains either" in str(e)
|
||||
or "cublasLt ran into an error!" in str(e)
|
||||
):
|
||||
print(
|
||||
"GPU error: question: %s answer: %s exception: %s"
|
||||
% (prompt, res, str(e)),
|
||||
flush=True,
|
||||
)
|
||||
traceback.print_exc()
|
||||
score = 0.0
|
||||
clear_torch_cache()
|
||||
else:
|
||||
raise
|
||||
score_dump.append(ex + [prompt, res, score])
|
||||
# dump every score in case abort
|
||||
df_scores = pd.DataFrame(
|
||||
score_dump,
|
||||
columns=eval_func_param_names + eval_extra_columns,
|
||||
)
|
||||
df_scores.to_parquet(eval_out_filename, index=False)
|
||||
# plot histogram so far
|
||||
plt.figure(figsize=(10, 10))
|
||||
plt.hist(df_scores["score"], bins=20)
|
||||
score_avg = np.mean(df_scores["score"])
|
||||
score_median = np.median(df_scores["score"])
|
||||
print(
|
||||
"SCORE %s: %s So far: AVG: %s MEDIAN: %s"
|
||||
% (exi, score, score_avg, score_median),
|
||||
flush=True,
|
||||
)
|
||||
plt.title(
|
||||
"Score avg: %s median: %s" % (score_avg, score_median)
|
||||
)
|
||||
plt.savefig(eval_out_filename.replace(".parquet", ".png"))
|
||||
plt.close()
|
||||
|
||||
print("END" + "=" * 102)
|
||||
print("")
|
||||
t2 = time.time()
|
||||
print(
|
||||
"Time taken for example: %s Time taken so far: %.4f about %.4g per example"
|
||||
% (t2 - t1, t2 - t0, (t2 - t0) / (1 + exi))
|
||||
)
|
||||
t1 = time.time()
|
||||
print(
|
||||
"Total time taken: %.4f about %.4g per example"
|
||||
% (t1 - t0, (t1 - t0) / num_examples)
|
||||
)
|
||||
print(
|
||||
"Score avg: %s median: %s" % (score_avg, score_median), flush=True
|
||||
)
|
||||
return eval_out_filename
|
||||
432
apps/language_models/langchain/expanded_pipelines.py
Normal file
432
apps/language_models/langchain/expanded_pipelines.py
Normal file
@@ -0,0 +1,432 @@
|
||||
"""Load question answering chains."""
|
||||
from __future__ import annotations
|
||||
from typing import (
|
||||
Any,
|
||||
Mapping,
|
||||
Optional,
|
||||
Dict,
|
||||
List,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
Protocol,
|
||||
)
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.question_answering import stuff_prompt
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.docstore.document import Document
|
||||
from abc import ABC, abstractmethod
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManager,
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.input import get_colored_text
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
from langchain.schema import LLMResult, PromptValue
|
||||
from pydantic import Extra, Field, root_validator
|
||||
|
||||
|
||||
def format_document(doc: Document, prompt: BasePromptTemplate) -> str:
|
||||
"""Format a document into a string based on a prompt template."""
|
||||
base_info = {"page_content": doc.page_content}
|
||||
base_info.update(doc.metadata)
|
||||
missing_metadata = set(prompt.input_variables).difference(base_info)
|
||||
if len(missing_metadata) > 0:
|
||||
required_metadata = [
|
||||
iv for iv in prompt.input_variables if iv != "page_content"
|
||||
]
|
||||
raise ValueError(
|
||||
f"Document prompt requires documents to have metadata variables: "
|
||||
f"{required_metadata}. Received document with missing metadata: "
|
||||
f"{list(missing_metadata)}."
|
||||
)
|
||||
document_info = {k: base_info[k] for k in prompt.input_variables}
|
||||
return prompt.format(**document_info)
|
||||
|
||||
|
||||
class BaseCombineDocumentsChain(Chain, ABC):
|
||||
"""Base interface for chains combining documents."""
|
||||
|
||||
input_key: str = "input_documents" #: :meta private:
|
||||
output_key: str = "output_text" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def prompt_length(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Optional[int]:
|
||||
"""Return the prompt length given the documents passed in.
|
||||
|
||||
Returns None if the method does not depend on the prompt length.
|
||||
"""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def combine_docs(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Combine documents into a single string."""
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, List[Document]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
_run_manager = (
|
||||
run_manager or CallbackManagerForChainRun.get_noop_manager()
|
||||
)
|
||||
docs = inputs[self.input_key]
|
||||
# Other keys are assumed to be needed for LLM prediction
|
||||
other_keys = {k: v for k, v in inputs.items() if k != self.input_key}
|
||||
output, extra_return_dict = self.combine_docs(
|
||||
docs, callbacks=_run_manager.get_child(), **other_keys
|
||||
)
|
||||
extra_return_dict[self.output_key] = output
|
||||
return extra_return_dict
|
||||
|
||||
|
||||
class LLMChain(Chain):
|
||||
"""Chain to run queries against LLMs.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import LLMChain, OpenAI, PromptTemplate
|
||||
prompt_template = "Tell me a {adjective} joke"
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["adjective"], template=prompt_template
|
||||
)
|
||||
llm = LLMChain(llm=OpenAI(), prompt=prompt)
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
prompt: BasePromptTemplate
|
||||
"""Prompt object to use."""
|
||||
llm: BaseLanguageModel
|
||||
output_key: str = "text" #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Will be whatever keys the prompt expects.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return self.prompt.input_variables
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Will always return text key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(
|
||||
self,
|
||||
inputs: Dict[str, Any],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Dict[str, str]:
|
||||
response = self.generate([inputs], run_manager=run_manager)
|
||||
return self.create_outputs(response)[0]
|
||||
|
||||
def generate(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> LLMResult:
|
||||
"""Generate LLM result from inputs."""
|
||||
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
||||
return self.llm.generate_prompt(
|
||||
prompts,
|
||||
stop,
|
||||
callbacks=run_manager.get_child() if run_manager else None,
|
||||
)
|
||||
|
||||
def prep_prompts(
|
||||
self,
|
||||
input_list: List[Dict[str, Any]],
|
||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
||||
) -> Tuple[List[PromptValue], Optional[List[str]]]:
|
||||
"""Prepare prompts from inputs."""
|
||||
stop = None
|
||||
if "stop" in input_list[0]:
|
||||
stop = input_list[0]["stop"]
|
||||
prompts = []
|
||||
for inputs in input_list:
|
||||
selected_inputs = {
|
||||
k: inputs[k] for k in self.prompt.input_variables
|
||||
}
|
||||
prompt = self.prompt.format_prompt(**selected_inputs)
|
||||
_colored_text = get_colored_text(prompt.to_string(), "green")
|
||||
_text = "Prompt after formatting:\n" + _colored_text
|
||||
if run_manager:
|
||||
run_manager.on_text(_text, end="\n", verbose=self.verbose)
|
||||
if "stop" in inputs and inputs["stop"] != stop:
|
||||
raise ValueError(
|
||||
"If `stop` is present in any inputs, should be present in all."
|
||||
)
|
||||
prompts.append(prompt)
|
||||
return prompts, stop
|
||||
|
||||
def apply(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Utilize the LLM generate method for speed gains."""
|
||||
callback_manager = CallbackManager.configure(
|
||||
callbacks, self.callbacks, self.verbose
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self),
|
||||
{"input_list": input_list},
|
||||
)
|
||||
try:
|
||||
response = self.generate(input_list, run_manager=run_manager)
|
||||
except (KeyboardInterrupt, Exception) as e:
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
outputs = self.create_outputs(response)
|
||||
run_manager.on_chain_end({"outputs": outputs})
|
||||
return outputs
|
||||
|
||||
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
|
||||
"""Create outputs from response."""
|
||||
return [
|
||||
# Get the text of the top generated string.
|
||||
{self.output_key: generation[0].text}
|
||||
for generation in response.generations
|
||||
]
|
||||
|
||||
def predict(self, callbacks: Callbacks = None, **kwargs: Any) -> str:
|
||||
"""Format prompt with kwargs and pass to LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Callbacks to pass to LLMChain
|
||||
**kwargs: Keys to pass to prompt template.
|
||||
|
||||
Returns:
|
||||
Completion from LLM.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
completion = llm.predict(adjective="funny")
|
||||
"""
|
||||
return self(kwargs, callbacks=callbacks)[self.output_key]
|
||||
|
||||
def predict_and_parse(
|
||||
self, callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Union[str, List[str], Dict[str, Any]]:
|
||||
"""Call predict and then parse the results."""
|
||||
result = self.predict(callbacks=callbacks, **kwargs)
|
||||
if self.prompt.output_parser is not None:
|
||||
return self.prompt.output_parser.parse(result)
|
||||
else:
|
||||
return result
|
||||
|
||||
def apply_and_parse(
|
||||
self, input_list: List[Dict[str, Any]], callbacks: Callbacks = None
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
"""Call apply and then parse the results."""
|
||||
result = self.apply(input_list, callbacks=callbacks)
|
||||
return self._parse_result(result)
|
||||
|
||||
def _parse_result(
|
||||
self, result: List[Dict[str, str]]
|
||||
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
|
||||
if self.prompt.output_parser is not None:
|
||||
return [
|
||||
self.prompt.output_parser.parse(res[self.output_key])
|
||||
for res in result
|
||||
]
|
||||
else:
|
||||
return result
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "llm_chain"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, llm: BaseLanguageModel, template: str) -> LLMChain:
|
||||
"""Create LLMChain from LLM and template."""
|
||||
prompt_template = PromptTemplate.from_template(template)
|
||||
return cls(llm=llm, prompt=prompt_template)
|
||||
|
||||
|
||||
def _get_default_document_prompt() -> PromptTemplate:
|
||||
return PromptTemplate(
|
||||
input_variables=["page_content"], template="{page_content}"
|
||||
)
|
||||
|
||||
|
||||
class StuffDocumentsChain(BaseCombineDocumentsChain):
|
||||
"""Chain that combines documents by stuffing into context."""
|
||||
|
||||
llm_chain: LLMChain
|
||||
"""LLM wrapper to use after formatting documents."""
|
||||
document_prompt: BasePromptTemplate = Field(
|
||||
default_factory=_get_default_document_prompt
|
||||
)
|
||||
"""Prompt to use to format each document."""
|
||||
document_variable_name: str
|
||||
"""The variable name in the llm_chain to put the documents in.
|
||||
If only one variable in the llm_chain, this need not be provided."""
|
||||
document_separator: str = "\n\n"
|
||||
"""The string with which to join the formatted documents"""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def get_default_document_variable_name(cls, values: Dict) -> Dict:
|
||||
"""Get default document variable name, if not provided."""
|
||||
llm_chain_variables = values["llm_chain"].prompt.input_variables
|
||||
if "document_variable_name" not in values:
|
||||
if len(llm_chain_variables) == 1:
|
||||
values["document_variable_name"] = llm_chain_variables[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"document_variable_name must be provided if there are "
|
||||
"multiple llm_chain_variables"
|
||||
)
|
||||
else:
|
||||
if values["document_variable_name"] not in llm_chain_variables:
|
||||
raise ValueError(
|
||||
f"document_variable_name {values['document_variable_name']} was "
|
||||
f"not found in llm_chain input_variables: {llm_chain_variables}"
|
||||
)
|
||||
return values
|
||||
|
||||
def _get_inputs(self, docs: List[Document], **kwargs: Any) -> dict:
|
||||
# Format each document according to the prompt
|
||||
doc_strings = [
|
||||
format_document(doc, self.document_prompt) for doc in docs
|
||||
]
|
||||
# Join the documents together to put them in the prompt.
|
||||
inputs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
if k in self.llm_chain.prompt.input_variables
|
||||
}
|
||||
inputs[self.document_variable_name] = self.document_separator.join(
|
||||
doc_strings
|
||||
)
|
||||
return inputs
|
||||
|
||||
def prompt_length(
|
||||
self, docs: List[Document], **kwargs: Any
|
||||
) -> Optional[int]:
|
||||
"""Get the prompt length by formatting the prompt."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
prompt = self.llm_chain.prompt.format(**inputs)
|
||||
return self.llm_chain.llm.get_num_tokens(prompt)
|
||||
|
||||
def combine_docs(
|
||||
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any
|
||||
) -> Tuple[str, dict]:
|
||||
"""Stuff all documents into one prompt and pass to LLM."""
|
||||
inputs = self._get_inputs(docs, **kwargs)
|
||||
# Call predict on the LLM.
|
||||
return self.llm_chain.predict(callbacks=callbacks, **inputs), {}
|
||||
|
||||
@property
|
||||
def _chain_type(self) -> str:
|
||||
return "stuff_documents_chain"
|
||||
|
||||
|
||||
class LoadingCallable(Protocol):
|
||||
"""Interface for loading the combine documents chain."""
|
||||
|
||||
def __call__(
|
||||
self, llm: BaseLanguageModel, **kwargs: Any
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Callable to load the combine documents chain."""
|
||||
|
||||
|
||||
def _load_stuff_chain(
|
||||
llm: BaseLanguageModel,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
document_variable_name: str = "context",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs: Any,
|
||||
) -> StuffDocumentsChain:
|
||||
_prompt = prompt or stuff_prompt.PROMPT_SELECTOR.get_prompt(llm)
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=_prompt,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
callbacks=callbacks,
|
||||
)
|
||||
# TODO: document prompt
|
||||
return StuffDocumentsChain(
|
||||
llm_chain=llm_chain,
|
||||
document_variable_name=document_variable_name,
|
||||
verbose=verbose,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_qa_chain(
|
||||
llm: BaseLanguageModel,
|
||||
chain_type: str = "stuff",
|
||||
verbose: Optional[bool] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseCombineDocumentsChain:
|
||||
"""Load question answering chain.
|
||||
|
||||
Args:
|
||||
llm: Language Model to use in the chain.
|
||||
chain_type: Type of document combining chain to use. Should be one of "stuff",
|
||||
"map_reduce", "map_rerank", and "refine".
|
||||
verbose: Whether chains should be run in verbose mode or not. Note that this
|
||||
applies to all chains that make up the final chain.
|
||||
callback_manager: Callback manager to use for the chain.
|
||||
|
||||
Returns:
|
||||
A chain to use for question answering.
|
||||
"""
|
||||
loader_mapping: Mapping[str, LoadingCallable] = {
|
||||
"stuff": _load_stuff_chain,
|
||||
}
|
||||
if chain_type not in loader_mapping:
|
||||
raise ValueError(
|
||||
f"Got unsupported chain type: {chain_type}. "
|
||||
f"Should be one of {loader_mapping.keys()}"
|
||||
)
|
||||
return loader_mapping[chain_type](
|
||||
llm, verbose=verbose, callback_manager=callback_manager, **kwargs
|
||||
)
|
||||
@@ -1,283 +0,0 @@
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
def do_export():
|
||||
BASE_MODEL = "h2oai/h2ogpt-oasst1-512-12b"
|
||||
LORA_WEIGHTS = "h2ogpt-oasst1-512-12b.h2oaih2ogpt-oig-oasst1-instruct-cleaned-v3.1_epochs.805b8e8eff369207340a5a6f90f3c833f9731254.2"
|
||||
OUTPUT_NAME = "h2ogpt-oig-oasst1-512-12b"
|
||||
|
||||
BASE_MODEL = "EleutherAI/pythia-12b-deduped"
|
||||
LORA_WEIGHTS = "pythia-12b-deduped.h2oaiopenassistant_oasst1_h2ogpt_graded.3_epochs.2ccf687ea3f3f3775a501838e81c1a0066430455.4"
|
||||
OUTPUT_NAME = "h2ogpt-oasst1-512-12b"
|
||||
|
||||
BASE_MODEL = "tiiuae/falcon-40b"
|
||||
LORA_WEIGHTS = "falcon-40b.h2oaiopenassistant_oasst1_h2ogpt.1_epochs.894d8450d35c180cd03222a45658d04c15b78d4b.9"
|
||||
OUTPUT_NAME = "h2ogpt-oasst1-2048-falcon-40b"
|
||||
|
||||
# BASE_MODEL = 'decapoda-research/llama-65b-hf'
|
||||
# LORA_WEIGHTS = 'llama-65b-hf.h2oaiopenassistant_oasst1_h2ogpt_graded.1_epochs.113510499324f0f007cbec9d9f1f8091441f2469.3'
|
||||
# OUTPUT_NAME = "h2ogpt-research-oasst1-llama-65b"
|
||||
|
||||
model = os.getenv("MODEL")
|
||||
# for testing
|
||||
if model:
|
||||
BASE_MODEL = "tiiuae/falcon-7b"
|
||||
LORA_WEIGHTS = model + ".lora"
|
||||
OUTPUT_NAME = model
|
||||
|
||||
llama_type = "llama" in BASE_MODEL
|
||||
as_pytorch = False # False -> HF
|
||||
|
||||
from loaders import get_loaders
|
||||
|
||||
model_loader, tokenizer_loader = get_loaders(
|
||||
model_name=BASE_MODEL, reward_type=False, llama_type=llama_type
|
||||
)
|
||||
|
||||
tokenizer = tokenizer_loader.from_pretrained(
|
||||
BASE_MODEL,
|
||||
local_files_only=False,
|
||||
resume_download=True,
|
||||
)
|
||||
tokenizer.save_pretrained(OUTPUT_NAME)
|
||||
|
||||
base_model = model_loader(
|
||||
BASE_MODEL,
|
||||
load_in_8bit=False,
|
||||
trust_remote_code=True,
|
||||
torch_dtype=torch.float16,
|
||||
device_map={"": "cpu"},
|
||||
)
|
||||
|
||||
print(base_model)
|
||||
if llama_type:
|
||||
layers = base_model.model.layers
|
||||
first_weight = layers[0].self_attn.q_proj.weight
|
||||
else:
|
||||
if any(
|
||||
[x in BASE_MODEL.lower() for x in ["pythia", "h2ogpt", "gpt-neox"]]
|
||||
):
|
||||
layers = base_model.gpt_neox.base_model.layers
|
||||
first_weight = layers[0].attention.query_key_value.weight
|
||||
elif any([x in BASE_MODEL.lower() for x in ["falcon"]]):
|
||||
first_weight = base_model.transformer.h._modules[
|
||||
"0"
|
||||
].self_attention.query_key_value.weight
|
||||
else:
|
||||
layers = base_model.transformer.base_model.h
|
||||
first_weight = layers[0].attn.q_proj.weight
|
||||
first_weight_old = first_weight.clone()
|
||||
|
||||
lora_model = PeftModel.from_pretrained(
|
||||
base_model,
|
||||
LORA_WEIGHTS,
|
||||
device_map={"": "cpu"},
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
|
||||
assert torch.allclose(first_weight_old, first_weight)
|
||||
|
||||
# merge weights TODO: include all lora_target_modules, not just default ones
|
||||
if llama_type:
|
||||
lora_model = lora_model.merge_and_unload()
|
||||
# for layer in lora_model.base_model.model.model.layers:
|
||||
# layer.self_attn.q_proj.merge_weights = True
|
||||
# layer.self_attn.k_proj.merge_weights = True
|
||||
# layer.self_attn.v_proj.merge_weights = True
|
||||
# layer.self_attn.o_proj.merge_weights = True
|
||||
else:
|
||||
if any(
|
||||
[x in BASE_MODEL.lower() for x in ["pythia", "h2ogpt", "gpt-neox"]]
|
||||
):
|
||||
for layer in lora_model.base_model.gpt_neox.base_model.layers:
|
||||
layer.attention.query_key_value.merge_weights = True
|
||||
else:
|
||||
lora_model.merge_and_unload()
|
||||
# for layer in lora_model.base_model.transformer.base_model.h:
|
||||
# layer.attn.q_proj.merge_weights = True
|
||||
# layer.attn.v_proj.merge_weights = True
|
||||
|
||||
lora_model.train(False)
|
||||
|
||||
# did we do anything?
|
||||
assert not torch.allclose(first_weight_old, first_weight)
|
||||
|
||||
lora_model_sd = lora_model.state_dict()
|
||||
|
||||
if as_pytorch:
|
||||
# FIXME - might not be generic enough still
|
||||
params = {
|
||||
"dim": base_model.config.hidden_size,
|
||||
"n_heads": base_model.config.num_attention_heads,
|
||||
"n_layers": base_model.config.num_hidden_layers,
|
||||
"norm_eps": base_model.config.layer_norm_eps,
|
||||
"vocab_size": base_model.config.vocab_size,
|
||||
}
|
||||
n_layers = params["n_layers"]
|
||||
n_heads = params["n_heads"]
|
||||
dim = params["dim"]
|
||||
dims_per_head = dim // n_heads
|
||||
base = 10000.0
|
||||
inv_freq = 1.0 / (
|
||||
base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)
|
||||
)
|
||||
|
||||
def permute(w):
|
||||
return (
|
||||
w.view(n_heads, dim // n_heads // 2, 2, dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(dim, dim)
|
||||
)
|
||||
|
||||
def unpermute(w):
|
||||
return (
|
||||
w.view(n_heads, 2, dim // n_heads // 2, dim)
|
||||
.transpose(1, 2)
|
||||
.reshape(dim, dim)
|
||||
)
|
||||
|
||||
def translate_state_dict_key(k):
|
||||
if "gpt-neoxt" in BASE_MODEL.lower():
|
||||
k = k.replace("gpt_neox.model.", "")
|
||||
else:
|
||||
k = k.replace("base_model.model.", "")
|
||||
if k == "model.embed_tokens.weight":
|
||||
return "tok_embeddings.weight"
|
||||
elif k == "model.norm.weight":
|
||||
return "norm.weight"
|
||||
elif k == "lm_head.weight":
|
||||
return "output.weight"
|
||||
elif k.startswith("model.layers."):
|
||||
layer = k.split(".")[2]
|
||||
if k.endswith(".self_attn.q_proj.weight"):
|
||||
return f"layers.{layer}.attention.wq.weight"
|
||||
elif k.endswith(".self_attn.k_proj.weight"):
|
||||
return f"layers.{layer}.attention.wk.weight"
|
||||
elif k.endswith(".self_attn.v_proj.weight"):
|
||||
return f"layers.{layer}.attention.wv.weight"
|
||||
elif k.endswith(".self_attn.o_proj.weight"):
|
||||
return f"layers.{layer}.attention.wo.weight"
|
||||
elif k.endswith(".mlp.gate_proj.weight"):
|
||||
return f"layers.{layer}.feed_forward.w1.weight"
|
||||
elif k.endswith(".mlp.down_proj.weight"):
|
||||
return f"layers.{layer}.feed_forward.w2.weight"
|
||||
elif k.endswith(".mlp.up_proj.weight"):
|
||||
return f"layers.{layer}.feed_forward.w3.weight"
|
||||
elif k.endswith(".input_layernorm.weight"):
|
||||
return f"layers.{layer}.attention_norm.weight"
|
||||
elif k.endswith(".post_attention_layernorm.weight"):
|
||||
return f"layers.{layer}.ffn_norm.weight"
|
||||
elif k.endswith("rotary_emb.inv_freq") or "lora" in k:
|
||||
return None
|
||||
else:
|
||||
print(layer, k)
|
||||
raise NotImplementedError
|
||||
else:
|
||||
print(k)
|
||||
raise NotImplementedError
|
||||
|
||||
new_state_dict = {}
|
||||
for k, v in lora_model_sd.items():
|
||||
new_k = translate_state_dict_key(k)
|
||||
if new_k is not None:
|
||||
if "wq" in new_k or "wk" in new_k:
|
||||
new_state_dict[new_k] = unpermute(v)
|
||||
else:
|
||||
new_state_dict[new_k] = v
|
||||
|
||||
os.makedirs("./ckpt", exist_ok=True)
|
||||
|
||||
torch.save(new_state_dict, "./ckpt/consolidated.00.pth")
|
||||
|
||||
with open("./ckpt/params.json", "w") as f:
|
||||
json.dump(params, f)
|
||||
else:
|
||||
deloreanized_sd = {
|
||||
k.replace("base_model.model.", ""): v
|
||||
for k, v in lora_model_sd.items()
|
||||
if "lora" not in k
|
||||
}
|
||||
base_model.config.custom_pipelines = {
|
||||
"text-generation": {
|
||||
"impl": "h2oai_pipeline.H2OTextGenerationPipeline",
|
||||
"pt": "AutoModelForCausalLM",
|
||||
}
|
||||
}
|
||||
PreTrainedModel.save_pretrained(
|
||||
base_model,
|
||||
OUTPUT_NAME,
|
||||
state_dict=deloreanized_sd,
|
||||
# max_shard_size="5GB",
|
||||
)
|
||||
|
||||
do_copy(OUTPUT_NAME)
|
||||
test_copy()
|
||||
|
||||
|
||||
def do_copy(OUTPUT_NAME):
|
||||
dest_file = os.path.join(OUTPUT_NAME, "h2oai_pipeline.py")
|
||||
shutil.copyfile("src/h2oai_pipeline.py", dest_file)
|
||||
os.system("""sed -i 's/from enums.*//g' %s""" % dest_file)
|
||||
os.system("""sed -i 's/from stopping.*//g' %s""" % dest_file)
|
||||
os.system("""sed -i 's/from prompter.*//g' %s""" % dest_file)
|
||||
os.system(
|
||||
"""cat %s|grep -v "from enums import PromptType" >> %s"""
|
||||
% ("src/enums.py", dest_file)
|
||||
)
|
||||
os.system(
|
||||
"""cat %s|grep -v "from enums import PromptType" >> %s"""
|
||||
% ("src/prompter.py", dest_file)
|
||||
)
|
||||
os.system(
|
||||
"""cat %s|grep -v "from enums import PromptType" >> %s"""
|
||||
% ("src/stopping.py", dest_file)
|
||||
)
|
||||
|
||||
|
||||
TEST_OUTPUT_NAME = "test_output"
|
||||
|
||||
|
||||
def test_copy():
|
||||
if os.path.isdir(TEST_OUTPUT_NAME):
|
||||
shutil.rmtree(TEST_OUTPUT_NAME)
|
||||
os.makedirs(TEST_OUTPUT_NAME, exist_ok=False)
|
||||
do_copy(TEST_OUTPUT_NAME)
|
||||
shutil.copy("src/export_hf_checkpoint.py", TEST_OUTPUT_NAME)
|
||||
os.environ["DO_COPY_TEST"] = "1"
|
||||
os.chdir(TEST_OUTPUT_NAME)
|
||||
output = subprocess.check_output(["python", "export_hf_checkpoint.py"])
|
||||
print(output)
|
||||
|
||||
|
||||
def inner_test_copy():
|
||||
"""
|
||||
pytest -s -v export_hf_checkpoint.py::test_copy
|
||||
:return:
|
||||
"""
|
||||
# test imports
|
||||
# below supposed to look bad in pycharm, don't fix!
|
||||
from h2oai_pipeline import (
|
||||
get_stopping,
|
||||
get_prompt,
|
||||
H2OTextGenerationPipeline,
|
||||
)
|
||||
|
||||
assert get_stopping
|
||||
assert get_prompt
|
||||
assert H2OTextGenerationPipeline
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if os.getenv("DO_COPY_TEST"):
|
||||
inner_test_copy()
|
||||
else:
|
||||
do_export()
|
||||
# uncomment for raw isolated test, but test is done every time for each export now
|
||||
# test_copy()
|
||||
@@ -98,790 +98,6 @@ class Langchain:
|
||||
self.device = device
|
||||
self.precision = precision
|
||||
|
||||
def main(
|
||||
self,
|
||||
load_8bit: bool = False,
|
||||
load_4bit: bool = True,
|
||||
load_half: bool = False,
|
||||
load_gptq: str = "",
|
||||
use_safetensors: bool = False,
|
||||
infer_devices: bool = True,
|
||||
base_model: str = "h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
tokenizer_base_model: str = "",
|
||||
lora_weights: str = "",
|
||||
gpu_id: int = 0,
|
||||
compile_model: bool = True,
|
||||
use_cache: bool = None,
|
||||
inference_server: str = "",
|
||||
prompt_type: Union[int, str] = None,
|
||||
prompt_dict: typing.Dict = None,
|
||||
model_lock: typing.List[typing.Dict[str, str]] = None,
|
||||
model_lock_columns: int = None,
|
||||
fail_if_cannot_connect: bool = False,
|
||||
# input to generation
|
||||
temperature: float = None,
|
||||
top_p: float = None,
|
||||
top_k: int = None,
|
||||
num_beams: int = None,
|
||||
repetition_penalty: float = None,
|
||||
num_return_sequences: int = None,
|
||||
do_sample: bool = None,
|
||||
max_new_tokens: int = None,
|
||||
min_new_tokens: int = None,
|
||||
early_stopping: Union[bool, str] = None,
|
||||
max_time: float = None,
|
||||
memory_restriction_level: int = None,
|
||||
debug: bool = False,
|
||||
save_dir: str = None,
|
||||
share: bool = False,
|
||||
local_files_only: bool = False,
|
||||
resume_download: bool = True,
|
||||
use_auth_token: Union[str, bool] = False,
|
||||
trust_remote_code: Union[str, bool] = True,
|
||||
offload_folder: str = "offline_folder",
|
||||
src_lang: str = "English",
|
||||
tgt_lang: str = "Russian",
|
||||
cli: bool = False,
|
||||
cli_loop: bool = True,
|
||||
gradio: bool = True,
|
||||
gradio_offline_level: int = 0,
|
||||
chat: bool = True,
|
||||
chat_context: bool = False,
|
||||
stream_output: bool = True,
|
||||
show_examples: bool = None,
|
||||
verbose: bool = False,
|
||||
h2ocolors: bool = True,
|
||||
height: int = 600,
|
||||
show_lora: bool = True,
|
||||
login_mode_if_model0: bool = False,
|
||||
block_gradio_exit: bool = True,
|
||||
concurrency_count: int = 1,
|
||||
api_open: bool = False,
|
||||
allow_api: bool = True,
|
||||
input_lines: int = 1,
|
||||
gradio_size: str = None,
|
||||
auth: typing.List[typing.Tuple[str, str]] = None,
|
||||
max_max_time=None,
|
||||
max_max_new_tokens=None,
|
||||
sanitize_user_prompt: bool = False,
|
||||
sanitize_bot_response: bool = False,
|
||||
extra_model_options: typing.List[str] = [],
|
||||
extra_lora_options: typing.List[str] = [],
|
||||
extra_server_options: typing.List[str] = [],
|
||||
score_model: str = "OpenAssistant/reward-model-deberta-v3-large-v2",
|
||||
eval_filename: str = None,
|
||||
eval_prompts_only_num: int = 0,
|
||||
eval_prompts_only_seed: int = 1234,
|
||||
eval_as_output: bool = False,
|
||||
langchain_mode: str = "UserData",
|
||||
langchain_action: str = LangChainAction.QUERY.value,
|
||||
force_langchain_evaluate: bool = False,
|
||||
visible_langchain_modes: list = ["UserData", "MyData"],
|
||||
# WIP:
|
||||
# visible_langchain_actions: list = langchain_actions.copy(),
|
||||
visible_langchain_actions: list = [
|
||||
LangChainAction.QUERY.value,
|
||||
LangChainAction.SUMMARIZE_MAP.value,
|
||||
],
|
||||
document_choice: list = [DocumentChoices.All_Relevant.name],
|
||||
user_path: str = "apps/language_models/langchain/user_path/",
|
||||
detect_user_path_changes_every_query: bool = False,
|
||||
load_db_if_exists: bool = True,
|
||||
keep_sources_in_context: bool = False,
|
||||
db_type: str = "chroma",
|
||||
use_openai_embedding: bool = False,
|
||||
use_openai_model: bool = False,
|
||||
hf_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
|
||||
allow_upload_to_user_data: bool = True,
|
||||
allow_upload_to_my_data: bool = True,
|
||||
enable_url_upload: bool = True,
|
||||
enable_text_upload: bool = True,
|
||||
enable_sources_list: bool = True,
|
||||
chunk: bool = True,
|
||||
chunk_size: int = 512,
|
||||
top_k_docs: int = None,
|
||||
reverse_docs: bool = True,
|
||||
auto_reduce_chunks: bool = True,
|
||||
max_chunks: int = 100,
|
||||
n_jobs: int = -1,
|
||||
enable_captions: bool = True,
|
||||
captions_model: str = "Salesforce/blip-image-captioning-base",
|
||||
pre_load_caption_model: bool = False,
|
||||
caption_gpu: bool = True,
|
||||
enable_ocr: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
:param load_8bit: load model in 8-bit using bitsandbytes
|
||||
:param load_4bit: load model in 4-bit using bitsandbytes
|
||||
:param load_half: load model in float16
|
||||
:param load_gptq: to load model with GPTQ, put model_basename here, e.g. gptq_model-4bit--1g
|
||||
:param use_safetensors: to use safetensors version (assumes file/HF points to safe tensors version)
|
||||
:param infer_devices: whether to control devices with gpu_id. If False, then spread across GPUs
|
||||
:param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab
|
||||
:param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model.
|
||||
:param lora_weights: LORA weights path/HF link
|
||||
:param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
|
||||
:param compile_model Whether to compile the model
|
||||
:param use_cache: Whether to use caching in model (some models fail when multiple threads use)
|
||||
:param inference_server: Consume base_model as type of model at this address
|
||||
Address can be text-generation-server hosting that base_model
|
||||
e.g. python generate.py --inference_server="http://192.168.1.46:6112" --base_model=h2oai/h2ogpt-oasst1-512-12b
|
||||
Or Address can be "openai_chat" or "openai" for OpenAI API
|
||||
e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
|
||||
e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
|
||||
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
||||
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
|
||||
:param model_lock: Lock models to specific combinations, for ease of use and extending to many models
|
||||
Only used if gradio = True
|
||||
List of dicts, each dict has base_model, tokenizer_base_model, lora_weights, inference_server, prompt_type, and prompt_dict
|
||||
If all models have same prompt_type, and prompt_dict, can still specify that once in CLI outside model_lock as default for dict
|
||||
Can specify model_lock instead of those items on CLI
|
||||
As with CLI itself, base_model can infer prompt_type and prompt_dict if in prompter.py.
|
||||
Also, tokenizer_base_model and lora_weights are optional.
|
||||
Also, inference_server is optional if loading model from local system.
|
||||
All models provided will automatically appear in compare model mode
|
||||
Model loading-unloading and related choices will be disabled. Model/lora/server adding will be disabled
|
||||
:param model_lock_columns: How many columns to show if locking models (and so showing all at once)
|
||||
If None, then defaults to up to 3
|
||||
if -1, then all goes into 1 row
|
||||
Maximum value is 4 due to non-dynamic gradio rendering elements
|
||||
:param fail_if_cannot_connect: if doing model locking (e.g. with many models), fail if True. Otherwise ignore.
|
||||
Useful when many endpoints and want to just see what works, but still have to wait for timeout.
|
||||
:param temperature: generation temperature
|
||||
:param top_p: generation top_p
|
||||
:param top_k: generation top_k
|
||||
:param num_beams: generation number of beams
|
||||
:param repetition_penalty: generation repetition penalty
|
||||
:param num_return_sequences: generation number of sequences (1 forced for chat)
|
||||
:param do_sample: generation sample
|
||||
:param max_new_tokens: generation max new tokens
|
||||
:param min_new_tokens: generation min tokens
|
||||
:param early_stopping: generation early stopping
|
||||
:param max_time: maximum time to allow for generation
|
||||
:param memory_restriction_level: 0 = no restriction to tokens or model, 1 = some restrictions on token 2 = HF like restriction 3 = very low memory case
|
||||
:param debug: enable debug mode
|
||||
:param save_dir: directory chat data is saved to
|
||||
:param share: whether to share the gradio app with sharable URL
|
||||
:param local_files_only: whether to only use local files instead of doing to HF for models
|
||||
:param resume_download: whether to resume downloads from HF for models
|
||||
:param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
|
||||
:param trust_remote_code: whether to use trust any code needed for HF model
|
||||
:param offload_folder: path for spilling model onto disk
|
||||
:param src_lang: source languages to include if doing translation (None = all)
|
||||
:param tgt_lang: target languages to include if doing translation (None = all)
|
||||
:param cli: whether to use CLI (non-gradio) interface.
|
||||
:param cli_loop: whether to loop for CLI (False usually only for testing)
|
||||
:param gradio: whether to enable gradio, or to enable benchmark mode
|
||||
:param gradio_offline_level: > 0, then change fonts so full offline
|
||||
== 1 means backend won't need internet for fonts, but front-end UI might if font not cached
|
||||
== 2 means backend and frontend don't need internet to download any fonts.
|
||||
Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
|
||||
This option further disables google fonts for downloading, which is less intrusive than uploading,
|
||||
but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
|
||||
Also set --share=False to avoid sharing a gradio live link.
|
||||
:param chat: whether to enable chat mode with chat history
|
||||
:param chat_context: whether to use extra helpful context if human_bot
|
||||
:param stream_output: whether to stream output
|
||||
:param show_examples: whether to show clickable examples in gradio
|
||||
:param verbose: whether to show verbose prints
|
||||
:param h2ocolors: whether to use H2O.ai theme
|
||||
:param height: height of chat window
|
||||
:param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
|
||||
:param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
|
||||
:param block_gradio_exit: whether to block gradio exit (used for testing)
|
||||
:param concurrency_count: gradio concurrency count (1 is optimal for LLMs)
|
||||
:param api_open: If False, don't let API calls skip gradio queue
|
||||
:param allow_api: whether to allow API calls at all to gradio server
|
||||
:param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
|
||||
:param gradio_size: Overall size of text and spaces: "xsmall", "small", "medium", "large".
|
||||
Small useful for many chatbots in model_lock mode
|
||||
:param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
|
||||
e.g. --auth=[('jon','password')] with no spaces
|
||||
:param max_max_time: Maximum max_time for gradio slider
|
||||
:param max_max_new_tokens: Maximum max_new_tokens for gradio slider
|
||||
:param sanitize_user_prompt: whether to remove profanity from user input (slows down input processing)
|
||||
:param sanitize_bot_response: whether to remove profanity and repeat lines from bot output (about 2x slower generation for long streaming cases due to better_profanity being slow)
|
||||
:param extra_model_options: extra models to show in list in gradio
|
||||
:param extra_lora_options: extra LORA to show in list in gradio
|
||||
:param extra_server_options: extra servers to show in list in gradio
|
||||
:param score_model: which model to score responses (None means no scoring)
|
||||
:param eval_filename: json file to use for evaluation, if None is sharegpt
|
||||
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
|
||||
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
||||
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
||||
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
||||
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
||||
:param langchain_action: Mode langchain operations in on documents.
|
||||
Query: Make query of document(s)
|
||||
Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce
|
||||
Summarize_all: Summarize document(s) using entire document at once
|
||||
Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary
|
||||
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
|
||||
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
||||
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
||||
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
|
||||
Expensive for large number of files, so not done by default. By default only detect changes during db loading.
|
||||
:param visible_langchain_modes: dbs to generate at launch to be ready for LLM
|
||||
Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
|
||||
But wiki_full is expensive and requires preparation
|
||||
To allow scratch space only live in session, add 'MyData' to list
|
||||
Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
|
||||
FIXME: Avoid 'All' for now, not implemented
|
||||
:param visible_langchain_actions: Which actions to allow
|
||||
:param document_choice: Default document choice when taking subset of collection
|
||||
:param load_db_if_exists: Whether to load chroma db if exists or re-generate db
|
||||
:param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
|
||||
:param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
|
||||
:param use_openai_embedding: Whether to use OpenAI embeddings for vector db
|
||||
:param use_openai_model: Whether to use OpenAI model for use with vector db
|
||||
:param hf_embedding_model: Which HF embedding model to use for vector db
|
||||
Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v1 if no GPUs
|
||||
Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
|
||||
Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
|
||||
We support automatically changing of embeddings for chroma, with a backup of db made if this is done
|
||||
:param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
|
||||
:param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
|
||||
:param enable_url_upload: Whether to allow upload from URL
|
||||
:param enable_text_upload: Whether to allow upload of text
|
||||
:param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
|
||||
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
||||
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
|
||||
:param top_k_docs: number of chunks to give LLM
|
||||
:param reverse_docs: whether to reverse docs order so most relevant is closest to question.
|
||||
Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
|
||||
But smaller 6_9 models fail to use newest context and can get stuck on old information.
|
||||
:param auto_reduce_chunks: Whether to automatically reduce top_k_docs to fit context given prompt
|
||||
:param max_chunks: If top_k_docs=-1, maximum number of chunks to allow
|
||||
:param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
|
||||
:param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
|
||||
:param captions_model: Which model to use for captions.
|
||||
captions_model: str = "Salesforce/blip-image-captioning-base", # continue capable
|
||||
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
|
||||
captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
|
||||
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
|
||||
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
|
||||
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
|
||||
Recommended if using larger caption model
|
||||
:param caption_gpu: If support caption, then use GPU if exists
|
||||
:param enable_ocr: Whether to support OCR on images
|
||||
:return:
|
||||
"""
|
||||
if base_model is None:
|
||||
base_model = ""
|
||||
if tokenizer_base_model is None:
|
||||
tokenizer_base_model = ""
|
||||
if lora_weights is None:
|
||||
lora_weights = ""
|
||||
if inference_server is None:
|
||||
inference_server = ""
|
||||
|
||||
# listen to env if set
|
||||
model_lock = os.getenv("model_lock", str(model_lock))
|
||||
model_lock = ast.literal_eval(model_lock)
|
||||
|
||||
if model_lock:
|
||||
assert gradio, "model_lock only supported for gradio=True"
|
||||
if len(model_lock) > 1:
|
||||
assert (
|
||||
chat
|
||||
), "model_lock only works for multiple models for chat=True"
|
||||
assert not cli, "model_lock only supported for cli=False"
|
||||
assert not (
|
||||
not cli and not gradio
|
||||
), "model_lock only supported for eval (cli=gradio=False)"
|
||||
assert not base_model, "Don't specify model_lock and base_model"
|
||||
assert (
|
||||
not tokenizer_base_model
|
||||
), "Don't specify model_lock and tokenizer_base_model"
|
||||
assert (
|
||||
not lora_weights
|
||||
), "Don't specify model_lock and lora_weights"
|
||||
assert (
|
||||
not inference_server
|
||||
), "Don't specify model_lock and inference_server"
|
||||
# assert not prompt_type, "Don't specify model_lock and prompt_type"
|
||||
# assert not prompt_dict, "Don't specify model_lock and prompt_dict"
|
||||
|
||||
n_jobs = int(os.getenv("n_jobs", str(n_jobs)))
|
||||
is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", "0")))
|
||||
is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", "0")))
|
||||
is_public = (
|
||||
is_hf or is_gpth2oai
|
||||
) # multi-user case with fixed model and disclaimer
|
||||
if memory_restriction_level is None:
|
||||
memory_restriction_level = (
|
||||
2 if is_hf else 0
|
||||
) # 2 assumes run on 24GB consumer GPU
|
||||
else:
|
||||
assert 0 <= memory_restriction_level <= 3, (
|
||||
"Bad memory_restriction_level=%s" % memory_restriction_level
|
||||
)
|
||||
if is_public and os.getenv("n_jobs") is None:
|
||||
n_jobs = max(1, min(os.cpu_count() // 2, 8))
|
||||
admin_pass = os.getenv("ADMIN_PASS")
|
||||
# will sometimes appear in UI or sometimes actual generation, but maybe better than empty result
|
||||
# but becomes unrecoverable sometimes if raise, so just be silent for now
|
||||
raise_generate_gpu_exceptions = True
|
||||
|
||||
# allow set token directly
|
||||
use_auth_token = os.environ.get(
|
||||
"HUGGINGFACE_API_TOKEN", use_auth_token
|
||||
)
|
||||
allow_upload_to_user_data = bool(
|
||||
int(
|
||||
os.environ.get(
|
||||
"allow_upload_to_user_data",
|
||||
str(int(allow_upload_to_user_data)),
|
||||
)
|
||||
)
|
||||
)
|
||||
allow_upload_to_my_data = bool(
|
||||
int(
|
||||
os.environ.get(
|
||||
"allow_upload_to_my_data",
|
||||
str(int(allow_upload_to_my_data)),
|
||||
)
|
||||
)
|
||||
)
|
||||
height = int(os.environ.get("HEIGHT", height))
|
||||
h2ocolors = bool(int(os.getenv("h2ocolors", h2ocolors)))
|
||||
|
||||
# allow enabling langchain via ENV
|
||||
# FIRST PLACE where LangChain referenced, but no imports related to it
|
||||
langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
|
||||
assert langchain_mode in langchain_modes, (
|
||||
"Invalid langchain_mode %s" % langchain_mode
|
||||
)
|
||||
visible_langchain_modes = ast.literal_eval(
|
||||
os.environ.get(
|
||||
"visible_langchain_modes", str(visible_langchain_modes)
|
||||
)
|
||||
)
|
||||
if (
|
||||
langchain_mode not in visible_langchain_modes
|
||||
and langchain_mode in langchain_modes
|
||||
):
|
||||
visible_langchain_modes += [langchain_mode]
|
||||
|
||||
assert langchain_action in langchain_actions, (
|
||||
"Invalid langchain_action %s" % langchain_action
|
||||
)
|
||||
|
||||
# if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
|
||||
if LangChainMode.MY_DATA.value not in visible_langchain_modes:
|
||||
allow_upload_to_my_data = False
|
||||
if LangChainMode.USER_DATA.value not in visible_langchain_modes:
|
||||
allow_upload_to_user_data = False
|
||||
|
||||
if is_public:
|
||||
allow_upload_to_user_data = False
|
||||
input_lines = 1 # ensure set, for ease of use
|
||||
temperature = 0.2 if temperature is None else temperature
|
||||
top_p = 0.85 if top_p is None else top_p
|
||||
top_k = 70 if top_k is None else top_k
|
||||
if is_hf:
|
||||
do_sample = True if do_sample is None else do_sample
|
||||
top_k_docs = 3 if top_k_docs is None else top_k_docs
|
||||
else:
|
||||
# by default don't sample, too chatty
|
||||
do_sample = False if do_sample is None else do_sample
|
||||
top_k_docs = 4 if top_k_docs is None else top_k_docs
|
||||
|
||||
if memory_restriction_level == 2:
|
||||
if not base_model and not inference_server and not model_lock:
|
||||
base_model = "h2oai/h2ogpt-oasst1-512-12b"
|
||||
# don't set load_8bit if passed base_model, doesn't always work so can't just override
|
||||
load_8bit = True
|
||||
load_4bit = (
|
||||
False # FIXME - consider using 4-bit instead of 8-bit
|
||||
)
|
||||
elif not inference_server:
|
||||
top_k_docs = 10 if top_k_docs is None else top_k_docs
|
||||
if memory_restriction_level >= 2:
|
||||
load_8bit = True
|
||||
load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
|
||||
if hf_embedding_model is None:
|
||||
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
top_k_docs = 3 if top_k_docs is None else top_k_docs
|
||||
if top_k_docs is None:
|
||||
top_k_docs = 3
|
||||
if is_public:
|
||||
if not max_time:
|
||||
max_time = 60 * 2
|
||||
if not max_max_time:
|
||||
max_max_time = max_time
|
||||
if not max_new_tokens:
|
||||
max_new_tokens = 256
|
||||
if not max_max_new_tokens:
|
||||
max_max_new_tokens = 256
|
||||
else:
|
||||
if not max_max_time:
|
||||
max_max_time = 60 * 20
|
||||
if not max_max_new_tokens:
|
||||
max_max_new_tokens = 512
|
||||
if is_hf:
|
||||
# must override share if in spaces
|
||||
share = False
|
||||
if not max_time:
|
||||
max_time = 60 * 1
|
||||
if not max_max_time:
|
||||
max_max_time = max_time
|
||||
# HF accounted for later in get_max_max_new_tokens()
|
||||
save_dir = os.getenv("SAVE_DIR", save_dir)
|
||||
score_model = os.getenv("SCORE_MODEL", score_model)
|
||||
if score_model == "None" or score_model is None:
|
||||
score_model = ""
|
||||
concurrency_count = int(
|
||||
os.getenv("CONCURRENCY_COUNT", concurrency_count)
|
||||
)
|
||||
api_open = bool(int(os.getenv("API_OPEN", str(int(api_open)))))
|
||||
allow_api = bool(int(os.getenv("ALLOW_API", str(int(allow_api)))))
|
||||
|
||||
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
||||
if n_gpus == 0:
|
||||
gpu_id = None
|
||||
load_8bit = False
|
||||
load_4bit = False
|
||||
load_half = False
|
||||
load_gptq = ""
|
||||
use_safetensors = False
|
||||
infer_devices = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cudnn.enabled = False
|
||||
torch.set_default_dtype(torch.float32)
|
||||
if (
|
||||
psutil.virtual_memory().available < 94 * 1024**3
|
||||
and not inference_server
|
||||
and not model_lock
|
||||
):
|
||||
# 12B uses ~94GB
|
||||
# 6.9B uses ~47GB
|
||||
base_model = (
|
||||
"h2oai/h2ogpt-oig-oasst1-512-6_9b"
|
||||
if not base_model
|
||||
else base_model
|
||||
)
|
||||
if hf_embedding_model is None:
|
||||
# if no GPUs, use simpler embedding model to avoid cost in time
|
||||
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
else:
|
||||
if hf_embedding_model is None:
|
||||
# if still None, then set default
|
||||
hf_embedding_model = "hkunlp/instructor-large"
|
||||
|
||||
# get defaults
|
||||
if base_model:
|
||||
model_lower = base_model.lower()
|
||||
elif model_lock:
|
||||
# have 0th model be thought of as normal model
|
||||
assert len(model_lock) > 0 and model_lock[0]["base_model"]
|
||||
model_lower = model_lock[0]["base_model"].lower()
|
||||
else:
|
||||
model_lower = ""
|
||||
if not gradio:
|
||||
# force, else not single response like want to look at
|
||||
stream_output = False
|
||||
# else prompt removal can mess up output
|
||||
chat = False
|
||||
# hard-coded defaults
|
||||
first_para = False
|
||||
text_limit = None
|
||||
|
||||
if offload_folder:
|
||||
makedirs(offload_folder)
|
||||
if user_path:
|
||||
makedirs(user_path)
|
||||
|
||||
(
|
||||
placeholder_instruction,
|
||||
placeholder_input,
|
||||
stream_output,
|
||||
show_examples,
|
||||
prompt_type,
|
||||
prompt_dict,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
num_beams,
|
||||
max_new_tokens,
|
||||
min_new_tokens,
|
||||
early_stopping,
|
||||
max_time,
|
||||
repetition_penalty,
|
||||
num_return_sequences,
|
||||
do_sample,
|
||||
src_lang,
|
||||
tgt_lang,
|
||||
examples,
|
||||
task_info,
|
||||
) = self.get_generate_params(
|
||||
model_lower,
|
||||
chat,
|
||||
stream_output,
|
||||
show_examples,
|
||||
prompt_type,
|
||||
prompt_dict,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
num_beams,
|
||||
max_new_tokens,
|
||||
min_new_tokens,
|
||||
early_stopping,
|
||||
max_time,
|
||||
repetition_penalty,
|
||||
num_return_sequences,
|
||||
do_sample,
|
||||
top_k_docs,
|
||||
chunk,
|
||||
chunk_size,
|
||||
verbose,
|
||||
)
|
||||
|
||||
git_hash = get_githash()
|
||||
locals_dict = locals()
|
||||
locals_print = "\n".join(
|
||||
["%s: %s" % (k, v) for k, v in locals_dict.items()]
|
||||
)
|
||||
if verbose:
|
||||
print(f"Generating model with params:\n{locals_print}", flush=True)
|
||||
print(
|
||||
"Command: %s\nHash: %s" % (str(" ".join(sys.argv)), git_hash),
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if langchain_mode != "Disabled":
|
||||
# SECOND PLACE where LangChain referenced, but all imports are kept local so not required
|
||||
from gpt_langchain import prep_langchain, get_some_dbs_from_hf
|
||||
|
||||
if is_hf:
|
||||
get_some_dbs_from_hf()
|
||||
dbs = {}
|
||||
for langchain_mode1 in visible_langchain_modes:
|
||||
if langchain_mode1 in ["MyData"]:
|
||||
# don't use what is on disk, remove it instead
|
||||
for gpath1 in glob.glob(
|
||||
os.path.join(
|
||||
scratch_base_dir, "db_dir_%s*" % langchain_mode1
|
||||
)
|
||||
):
|
||||
if os.path.isdir(gpath1):
|
||||
print(
|
||||
"Removing old MyData: %s" % gpath1, flush=True
|
||||
)
|
||||
remove(gpath1)
|
||||
continue
|
||||
if langchain_mode1 in ["All"]:
|
||||
# FIXME: All should be avoided until scans over each db, shouldn't be separate db
|
||||
continue
|
||||
persist_directory1 = (
|
||||
"db_dir_%s" % langchain_mode1
|
||||
) # single place, no special names for each case
|
||||
try:
|
||||
db = prep_langchain(
|
||||
persist_directory1,
|
||||
load_db_if_exists,
|
||||
db_type,
|
||||
use_openai_embedding,
|
||||
langchain_mode1,
|
||||
user_path,
|
||||
hf_embedding_model,
|
||||
device=self.device,
|
||||
kwargs_make_db=locals(),
|
||||
)
|
||||
finally:
|
||||
# in case updated embeddings or created new embeddings
|
||||
clear_torch_cache()
|
||||
dbs[langchain_mode1] = db
|
||||
# remove None db's so can just rely upon k in dbs for if hav db
|
||||
dbs = {k: v for k, v in dbs.items() if v is not None}
|
||||
else:
|
||||
dbs = {}
|
||||
# import control
|
||||
if os.environ.get("TEST_LANGCHAIN_IMPORT"):
|
||||
assert (
|
||||
"gpt_langchain" not in sys.modules
|
||||
), "Dev bug, import of langchain when should not have"
|
||||
assert (
|
||||
"langchain" not in sys.modules
|
||||
), "Dev bug, import of langchain when should not have"
|
||||
|
||||
model_state_none = dict(
|
||||
model=None,
|
||||
tokenizer=None,
|
||||
device=None,
|
||||
base_model=None,
|
||||
tokenizer_base_model=None,
|
||||
lora_weights=None,
|
||||
inference_server=None,
|
||||
prompt_type=None,
|
||||
prompt_dict=None,
|
||||
)
|
||||
|
||||
if cli:
|
||||
from cli import run_cli
|
||||
|
||||
return run_cli(
|
||||
**get_kwargs(
|
||||
run_cli, exclude_names=["model_state0"], **locals()
|
||||
)
|
||||
)
|
||||
elif not gradio:
|
||||
from eval import run_eval
|
||||
|
||||
return run_eval(
|
||||
**get_kwargs(
|
||||
run_eval, exclude_names=["model_state0"], **locals()
|
||||
)
|
||||
)
|
||||
elif gradio:
|
||||
# imported here so don't require gradio to run generate
|
||||
from gradio_runner import go_gradio
|
||||
|
||||
# get default model
|
||||
model_states = []
|
||||
model_list = [
|
||||
dict(
|
||||
base_model=base_model,
|
||||
tokenizer_base_model=tokenizer_base_model,
|
||||
lora_weights=lora_weights,
|
||||
inference_server=inference_server,
|
||||
prompt_type=prompt_type,
|
||||
prompt_dict=prompt_dict,
|
||||
)
|
||||
]
|
||||
model_list0 = copy.deepcopy(
|
||||
model_list
|
||||
) # just strings, safe to deepcopy
|
||||
model_state0 = model_state_none.copy()
|
||||
assert len(model_state_none) == len(model_state0)
|
||||
if model_lock:
|
||||
model_list = model_lock
|
||||
for model_dict in reversed(model_list):
|
||||
# do reverse, so first is default base_model etc., so some logic works in go_gradio() more easily
|
||||
# handles defaults user didn't have to pass
|
||||
model_dict["base_model"] = base_model1 = model_dict.get(
|
||||
"base_model", ""
|
||||
)
|
||||
model_dict[
|
||||
"tokenizer_base_model"
|
||||
] = tokenizer_base_model1 = model_dict.get(
|
||||
"tokenizer_base_model", ""
|
||||
)
|
||||
model_dict["lora_weights"] = lora_weights1 = model_dict.get(
|
||||
"lora_weights", ""
|
||||
)
|
||||
model_dict[
|
||||
"inference_server"
|
||||
] = inference_server1 = model_dict.get("inference_server", "")
|
||||
prompt_type1 = model_dict.get(
|
||||
"prompt_type", model_list0[0]["prompt_type"]
|
||||
) # don't use mutated value
|
||||
# try to infer, ignore empty initial state leading to get_generate_params -> 'plain'
|
||||
if model_dict.get("prompt_type") is None:
|
||||
model_lower1 = base_model1.lower()
|
||||
if model_lower1 in inv_prompt_type_to_model_lower:
|
||||
prompt_type1 = inv_prompt_type_to_model_lower[
|
||||
model_lower1
|
||||
]
|
||||
prompt_dict1, error0 = get_prompt(
|
||||
prompt_type1,
|
||||
"",
|
||||
chat=False,
|
||||
context="",
|
||||
reduced=False,
|
||||
making_context=False,
|
||||
return_dict=True,
|
||||
)
|
||||
else:
|
||||
prompt_dict1 = prompt_dict
|
||||
else:
|
||||
prompt_dict1 = prompt_dict
|
||||
model_dict["prompt_type"] = prompt_type1
|
||||
model_dict["prompt_dict"] = prompt_dict1 = model_dict.get(
|
||||
"prompt_dict", prompt_dict1
|
||||
)
|
||||
all_kwargs = locals().copy()
|
||||
all_kwargs.update(
|
||||
dict(
|
||||
base_model=base_model1,
|
||||
tokenizer_base_model=tokenizer_base_model1,
|
||||
lora_weights=lora_weights1,
|
||||
inference_server=inference_server1,
|
||||
)
|
||||
)
|
||||
if base_model1 and not login_mode_if_model0:
|
||||
model0, tokenizer0, _ = self.get_model(
|
||||
reward_type=False,
|
||||
**get_kwargs(
|
||||
self.get_model,
|
||||
exclude_names=["reward_type"],
|
||||
**all_kwargs,
|
||||
),
|
||||
)
|
||||
else:
|
||||
# if empty model, then don't load anything, just get gradio up
|
||||
model0, tokenizer0, _ = None, None, None
|
||||
if model0 is None:
|
||||
if fail_if_cannot_connect:
|
||||
raise RuntimeError("Could not connect, see logs")
|
||||
# skip
|
||||
if isinstance(model_lock, list):
|
||||
model_lock.remove(model_dict)
|
||||
continue
|
||||
model_state_trial = dict(
|
||||
model=model0, tokenizer=tokenizer0, device=self.device
|
||||
)
|
||||
model_state_trial.update(model_dict)
|
||||
assert len(model_state_none) == len(model_state_trial)
|
||||
print("Model %s" % model_dict, flush=True)
|
||||
if model_lock:
|
||||
# last in iteration will be first
|
||||
model_states.insert(0, model_state_trial)
|
||||
# fill model_state0 so go_gradio() easier, manage model_states separately
|
||||
model_state0 = model_state_trial.copy()
|
||||
else:
|
||||
model_state0 = model_state_trial.copy()
|
||||
assert len(model_state_none) == len(model_state0)
|
||||
|
||||
# get score model
|
||||
all_kwargs = locals().copy()
|
||||
smodel, stokenizer, _ = self.get_score_model(
|
||||
reward_type=True,
|
||||
**get_kwargs(
|
||||
self.get_score_model,
|
||||
exclude_names=["reward_type"],
|
||||
**all_kwargs,
|
||||
),
|
||||
)
|
||||
score_model_state0 = dict(
|
||||
model=smodel,
|
||||
tokenizer=stokenizer,
|
||||
device=self.device,
|
||||
base_model=score_model,
|
||||
tokenizer_base_model="",
|
||||
lora_weights="",
|
||||
inference_server="",
|
||||
prompt_type="",
|
||||
prompt_dict="",
|
||||
)
|
||||
|
||||
if enable_captions:
|
||||
if pre_load_caption_model:
|
||||
from image_captions import H2OImageCaptionLoader
|
||||
|
||||
caption_loader = H2OImageCaptionLoader(
|
||||
caption_gpu=caption_gpu
|
||||
).load_model()
|
||||
else:
|
||||
caption_loader = "gpu" if caption_gpu else "cpu"
|
||||
else:
|
||||
caption_loader = False
|
||||
|
||||
# assume gradio needs everything
|
||||
go_gradio(**locals())
|
||||
|
||||
def get_config(
|
||||
self,
|
||||
base_model,
|
||||
|
||||
@@ -87,7 +87,7 @@ from langchain.document_loaders import (
|
||||
UnstructuredExcelLoader,
|
||||
)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
|
||||
from langchain.chains.question_answering import load_qa_chain
|
||||
from expanded_pipelines import load_qa_chain
|
||||
from langchain.docstore.document import Document
|
||||
from langchain import PromptTemplate, HuggingFaceTextGenInference
|
||||
from langchain.vectorstores import Chroma
|
||||
@@ -2958,56 +2958,8 @@ def get_similarity_chain(
|
||||
template=template,
|
||||
)
|
||||
chain = load_qa_chain(llm, prompt=prompt)
|
||||
else:
|
||||
# only if use_openai_model = True, unused normally except in testing
|
||||
chain = load_qa_with_sources_chain(llm)
|
||||
if not use_context:
|
||||
chain_kwargs = dict(input_documents=[], question=query)
|
||||
else:
|
||||
chain_kwargs = dict(input_documents=docs, question=query)
|
||||
chain_kwargs = dict(input_documents=docs, question=query)
|
||||
target = wrapped_partial(chain, chain_kwargs)
|
||||
elif langchain_action in [
|
||||
LangChainAction.SUMMARIZE_MAP.value,
|
||||
LangChainAction.SUMMARIZE_REFINE,
|
||||
LangChainAction.SUMMARIZE_ALL.value,
|
||||
]:
|
||||
from langchain.chains.summarize import load_summarize_chain
|
||||
|
||||
if langchain_action == LangChainAction.SUMMARIZE_MAP.value:
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["text"], template=template
|
||||
)
|
||||
chain = load_summarize_chain(
|
||||
llm,
|
||||
chain_type="map_reduce",
|
||||
map_prompt=prompt,
|
||||
combine_prompt=prompt,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
target = wrapped_partial(
|
||||
chain, {"input_documents": docs}
|
||||
) # , return_only_outputs=True)
|
||||
elif langchain_action == LangChainAction.SUMMARIZE_ALL.value:
|
||||
assert use_template
|
||||
prompt = PromptTemplate(
|
||||
input_variables=["text"], template=template
|
||||
)
|
||||
chain = load_summarize_chain(
|
||||
llm,
|
||||
chain_type="stuff",
|
||||
prompt=prompt,
|
||||
return_intermediate_steps=True,
|
||||
)
|
||||
target = wrapped_partial(chain)
|
||||
elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value:
|
||||
chain = load_summarize_chain(
|
||||
llm, chain_type="refine", return_intermediate_steps=True
|
||||
)
|
||||
target = wrapped_partial(chain)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No such langchain_action=%s" % langchain_action
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("No such langchain_action=%s" % langchain_action)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,225 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Iterable
|
||||
|
||||
from gradio.themes.soft import Soft
|
||||
from gradio.themes import Color, Size
|
||||
from gradio.themes.utils import colors, sizes, fonts
|
||||
|
||||
h2o_yellow = Color(
|
||||
name="yellow",
|
||||
c50="#fffef2",
|
||||
c100="#fff9e6",
|
||||
c200="#ffecb3",
|
||||
c300="#ffe28c",
|
||||
c400="#ffd659",
|
||||
c500="#fec925",
|
||||
c600="#e6ac00",
|
||||
c700="#bf8f00",
|
||||
c800="#a67c00",
|
||||
c900="#664d00",
|
||||
c950="#403000",
|
||||
)
|
||||
h2o_gray = Color(
|
||||
name="gray",
|
||||
c50="#f8f8f8",
|
||||
c100="#e5e5e5",
|
||||
c200="#cccccc",
|
||||
c300="#b2b2b2",
|
||||
c400="#999999",
|
||||
c500="#7f7f7f",
|
||||
c600="#666666",
|
||||
c700="#4c4c4c",
|
||||
c800="#333333",
|
||||
c900="#191919",
|
||||
c950="#0d0d0d",
|
||||
)
|
||||
|
||||
|
||||
text_xsm = Size(
|
||||
name="text_xsm",
|
||||
xxs="4px",
|
||||
xs="5px",
|
||||
sm="6px",
|
||||
md="7px",
|
||||
lg="8px",
|
||||
xl="10px",
|
||||
xxl="12px",
|
||||
)
|
||||
|
||||
|
||||
spacing_xsm = Size(
|
||||
name="spacing_xsm",
|
||||
xxs="1px",
|
||||
xs="1px",
|
||||
sm="1px",
|
||||
md="2px",
|
||||
lg="3px",
|
||||
xl="5px",
|
||||
xxl="7px",
|
||||
)
|
||||
|
||||
|
||||
radius_xsm = Size(
|
||||
name="radius_xsm",
|
||||
xxs="1px",
|
||||
xs="1px",
|
||||
sm="1px",
|
||||
md="2px",
|
||||
lg="3px",
|
||||
xl="5px",
|
||||
xxl="7px",
|
||||
)
|
||||
|
||||
|
||||
class H2oTheme(Soft):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
primary_hue: colors.Color | str = h2o_yellow,
|
||||
secondary_hue: colors.Color | str = h2o_yellow,
|
||||
neutral_hue: colors.Color | str = h2o_gray,
|
||||
spacing_size: sizes.Size | str = sizes.spacing_md,
|
||||
radius_size: sizes.Size | str = sizes.radius_md,
|
||||
text_size: sizes.Size | str = sizes.text_lg,
|
||||
font: fonts.Font
|
||||
| str
|
||||
| Iterable[fonts.Font | str] = (
|
||||
fonts.GoogleFont("Montserrat"),
|
||||
"ui-sans-serif",
|
||||
"system-ui",
|
||||
"sans-serif",
|
||||
),
|
||||
font_mono: fonts.Font
|
||||
| str
|
||||
| Iterable[fonts.Font | str] = (
|
||||
fonts.GoogleFont("IBM Plex Mono"),
|
||||
"ui-monospace",
|
||||
"Consolas",
|
||||
"monospace",
|
||||
),
|
||||
):
|
||||
super().__init__(
|
||||
primary_hue=primary_hue,
|
||||
secondary_hue=secondary_hue,
|
||||
neutral_hue=neutral_hue,
|
||||
spacing_size=spacing_size,
|
||||
radius_size=radius_size,
|
||||
text_size=text_size,
|
||||
font=font,
|
||||
font_mono=font_mono,
|
||||
)
|
||||
super().set(
|
||||
link_text_color="#3344DD",
|
||||
link_text_color_hover="#3344DD",
|
||||
link_text_color_visited="#3344DD",
|
||||
link_text_color_dark="#74abff",
|
||||
link_text_color_hover_dark="#a3c8ff",
|
||||
link_text_color_active_dark="#a3c8ff",
|
||||
link_text_color_visited_dark="#74abff",
|
||||
button_primary_text_color="*neutral_950",
|
||||
button_primary_text_color_dark="*neutral_950",
|
||||
button_primary_background_fill="*primary_500",
|
||||
button_primary_background_fill_dark="*primary_500",
|
||||
block_label_background_fill="*primary_500",
|
||||
block_label_background_fill_dark="*primary_500",
|
||||
block_label_text_color="*neutral_950",
|
||||
block_label_text_color_dark="*neutral_950",
|
||||
block_title_text_color="*neutral_950",
|
||||
block_title_text_color_dark="*neutral_950",
|
||||
block_background_fill_dark="*neutral_950",
|
||||
body_background_fill="*neutral_50",
|
||||
body_background_fill_dark="*neutral_900",
|
||||
background_fill_primary_dark="*block_background_fill",
|
||||
block_radius="0 0 8px 8px",
|
||||
checkbox_label_text_color_selected_dark="#000000",
|
||||
)
|
||||
|
||||
|
||||
class SoftTheme(Soft):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
primary_hue: colors.Color | str = colors.indigo,
|
||||
secondary_hue: colors.Color | str = colors.indigo,
|
||||
neutral_hue: colors.Color | str = colors.gray,
|
||||
spacing_size: sizes.Size | str = sizes.spacing_md,
|
||||
radius_size: sizes.Size | str = sizes.radius_md,
|
||||
text_size: sizes.Size | str = sizes.text_md,
|
||||
font: fonts.Font
|
||||
| str
|
||||
| Iterable[fonts.Font | str] = (
|
||||
fonts.GoogleFont("Montserrat"),
|
||||
"ui-sans-serif",
|
||||
"system-ui",
|
||||
"sans-serif",
|
||||
),
|
||||
font_mono: fonts.Font
|
||||
| str
|
||||
| Iterable[fonts.Font | str] = (
|
||||
fonts.GoogleFont("IBM Plex Mono"),
|
||||
"ui-monospace",
|
||||
"Consolas",
|
||||
"monospace",
|
||||
),
|
||||
):
|
||||
super().__init__(
|
||||
primary_hue=primary_hue,
|
||||
secondary_hue=secondary_hue,
|
||||
neutral_hue=neutral_hue,
|
||||
spacing_size=spacing_size,
|
||||
radius_size=radius_size,
|
||||
text_size=text_size,
|
||||
font=font,
|
||||
font_mono=font_mono,
|
||||
)
|
||||
|
||||
|
||||
h2o_logo = (
|
||||
'<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" width="100%" height="100%"'
|
||||
' viewBox="0 0 600.28 600.28"><defs><style>.cls-1{fill:#fec925;}.cls-2{fill:#161616;}.cls-3{fill:'
|
||||
'#54585a;}</style></defs><g id="Fill-1"><rect class="cls-1" width="600.28" height="600.28" '
|
||||
'rx="23.24"/></g><path class="cls-2" d="M174.33,246.06v92.78H152.86v-38H110.71v38H89.24V246.06h21.'
|
||||
'47v36.58h42.15V246.06Z"/><path class="cls-2" d="M259.81,321.34v17.5H189.7V324.92l35.78-33.8c8.22-7.'
|
||||
"82,9.68-12.59,9.68-17.09,0-7.29-5-11.53-14.85-11.53-7.95,0-14.71,3-19.21,9.27L185.46,261.7c7.15-10"
|
||||
'.47,20.14-17.23,36.84-17.23,20.68,0,34.46,10.6,34.46,27.44,0,9-2.52,17.22-15.51,29.29l-21.33,20.14Z"'
|
||||
'/><path class="cls-2" d="M268.69,292.45c0-27.57,21.47-48,50.76-48s50.76,20.28,50.76,48-21.6,48-50.'
|
||||
"76,48S268.69,320,268.69,292.45Zm79.78,0c0-17.63-12.46-29.69-29-29.69s-29,12.06-29,29.69,12.46,29.69"
|
||||
',29,29.69S348.47,310.08,348.47,292.45Z"/><path class="cls-3" d="M377.23,326.91c0-7.69,5.7-12.73,12.'
|
||||
'85-12.73s12.86,5,12.86,12.73a12.86,12.86,0,1,1-25.71,0Z"/><path class="cls-3" d="M481.4,298.15v40.'
|
||||
"69H462.05V330c-3.84,6.49-11.27,9.94-21.74,9.94-16.7,0-26.64-9.28-26.64-21.61,0-12.59,8.88-21.34,30."
|
||||
"62-21.34h16.43c0-8.87-5.3-14-16.43-14-7.55,0-15.37,2.51-20.54,6.62l-7.43-14.44c7.82-5.57,19.35-8."
|
||||
"62,30.75-8.62C468.81,266.47,481.4,276.54,481.4,298.15Zm-20.68,18.16V309H446.54c-9.67,0-12.72,3.57-"
|
||||
'12.72,8.35,0,5.16,4.37,8.61,11.66,8.61C452.37,326,458.34,322.8,460.72,316.31Z"/><path class="cls-3"'
|
||||
' d="M497.56,246.06c0-6.49,5.17-11.53,12.86-11.53s12.86,4.77,12.86,11.13c0,6.89-5.17,11.93-12.86,'
|
||||
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
|
||||
)
|
||||
|
||||
|
||||
def get_h2o_title(title, description):
|
||||
# NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
|
||||
return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
|
||||
{description}
|
||||
</div>
|
||||
<div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
|
||||
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
|
||||
<h1 style="line-height:60px">{title}</h1>
|
||||
</div>
|
||||
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
||||
<img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
def get_simple_title(title, description):
|
||||
return f"""{description}<h1 align="center"> {title}</h1>"""
|
||||
|
||||
|
||||
def get_dark_js():
|
||||
return """() => {
|
||||
if (document.querySelectorAll('.dark').length) {
|
||||
document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark'));
|
||||
} else {
|
||||
document.querySelector('body').classList.add('dark');
|
||||
}
|
||||
}"""
|
||||
@@ -1,53 +0,0 @@
|
||||
def get_css(kwargs) -> str:
|
||||
if kwargs["h2ocolors"]:
|
||||
css_code = """footer {visibility: hidden;}
|
||||
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
||||
body.dark{background:linear-gradient(#000000,#0d0d0d);}
|
||||
"""
|
||||
else:
|
||||
css_code = """footer {visibility: hidden}"""
|
||||
|
||||
css_code += make_css_base()
|
||||
return css_code
|
||||
|
||||
|
||||
def make_css_base() -> str:
|
||||
return """
|
||||
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
||||
|
||||
body.dark{#warning {background-color: #555555};}
|
||||
|
||||
#small_btn {
|
||||
margin: 0.6em 0em 0.55em 0;
|
||||
max-width: 20em;
|
||||
min-width: 5em !important;
|
||||
height: 5em;
|
||||
font-size: 14px !important;
|
||||
}
|
||||
|
||||
#prompt-form {
|
||||
border: 1px solid var(--primary-500) !important;
|
||||
}
|
||||
|
||||
#prompt-form.block {
|
||||
border-radius: var(--block-radius) !important;
|
||||
}
|
||||
|
||||
#prompt-form textarea {
|
||||
border: 1px solid rgb(209, 213, 219);
|
||||
}
|
||||
|
||||
#prompt-form label > div {
|
||||
margin-top: 4px;
|
||||
}
|
||||
|
||||
button.primary:hover {
|
||||
background-color: var(--primary-600) !important;
|
||||
transition: .2s;
|
||||
}
|
||||
|
||||
#prompt-form-area {
|
||||
margin-bottom: 2.5rem;
|
||||
}
|
||||
.chatsmall chatbot {font-size: 10px !important}
|
||||
"""
|
||||
@@ -1,185 +0,0 @@
|
||||
import os
|
||||
import math
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def make_chatbots(output_label0, output_label0_model2, **kwargs):
|
||||
text_outputs = []
|
||||
chat_kwargs = []
|
||||
for model_state_lock in kwargs["model_states"]:
|
||||
if os.environ.get("DEBUG_MODEL_LOCK"):
|
||||
model_name = (
|
||||
model_state_lock["base_model"]
|
||||
+ " : "
|
||||
+ model_state_lock["inference_server"]
|
||||
)
|
||||
else:
|
||||
model_name = model_state_lock["base_model"]
|
||||
output_label = f"h2oGPT [{model_name}]"
|
||||
min_width = (
|
||||
250
|
||||
if kwargs["gradio_size"] in ["small", "large", "medium"]
|
||||
else 160
|
||||
)
|
||||
chat_kwargs.append(
|
||||
dict(
|
||||
label=output_label,
|
||||
visible=kwargs["model_lock"],
|
||||
elem_classes="chatsmall",
|
||||
height=kwargs["height"] or 400,
|
||||
min_width=min_width,
|
||||
)
|
||||
)
|
||||
|
||||
if kwargs["model_lock_columns"] == -1:
|
||||
kwargs["model_lock_columns"] = len(kwargs["model_states"])
|
||||
if kwargs["model_lock_columns"] is None:
|
||||
kwargs["model_lock_columns"] = 3
|
||||
|
||||
ncols = kwargs["model_lock_columns"]
|
||||
if kwargs["model_states"] == 0:
|
||||
nrows = 0
|
||||
else:
|
||||
nrows = math.ceil(
|
||||
len(kwargs["model_states"]) / kwargs["model_lock_columns"]
|
||||
)
|
||||
|
||||
if kwargs["model_lock_columns"] == 0:
|
||||
# not using model_lock
|
||||
pass
|
||||
elif nrows <= 1:
|
||||
with gr.Row():
|
||||
for chat_kwargs1, model_state_lock in zip(
|
||||
chat_kwargs, kwargs["model_states"]
|
||||
):
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
elif nrows == kwargs["model_states"]:
|
||||
with gr.Row():
|
||||
for chat_kwargs1, model_state_lock in zip(
|
||||
chat_kwargs, kwargs["model_states"]
|
||||
):
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
elif nrows == 2:
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if mii >= len(kwargs["model_states"]) / 2:
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if mii < len(kwargs["model_states"]) / 2:
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
elif nrows == 3:
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if mii >= 1 * len(kwargs["model_states"]) / 3:
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if (
|
||||
mii < 1 * len(kwargs["model_states"]) / 3
|
||||
or mii >= 2 * len(kwargs["model_states"]) / 3
|
||||
):
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if mii < 2 * len(kwargs["model_states"]) / 3:
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
elif nrows >= 4:
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if mii >= 1 * len(kwargs["model_states"]) / 4:
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if (
|
||||
mii < 1 * len(kwargs["model_states"]) / 4
|
||||
or mii >= 2 * len(kwargs["model_states"]) / 4
|
||||
):
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if (
|
||||
mii < 2 * len(kwargs["model_states"]) / 4
|
||||
or mii >= 3 * len(kwargs["model_states"]) / 4
|
||||
):
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
with gr.Row():
|
||||
for mii, (chat_kwargs1, model_state_lock) in enumerate(
|
||||
zip(chat_kwargs, kwargs["model_states"])
|
||||
):
|
||||
if mii < 3 * len(kwargs["model_states"]) / 4:
|
||||
continue
|
||||
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
||||
|
||||
with gr.Row():
|
||||
text_output = gr.Chatbot(
|
||||
label=output_label0,
|
||||
visible=not kwargs["model_lock"],
|
||||
height=kwargs["height"] or 400,
|
||||
)
|
||||
text_output2 = gr.Chatbot(
|
||||
label=output_label0_model2,
|
||||
visible=False and not kwargs["model_lock"],
|
||||
height=kwargs["height"] or 400,
|
||||
)
|
||||
return text_output, text_output2, text_outputs
|
||||
|
||||
|
||||
def make_prompt_form(kwargs, LangChainMode):
|
||||
if kwargs["langchain_mode"] != LangChainMode.DISABLED.value:
|
||||
extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
|
||||
else:
|
||||
extra_prompt_form = ""
|
||||
if kwargs["input_lines"] > 1:
|
||||
instruction_label = (
|
||||
"Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form
|
||||
)
|
||||
else:
|
||||
instruction_label = (
|
||||
"Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
|
||||
)
|
||||
|
||||
with gr.Row(): # elem_id='prompt-form-area'):
|
||||
with gr.Column(scale=50):
|
||||
instruction = gr.Textbox(
|
||||
lines=kwargs["input_lines"],
|
||||
label="Ask anything",
|
||||
placeholder=instruction_label,
|
||||
info=None,
|
||||
elem_id="prompt-form",
|
||||
container=True,
|
||||
)
|
||||
with gr.Row():
|
||||
submit = gr.Button(
|
||||
value="Submit", variant="primary", scale=0, size="sm"
|
||||
)
|
||||
stop_btn = gr.Button(
|
||||
value="Stop", variant="secondary", scale=0, size="sm"
|
||||
)
|
||||
|
||||
return instruction, submit, stop_btn
|
||||
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
from apps.stable_diffusion.src.utils.utils import _compile_module
|
||||
from io import BytesIO
|
||||
import torch_mlir
|
||||
|
||||
from transformers import TextGenerationPipeline
|
||||
from transformers.pipelines.text_generation import ReturnType
|
||||
@@ -20,8 +22,56 @@ import gc
|
||||
from pathlib import Path
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_downloader import download_public_file
|
||||
from shark.shark_importer import import_with_fx
|
||||
from apps.stable_diffusion.src import args
|
||||
|
||||
# Brevitas
|
||||
from typing import List, Tuple
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
return [lhs[0], rhs[0]]
|
||||
else:
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
|
||||
global_device = "cuda"
|
||||
global_precision = "fp16"
|
||||
|
||||
@@ -31,6 +81,67 @@ if not args.run_docuchat_web:
|
||||
tensor_device = "cpu" if args.device == "cpu" else "cuda"
|
||||
|
||||
|
||||
class H2OGPTModel(torch.nn.Module):
|
||||
def __init__(self, device, precision):
|
||||
super().__init__()
|
||||
torch_dtype = (
|
||||
torch.float32
|
||||
if precision == "fp32" or device == "cpu"
|
||||
else torch.float16
|
||||
)
|
||||
device_map = {"": "cpu"} if device == "cpu" else {"": 0}
|
||||
model_kwargs = {
|
||||
"local_files_only": False,
|
||||
"torch_dtype": torch_dtype,
|
||||
"resume_download": True,
|
||||
"use_auth_token": False,
|
||||
"trust_remote_code": True,
|
||||
"offload_folder": "offline_folder",
|
||||
"device_map": device_map,
|
||||
}
|
||||
config = AutoConfig.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
use_auth_token=False,
|
||||
trust_remote_code=True,
|
||||
offload_folder="offline_folder",
|
||||
)
|
||||
self.model = AutoModelForCausalLM.from_pretrained(
|
||||
"h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3",
|
||||
config=config,
|
||||
**model_kwargs,
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Applying weight quantization..")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model.transformer.h,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=128,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
input_dict = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
}
|
||||
output = self.model(
|
||||
**input_dict,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return output.logits[:, -1, :]
|
||||
|
||||
|
||||
class H2OGPTSHARKModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
@@ -42,47 +153,48 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
mlir_path = Path(model_name + "_" + args.precision + ".mlir")
|
||||
shark_module = None
|
||||
|
||||
need_to_compile = False
|
||||
if not vmfb_path.exists():
|
||||
if args.device in ["cuda", "cpu"] and args.precision in [
|
||||
"fp16",
|
||||
"fp32",
|
||||
]:
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Downloading vmfb from shark tank.")
|
||||
need_to_compile = True
|
||||
# Downloading VMFB from shark_tank
|
||||
print("Trying to download pre-compiled vmfb from shark tank.")
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if vmfb_path.exists():
|
||||
print(
|
||||
"Pre-compiled vmfb downloaded from shark tank successfully."
|
||||
)
|
||||
need_to_compile = False
|
||||
|
||||
if need_to_compile:
|
||||
if not mlir_path.exists():
|
||||
print("Trying to download pre-generated mlir from shark tank.")
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(vmfb_path),
|
||||
vmfb_path.absolute(),
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
# Downloading MLIR from shark_tank
|
||||
download_public_file(
|
||||
"gs://shark_tank/langchain/" + str(mlir_path),
|
||||
mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if mlir_path.exists():
|
||||
with open(mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
# Generating the mlir
|
||||
bytecode = self.get_bytecode(tensor_device, args.precision)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode,
|
||||
device=args.device,
|
||||
mlir_dialect="linalg",
|
||||
)
|
||||
print(f"[DEBUG] generating vmfb.")
|
||||
shark_module = _compile_module(
|
||||
shark_module, extended_model_name, []
|
||||
)
|
||||
print("Saved newly generated vmfb.")
|
||||
|
||||
if shark_module is None:
|
||||
if vmfb_path.exists():
|
||||
@@ -97,6 +209,72 @@ class H2OGPTSHARKModel(torch.nn.Module):
|
||||
|
||||
self.model = shark_module
|
||||
|
||||
def get_bytecode(self, device, precision):
|
||||
h2ogpt_model = H2OGPTModel(device, precision)
|
||||
|
||||
compilation_input_ids = torch.randint(
|
||||
low=1, high=10000, size=(1, 400)
|
||||
).to(device=device)
|
||||
compilation_attention_mask = torch.ones(1, 400, dtype=torch.int64).to(
|
||||
device=device
|
||||
)
|
||||
|
||||
h2ogptCompileInput = (
|
||||
compilation_input_ids,
|
||||
compilation_attention_mask,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
h2ogpt_model,
|
||||
h2ogptCompileInput,
|
||||
is_f16=False,
|
||||
precision=precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del h2ogpt_model
|
||||
del self.src_model
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if precision in ["int4", "int8"]:
|
||||
from torch_mlir.compiler_utils import (
|
||||
run_pipeline_with_repro_report,
|
||||
)
|
||||
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*h2ogptCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
print(f"[DEBUG] converting to bytecode")
|
||||
bytecode_stream = BytesIO()
|
||||
module.operation.write_bytecode(bytecode_stream)
|
||||
bytecode = bytecode_stream.getvalue()
|
||||
del module
|
||||
|
||||
return bytecode
|
||||
|
||||
def forward(self, input_ids, attention_mask):
|
||||
result = torch.from_numpy(
|
||||
self.model(
|
||||
@@ -381,6 +559,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
||||
return next_token
|
||||
|
||||
def generate_token(self, **generate_kwargs):
|
||||
del generate_kwargs["max_time"]
|
||||
self.truncated_input_ids = []
|
||||
|
||||
generation_config_ = GenerationConfig.from_model_config(
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# for generate (gradio server) and finetune
|
||||
datasets==2.13.0
|
||||
sentencepiece==0.1.99
|
||||
# gradio==3.37.0
|
||||
huggingface_hub==0.16.4
|
||||
appdirs==1.4.4
|
||||
fire==0.5.0
|
||||
docutils==0.20.1
|
||||
# torch==2.0.1; sys_platform != "darwin" and platform_machine != "arm64"
|
||||
evaluate==0.4.0
|
||||
rouge_score==0.1.2
|
||||
sacrebleu==2.3.1
|
||||
@@ -21,7 +19,7 @@ bitsandbytes==0.39.0
|
||||
accelerate==0.20.3
|
||||
peft==0.4.0
|
||||
# 4.31.0+ breaks load_in_8bit=True (https://github.com/huggingface/transformers/issues/25026)
|
||||
# transformers==4.30.2
|
||||
transformers==4.30.2
|
||||
tokenizers==0.13.3
|
||||
APScheduler==3.10.1
|
||||
|
||||
@@ -67,7 +65,7 @@ tiktoken==0.4.0
|
||||
openai==0.27.8
|
||||
|
||||
# optional for chat with PDF
|
||||
langchain==0.0.235
|
||||
langchain==0.0.202
|
||||
pypdf==3.12.2
|
||||
# avoid textract, requires old six
|
||||
#textract==1.6.5
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import gc
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
@@ -40,9 +41,6 @@ from shark.shark_inference import SharkInference
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
if __name__ == "__main__":
|
||||
import gc
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="vicuna runner",
|
||||
@@ -110,8 +108,15 @@ parser.add_argument(
|
||||
default=None,
|
||||
help="Specify your own huggingface authentication tokens for models like Llama2.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_vicunas",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="For debugging purposes, creates a first_{precision}.mlir and second_{precision}.mlir and stores on disk",
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(lhs: List[int], rhs: List[int], rhs_scale: List[int], rhs_zero_point: List[int], rhs_bit_width: int, rhs_group_size: int) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
@@ -135,6 +140,7 @@ brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics]
|
||||
# fmt: on
|
||||
|
||||
|
||||
class VicunaBase(SharkLLMBase):
|
||||
@@ -144,7 +150,7 @@ class VicunaBase(SharkLLMBase):
|
||||
hf_model_path="TheBloke/vicuna-7B-1.1-HF",
|
||||
max_num_tokens=512,
|
||||
device="cpu",
|
||||
precision="int8"
|
||||
precision="int8",
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
self.max_sequence_length = 256
|
||||
@@ -169,11 +175,15 @@ class VicunaBase(SharkLLMBase):
|
||||
def combine_mlir_scripts(
|
||||
self, first_vicuna_mlir, second_vicuna_mlir, output_name
|
||||
):
|
||||
print(f"[DEBUG] combining first and second mlir")
|
||||
print(f"[DEBIG] output_name = {output_name}")
|
||||
maps1 = []
|
||||
maps2 = []
|
||||
constants = set()
|
||||
f1 = []
|
||||
f2 = []
|
||||
|
||||
print(f"[DEBUG] processing first vircuna mlir")
|
||||
first_vicuna_mlir = first_vicuna_mlir.splitlines()
|
||||
while first_vicuna_mlir:
|
||||
line = first_vicuna_mlir.pop(0)
|
||||
@@ -186,6 +196,7 @@ class VicunaBase(SharkLLMBase):
|
||||
f1.append(line)
|
||||
f1 = f1[:-1]
|
||||
del first_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps1):
|
||||
map_var = map_line.split(" ")[0]
|
||||
@@ -196,6 +207,7 @@ class VicunaBase(SharkLLMBase):
|
||||
for func_line in f1
|
||||
]
|
||||
|
||||
print(f"[DEBUG] processing second vircuna mlir")
|
||||
second_vicuna_mlir = second_vicuna_mlir.splitlines()
|
||||
while second_vicuna_mlir:
|
||||
line = second_vicuna_mlir.pop(0)
|
||||
@@ -209,6 +221,8 @@ class VicunaBase(SharkLLMBase):
|
||||
line = re.sub("forward", "second_vicuna_forward", line)
|
||||
f2.append(line)
|
||||
f2 = f2[:-1]
|
||||
del second_vicuna_mlir
|
||||
gc.collect()
|
||||
|
||||
for i, map_line in enumerate(maps2):
|
||||
map_var = map_line.split(" ")[0]
|
||||
@@ -229,6 +243,7 @@ class VicunaBase(SharkLLMBase):
|
||||
global_var_loading1 = []
|
||||
global_var_loading2 = []
|
||||
|
||||
print(f"[DEBUG] processing constants")
|
||||
counter = 0
|
||||
constants = list(constants)
|
||||
while constants:
|
||||
@@ -238,15 +253,15 @@ class VicunaBase(SharkLLMBase):
|
||||
vname = vname.strip()
|
||||
vbody = re.sub("arith.constant", "", vbody)
|
||||
vbody = vbody.strip()
|
||||
if len(vbody.split(":"))<2:
|
||||
if len(vbody.split(":")) < 2:
|
||||
print(constant)
|
||||
vdtype = vbody.split(":")[-1].strip()
|
||||
fixed_vdtype = vdtype
|
||||
if "c1_i64" in vname:
|
||||
print(constant)
|
||||
counter+=1
|
||||
if counter==2:
|
||||
counter=0
|
||||
counter += 1
|
||||
if counter == 2:
|
||||
counter = 0
|
||||
print("detected duplicate")
|
||||
continue
|
||||
vnames.append(vname)
|
||||
@@ -272,6 +287,7 @@ class VicunaBase(SharkLLMBase):
|
||||
)
|
||||
new_f1, new_f2 = [], []
|
||||
|
||||
print(f"[DEBUG] processing f1")
|
||||
for line in f1:
|
||||
if "func.func" in line:
|
||||
new_f1.append(line)
|
||||
@@ -280,39 +296,63 @@ class VicunaBase(SharkLLMBase):
|
||||
else:
|
||||
new_f1.append(line)
|
||||
|
||||
print(f"[DEBUG] processing f2")
|
||||
for line in f2:
|
||||
if "func.func" in line:
|
||||
new_f2.append(line)
|
||||
for global_var in global_var_loading2:
|
||||
if "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in global_var:
|
||||
if (
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
in global_var
|
||||
):
|
||||
print(global_var)
|
||||
new_f2.append(global_var)
|
||||
else:
|
||||
if "c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in line:
|
||||
new_f2.append("%"+line)
|
||||
new_f2.append("%" + line)
|
||||
else:
|
||||
new_f2.append(line)
|
||||
|
||||
f1 = new_f1
|
||||
f2 = new_f2
|
||||
print(["c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x for x in [maps1, maps2, global_vars, f1, f2]])
|
||||
whole_string = "\n".join(
|
||||
maps1
|
||||
+ maps2
|
||||
+ [module_start]
|
||||
+ global_vars
|
||||
+ f1
|
||||
+ f2
|
||||
+ [module_end]
|
||||
|
||||
del new_f1
|
||||
del new_f2
|
||||
gc.collect()
|
||||
|
||||
print(
|
||||
[
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64" in x
|
||||
for x in [maps1, maps2, global_vars, f1, f2]
|
||||
]
|
||||
)
|
||||
|
||||
f_ = open(output_name, "w+")
|
||||
f_.write(whole_string)
|
||||
f_.close()
|
||||
# doing it this way rather than assembling the whole string
|
||||
# to prevent OOM with 64GiB RAM when encoding the file.
|
||||
|
||||
return whole_string
|
||||
print(f"[DEBUG] Saving mlir to {output_name}")
|
||||
with open(output_name, "w+") as f_:
|
||||
f_.writelines(line + "\n" for line in maps1)
|
||||
f_.writelines(line + "\n" for line in maps2)
|
||||
f_.writelines(line + "\n" for line in [module_start])
|
||||
f_.writelines(line + "\n" for line in global_vars)
|
||||
f_.writelines(line + "\n" for line in f1)
|
||||
f_.writelines(line + "\n" for line in f2)
|
||||
f_.writelines(line + "\n" for line in [module_end])
|
||||
|
||||
del maps1
|
||||
del maps2
|
||||
del module_start
|
||||
del global_vars
|
||||
del f1
|
||||
del f2
|
||||
del module_end
|
||||
gc.collect()
|
||||
|
||||
print(f"[DEBUG] Reading combined mlir back in")
|
||||
with open(output_name, "rb") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
def generate_new_token(self, params, sharded=True):
|
||||
is_first = params["is_first"]
|
||||
if is_first:
|
||||
@@ -336,13 +376,16 @@ class VicunaBase(SharkLLMBase):
|
||||
input_ids = input_ids.reshape([1, input_id_len])
|
||||
if sharded:
|
||||
output = self.shark_model.forward(
|
||||
input_ids, past_key_values=past_key_values, is_first=is_first
|
||||
input_ids,
|
||||
past_key_values=past_key_values,
|
||||
is_first=is_first,
|
||||
)
|
||||
else:
|
||||
token = token.to(torch.int64).reshape([1,1])
|
||||
token = token.to(torch.int64).reshape([1, 1])
|
||||
second_input = (token,) + tuple(past_key_values)
|
||||
output = self.shark_model("second_vicuna_forward", second_input)
|
||||
|
||||
output = self.shark_model(
|
||||
"second_vicuna_forward", second_input
|
||||
)
|
||||
|
||||
if sharded:
|
||||
_logits = output["logits"]
|
||||
@@ -367,6 +410,7 @@ class VicunaBase(SharkLLMBase):
|
||||
|
||||
return ret_dict
|
||||
|
||||
|
||||
class ShardedVicuna(VicunaBase):
|
||||
# Class representing Sharded Vicuna Model
|
||||
def __init__(
|
||||
@@ -785,7 +829,7 @@ class ShardedVicuna(VicunaBase):
|
||||
module = SharkInference(
|
||||
None,
|
||||
device=device,
|
||||
device_idx=idx % 4,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=False,
|
||||
)
|
||||
@@ -798,7 +842,7 @@ class ShardedVicuna(VicunaBase):
|
||||
module = SharkInference(
|
||||
mlirs[idx],
|
||||
device=device,
|
||||
device_idx=idx % 4,
|
||||
device_idx=device_idx,
|
||||
mlir_dialect="tm_tensor",
|
||||
mmap=False,
|
||||
)
|
||||
@@ -945,7 +989,6 @@ class ShardedVicuna(VicunaBase):
|
||||
result_output = self.tokenizer.decode(tokens_generated)
|
||||
return result_output
|
||||
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
@@ -966,6 +1009,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
low_device_memory=False,
|
||||
weight_group_size=128,
|
||||
download_vmfb=False,
|
||||
cache_vicunas=False,
|
||||
) -> None:
|
||||
super().__init__(model_name, hf_model_path, max_num_tokens)
|
||||
if "llama2" in self.model_name and hf_auth_token == None:
|
||||
@@ -992,14 +1036,15 @@ class UnshardedVicuna(VicunaBase):
|
||||
if self.vicuna_vmfb_path == None:
|
||||
self.vicuna_vmfb_path = self.get_model_path(suffix="vmfb")
|
||||
self.tokenizer = self.get_tokenizer()
|
||||
self.cache_vicunas = cache_vicunas
|
||||
self.compile()
|
||||
|
||||
def get_model_path(self, suffix="mlir"):
|
||||
safe_device = self.device.split("-")[0]
|
||||
if suffix == "mlir":
|
||||
return Path(f"vicuna_{self.precision}.{suffix}")
|
||||
return Path(f"{self.model_name}_{self.precision}.{suffix}")
|
||||
return Path(
|
||||
f"vicuna_{self.precision}_{safe_device}.{suffix}"
|
||||
f"{self.model_name}_{self.precision}_{safe_device}.{suffix}"
|
||||
)
|
||||
|
||||
def get_tokenizer(self):
|
||||
@@ -1029,9 +1074,9 @@ class UnshardedVicuna(VicunaBase):
|
||||
return vicuna_model
|
||||
|
||||
def write_in_dynamic_inputs0(self, module, dynamic_input_size):
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna.")
|
||||
print("[DEBUG] writing dynamic inputs to first vicuna")
|
||||
# Current solution for ensuring mlir files support dynamic inputs
|
||||
# TODO find a more elegant way to implement this
|
||||
# TODO: find a more elegant way to implement this
|
||||
new_lines = []
|
||||
module = module.splitlines()
|
||||
while module:
|
||||
@@ -1047,26 +1092,24 @@ class UnshardedVicuna(VicunaBase):
|
||||
if "arith.cmpi" in line:
|
||||
line = re.sub(f"c{dynamic_input_size}", "dim", line)
|
||||
if "%0 = tensor.empty(%dim) : tensor<?xi64>" in line:
|
||||
new_lines.append("%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>")
|
||||
if (
|
||||
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
|
||||
in line
|
||||
):
|
||||
new_lines.append(
|
||||
"%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>"
|
||||
)
|
||||
if "%dim = tensor.dim %arg0, %c1 : tensor<1x?xi64>" in line:
|
||||
continue
|
||||
|
||||
new_lines.append(line)
|
||||
return '\n'.join(new_lines)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
def write_in_dynamic_inputs1(self, module):
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna.")
|
||||
print("[DEBUG] writing dynamic inputs to second vicuna")
|
||||
|
||||
def remove_constant_dim(line):
|
||||
if "c19_i64" in line:
|
||||
line = re.sub("c19_i64", "dim_i64", line)
|
||||
if "19x" in line:
|
||||
line = re.sub("19x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dim)", line
|
||||
)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dim)", line)
|
||||
if "tensor.empty" in line and "?x?" in line:
|
||||
line = re.sub(
|
||||
"tensor.empty\(%dim\)",
|
||||
@@ -1079,256 +1122,268 @@ class UnshardedVicuna(VicunaBase):
|
||||
line = re.sub(" 19,", " %dim,", line)
|
||||
if "20x" in line:
|
||||
line = re.sub("20x", "?x", line)
|
||||
line = re.sub(
|
||||
"tensor.empty\(\)", "tensor.empty(%dimp1)", line
|
||||
)
|
||||
line = re.sub("tensor.empty\(\)", "tensor.empty(%dimp1)", line)
|
||||
if " 20," in line:
|
||||
line = re.sub(" 20,", " %dimp1,", line)
|
||||
return line
|
||||
|
||||
module = module.splitlines()
|
||||
new_lines = []
|
||||
#Using a while loop and the pop method to avoid creating a copy of module
|
||||
new_lines = []
|
||||
# Using a while loop and the pop method to avoid creating a copy of module
|
||||
while module:
|
||||
line = module.pop(0)
|
||||
if "%c19_i64 = arith.constant 19 : i64" in line:
|
||||
new_lines.append("%c2 = arith.constant 2 : index")
|
||||
new_lines.append(f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>")
|
||||
new_lines.append("%dim_i64 = arith.index_cast %dim_4_int : index to i64")
|
||||
new_lines.append(
|
||||
f"%dim_4_int = tensor.dim %arg1, %c2 : tensor<1x32x?x128x{'f16' if self.precision == 'fp16' else 'f32'}>"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dim_i64 = arith.index_cast %dim_4_int : index to i64"
|
||||
)
|
||||
continue
|
||||
if "%c2 = arith.constant 2 : index" in line:
|
||||
continue
|
||||
if "%c20_i64 = arith.constant 20 : i64" in line:
|
||||
new_lines.append("%c1_i64 = arith.constant 1 : i64")
|
||||
new_lines.append("c20_i64 = arith.addi %dim_i64, %c1_i64 : i64")
|
||||
new_lines.append("%dimp1 = arith.index_cast %c20_i64 : i64 to index")
|
||||
new_lines.append(
|
||||
"c20_i64 = arith.addi %dim_i64, %c1_i64 : i64"
|
||||
)
|
||||
new_lines.append(
|
||||
"%dimp1 = arith.index_cast %c20_i64 : i64 to index"
|
||||
)
|
||||
continue
|
||||
line = remove_constant_dim(line)
|
||||
new_lines.append(line)
|
||||
|
||||
return '\n'.join(new_lines)
|
||||
return "\n".join(new_lines)
|
||||
|
||||
def compile(self, download_vmfb=False):
|
||||
# Testing : DO NOT Download Vmfbs if not found. Modify later
|
||||
# download vmfbs for A100
|
||||
supported_devices = ["cuda", "cpu-sync", "cpu-task", "cpu"]
|
||||
if (
|
||||
not self.vicuna_vmfb_path.exists()
|
||||
and self.device in supported_devices
|
||||
and self.precision in ["fp32", "fp16", "int8"]
|
||||
):
|
||||
if (self.device == "cuda" and self.precision == "fp16") or (
|
||||
self.device in ["cpu-sync", "cpu-task"]
|
||||
and self.precision == "int8" and download_vmfb
|
||||
):
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/vmfb/{self.vicuna_vmfb_path.name}",
|
||||
self.vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
if not self.vicuna_vmfb_path.exists() and download_vmfb:
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/vmfb/{self.vicuna_vmfb_path.name}",
|
||||
self.vicuna_vmfb_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
self.shark_model = get_vmfb_from_path(
|
||||
self.vicuna_vmfb_path, self.device, "tm_tensor"
|
||||
)
|
||||
if self.shark_model is not None:
|
||||
return None
|
||||
print(f"[DEBUG] vmfb found at {self.vicuna_vmfb_path.absolute()}")
|
||||
return
|
||||
|
||||
print(
|
||||
f"[DEBUG] vmfb not found at {self.vicuna_vmfb_path.absolute()}. Trying to work with\n"
|
||||
f"[DEBUG] mlir path { self.vicuna_mlir_path} {'exists' if self.vicuna_mlir_path.exists() else 'does not exist'}"
|
||||
)
|
||||
print(f"[DEBUG] vmfb not found at {self.vicuna_vmfb_path.absolute()}")
|
||||
if self.vicuna_mlir_path.exists():
|
||||
print(f"[DEBUG] mlir found at {self.vicuna_mlir_path.absolute()}")
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
else:
|
||||
print(
|
||||
f"[DEBUG] mlir not found at {self.vicuna_mlir_path.absolute()}"
|
||||
)
|
||||
mlir_generated = False
|
||||
if self.load_mlir_from_shark_tank:
|
||||
if self.precision in ["fp32", "fp16", "int8", "int4"]:
|
||||
# download MLIR from shark_tank
|
||||
download_public_file(
|
||||
f"gs://shark_tank/vicuna/unsharded/mlir/{self.vicuna_mlir_path.name}",
|
||||
self.vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.vicuna_mlir_path.exists():
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
bytecode = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"MLIR not found at {self.vicuna_mlir_path.absolute()}"
|
||||
" after downloading! Please check path and try again"
|
||||
)
|
||||
# download MLIR from shark tank
|
||||
download_public_file(
|
||||
f"gs://shark_tank/{self.model_name}/unsharded/mlir/{self.vicuna_mlir_path.name}",
|
||||
self.vicuna_mlir_path.absolute(),
|
||||
single_file=True,
|
||||
)
|
||||
if self.vicuna_mlir_path.exists():
|
||||
with open(self.vicuna_mlir_path, "rb") as f:
|
||||
combined_module = f.read()
|
||||
mlir_generated = True
|
||||
else:
|
||||
print(
|
||||
f"Only fp32/fp16/int8/int4 mlir added to tank, generating {self.precision} mlir on device."
|
||||
f"[DEBUG] failed to download {self.vicuna_mlir_path.name} from shark tank"
|
||||
)
|
||||
|
||||
if not mlir_generated:
|
||||
print("[DEBUG] generating mlir on device")
|
||||
# Select a compilation prompt such that the resulting input_ids
|
||||
# from the model's tokenizer has shape [1, 19]
|
||||
if self.model_name == "codegen":
|
||||
compilation_prompt = "def hello_world():\n print('Hello World')\n print('Hello World')"
|
||||
else:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
combined_module = None
|
||||
if Path("first.mlir").exists():
|
||||
print("loading first.mlir")
|
||||
with open(Path("first.mlir"), "r") as f:
|
||||
first_module = f.read()
|
||||
else:
|
||||
compilation_prompt = "".join(["0" for _ in range(17)])
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=self.precision
|
||||
== "fp16", # TODO: Remove from import_with_fx args and fix all calls
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[0] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
first_module = None
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
if Path(f"first_{self.precision}.mlir").exists():
|
||||
print(f"loading first_{self.precision}.mlir")
|
||||
with open(Path(f"first_{self.precision}.mlir"), "r") as f:
|
||||
first_module = f.read()
|
||||
else:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
# generate first vicuna
|
||||
compilation_input_ids = self.tokenizer(
|
||||
compilation_prompt,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
compilation_input_ids = torch.tensor(
|
||||
compilation_input_ids
|
||||
).reshape([1, 19])
|
||||
firstVicunaCompileInput = (compilation_input_ids,)
|
||||
model = FirstVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
del ts_graph
|
||||
|
||||
|
||||
first_module = self.write_in_dynamic_inputs0(str(first_module), dynamic_input_size=19)
|
||||
|
||||
with open("first.mlir", "w+") as f:
|
||||
f.write(first_module)
|
||||
|
||||
if Path("second.mlir").exists():
|
||||
print("loading second.mlir")
|
||||
with open(Path("second.mlir"), "r") as f:
|
||||
second_module = f.read()
|
||||
else:
|
||||
compilation_input_ids = torch.zeros([1, 1], dtype=torch.int64)
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
|
||||
for _ in range(64)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
model = SecondVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
if self.precision == "fp16":
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
firstVicunaCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False, False],
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[
|
||||
i
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
del model
|
||||
firstVicunaCompileInput = list(firstVicunaCompileInput)
|
||||
firstVicunaCompileInput[
|
||||
0
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
firstVicunaCompileInput[0], dynamic_axes=[1]
|
||||
)
|
||||
|
||||
firstVicunaCompileInput = tuple(firstVicunaCompileInput)
|
||||
first_module = None
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=[
|
||||
"brevitas.matmul_rhs_group_quant"
|
||||
],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
first_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
first_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*firstVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del firstVicunaCompileInput
|
||||
gc.collect()
|
||||
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
print(
|
||||
"[DEBUG] successfully generated first vicuna linalg mlir"
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
first_module = self.write_in_dynamic_inputs0(
|
||||
str(first_module), dynamic_input_size=19
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(f"first_{self.precision}.mlir", "w+") as f:
|
||||
f.write(first_module)
|
||||
|
||||
if Path(f"second_{self.precision}.mlir").exists():
|
||||
print(f"loading second_{self.precision}.mlir")
|
||||
with open(Path(f"second_{self.precision}.mlir"), "r") as f:
|
||||
second_module = f.read()
|
||||
else:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
# generate second vicuna
|
||||
compilation_input_ids = torch.zeros(
|
||||
[1, 1], dtype=torch.int64
|
||||
)
|
||||
print("[DEBUG] successfully converted second vicuna to linalg.")
|
||||
second_module = str(second_module)
|
||||
second_module = self.write_in_dynamic_inputs1(second_module)
|
||||
with open("second.mlir", "w+") as f:
|
||||
f.write(second_module)
|
||||
|
||||
combined_module = self.combine_mlir_scripts(first_module, second_module, self.vicuna_mlir_path)
|
||||
del first_module, second_module
|
||||
pkv = tuple(
|
||||
(torch.zeros([1, 32, 19, 128], dtype=torch.float32))
|
||||
for _ in range(64)
|
||||
)
|
||||
secondVicunaCompileInput = (compilation_input_ids,) + pkv
|
||||
model = SecondVicuna(
|
||||
self.hf_model_path,
|
||||
self.precision,
|
||||
self.weight_group_size,
|
||||
self.model_name,
|
||||
self.hf_auth_token,
|
||||
)
|
||||
print(f"[DEBUG] generating torchscript graph")
|
||||
ts_graph = import_with_fx(
|
||||
model,
|
||||
secondVicunaCompileInput,
|
||||
is_f16=self.precision == "fp16",
|
||||
precision=self.precision,
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
del model
|
||||
if self.precision == "fp16":
|
||||
secondVicunaCompileInput = get_f16_inputs(
|
||||
secondVicunaCompileInput,
|
||||
True,
|
||||
f16_input_mask=[False] + [True] * 64,
|
||||
)
|
||||
secondVicunaCompileInput = list(secondVicunaCompileInput)
|
||||
for i in range(len(secondVicunaCompileInput)):
|
||||
if i != 0:
|
||||
secondVicunaCompileInput[
|
||||
i
|
||||
] = torch_mlir.TensorPlaceholder.like(
|
||||
secondVicunaCompileInput[i], dynamic_axes=[2]
|
||||
)
|
||||
secondVicunaCompileInput = tuple(secondVicunaCompileInput)
|
||||
print(f"[DEBUG] generating torch mlir")
|
||||
if self.precision in ["int4", "int8"]:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
output_type=torch_mlir.OutputType.TORCH,
|
||||
backend_legal_ops=[
|
||||
"brevitas.matmul_rhs_group_quant"
|
||||
],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
second_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
else:
|
||||
second_module = torch_mlir.compile(
|
||||
ts_graph,
|
||||
[*secondVicunaCompileInput],
|
||||
torch_mlir.OutputType.LINALG_ON_TENSORS,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
del ts_graph
|
||||
del secondVicunaCompileInput
|
||||
gc.collect()
|
||||
print(
|
||||
"[DEBUG] successfully generated second vicuna linalg mlir"
|
||||
)
|
||||
second_module = self.write_in_dynamic_inputs1(
|
||||
str(second_module)
|
||||
)
|
||||
if self.cache_vicunas:
|
||||
with open(f"second_{self.precision}.mlir", "w+") as f:
|
||||
f.write(second_module)
|
||||
|
||||
combined_module = self.combine_mlir_scripts(
|
||||
first_module, second_module, self.vicuna_mlir_path
|
||||
)
|
||||
del first_module, second_module
|
||||
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module=combined_module, device=self.device, mlir_dialect="tm_tensor"
|
||||
mlir_module=combined_module,
|
||||
device=self.device,
|
||||
mlir_dialect="tm_tensor",
|
||||
)
|
||||
path = shark_module.save_module(
|
||||
self.vicuna_vmfb_path.parent.absolute(),
|
||||
@@ -1341,9 +1396,7 @@ class UnshardedVicuna(VicunaBase):
|
||||
)
|
||||
print("Saved vic vmfb at ", str(path))
|
||||
shark_module.load_module(path)
|
||||
|
||||
self.shark_module = shark_module
|
||||
|
||||
self.shark_model = shark_module
|
||||
|
||||
def decode_tokens(self, res_tokens):
|
||||
for i in range(len(res_tokens)):
|
||||
@@ -1358,16 +1411,14 @@ class UnshardedVicuna(VicunaBase):
|
||||
|
||||
def generate(self, prompt, cli=True):
|
||||
# TODO: refactor for cleaner integration
|
||||
import gc
|
||||
|
||||
if self.shark_model is None:
|
||||
self.compile()
|
||||
res_tokens = []
|
||||
params = {
|
||||
"prompt": prompt,
|
||||
"is_first": True,
|
||||
"fv": self.shark_model
|
||||
}
|
||||
params = {"prompt": prompt, "is_first": True, "fv": self.shark_model}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params, sharded=False)
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False
|
||||
)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
@@ -1379,17 +1430,18 @@ class UnshardedVicuna(VicunaBase):
|
||||
if cli:
|
||||
print(f"Assistant: {detok}", end=" ", flush=True)
|
||||
|
||||
|
||||
for _ in range(self.max_num_tokens - 2):
|
||||
params = {
|
||||
"token": token,
|
||||
"is_first": False,
|
||||
"logits": logits,
|
||||
"past_key_values": pkv,
|
||||
"sv": self.shark_model
|
||||
"sv": self.shark_model,
|
||||
}
|
||||
|
||||
generated_token_op = self.generate_new_token(params=params, sharded=False)
|
||||
generated_token_op = self.generate_new_token(
|
||||
params=params, sharded=False
|
||||
)
|
||||
|
||||
token = generated_token_op["token"]
|
||||
logits = generated_token_op["logits"]
|
||||
@@ -1410,12 +1462,10 @@ class UnshardedVicuna(VicunaBase):
|
||||
part_str = self.decode_tokens(res_tokens)
|
||||
yield part_str
|
||||
|
||||
|
||||
res_str = self.decode_tokens(res_tokens)
|
||||
# print(f"[DEBUG] final output : \n{res_str}")
|
||||
yield res_str
|
||||
|
||||
|
||||
def autocomplete(self, prompt):
|
||||
# use First vic alone to complete a story / prompt / sentence.
|
||||
pass
|
||||
@@ -1446,6 +1496,7 @@ if __name__ == "__main__":
|
||||
load_mlir_from_shark_tank=args.load_mlir_from_shark_tank,
|
||||
weight_group_size=args.weight_group_size,
|
||||
download_vmfb=args.download_vmfb,
|
||||
cache_vicunas=args.cache_vicunas,
|
||||
)
|
||||
else:
|
||||
if args.config is not None:
|
||||
|
||||
503
apps/language_models/src/model_wrappers/minigpt4.py
Normal file
503
apps/language_models/src/model_wrappers/minigpt4.py
Normal file
@@ -0,0 +1,503 @@
|
||||
import torch
|
||||
import dataclasses
|
||||
from enum import auto, Enum
|
||||
from typing import List, Any
|
||||
from transformers import StoppingCriteria
|
||||
|
||||
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
class LayerNorm(torch.nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class VisionModel(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ln_vision,
|
||||
visual_encoder,
|
||||
precision="fp32",
|
||||
weight_group_size=128,
|
||||
):
|
||||
super().__init__()
|
||||
self.ln_vision = ln_vision
|
||||
self.visual_encoder = visual_encoder
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Vision Model applying weight quantization to ln_vision")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.ln_vision,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
print(
|
||||
"Vision Model applying weight quantization to visual_encoder"
|
||||
)
|
||||
quantize_model(
|
||||
self.visual_encoder,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, image):
|
||||
image_embeds = self.ln_vision(self.visual_encoder(image))
|
||||
return image_embeds
|
||||
|
||||
|
||||
class QformerBertModel(torch.nn.Module):
|
||||
def __init__(self, qformer_bert):
|
||||
super().__init__()
|
||||
self.qformer_bert = qformer_bert
|
||||
|
||||
def forward(self, query_tokens, image_embeds, image_atts):
|
||||
query_output = self.qformer_bert(
|
||||
query_embeds=query_tokens,
|
||||
encoder_hidden_states=image_embeds,
|
||||
encoder_attention_mask=image_atts,
|
||||
return_dict=True,
|
||||
)
|
||||
return query_output.last_hidden_state
|
||||
|
||||
|
||||
class FirstLlamaModel(torch.nn.Module):
|
||||
def __init__(self, model, precision="fp32", weight_group_size=128):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
print("SHARK: Loading LLAMA Done")
|
||||
if precision in ["int4", "int8"]:
|
||||
print("First Llama applying weight quantization")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(self, inputs_embeds, position_ids, attention_mask):
|
||||
print("************************************")
|
||||
print(
|
||||
"inputs_embeds: ",
|
||||
inputs_embeds.shape,
|
||||
" dtype: ",
|
||||
inputs_embeds.dtype,
|
||||
)
|
||||
print(
|
||||
"position_ids: ",
|
||||
position_ids.shape,
|
||||
" dtype: ",
|
||||
position_ids.dtype,
|
||||
)
|
||||
print(
|
||||
"attention_mask: ",
|
||||
attention_mask.shape,
|
||||
" dtype: ",
|
||||
attention_mask.dtype,
|
||||
)
|
||||
print("************************************")
|
||||
config = {
|
||||
"inputs_embeds": inputs_embeds,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": None,
|
||||
"use_cache": True,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(
|
||||
**config,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(output.logits)
|
||||
temp_past_key_values = output.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SecondLlamaModel(torch.nn.Module):
|
||||
def __init__(self, model, precision="fp32", weight_group_size=128):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
print("SHARK: Loading LLAMA Done")
|
||||
if precision in ["int4", "int8"]:
|
||||
print("Second Llama applying weight quantization")
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
quantize_model(
|
||||
self.model,
|
||||
dtype=torch.float32,
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_type="asym",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
attention_mask,
|
||||
i1,
|
||||
i2,
|
||||
i3,
|
||||
i4,
|
||||
i5,
|
||||
i6,
|
||||
i7,
|
||||
i8,
|
||||
i9,
|
||||
i10,
|
||||
i11,
|
||||
i12,
|
||||
i13,
|
||||
i14,
|
||||
i15,
|
||||
i16,
|
||||
i17,
|
||||
i18,
|
||||
i19,
|
||||
i20,
|
||||
i21,
|
||||
i22,
|
||||
i23,
|
||||
i24,
|
||||
i25,
|
||||
i26,
|
||||
i27,
|
||||
i28,
|
||||
i29,
|
||||
i30,
|
||||
i31,
|
||||
i32,
|
||||
i33,
|
||||
i34,
|
||||
i35,
|
||||
i36,
|
||||
i37,
|
||||
i38,
|
||||
i39,
|
||||
i40,
|
||||
i41,
|
||||
i42,
|
||||
i43,
|
||||
i44,
|
||||
i45,
|
||||
i46,
|
||||
i47,
|
||||
i48,
|
||||
i49,
|
||||
i50,
|
||||
i51,
|
||||
i52,
|
||||
i53,
|
||||
i54,
|
||||
i55,
|
||||
i56,
|
||||
i57,
|
||||
i58,
|
||||
i59,
|
||||
i60,
|
||||
i61,
|
||||
i62,
|
||||
i63,
|
||||
i64,
|
||||
):
|
||||
print("************************************")
|
||||
print("input_ids: ", input_ids.shape, " dtype: ", input_ids.dtype)
|
||||
print(
|
||||
"position_ids: ",
|
||||
position_ids.shape,
|
||||
" dtype: ",
|
||||
position_ids.dtype,
|
||||
)
|
||||
print(
|
||||
"attention_mask: ",
|
||||
attention_mask.shape,
|
||||
" dtype: ",
|
||||
attention_mask.dtype,
|
||||
)
|
||||
print("past_key_values: ", i1.shape, i2.shape, i63.shape, i64.shape)
|
||||
print("past_key_values dtype: ", i1.dtype)
|
||||
print("************************************")
|
||||
config = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": position_ids,
|
||||
"past_key_values": (
|
||||
(i1, i2),
|
||||
(
|
||||
i3,
|
||||
i4,
|
||||
),
|
||||
(
|
||||
i5,
|
||||
i6,
|
||||
),
|
||||
(
|
||||
i7,
|
||||
i8,
|
||||
),
|
||||
(
|
||||
i9,
|
||||
i10,
|
||||
),
|
||||
(
|
||||
i11,
|
||||
i12,
|
||||
),
|
||||
(
|
||||
i13,
|
||||
i14,
|
||||
),
|
||||
(
|
||||
i15,
|
||||
i16,
|
||||
),
|
||||
(
|
||||
i17,
|
||||
i18,
|
||||
),
|
||||
(
|
||||
i19,
|
||||
i20,
|
||||
),
|
||||
(
|
||||
i21,
|
||||
i22,
|
||||
),
|
||||
(
|
||||
i23,
|
||||
i24,
|
||||
),
|
||||
(
|
||||
i25,
|
||||
i26,
|
||||
),
|
||||
(
|
||||
i27,
|
||||
i28,
|
||||
),
|
||||
(
|
||||
i29,
|
||||
i30,
|
||||
),
|
||||
(
|
||||
i31,
|
||||
i32,
|
||||
),
|
||||
(
|
||||
i33,
|
||||
i34,
|
||||
),
|
||||
(
|
||||
i35,
|
||||
i36,
|
||||
),
|
||||
(
|
||||
i37,
|
||||
i38,
|
||||
),
|
||||
(
|
||||
i39,
|
||||
i40,
|
||||
),
|
||||
(
|
||||
i41,
|
||||
i42,
|
||||
),
|
||||
(
|
||||
i43,
|
||||
i44,
|
||||
),
|
||||
(
|
||||
i45,
|
||||
i46,
|
||||
),
|
||||
(
|
||||
i47,
|
||||
i48,
|
||||
),
|
||||
(
|
||||
i49,
|
||||
i50,
|
||||
),
|
||||
(
|
||||
i51,
|
||||
i52,
|
||||
),
|
||||
(
|
||||
i53,
|
||||
i54,
|
||||
),
|
||||
(
|
||||
i55,
|
||||
i56,
|
||||
),
|
||||
(
|
||||
i57,
|
||||
i58,
|
||||
),
|
||||
(
|
||||
i59,
|
||||
i60,
|
||||
),
|
||||
(
|
||||
i61,
|
||||
i62,
|
||||
),
|
||||
(
|
||||
i63,
|
||||
i64,
|
||||
),
|
||||
),
|
||||
"use_cache": True,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
output = self.model(
|
||||
**config,
|
||||
return_dict=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
)
|
||||
return_vals = []
|
||||
return_vals.append(output.logits)
|
||||
temp_past_key_values = output.past_key_values
|
||||
for item in temp_past_key_values:
|
||||
return_vals.append(item[0])
|
||||
return_vals.append(item[1])
|
||||
return tuple(return_vals)
|
||||
|
||||
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: str = None
|
||||
|
||||
skip_next: bool = False
|
||||
conv_id: Any = None
|
||||
|
||||
def get_prompt(self):
|
||||
if self.sep_style == SeparatorStyle.SINGLE:
|
||||
ret = self.system + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
ret += role + ": " + message + self.sep
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
elif self.sep_style == SeparatorStyle.TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
ret = self.system + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
ret += role + ": " + message + seps[i % 2]
|
||||
else:
|
||||
ret += role + ":"
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
def to_gradio_chatbot(self):
|
||||
ret = []
|
||||
for i, (role, msg) in enumerate(self.messages[self.offset :]):
|
||||
if i % 2 == 0:
|
||||
ret.append([msg, None])
|
||||
else:
|
||||
ret[-1][-1] = msg
|
||||
return ret
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
conv_id=self.conv_id,
|
||||
)
|
||||
|
||||
def dict(self):
|
||||
return {
|
||||
"system": self.system,
|
||||
"roles": self.roles,
|
||||
"messages": self.messages,
|
||||
"offset": self.offset,
|
||||
"sep": self.sep,
|
||||
"sep2": self.sep2,
|
||||
"conv_id": self.conv_id,
|
||||
}
|
||||
|
||||
|
||||
class StoppingCriteriaSub(StoppingCriteria):
|
||||
def __init__(self, stops=[], encounters=1):
|
||||
super().__init__()
|
||||
self.stops = stops
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
|
||||
for stop in self.stops:
|
||||
if torch.all((stop == input_ids[0][-len(stop) :])).item():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
CONV_VISION = Conversation(
|
||||
system="Give the following image: <Img>ImageContent</Img>. "
|
||||
"You will be able to see the image once I provide it to you. Please answer my questions.",
|
||||
roles=("Human", "Assistant"),
|
||||
messages=[],
|
||||
offset=2,
|
||||
sep_style=SeparatorStyle.SINGLE,
|
||||
sep="###",
|
||||
)
|
||||
1439
apps/language_models/src/pipelines/minigpt4_pipeline.py
Normal file
1439
apps/language_models/src/pipelines/minigpt4_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
1297
apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
Normal file
1297
apps/language_models/src/pipelines/minigpt4_utils/Qformer.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Copyright (c) 2022, salesforce.com, inc.
|
||||
All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
For full license text, see the LICENSE_Lavis file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
|
||||
|
||||
class BaseProcessor:
|
||||
def __init__(self):
|
||||
self.transform = lambda x: x
|
||||
return
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
return cls()
|
||||
|
||||
def build(self, **kwargs):
|
||||
cfg = OmegaConf.create(kwargs)
|
||||
|
||||
return self.from_config(cfg)
|
||||
|
||||
|
||||
class BlipImageBaseProcessor(BaseProcessor):
|
||||
def __init__(self, mean=None, std=None):
|
||||
if mean is None:
|
||||
mean = (0.48145466, 0.4578275, 0.40821073)
|
||||
if std is None:
|
||||
std = (0.26862954, 0.26130258, 0.27577711)
|
||||
|
||||
self.normalize = transforms.Normalize(mean, std)
|
||||
|
||||
|
||||
class Blip2ImageEvalProcessor(BlipImageBaseProcessor):
|
||||
def __init__(self, image_size=224, mean=None, std=None):
|
||||
super().__init__(mean=mean, std=std)
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
self.normalize,
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(self, item):
|
||||
return self.transform(item)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg=None):
|
||||
if cfg is None:
|
||||
cfg = OmegaConf.create()
|
||||
|
||||
image_size = cfg.get("image_size", 224)
|
||||
|
||||
mean = cfg.get("mean", None)
|
||||
std = cfg.get("std", None)
|
||||
|
||||
return cls(image_size=image_size, mean=mean, std=std)
|
||||
@@ -0,0 +1,5 @@
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
data_type: images
|
||||
build_info:
|
||||
storage: /path/to/cc_sbu_align/
|
||||
@@ -0,0 +1,33 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
|
||||
# vit encoder
|
||||
image_size: 224
|
||||
drop_path_rate: 0
|
||||
use_grad_checkpoint: False
|
||||
vit_precision: "fp16"
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
|
||||
# Q-Former
|
||||
num_query_token: 32
|
||||
|
||||
# Vicuna
|
||||
llama_model: "lmsys/vicuna-7b-v1.3"
|
||||
|
||||
# generation configs
|
||||
prompt: ""
|
||||
|
||||
preprocess:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_train"
|
||||
image_size: 224
|
||||
eval:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
eval:
|
||||
name: "blip_caption"
|
||||
@@ -0,0 +1,25 @@
|
||||
model:
|
||||
arch: mini_gpt4
|
||||
model_type: pretrain_vicuna
|
||||
freeze_vit: True
|
||||
freeze_qformer: True
|
||||
max_txt_len: 160
|
||||
end_sym: "###"
|
||||
low_resource: False
|
||||
prompt_path: "apps/language_models/src/pipelines/minigpt4_utils/prompts/alignment.txt"
|
||||
prompt_template: '###Human: {} ###Assistant: '
|
||||
ckpt: 'prerained_minigpt4_7b.pth'
|
||||
|
||||
|
||||
datasets:
|
||||
cc_sbu_align:
|
||||
vis_processor:
|
||||
train:
|
||||
name: "blip2_image_eval"
|
||||
image_size: 224
|
||||
text_processor:
|
||||
train:
|
||||
name: "blip_caption"
|
||||
|
||||
run:
|
||||
task: image_text_pretrain
|
||||
629
apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
Normal file
629
apps/language_models/src/pipelines/minigpt4_utils/eva_vit.py
Normal file
@@ -0,0 +1,629 @@
|
||||
# Based on EVA, BEIT, timm and DeiT code bases
|
||||
# https://github.com/baaivision/EVA
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
||||
# https://github.com/microsoft/unilm/tree/master/beit
|
||||
# https://github.com/facebookresearch/deit/
|
||||
# https://github.com/facebookresearch/dino
|
||||
# --------------------------------------------------------'
|
||||
import math
|
||||
import requests
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
def _cfg(url="", **kwargs):
|
||||
return {
|
||||
"url": url,
|
||||
"num_classes": 1000,
|
||||
"input_size": (3, 224, 224),
|
||||
"pool_size": None,
|
||||
"crop_pct": 0.9,
|
||||
"interpolation": "bicubic",
|
||||
"mean": (0.5, 0.5, 0.5),
|
||||
"std": (0.5, 0.5, 0.5),
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return "p={}".format(self.drop_prob)
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
# x = self.drop(x)
|
||||
# commit this for the orignal BERT implement
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads=8,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
attn_drop=0.0,
|
||||
proj_drop=0.0,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
if attn_head_dim is not None:
|
||||
head_dim = attn_head_dim
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(
|
||||
torch.meshgrid([coords_h, coords_w])
|
||||
) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += (
|
||||
window_size[0] - 1
|
||||
) # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1
|
||||
) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index", relative_position_index
|
||||
)
|
||||
else:
|
||||
self.window_size = None
|
||||
self.relative_position_bias_table = None
|
||||
self.relative_position_index = None
|
||||
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(all_head_dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
qkv_bias = torch.cat(
|
||||
(
|
||||
self.q_bias,
|
||||
torch.zeros_like(self.v_bias, requires_grad=False),
|
||||
self.v_bias,
|
||||
)
|
||||
)
|
||||
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = (
|
||||
qkv[0],
|
||||
qkv[1],
|
||||
qkv[2],
|
||||
) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
if self.relative_position_bias_table is not None:
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if rel_pos_bias is not None:
|
||||
attn = attn + rel_pos_bias
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
num_heads,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop=0.0,
|
||||
attn_drop=0.0,
|
||||
drop_path=0.0,
|
||||
init_values=None,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
window_size=None,
|
||||
attn_head_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
window_size=window_size,
|
||||
attn_head_dim=attn_head_dim,
|
||||
)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = (
|
||||
DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
)
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
)
|
||||
|
||||
if init_values is not None and init_values > 0:
|
||||
self.gamma_1 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
self.gamma_2 = nn.Parameter(
|
||||
init_values * torch.ones((dim)), requires_grad=True
|
||||
)
|
||||
else:
|
||||
self.gamma_1, self.gamma_2 = None, None
|
||||
|
||||
def forward(self, x, rel_pos_bias=None):
|
||||
if self.gamma_1 is None:
|
||||
x = x + self.drop_path(
|
||||
self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
else:
|
||||
x = x + self.drop_path(
|
||||
self.gamma_1
|
||||
* self.attn(self.norm1(x), rel_pos_bias=rel_pos_bias)
|
||||
)
|
||||
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
num_patches = (img_size[1] // patch_size[1]) * (
|
||||
img_size[0] // patch_size[0]
|
||||
)
|
||||
self.patch_shape = (
|
||||
img_size[0] // patch_size[0],
|
||||
img_size[1] // patch_size[1],
|
||||
)
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
||||
)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
B, C, H, W = x.shape
|
||||
# FIXME look at relaxing size constraints
|
||||
assert (
|
||||
H == self.img_size[0] and W == self.img_size[1]
|
||||
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
class RelativePositionBias(nn.Module):
|
||||
def __init__(self, window_size, num_heads):
|
||||
super().__init__()
|
||||
self.window_size = window_size
|
||||
self.num_relative_distance = (2 * window_size[0] - 1) * (
|
||||
2 * window_size[1] - 1
|
||||
) + 3
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros(self.num_relative_distance, num_heads)
|
||||
) # 2*Wh-1 * 2*Ww-1, nH
|
||||
# cls to token & token 2 cls & cls to cls
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(window_size[0])
|
||||
coords_w = torch.arange(window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = (
|
||||
coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
||||
) # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0
|
||||
).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
||||
relative_position_index = torch.zeros(
|
||||
size=(window_size[0] * window_size[1] + 1,) * 2,
|
||||
dtype=relative_coords.dtype,
|
||||
)
|
||||
relative_position_index[1:, 1:] = relative_coords.sum(
|
||||
-1
|
||||
) # Wh*Ww, Wh*Ww
|
||||
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
||||
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
||||
relative_position_index[0, 0] = self.num_relative_distance - 1
|
||||
|
||||
self.register_buffer(
|
||||
"relative_position_index", relative_position_index
|
||||
)
|
||||
|
||||
# trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
|
||||
def forward(self):
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)
|
||||
].view(
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
self.window_size[0] * self.window_size[1] + 1,
|
||||
-1,
|
||||
) # Wh*Ww,Wh*Ww,nH
|
||||
return relative_position_bias.permute(
|
||||
2, 0, 1
|
||||
).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
"""Vision Transformer with support for patch or hybrid CNN input stage"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
num_classes=1000,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
drop_rate=0.0,
|
||||
attn_drop_rate=0.0,
|
||||
drop_path_rate=0.0,
|
||||
norm_layer=nn.LayerNorm,
|
||||
init_values=None,
|
||||
use_abs_pos_emb=True,
|
||||
use_rel_pos_bias=False,
|
||||
use_shared_rel_pos_bias=False,
|
||||
use_mean_pooling=True,
|
||||
init_scale=0.001,
|
||||
use_checkpoint=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_size = img_size
|
||||
self.num_classes = num_classes
|
||||
self.num_features = (
|
||||
self.embed_dim
|
||||
) = embed_dim # num_features for consistency with other models
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim,
|
||||
)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
if use_abs_pos_emb:
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, embed_dim)
|
||||
)
|
||||
else:
|
||||
self.pos_embed = None
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
if use_shared_rel_pos_bias:
|
||||
self.rel_pos_bias = RelativePositionBias(
|
||||
window_size=self.patch_embed.patch_shape, num_heads=num_heads
|
||||
)
|
||||
else:
|
||||
self.rel_pos_bias = None
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
dpr = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
self.use_rel_pos_bias = use_rel_pos_bias
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
Block(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
init_values=init_values,
|
||||
window_size=self.patch_embed.patch_shape
|
||||
if use_rel_pos_bias
|
||||
else None,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
# self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
|
||||
# self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
|
||||
# self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
if self.pos_embed is not None:
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
trunc_normal_(self.cls_token, std=0.02)
|
||||
# trunc_normal_(self.mask_token, std=.02)
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# trunc_normal_(self.head.weight, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
self.fix_init_weight()
|
||||
|
||||
# if isinstance(self.head, nn.Linear):
|
||||
# self.head.weight.data.mul_(init_scale)
|
||||
# self.head.bias.data.mul_(init_scale)
|
||||
|
||||
def fix_init_weight(self):
|
||||
def rescale(param, layer_id):
|
||||
param.div_(math.sqrt(2.0 * layer_id))
|
||||
|
||||
for layer_id, layer in enumerate(self.blocks):
|
||||
rescale(layer.attn.proj.weight.data, layer_id + 1)
|
||||
rescale(layer.mlp.fc2.weight.data, layer_id + 1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
def get_classifier(self):
|
||||
return self.head
|
||||
|
||||
def reset_classifier(self, num_classes, global_pool=""):
|
||||
self.num_classes = num_classes
|
||||
self.head = (
|
||||
nn.Linear(self.embed_dim, num_classes)
|
||||
if num_classes > 0
|
||||
else nn.Identity()
|
||||
)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
rel_pos_bias = (
|
||||
self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
)
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, rel_pos_bias)
|
||||
else:
|
||||
x = blk(x, rel_pos_bias)
|
||||
return x
|
||||
|
||||
# x = self.norm(x)
|
||||
|
||||
# if self.fc_norm is not None:
|
||||
# t = x[:, 1:, :]
|
||||
# return self.fc_norm(t.mean(1))
|
||||
# else:
|
||||
# return x[:, 0]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.forward_features(x)
|
||||
# x = self.head(x)
|
||||
return x
|
||||
|
||||
def get_intermediate_layers(self, x):
|
||||
x = self.patch_embed(x)
|
||||
batch_size, seq_len, _ = x.size()
|
||||
|
||||
cls_tokens = self.cls_token.expand(
|
||||
batch_size, -1, -1
|
||||
) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
if self.pos_embed is not None:
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
features = []
|
||||
rel_pos_bias = (
|
||||
self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
||||
)
|
||||
for blk in self.blocks:
|
||||
x = blk(x, rel_pos_bias)
|
||||
features.append(x)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
def interpolate_pos_embed(model, checkpoint_model):
|
||||
if "pos_embed" in checkpoint_model:
|
||||
pos_embed_checkpoint = checkpoint_model["pos_embed"].float()
|
||||
embedding_size = pos_embed_checkpoint.shape[-1]
|
||||
num_patches = model.patch_embed.num_patches
|
||||
num_extra_tokens = model.pos_embed.shape[-2] - num_patches
|
||||
# height (== width) for the checkpoint position embedding
|
||||
orig_size = int(
|
||||
(pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5
|
||||
)
|
||||
# height (== width) for the new position embedding
|
||||
new_size = int(num_patches**0.5)
|
||||
# class_token and dist_token are kept unchanged
|
||||
if orig_size != new_size:
|
||||
print(
|
||||
"Position interpolate from %dx%d to %dx%d"
|
||||
% (orig_size, orig_size, new_size, new_size)
|
||||
)
|
||||
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
|
||||
# only the position tokens are interpolated
|
||||
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
|
||||
pos_tokens = pos_tokens.reshape(
|
||||
-1, orig_size, orig_size, embedding_size
|
||||
).permute(0, 3, 1, 2)
|
||||
pos_tokens = torch.nn.functional.interpolate(
|
||||
pos_tokens,
|
||||
size=(new_size, new_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
||||
checkpoint_model["pos_embed"] = new_pos_embed
|
||||
|
||||
|
||||
def convert_weights_to_fp16(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
# l.weight.data = l.weight.data.half()
|
||||
l.weight.data = l.weight.data
|
||||
if l.bias is not None:
|
||||
# l.bias.data = l.bias.data.half()
|
||||
l.bias.data = l.bias.data
|
||||
|
||||
# if isinstance(l, (nn.MultiheadAttention, Attention)):
|
||||
# for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
# tensor = getattr(l, attr)
|
||||
# if tensor is not None:
|
||||
# tensor.data = tensor.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def create_eva_vit_g(
|
||||
img_size=224, drop_path_rate=0.4, use_checkpoint=False, precision="fp16"
|
||||
):
|
||||
model = VisionTransformer(
|
||||
img_size=img_size,
|
||||
patch_size=14,
|
||||
use_mean_pooling=False,
|
||||
embed_dim=1408,
|
||||
depth=39,
|
||||
num_heads=1408 // 88,
|
||||
mlp_ratio=4.3637,
|
||||
qkv_bias=True,
|
||||
drop_path_rate=drop_path_rate,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6),
|
||||
use_checkpoint=use_checkpoint,
|
||||
)
|
||||
url = "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/eva_vit_g.pth"
|
||||
|
||||
local_filename = "eva_vit_g.pth"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
with open(local_filename, "wb") as f:
|
||||
f.write(response.content)
|
||||
print("File downloaded successfully.")
|
||||
state_dict = torch.load(local_filename, map_location="cpu")
|
||||
interpolate_pos_embed(model, state_dict)
|
||||
|
||||
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
if precision == "fp16":
|
||||
# model.to("cuda")
|
||||
convert_weights_to_fp16(model)
|
||||
return model
|
||||
@@ -0,0 +1,4 @@
|
||||
<Img><ImageHere></Img> Describe this image in detail.
|
||||
<Img><ImageHere></Img> Take a look at this image and describe what you notice.
|
||||
<Img><ImageHere></Img> Please provide a detailed description of the picture.
|
||||
<Img><ImageHere></Img> Could you describe the contents of this image for me?
|
||||
@@ -3,6 +3,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._decomp import get_decompositions
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from shark.shark_downloader import download_public_file
|
||||
|
||||
|
||||
# expects a Path / str as arg
|
||||
@@ -17,9 +18,23 @@ def get_vmfb_from_path(vmfb_path, device, mlir_dialect):
|
||||
return None
|
||||
|
||||
print("Loading vmfb from: ", vmfb_path)
|
||||
print("Device from get_vmfb_from_path - ", device)
|
||||
shark_module = SharkInference(
|
||||
None, device=device, mlir_dialect=mlir_dialect
|
||||
)
|
||||
shark_module.load_module(vmfb_path)
|
||||
print("Successfully loaded vmfb")
|
||||
return shark_module
|
||||
|
||||
|
||||
def get_vmfb_from_config(
|
||||
shark_container, model, precision, device, vmfb_path, padding=None
|
||||
):
|
||||
vmfb_url = (
|
||||
f"gs://shark_tank/{shark_container}/{model}_{precision}_{device}"
|
||||
)
|
||||
if padding:
|
||||
vmfb_url = vmfb_url + f"_{padding}"
|
||||
vmfb_url = vmfb_url + ".vmfb"
|
||||
download_public_file(vmfb_url, vmfb_path.absolute(), single_file=True)
|
||||
return get_vmfb_from_path(vmfb_path, device, "tm_tensor")
|
||||
|
||||
@@ -7,7 +7,11 @@ import sys
|
||||
sys.setrecursionlimit(sys.getrecursionlimit() * 5)
|
||||
|
||||
# python path for pyinstaller
|
||||
pathex = [".", "./apps/language_models/langchain"]
|
||||
pathex = [
|
||||
".",
|
||||
"./apps/language_models/langchain",
|
||||
"./apps/language_models/src/pipelines/minigpt4_utils",
|
||||
]
|
||||
|
||||
# datafiles for pyinstaller
|
||||
datas = []
|
||||
@@ -39,6 +43,7 @@ datas += collect_data_files("gradio_client")
|
||||
datas += collect_data_files("iree")
|
||||
datas += collect_data_files("google_cloud_storage")
|
||||
datas += collect_data_files("shark")
|
||||
datas += collect_data_files("timm", include_py_files=True)
|
||||
datas += collect_data_files("tkinter")
|
||||
datas += collect_data_files("webview")
|
||||
datas += collect_data_files("sentencepiece")
|
||||
@@ -52,6 +57,14 @@ datas += [
|
||||
("src/utils/resources/base_model.json", "resources"),
|
||||
("web/ui/css/*", "ui/css"),
|
||||
("web/ui/logos/*", "logos"),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/configs/*",
|
||||
"minigpt4_utils/configs",
|
||||
),
|
||||
(
|
||||
"../language_models/src/pipelines/minigpt4_utils/prompts/*",
|
||||
"minigpt4_utils/prompts",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@ from apps.stable_diffusion.src.utils.stable_args import args
|
||||
|
||||
# Helper function to profile the vulkan device.
|
||||
def start_profiling(file_path="foo.rdc", profiling_mode="queue"):
|
||||
if args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
from shark.parser import shark_args
|
||||
|
||||
if shark_args.vulkan_debug_utils and "vulkan" in args.device:
|
||||
import iree
|
||||
|
||||
print(f"Profiling and saving to {file_path}.")
|
||||
|
||||
@@ -425,27 +425,6 @@ p.add_argument(
|
||||
help="Specify target triple for metal.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="2073741824",
|
||||
help="Flag for setting VMA preferredLargeHeapBlockSize for "
|
||||
"vulkan device, default is 4G.",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for disabling vulkan validation layers when benchmarking.",
|
||||
)
|
||||
|
||||
##############################################################################
|
||||
# Misc. Debug and Optimization flags
|
||||
##############################################################################
|
||||
|
||||
@@ -22,6 +22,7 @@ from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
from shark.iree_utils.metal_utils import get_metal_target_triple
|
||||
from shark.iree_utils.gpu_utils import get_cuda_sm_cc
|
||||
@@ -183,10 +184,7 @@ def compile_through_fx(
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
|
||||
]
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
@@ -461,7 +459,12 @@ def get_available_devices():
|
||||
device_name = (
|
||||
cpu_name if device["name"] == "default" else device["name"]
|
||||
)
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
if "local" in driver_name:
|
||||
device_list.append(
|
||||
f"{device_name} => {driver_name.replace('local', 'cpu')}"
|
||||
)
|
||||
else:
|
||||
device_list.append(f"{device_name} => {driver_name}://{i}")
|
||||
return device_list
|
||||
|
||||
set_iree_runtime_flags()
|
||||
|
||||
@@ -161,6 +161,7 @@ if __name__ == "__main__":
|
||||
modelmanager_sendto_outpaint,
|
||||
modelmanager_sendto_upscaler,
|
||||
stablelm_chat,
|
||||
minigpt4_web,
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
outputgallery_watch,
|
||||
@@ -220,14 +221,8 @@ if __name__ == "__main__":
|
||||
outpaint_web.render()
|
||||
with gr.TabItem(label="Upscaler", id=4):
|
||||
upscaler_web.render()
|
||||
with gr.TabItem(label="Model Manager", id=5):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="Chat Bot(Experimental)", id=6):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(label="LoRA Training(Experimental)", id=7):
|
||||
lora_train_web.render()
|
||||
if args.output_gallery:
|
||||
with gr.TabItem(label="Output Gallery", id=8) as og_tab:
|
||||
with gr.TabItem(label="Output Gallery", id=5) as og_tab:
|
||||
outputgallery_web.render()
|
||||
|
||||
# extra output gallery configuration
|
||||
@@ -241,7 +236,15 @@ if __name__ == "__main__":
|
||||
upscaler_status,
|
||||
]
|
||||
)
|
||||
with gr.TabItem(label="DocuChat(Experimental)", id=9):
|
||||
with gr.TabItem(label="Model Manager", id=6):
|
||||
model_web.render()
|
||||
with gr.TabItem(label="LoRA Training (Experimental)", id=8):
|
||||
lora_train_web.render()
|
||||
with gr.TabItem(label="Chat Bot (Experimental)", id=7):
|
||||
stablelm_chat.render()
|
||||
with gr.TabItem(label="MultiModal (Experimental)", id=9):
|
||||
minigpt4_web.render()
|
||||
with gr.TabItem(label="DocuChat(Experimental)", id=10):
|
||||
h2ogpt_web.render()
|
||||
|
||||
# send to buttons
|
||||
|
||||
@@ -79,6 +79,7 @@ from apps.stable_diffusion.web.ui.stablelm_ui import (
|
||||
llm_chat_api,
|
||||
)
|
||||
from apps.stable_diffusion.web.ui.h2ogpt import h2ogpt_web
|
||||
from apps.stable_diffusion.web.ui.minigpt4_ui import minigpt4_web
|
||||
from apps.stable_diffusion.web.ui.outputgallery_ui import (
|
||||
outputgallery_web,
|
||||
outputgallery_tab_select,
|
||||
|
||||
193
apps/stable_diffusion/web/ui/minigpt4_ui.py
Normal file
193
apps/stable_diffusion/web/ui/minigpt4_ui.py
Normal file
@@ -0,0 +1,193 @@
|
||||
# ========================================
|
||||
# Gradio Setting
|
||||
# ========================================
|
||||
import gradio as gr
|
||||
|
||||
# from apps.language_models.src.pipelines.minigpt4_pipeline import (
|
||||
# # MiniGPT4,
|
||||
# CONV_VISION,
|
||||
# )
|
||||
from pathlib import Path
|
||||
|
||||
chat = None
|
||||
|
||||
|
||||
def gradio_reset(chat_state, img_list):
|
||||
if chat_state is not None:
|
||||
chat_state.messages = []
|
||||
if img_list is not None:
|
||||
img_list = []
|
||||
return (
|
||||
None,
|
||||
gr.update(value=None, interactive=True),
|
||||
gr.update(
|
||||
placeholder="Please upload your image first", interactive=False
|
||||
),
|
||||
gr.update(value="Upload & Start Chat", interactive=True),
|
||||
chat_state,
|
||||
img_list,
|
||||
)
|
||||
|
||||
|
||||
def upload_img(gr_img, text_input, chat_state, device, precision, _compile):
|
||||
global chat
|
||||
if chat is None:
|
||||
from apps.language_models.src.pipelines.minigpt4_pipeline import (
|
||||
MiniGPT4,
|
||||
CONV_VISION,
|
||||
)
|
||||
|
||||
vision_model_precision = precision
|
||||
if precision in ["int4", "int8"]:
|
||||
vision_model_precision = "fp16"
|
||||
vision_model_vmfb_path = Path(
|
||||
f"vision_model_{vision_model_precision}_{device}.vmfb"
|
||||
)
|
||||
qformer_vmfb_path = Path(f"qformer_fp32_{device}.vmfb")
|
||||
chat = MiniGPT4(
|
||||
model_name="MiniGPT4",
|
||||
hf_model_path=None,
|
||||
max_new_tokens=30,
|
||||
device=device,
|
||||
precision=precision,
|
||||
_compile=_compile,
|
||||
vision_model_vmfb_path=vision_model_vmfb_path,
|
||||
qformer_vmfb_path=qformer_vmfb_path,
|
||||
)
|
||||
if gr_img is None:
|
||||
return None, None, gr.update(interactive=True), chat_state, None
|
||||
chat_state = CONV_VISION.copy()
|
||||
img_list = []
|
||||
llm_message = chat.upload_img(gr_img, chat_state, img_list)
|
||||
return (
|
||||
gr.update(interactive=False),
|
||||
gr.update(interactive=True, placeholder="Type and press Enter"),
|
||||
gr.update(value="Start Chatting", interactive=False),
|
||||
chat_state,
|
||||
img_list,
|
||||
)
|
||||
|
||||
|
||||
def gradio_ask(user_message, chatbot, chat_state):
|
||||
if len(user_message) == 0:
|
||||
return (
|
||||
gr.update(
|
||||
interactive=True, placeholder="Input should not be empty!"
|
||||
),
|
||||
chatbot,
|
||||
chat_state,
|
||||
)
|
||||
chat.ask(user_message, chat_state)
|
||||
chatbot = chatbot + [[user_message, None]]
|
||||
return "", chatbot, chat_state
|
||||
|
||||
|
||||
def gradio_answer(chatbot, chat_state, img_list, num_beams, temperature):
|
||||
llm_message = chat.answer(
|
||||
conv=chat_state,
|
||||
img_list=img_list,
|
||||
num_beams=num_beams,
|
||||
temperature=temperature,
|
||||
max_new_tokens=300,
|
||||
max_length=2000,
|
||||
)[0]
|
||||
print(llm_message)
|
||||
print("************")
|
||||
chatbot[-1][1] = llm_message
|
||||
return chatbot, chat_state, img_list
|
||||
|
||||
|
||||
title = """<h1 align="center">MultiModal SHARK (experimental)</h1>"""
|
||||
description = """<h3>Upload your images and start chatting!</h3>"""
|
||||
article = """<p><a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a></p><p><a href='https://github.com/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/Github-Code-blue'></a></p><p><a href='https://raw.githubusercontent.com/Vision-CAIR/MiniGPT-4/main/MiniGPT_4.pdf'><img src='https://img.shields.io/badge/Paper-PDF-red'></a></p>
|
||||
"""
|
||||
|
||||
# TODO show examples below
|
||||
|
||||
with gr.Blocks() as minigpt4_web:
|
||||
gr.Markdown(title)
|
||||
gr.Markdown(description)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column(scale=0.5):
|
||||
image = gr.Image(type="pil")
|
||||
upload_button = gr.Button(
|
||||
value="Upload & Start Chat",
|
||||
interactive=True,
|
||||
variant="primary",
|
||||
)
|
||||
clear = gr.Button("Restart")
|
||||
|
||||
num_beams = gr.Slider(
|
||||
minimum=1,
|
||||
maximum=10,
|
||||
value=1,
|
||||
step=1,
|
||||
interactive=True,
|
||||
label="beam search numbers)",
|
||||
)
|
||||
|
||||
temperature = gr.Slider(
|
||||
minimum=0.1,
|
||||
maximum=2.0,
|
||||
value=1.0,
|
||||
step=0.1,
|
||||
interactive=True,
|
||||
label="Temperature",
|
||||
)
|
||||
|
||||
device = gr.Dropdown(
|
||||
label="Device",
|
||||
value="cuda",
|
||||
# if enabled
|
||||
# else "Only CUDA Supported for now",
|
||||
choices=["cuda"],
|
||||
interactive=False,
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
chat_state = gr.State()
|
||||
img_list = gr.State()
|
||||
chatbot = gr.Chatbot(label="MiniGPT-4")
|
||||
text_input = gr.Textbox(
|
||||
label="User",
|
||||
placeholder="Please upload your image first",
|
||||
interactive=False,
|
||||
)
|
||||
precision = gr.Radio(
|
||||
label="Precision",
|
||||
value="int8",
|
||||
choices=[
|
||||
"int8",
|
||||
"fp16",
|
||||
"fp32",
|
||||
],
|
||||
visible=True,
|
||||
)
|
||||
_compile = gr.Checkbox(
|
||||
value=False,
|
||||
label="Compile",
|
||||
interactive=True,
|
||||
)
|
||||
|
||||
upload_button.click(
|
||||
upload_img,
|
||||
[image, text_input, chat_state, device, precision, _compile],
|
||||
[image, text_input, upload_button, chat_state, img_list],
|
||||
)
|
||||
|
||||
text_input.submit(
|
||||
gradio_ask,
|
||||
[text_input, chatbot, chat_state],
|
||||
[text_input, chatbot, chat_state],
|
||||
).then(
|
||||
gradio_answer,
|
||||
[chatbot, chat_state, img_list, num_beams, temperature],
|
||||
[chatbot, chat_state, img_list],
|
||||
)
|
||||
clear.click(
|
||||
gradio_reset,
|
||||
[chat_state, img_list],
|
||||
[chatbot, image, text_input, upload_button, chat_state, img_list],
|
||||
queue=False,
|
||||
)
|
||||
@@ -56,3 +56,14 @@ for line in fileinput.input(path_to_lazy_loader, inplace=True):
|
||||
)
|
||||
else:
|
||||
print(line, end="")
|
||||
|
||||
# For getting around timm's packaging.
|
||||
# Refer: https://github.com/pyinstaller/pyinstaller/issues/5673#issuecomment-808731505
|
||||
path_to_timm_activations = Path(
|
||||
get_python_lib() + "/timm/layers/activations_jit.py"
|
||||
)
|
||||
for line in fileinput.input(path_to_timm_activations, inplace=True):
|
||||
if "@torch.jit.script" in line:
|
||||
print("@torch.jit._script_if_tracing", end="\n")
|
||||
else:
|
||||
print(line, end="")
|
||||
|
||||
@@ -15,3 +15,4 @@ build-backend = "setuptools.build_meta"
|
||||
line-length = 79
|
||||
include = '\.pyi?$'
|
||||
exclude = "apps/language_models/scripts/vicuna.py"
|
||||
extend-exclude = "apps/language_models/src/pipelines/minigpt4_pipeline.py"
|
||||
|
||||
@@ -34,6 +34,7 @@ sentencepiece
|
||||
py-cpuinfo
|
||||
tiktoken # for codegen
|
||||
joblib # for langchain
|
||||
timm # for MiniGPT4
|
||||
|
||||
# Keep PyInstaller at the end. Sometimes Windows Defender flags it but most folks can continue even if it errors
|
||||
pefile
|
||||
|
||||
@@ -94,18 +94,5 @@ p.add_argument(
|
||||
help="Profiles vulkan device and collects the .rdc info",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="4147483648",
|
||||
help="flag for setting VMA preferredLargeHeapBlockSize for vulkan device, default is 4G",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="flag for disabling vulkan validation layers when benchmarking",
|
||||
)
|
||||
|
||||
|
||||
args = p.parse_args()
|
||||
|
||||
@@ -6,6 +6,7 @@ from shark.shark_importer import import_with_fx
|
||||
from shark.iree_utils.vulkan_utils import (
|
||||
set_iree_vulkan_runtime_flags,
|
||||
get_vulkan_target_triple,
|
||||
get_iree_vulkan_runtime_flags,
|
||||
)
|
||||
|
||||
|
||||
@@ -75,10 +76,7 @@ def compile_through_fx(
|
||||
|
||||
|
||||
def set_iree_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if args.vulkan_validation_layers else 'false'}",
|
||||
]
|
||||
vulkan_runtime_flags = get_iree_vulkan_runtime_flags()
|
||||
if args.enable_rgp:
|
||||
vulkan_runtime_flags += [
|
||||
f"--enable_rgp=true",
|
||||
|
||||
@@ -19,6 +19,7 @@ from shark.iree_utils._common import run_cmd
|
||||
import iree.runtime as ireert
|
||||
from sys import platform
|
||||
from shark.iree_utils.vulkan_target_env_utils import get_vulkan_target_env_flag
|
||||
from shark.parser import shark_args
|
||||
|
||||
|
||||
def get_vulkan_device_name(device_num=0):
|
||||
@@ -171,6 +172,15 @@ def get_iree_vulkan_args(device_num=0, extra_args=[]):
|
||||
return res_vulkan_flag
|
||||
|
||||
|
||||
def get_iree_vulkan_runtime_flags():
|
||||
vulkan_runtime_flags = [
|
||||
f"--vulkan_large_heap_block_size={shark_args.vulkan_large_heap_block_size}",
|
||||
f"--vulkan_validation_layers={'true' if shark_args.vulkan_validation_layers else 'false'}",
|
||||
f"--vulkan_vma_allocator={'true' if shark_args.vulkan_vma_allocator else 'false'}",
|
||||
]
|
||||
return vulkan_runtime_flags
|
||||
|
||||
|
||||
def set_iree_vulkan_runtime_flags(flags):
|
||||
for flag in flags:
|
||||
ireert.flags.parse_flags(flag)
|
||||
|
||||
@@ -126,4 +126,32 @@ parser.add_argument(
|
||||
help="passthrough flag for the iree flag of the same name. If None, defaults to cpu-count",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_debug_utils",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Profiles vulkan device and collects the .rdc info.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_large_heap_block_size",
|
||||
default="2073741824",
|
||||
help="Flag for setting VMA preferredLargeHeapBlockSize for "
|
||||
"vulkan device, default is 4G.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_validation_layers",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for disabling vulkan validation layers when benchmarking.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vulkan_vma_allocator",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Flag for enabling / disabling Vulkan VMA Allocator.",
|
||||
)
|
||||
|
||||
shark_args, unknown = parser.parse_known_args()
|
||||
|
||||
@@ -2,6 +2,55 @@ import os
|
||||
import tempfile
|
||||
from shark.shark_inference import SharkInference
|
||||
from shark.shark_importer import import_with_fx
|
||||
import torch
|
||||
import torch_mlir
|
||||
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
|
||||
from typing import List, Tuple
|
||||
from io import BytesIO
|
||||
from brevitas_examples.llm.llm_quant.quantize import quantize_model
|
||||
from brevitas_examples.llm.llm_quant.run_utils import get_model_impl
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡shape(
|
||||
lhs: List[int],
|
||||
rhs: List[int],
|
||||
rhs_scale: List[int],
|
||||
rhs_zero_point: List[int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> List[int]:
|
||||
if len(lhs) == 3 and len(rhs) == 2:
|
||||
return [lhs[0], lhs[1], rhs[0]]
|
||||
elif len(lhs) == 2 and len(rhs) == 2:
|
||||
return [lhs[0], rhs[0]]
|
||||
else:
|
||||
raise ValueError("Input shapes not supported.")
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡dtype(
|
||||
lhs_rank_dtype: Tuple[int, int],
|
||||
rhs_rank_dtype: Tuple[int, int],
|
||||
rhs_scale_rank_dtype: Tuple[int, int],
|
||||
rhs_zero_point_rank_dtype: Tuple[int, int],
|
||||
rhs_bit_width: int,
|
||||
rhs_group_size: int,
|
||||
) -> int:
|
||||
# output dtype is the dtype of the lhs float input
|
||||
lhs_rank, lhs_dtype = lhs_rank_dtype
|
||||
return lhs_dtype
|
||||
|
||||
|
||||
def brevitas〇matmul_rhs_group_quant〡has_value_semantics(
|
||||
lhs, rhs, rhs_scale, rhs_zero_point, rhs_bit_width, rhs_group_size
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
brevitas_matmul_rhs_group_quant_library = [
|
||||
brevitas〇matmul_rhs_group_quant〡shape,
|
||||
brevitas〇matmul_rhs_group_quant〡dtype,
|
||||
brevitas〇matmul_rhs_group_quant〡has_value_semantics,
|
||||
]
|
||||
|
||||
|
||||
def load_vmfb(extended_model_name, device, mlir_dialect, extra_args=[]):
|
||||
@@ -39,11 +88,90 @@ def compile_module(
|
||||
return shark_module
|
||||
|
||||
|
||||
def compile_int_precision(
|
||||
model, inputs, precision, device, generate_vmfb, extended_model_name
|
||||
):
|
||||
weight_bit_width = 4 if precision == "int4" else 8
|
||||
weight_group_size = 128
|
||||
quantize_model(
|
||||
get_model_impl(model),
|
||||
dtype=torch.float32,
|
||||
weight_quant_type="asym",
|
||||
weight_bit_width=weight_bit_width,
|
||||
weight_param_method="stats",
|
||||
weight_scale_precision="float",
|
||||
weight_quant_granularity="per_group",
|
||||
weight_group_size=weight_group_size,
|
||||
quantize_weight_zero_point=False,
|
||||
input_bit_width=None,
|
||||
input_scale_type="float",
|
||||
input_param_method="stats",
|
||||
input_quant_type="asym",
|
||||
input_quant_granularity="per_tensor",
|
||||
quantize_input_zero_point=False,
|
||||
seqlen=2048,
|
||||
)
|
||||
print("Weight quantization applied.")
|
||||
torchscript_module = import_with_fx(
|
||||
model,
|
||||
inputs,
|
||||
precision=precision,
|
||||
mlir_type="torchscript",
|
||||
)
|
||||
mlir_module = torch_mlir.compile(
|
||||
torchscript_module,
|
||||
inputs,
|
||||
output_type="torch",
|
||||
backend_legal_ops=["brevitas.matmul_rhs_group_quant"],
|
||||
extra_library=brevitas_matmul_rhs_group_quant_library,
|
||||
use_tracing=False,
|
||||
verbose=False,
|
||||
)
|
||||
print(f"[DEBUG] converting torch to linalg")
|
||||
run_pipeline_with_repro_report(
|
||||
mlir_module,
|
||||
"builtin.module(func.func(torch-unpack-torch-tensor),torch-backend-to-linalg-on-tensors-backend-pipeline)",
|
||||
description="Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
|
||||
)
|
||||
from contextlib import redirect_stdout
|
||||
|
||||
mlir_file_path = os.path.join(
|
||||
os.getcwd(), f"{extended_model_name}_linalg.mlir"
|
||||
)
|
||||
with open(mlir_file_path, "w") as f:
|
||||
with redirect_stdout(f):
|
||||
print(mlir_module.operation.get_asm())
|
||||
mlir_module = str(mlir_module)
|
||||
mlir_module = mlir_module.encode("UTF-8")
|
||||
mlir_module = BytesIO(mlir_module)
|
||||
bytecode = mlir_module.read()
|
||||
print(f"Elided IR written for {extended_model_name}")
|
||||
return bytecode
|
||||
shark_module = SharkInference(
|
||||
mlir_module=bytecode, device=device, mlir_dialect="tm_tensor"
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
return (
|
||||
compile_module(
|
||||
shark_module,
|
||||
extended_model_name=extended_model_name,
|
||||
generate_vmfb=generate_vmfb,
|
||||
extra_args=extra_args,
|
||||
),
|
||||
bytecode,
|
||||
)
|
||||
|
||||
|
||||
def shark_compile_through_fx(
|
||||
model,
|
||||
inputs,
|
||||
extended_model_name,
|
||||
is_f16=False,
|
||||
precision,
|
||||
f16_input_mask=None,
|
||||
save_dir=tempfile.gettempdir(),
|
||||
debug=False,
|
||||
@@ -52,6 +180,7 @@ def shark_compile_through_fx(
|
||||
device=None,
|
||||
mlir_dialect="tm_tensor",
|
||||
):
|
||||
is_f16 = precision == "fp16"
|
||||
if generate_or_load_vmfb:
|
||||
shark_module = load_vmfb(
|
||||
extended_model_name=extended_model_name,
|
||||
@@ -70,18 +199,34 @@ def shark_compile_through_fx(
|
||||
if "cuda" in device:
|
||||
shark_args.enable_tf32 = True
|
||||
|
||||
(
|
||||
mlir_module,
|
||||
_,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
if precision in ["int4", "int8"]:
|
||||
mlir_module = compile_int_precision(
|
||||
model,
|
||||
inputs,
|
||||
precision,
|
||||
device,
|
||||
generate_or_load_vmfb,
|
||||
extended_model_name,
|
||||
)
|
||||
extra_args = [
|
||||
"--iree-hal-dump-executable-sources-to=ies",
|
||||
"--iree-vm-target-truncate-unsupported-floats",
|
||||
"--iree-codegen-check-ir-before-llvm-conversion=false",
|
||||
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
|
||||
]
|
||||
else:
|
||||
(
|
||||
mlir_module,
|
||||
_,
|
||||
) = import_with_fx(
|
||||
model=model,
|
||||
inputs=inputs,
|
||||
is_f16=is_f16,
|
||||
f16_input_mask=f16_input_mask,
|
||||
debug=debug,
|
||||
model_name=extended_model_name,
|
||||
save_dir=save_dir,
|
||||
)
|
||||
|
||||
shark_module = SharkInference(
|
||||
mlir_module,
|
||||
|
||||
@@ -488,7 +488,7 @@ def flatten_training_input(inputs):
|
||||
return tuple(flattened_input)
|
||||
|
||||
|
||||
# TODO: get rid of is_f16 by using precision
|
||||
# TODO: Remove is_f16 and fix all calls with using precision instead
|
||||
# Applies fx conversion to the model and imports the mlir.
|
||||
def import_with_fx(
|
||||
model,
|
||||
|
||||
Reference in New Issue
Block a user